温馨提示×

pytorch全连接层能自定义吗

小樊
81
2024-12-25 19:46:24
栏目: 深度学习

当然可以!PyTorch中的全连接层可以通过继承nn.Module类并实现自己的前向传播函数来自定义。以下是一个简单的自定义全连接层的示例:

import torch
import torch.nn as nn

class CustomLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(CustomLinear, self).__init__()
        self.linear = nn.Linear(in_features, out_features, bias)

    def forward(self, x):
        # 在这里可以添加自定义的前向传播逻辑
        return self.linear(x)

# 示例用法
in_features = 784
out_features = 10
model = CustomLinear(in_features, out_features)
input_data = torch.randn(1, in_features)
output_data = model(input_data)
print(output_data.shape)  # 输出: torch.Size([1, 10])

在这个示例中,我们定义了一个名为CustomLinear的自定义全连接层,它接受输入特征数in_features和输出特征数out_features作为参数。在forward方法中,我们可以添加自定义的前向传播逻辑。在这个简单的例子中,我们只是直接调用了nn.Linear的前向传播函数。

0