1 import tensorflow as tf 2 import multiprocessing 3 4 5 def make_anime_dataset(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, repeat=1): 6 @tf.function 7 def _map_fn(img): 8 img = tf.image.resize(img, [resize, resize]) 9 img = tf.clip_by_value(img, 0, 255) 10 img = img / 127.5 - 1 11 12 return img 13 14 dataset = disk_image_batch_dataset(img_paths, batch_size, drop_remainder=drop_remainder, 15 map_fn=_map_fn, shuffle=shuffle, repeat=repeat) 16 img_shape = (resize, resize, 3) 17 len_dataset = len(img_paths) // batch_size 18 19 return dataset, img_shape, len_dataset 20 21 22 def batch_dataset(dataset, 23 batch_size, 24 drop_remainder=True, 25 n_prefetch_batch=1, 26 filter_fn=None, 27 map_fn=None, 28 n_map_threads=None, 29 filter_after_map=False, 30 shuffle=True, 31 shuffle_buffer_size=None, 32 repeat=None): 33 # set defaults 34 if n_map_threads is None: 35 n_map_threads = multiprocessing.cpu_count() 36 37 if shuffle and shuffle_buffer_size is None: 38 shuffle_buffer_size = max(batch_size * 128, 2048) # set the minimum buffer size as 2048 39 40 # [*] it is efficient to conduct `shuffle` before `map`/`filter` because `map`/`filter` is sometimes costly 41 if shuffle: 42 dataset = dataset.shuffle(shuffle_buffer_size) 43 44 if not filter_after_map: 45 if filter_fn: 46 dataset = dataset.filter(filter_fn) 47 48 if map_fn: 49 dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads) 50 51 else: # [*] this is slower 52 if map_fn: 53 dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads) 54 55 if filter_fn: 56 dataset = dataset.filter(filter_fn) 57 58 dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) 59 dataset = dataset.repeat(repeat).prefetch(n_prefetch_batch) 60 61 return dataset 62 63 64 def memory_data_batch_dataset(memory_data, 65 batch_size, 66 drop_remainder=True, 67 n_prefetch_batch=1, 68 filter_fn=None, 69 map_fn=None, 70 n_map_threads=None, 71 filter_after_map=False, 72 shuffle=True, 73 shuffle_buffer_size=None, 74 repeat=None): 75 """Batch dataset of memory data. 76 Parameters 77 ---------- 78 memory_data : nested structure of tensors/ndarrays/lists 79 """ 80 81 dataset = tf.data.Dataset.from_tensor_slices(memory_data) 82 dataset = batch_dataset(dataset, batch_size, 83 drop_remainder=drop_remainder, 84 n_prefetch_batch=n_prefetch_batch, 85 filter_fn=filter_fn, 86 map_fn=map_fn, 87 n_map_threads=n_map_threads, 88 filter_after_map=filter_after_map, 89 shuffle=shuffle, 90 shuffle_buffer_size=shuffle_buffer_size, 91 repeat=repeat) 92 93 return dataset 94 95 96 def disk_image_batch_dataset(img_paths, 97 batch_size, 98 labels=None, 99 drop_remainder=True, 100 n_prefetch_batch=1, 101 filter_fn=None, 102 map_fn=None, 103 n_map_threads=None, 104 filter_after_map=False, 105 shuffle=True, 106 shuffle_buffer_size=None, 107 repeat=None): 108 """Batch dataset of disk image for PNG and JPEG. 109 Parameters 110 ---------- 111 img_paths : 1d-tensor/ndarray/list of str 112 labels : nested structure of tensors/ndarrays/lists 113 """ 114 115 if labels is None: 116 memory_data = img_paths 117 118 else: 119 memory_data = (img_paths, labels) 120 121 def parse_fn(path, *label): 122 img = tf.io.read_file(path) 123 img = tf.image.decode_png(img, 3) # fix channels to 3 124 return (img,) + label 125 126 if map_fn: # fuse `map_fn` and `parse_fn` 127 def map_fn_(*args): 128 return map_fn(*parse_fn(*args)) 129 else: 130 map_fn_ = parse_fn 131 132 dataset = memory_data_batch_dataset(memory_data, 133 batch_size, 134 drop_remainder=drop_remainder, 135 n_prefetch_batch=n_prefetch_batch, 136 filter_fn=filter_fn, 137 map_fn=map_fn_, 138 n_map_threads=n_map_threads, 139 filter_after_map=filter_after_map, 140 shuffle=shuffle, 141 shuffle_buffer_size=shuffle_buffer_size, 142 repeat=repeat) 143 144 return dataset
1 import tensorflow as tf 2 from tensorflow.keras import layers, Model 3 4 5 class Generator(Model): 6 # 生成器网络类 7 def __init__(self): 8 super(Generator, self).__init__() 9 filter = 64 10 # 转置卷积层1,输出channel 为filter*8,核大小4,步长1,不使用padding,不使用偏置 11 self.conv1 = layers.Conv2DTranspose(filter*8, 4,1, 'valid', use_bias=False) 12 self.bn1 = layers.BatchNormalization() 13 # 转置卷积层2 14 self.conv2 = layers.Conv2DTranspose(filter * 4, 4, 2, 'same', use_bias=False) 15 self.bn2 = layers.BatchNormalization() 16 # 转置卷积层3 17 self.conv3 = layers.Conv2DTranspose(filter * 2, 4, 2, 'same', use_bias=False) 18 self.bn3 = layers.BatchNormalization() 19 # 转置卷积层4 20 self.conv4 = layers.Conv2DTranspose(filter * 1, 4, 2, 'same', use_bias=False) 21 self.bn4 = layers.BatchNormalization() 22 # 转置卷积层5 23 self.conv5 = layers.Conv2DTranspose(3, 4, 2, 'same', use_bias=False) 24 25 def call(self, inputs, training=None): 26 x = inputs # [z, 100] 27 # Reshape 乘4D 张量,方便后续转置卷积运算:(b, 1, 1, 100) 28 x = tf.reshape(x, (x.shape[0], 1, 1, x.shape[1])) 29 x = tf.nn.relu(x) # 激活函数 30 # 转置卷积-BN-激活函数:(b, 4, 4, 512) 31 x = tf.nn.relu(self.bn1(self.conv1(x), training=training)) 32 # 转置卷积-BN-激活函数:(b, 8, 8, 256) 33 x = tf.nn.relu(self.bn2(self.conv2(x), training=training)) 34 # 转置卷积-BN-激活函数:(b, 16, 16, 128) 35 x = tf.nn.relu(self.bn3(self.conv3(x), training=training)) 36 # 转置卷积-BN-激活函数:(b, 32, 32, 64) 37 x = tf.nn.relu(self.bn4(self.conv4(x), training=training)) 38 # 转置卷积-激活函数:(b, 64, 64, 3) 39 x = self.conv5(x) 40 x = tf.tanh(x) # 输出x 范围-1~1,与预处理一致 41 42 return x 43 44 45 class Discriminator(Model): 46 # 判别器类 47 def __init__(self): 48 super(Discriminator, self).__init__() 49 filter = 64 50 # 卷积层1 51 self.conv1 = layers.Conv2D(filter, 4, 2, 'valid', use_bias=False) 52 self.bn1 = layers.BatchNormalization() 53 # 卷积层2 54 self.conv2 = layers.Conv2D(filter*2, 4, 2, 'valid', use_bias=False) 55 self.bn2 = layers.BatchNormalization() 56 # 卷积层3 57 self.conv3 = layers.Conv2D(filter * 4, 4, 2, 'valid', use_bias=False) 58 self.bn3 = layers.BatchNormalization() 59 # 卷积层4 60 self.conv4 = layers.Conv2D(filter * 8, 3, 1, 'valid', use_bias=False) 61 self.bn4 = layers.BatchNormalization() 62 # 卷积层5 63 self.conv5 = layers.Conv2D(filter * 16, 3, 1, 'valid', use_bias=False) 64 self.bn5 = layers.BatchNormalization() 65 # 全局池化层 66 self.pool = layers.GlobalAveragePooling2D() 67 # 特征打平层 68 self.flatten = layers.Flatten() 69 # 2 分类全连接层 70 self.fc = layers.Dense(1) 71 72 def call(self, inputs, training=None): 73 # 卷积-BN-激活函数:(4, 31, 31, 64) 74 x = tf.nn.leaky_relu(self.bn1(self.conv1(inputs), training=training) ) 75 # 卷积-BN-激活函数:(4, 14, 14, 128) 76 x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training)) 77 # 卷积-BN-激活函数:(4, 6, 6, 256) 78 x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training)) 79 # 卷积-BN-激活函数:(4, 4, 4, 512) 80 x = tf.nn.leaky_relu(self.bn4(self.conv4(x), training=training)) 81 # 卷积-BN-激活函数:(4, 2, 2, 1024) 82 x = tf.nn.leaky_relu(self.bn5(self.conv5(x), training=training)) 83 # 卷积-BN-激活函数:(4, 1024) 84 x = self.pool(x) 85 # 打平 86 x = self.flatten(x) 87 # 输出,[b, 1024] => [b, 1] 88 logits = self.fc(x) 89 90 return logits
1 import os 2 import glob 3 import numpy as np 4 5 import tensorflow as tf 6 from tensorflow import keras 7 8 from GAN import Generator, Discriminator 9 from Dataset import make_anime_dataset 10 11 from PIL import Image 12 import scipy.misc 13 import matplotlib.pyplot as plt 14 15 16 def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training): 17 # 计算判别器的误差函数 18 # 采样生成图片 19 fake_image = generator(batch_z, is_training) 20 # 判定生成图片 21 d_fake_logits = discriminator(fake_image, is_training) 22 # 判定真实图片 23 d_real_logits = discriminator(batch_x, is_training) 24 # 真实图片与1 之间的误差 25 d_loss_real = celoss_ones(d_real_logits) 26 # 生成图片与0 之间的误差 27 d_loss_fake = celoss_zeros(d_fake_logits) 28 # 合并误差 29 loss = d_loss_fake + d_loss_real 30 31 return loss 32 33 34 def celoss_ones(logits): 35 # 计算属于与标签为1 的交叉熵 36 y = tf.ones_like(logits) 37 loss = keras.losses.binary_crossentropy(y, logits, from_logits=True) 38 39 return tf.reduce_mean(loss) 40 41 42 def celoss_zeros(logits): 43 # 计算属于与便签为0 的交叉熵 44 y = tf.zeros_like(logits) 45 loss = keras.losses.binary_crossentropy(y, logits, from_logits=True) 46 47 return tf.reduce_mean(loss) 48 49 50 def g_loss_fn(generator, discriminator, batch_z, is_training): 51 # 采样生成图片 52 fake_image = generator(batch_z, is_training) 53 # 在训练生成网络时,需要迫使生成图片判定为真 54 d_fake_logits = discriminator(fake_image, is_training) 55 # 计算生成图片与1 之间的误差 56 loss = celoss_ones(d_fake_logits) 57 58 return loss 59 60 61 def save_result(val_out, val_block_size, image_path, color_mode): 62 def preprocess(img): 63 img = ((img + 1.0) * 127.5).astype(np.uint8) 64 # img = img.astype(np.uint8) 65 return img 66 67 preprocesed = preprocess(val_out) 68 final_image = np.array([]) 69 single_row = np.array([]) 70 71 for b in range(val_out.shape[0]): 72 # concat image into a row 73 if single_row.size == 0: 74 single_row = preprocesed[b, :, :, :] 75 else: 76 single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1) 77 78 # concat image row to final_image 79 if (b + 1) % val_block_size == 0: 80 if final_image.size == 0: 81 final_image = single_row 82 else: 83 final_image = np.concatenate((final_image, single_row), axis=0) 84 85 # reset single row 86 single_row = np.array([]) 87 88 if final_image.shape[2] == 1: 89 final_image = np.squeeze(final_image, axis=2) 90 im = Image.fromarray(final_image) 91 im.save('exam11_final_image.png') 92 # Image.save(final_image) 93 # Image(final_image).save(image_path) 94 95 96 d_losses, g_losses = [], [] 97 98 99 def draw(): 100 plt.figure() 101 plt.plot(d_losses, 'b', label='generator') 102 plt.plot(g_losses, 'r', label='discriminator') 103 plt.xlabel('Epoch') 104 plt.ylabel('ACC') 105 plt.legend() 106 plt.savefig('exam11.1_train_test_VAE.png') 107 plt.show() 108 109 110 def main(): 111 batch_size = 64 112 learning_rate = 0.0002 113 z_dim = 100 114 is_training = True 115 epochs = 300 116 117 img_path = glob.glob(r'G:2020pythonfacesfaces*.jpg') 118 print('images num:', len(img_path)) 119 # 构建数据集对象,返回数据集Dataset 类和图片大小 120 dataset, img_shape, _ = make_anime_dataset(img_path, batch_size, resize=64) # (64, 64, 64, 3) (64, 64, 3) 121 sample = next(iter(dataset)) # 采样 (64, 64, 64, 3) 122 print(sample.shape, tf.reduce_max(sample).numpy(), tf.reduce_min(sample).numpy()) # (64, 64, 64, 3) 1.0 -1.0 123 dataset = dataset.repeat(100) # 重复循环 124 db_iter = iter(dataset) 125 126 generator = Generator() # 创建生成器 127 generator.build(input_shape=(4, z_dim)) 128 discriminator = Discriminator() # 创建判别器 129 discriminator.build(input_shape=(4, 64, 64, 3)) 130 # 分别为生成器和判别器创建优化器 131 g_optimizer = keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5) 132 d_optimizer = keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5) 133 134 # generator.load_weights('exam11.1_generator.ckpt') 135 # discriminator.load_weights('exam11.1_discriminator.ckpt') 136 # print('Loaded chpt!!') 137 138 for epoch in range(epochs): # 训练epochs 次 139 # 1. 训练判别器 140 for _ in range(5): 141 # 采样隐藏向量 142 batch_z = tf.random.normal([batch_size, z_dim]) 143 batch_x = next(db_iter) # 采样真实图片 144 # 判别器前向计算 145 with tf.GradientTape() as tape: 146 d_loss = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training) 147 grads = tape.gradient(d_loss, discriminator.trainable_variables) 148 d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables)) 149 150 # 2. 训练生成器 151 # 采样隐藏向量 152 batch_z = tf.random.normal([batch_size, z_dim]) 153 batch_x = next(db_iter) # 采样真实图片 154 # 生成器前向计算 155 with tf.GradientTape() as tape: 156 g_loss = g_loss_fn(generator, discriminator, batch_z, is_training) 157 grads = tape.gradient(g_loss, generator.trainable_variables) 158 g_optimizer.apply_gradients(zip(grads, generator.trainable_variables)) 159 160 if epoch % 100 == 0: 161 print(epoch, 'd-loss:', float(d_loss), 'g-loss:', float(g_loss)) # 可视化 162 z = tf.random.normal([100, z_dim]) 163 fake_image = generator(z, training=False) 164 img_path = os.path.join('gan_images', 'gan-%d.png' % epoch) 165 save_result(fake_image.numpy(), 10, img_path, color_mode='P') 166 167 d_losses.append(float(d_loss)) 168 g_losses.append(float(g_loss)) 169 170 if epoch % 10000 == 1: 171 # print(d_losses) 172 # print(g_losses) 173 generator.save_weights('exam11.1_generator.ckpt') 174 discriminator.save_weights('exam11.1_discriminator.ckpt') 175 176 177 if __name__ == '__main__': 178 main() 179 draw()
1 import tensorflow as tf 2 from tensorflow.keras import layers, Model 3 4 5 class Generator(Model): 6 def __init__(self): 7 super(Generator, self).__init__() 8 # z: [b, 100] => [b, 3*3*512] => [b, 3, 3, 512] => [b, 64, 64, 3] 9 self.fc = layers.Dense(3*3*512) 10 self.conv1 = layers.Conv2DTranspose(256, 3, 3, 'valid') 11 self.bn1 = layers.BatchNormalization() 12 13 self.conv2 = layers.Conv2DTranspose(128, 5, 2, 'valid') 14 self.bn2 = layers.BatchNormalization() 15 self.conv3 = layers.Conv2DTranspose(3, 4, 3, 'valid') 16 17 def call(self, inputs, training=None): 18 # [z, 100] => [z, 3*3*512] 19 x = self.fc(inputs) 20 x = tf.reshape(x, [-1, 3, 3, 512]) 21 x = tf.nn.leaky_relu(x) 22 23 # 24 x = tf.nn.leaky_relu(self.bn1(self.conv1(x), training=training)) 25 x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training)) 26 x = self.conv3(x) 27 x = tf.tanh(x) 28 29 return x 30 31 32 class Discriminator(Model): 33 def __init__(self): 34 super(Discriminator, self).__init__() 35 36 # [b, 64, 64, 3] => [b, 1] 37 self.conv1 = layers.Conv2D(64, 5, 3, 'valid') 38 self.conv2 = layers.Conv2D(128, 5, 3, 'valid') 39 self.bn2 = layers.BatchNormalization() 40 41 self.conv3 = layers.Conv2D(256, 5, 3, 'valid') 42 self.bn3 = layers.BatchNormalization() 43 44 # [b, h, w ,c] => [b, -1] 45 self.flatten = layers.Flatten() 46 self.fc = layers.Dense(1) 47 48 49 def call(self, inputs, training=None): 50 x = tf.nn.leaky_relu(self.conv1(inputs)) 51 x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training)) 52 x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training)) 53 54 # [b, h, w, c] => [b, -1] 55 x = self.flatten(x) 56 57 # [b, -1] => [b, 1] 58 logits = self.fc(x) 59 return logits 60 61 62 def main(): 63 d = Discriminator() 64 g = Generator() 65 66 x = tf.random.normal([2, 64, 64, 3]) 67 z = tf.random.normal([2, 100]) 68 69 prob = d(x) 70 print(prob) 71 x_hat = g(z) 72 print(x_hat.shape) 73 74 75 if __name__ == '__main__': 76 main()
1 import os 2 import glob 3 import numpy as np 4 5 import tensorflow as tf 6 from tensorflow import keras 7 8 from WGAN import Generator, Discriminator 9 from Dataset import make_anime_dataset 10 11 from PIL import Image 12 import matplotlib.pyplot as plt 13 14 15 def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training): 16 # 计算D 的损失函数 17 fake_image = generator(batch_z, is_training) # 假样本 18 d_fake_logits = discriminator(fake_image, is_training) # 假样本的输出 19 d_real_logits = discriminator(batch_x, is_training) # 真样本的输出 20 # 计算梯度惩罚项 21 gp = gradient_penalty(discriminator, batch_x, fake_image) 22 # WGAN-GP D 损失函数的定义,这里并不是计算交叉熵,而是直接最大化正样本的输出 23 # 最小化假样本的输出和梯度惩罚项 24 loss = tf.reduce_mean(d_fake_logits) - tf.reduce_mean(d_real_logits) + 10. * gp 25 26 return loss, gp 27 28 29 def celoss_ones(logits): 30 # 计算属于与标签为1 的交叉熵 31 y = tf.ones_like(logits) 32 loss = keras.losses.binary_crossentropy(y, logits, from_logits=True) 33 34 return tf.reduce_mean(loss) 35 36 37 def celoss_zeros(logits): 38 # 计算属于与便签为0 的交叉熵 39 y = tf.zeros_like(logits) 40 loss = keras.losses.binary_crossentropy(y, logits, from_logits=True) 41 42 return tf.reduce_mean(loss) 43 44 45 def gradient_penalty(discriminator, batch_x, fake_image): 46 # 梯度惩罚项计算函数 47 batchsz = batch_x.shape[0] 48 49 # 每个样本均随机采样t,用于插值 50 t = tf.random.uniform([batchsz, 1, 1, 1]) 51 # 自动扩展为x 的形状,[b, 1, 1, 1] => [b, h, w, c] 52 t = tf.broadcast_to(t, batch_x.shape) 53 54 # 在真假图片之间做线性插值 55 interplate = t * batch_x + (1 - t) * fake_image 56 # 在梯度环境中计算D 对插值样本的梯度 57 with tf.GradientTape() as tape: 58 tape.watch([interplate]) # 加入梯度观察列表 59 d_interplote_logits = discriminator(interplate) 60 grads = tape.gradient(d_interplote_logits, interplate) 61 62 # 计算每个样本的梯度的范数:[b, h, w, c] => [b, -1] 63 grads = tf.reshape(grads, [grads.shape[0], -1]) 64 gp = tf.norm(grads, axis=1) # [b] 65 # 计算梯度惩罚项 66 gp = tf.reduce_mean((gp - 1.) ** 2) 67 68 return gp 69 70 71 def g_loss_fn(generator, discriminator, batch_z, is_training): 72 # 生成器的损失函数 73 fake_image = generator(batch_z, is_training) 74 d_fake_logits = discriminator(fake_image, is_training) 75 # WGAN-GP G 损失函数,最大化假样本的输出值 76 loss = - tf.reduce_mean(d_fake_logits) 77 78 return loss 79 80 81 def save_result(val_out, val_block_size, image_path, color_mode): 82 def preprocess(img): 83 img = ((img + 1.0) * 127.5).astype(np.uint8) 84 # img = img.astype(np.uint8) 85 return img 86 87 preprocesed = preprocess(val_out) 88 final_image = np.array([]) 89 single_row = np.array([]) 90 91 for b in range(val_out.shape[0]): 92 # concat image into a row 93 if single_row.size == 0: 94 single_row = preprocesed[b, :, :, :] 95 else: 96 single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1) 97 98 # concat image row to final_image 99 if (b + 1) % val_block_size == 0: 100 if final_image.size == 0: 101 final_image = single_row 102 else: 103 final_image = np.concatenate((final_image, single_row), axis=0) 104 105 # reset single row 106 single_row = np.array([]) 107 108 if final_image.shape[2] == 1: 109 final_image = np.squeeze(final_image, axis=2) 110 im = Image.fromarray(final_image) 111 im.save('exam11_WGAN_final_image.png') 112 # Image.save(final_image) 113 # Image(final_image).save(image_path) 114 115 116 d_losses, g_losses = [], [] 117 118 119 def draw(): 120 plt.figure() 121 plt.plot(d_losses, 'b', label='generator') 122 plt.plot(g_losses, 'r', label='discriminator') 123 plt.xlabel('Epoch') 124 plt.ylabel('ACC') 125 plt.legend() 126 plt.savefig('exam11.2_train_test_VAE.png') 127 plt.show() 128 129 130 def main(): 131 batch_size = 512 132 learning_rate = 0.002 133 z_dim = 100 134 is_training = True 135 epochs = 300 136 137 img_path = glob.glob(r'G:2020pythonfacesfaces*.jpg') 138 print('images num:', len(img_path)) # images num: 51223 139 # 构建数据集对象,返回数据集Dataset 类和图片大小 140 dataset, img_shape, _ = make_anime_dataset(img_path, batch_size, resize=64) # (512, 64, 64, 3) (64, 64, 3) 141 sample = next(iter(dataset)) # 采样 (512, 64, 64, 3) 142 print(sample.shape, tf.reduce_max(sample).numpy(), tf.reduce_min(sample).numpy()) # (512, 64, 64, 3) 1.0 -1.0 143 dataset = dataset.repeat(100) # 重复循环 144 db_iter = iter(dataset) 145 146 generator = Generator() # 创建生成器 147 generator.build(input_shape=(None, z_dim)) 148 discriminator = Discriminator() # 创建判别器 149 discriminator.build(input_shape=(None, 64, 64, 3)) 150 # 分别为生成器和判别器创建优化器 151 g_optimizer = keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5) 152 d_optimizer = keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5) 153 154 # generator.load_weights('exam11.1_generator.ckpt') 155 # discriminator.load_weights('exam11.1_discriminator.ckpt') 156 # print('Loaded chpt!!') 157 158 for epoch in range(epochs): # 训练epochs 次 159 # 采样隐藏向量 160 batch_z = tf.random.uniform([batch_size, z_dim], minval=-1., maxval=1.) 161 batch_x = next(db_iter) 162 163 # 判别器前向计算 164 with tf.GradientTape() as tape: 165 d_loss, gp = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training) 166 grads = tape.gradient(d_loss, discriminator.trainable_variables) 167 d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables)) 168 169 with tf.GradientTape() as tape: 170 g_loss = g_loss_fn(generator, discriminator, batch_z, is_training) 171 grads = tape.gradient(g_loss, generator.trainable_variables) 172 g_optimizer.apply_gradients(zip(grads, generator.trainable_variables)) 173 174 if epoch % 100 == 0: 175 print(epoch, 'd-loss:',float(d_loss), 'g-loss:', float(g_loss), 'gp:', float(gp)) 176 z = tf.random.uniform([100, z_dim]) 177 178 fake_image = generator(z, training=False) 179 img_path = os.path.join('images', 'wgan-%d.png'%epoch) 180 save_result(fake_image.numpy(), 10, img_path, color_mode='P') 181 182 if epoch % 10000 == 1: 183 # print(d_losses) 184 # print(g_losses) 185 generator.save_weights('exam11.2_generator.ckpt') 186 discriminator.save_weights('exam11.2_discriminator.ckpt') 187 188 189 if __name__ == '__main__': 190 main() 191 draw()