pytorch中repeat()函数

tech2026-06-08  4

pytorch中的repeat()函数可以对张量进行复制。

当参数只有两个时,第一个参数表示的是复制后的列数,第二个参数表示复制后的行数。 当参数有三个时,第一个参数表示的是复制后的通道数,第二个参数表示的是复制后的列数,第三个参数表示复制后的行数。

接下来我们举一个例子来直观理解一下:

>>> x = torch.tensor([6,7,8]) >>> x.repeat(4,2) tensor([[6, 7, 8, 6, 7, 8], [6, 7, 8, 6, 7, 8], [6, 7, 8, 6, 7, 8], [6, 7, 8, 6, 7, 8]]) >>> x.repeat(4,2,1) tensor([[[6, 7, 8], [6, 7, 8]], [[6, 7, 8], [6, 7, 8]], [[6, 7, 8], [6, 7, 8]], [[6, 7, 8], [6, 7, 8]]]) >>> x.repeat(4,2,1).size() torch.Size([4, 2, 3])
最新回复(0)