温馨提示×

pytorch怎么改变tensor维度

小亿
81
2024-12-31 05:09:09
栏目: 深度学习

在PyTorch中,可以使用unsqueeze()squeeze()方法来改变Tensor的维度。

unsqueeze()方法可以在指定位置插入一个新的维度,其值都为1。例如,如果你有一个形状为(3,)的Tensor,你可以使用unsqueeze()方法将其形状变为(1, 3):

import torch

x = torch.tensor([1, 2, 3])
y = x.unsqueeze(0)
print(y.shape)  # 输出(1, 3)

squeeze()方法可以删除指定位置的维度,其大小为1。例如,如果你有一个形状为(1, 3)的Tensor,你可以使用squeeze()方法将其形状变为(3,):

import torch

x = torch.tensor([[1, 2, 3]])
y = x.squeeze(0)
print(y.shape)  # 输出(3,)

注意,unsqueeze()squeeze()方法都不会改变Tensor中的数据,只会改变其形状。

0