• pytorch(二十六):自动编码器


    一、自动编码器

    1、AE.py

    import torch
    from torch import nn
    
    class AE(nn.Module):
        def __init__(self):
            super(AE, self).__init__()
    
            #[b, 784] => [b, 20]
            self.encoder = nn.Sequential(
                nn.Linear(784, 256),
                nn.ReLU(),
                nn.Linear(256, 64),
                nn.ReLU(),
                nn.Linear(64, 20),
                nn.ReLU()
            )
    
            #[b, 20] => [b, 784]
            self.decoder = nn.Sequential(
                nn.Linear(20, 64),
                nn.ReLU(),
                nn.Linear(64, 256),
                nn.ReLU(),
                nn.Linear(256, 784),
                nn.Sigmoid(),
            )
    
        def forward(self, x):
            """
            :param x: [b, 1, 28, 28]
            :return:
            """
            batchsz = x.shape[0]
            #flatten
            x = x.view(batchsz, 784)
            #encoder
            x = self.encoder(x)
            #decoder
            x = self.decoder(x)
            #reshape
            x = x.view(batchsz,1, 28, 28)
    
            return x, None

    2、main.py

    import torch
    from torch.utils.data import DataLoader
    from torchvision import transforms, datasets
    from auto_encoder import AE
    from torch import nn, optim
    import visdom
    def main():
        mnist_train = datasets.MNIST("mnist", True, transform=transforms.Compose([
            transforms.ToTensor()
        ]), download=True)
        mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)
    
        mnist_test = datasets.MNIST("mnist", False, transform=transforms.Compose([
            transforms.ToTensor()
        ]), download=True)
        mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)
    
        x, _ = iter(mnist_train).__next__()
        print(x.shape)
    
        model = AE()
        criton = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr=1e-3)
        viz = visdom.Visdom()
        for epoch in range(1000):
            for batchidx, (x, _) in enumerate(mnist_train):
                #[b, 1, 28, 28]
                x_hat, _ = model(x)
                loss = criton(x_hat, x)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            print(epoch, "loss:",  loss.item())
            x, _ = iter(mnist_test).__next__()
            with torch.no_grad():
                x_hat, _ = model(x)
            viz.images(x, nrow=8, win="x", opts=dict(title = "x"))
            viz.images(x_hat, nrow=8, win="x_hat", opts=dict(title="x_hat"))
    
    if __name__ == '__main__':
        main()

    二、变分自动编码器编码器

    1、模型

    import torch
    from torch import nn
    import numpy as np
    
    class VAE(nn.Module):
        def __init__(self):
            super(VAE, self).__init__()
    
            #[b, 784] => [b, 20]
            #u:[b, 10]
            #sigma:[b, 10]
            self.encoder = nn.Sequential(
                nn.Linear(784, 256),
                nn.ReLU(),
                nn.Linear(256, 64),
                nn.ReLU(),
                nn.Linear(64, 20),
                nn.ReLU()
            )
    
            #[b, 20] => [b, 784]
            self.decoder = nn.Sequential(
                nn.Linear(10, 64),
                nn.ReLU(),
                nn.Linear(64, 256),
                nn.ReLU(),
                nn.Linear(256, 784),
                nn.Sigmoid(),
            )
    
        def forward(self, x):
            """
            :param x: [b, 1, 28, 28]
            :return:
            """
            batchsz = x.shape[0]
            #flatten
            x = x.view(batchsz, 784)
            #encoder
            #[b, 20], including mean and sigma
            h_ = self.encoder(x)
            #[b, 20] => [b, 10] and [b, 10]
            mu, sigma = h_.chunk(2, dim = 1)
            # reparametrize trick, epison~N(0, 1), [b, 10]
            h = mu + sigma * torch.randn_like(sigma)
    
            kld = 0.5 * torch.sum(
                torch.pow(mu, 2) +
                torch.pow(sigma, 2) -
                torch.log(1e-8 + torch.pow(sigma, 2)) - 1
            ) / (batchsz*28*28)
    
            #decoder
            x = self.decoder(h)
            #reshape
            x = x.view(batchsz,1, 28, 28)
    
            return x, kld

    2、运行程序

    import torch
    from torch.utils.data import DataLoader
    from torchvision import transforms, datasets
    from vae import VAE
    from torch import nn, optim
    import visdom
    def main():
        mnist_train = datasets.MNIST("mnist", True, transform=transforms.Compose([
            transforms.ToTensor()
        ]), download=True)
        mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)
    
        mnist_test = datasets.MNIST("mnist", False, transform=transforms.Compose([
            transforms.ToTensor()
        ]), download=True)
        mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)
    
        x, _ = iter(mnist_train).__next__()
        print(x.shape)
    
        model = VAE()
        criton = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr=1e-3)
        viz = visdom.Visdom()
        for epoch in range(1000):
            for batchidx, (x, _) in enumerate(mnist_train):
                #[b, 1, 28, 28]
                x_hat, kld = model(x)
                loss = criton(x_hat, x)
    
                if kld is not  None:
                    loss = loss + 1.0 * kld
    
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            print(epoch, "loss:",  loss.item(), kld.item())
            x, _ = iter(mnist_test).__next__()
            with torch.no_grad():
                x_hat, _ = model(x)
            viz.images(x, nrow=8, win="x", opts=dict(title = "x"))
            viz.images(x_hat, nrow=8, win="x_hat", opts=dict(title="x_hat"))
    
    if __name__ == '__main__':
        main()

     

  • 相关阅读:
    bzoj 1697: [Usaco2007 Feb]Cow Sorting牛排序【置换群】
    【20】AngularJS 参考手册
    【19】AngularJS 应用
    【18】AngularJS 包含
    【17】AngularJS Bootstrap
    【16】AngularJS API
    【15】AngularJS 输入验证
    【14】AngularJS 表单
    【13】AngularJS 模块
    【12】AngularJS 事件
  • 原文地址:https://www.cnblogs.com/zhangxianrong/p/14956162.html
Copyright © 2020-2023  润新知