在PyTorch中,可以使用unsqueeze()
和squeeze()
方法来改变Tensor的维度。
unsqueeze()
方法可以在指定位置插入一个新的维度,其值都为1。例如,如果你有一个形状为(3,)的Tensor,你可以使用unsqueeze()
方法将其形状变为(1, 3):
import torch
x = torch.tensor([1, 2, 3])
y = x.unsqueeze(0)
print(y.shape) # 输出(1, 3)
squeeze()
方法可以删除指定位置的维度,其大小为1。例如,如果你有一个形状为(1, 3)的Tensor,你可以使用squeeze()
方法将其形状变为(3,):
import torch
x = torch.tensor([[1, 2, 3]])
y = x.squeeze(0)
print(y.shape) # 输出(3,)
注意,unsqueeze()
和squeeze()
方法都不会改变Tensor中的数据,只会改变其形状。