• starGAN代码分析


    #参数设置
    import sys sys.path.append("/home/hxj/anaconda3/lib/python3.6/site-packages") from torchvision.datasets import ImageFolder from PIL import Image import torch import os import random c_dim=5 # dimension of domain labels (1st dataset) c2_dim=8 # dimension of domain labels (2nd dataset) celeba_crop_size=178 # crop size for the CelebA dataset rafd_crop_size=256 #crop size for the RaFD dataset image_size=128 #image resolution g_conv_dim=64 # number of conv filters in the first layer of G d_conv_dim=64 # number of conv filters in the first layer of D g_repeat_num = 6 #number of residual blocks in G d_repeat_num=6 #number of strided conv layers in D lambda_cls=1 #weight for domain classification loss lambda_rec=10 # weight for reconstruction loss lambda_gp=10 #'weight for gradient penalty # Training configuration. dataset='CelebA' # choices=['CelebA', 'RaFD', 'Both']) batch_size=16 # 'mini-batch size num_iters=200000 #number of total iterations for training D num_iters_decay=100000 #number of iterations for decaying lr g_lr=0.0001 #learning rate for G d_lr=0.0001 #learning rate for D n_critic=5 #number of D updates per each G update beta1=0.5 #beta1 for Adam optimizer beta2=0.999 #beta2 for Adam optimizer resume_iters=None #resume training from this step selected_attrs=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'] #selected attributes for the CelebA dataset' # Test configuration. test_iters=200000 #test model from this step # Miscellaneous. num_workers=1 mode='test' # choices=['train', 'test']) use_tensorboard=True # Directories. celeba_image_dir='../data/CelebA_nocrop/images/' if mode == 'train' else '../test/test/' attr_path='../data/list_attr_celeba.txt' if mode == 'train' else '../test/test_celeba.txt' rafd_image_dir='../data/RaFD/train/' log_dir='../test/logs' model_save_dir='../stargan/models' sample_dir='../test/samples' result_dir='../test/result' # Step size. log_step=10 sample_step=1000 model_save_step=10000 lr_update_step=1000
    import tensorflow as tf
    #这是加载TensorBord
    class Logger(object):
        """Tensorboard logger."""
    
        def __init__(self, log_dir):
            """Initialize summary writer."""
            self.writer = tf.summary.FileWriter(log_dir)
    
        def scalar_summary(self, tag, value, step):
            """Add scalar summary."""
            summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
            self.writer.add_summary(summary, step)

    #预处理和加载数据

    from torch.utils import data
    from torchvision import transforms as T
    
    class CelebA(data.Dataset):
        """Dataset class for the CelebA dataset."""
    
        def __init__(self, image_dir, attr_path, selected_attrs, transform, mode):
            """Initialize and preprocess the CelebA dataset."""
            self.image_dir = image_dir
            self.attr_path = attr_path
            self.selected_attrs = selected_attrs
            self.transform = transform
            self.mode = mode
            self.train_dataset = []
            self.test_dataset = []
            self.attr2idx = {}
            self.idx2attr = {}
            self.preprocess()
            
            if mode == 'train':
                self.num_images = len(self.train_dataset)
            else:
                self.num_images = len(self.test_dataset)
            """
            train_dataset的数据格式如下
             '000003.jpg', [True, False, False, False, True]],
            """
           
        def preprocess(self):
            """Preprocess the CelebA attribute file."""
            lines = [line.rstrip() for line in open(self.attr_path, 'r')]
            all_attr_names = lines[1].split()
    
            for i, attr_name in enumerate(all_attr_names):
                self.attr2idx[attr_name] = i
                self.idx2attr[i] = attr_name
           
            lines = lines[2:]
            random.seed(1234)
            random.shuffle(lines)
            for i, line in enumerate(lines):
                split = line.split()
                filename = split[0]
                values = split[1:]
    
                label = []
                for attr_name in self.selected_attrs:
                    idx = self.attr2idx[attr_name]
                    label.append(values[idx] == '1')
    
                if (i+1) < 2000:
                    self.test_dataset.append([filename, label])
                else:
                    self.train_dataset.append([filename, label])
    
            print('Finished preprocessing the CelebA dataset...')
    
            #该方法是继承torch里面的utils文件夹里面data文件夹里面的Dataset类
        def __getitem__(self, index):
            """Return one image and its corresponding attribute label."""
            dataset = self.train_dataset if self.mode == 'train' else self.test_dataset
            filename, label = dataset[index]
            image = Image.open(os.path.join(self.image_dir, filename))
            return self.transform(image), torch.FloatTensor(label)
    
        def __len__(self):
            """Return the number of images."""
            return self.num_images
    
    
    def get_loader(image_dir, attr_path, selected_attrs, crop_size=178, image_size=128, 
                   batch_size=16, dataset='CelebA', mode='train', num_workers=1):
        """Build and return a data loader."""
        transform = []
        if mode == 'train':
            transform.append(T.RandomHorizontalFlip())
        transform.append(T.CenterCrop(crop_size))
        #to run only once
        transform.append(T.Resize(image_size))
        transform.append(T.ToTensor())
        transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
        transform = T.Compose(transform)
    
        if dataset == 'CelebA':
            #dataset 是CelebA的一个对象
            dataset = CelebA(image_dir, attr_path, selected_attrs, transform, mode)
            
            #加载自己私有数据,从folder.py里面进行加载,但是报错
        elif dataset == 'RaFD':
            dataset = ImageFolder(image_dir, transform)
            
            #DataLoader类中dataset参数必须是 data.Dataset 类
        data_loader = data.DataLoader(dataset=dataset,
                                      batch_size=batch_size,
                                      shuffle=(mode=='train'),
                                      num_workers=num_workers)
        return data_loader
    
    #celeba_loader 相当于是 data_loader,而data_loader 是 torch.utils.data.dataloader.DataLoader的返回值
    #其中 里面封装的dataset是CelebA 这个类的对象
    celeba_loader = get_loader(celeba_image_dir, attr_path, selected_attrs,celeba_crop_size, image_size, 
                               batch_size,'CelebA', mode, num_workers)

    网络模型结构

    import torch.nn as nn
    import torch.nn.functional as F
    import numpy as np
    
    class ResidualBlock(nn.Module):
        """Residual Block with instance normalization."""
        def __init__(self, dim_in, dim_out):
            super(ResidualBlock, self).__init__()
            self.main = nn.Sequential(
                nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
                nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True),
                nn.ReLU(inplace=True),
                nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
                nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True))
    
        def forward(self, x):
            return x + self.main(x)
    
    
    class Generator(nn.Module):
        """Generator network."""
        def __init__(self, conv_dim=64, c_dim=5, repeat_num=6):
            super(Generator, self).__init__()
    
            layers = []
            # 第一个卷积层,输入为图像和label的串联,3表示图像为3通道,c_dim为label的维度,
            layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))
            layers.append(nn.InstanceNorm2d(conv_dim, affine=True, track_running_stats=True))
            layers.append(nn.ReLU(inplace=True))
    
            # Down-sampling layers.
            curr_dim = conv_dim #这时候的64个维度
            for i in range(2):
                layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False))
                layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True, track_running_stats=True))
                layers.append(nn.ReLU(inplace=True))
                curr_dim = curr_dim * 2
                
            #经过两次循环,这时 curr_dim 的维度为256
            # Bottleneck layers.
            for i in range(repeat_num):
                layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))
            
            # Up-sampling layers.
            for i in range(2):
                layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False))
                layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True, track_running_stats=True))
                layers.append(nn.ReLU(inplace=True))
                curr_dim = curr_dim // 2
                
            #最后的维度为3维
            layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False))
            layers.append(nn.Tanh())
            self.main = nn.Sequential(*layers)
    
        def forward(self, x, c): #定义计算的过程
            # Replicate spatially and concatenate domain information.
            c = c.view(c.size(0), c.size(1), 1, 1) #view 相当于Numpy中的reshape
            c = c.repeat(1, 1, x.size(2), x.size(3)) #沿着指定的维度重复tensor
            x = torch.cat([x, c], dim=1) #将输入图像x,label向量c,串联
            return self.main(x)
    
    
    class Discriminator(nn.Module):
        """Discriminator network with PatchGAN."""
        def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6):
            super(Discriminator, self).__init__()
            layers = []
            layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
            layers.append(nn.LeakyReLU(0.01))
    
            curr_dim = conv_dim
            for i in range(1, repeat_num):
                layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))
                layers.append(nn.LeakyReLU(0.01))
                curr_dim = curr_dim * 2
    
            kernel_size = int(image_size / np.power(2, repeat_num))
            self.main = nn.Sequential(*layers) #将层加入到神经网络
            self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False)#D判读图像的真假
            self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=kernel_size, bias=False)#判别输入图像的label.
            
        def forward(self, x):
            h = self.main(x)     #这里的X表示训练时的图像,经过main()后生成2048维数据
            out_src = self.conv1(h) #out_src 表示图像的真假
            out_cls = self.conv2(h) # out_cls 表示图像的标签
            return out_src, out_cls.view(out_cls.size(0), out_cls.size(1))

    solver

    from torchvision.utils import save_image
    import time
    import datetime
    
    class Solver(object):
        """Solver for training and testing StarGAN."""
    
        def __init__(self, celeba_loader, rafd_loader):
            """Initialize configurations."""
    
            # Data loader.
            self.celeba_loader = celeba_loader
            self.rafd_loader = rafd_loader
    
            # Model configurations.
            self.c_dim = c_dim
            self.c2_dim = c2_dim
            self.image_size = image_size
            self.g_conv_dim = g_conv_dim
            self.d_conv_dim = d_conv_dim
            self.g_repeat_num = g_repeat_num
            self.d_repeat_num = d_repeat_num
            self.lambda_cls = lambda_cls
            self.lambda_rec = lambda_rec
            self.lambda_gp = lambda_gp
    
            # Training configurations.
            self.dataset = dataset
            self.batch_size = batch_size
            self.num_iters = num_iters
            self.num_iters_decay = num_iters_decay
            self.g_lr = g_lr
            self.d_lr = d_lr
            self.n_critic = n_critic
            self.beta1 = beta1
            self.beta2 = beta2
            self.resume_iters = resume_iters
            self.selected_attrs = selected_attrs
    
            # Test configurations.
            self.test_iters = test_iters
    
            # Miscellaneous.
            self.use_tensorboard = use_tensorboard
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            #self.device = torch.device('cpu')
            
            # Directories.
            self.log_dir = log_dir
            self.sample_dir = sample_dir
            self.model_save_dir = model_save_dir
            self.result_dir = result_dir
    
            # Step size.
            self.log_step = log_step
            self.sample_step = sample_step
            self.model_save_step = model_save_step
            self.lr_update_step = lr_update_step
    
            # Build the model and tensorboard.
            self.build_model()
            if self.use_tensorboard:
                self.build_tensorboard()
        
        def build_model(self):
            """Create a generator and a discriminator."""
            if self.dataset in ['CelebA', 'RaFD']:
                self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
                self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num) 
            elif self.dataset in ['Both']:
                self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num)   # 2 for mask vector.
                self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim+self.c2_dim, self.d_repeat_num)
    
            self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
            self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
            
            #打印网络结构
            #self.print_network(self.G, 'G')
            #self.print_network(self.D, 'D')
                
            self.G.to(self.device)
            self.D.to(self.device)
    
        def print_network(self, model, name):
            """Print out the network information."""
            num_params = 0
            for p in model.parameters():
                num_params += p.numel()
            print(model)
            print(name)
            print("The number of parameters: {}".format(num_params))
            
        def create_labels(self, c_org, c_dim=5, dataset='CelebA', selected_attrs=None):
            """Generate target domain labels for debugging and testing."""
            # Get hair color indices.
            if dataset == 'CelebA':
                hair_color_indices = []
                for i, attr_name in enumerate(selected_attrs):
                    if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
                        hair_color_indices.append(i)
                        # hair_color_indices [0 ,1 ,2]
            c_trg_list = []
            for i in range(c_dim):
                if dataset == 'CelebA':
                    c_trg = c_org.clone()
                    if i in hair_color_indices:  # Set one hair color to 1 and the rest to 0.
                        c_trg[:, i] = 1
                        for j in hair_color_indices:
                            if j != i:
                                c_trg[:, j] = 0
                    else:
                        c_trg[:, i] = (c_trg[:, i] == 0)  # Reverse attribute value.
                elif dataset == 'RaFD':
                    c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim)
    
                c_trg_list.append(c_trg.to(self.device))
            return c_trg_list
        
        def denorm(self, x):
            """Convert the range from [-1, 1] to [0, 1]."""
            out = (x + 1) / 2
            return out.clamp_(0, 1)
        
    
        def build_tensorboard(self):
            self.logger = Logger(self.log_dir)
    
        def restore_model(self, resume_iters):
            """Restore the trained generator and discriminator."""
            print('Loading the trained models from step {}...'.format(resume_iters))
            G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters))
            D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters))
            self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
            self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
            
        def update_lr(self, g_lr, d_lr):
            """Decay learning rates of the generator and discriminator."""
            for param_group in self.g_optimizer.param_groups:
                param_group['lr'] = g_lr
            for param_group in self.d_optimizer.param_groups:
                param_group['lr'] = d_lr
    
        def reset_grad(self):
            """Reset the gradient buffers."""
            self.g_optimizer.zero_grad()
            self.d_optimizer.zero_grad()
            
        def classification_loss(self, logit, target, dataset='CelebA'):
            """Compute binary or softmax cross entropy loss."""
            if dataset == 'CelebA':
                return F.binary_cross_entropy_with_logits(logit, target, size_average=False) / logit.size(0)
            elif dataset == 'RaFD':
                return F.cross_entropy(logit, target)
            
        def gradient_penalty(self, y, x):
            """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
            weight = torch.ones(y.size()).to(self.device)
            dydx = torch.autograd.grad(outputs=y,
                                       inputs=x,
                                       grad_outputs=weight,
                                       retain_graph=True,
                                       create_graph=True,
                                       only_inputs=True)[0]
    
            dydx = dydx.view(dydx.size(0), -1)
            dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
            return torch.mean((dydx_l2norm-1)**2)
    
        def label2onehot(self, labels, dim):
            """Convert label indices to one-hot vectors."""
            batch_size = labels.size(0)
            out = torch.zeros(batch_size, dim)
            out[np.arange(batch_size), labels.long()] = 1
            return out
    
            
        def train(self):
            """Train StarGAN within a single dataset."""
            # Set data loader.
            if self.dataset == 'CelebA':
                data_loader = self.celeba_loader
            elif self.dataset == 'RaFD':
                data_loader = self.rafd_loader
    
            # Fetch fixed inputs for debugging.
            data_iter = iter(data_loader)
            x_fixed, c_org = next(data_iter)
            # x_fixed表示图像像素值  c_org表示真实标签值  tensor([[ 1.,  0.,  0.,  1.,  1.]])
    
            x_fixed = x_fixed.to(self.device)
            c_fixed_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs)
            #print(c_fixed_list)
            #[tensor([[ 1.,  0.,  0.,  1.,  1.]]), tensor([[ 0.,  1.,  0.,  1.,  1.]]), tensor([[ 0.,  0.,  1.,  1.,  1.]]),
            # tensor([[ 1.,  0.,  0.,  0.,  1.]]), tensor([[ 1.,  0.,  0.,  1.,  0.]])]
            # Learning rate cache for decaying.
            g_lr = self.g_lr
            d_lr = self.d_lr
    
            # Start training from scratch or resume training.
            start_iters = 0  
            if self.resume_iters: #参数resume_iters 设置为none 
                start_iters = self.resume_iters #可以不连续训练,从之前训练好后的结果处开始
                self.restore_model(self.resume_iters)
            
            # Start training.
            print('Start training...')
            start_time = time.time()
            for i in range(start_iters, self.num_iters):
    
                # =================================================================================== #
                #                             1. Preprocess input data                                #
                # =================================================================================== #
    
                # Fetch real images and labels.
                try:
                    x_real, label_org = next(data_iter)
                except:
                    data_iter = iter(data_loader)
                    x_real, label_org = next(data_iter)
               
                # Generate target domain labels randomly.
                rand_idx = torch.randperm(label_org.size(0)) #tensor([ 0])
                label_trg = label_org[rand_idx] #tensor([[ 1.,  0.,  0.,  1.,  1.]]) 真实label,从数据中取出
                if self.dataset == 'CelebA':
                    c_org = label_org.clone()
                    c_trg = label_trg.clone()
                elif self.dataset == 'RaFD':
                    c_org = self.label2onehot(label_org, self.c_dim)
                    c_trg = self.label2onehot(label_trg, self.c_dim)
    
                x_real = x_real.to(self.device)           # Input images.
                c_org = c_org.to(self.device)             # Original domain labels.
                #print(c_org) tensor([[ 1.,  0.,  0.,  1.,  1.]]
                c_trg = c_trg.to(self.device)             # Target domain labels.
                #print(c_trg) tensor([[ 1.,  0.,  0.,  1.,  1.]]
                label_org = label_org.to(self.device)     # Labels for computing classification loss.
                label_trg = label_trg.to(self.device)     # Labels for computing classification loss.
    
                # =================================================================================== #
                #                             2. Train the discriminator                              #
                # =================================================================================== #
    
                # Compute loss with real images.
                out_src, out_cls = self.D(x_real)
                """ 
                out_src
                tensor(1.00000e-03 *
                   [[[[-1.8202,  0.3373],
                      [-0.5725,  0.4968]]]])
                out_cls
                tensor(1.00000e-03 *
                   [[ 0.3915,  2.0016,  0.4509, -2.0520,  2.4382]])
                """
                d_loss_real = - torch.mean(out_src) # d_loss_real最小,那么 out_src 最大==1 (针对图像)
                # d_loss_real = tensor(1.00000e-04 * 3.8965)
                d_loss_cls = self.classification_loss(out_cls, label_org, self.dataset) #针对标签 
                # d_loss_cls = tensor(3.4666)
                # Compute loss with fake images.
                #将真实图像输入x_real和假的标签c_trg输入生成网络,得到生成图像x_fake,
                
                x_fake = self.G(x_real, c_trg) #x_fake 生成一个图像数据
    
                out_src, out_cls = self.D(x_fake.detach())
                """
                out_src
                tensor(1.00000e-03 *
                   [[[[-1.5289,  0.8110],
                  [ 0.2153,  0.4624]]]])
                out_cls
                tensor(1.00000e-03 *
                   [[ 1.4681,  1.9497,  1.2743, -1.1915,  0.7609]])
                """
                d_loss_fake = torch.mean(out_src) #假图像为0 
                #tensor(1.00000e-05 *-1.0045)
    
                # Compute loss for gradient penalty.
                #计算梯度惩罚因子alpha,根据alpha结合x_real,x_fake,输入判别网络,计算梯度,得到梯度损失函数,
                alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device) 
                # alpha是一个随机数 tensor([[[[ 0.7610]]]])
                x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
                # x_hat是一个图像大小的张量数据,随着alpha的改变而变化
                out_src, _ = self.D(x_hat) #x_hat 表示梯度惩罚因子
                d_loss_gp = self.gradient_penalty(out_src, x_hat)
                #最终d_loss_gp 在0.9954~ 0.9956 波动
                
                # Backward and optimize.
                #损失包含4项:
                # 1.真实图像判定为真
                # 2.真实图像+错误标签记过G网络生成的图像判定为假
                # 3.真实图像经过D网络的生成的标签与真实标签之间的差异损失
                # 4.真实图像和 真实图像+错误标签记过G网络生成的图像 融合的梯度惩罚因子
                d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()
    
                # Logging.
                loss = {}
                loss['D/loss_real'] = d_loss_real.item()
                loss['D/loss_fake'] = d_loss_fake.item()
                loss['D/loss_cls'] = d_loss_cls.item()
                loss['D/loss_gp'] = d_loss_gp.item()
                
                # =================================================================================== #
                #                               3. Train the generator                                #
                # =================================================================================== #
                #生成网络的作用是,输入original域的图可以生成目标域的图像,输入为目标域的图像,生成original域的图像(重建)
                if (i+1) % self.n_critic == 0:
                    # Original-to-target domain.
                    #将真实图像输入x_real和假的标签c_trg输入生成网络,得到生成图像x_fake
                    x_fake = self.G(x_real, c_trg)
                    print("c_trg:",c_trg)
                    out_src, out_cls = self.D(x_fake)
                    g_loss_fake = - torch.mean(out_src) #这里是对抗损失,希望生成的假图像为1
                    g_loss_cls = self.classification_loss(out_cls, label_trg, self.dataset)#向目标标签进行转化
    
                    # Target-to-original domain.
                    x_reconst = self.G(x_fake, c_org)
                    print("c_org:",c_org)
                    sys.exit(0)
                    g_loss_rec = torch.mean(torch.abs(x_real - x_reconst))
    
                    # Backward and optimize.
                    g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
                    self.reset_grad()
                    g_loss.backward()
                    self.g_optimizer.step()
    
                    # Logging.
                    loss['G/loss_fake'] = g_loss_fake.item()
                    loss['G/loss_rec'] = g_loss_rec.item()
                    loss['G/loss_cls'] = g_loss_cls.item()
    
                # =================================================================================== #
                #                                 4. Miscellaneous                                    #
                # =================================================================================== #
    
                # Print out training information.
                if (i+1) % self.log_step == 0:
                    et = time.time() - start_time
                    et = str(datetime.timedelta(seconds=et))[:-7]
                    log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters)
                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)
    
                    if self.use_tensorboard:
                        for tag, value in loss.items():
                            self.logger.scalar_summary(tag, value, i+1)
    
                # Translate fixed images for debugging.
                if (i+1) % self.sample_step == 0:
                    with torch.no_grad():
                        x_fake_list = [x_fixed]
                        for c_fixed in c_fixed_list:
                            x_fake_list.append(self.G(x_fixed, c_fixed))
                        x_concat = torch.cat(x_fake_list, dim=3)
                        sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1))
                        save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
                        print('Saved real and fake images into {}...'.format(sample_path))
    
                # Save model checkpoints.
                if (i+1) % self.model_save_step == 0:
                    G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1))
                    D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1))
                    torch.save(self.G.state_dict(), G_path)
                    torch.save(self.D.state_dict(), D_path)
                    print('Saved model checkpoints into {}...'.format(self.model_save_dir))
    
                # Decay learning rates.
                if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay):
                    g_lr -= (self.g_lr / float(self.num_iters_decay))
                    d_lr -= (self.d_lr / float(self.num_iters_decay))
                    self.update_lr(g_lr, d_lr)
                    print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
    
        def test(self):
            """Translate images using StarGAN trained on a single dataset."""
            # Load the trained generator.
            self.restore_model(test_iters)
            
            # Set data loader.
            if self.dataset == 'CelebA':
                data_loader = celeba_loader
                
            elif self.dataset == 'RaFD':
                data_loader = rafd_loader
            
            with torch.no_grad():
                for i, (x_real, c_org) in enumerate(data_loader):
                    # Prepare input images and target domain labels.
                    x_real = x_real.to(self.device)
                    c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs)
                
                    # Translate images.
                    x_fake_list = [x_real]
                       
                    for c_trg in c_trg_list:
                        x_fake_list.append(self.G(x_real, c_trg))
                  
                    # Save the translated images.
                    x_concat = torch.cat(x_fake_list, dim=3)
                    result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
                    save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
                    print('Saved real and fake images into {}...'.format(result_path))

    开始训练

    rafd_loader = None
    solver = Solver(celeba_loader, rafd_loader)
    solver.train()
  • 相关阅读:
    Mysql命令下导出select查询数据之 select ... into outfile方法
    接口调试工具Postman之自动同步Chrome cookies,实现自动登陆验证
    PHP函数file_get_contents()使用 https 协议时报错:SSL operation failed
    MySQL中连接超时自动断开的解决方案
    UEditor富文本WEB编辑器设置代码高亮
    Laravel 自定义公共函数全局使用,并设置自定加载
    Laravel 解决blade模板转义html标签问题
    PHP 高效导入导出Excel(csv)方法之fgetcsv()和fputcsv()函数
    Mysql命令行tab自动补全方法
    PHP利用get_headers()函数判断远程的url地址是否有效
  • 原文地址:https://www.cnblogs.com/hxjbc/p/9361168.html
Copyright © 2020-2023  润新知