• GAN生成式对抗网络(三)——mnist数据生成


    通过GAN生成式对抗网络,产生mnist数据

    引入包,数据约定等

    import numpy as np
    import matplotlib.pyplot as plt
    import input_data  #读取数据的一个工具文件,不影响理解
    import tensorflow as tf
    
    
    # 获取数据
    mnist = input_data.read_data_sets('data/', one_hot=True)
    trainimg = mnist.train.images
    
    X = mnist.train.images[:, :]
    batch_size = 64
    
    #用来返回真实数据
    def iterate_minibatch(x, batch_size, shuffle=True):
        indices = np.arange(x.shape[0])
        if shuffle:
            np.random.shuffle(indices)
        for i in range(0, x.shape[0]-1000, batch_size):
            temp = x[indices[i:i + batch_size], :]
            temp = np.array(temp) * 2 - 1
            yield np.reshape(temp, (-1, 28, 28, 1))
    

    GAN对象结构

    class GAN(object):
        def __init__(self):
            #初始函数,在这里对初始化模型
        def netG(self, z):
            #生成器模型
        def netD(self, x, reuse=False):
            #判别器模型
    

    生成器函数

    对随机值z(维度为1,100),进行包装,伪造,产生伪造数据。
    包装过程概括为:全连接->reshape->反卷积
    包装过程中使用了batch_normalization,Leaky ReLU,dropout,tanh等技巧

       #对随机值z(维度为1,100),进行包装,伪造,产生伪造数据。
        #包装过程概括为:全连接->reshape->反卷积
        #包装过程中使用了batch_normalization,Leaky ReLU,dropout,tanh等技巧
        def netG(self,z,alpha=0.01):
            with tf.variable_scope('generator') as scope:
                layer1 = tf.layers.dense(z, 4 * 4 * 512)  # 这是一个全连接层,输出 (n,4*4*512)
                layer1 = tf.reshape(layer1, [-1, 4, 4, 512])
                # batch normalization
                layer1 = tf.layers.batch_normalization(layer1, training=True)  # 做BN标准化处理
                # Leaky ReLU
                layer1 = tf.maximum(alpha * layer1, layer1)
                # dropout
                layer1 = tf.nn.dropout(layer1, keep_prob=0.8)
    
                # 4 x 4 x 512 to 7 x 7 x 256
                layer2 = tf.layers.conv2d_transpose(layer1, 256, 4, strides=1, padding='valid')
                layer2 = tf.layers.batch_normalization(layer2, training=True)
                layer2 = tf.maximum(alpha * layer2, layer2)
                layer2 = tf.nn.dropout(layer2, keep_prob=0.8)
    
                # 7 x 7 256 to 14 x 14 x 128
                layer3 = tf.layers.conv2d_transpose(layer2, 128, 3, strides=2, padding='same')
                layer3 = tf.layers.batch_normalization(layer3, training=True)
                layer3 = tf.maximum(alpha * layer3, layer3)
                layer3 = tf.nn.dropout(layer3, keep_prob=0.8)
    
                # 14 x 14 x 128 to 28 x 28 x 1
                logits = tf.layers.conv2d_transpose(layer3, 1, 3, strides=2, padding='same')
                # MNIST原始数据集的像素范围在0-1,这里的生成图片范围为(-1,1)
                # 因此在训练时,记住要把MNIST像素范围进行resize
                outputs = tf.tanh(logits)
    
                return outputs
    
    

    判别器函数

    通过深度卷积+全连接的形式,判别器将输入分类为真数据,还是假数据。

        def netD(self, x, reuse=False,alpha=0.01):
            with tf.variable_scope('discriminator') as scope:
                if reuse:
                    scope.reuse_variables()
                layer1 = tf.layers.conv2d(x, 128, 3, strides=2, padding='same')
                layer1 = tf.maximum(alpha * layer1, layer1)
                layer1 = tf.nn.dropout(layer1, keep_prob=0.8)
    
                # 14 x 14 x 128 to 7 x 7 x 256
                layer2 = tf.layers.conv2d(layer1, 256, 3, strides=2, padding='same')
                layer2 = tf.layers.batch_normalization(layer2, training=True)
                layer2 = tf.maximum(alpha * layer2, layer2)
                layer2 = tf.nn.dropout(layer2, keep_prob=0.8)
    
                # 7 x 7 x 256 to 4 x 4 x 512
                layer3 = tf.layers.conv2d(layer2, 512, 3, strides=2, padding='same')
                layer3 = tf.layers.batch_normalization(layer3, training=True)
                layer3 = tf.maximum(alpha * layer3, layer3)
                layer3 = tf.nn.dropout(layer3, keep_prob=0.8)
    
                # 4 x 4 x 512 to 4*4*512 x 1
                flatten = tf.reshape(layer3, (-1, 4 * 4 * 512))
                f = tf.layers.dense(flatten, 1)
                return f
    

    初始化函数

    有一个前置训练,将真实数据喂给判别器,训练判别器的鉴别能力

        # 有一个前置训练,将真实数据喂给判别器,训练判别器的鉴别能力
        def __init__(self):
            self.z = tf.placeholder(tf.float32, shape=[batch_size, 100], name='z')  # 随机输入值
            self.x = tf.placeholder(tf.float32, shape=[batch_size, 28, 28, 1], name='real_x')  # 图片值
    
            self.fake_x = self.netG(self.z)  # 将随机输入,包装为伪造图片值
    
            self.pre_logits = self.netD(self.x, reuse=False)  # 判别器预训练时,判别器对真实数据的判别情况-未sigmoid处理
            self.real_logits = self.netD(self.x, reuse=True)  # 判别器对真实数据的判别情况-未sigmoid处理
            self.fake_logits = self.netD(self.fake_x, reuse=True)  # 判别器对伪造数据的判别情况-未sigmoid处理
    
            # 预训练时判别器,判别器将真实数据判定为真的得分情况。
            self.loss_pre_D = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.pre_logits,
                                                                                     labels=tf.ones_like(self.pre_logits)))
            # 训练时,判别器将真实数据判定为真,将伪造数据判定为假的得分情况。
            self.loss_D = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.real_logits,
                                                                                 labels=tf.ones_like(self.real_logits))) + 
                          tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_logits,
                                                                                 labels=tf.zeros_like(self.fake_logits)))
            # 训练时,生成器伪造的数据,被判定为真实数据的得分情况。
            self.loss_G = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_logits,
                                                                                 labels=tf.ones_like(self.fake_logits)))
    
            # 获取生成器和判定器对应的变量地址,用于更新变量
            t_vars = tf.trainable_variables()
            self.g_vars = [var for var in t_vars if var.name.startswith("generator")]
            self.d_vars = [var for var in t_vars if var.name.startswith("discriminator")]
    
    

    开始训练

    gan = DCGAN()
    #预训练时的梯度优化函数
    d_pre_optim = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.4).minimize(gan.loss_pre_D, var_list=gan.d_vars)
    #判别器的梯度优化函数
    d_optim = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.4).minimize(gan.loss_D, var_list=gan.d_vars)
    #预训练时的梯度优化函数
    g_optim = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.4).minimize(gan.loss_G, var_list=gan.g_vars)
    
    init = tf.global_variables_initializer()
    
    with tf.Session() as sess:
        sess.run(init)
        #对判别器的预训练,训练了两个epoch
        for i in range(2):
            print('判别器初始训练,第' + str(i) + '次包')
            for x_batch in iterate_minibatch(X, batch_size=batch_size):
                loss_pre_D, _ = sess.run([gan.pre_logits, d_pre_optim],
                                         feed_dict={
                                             gan.x: x_batch
                                         })
        #训练5个epoch
        for epoch in range(5):
            print('对抗' + str(epoch) + '次包')
            avg_loss = 0
            count = 0
            for x_batch in iterate_minibatch(X, batch_size=batch_size):
                z_batch = np.random.uniform(-1, 1, size=(batch_size, 100))  # 随机起点值
    
                loss_D, _ = sess.run([gan.loss_D, d_optim],
                                     feed_dict={
                                         gan.z: z_batch,
                                         gan.x: x_batch
                                     })
    
                loss_G, _ = sess.run([gan.loss_G, g_optim],
                                     feed_dict={
                                         gan.z: z_batch,
                                         # gan.x: np.zeros(z_batch.shape)
                                     })
    
                avg_loss += loss_D
                count += 1
    
            # 显示预测情况
            if True:
                avg_loss /= count
                z = np.random.normal(size=(batch_size, 100))
                excerpt = np.random.randint(100, size=batch_size)
                needTest = np.reshape(X[excerpt, :], (-1, 28, 28, 1))
                fake_x, real_logits, fake_logits = sess.run([gan.fake_x, gan.real_logits, gan.fake_logits],
                                                            feed_dict={gan.z: z, gan.x: needTest})
                # accuracy = (np.sum(real_logits > 0.5) + np.sum(fake_logits < 0.5)) / (2 * batch_size)
                print('real_logits')
                print(len(real_logits))
                print('fake_logits')
                print(len(fake_logits))
                print('
    discriminator loss at epoch %d: %f' % (epoch, avg_loss))
                # print('
    discriminator accuracy at epoch %d: %f' % (epoch, accuracy))
                print('----')
                print()
    
                # curr_img = np.reshape(trainimg[i, :], (28, 28))  # 28 by 28 matrix
                curr_img = np.reshape(fake_x[0], (28, 28))
                plt.matshow(curr_img, cmap=plt.get_cmap('gray'))
                plt.show()
                curr_img2 = np.reshape(fake_x[10], (28, 28))
                plt.matshow(curr_img2, cmap=plt.get_cmap('gray'))
                plt.show()
                curr_img3 = np.reshape(fake_x[20], (28, 28))
                plt.matshow(curr_img3, cmap=plt.get_cmap('gray'))
                plt.show()
    
                curr_img4 = np.reshape(fake_x[30], (28, 28))
                plt.matshow(curr_img4, cmap=plt.get_cmap('gray'))
                plt.show()
    
                curr_img5 = np.reshape(fake_x[40], (28, 28))
                plt.matshow(curr_img5, cmap=plt.get_cmap('gray'))
                plt.show()
                # plt.figure(figsize=(28, 28))
    
                # plt.title("" + str(i) + "th Training Data "
                #           + "Label is " + str(curr_label))
                # print("" + str(i) + "th Training Data "
                #       + "Label is " + str(curr_label))
    
                # plt.scatter(X[:, 0], X[:, 1])
                # plt.scatter(fake_x[:, 0], fake_x[:, 1])
                # plt.show()
    

    结果

    下载链接

  • 相关阅读:
    Codeforces Round #370 (Div. 2)
    Codeforces Round #425 (Div. 2)
    变量调节器
    Smarty基础
    流程
    iframe 内联框架
    权限:改变权限
    权限:查找
    html 框架
    Jcrop+uploadify+php实现上传头像预览裁剪
  • 原文地址:https://www.cnblogs.com/panfengde/p/10021461.html
Copyright © 2020-2023  润新知