温馨提示×

PyTorch中的数据加载器怎么使用

小亿
91
2024-03-14 13:34:28
栏目: 深度学习

在PyTorch中,数据加载器可以通过torch.utils.data.DataLoader来实现。数据加载器可以帮助用户批量加载数据,并可以在训练过程中对数据进行随机排列、并行加载等操作。

下面是一个简单的示例,演示如何使用数据加载器来加载一个简单的数据集:

import torch
from torch.utils.data import Dataset, DataLoader

# 创建一个自定义的数据集类
class CustomDataset(Dataset):
    def __init__(self):
        self.data = torch.randn(100, 3)  # 100个3维的随机数据
        self.targets = torch.randint(0, 2, (100,))  # 100个随机目标标签

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

# 创建数据集实例
dataset = CustomDataset()

# 创建数据加载器实例
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# 遍历数据加载器
for i, (data, target) in enumerate(data_loader):
    print(f'Batch {i}:')
    print('Data:', data)
    print('Target:', target)

在上述示例中,首先定义了一个自定义的数据集类CustomDataset,然后创建了一个数据集实例dataset。接着利用DataLoader类来创建一个数据加载器实例data_loader,并指定了批量大小为32且开启了数据随机排列。最后通过对数据加载器进行遍历,便可以逐批次地获取数据和标签。

0