• PyTorch实现简单的变分自动编码器VAE


          在上一篇博客中我们介绍并实现了自动编码器,本文将用PyTorch实现变分自动编码器(Variational AutoEncoder, VAE)。自动变分编码器原理与一般的自动编码器的区别在于需要在编码过程增加一点限制,迫使它生成的隐含向量能够粗略的遵循标准正态分布。这样一来,当需要生成一张新图片时,只需要给解码器一个标准正态分布的隐含随机向量就可以了。

          在实际操作中,实际上不是生成一个隐含向量,而是生成两个向量:一个表示均值,一个表示标准差,然后通过这两个统计量合成隐含向量,用一个标准正态分布先乘标准差再加上均值就行了。具体关于变分自动编码器的内容,可参考廖星宇的《深度学习之PyTorch》的第六章,下面的代码也是来自这个资料,但本文对原代码做了一点改动。

    import os
    import torch
    import torch.nn.functional as F
    from torch import nn
    from torch.utils.data import DataLoader
    from torchvision.datasets import MNIST
    from torchvision import transforms as tfs
    from torchvision.utils import save_image
    
    # Hyper parameters
    EPOCH = 1
    LR = 1e-3
    BATCHSIZE = 128
    
    im_tfs = tfs.Compose([
        tfs.ToTensor(),    # Converts a PIL.Image or numpy.ndarray to
                           # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
        tfs.Normalize([0.5], [0.5])   # 把[0.0, 1.0]的数据扩大范围到[-1., 1]
    ])
    
    train_set = MNIST(
        root='/Users/wangpeng/Desktop/all/CS/Courses/Deep Learning/mofan_PyTorch/mnist/',   # mnist has been downloaded before, use it directly
        train=True,
        transform=im_tfs,
    )
    train_loader = DataLoader(train_set, batch_size=BATCHSIZE, shuffle=True)
    
    
    class VAE(nn.Module):
        def __init__(self):
            super(VAE, self).__init__()
    
            self.fc1 = nn.Linear(784, 400)
            self.fc21 = nn.Linear(400, 20)   # mean
            self.fc22 = nn.Linear(400, 20)   # var
            self.fc3 = nn.Linear(20, 400)
            self.fc4 = nn.Linear(400, 784)
    
        def encode(self, x):
            h1 = F.relu(self.fc1(x))
            return self.fc21(h1), self.fc22(h1)
    
        def reparametrize(self, mu, logvar):
            std = logvar.mul(0.5).exp_()                     # 矩阵点对点相乘之后再把这些元素作为e的指数
            eps = torch.FloatTensor(std.size()).normal_()    # 生成随机数组
            if torch.cuda.is_available():
                eps = eps.cuda()
            return eps.mul(std).add_(mu)    # 用一个标准正态分布乘标准差,再加上均值,使隐含向量变为正太分布
    
        def decode(self, z):
            h3 = F.relu(self.fc3(z))
            return torch.tanh(self.fc4(h3))
    
        def forward(self, x):
            mu, logvar = self.encode(x)          # 编码
            z = self.reparametrize(mu, logvar)   # 重新参数化成正态分布
            return self.decode(z), mu, logvar    # 解码,同时输出均值方差
    
    
    net = VAE()  # 实例化网络
    if torch.cuda.is_available():
        net = net.cuda()
    
    reconstruction_function = nn.MSELoss(size_average=False)
    
    
    def loss_function(recon_x, x, mu, logvar):
        """
        recon_x: generating images
        x: origin images
        mu: latent mean
        logvar: latent log variance
        """
        MSE = reconstruction_function(recon_x, x)
        # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
        KLD = torch.sum(KLD_element).mul_(-0.5)
        # KL divergence
        return MSE + KLD
    
    
    optimizer = torch.optim.Adam(net.parameters(), lr=LR)
    
    
    def to_img(x):   # x shape (bachsize, 28*28), x中每个像素点的大小范围[-1., 1.]
        '''
        定义一个函数将最后的结果转换回图片
        '''
        x = 0.5 * (x + 1.)
        x = x.clamp(0, 1)
        x = x.view(x.shape[0], 1, 28, 28)
        return x
    
    
    for epoch in range(EPOCH):
        for iteration, (im, y) in enumerate(train_loader):
            im = im.view(im.shape[0], -1)
            if torch.cuda.is_available():
                im = im.cuda()
            recon_im, mu, logvar = net(im)
            loss = loss_function(recon_im, im, mu, logvar) / im.shape[0]   # 将 loss 平均
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            if iteration % 100 == 0:
                print('epoch: {:2d} | iteration: {:4d} | Loss: {:.4f}'.format(epoch, iteration, loss.data.numpy()))
                save = to_img(recon_im.cpu().data)
                if not os.path.exists('./vae_img'):
                    os.mkdir('./vae_img')
                save_image(save, './vae_img/image_{}_{}.png'.format(epoch, iteration))
    
    
    # test
    code = torch.randn(1, 20)   # 随机给一个符合正态分布的张量
    out = net.decode(code)
    img = to_img(out)
    save_image(img, './vae_img/test_img.png')
  • 相关阅读:
    “学霸系统”app——NABC
    Scrum Meeting NO.1
    团队成员角色
    团队作业 #2
    团队作业 #1
    Qt, 我回来了。。。
    boost: tcp client sample
    makefile 中定义宏位置需要注意一下
    libpcap报文解析: ipv4、ipv6 @ 2014.7.2
    编程网站收集
  • 原文地址:https://www.cnblogs.com/picassooo/p/12601785.html
Copyright © 2020-2023  润新知