这篇文章主要介绍“如何理解Python中的pyTorch权重衰减与L2范数正则化”,在日常操作中,相信很多人在如何理解Python中的pyTorch权重衰减与L2范数正则化问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”如何理解Python中的pyTorch权重衰减与L2范数正则化”的疑惑有所帮助!接下来,请跟着小编一起来学习吧!
下面进行一个高维线性实验
假设我们的真实方程是:
假设feature数200,训练样本和测试样本各20个
num_train,num_test = 10,10 num_features = 200 true_w = torch.ones((num_features,1),dtype=torch.float32) * 0.01 true_b = torch.tensor(0.5) samples = torch.normal(0,1,(num_train+num_test,num_features)) noise = torch.normal(0,0.01,(num_train+num_test,1)) labels = samples.matmul(true_w) + true_b + noise train_samples, train_labels= samples[:num_train],labels[:num_train] test_samples, test_labels = samples[num_train:],labels[num_train:]
def loss_function(predict,label,w,lambd):
loss = (predict - label) ** 2
loss = loss.mean() + lambd * (w**2).mean()
return loss
def semilogy(x_val,y_val,x_label,y_label,x2_val,y2_val,legend):
plt.figure(figsize=(3,3))
plt.xlabel(x_label)
plt.ylabel(y_label)
plt.semilogy(x_val,y_val)
if x2_val and y2_val:
plt.semilogy(x2_val,y2_val)
plt.legend(legend)
plt.show()
def fit_and_plot(train_samples,train_labels,test_samples,test_labels,num_epoch,lambd):
w = torch.normal(0,1,(train_samples.shape[-1],1),requires_grad=True)
b = torch.tensor(0.,requires_grad=True)
optimizer = torch.optim.Adam([w,b],lr=0.05)
train_loss = []
test_loss = []
for epoch in range(num_epoch):
predict = train_samples.matmul(w) + b
epoch_train_loss = loss_function(predict,train_labels,w,lambd)
optimizer.zero_grad()
epoch_train_loss.backward()
optimizer.step()
test_predict = test_sapmles.matmul(w) + b
epoch_test_loss = loss_function(test_predict,test_labels,w,lambd)
train_loss.append(epoch_train_loss.item())
test_loss.append(epoch_test_loss.item())
semilogy(range(1,num_epoch+1),train_loss,'epoch','loss',range(1,num_epoch+1),test_loss,['train','test'])
可以发现加了正则项的模型,在测试集上的loss确实下降了
到此,关于“如何理解Python中的pyTorch权重衰减与L2范数正则化”的学习就结束了,希望能够解决大家的疑惑。理论与实践的搭配能更好的帮助大家学习,快去试试吧!若想继续学习更多相关知识,请继续关注亿速云网站,小编会继续努力为大家带来更多实用的文章!
亿速云「云服务器」,即开即用、新一代英特尔至强铂金CPU、三副本存储NVMe SSD云盘,价格低至29元/月。点击查看>>
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。