• pytorch 不使用转置卷积来实现上采样


    上采样(upsampling)一般包括2种方式:

    第二种方法如何用pytorch实现可见上面的链接

    这里想要介绍的是如何使用pytorch实现第一种方法:

    举例:

    1)使用torch.nn模块实现一个生成器为:

    import torch.nn as nn
    import torch.nn.functional as F
    
    
    class ConvLayer(nn.Module):
        def __init__(self, in_channels, out_channels, kernel_size, stride):
            super(ConvLayer, self).__init__()
            padding = kernel_size // 2
            self.reflection_pad = nn.ReflectionPad2d(padding)
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
    
        def forward(self, x):
            out = self.reflection_pad(x)
            out = self.conv(out)
    
            return out
    
    class Generator(nn.Module):
        def __init__(self, in_channels):
            super(Generator, self).__init__()
            self.in_channels = in_channels
    
            self.encoder = nn.Sequential(
                ConvLayer(self.in_channels, 32, 3, 2),
                nn.BatchNorm2d(32),
                nn.ReLU(),
                ConvLayer(32, 64, 3, 2),
                nn.BatchNorm2d(64),
                nn.ReLU(),
                ConvLayer(64, 128, 3, 2),
            )
    
            upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.decoder = nn.Sequential(
                upsample,
                nn.Conv2d(128, 64, 1),
                nn.BatchNorm2d(64),
                nn.ReLU(),
                upsample,
                nn.Conv2d(64, 32, 1),
                nn.BatchNorm2d(32),
                nn.ReLU(),
                upsample,
                nn.Conv2d(32, 3, 1),
                nn.Tanh()
            )
    
        def forward(self, x):
            x = self.encoder(x)
            out = self.decoder(x)
    
            return out
    
    def test():
        net = Generator(3)
        for module in net.children():
            print(module)
        x = Variable(torch.randn(2,3,224,224))
        output = net(x)
        print('output :', output.size())
        print(type(output))
    
    if __name__ == '__main__':
        test()
    View Code

    返回:

    model.py .Sequential(
      (0): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2))
      )
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
      )
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2))
      )
    )
    Sequential(
      (0): Upsample(scale_factor=2, mode=bilinear)
      (1): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU()
      (4): Upsample(scale_factor=2, mode=bilinear)
      (5): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
      (6): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (7): ReLU()
      (8): Upsample(scale_factor=2, mode=bilinear)
      (9): Conv2d(32, 3, kernel_size=(1, 1), stride=(1, 1))
      (10): Tanh()
    )
    output : torch.Size([2, 3, 224, 224])
    <class 'torch.Tensor'>
    View Code

    但是这个会有警告:

     UserWarning: nn.Upsample is deprecated. Use nn.functional.interpolate instead.

    可使用torch.nn.functional模块替换为:

    import torch.nn as nn
    import torch.nn.functional as F
    
    
    class ConvLayer(nn.Module):
        def __init__(self, in_channels, out_channels, kernel_size, stride):
            super(ConvLayer, self).__init__()
            padding = kernel_size // 2
            self.reflection_pad = nn.ReflectionPad2d(padding)
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
    
        def forward(self, x):
            out = self.reflection_pad(x)
            out = self.conv(out)
    
            return out
    
    class Generator(nn.Module):
        def __init__(self, in_channels):
            super(Generator, self).__init__()
            self.in_channels = in_channels
    
            self.encoder = nn.Sequential(
                ConvLayer(self.in_channels, 32, 3, 2),
                nn.BatchNorm2d(32),
                nn.ReLU(),
                ConvLayer(32, 64, 3, 2),
                nn.BatchNorm2d(64),
                nn.ReLU(),
                ConvLayer(64, 128, 3, 2),
            )
    
            self.decoder1 = nn.Sequential(
                nn.Conv2d(128, 64, 1),
                nn.BatchNorm2d(64),
                nn.ReLU()
            )
            self.decoder2 = nn.Sequential(
                nn.Conv2d(64, 32, 1),
                nn.BatchNorm2d(32),
                nn.ReLU()
            )
            self.decoder3 = nn.Sequential(
                nn.Conv2d(32, 3, 1),
                nn.Tanh()
            )
    
        def forward(self, x):
            x = self.encoder(x)
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
            x = self.decoder1(x)
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
            x = self.decoder2(x)
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
            out = self.decoder3(x)
    
            return out
    
    def test():
        net = Generator(3)
        for module in net.children():
            print(module)
        x = Variable(torch.randn(2,3,224,224))
        output = net(x)
        print('output :', output.size())
        print(type(output))
    
    if __name__ == '__main__':
        test()
    View Code

    返回:

    model.py .Sequential(
      (0): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2))
      )
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
      )
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2))
      )
    )
    Sequential(
      (0): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    Sequential(
      (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    Sequential(
      (0): Conv2d(32, 3, kernel_size=(1, 1), stride=(1, 1))
      (1): Tanh()
    )
    output : torch.Size([2, 3, 224, 224])
    <class 'torch.Tensor'>
    View Code
  • 相关阅读:
    joson返回数据库的时间格式在前台用js转换
    网站图片不存在,显示默认图片解决办法
    最常用的截取函数有left,right,substring
    Fortran 基础语法(一)
    SQL Server2008附加数据库之后显示为只读时解决方法
    table 控制单双行颜色以及鼠标hover颜色 table光棒
    select change下拉框改变事件 设置选定项,禁用select
    “如何稀释scroll事件”引出的问题
    自给自足:动手打造html5俄罗斯方块
    一个可以拓展的垂直多级导航栏 Demo
  • 原文地址:https://www.cnblogs.com/wanghui-garcia/p/11400866.html
Copyright © 2020-2023  润新知