tf.argmax()函数原型:
def argmax(input, axis=None, name=None, dimension=None, output_type=dtypes.int64)
作用是返回每列/行的最大值的索引。
input是一个张量,
axis是0或1,0返回各列最大值索引,1返回各行最大值索引。
其他3个参数不常用,常用写法是 a = tf.argmax(tensor, 1)。
import tensorflow as tf sess = tf.InteractiveSession() a = tf.constant([[12, 3, 9], [3, 6, 13]]) b_1 = tf.argmax(a, 0) # 返回ndarry,元素是每列的最大值索引 b_2 = tf.argmax(a, 1) print(b_1) # >>array([0, 1, 1], dtype=int64) print(b_2) # >>array([0, 2], dtype=int64)