温馨提示×

pytorch张量如何拼接

小樊
81
2024-12-25 20:54:26
栏目: 深度学习

在PyTorch中,可以使用torch.cat()函数来拼接张量。torch.cat()函数需要指定沿着哪个维度进行拼接。以下是一些示例:

  1. 沿第一个维度(axis=0)拼接两个相同形状的张量:
import torch

tensor1 = torch.randn(2, 3)
tensor2 = torch.randn(2, 3)

result = torch.cat((tensor1, tensor2), dim=0)
print(result.shape)  # 输出:(4, 3)
  1. 沿第二个维度(axis=1)拼接两个相同形状的张量:
import torch

tensor1 = torch.randn(2, 3)
tensor2 = torch.randn(2, 3)

result = torch.cat((tensor1, tensor2), dim=1)
print(result.shape)  # 输出:(2, 6)
  1. 沿第三个维度(axis=2)拼接两个相同形状的张量:
import torch

tensor1 = torch.randn(2, 3, 4)
tensor2 = torch.randn(2, 3, 4)

result = torch.cat((tensor1, tensor2), dim=2)
print(result.shape)  # 输出:(2, 3, 8)

请注意,要沿指定维度拼接张量,它们的形状必须相同。例如,如果沿第一个维度拼接,张量的形状必须为(batch_size, input_dim1, input_dim2, ...)

0