要实现自定义数据集类,需要继承PyTorch中的Dataset类,并重写其中的两个方法:len__和__getitem。下面是一个简单的例子,演示如何实现一个自定义数据集类:
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data, targets):
self.data = data
self.targets = targets
def __len__(self):
return len(self.data)
def __getitem__(self, index):
data_point = self.data[index]
target = self.targets[index]
return data_point, target
在上面的例子中,CustomDataset类接收两个参数data和targets作为初始化参数,分别表示数据和标签。然后重写了__len__方法,返回数据集的长度,重写了__getitem__方法,根据索引index返回对应的数据点和标签。
使用这个自定义数据集类的方法如下:
data = [...] # your data
targets = [...] # your targets
custom_dataset = CustomDataset(data, targets)
data_loader = torch.utils.data.DataLoader(custom_dataset, batch_size=64, shuffle=True)
for data, target in data_loader:
# do something with data and target
这样就可以通过自定义数据集类来加载自己的数据集,并使用DataLoader来批量加载数据。