gather就是按行取值:
a1 = [[1,2], [3, 4], [5, 6]]
a2 = tf.gather(tf.constant(a1), [0, 1])
print(a2)
输出:
tf.Tensor(
[[1 2]
[3 4]], shape=(2, 2), dtype=int32)
相当于:
a1[:2]
gather就是按行取值:
a1 = [[1,2], [3, 4], [5, 6]]
a2 = tf.gather(tf.constant(a1), [0, 1])
print(a2)
输出:
tf.Tensor(
[[1 2]
[3 4]], shape=(2, 2), dtype=int32)
相当于:
a1[:2]