一、自动编码器
1、AE.py
import torch from torch import nn class AE(nn.Module): def __init__(self): super(AE, self).__init__() #[b, 784] => [b, 20] self.encoder = nn.Sequential( nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 64), nn.ReLU(), nn.Linear(64, 20), nn.ReLU() ) #[b, 20] => [b, 784] self.decoder = nn.Sequential( nn.Linear(20, 64), nn.ReLU(), nn.Linear(64, 256), nn.ReLU(), nn.Linear(256, 784), nn.Sigmoid(), ) def forward(self, x): """ :param x: [b, 1, 28, 28] :return: """ batchsz = x.shape[0] #flatten x = x.view(batchsz, 784) #encoder x = self.encoder(x) #decoder x = self.decoder(x) #reshape x = x.view(batchsz,1, 28, 28) return x, None
2、main.py
import torch from torch.utils.data import DataLoader from torchvision import transforms, datasets from auto_encoder import AE from torch import nn, optim import visdom def main(): mnist_train = datasets.MNIST("mnist", True, transform=transforms.Compose([ transforms.ToTensor() ]), download=True) mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True) mnist_test = datasets.MNIST("mnist", False, transform=transforms.Compose([ transforms.ToTensor() ]), download=True) mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True) x, _ = iter(mnist_train).__next__() print(x.shape) model = AE() criton = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=1e-3) viz = visdom.Visdom() for epoch in range(1000): for batchidx, (x, _) in enumerate(mnist_train): #[b, 1, 28, 28] x_hat, _ = model(x) loss = criton(x_hat, x) optimizer.zero_grad() loss.backward() optimizer.step() print(epoch, "loss:", loss.item()) x, _ = iter(mnist_test).__next__() with torch.no_grad(): x_hat, _ = model(x) viz.images(x, nrow=8, win="x", opts=dict(title = "x")) viz.images(x_hat, nrow=8, win="x_hat", opts=dict(title="x_hat")) if __name__ == '__main__': main()
二、变分自动编码器编码器
1、模型
import torch from torch import nn import numpy as np class VAE(nn.Module): def __init__(self): super(VAE, self).__init__() #[b, 784] => [b, 20] #u:[b, 10] #sigma:[b, 10] self.encoder = nn.Sequential( nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 64), nn.ReLU(), nn.Linear(64, 20), nn.ReLU() ) #[b, 20] => [b, 784] self.decoder = nn.Sequential( nn.Linear(10, 64), nn.ReLU(), nn.Linear(64, 256), nn.ReLU(), nn.Linear(256, 784), nn.Sigmoid(), ) def forward(self, x): """ :param x: [b, 1, 28, 28] :return: """ batchsz = x.shape[0] #flatten x = x.view(batchsz, 784) #encoder #[b, 20], including mean and sigma h_ = self.encoder(x) #[b, 20] => [b, 10] and [b, 10] mu, sigma = h_.chunk(2, dim = 1) # reparametrize trick, epison~N(0, 1), [b, 10] h = mu + sigma * torch.randn_like(sigma) kld = 0.5 * torch.sum( torch.pow(mu, 2) + torch.pow(sigma, 2) - torch.log(1e-8 + torch.pow(sigma, 2)) - 1 ) / (batchsz*28*28) #decoder x = self.decoder(h) #reshape x = x.view(batchsz,1, 28, 28) return x, kld
2、运行程序
import torch from torch.utils.data import DataLoader from torchvision import transforms, datasets from vae import VAE from torch import nn, optim import visdom def main(): mnist_train = datasets.MNIST("mnist", True, transform=transforms.Compose([ transforms.ToTensor() ]), download=True) mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True) mnist_test = datasets.MNIST("mnist", False, transform=transforms.Compose([ transforms.ToTensor() ]), download=True) mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True) x, _ = iter(mnist_train).__next__() print(x.shape) model = VAE() criton = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=1e-3) viz = visdom.Visdom() for epoch in range(1000): for batchidx, (x, _) in enumerate(mnist_train): #[b, 1, 28, 28] x_hat, kld = model(x) loss = criton(x_hat, x) if kld is not None: loss = loss + 1.0 * kld optimizer.zero_grad() loss.backward() optimizer.step() print(epoch, "loss:", loss.item(), kld.item()) x, _ = iter(mnist_test).__next__() with torch.no_grad(): x_hat, _ = model(x) viz.images(x, nrow=8, win="x", opts=dict(title = "x")) viz.images(x_hat, nrow=8, win="x_hat", opts=dict(title="x_hat")) if __name__ == '__main__': main()