• pytorch(二十六):自动编码器


    一、自动编码器

    1、AE.py

    import torch
    from torch import nn
    
    class AE(nn.Module):
        def __init__(self):
            super(AE, self).__init__()
    
            #[b, 784] => [b, 20]
            self.encoder = nn.Sequential(
                nn.Linear(784, 256),
                nn.ReLU(),
                nn.Linear(256, 64),
                nn.ReLU(),
                nn.Linear(64, 20),
                nn.ReLU()
            )
    
            #[b, 20] => [b, 784]
            self.decoder = nn.Sequential(
                nn.Linear(20, 64),
                nn.ReLU(),
                nn.Linear(64, 256),
                nn.ReLU(),
                nn.Linear(256, 784),
                nn.Sigmoid(),
            )
    
        def forward(self, x):
            """
            :param x: [b, 1, 28, 28]
            :return:
            """
            batchsz = x.shape[0]
            #flatten
            x = x.view(batchsz, 784)
            #encoder
            x = self.encoder(x)
            #decoder
            x = self.decoder(x)
            #reshape
            x = x.view(batchsz,1, 28, 28)
    
            return x, None

    2、main.py

    import torch
    from torch.utils.data import DataLoader
    from torchvision import transforms, datasets
    from auto_encoder import AE
    from torch import nn, optim
    import visdom
    def main():
        mnist_train = datasets.MNIST("mnist", True, transform=transforms.Compose([
            transforms.ToTensor()
        ]), download=True)
        mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)
    
        mnist_test = datasets.MNIST("mnist", False, transform=transforms.Compose([
            transforms.ToTensor()
        ]), download=True)
        mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)
    
        x, _ = iter(mnist_train).__next__()
        print(x.shape)
    
        model = AE()
        criton = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr=1e-3)
        viz = visdom.Visdom()
        for epoch in range(1000):
            for batchidx, (x, _) in enumerate(mnist_train):
                #[b, 1, 28, 28]
                x_hat, _ = model(x)
                loss = criton(x_hat, x)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            print(epoch, "loss:",  loss.item())
            x, _ = iter(mnist_test).__next__()
            with torch.no_grad():
                x_hat, _ = model(x)
            viz.images(x, nrow=8, win="x", opts=dict(title = "x"))
            viz.images(x_hat, nrow=8, win="x_hat", opts=dict(title="x_hat"))
    
    if __name__ == '__main__':
        main()

    二、变分自动编码器编码器

    1、模型

    import torch
    from torch import nn
    import numpy as np
    
    class VAE(nn.Module):
        def __init__(self):
            super(VAE, self).__init__()
    
            #[b, 784] => [b, 20]
            #u:[b, 10]
            #sigma:[b, 10]
            self.encoder = nn.Sequential(
                nn.Linear(784, 256),
                nn.ReLU(),
                nn.Linear(256, 64),
                nn.ReLU(),
                nn.Linear(64, 20),
                nn.ReLU()
            )
    
            #[b, 20] => [b, 784]
            self.decoder = nn.Sequential(
                nn.Linear(10, 64),
                nn.ReLU(),
                nn.Linear(64, 256),
                nn.ReLU(),
                nn.Linear(256, 784),
                nn.Sigmoid(),
            )
    
        def forward(self, x):
            """
            :param x: [b, 1, 28, 28]
            :return:
            """
            batchsz = x.shape[0]
            #flatten
            x = x.view(batchsz, 784)
            #encoder
            #[b, 20], including mean and sigma
            h_ = self.encoder(x)
            #[b, 20] => [b, 10] and [b, 10]
            mu, sigma = h_.chunk(2, dim = 1)
            # reparametrize trick, epison~N(0, 1), [b, 10]
            h = mu + sigma * torch.randn_like(sigma)
    
            kld = 0.5 * torch.sum(
                torch.pow(mu, 2) +
                torch.pow(sigma, 2) -
                torch.log(1e-8 + torch.pow(sigma, 2)) - 1
            ) / (batchsz*28*28)
    
            #decoder
            x = self.decoder(h)
            #reshape
            x = x.view(batchsz,1, 28, 28)
    
            return x, kld

    2、运行程序

    import torch
    from torch.utils.data import DataLoader
    from torchvision import transforms, datasets
    from vae import VAE
    from torch import nn, optim
    import visdom
    def main():
        mnist_train = datasets.MNIST("mnist", True, transform=transforms.Compose([
            transforms.ToTensor()
        ]), download=True)
        mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)
    
        mnist_test = datasets.MNIST("mnist", False, transform=transforms.Compose([
            transforms.ToTensor()
        ]), download=True)
        mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)
    
        x, _ = iter(mnist_train).__next__()
        print(x.shape)
    
        model = VAE()
        criton = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr=1e-3)
        viz = visdom.Visdom()
        for epoch in range(1000):
            for batchidx, (x, _) in enumerate(mnist_train):
                #[b, 1, 28, 28]
                x_hat, kld = model(x)
                loss = criton(x_hat, x)
    
                if kld is not  None:
                    loss = loss + 1.0 * kld
    
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            print(epoch, "loss:",  loss.item(), kld.item())
            x, _ = iter(mnist_test).__next__()
            with torch.no_grad():
                x_hat, _ = model(x)
            viz.images(x, nrow=8, win="x", opts=dict(title = "x"))
            viz.images(x_hat, nrow=8, win="x_hat", opts=dict(title="x_hat"))
    
    if __name__ == '__main__':
        main()

     

  • 相关阅读:
    《20171101-构建之法:现代软件工程-阅读笔记》
    《软件工程课程总结》
    《20171122-构建之法:现代软件工程-阅读笔记》) (5分)
    阅读任务-阅读提问-4
    《20171115构建之法:现代软件工程-阅读笔记》)
    对软件工程的期望
    自我介绍
    Javaweb学习计划
    分布式事务解决方案
    countdown模式
  • 原文地址:https://www.cnblogs.com/zhangxianrong/p/14956162.html
Copyright © 2020-2023  润新知