温馨提示×

pytorch全连接神经网络怎么构建

小樊
81
2024-12-26 14:39:47
栏目: 深度学习

在PyTorch中,构建全连接神经网络(也称为密集层或线性层)非常简单。以下是一个简单的示例,展示了如何构建一个具有单个全连接层的网络:

import torch
import torch.nn as nn

# 定义一个简单的全连接神经网络
class FullyConnectedNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(FullyConnectedNetwork, self).__init__()
        
        # 创建一个全连接层,输入大小为 input_size,输出大小为 hidden_size
        self.fc1 = nn.Linear(input_size, hidden_size)
        
        # 创建另一个全连接层,输入大小为 hidden_size,输出大小为 output_size
        self.fc2 = nn.Linear(hidden_size, output_size)
        
        # 可选:添加激活函数,例如 ReLU
        self.relu = nn.ReLU()
    
    def forward(self, x):
        # 通过第一个全连接层
        x = self.fc1(x)
        
        # 可选:应用激活函数
        x = self.relu(x)
        
        # 通过第二个全连接层
        x = self.fc2(x)
        
        return x

# 参数设置
input_size = 784  # 例如,MNIST数据集的图像大小为28x28
hidden_size = 128
output_size = 10  # 例如,MNIST数据集有10个类别

# 创建网络实例
network = FullyConnectedNetwork(input_size, hidden_size, output_size)

# 打印网络结构
print(network)

在这个示例中,我们定义了一个名为FullyConnectedNetwork的类,它继承自nn.Module。我们在__init__方法中定义了两个全连接层(fc1fc2),并在forward方法中定义了数据通过网络的方向。

你可以根据需要调整网络结构,例如添加更多的全连接层、使用不同的激活函数等。

0