背景
使用pytorch时,有一个yolov3的bug,我认为涉及到学习率的调整。收集到tencent yolov3和mxnet开源的yolov3,两个优化器中的学习率设置不一样,而且使用GPU数目和batch的更新也不太一样。据此,我简单的了解了下pytorch的权重梯度的更新策略,看看能否一窥究竟。
对代码说明
共三个实验,分布写在代码中的(一)(二)(三)三个地方。运行实验时注释掉其他两个
实验及其结果
实验(三):
不使用zero_grad()时,grad累加在一起,官网是使用accumulate 来表述的,所以不太清楚是取的和还是均值(这两种最有可能)。
不使用zero_grad()时,是直接叠加add的方式累加的。
tensor([[[ 1., 1.],……torch.Size([2, 2, 2]) 0 2 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * tensor([[[ 2., 2.],…… torch.Size([2, 2, 2]) 1 2 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * tensor([[[ 3., 3.],…… torch.Size([2, 2, 2]) 2 2 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
实验(二):
单卡上不同的batchsize对梯度是怎么作用的。 mini-batch SGD中的batch是加快训练,同时保持一定的噪声。但设置不同的batchsize的权重的梯度是怎么计算的呢。
设置运行实验(二),可以看到结果如下:所以单卡batchsize计算梯度是取均值的
tensor([[[ 3., 3.],…… torch.Size([2, 2, 2])
实验(一):
多gpu情况下,梯度怎么合并在一起的。
在《training imagenet in 1 hours》中提到grad是allreduce的,是累加的形式。但是当设置g=2,实验一运行时,结果也是取均值的,类同于实验(二)
tensor([[[ 3., 3.],…… torch.Size([2, 2, 2])
实验代码
import torch import torch.nn as nn from torch.autograd import Variable class model(nn.Module): def __init__(self, w): super(model, self).__init__() self.w = w def forward(self, xx): b, c, _, _ = xx.shape # extra = xx.device.index + 1 ## 实验(一) y = xx.reshape(b, -1).mm(self.w.cuda(xx.device).reshape(-1, 2) * extra) return y.reshape(len(xx), -1) g = 1 x = Variable(torch.ones(2, 1, 2, 2)) # x[1] += 1 ## 实验(二) w = Variable(torch.ones(2, 2, 2) * 2, requires_grad=True) # optim = torch.optim.SGD({'params': x}, lr = 0.01 momentum = 0.9 M = model(w) M = torch.nn.DataParallel(M, device_ids=range(g)) for i in range(3): b = len(x) z = M(x) zz = z.sum(1) l = (zz - Variable(torch.ones(b).cuda())).mean() # zz.backward(Variable(torch.ones(b).cuda())) l.backward() print(w.grad, w.grad.shape) # w.grad.zero_() ## 实验(三) print(i, b, '* * ' * 20)
以上这篇对pytorch中的梯度更新方法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持亿速云。
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。