pytorch中stack和cat的区别

tech2026-03-02  3

torch.cat((tensor1,tensor2), dim)

将两个tensor连接起来,具体如何连接见下面例子

x = torch.rand((2,2,3)) y = torch.rand((2,2,3)) print("x:",x) print("y:",y) print("dim=0:", torch.cat((x,y),dim=0)) print("dim=1:", torch.cat((x,y), dim=1)) print("dim=2:", torch.cat((x, y), dim=2))

输出:

x: tensor([[[0.2571, 0.9011, 0.7935], [0.9308, 0.3267, 0.3290]], [[0.6155, 0.4739, 0.7251], [0.8025, 0.0424, 0.8101]]]) y: tensor([[[0.8813, 0.1149, 0.7757], [0.4733, 0.9003, 0.3300]], [[0.2597, 0.5810, 0.2507], [0.1220, 0.2260, 0.5620]]]) dim=0: tensor([[[0.2571, 0.9011, 0.7935], [0.9308, 0.3267, 0.3290]], [[0.6155, 0.4739, 0.7251], [0.8025, 0.0424, 0.8101]], [[0.8813, 0.1149, 0.7757], [0.4733, 0.9003, 0.3300]], [[0.2597, 0.5810, 0.2507], [0.1220, 0.2260, 0.5620]]]) dim=1: tensor([[[0.2571, 0.9011, 0.7935], [0.9308, 0.3267, 0.3290], [0.8813, 0.1149, 0.7757], [0.4733, 0.9003, 0.3300]], [[0.6155, 0.4739, 0.7251], [0.8025, 0.0424, 0.8101], [0.2597, 0.5810, 0.2507], [0.1220, 0.2260, 0.5620]]]) dim=2: tensor([[[0.2571, 0.9011, 0.7935, 0.8813, 0.1149, 0.7757], [0.9308, 0.3267, 0.3290, 0.4733, 0.9003, 0.3300]], [[0.6155, 0.4739, 0.7251, 0.2597, 0.5810, 0.2507], [0.8025, 0.0424, 0.8101, 0.1220, 0.2260, 0.5620]]]) [Finished in 2.1s]

torch.stack((tensor1, tensor2), dim)

x = torch.rand((2,2,3)) y = torch.rand((2,2,3)) print("x:",x) print("y:",y) print("dim=0:", torch.stack((x,y),dim=0)) print("dim=1:", torch.stack((x,y), dim=1)) print("dim=2:", torch.stack((x, y), dim=2)) print("dim=3", torch.stack((x, y), dim=3))

输出:

x: tensor([[[0.5099, 0.3434, 0.3731], [0.8523, 0.4672, 0.4163]], [[0.3364, 0.4910, 0.2302], [0.7896, 0.8119, 0.3978]]]) y: tensor([[[0.3843, 0.7627, 0.9757], [0.0065, 0.5462, 0.2765]], [[0.1890, 0.1698, 0.4486], [0.3459, 0.5552, 0.1908]]]) dim=0: tensor([[[[0.5099, 0.3434, 0.3731], [0.8523, 0.4672, 0.4163]], [[0.3364, 0.4910, 0.2302], [0.7896, 0.8119, 0.3978]]], [[[0.3843, 0.7627, 0.9757], [0.0065, 0.5462, 0.2765]], [[0.1890, 0.1698, 0.4486], [0.3459, 0.5552, 0.1908]]]]) dim=1: tensor([[[[0.5099, 0.3434, 0.3731], [0.8523, 0.4672, 0.4163]], [[0.3843, 0.7627, 0.9757], [0.0065, 0.5462, 0.2765]]], [[[0.3364, 0.4910, 0.2302], [0.7896, 0.8119, 0.3978]], [[0.1890, 0.1698, 0.4486], [0.3459, 0.5552, 0.1908]]]]) dim=2: tensor([[[[0.5099, 0.3434, 0.3731], [0.3843, 0.7627, 0.9757]], [[0.8523, 0.4672, 0.4163], [0.0065, 0.5462, 0.2765]]], [[[0.3364, 0.4910, 0.2302], [0.1890, 0.1698, 0.4486]], [[0.7896, 0.8119, 0.3978], [0.3459, 0.5552, 0.1908]]]]) dim=3 tensor([[[[0.5099, 0.3843], [0.3434, 0.7627], [0.3731, 0.9757]], [[0.8523, 0.0065], [0.4672, 0.5462], [0.4163, 0.2765]]], [[[0.3364, 0.1890], [0.4910, 0.1698], [0.2302, 0.4486]], [[0.7896, 0.3459], [0.8119, 0.5552], [0.3978, 0.1908]]]]) [Finished in 2.2s]

注意stack和cat的区别: stack操作后会在原来的基础上再增加一维,比如原来两个tensor的维度都是3维,经过stack后的结果为4维tensor; 而cat操作其结果和原来的tensor保持一致,具体stack和cat如何连接两个tensor见上方例子。

#参考链接 https://blog.csdn.net/orangerfun/article/details/104012365

最新回复(0)