首先检测TPU存在:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver() #如果先前设置好了TPU_NAME环境变量,不需要再给参数.
tpu的返回值为1 or 0 ,1则检测到了TPU.
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
strategy = tf.distribute.experimental.TPUStrategy(tpu)
with strategy.scope():
#define a model
#compile it
#train it
因为这目前还是一个实验功能,代码实现可能过一段时间就变了,看官方给的通知吧.