要用TPU训练tensorflow模型,只能使用静态图。也就是要先通过keras的sequential或者函数式定义模型,而不能直接使用重写的Model类。例子如下,其中包含层的自定义,以及子像素卷积。需要注意的是,tensorflow的子pixel_shuffle通道顺序与pytorch不同,具体怎么不同不记录了,可以直接实验一下。
from tensorflow import keras from tensorflow.keras import losses,layers,optimizers,Model import tensorflow as tf import numpy as np tpu = tf.distribute.cluster_resolver.TPUClusterResolver() tf.config.experimental_connect_to_cluster(tpu) tf.tpu.experimental.initialize_tpu_system(tpu) strategy = tf.distribute.experimental.TPUStrategy(tpu) def pixel_unshuffle(x, scale): x = tf.nn.space_to_depth(x, scale) return x class MyDense(layers.Layer): def __init__(self): super().__init__() self.layer = layers.Conv2D(3, 3, 1, 'same') def call(self, inp): x = self.layer(inp) x = pixel_unshuffle(x, 2) x = tf.maximum(x, 50) return x with strategy.scope(): inputs = keras.Input(shape=[48,48,3]) x = MyDense()(inputs) model = Model(inputs, x) model.compile(optimizers.SGD(), losses.MSE) x = np.zeros([4096*10,48,48,3]).astype(np.float32) y = np.zeros([4096*10,24,24,12]).astype(np.float32) model.fit(x,y,epochs=50,batch_size=4096)