tf.argmax用法

tech2024-10-21  8

tf.argmax():返回跨张量轴的最大值的索引。同tf.compat.v1.arg_max及tf.compat.v1.argmax

tf.compat.v1.argmax( input, axis=None, name=None, dimension=None, output_type=tf.dtypes.int64 )

注意,如果是ties,则不能保证返回值的同一性。

用法:

import tensorflow as tf a = [1, 10, 26.9, 2.8, 166.32, 62.3] b = tf.math.argmax(input = a) c = tf.keras.backend.eval(b) # c = 4 # here a[4] = 166.32 which is the largest element of a across axis 0

 

Args

input一个张量Tensor. 必须是以下类型之一: float32,float64,int32,uint8,int16,int8,complex64,int64,qint8,quint8,qint32,bfloat16,uint16,complex128,half,uint32,uint64,bool。axis一个张量Tensor. 必须是以下类型之一: int32, int64。 int32 或 int64必须在[-rank(input), rank(input)]范围内. 描述要减少输入张量的哪条轴。 对于矢量,请使用axis = 0。output_type可选的 tf.DType 来自: tf.int32, tf.int64. 默认为 tf.int64.name操作的名称(可选)。

Returns

类型为output_type的Tensor。

 即:

axis = 0时,返回每一列最大值的位置索引axis = 1时,返回每一行最大值的位置索引axis = 2、3、4...,即为多维张量时,同理推断

 

重点:结果进行降维,再按照对应维度求最大值索引

 

import tensorflow as tf a = tf.constant([1., 2., 3., 0., 9.]) b = tf.constant([[1, 2, 3], [3, 2, 1], [4, 5, 6], [6, 5, 4]]) with tf.Session() as sess: print(sess.run(a)) print(sess.run(tf.argmax(a, 0))) print("*" * 20) print(sess.run(b)) print(sess.run(tf.argmax(b, 0))) print("*" * 20) print(sess.run(b)) print(sess.run(tf.argmax(b, 1))) # 输出结果: # [1. 2. 3. 0. 9.] # 4 # ******************** # [[1 2 3] # [3 2 1] # [4 5 6] # [6 5 4]] # [3 2 2] # ******************** # [[1 2 3] # [3 2 1] # [4 5 6] # [6 5 4]] # [2 0 2 0] import tensorflow as tf c = tf.constant( [[[1, 2, 3], [2, 3, 5], [2, 2, 2]], [[5, 4, 3], [8, 7, 2], [1, 2, 3]], [[5, 4, 6], [10, 7, 30], [1, 2, 3]]]) with tf.Session() as sess: print(sess.run(c)) print("*" * 20) print(sess.run(tf.argmax(c, 0))) print("*" * 20) print(sess.run(tf.argmax(c, 1))) print("*" * 20) print(sess.run(tf.argmax(c, 2))) print("*" * 20) # 输出结果: # # [[[ 1 2 3] # [ 2 3 5] # [ 2 2 2]] # # [[ 5 4 3] # [ 8 7 2] # [ 1 2 3]] # # [[ 5 4 6] # [10 7 30] # [ 1 2 3]]] # ******************** # [[1 1 2] # [2 1 2] # [0 0 1]] # ******************** # [[1 1 1] # [1 1 0] # [1 1 1]] # ******************** # [[2 2 0] # [0 0 2] # [2 2 2]] # ********************

 

最新回复(0)