import numpy as np import tensorflow as tf # ************************* gather()根据索引提取数据 ***************************** a = tf.range(5) print('原张量a:',a) b = tf.gather(a,indices = [0,1,4]) print('根据索引提取出的数据b为: ',b) a = tf.reshape(tf.range(24),shape = (4,6)) print('多维张量a:',a) b = tf.gather(a,axis = 0,indices = [0,1,3]) c = tf.gather(a,axis = 1,indices = [0,1,3]) print('axis = 0根据索引提取出的数据b为: ',b) print('axis = 1根据索引提取出的数据c为: ',c) print() print('gather_nd()同时采样多个点') print('原张量a:',a) b = tf.gather_nd(a,[[0,1],[1,2],[2,0],[3,2]]) print('gather_nd采样后结果b:',b)