nn.Linear
是 PyTorch 中的一个类,用来定义一个线性变换(线性层)的操作。
具体来说,nn.Linear
用于定义一个线性映射,将输入张量的每个元素与权重矩阵相乘,并加上偏置向量。其功能可以总结如下:
线性变换:将输入张量与权重矩阵相乘,得到输出张量。输入张量的形状为 (batch_size, input_size)
,权重矩阵的形状为 (output_size, input_size)
。输出张量的形状为 (batch_size, output_size)
。
加偏置:将输出张量加上一个偏置向量,该偏置向量的形状为 (output_size,)
。偏置向量会被广播到每个样本的输出上。
自动创建参数:nn.Linear
创建线性层时会自动创建权重矩阵和偏置向量,并将它们保存在模型的参数列表中。
自动梯度计算:通过 PyTorch 的自动求导机制,nn.Linear
可以自动计算权重矩阵和偏置向量的梯度,并进行优化。
nn.Linear
通常在神经网络模型中被用作全连接层(全连接神经网络),用来将输入特征映射到输出特征。