GAN最不好理解的就是Loss函数的定义和训练过程,这里用一段代码来辅助理解,就能明白到底是怎么回事。其实GAN的损失函数并没有特殊之处,就是常用的binary_crossentropy,关键在于训练过程中存在两个神经网络和两个损失函数。
np.random.seed(42)
tf.random.set_seed(42)
codings_size = 30
generator = keras.models.Sequential([
keras.layers.Dense(100, activation="selu", input_shape=[codings_size]),
keras.layers.Dense(150, activation="selu"),
keras.layers.Dense(28 * 28, activation="sigmoid"),
keras.layers.Reshape([28, 28])
])
discriminator = keras.models.Sequential([
keras.layers.Flatten(input_shape=[28, 28]),
keras.layers.Dense(150, activation="selu"),
keras.layers.Dense(100, activation="selu"),
keras.layers.Dense(1, activation="sigmoid")
])
gan = keras.models.Sequential([generator, discriminator])
discriminator.compile(loss="binary_crossentropy", optimizer="rmsprop")
discriminator.trainable = False
gan.compile(loss="binary_crossentropy", optimizer="rmsprop")
batch_size = 32
dataset = tf.data.Dataset.from_tensor_slices(X_train).shuffle(1000)
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(1)
这里generator并不用compile,因为gan网络已经compile了。具体原因见下文。
训练过程的代码如下
def train_gan(gan, dataset, batch_size, codings_size, n_epochs=50):
generator, discriminator = gan.layers
for epoch in range(n_epochs):
print("Epoch {}/{}".format(epoch + 1, n_epochs)) # not shown in the book
for X_batch in dataset:
# phase 1 - training the discriminator
noise = tf.random.normal(shape=[batch_size, codings_size])
generated_images = generator(noise)
X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)
y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
discriminator.trainable = True
discriminator.train_on_batch(X_fake_and_real, y1)
# phase 2 - training the generator
noise = tf.random.normal(shape=[batch_size, codings_size])
y2 = tf.constant([[1.]] * batch_size)
discriminator.trainable = False
gan.train_on_batch(noise, y2)
plot_multiple_images(generated_images, 8) # not shown
plt.show() # not shown
第一阶段(discriminator训练)
# phase 1 - training the discriminator
noise = tf.random.normal(shape=[batch_size, codings_size])
generated_images = generator(noise)
X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)
y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
discriminator.trainable = True
discriminator.train_on_batch(X_fake_and_real, y1)
这个阶段首先生成数量相同的真实图片和假图片,concat在一起,即X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)
。然后是label,真图片的label是1,假图片的label是0。
然后是迅速阶段,首先将discrinimator设置为可训练,discriminator.trainable = True
,然后开始阶段。第一个阶段的训练过程只训练discriminator,discriminator.train_on_batch(X_fake_and_real, y1)
,而不是整个GAN网络gan
。
第二阶段(generator训练)
# phase 2 - training the generator
noise = tf.random.normal(shape=[batch_size, codings_size])
y2 = tf.constant([[1.]] * batch_size)
discriminator.trainable = False
gan.train_on_batch(noise, y2)
在第二阶段首先生成假图片,但是不再生成真图片。把假图片的label全部设置为1,并把discriminator的权重冻结,即discriminator.trainable = False
。这一步很关键,应该这么理解:
前面第一阶段的是discriminator的训练,使真图片的预测值尽量接近1,假图片的预测值尽量接近0,以此来达到优化损失函数的目的。现在将discrinimator的权重冻结,网络中输入假图片,并故意把label设置为1。
注意,在整个gan网络中,从上向下的顺序是先通过geneartor,再通过discriminator,即gan = keras.models.Sequential([generator, discriminator])
。第二个阶段将discrinimator冻结,并训练网络gan.train_on_batch(noise, y2)
。如果generator生成的图片足够真实,经过discrinimator后label会尽可能接近1。由于故意把y2的label设置为1,所以如果genrator生成的图片足够真实,此时generator训练已经达到最优状态,不会大幅度更新权重;如果genrator生成的图片不够真实,经过discriminator之后,预测值会接近0,由于y2的label是1,相当于预测值不准确,这时候gan网络的损失函数较大,generator会通过更新generator的权重来降低损失函数。
之后,重新回到第一阶段训练discriminator,然后第二阶段训练generator。假设整个GAN网络达到理想状态,这时候generator产生的假图片,经过discriminator之后,预测值应该是0.5。假如这个值小于0.5,证明generator不是特别准确,在第二阶段训练过程中,generator的权重会被继续更新。假如这个值大于0.5,证明discriminator不是特别准确,在第一阶段训练中,discriminator的权征会被继续更新。
简单说,对于一张generator生成的假图片,discriminator会尽量把预测值拉下拉,generator会尽量把预测值往上扯,类似一个拔河的过程,最后达到均衡状态,例如0.6, 0.4, 0.55, 0.45, 0.51, 0.49, 0.50。