在PyTorch中,nn.Parameter
是一个特殊的Tensor,它是nn.Module
中可训练参数的一种特殊类型。nn.Parameter
对象由nn.Module
的构造函数自动识别并将其注册为模型的可训练参数。
要使用nn.Parameter
,首先需要创建一个nn.Parameter
对象,并将其作为模型的属性。下面是一个简单的示例:
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.weight = nn.Parameter(torch.rand(3, 4)) # 创建一个参数
def forward(self, x):
out = torch.matmul(x, self.weight)
return out
model = MyModel()
print(model.weight) # 打印参数
在上面的示例中,我们定义了一个MyModel
类,它继承自nn.Module
。在构造函数__init__
中,我们创建了一个nn.Parameter
对象self.weight
,它是一个形状为(3, 4)
的随机初始化的Tensor。
在forward
方法中,我们可以使用self.weight
参数进行计算。在模型创建完毕后,我们可以通过model.weight
来访问这个参数。
需要注意的是,nn.Parameter
对象会自动被注册为模型的可训练参数,并且在模型的parameters()
方法中可以访问到。此外,nn.Parameter
对象还会自动具有梯度计算的功能,可以通过backward()
方法自动计算梯度。