当然可以!PyTorch中的全连接层可以通过继承nn.Module
类并实现自己的前向传播函数来自定义。以下是一个简单的自定义全连接层的示例:
import torch
import torch.nn as nn
class CustomLinear(nn.Module):
def __init__(self, in_features, out_features, bias=True):
super(CustomLinear, self).__init__()
self.linear = nn.Linear(in_features, out_features, bias)
def forward(self, x):
# 在这里可以添加自定义的前向传播逻辑
return self.linear(x)
# 示例用法
in_features = 784
out_features = 10
model = CustomLinear(in_features, out_features)
input_data = torch.randn(1, in_features)
output_data = model(input_data)
print(output_data.shape) # 输出: torch.Size([1, 10])
在这个示例中,我们定义了一个名为CustomLinear
的自定义全连接层,它接受输入特征数in_features
和输出特征数out_features
作为参数。在forward
方法中,我们可以添加自定义的前向传播逻辑。在这个简单的例子中,我们只是直接调用了nn.Linear
的前向传播函数。