在PyTorch中,可以使用torch.cat()
函数来拼接张量。torch.cat()
函数需要指定沿着哪个维度进行拼接。以下是一些示例:
import torch
tensor1 = torch.randn(2, 3)
tensor2 = torch.randn(2, 3)
result = torch.cat((tensor1, tensor2), dim=0)
print(result.shape) # 输出:(4, 3)
import torch
tensor1 = torch.randn(2, 3)
tensor2 = torch.randn(2, 3)
result = torch.cat((tensor1, tensor2), dim=1)
print(result.shape) # 输出:(2, 6)
import torch
tensor1 = torch.randn(2, 3, 4)
tensor2 = torch.randn(2, 3, 4)
result = torch.cat((tensor1, tensor2), dim=2)
print(result.shape) # 输出:(2, 3, 8)
请注意,要沿指定维度拼接张量,它们的形状必须相同。例如,如果沿第一个维度拼接,张量的形状必须为(batch_size, input_dim1, input_dim2, ...)
。