• tensorflow2.0—— GAN实战代码


    from  tensorflow import keras
    import tensorflow as tf
    from  tensorflow.keras import layers
    import numpy as np
    import os
    import matplotlib.pyplot as plt
    
    #   设置相关底层配置
    physical_devices = tf.config.experimental.list_physical_devices('GPU')
    assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
    
    # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    # os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # 使用第2块gpu
    
    #   拼接图片
    def my_save_img(data,save_path):
        #   新图拼接行列
        r_c = 10
        len_data = data.shape[0]
        each_pix = 64
        save_img_path = save_path
        new_img = np.zeros((r_c*each_pix,r_c*each_pix,3))
        for index,each_img in enumerate(data[:r_c*r_c]):
            # print('each_img.shape:',each_img.shape,np.max(each_img),np.min(each_img))
            each_img  = (each_img+1)/2
            # print('each_img.shape:', each_img.shape, np.max(each_img), np.min(each_img))
            row_start = int(index/r_c) * each_pix
            col_start = (index%r_c)*each_pix
            # print(index,row_start,col_start)
            new_img[row_start:row_start+each_pix,col_start:col_start+each_pix,:] = each_img
            # print('new_img:',new_img)
    
        plt.imsave(save_img_path,new_img)
    
    class Generator(keras.Model):
        def __init__(self):
            super(Generator,self).__init__()
            # z: [b, 100] => [b, 3*3*512] => [b, 3, 3, 512] => [b, 64, 64, 3]
            self.fc = layers.Dense(3 * 3 * 512)
    
            self.Tconv1 = layers.Conv2DTranspose(256, 3, 3, 'valid')
            self.bn1 = layers.BatchNormalization()
    
            self.Tconv2 = layers.Conv2DTranspose(128, 5, 2, 'valid')
            self.bn2 = layers.BatchNormalization()
    
            self.Tconv3 = layers.Conv2DTranspose(3, 4, 3, 'valid')
    
        def call(self, inputs, training=None, mask=None):
            # [z, 100] => [z, 3*3*512]
            x = self.fc(inputs)
            x = tf.reshape(x, [-1, 3, 3, 512])
            x = tf.nn.leaky_relu(x)
    
            #
            x = tf.nn.leaky_relu(self.bn1(self.Tconv1(x), training=training))
            x = tf.nn.leaky_relu(self.bn2(self.Tconv2(x), training=training))
            x = self.Tconv3(x)
            x = tf.tanh(x)
    
            return x
    
    class Discriminator(keras.Model):
        def __init__(self):
            super(Discriminator,self).__init__()
            # [b, 64, 64, 3] => [b, 1]
    
            self.conv1 = layers.Conv2D(64,5,3,'valid')
    
            self.conv2 = layers.Conv2D(128, 5, 3, 'valid')
            self.bn2 = layers.BatchNormalization()
    
            self.conv3 = layers.Conv2D(256, 5, 3, 'valid')
            self.bn3 = layers.BatchNormalization()
    
            #   [b,h,w,3] => [b,-1]
            self.flatten = layers.Flatten()
            self.fc = layers.Dense(1)
    
        def call(self, inputs, training=None, mask=None):
    
            x = tf.nn.leaky_relu(self.conv1(inputs))
            x = tf.nn.leaky_relu(self.bn2(self.conv2(x),training = training))
            x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))
    
            #   打平
            x = self.flatten(x)
            #   [b,-1] => [b,1]
            logits = self.fc(x)
            return logits
    
    def main():
        #   超参数
        z_dim = 100
        epochs = 3000000
        batch_size = 1024
        learning_rate = 0.002
        is_training = True
    
        img_data = np.load('img.npy')
        train_db = tf.data.Dataset.from_tensor_slices(img_data).shuffle(10000).batch(batch_size)
        sample = next(iter(train_db))
        print(sample.shape, tf.reduce_max(sample).numpy(),
              tf.reduce_min(sample).numpy())
    
        train_db = train_db.repeat()
        db_iter = iter(train_db)
    
        #   判别器
        d = Discriminator()
        # d.build(input_shape=(None, 64, 64, 3))
        #   生成器
        g = Generator()
        # g.build(input_shape=(None, z_dim))
    
        #   分别定义优化器
        g_optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
        d_optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
    
        for epoch in range(epochs):
            batch_z = tf.random.uniform([batch_size, z_dim], minval=-1., maxval=1.)
            batch_x = next(db_iter)
    
            # train D
            with tf.GradientTape() as tape:
                # 1. treat real image as real
                # 2. treat generated image as fake
                fake_image = g(batch_z, is_training)
                d_fake_logits = d(fake_image, is_training)
                d_real_logits = d(batch_x, is_training)
    
                d_loss_real = tf.nn.sigmoid_cross_entropy_with_logits(logits=d_real_logits,labels=tf.ones_like(d_real_logits))
                # d_loss_real = tf.reduce_mean(d_loss_real)
                d_loss_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits=d_fake_logits,labels=tf.ones_like(d_fake_logits))
                # d_loss_fake = tf.reduce_mean(d_loss_fake)
    
                d_loss = d_loss_fake + d_loss_real
            grads = tape.gradient(d_loss, d.trainable_variables)
            d_optimizer.apply_gradients(zip(grads, d.trainable_variables))
    
            with tf.GradientTape() as tape:
                fake_image = g(batch_z, is_training)
                d_fake_logits = d(fake_image, is_training)
                g_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=d_fake_logits,labels=tf.ones_like(d_fake_logits))
                # g_loss = tf.reduce_mean(g_loss)
            grads = tape.gradient(g_loss, g.trainable_variables)
            g_optimizer.apply_gradients(zip(grads, g.trainable_variables))
            if epoch % 10 == 0:
                # print(epoch, 'd-loss:',float(d_loss), 'g-loss:', float(g_loss))
                print(epoch, 'd-loss:', d_loss.numpy(), 'g-loss:', g_loss.numpy())
                if epoch % 50 == 0:
                    z = tf.random.uniform([225,z_dim])
                    fake_image = g(z,training = False)
                    img_path = os.path.join('g_pic2', 'gan-%d.png'%epoch)
                    my_save_img(fake_image,img_path)
    
    
    if __name__ == '__main__':
        main()
  • 相关阅读:
    hdu 4521 小明系列问题——小明序列(线段树 or DP)
    hdu 1115 Lifting the Stone
    hdu 5476 Explore Track of Point(2015上海网络赛)
    Codeforces 527C Glass Carving
    hdu 4414 Finding crosses
    LA 5135 Mining Your Own Business
    uva 11324 The Largest Clique
    hdu 4288 Coder
    PowerShell随笔3 ---别名
    PowerShell随笔2---初始命令
  • 原文地址:https://www.cnblogs.com/cxhzy/p/14268107.html
Copyright © 2020-2023  润新知