温馨提示×

首页 > 教程 > AI深度学习 > PyTorch教程 > 定义模型结构

定义模型结构

在PyTorch中定义模型结构通常需要创建一个继承自nn.Module的类。这个类需要实现两个方法:__init__方法用来定义模型的结构,forward方法用来定义模型的前向传播过程。下面是一个简单的例子,展示如何定义一个基本的神经网络模型:

import torch
import torch.nn as nn

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(784, 128)  # 输入层到隐藏层的全连接层
        self.relu = nn.ReLU()  # 激活函数
        self.fc2 = nn.Linear(128, 10)  # 隐藏层到输出层的全连接层
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# 创建模型实例
model = SimpleNN()

在这个例子中,我们定义了一个包含一个隐藏层的全连接神经网络模型。模型的输入维度是784,输出维度是10。在__init__方法中,我们定义了两个全连接层和一个ReLU激活函数。在forward方法中,我们定义了模型的前向传播过程。

除了上面这种简单的方式,还可以使用nn.Sequential来定义模型结构,这种方式更加简洁。下面是一个使用nn.Sequential的例子:

import torch
import torch.nn as nn

model = nn.Sequential(
    nn.Linear(784, 128),  # 输入层到隐藏层的全连接层
    nn.ReLU(),  # 激活函数
    nn.Linear(128, 10)  # 隐藏层到输出层的全连接层
)

使用nn.Sequential方式定义模型结构时,不需要定义forward方法,PyTorch会自动根据模型的结构生成前向传播过程。

无论是使用自定义的类还是nn.Sequential,都可以很方便地定义PyTorch模型的结构。建议根据具体的任务需求选择合适的方式。