torch.mean和torch.max函数在二维矩阵上的用法实例

tech2022-11-04  92

>>> import torch >>> x = torch.arange(15).view(3,5)*1.0 #乘1.0因为torch.mean只能处理float类型 >>> print(x) tensor([[ 0., 1., 2., 3., 4.], [ 5., 6., 7., 8., 9.], [10., 11., 12., 13., 14.]])

torch.mean(input, dim, keepdim)

>>> x_mean0 = torch.mean(x, dim=0, keepdim=True) >>> print(x_mean0) tensor([[5., 6., 7., 8., 9.]]) >>> >>> x_mean1 = torch.mean(x, dim=1 ,keepdim=True) >>> print(x_mean1) tensor([[ 2.], [ 7.], [12.]]) >>>

torch.max(input, dim, keepdim)

>>> values0, indices0 = torch.max(x, dim=0 ,keepdim=True) >>> print(values0) tensor([[10., 11., 12., 13., 14.]]) >>> print(indices0) tensor([[2, 2, 2, 2, 2]]) >>> >>> values1, indices1 = torch.max(x, dim=1 ,keepdim=True) >>> print(values1) tensor([[ 4.], [ 9.], [14.]]) >>> print(indices1) tensor([[4], [4], [4]]) >>>

keepdim的作用

>>> x_mean = torch.mean(x, dim=0, keepdim=True) # keepdim=True保留原维度 >>> print(x_mean) tensor([[5., 6., 7., 8., 9.]]) # 2维 >>> x_mean = torch.mean(x, dim=0,keepdim=False) # keepdim=False不保留原维度 >>> print(x_mean) tensor([5., 6., 7., 8., 9.]) # 1维 >>>
最新回复(0)