• Auto-Encoders实战


    • Auto-Encoder

    • Variational Auto-Encoders




    import os
    import tensorflow as tf
    import numpy as np
    from tensorflow import keras
    from tensorflow.keras import Sequential, layers
    from PIL import Image
    from matplotlib import pyplot as plt
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    assert tf.__version__.startswith('2.')
    def save_images(imgs, name):
        new_im = Image.new('L', (280, 280))
        index = 0
        for i in range(0, 280, 28):
            for j in range(0, 280, 28):
                im = imgs[index]
                im = Image.fromarray(im, mode='L')
                new_im.paste(im, (i, j))
                index += 1
    h_dim = 20  # 784降维20维
    batchsz = 512
    lr = 1e-3
    (x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
    x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(
        np.float32) / 255.
    # we do not need label
    train_db = tf.data.Dataset.from_tensor_slices(x_train)
    train_db = train_db.shuffle(batchsz * 5).batch(batchsz)
    test_db = tf.data.Dataset.from_tensor_slices(x_test)
    test_db = test_db.batch(batchsz)
    print(x_train.shape, y_train.shape)
    print(x_test.shape, y_test.shape)
    class AE(keras.Model):
        def __init__(self):
            super(AE, self).__init__()
            # Encoders
            self.encoder = Sequential([
                layers.Dense(256, activation=tf.nn.relu),
                layers.Dense(128, activation=tf.nn.relu),
            # Decoders
            self.decoder = Sequential([
                layers.Dense(128, activation=tf.nn.relu),
                layers.Dense(256, activation=tf.nn.relu),
        def call(self, inputs, training=None):
            # [b,784] ==> [b,19]
            h = self.encoder(inputs)
            # [b,10] ==> [b,784]
            x_hat = self.decoder(h)
            return x_hat
    model = AE()
    model.build(input_shape=(None, 784))  # tensorflow尽量用元组
    (60000, 28, 28) (60000,)
    (10000, 28, 28) (10000,)
    Model: "ae"
    Layer (type)                 Output Shape              Param #   
    sequential (Sequential)      multiple                  236436    
    sequential_1 (Sequential)    multiple                  237200    
    Total params: 473,636
    Trainable params: 473,636
    Non-trainable params: 0


    optimizer = tf.optimizers.Adam(lr=lr)
    for epoch in range(10):
        for step, x in enumerate(train_db):
            # [b,28,28]==>[b,784]
            x = tf.reshape(x, [-1, 784])
            with tf.GradientTape() as tape:
                x_rec_logits = model(x)
                rec_loss = tf.losses.binary_crossentropy(x,
                rec_loss = tf.reduce_min(rec_loss)
            grads = tape.gradient(rec_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            if step % 100 == 0:
                print(epoch, step, float(rec_loss))
                # evaluation
            x = next(iter(test_db))
            logits = model(tf.reshape(x, [-1, 784]))
            x_hat = tf.sigmoid(logits)
            # [b,784]==>[b,28,28]
            x_hat = tf.reshape(x_hat, [-1, 28, 28])
            # [b,28,28] ==> [2b,28,28]
            x_concat = tf.concat([x, x_hat], axis=0)
            # x_concat = x  # 原始图片
            x_concat = x_hat
            x_concat = x_concat.numpy() * 255.
            x_concat = x_concat.astype(np.uint8)  # 保存为整型
            if not os.path.exists('ae_images'):
            save_images(x_concat, 'ae_images/rec_epoch_%d.png' % epoch)
    0 0 0.09717604517936707
    0 100 0.12493347376585007
    1 0 0.09747321903705597
    1 100 0.12291513383388519
    2 0 0.10048121958971024
    2 100 0.12292417883872986
    3 0 0.10093794018030167
    3 100 0.12260882556438446
    4 0 0.10006923228502274
    4 100 0.12275046110153198
    5 0 0.0993042066693306
    5 100 0.12257824838161469
    6 0 0.0967678651213646
    6 100 0.12443818897008896
    7 0 0.0965462476015091
    7 100 0.12179268896579742
    8 0 0.09197664260864258
    8 100 0.12110235542058945
    9 0 0.0913471132516861
    9 100 0.12342415750026703
  • 相关阅读:
    由jQuery Validation Remote验证引起的错误(MVC3 jQuery.validate.unobtrusive)
    Asp.Net MVC 必备插件MVC Route Visualizer(Visual Studio 2012 版)
    2012 LinkCoder Jeffrey Richter:Win 8应用开发与.NET4.5
    WCF应用:宿主与调用纯代码示例(Host &Client code only sample)
    Nexus 7 入手风波记
    [转]使用HyperV BPA(Best Practices Analyzer最佳化分析工具)
    [转]Installing and Configuring target iSCSI server on Windows Server 2012
  • 原文地址:https://www.cnblogs.com/abdm-989/p/14123449.html
Copyright © 2020-2023  润新知