温馨提示×

pytorch mnist分类数据如何预处理

小樊
82
2024-12-26 15:25:47
栏目: 深度学习

在PyTorch中,对MNIST手写数字分类数据进行预处理的步骤如下:

  1. 导入相关库:
import torch
from torchvision import datasets, transforms
  1. 下载并加载MNIST数据集:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

在这个例子中,我们首先导入了所需的库,然后使用transforms.Compose()函数定义了一个预处理管道。这个管道包括两个步骤:

  • transforms.ToTensor():将图像数据从(28, 28)的numpy数组转换为PyTorch张量,数值范围从[0, 255]缩放到[0, 1]。
  • transforms.Normalize((0.1307,), (0.3081,)):对图像数据进行归一化处理,使用MNIST数据集的均值和标准差作为参数。

接下来,我们分别加载了训练集和测试集,并使用预处理管道对它们进行了处理。

0