温馨提示×

pytorch张量如何拼接

小樊
84
2024-12-25 20:54:26
栏目: 深度学习
开发者测试专用服务器限时活动,0元免费领,库存有限,领完即止! 点击查看>>

在PyTorch中,可以使用torch.cat()函数来拼接张量。torch.cat()函数需要指定沿着哪个维度进行拼接。以下是一些示例:

  1. 沿第一个维度(axis=0)拼接两个相同形状的张量:
import torch

tensor1 = torch.randn(2, 3)
tensor2 = torch.randn(2, 3)

result = torch.cat((tensor1, tensor2), dim=0)
print(result.shape)  # 输出:(4, 3)
  1. 沿第二个维度(axis=1)拼接两个相同形状的张量:
import torch

tensor1 = torch.randn(2, 3)
tensor2 = torch.randn(2, 3)

result = torch.cat((tensor1, tensor2), dim=1)
print(result.shape)  # 输出:(2, 6)
  1. 沿第三个维度(axis=2)拼接两个相同形状的张量:
import torch

tensor1 = torch.randn(2, 3, 4)
tensor2 = torch.randn(2, 3, 4)

result = torch.cat((tensor1, tensor2), dim=2)
print(result.shape)  # 输出:(2, 3, 8)

请注意,要沿指定维度拼接张量,它们的形状必须相同。例如,如果沿第一个维度拼接,张量的形状必须为(batch_size, input_dim1, input_dim2, ...)

亿速云「云服务器」,即开即用、新一代英特尔至强铂金CPU、三副本存储NVMe SSD云盘,价格低至29元/月。点击查看>>

推荐阅读:pytorch张量拼接怎么实现

0