温馨提示×

python的nn.linear有什么功能

小亿
460
2023-12-22 10:20:59
栏目: 编程语言

nn.Linear 是 PyTorch 中的一个类,用来定义一个线性变换(线性层)的操作。

具体来说,nn.Linear 用于定义一个线性映射,将输入张量的每个元素与权重矩阵相乘,并加上偏置向量。其功能可以总结如下:

  1. 线性变换:将输入张量与权重矩阵相乘,得到输出张量。输入张量的形状为 (batch_size, input_size),权重矩阵的形状为 (output_size, input_size)。输出张量的形状为 (batch_size, output_size)

  2. 加偏置:将输出张量加上一个偏置向量,该偏置向量的形状为 (output_size,)。偏置向量会被广播到每个样本的输出上。

  3. 自动创建参数:nn.Linear 创建线性层时会自动创建权重矩阵和偏置向量,并将它们保存在模型的参数列表中。

  4. 自动梯度计算:通过 PyTorch 的自动求导机制,nn.Linear 可以自动计算权重矩阵和偏置向量的梯度,并进行优化。

nn.Linear 通常在神经网络模型中被用作全连接层(全连接神经网络),用来将输入特征映射到输出特征。

0