• 风格迁移训练实践与分析


    前一篇文章分享了Pytorch简单风格迁移的代码,本着不跑挂服务器不死心的态度,不停的增加计算步骤,看看图片融合生成的效果,

    为了方便一次性执行,把代码简单改造了一下,与前一篇文章大同小异:

      1 import torch
      2 import torch.nn as nn
      3 import torch.nn.functional as F
      4 import torch.optim as optim
      5 
      6 from PIL import Image
      7 import matplotlib.pyplot as plt
      8 
      9 import torchvision.transforms as transforms
     10 import torchvision.models as models
     11 import datetime
     12 
     13 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     14 
     15 
     16 def get_img_size(img_name):
     17     """
     18     获取图像大小
     19     :param img_name:
     20     :return:
     21     """
     22     im = Image.open(img_name).convert('RGB')
     23     return im, im.height, im.width
     24 
     25 
     26 def image_loader(img, im_h, im_w):
     27     """
     28     加载图像
     29     :param img:
     30     :param im_h:
     31     :param im_w:
     32     :return:
     33     """
     34 
     35     # loader = transforms.Compose([transforms.Resize([im_h, im_w]), transforms.ToTensor()])
     36     loader = transforms.Compose([transforms.Resize([1000, 1000]), transforms.ToTensor()])
     37     im_l = loader(img).unsqueeze(0)
     38     return im_l.to(device, torch.float)
     39 
     40 
     41 def im_show(tensor, save_file_path):
     42     """
     43     显示保存图片
     44     :param tensor:
     45     :param save_file_path:
     46     :return:
     47     """
     48     image = tensor.cpu().clone()
     49     image = image.squeeze(0)
     50     image = transforms.ToPILImage()(image)
     51     plt.imshow(image, aspect='equal')
     52     plt.axis('off')
     53     plt.savefig(save_file_path, bbox_inches='tight', pad_inches=0.0)
     54     plt.pause(0.001)
     55 
     56 
     57 class ContentLoss(nn.Module):
     58     """
     59     内容损失
     60     """
     61 
     62     def __init__(self, target,):
     63         super(ContentLoss, self).__init__()
     64         self.target = target.detach()
     65 
     66     def forward(self, cl_input):
     67         self.loss = F.mse_loss(cl_input, self.target)
     68         return cl_input
     69 
     70 
     71 def gram_matrix(gm_input):
     72     """
     73     风格损失矩阵
     74     :param gm_input:
     75     :return:
     76     """
     77     a, b, c, d = gm_input.size()
     78     features = gm_input.view(a * b, c * d)
     79     G = torch.mm(features, features.t())
     80 
     81     return G.div(a * b * c * d)
     82 
     83 
     84 class StyleLoss(nn.Module):
     85     """
     86     风格损失
     87     """
     88 
     89     def __init__(self, target_feature):
     90         super(StyleLoss, self).__init__()
     91         self.target = gram_matrix(target_feature).detach()
     92 
     93     def forward(self, fw_input):
     94         G = gram_matrix(fw_input)
     95         self.loss = F.mse_loss(G, self.target)
     96         return fw_input
     97 
     98 
     99 # 使用19层的VGG神经网络模型
    100 cnn = models.vgg19(pretrained=True).features.to(device).eval()
    101 
    102 
    103 cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
    104 cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
    105 
    106 
    107 class Normalization(nn.Module):
    108     """
    109     规范化输入图像
    110     """
    111     def __init__(self, mean, std):
    112         super(Normalization, self).__init__()
    113         self.mean = mean.clone().detach().view(-1, 1, 1)
    114         self.std = std.clone().detach().view(-1, 1, 1)
    115 
    116     def forward(self, img):
    117         return (img - self.mean) / self.std
    118 
    119 
    120 def get_style_model_and_losses(cn, normalization_mean, normalization_std, style_i, content_i, cld, sld):
    121     """
    122     获取内容损失和风格损失
    123     :param cn:
    124     :param normalization_mean:
    125     :param normalization_std:
    126     :param style_i:
    127     :param content_i:
    128     :param cld:
    129     :param sld:
    130     :return:
    131     """
    132 
    133     normalization = Normalization(normalization_mean, normalization_std).to(device)
    134     content_losses = []
    135     style_losses = []
    136 
    137     model = nn.Sequential(normalization)
    138 
    139     i = 0
    140     for layer in cn.children():
    141         if isinstance(layer, nn.Conv2d):
    142             i += 1
    143             name = 'conv_{}'.format(i)
    144         elif isinstance(layer, nn.ReLU):
    145             name = 'relu_{}'.format(i)
    146             layer = nn.ReLU(inplace=False)
    147         elif isinstance(layer, nn.MaxPool2d):
    148             name = 'pool_{}'.format(i)
    149         elif isinstance(layer, nn.BatchNorm2d):
    150             name = 'bn_{}'.format(i)
    151         else:
    152             raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))
    153 
    154         model.add_module(name, layer)
    155 
    156         if name in cld:
    157             target = model(content_i).detach()
    158             content_loss = ContentLoss(target)
    159             model.add_module("content_loss_{}".format(i), content_loss)
    160             content_losses.append(content_loss)
    161 
    162         if name in sld:
    163             target_feature = model(style_i).detach()
    164             style_loss = StyleLoss(target_feature)
    165             model.add_module("style_loss_{}".format(i), style_loss)
    166             style_losses.append(style_loss)
    167 
    168     for i in range(len(model) - 1, -1, -1):
    169         if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
    170             break
    171 
    172     model = model[:(i + 1)]
    173 
    174     return model, style_losses, content_losses
    175 
    176 
    177 def get_input_optimizer(input_i):
    178     """
    179     使用 L-BFGS 算法
    180     最小化风格、内容的损失
    181     :param input_i:
    182     :return:
    183     """
    184     optimizer = optim.LBFGS([input_i])
    185     return optimizer
    186 
    187 
    188 def run_style_transfer(cn, norma_mean, normalization_std, ct_img, sl_img, in_img, steps, style_weight, content_weight):
    189     """
    190     样式转换,建立风格迁移模型
    191     :param cn:
    192     :param norma_mean:
    193     :param normalization_std:
    194     :param ct_img:
    195     :param sl_img:
    196     :param in_img:
    197     :param steps:
    198     :param style_weight:
    199     :param content_weight:
    200     :return:
    201     """
    202     content_layers = ['conv_4']
    203     style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
    204     model, style_losses, content_losses = get_style_model_and_losses(cn, norma_mean, normalization_std, sl_img, ct_img, content_layers, style_layers)
    205     in_img.requires_grad_(True)
    206     model.requires_grad_(False)
    207 
    208     optimizer = get_input_optimizer(in_img)
    209     print('Optimizing..')
    210     run = [0]
    211     while run[0] <= steps:
    212 
    213         def closure():
    214             with torch.no_grad():
    215                 in_img.clamp_(0, 1)
    216 
    217             optimizer.zero_grad()
    218             model(in_img)
    219             style_score = 0
    220             content_score = 0
    221 
    222             for sl in style_losses:
    223                 style_score += sl.loss
    224             for cl in content_losses:
    225                 content_score += cl.loss
    226 
    227             style_score *= style_weight
    228             content_score *= content_weight
    229 
    230             loss = style_score + content_score
    231             loss.backward()
    232 
    233             run[0] += 1
    234             if run[0] % 50 == 0:
    235                 print("run {}:".format(run))
    236                 print('Style Loss : {:4f} Content Loss: {:4f}'.format(style_score.item(), content_score.item()))
    237             return style_score + content_score
    238 
    239         optimizer.step(closure)
    240     with torch.no_grad():
    241         in_img.clamp_(0, 1)
    242     return in_img
    243 
    244 
    245 def style_transfer(content_image_path, style_image_path, image_save_path, run_steps):
    246     """
    247     风格迁移主入口
    248     :param content_image_path: 内容图片
    249     :param style_image_path: 风格图片
    250     :param image_save_path: 存储图片地址
    251     :param run_steps: 执行计算次数
    252     :return:
    253     """
    254     c_image, c_im_h, c_im_w = get_img_size(content_image_path)
    255     s_image, s_im_h, s_im_w = get_img_size(style_image_path)
    256     content_img = image_loader(c_image, c_im_h, c_im_w)
    257     style_img = image_loader(s_image, c_im_h, c_im_w)
    258     assert style_img.size() == content_img.size()
    259     # 输入内容图像
    260     input_img = content_img.clone()
    261     begin_time = datetime.datetime.now()
    262     print("******************开始时间*****************", begin_time)
    263     output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std, content_img, style_img, input_img, run_steps, s_weight, c_weight)
    264     try:
    265         im_show(output, image_save_path)
    266     except Exception as e:
    267         print(e)
    268     print("******************结束时间*****************", datetime.datetime.now())
    269     print("******************耗时*****************", datetime.datetime.now() - begin_time)
    270 
    271 
    272 if __name__ == '__main__':
    273     s_weight = 1000000
    274     c_weight = 1
    275     # content_img_path = "data/drew/img/512.png"
    276     content_img_path = "/data/drew/img/dancing.jpg"
    277     # style_img_path = "data/drew/img/512r.png"
    278     style_img_path = "/data/drew/img/picasso.jpg"
    279     for steps in range(100, 3200, 200):
    280         # save_path = "data/drew/img/end_%s_%s.jpg" % (steps, datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
    281         save_path = "/data/drew/img/end_%s_%s.jpg" % (steps, datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
    282         style_transfer(content_img_path, style_img_path, save_path, steps)
    View Code

    计算步数:

    100  耗时16秒

      

     500  耗时1分18秒

      

     900   耗时2分20秒

      

     1300   耗时3分22秒

      

     1700  耗时4分24秒

      

    1900    耗时4分54秒

      

    2100  耗时5分25秒

      

    2300    耗时5分55秒

      

     2900   耗时7分29秒

      

     3100   耗时8分钟

      

    5000   耗时12分钟49秒

      

     10000   耗时25分钟33秒

       

    聚集对比

     

      

     可以看出,随着计算步骤的增加,图片融合细腻度越来越好,但指不定某个期间取到不同的色值,融合出的图有可能出现较大差异,程序是不知道那张图适合你的,得自己多次计算看那张相对优秀,

    个人感觉1700和3100看上去效果都不错,有时候算的越多反而越不适合。总的来说,训练越久,图片的色彩,颜色的深浅,颜色的层次及位置,都越来越与风格图片类似,比如30000次训练的结果,

    可以看出越来越接近风格图片。如果一直训练下去,比如训练个百万次,会是啥样?~~O(∩_∩)O,谁的机子牛逼的可以试试!!

     

     计算步数与耗时:

    100--------16.042588

    300--------45.938147

    500--------01:18.010645

    700--------01:49.314894

    900--------02:20.706951

    1100-------02:51.681638

    1300-------03:22.309657

    1500-------03:53.468090

    1700-------04:24.346204

    1900-------04:54.553869

    2100-------05:25.256131

    2300-------05:55.873070

    2500-------06:26.832513

    2700-------06:57.715941

    2900-------07:29.032707

    3100-------07:59.818595

    5000-------12:49.948595

    10000-----25:33.300846

    30000-----1:16:40.921489

    当然,调整两张图里的权重,迁移融合的效果差异也很大:

    比如我们把内容的权重由1调整为10,100,看看有啥效果:

     1 if __name__ == '__main__':
     2     s_weight = 1000000
     3     c_weight = 100

      

      

     左图为百万比一,中间图为十万比一,右图为一万比一,看看是不是效果有很大差别。

     例二:

    找了一幅竹图和一个竹字,试试效果:

       

    结论一样

     

    感兴趣的欢迎讨论探索!!

  • 相关阅读:
    SPOJ 1812 Longest Common Substring II(后缀自动机)(LCS2)
    HDU 4441 Queue Sequence(优先队列+Treap树)(2012 Asia Tianjin Regional Contest)
    HDU 4433 locker(DP)(2012 Asia Tianjin Regional Contest)
    HDU 4431 Mahjong(枚举+模拟)(2012 Asia Tianjin Regional Contest)
    NavigationBar的简单设置
    Android如何设置标题栏的高度
    android 在标题栏加上按钮
    MediaRecorder类介绍
    Android ADB server didn't ACK * failed to start daemon * 简单有效的解决方案
    2016/1/7 改 百文百鸡 水仙花数 百马百担
  • 原文地址:https://www.cnblogs.com/drewgg/p/15904654.html
Copyright © 2020-2023  润新知