本文小编为大家详细介绍“PyTorch的TensorDataset功能怎么使用”,内容详细,步骤清晰,细节处理妥当,希望这篇“PyTorch的TensorDataset功能怎么使用”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧。
TensorDataset 可以用来对 tensor 进行打包,就好像 python 中的 zip 功能。
该类通过每一个 tensor 的第一个维度进行索引。
因此,该类中的 tensor 第一维度必须相等。
from torch.utils.data import TensorDataset import torch from torch.utils.data import DataLoader a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]]) b = torch.tensor([44, 55, 66, 44, 55, 66, 44, 55, 66, 44, 55, 66]) train_ids = TensorDataset(a, b) # 切片输出 print(train_ids[0:2]) print('=' * 80) # 循环取数据 for x_train, y_label in train_ids: print(x_train, y_label) # DataLoader进行数据封装 print('=' * 80) train_loader = DataLoader(dataset=train_ids, batch_size=4, shuffle=True) for i, data in enumerate(train_loader, 1): # 注意enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签) x_data, label = data print(' batch:{0} x_data:{1} label: {2}'.format(i, x_data, label))
运行结果:
(tensor([[1, 2, 3],
[4, 5, 6]]), tensor([44, 55]))
================================================================================
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
================================================================================
batch:1 x_data:tensor([[1, 2, 3],
[1, 2, 3],
[4, 5, 6],
[4, 5, 6]]) label: tensor([44, 44, 55, 55])
batch:2 x_data:tensor([[4, 5, 6],
[7, 8, 9],
[7, 8, 9],
[7, 8, 9]]) label: tensor([55, 66, 66, 66])
batch:3 x_data:tensor([[1, 2, 3],
[1, 2, 3],
[7, 8, 9],
[4, 5, 6]]) label: tensor([44, 44, 66, 55])
注意:TensorDataset 中的参数必须是 tensor
Pytorch中,TensorDataset()可以快速构建训练所用的数据,不用使用自建的Mydataset(),如果没有熟悉适用的dataset可以使用TensorDataset()作为暂时替代。
只需要把data和label作为参数输入,就可以快速构建,之后便可以用Dataloader处理。
import numpy as np from torch.utils.data import DataLoader, TensorDataset data = np.loadtxt('x.txt') label = np.loadtxt('y.txt') data = torch.tensor(data) label = torch.tensor(label) train_data = TensorDataset(data, label) train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
读到这里,这篇“PyTorch的TensorDataset功能怎么使用”文章已经介绍完毕,想要掌握这篇文章的知识点还需要大家自己动手实践使用过才能领会,如果想了解更多相关内容的文章,欢迎关注亿速云行业资讯频道。
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。