对抗生成网络主要的原理,主要是使用生成器生成网络,判别器进行判别
生成器损失值:
判别器判别生成图片为真的BCE损失值
判别器损失值
判别真实图片为真和判别生成图片为假的BCE损失值
第一步: 使用argparse构造cmd输入的参数函数, 包含batch_size, lr学习率 ,latent_dim表示噪音生成的维度
第二步: 构造mnist数据集的dataloaders,使用torchvison.dataset.MNIST数据集, 使用transforms.compose([])进行数据集的转换, 使用torch.utils.data.Dataloaders构造batch_size数据集
第三步: 实例化生成网络
生成网络网络结构:
构造block模块,包含nn.Leanear, nn.BatchNormal1d(out_feats, 0.8) 和 nn.LeakyRelu(0.2) 表示对于小于0的数据乘以0.2,将其比例进行稀释
第一层: 转换为int(latent_dim, 128)
第二层: 转换为(128, 256)
第三层:转换为(256, 512)
第四层: 转换为(512, 1024)
第五层: 转换为(1024, int(np.prod(input_size)))
第六层: nn.Tanh()
实例化判别网络
判别网络网络结构:
第一层: 转换为(int(np.prod(input_size), 512))
第二层: 转换为(512, 256)
第三层: 转换为(256, 1)
第四层:转换为nn.Sigmoid()
第四步: 进行网络训练操作
第一步: 使用torch.nn.BCELoss() 构造损失函数
第二步: 实例化判别网络和生成网络
第三步: 构造迭代优化器
第四步:进行网络的训练操作,构造全1的真实标签valid和构造全0的虚假标签fake, 将输入的数据转换为Variable的tensor类型,构造随机的100维噪音数据,将噪音数据传入生成生成图片
第五步: 构造生成器的优化函数,构造判别器的优化函数, 生成器的损失值,使用判别器判别为真的BCE损失值 , 对于判别器的损失值,使用判别器判别为真实图片为真,判别生成图片为假的损失值
第六步: 打印数据,同时使用save_images进行数据集的保存
import argparse import time import torch import torch.utils.data from torchvision import transforms, datasets from torch import nn from torch.autograd import Variable import os import numpy as np from torch import optim from torchvision.utils import save_image parser = argparse.ArgumentParser() parser.add_argument('--n_epochs', type=int, default=20, help='迭代的次数') parser.add_argument('--batch_size', type=int, default=64, help='每个batch_size') parser.add_argument('--lr', type=int, default=0.0002, help='表示学习率') parser.add_argument('--b1', type=float, default=0.5, help='表示动量梯度下降第一个参数') parser.add_argument('--b2', type=float, default=0.99, help='动量梯度下降第二个参数') parser.add_argument('--n_cpu', type=int, default=8, help='表示cpu运行的个数') parser.add_argument('--latent_dim', type=int, default=100, help='表示噪音数据生成的维度') parser.add_argument('--image_size', type=int, default=28, help='表示输入数据的维度') parser.add_argument('--channel', type=int, default=1, help='表示输入数据的通道数') parser.add_argument('--sample_interval', type=int, default=400, help='表示保存图片的迭代数') opt = parser.parse_args() # 表示输入数据的尺寸 input_size = (opt.image_size, opt.image_size, opt.channel) os.makedirs('./data', exist_ok=True) os.makedirs('./data/mnist', exist_ok=True) # 进行数据集的准备 os.makedirs('./data/mnist', exist_ok=True) dataloaders = torch.utils.data.DataLoader( datasets.MNIST( './data/mnist', train = True, download=True, transform=transforms.Compose( [transforms.Resize(opt.image_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] ), ), batch_size=opt.batch_size, shuffle=True ) cuda = True if torch.cuda.is_available() else False # 构建生成网络 class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() def block(in_feats, out_feats, Normalize=True): layers = [nn.Linear(in_feats, out_feats)] if Normalize: layers.append(nn.BatchNorm1d(out_feats, 0.8)) layers.append(nn.LeakyReLU(0.2, inplace=True)) return layers self.model = nn.Sequential( *block(opt.latent_dim, 128, Normalize=False), *block(128, 256), *block(256, 512), *block(512, 1024), # 最后一层全连接层,不需要进行batchnomalize 和 relu操作 nn.Linear(1024, int(np.prod(input_size))), nn.Tanh(), ) def forward(self, x): output = self.model(x) return output class Discrimator(nn.Module): def __init__(self): super(Discrimator, self).__init__() self.model = nn.Sequential( nn.Linear(int(np.prod(input_size)), 512), nn.LeakyReLU(0.2, inplace=True), nn.Linear(512, 256), nn.LeakyReLU(0.2, inplace=True), nn.Linear(256, 1), nn.Sigmoid(), ) def forward(self, x): output = self.model(x) return output # 构造损失值函数 adversial_loss = torch.nn.BCELoss() generator = Generator() discrimator = Discrimator() # 将数据放在cuda上 if cuda: adversial_loss.cuda() generator.cuda() discrimator.cuda() optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) optimizer_D = torch.optim.Adam(discrimator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor for epoch in range(opt.n_epochs): for i, (image, _) in enumerate(dataloaders): # 构造标签损失函数 valid = Variable(tensor(image.size(0), 1).fill_(1.0), requires_grad=False) fake = Variable(tensor(image.size(0), 1).fill_(0.0), requires_grad=False) # 构建真实的输入值 real_image = torch.reshape(Variable(image.type(tensor)), (int(image.shape[0]), -1)) optimizer_G.zero_grad() # 对于生成器 z = Variable(tensor(np.random.normal(0, 1, (image.shape[0], opt.latent_dim)))) gen_images = generator(z) g_loss = adversial_loss(discrimator(gen_images), valid) g_loss.backward() optimizer_G.step() # # 构造判别器的损失函数 optimizer_D.zero_grad() real_loss = adversial_loss(discrimator(real_image), valid) fake_loss = adversial_loss(discrimator(gen_images.detach()), fake) d_loss = (real_loss + fake_loss) / 2 d_loss.backward() optimizer_D.step() print( '[Epoch %d / %d] Batch %d / %d [D loss: %f] [G loss: %f]' % (epoch, opt.n_epochs, i, len(dataloaders), d_loss.item(), g_loss.item()) ) batches_done = epoch * len(dataloaders) + i if batches_done % opt.sample_interval == 0: save_image(gen_images.data[:25], 'images/%d.png' % batches_done, nrow=5, normalize=True)