• 风格迁移网络(vgg19提取特征,gram矩阵提取风格特征)


    from __future__ import division
    from torchvision import models
    from torchvision import transforms
    from PIL import Image
    import argparse
    import torch
    import torchvision
    import torch.nn as nn
    import numpy as np
    
    
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    def load_image(image_path, transform=None, max_size=None, shape=None):
        """Load an image and convert it to a torch tensor."""
        image = Image.open(image_path)
        
        if max_size:
            scale = max_size / max(image.size)
            size = np.array(image.size) * scale
            image = image.resize(size.astype(int), Image.ANTIALIAS)
        
        if shape:
            image = image.resize(shape, Image.LANCZOS)
        
        if transform:
            image = transform(image).unsqueeze(0)
        
        return image.to(device)
    
    
    class VGGNet(nn.Module):
        def __init__(self):
            """Select conv1_1 ~ conv5_1 activation maps."""
            super(VGGNet, self).__init__()
            self.select = ['0', '5', '10', '19', '28'] 
            self.vgg = models.vgg19(pretrained=True).features
            
        def forward(self, x):
            """Extract multiple convolutional feature maps."""
            features = []
            for name, layer in self.vgg._modules.items():
                x = layer(x)
                if name in self.select:
                    features.append(x)
            return features
    
    
    def main(config):
        
        # Image preprocessing
        # VGGNet was trained on ImageNet where images are normalized by mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
        # We use the same normalization statistics here.
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406), 
                                 std=(0.229, 0.224, 0.225))])
        #mean均值,std方差(rgb)
        
        # Load content and style images
        # Make the style image same size as the content image
        content = load_image(config.content, transform, max_size=config.max_size)
        style = load_image(config.style, transform, shape=[content.size(2), content.size(3)])
        
        # Initialize a target image with the content image
        target = content.clone().requires_grad_(True)
        
        optimizer = torch.optim.Adam([target], lr=config.lr, betas=[0.5, 0.999])
        vgg = VGGNet().to(device).eval()
        
        for step in range(config.total_step):
            
            # Extract multiple(5) conv feature vectors
            target_features = vgg(target)
            content_features = vgg(content)
            style_features = vgg(style)
    
            style_loss = 0
            content_loss = 0
            for f1, f2, f3 in zip(target_features, content_features, style_features):
                # Compute content loss with target and content images
                # 内容损失,l2损失函数
                content_loss += torch.mean((f1 - f2)**2)
    
    
                # Reshape convolutional feature maps
                _, c, h, w = f1.size()
                f1 = f1.view(c, h * w)
                f3 = f3.view(c, h * w)
    
                # Compute gram matrix
                # 图像矩阵点乘矩阵的转置
                f1 = torch.mm(f1, f1.t())  #.t() 转置函数
                f3 = torch.mm(f3, f3.t())
    
                # Compute style loss with target and style images
                style_loss += torch.mean((f1 - f3)**2) / (c * h * w) 
            
            # Compute total loss, backprop and optimize
            loss = content_loss + config.style_weight * style_loss 
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            if (step+1) % config.log_step == 0:
                print ('Step [{}/{}], Content Loss: {:.4f}, Style Loss: {:.4f}' 
                       .format(step+1, config.total_step, content_loss.item(), style_loss.item()))
    
            if (step+1) % config.sample_step == 0:
                # Save the generated image
                denorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44))
                img = target.clone().squeeze()
                img = denorm(img).clamp_(0, 1)
                torchvision.utils.save_image(img, 'output-{}.png'.format(step+1))
    
    
    if __name__ == "__main__":
        parser = argparse.ArgumentParser()
        parser.add_argument('--content', type=str, default='png/content.png')
        parser.add_argument('--style', type=str, default='png/style.png')
        parser.add_argument('--max_size', type=int, default=400)
        parser.add_argument('--total_step', type=int, default=2000)
        parser.add_argument('--log_step', type=int, default=10)
        parser.add_argument('--sample_step', type=int, default=500)
        parser.add_argument('--style_weight', type=float, default=100)
        parser.add_argument('--lr', type=float, default=0.003)
        config = parser.parse_args()
        print(config)
        main(config)
  • 相关阅读:
    Java实现热替换
    SQL判断字符串里不包含字母
    Useful bat command
    Entity FrameworkCore教程(一):包概念理解
    Docker:Docker常见命令
    ASP.NET Core:ASP.NET Core程序使用Docker部署
    ASP.NET Core:中间件
    ASP.NET Core:依赖注入
    Jenkins:创建定时构建任务
    ASP.NET Core 3.1使用Swagger
  • 原文地址:https://www.cnblogs.com/h694879357/p/16005041.html
Copyright © 2020-2023  润新知