>>> import torch
>>> x
= torch
.arange
(15).view
(3,5)*1.0
>>> print(x
)
tensor
([[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[10., 11., 12., 13., 14.]])
torch.mean(input, dim, keepdim)
>>> x_mean0
= torch
.mean
(x
, dim
=0, keepdim
=True)
>>> print(x_mean0
)
tensor
([[5., 6., 7., 8., 9.]])
>>>
>>> x_mean1
= torch
.mean
(x
, dim
=1 ,keepdim
=True)
>>> print(x_mean1
)
tensor
([[ 2.],
[ 7.],
[12.]])
>>>
torch.max(input, dim, keepdim)
>>> values0
, indices0
= torch
.max(x
, dim
=0 ,keepdim
=True)
>>> print(values0
)
tensor
([[10., 11., 12., 13., 14.]])
>>> print(indices0
)
tensor
([[2, 2, 2, 2, 2]])
>>>
>>> values1
, indices1
= torch
.max(x
, dim
=1 ,keepdim
=True)
>>> print(values1
)
tensor
([[ 4.],
[ 9.],
[14.]])
>>> print(indices1
)
tensor
([[4],
[4],
[4]])
>>>
keepdim的作用
>>> x_mean
= torch
.mean
(x
, dim
=0, keepdim
=True)
>>> print(x_mean
)
tensor
([[5., 6., 7., 8., 9.]])
>>> x_mean
= torch
.mean
(x
, dim
=0,keepdim
=False)
>>> print(x_mean
)
tensor
([5., 6., 7., 8., 9.])
>>>
转载请注明原文地址:https://tech.qufami.com/read-7238.html