在PyTorch中,可以使用torch.cat()
函数来拼接张量。torch.cat()
函数需要指定沿着哪个维度进行拼接。以下是一些示例:
import torch
tensor1 = torch.randn(2, 3)
tensor2 = torch.randn(2, 3)
result = torch.cat((tensor1, tensor2), dim=0)
print(result.shape) # 输出:(4, 3)
import torch
tensor1 = torch.randn(2, 3)
tensor2 = torch.randn(2, 3)
result = torch.cat((tensor1, tensor2), dim=1)
print(result.shape) # 输出:(2, 6)
import torch
tensor1 = torch.randn(2, 3, 4)
tensor2 = torch.randn(2, 3, 4)
result = torch.cat((tensor1, tensor2), dim=2)
print(result.shape) # 输出:(2, 3, 8)
请注意,要沿指定维度拼接张量,它们的形状必须相同。例如,如果沿第一个维度拼接,张量的形状必须为(batch_size, input_dim1, input_dim2, ...)
。
亿速云「云服务器」,即开即用、新一代英特尔至强铂金CPU、三副本存储NVMe SSD云盘,价格低至29元/月。点击查看>>
推荐阅读:pytorch张量拼接怎么实现