在PyTorch中,正则化是一种常用的技术,用于防止模型过拟合。常见的正则化方法包括L1正则化和L2正则化。
在PyTorch中,可以使用nn.Module
的add_weight()
方法为模型参数添加正则化项。例如,以下代码为模型的权重添加了L2正则化项:
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3)
self.fc1 = nn.Linear(128 * 25 * 25, 1024)
self.fc2 = nn.Linear(1024, 512)
self.fc3 = nn.Linear(512, 10)
# 添加L2正则化项
for param in self.parameters():
param.requires_grad = True
param.register_hook(lambda x: x * (1 - 0.001))
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = x.view(-1, 128 * 25 * 25)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
在上面的代码中,我们使用了register_hook()
方法为每个参数添加了一个钩子函数,该函数将参数乘以一个因子(在这里是1 - 0.001
),从而实现了L2正则化。
除了L2正则化外,还可以使用其他正则化方法,例如L1正则化和Dropout。在PyTorch中,这些方法也可以很容易地实现。