温馨提示×

如何在Torch中定义一个神经网络模型

小樊
53
2024-03-30 19:29:59
栏目: 深度学习

在Torch中定义一个神经网络模型通常需要使用nn.Module类。下面是一个示例代码,展示了如何定义一个简单的全连接神经网络模型:

import torch
import torch.nn as nn

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

model = SimpleNN()

在上面的代码中,我们定义了一个名为SimpleNN的神经网络模型,它包含两个全连接层和一个ReLU激活函数。在__init__方法中,我们定义了模型的各个层,然后在forward方法中定义了数据在模型中的流动路径。

需要注意的是,在定义神经网络模型时,通常需要继承nn.Module类,并实现__init__forward方法。__init__方法用于初始化模型的结构,forward方法用于定义数据在模型中的传播路径。

0