• pytorch jupyter下的CycleGAN代码


    模型用的是苹果转橘子的数据集。但可能是由于模型太大且图片数量不足(1000张左右)。因此,有些图片transform不是很好。

    模型是挂在天池上面跑的。还需要导入until.py文件,我放在文末了。

    import glob
    import random
    import os
    import torch
    from torch.utils.data import Dataset
    from PIL import Image
    import utils
    import torchvision.transforms as transforms
    from torch.autograd import Variable
    from PIL import Image
    import matplotlib.pyplot as plt
    %matplotlib inline
    import torchvision.utils as vutils
    import numpy as np
    import torch.nn as nn
    import torch.nn.functional as F
    import itertools
    import torchvision

    定义一些超参

    """ gpu """
    gpu_id = [0]
    utils.cuda_devices(gpu_id)
    # 决定我们在哪个设备上运行
    device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
    
    """ param """
    epochs = 2500
    batch_size = 50
    size=64
    lr = 0.0002
    n_critic = 5
    z_dim = 100

    导入数据集

    class ImageDataset(Dataset):
        def __init__(self, root, transforms_=None, unaligned=False, mode='train'):
            self.transform = transforms.Compose(transforms_)  # 将几个变化整合在一起
            self.unaligned = unaligned
            
            # 匹配 `数据集文件夹/(train or test)/(A or B)` 下的所有文件并打乱
            self.files_A = sorted(glob.glob(os.path.join(root, '%sA' % mode) + '/*.*'))
            self.files_B = sorted(glob.glob(os.path.join(root, '%sB' % mode) + '/*.*'))
        def __getitem__(self, index):  # `__getitem__`, 允许用户像字典一样访问数据 : X[key] -> value 
            
            item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))
    
            if self.unaligned:
                # 不对齐则随机出一张图片
                item_B = self.transform(Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)]))
            else:
                item_B = self.transform(Image.open(self.files_B[index % len(self.files_B)]))
    
            return {'A': item_A, 'B': item_B}
    
        def __len__(self):
            # 两者中取一张取数量大的
            return max(len(self.files_A), len(self.files_B))
    # Dataset loader
    transforms_ = [transforms.Resize(int(size*1.12), Image.BICUBIC), 
                   transforms.RandomCrop(size), 
                   transforms.RandomHorizontalFlip(), # 随机水平翻转
                   transforms.ToTensor(),             # PIL.Image/np.ndarray (HWC) [0, 255] -> torch.FloatTensor (CHW) [0.0, 1.0]
                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] # 将三个通道 `Normalize`
    dataloader = torch.utils.data.DataLoader(ImageDataset(r'dataset/apple2orange', transforms_=transforms_, unaligned=True), 
                            batch_size=batch_size, shuffle=True)
    # 展示一些训练图片
    real_batch = next(iter(dataloader))['B']
    plt.figure(figsize=(8,8))
    plt.axis("off")
    plt.title("Training Images")
    plt.imshow(np.transpose(vutils.make_grid(real_batch.to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

    定义模型

    class ResidualBlock(nn.Module):
        def __init__(self, in_features):
            super(ResidualBlock, self).__init__()
    
            conv_block = [  nn.ReflectionPad2d(1),
                            nn.Conv2d(in_features, in_features, 3),
                            nn.InstanceNorm2d(in_features),
                            nn.ReLU(inplace=True),   # 进行原地操作, 节省内存
                            nn.ReflectionPad2d(1),
                            nn.Conv2d(in_features, in_features, 3),
                            nn.InstanceNorm2d(in_features)  ]
    
            self.conv_block = nn.Sequential(*conv_block)
    
        def forward(self, x):
            return x + self.conv_block(x)
    class Generator(nn.Module):
        def __init__(self, input_nc, output_nc, n_residual_blocks=2):
            super(Generator, self).__init__()
    
            # Initial convolution block       
            model = [   nn.ReflectionPad2d(3),
                        nn.Conv2d(input_nc, 64, 7),
                        nn.InstanceNorm2d(64),
                        nn.ReLU(inplace=True) ]
    
            # Downsampling
            in_features = 64
            out_features = in_features*2
            for _ in range(2):
                model += [  nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                            nn.InstanceNorm2d(out_features),
                            nn.ReLU(inplace=True) ]
                in_features = out_features
                out_features = in_features*2
    
            # Residual blocks
            for _ in range(n_residual_blocks):
                model += [ResidualBlock(in_features)]
    
            # Upsampling
            out_features = in_features//2
            for _ in range(2):
                model += [nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                          nn.InstanceNorm2d(out_features),
                          nn.ReLU(inplace=True) ]
                in_features = out_features
                out_features = in_features//2
    
            # Output layer
            model += [nn.ReflectionPad2d(3),
                      nn.Conv2d(64, output_nc, 7),
                      nn.Tanh() ]
    
            self.model = nn.Sequential(*model)
    
        def forward(self, x):
            return self.model(x)
    class Discriminator(nn.Module):
        def __init__(self, input_nc):
            super(Discriminator, self).__init__()
    
            # A bunch of convolutions one after another
            model = [nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
                     nn.LeakyReLU(0.2, inplace=True) ]
    
            model += [nn.Conv2d(64, 128, 4, stride=2, padding=1),
                      nn.InstanceNorm2d(128), 
                      nn.LeakyReLU(0.2, inplace=True) ]
    
            model += [nn.Conv2d(128, 256, 4, stride=2, padding=1),
                      nn.InstanceNorm2d(256), 
                      nn.LeakyReLU(0.2, inplace=True) ]
    
            model += [nn.Conv2d(256, 512, 4, padding=1),
                      nn.InstanceNorm2d(512), 
                      nn.LeakyReLU(0.2, inplace=True) ]
    
            # FCN classification layer
            model += [nn.Conv2d(512, 1, 4, padding=1)]
    
            self.model = nn.Sequential(*model)
    
        def forward(self, x):
            x =  self.model(x)
            # Globel average pooling and flatten
            return F.avg_pool2d(x, x.shape[2:]).view(x.shape[0], -1)

    实例化模型

    netG_A2B = Generator(3, 3)
    netG_B2A = Generator(3, 3)
    netD_A = Discriminator(3)
    netD_B = Discriminator(3)
    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()
    utils.cuda([netG_A2B, netG_B2A, netD_A, netD_B, criterion_GAN, criterion_cycle, criterion_identity])
    # Optimizers & LR schedulers
    optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),  # `itertools.chain` 相当于把两个参数结合在一起了
                                   lr=lr, betas=(0.5, 0.999))
    optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=lr, betas=(0.5, 0.999))

    每次训练的语句停了,都要重新运行这句,把保存的最新模型变成本次运行的模型

    """ load checkpoint """
    ckpt_dir = './checkpoints1/celeba_cyclegan'
    utils.mkdir(ckpt_dir)
    try:
        ckpt = utils.load_checkpoint(ckpt_dir)
        start_epoch = ckpt['epoch']
        netD_A.load_state_dict(ckpt['netD_A'])
        netD_B.load_state_dict(ckpt['netD_B'])
        netG_A2B.load_state_dict(ckpt['netG_A2B'])
        netG_B2A.load_state_dict(ckpt['netG_B2A'])
        optimizer_G.load_state_dict(ckpt['optimizer_G'])
        optimizer_D_A.load_state_dict(ckpt['optimizer_D_A'])
        optimizer_D_B.load_state_dict(ckpt['optimizer_D_B'])
    except:
        print(' [*] No checkpoint!')
        start_epoch = 0
    class ReplayBuffer():
        def __init__(self, max_size=50):
            assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.'
            self.max_size = max_size
            self.data = []
    
        def push_and_pop(self, data):
            to_return = []
            for element in data.data:
                element = torch.unsqueeze(element, 0) # 在指定位置添加一个维度
                if len(self.data) < self.max_size:
                    self.data.append(element)
                    to_return.append(element)
                else:
                    if random.uniform(0,1) > 0.5:
                        i = random.randint(0, self.max_size-1)
                        to_return.append(self.data[i].clone())  # torch.Tensor.clone 相当于 .copy
                        self.data[i] = element
                    else:
                        to_return.append(element)
            return Variable(torch.cat(to_return))

    定义一些用到的变量

    # Inputs & targets memory allocation
    Tensor = torch.cuda.FloatTensor
    input_A = Tensor(batch_size, 3, size, size)
    input_B = Tensor(batch_size, 3, size, size)
    target_real = Variable(Tensor(batch_size).fill_(1.0), requires_grad=False)
    target_fake = Variable(Tensor(batch_size).fill_(0.0), requires_grad=False)
    
    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()
    A = [] # 用来显示图片的
    B = []
    for epoch in range(start_epoch, epochs):
        for i, batch in enumerate(dataloader):
            if i == len(dataloader) - 1:
                continue
            # Set model input (X, 3, H, W)
            real_A = Variable(input_A.copy_(batch['A']))
            real_B = Variable(input_B.copy_(batch['B']))
            
            real_A, real_B, target_real, target_fake = utils.cuda([real_A, real_B, target_real, target_fake])
            
            #-------- Generators A2B and B2A --------
            optimizer_G.zero_grad()
    
            # Identity loss
            # G_A2B(B) should equal B if real B is fed
            same_B = netG_A2B(real_B)
            loss_identity_B = criterion_identity(same_B, real_B)*5.0   # 0 维变量
            # G_B2A(A) should equal A if real A is fed
            same_A = netG_B2A(real_A)
            loss_identity_A = criterion_identity(same_A, real_A)*5.0   # 0 维变量
            
            # GAN loss
            fake_B = netG_A2B(real_A)
            pred_fake = netD_B(fake_B)
            # 此处有 `UserWarning` : [1], [1, 1] 不匹配, 但是不影响操作
            loss_GAN_A2B = criterion_GAN(pred_fake, target_real)       # 0 维变量
            
            fake_A = netG_B2A(real_B)
            pred_fake = netD_A(fake_A)
            # 此处有 `UserWarning` : [1], [1, 1] 不匹配, 但是不影响操作
            loss_GAN_B2A = criterion_GAN(pred_fake, target_real)       # 0 维变量
            
            # Cycle loss
            recovered_A = netG_B2A(fake_B)
            loss_cycle_ABA = criterion_cycle(recovered_A, real_A)*10.0 # 0 维变量
    
            recovered_B = netG_A2B(fake_A)
            loss_cycle_BAB = criterion_cycle(recovered_B, real_B)*10.0 # 0 维变量
    
            # Total loss
            loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
            loss_G.backward()
            
            optimizer_G.step()
            
            #-------- Discriminator A --------
            optimizer_D_A.zero_grad()
            
            # Real loss
            pred_real = netD_A(real_A)
            loss_D_real = criterion_GAN(pred_real, target_real)
    
            # Fake loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            fake_A = utils.cuda(fake_A)
            pred_fake = netD_A(fake_A.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)
    
            # Total loss
            loss_D_A = (loss_D_real + loss_D_fake)*0.5
            loss_D_A.backward()
    
            optimizer_D_A.step()
            #-------- Discriminator B --------
            optimizer_D_B.zero_grad()
    
            # Real loss
            pred_real = netD_B(real_B)
            loss_D_real = criterion_GAN(pred_real, target_real)
            
            # Fake loss
            fake_B = fake_B_buffer.push_and_pop(fake_B)
            fake_B = utils.cuda(fake_B)
            pred_fake = netD_B(fake_B.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)
    
            # Total loss
            loss_D_B = (loss_D_real + loss_D_fake)*0.5
            loss_D_B.backward()
    
            optimizer_D_B.step()
            ###################################
            
            if (i + 1) % 15 == 0:
                print("Epoch: (%3d) (%5d/%5d)" % (epoch, i + 1, len(dataloader)))
    
        if (epoch + 1) % 5 == 0:  # 因为我训练了近2000次,所以我每5个epoch存一次图片
            save_dir = './sample_images_while_training/cycleGAN'
            utils.mkdir(save_dir)
            # torchvision.utils.save_image(real_A, '%s/Epoch_(%d)_(%dof%d)_real_A.jpg' % (save_dir, epoch, i + 1, len(dataloader)), nrow=10)
            # torchvision.utils.save_image(real_B, '%s/Epoch_(%d)_(%dof%d)_real_B.jpg' % (save_dir, epoch, i + 1, len(dataloader)), nrow=10)
            torchvision.utils.save_image(fake_A, '%s/Epoch_(%d)_(%dof%d)_fake_A.jpg' % (save_dir, epoch, i + 1, len(dataloader)), nrow=10)
            torchvision.utils.save_image(fake_B, '%s/Epoch_(%d)_(%dof%d)_fake_B.jpg' % (save_dir, epoch, i + 1, len(dataloader)), nrow=10)
    
            with torch.no_grad():
    
                A.append(vutils.make_grid(fake_A.detach().cpu(), padding=2, normalize=True))
                B.append(vutils.make_grid(fake_B.detach().cpu(), padding=2, normalize=True))
                    
        utils.save_checkpoint({'epoch': epoch + 1,
                               'netD_A': netD_A.state_dict(),
                               'netD_B': netD_B.state_dict(),
                               'netG_A2B': netG_A2B.state_dict(),
                               'netG_B2A': netG_B2A.state_dict(),
                               'optimizer_G': optimizer_G.state_dict(),
                               'optimizer_D_A': optimizer_D_A.state_dict(),
                               'optimizer_D_B': optimizer_D_B.state_dict(),},
                              '%s/Epoch_(%d).ckpt' % (ckpt_dir, epoch + 1),
                              max_keep=2)

    显示训练图片

    # 画出真实图像
    plt.figure(figsize=(15,15))
    plt.subplot(1,2,1)
    plt.axis("off")
    plt.title("A")
    plt.imshow(np.transpose(A[1], (1,2,0)))
    
    # 画出来自最后一次训练的假图像
    plt.subplot(1,2,2)
    plt.axis("off")
    plt.title("B")
    plt.imshow(np.transpose(B[1],(1,2,0)))
    plt.show()

     untils.py文件,其中定义了转cuda,保存模型,调用模型等函数

    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    import os
    import shutil
    
    import torch
    
    
    def mkdir(paths):
        if not isinstance(paths, (list, tuple)):
            paths = [paths]
        for path in paths:
            if not os.path.isdir(path):
                os.makedirs(path)
    
    
    def cuda_devices(gpu_ids):
        gpu_ids = [str(i) for i in gpu_ids]
        os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(gpu_ids)
    
    
    def cuda(xs):
        if torch.cuda.is_available():
            if not isinstance(xs, (list, tuple)):
                return xs.cuda()
            else:
                return [x.cuda() for x in xs]
    
    
    def save_checkpoint(state, save_path, is_best=False, max_keep=None):
        # save checkpoint
        torch.save(state, save_path)
    
        # deal with max_keep
        save_dir = os.path.dirname(save_path)
        list_path = os.path.join(save_dir, 'latest_checkpoint')
    
        save_path = os.path.basename(save_path)
        if os.path.exists(list_path):
            with open(list_path) as f:
                ckpt_list = f.readlines()
                ckpt_list = [save_path + '
    '] + ckpt_list
        else:
            ckpt_list = [save_path + '
    ']
    
        if max_keep is not None:
            for ckpt in ckpt_list[max_keep:]:
                ckpt = os.path.join(save_dir, ckpt[:-1])
                if os.path.exists(ckpt):
                    os.remove(ckpt)
            ckpt_list[max_keep:] = []
    
        with open(list_path, 'w') as f:
            f.writelines(ckpt_list)
    
        # copy best
        if is_best:
            shutil.copyfile(save_path, os.path.join(save_dir, 'best_model.ckpt'))
    
    
    def load_checkpoint(ckpt_dir_or_file, map_location=None, load_best=False):
        if os.path.isdir(ckpt_dir_or_file):
            if load_best:
                ckpt_path = os.path.join(ckpt_dir_or_file, 'best_model.ckpt')
            else:
                with open(os.path.join(ckpt_dir_or_file, 'latest_checkpoint')) as f:
                    ckpt_path = os.path.join(ckpt_dir_or_file, f.readline()[:-1])
        else:
            ckpt_path = ckpt_dir_or_file
        ckpt = torch.load(ckpt_path, map_location=map_location)
        print(' [*] Loading checkpoint from %s succeed!' % ckpt_path)
        return ckpt
  • 相关阅读:
    蓝桥杯 勾股数 暴力
    蓝桥杯 连接乘积 暴力
    蓝桥杯 师座操作系统 map
    蓝桥杯 洗牌 vector
    蓝桥杯 盾神与砝码称重 dfs 剪枝
    蓝桥杯 盾神与积木游戏 贪心
    RESTful风格API
    APIview使用
    linux常用命令
    python中的三种路径
  • 原文地址:https://www.cnblogs.com/abc23/p/14390153.html
Copyright © 2020-2023  润新知