温馨提示×

首页 > 教程 > AI深度学习 > PyTorch教程 > 正则化(Regularization)

正则化(Regularization)

正则化是一种常用的技术,用于减少模型的过拟合,提高模型的泛化能力。在PyTorch中,我们可以通过在损失函数中添加正则化项来实现正则化。

在深度学习中,常用的正则化方法有L1正则化和L2正则化。L1正则化会使模型的权重稀疏化,即让一部分权重变为0,从而减少模型的复杂度;而L2正则化会惩罚模型参数的平方和,使模型参数更加平滑,从而减少过拟合。

下面是一个简单的示例,展示如何在PyTorch中实现L2正则化:

import torch
import torch.nn as nn
import torch.optim as optim

# 定义一个简单的神经网络模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.fc2(x)
        return x

model = Net()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.001) # 添加L2正则化项

# 训练模型
for epoch in range(num_epochs):
    for i, data in enumerate(train_loader):
        inputs, labels = data
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        # 添加L2正则化项
        l2_reg = torch.tensor(0.)
        for param in model.parameters():
            l2_reg += torch.norm(param)
        
        loss += weight_decay * l2_reg
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

在上面的示例中,我们在优化器中设置了weight_decay参数,用来控制L2正则化项的权重。然后在每次计算损失函数时,我们计算模型所有参数的L2范数,并加到损失函数中,从而实现L2正则化。

除了在优化器中设置weight_decay参数外,我们也可以使用torch.optim中的其他优化器类,如AdamW,它在Adam的基础上加入了权重衰减(weight decay)功能,相当于L2正则化。

正则化是提高模型泛化能力的重要手段,可以帮助我们有效减少过拟合问题。在实际应用中,可以根据具体情况选择适合的正则化方法和参数设置。