• 【语义分割】large kernel matters中GCN模块的pytorch实现


    GCN模块的实现比较简单,在giuhub上看到两种实现,轻微不同

    实现一:https://github.com/ycszen/pytorch-segmentation/blob/master/gcn.py

    class GCN(nn.Module):
        def __init__(self, inplanes, planes, ks=7):
            super(GCN, self).__init__()
            self.conv_l1 = nn.Conv2d(inplanes, planes, kernel_size=(ks, 1),
                                     padding=(ks/2, 0))
    
            self.conv_l2 = nn.Conv2d(planes, planes, kernel_size=(1, ks),
                                     padding=(0, ks/2))
            self.conv_r1 = nn.Conv2d(inplanes, planes, kernel_size=(1, ks),
                                     padding=(0, ks/2))
            self.conv_r2 = nn.Conv2d(planes, planes, kernel_size=(ks, 1),
                                     padding=(ks/2, 0))
    
        def forward(self, x):
            x_l = self.conv_l1(x)
            x_l = self.conv_l2(x_l)
    
            x_r = self.conv_r1(x)
            x_r = self.conv_r2(x_r)
    
            x = x_l + x_r
    
            return x
    

    实现二:https://github.com/ogvalt/large_kernel_matters/blob/master/scripts/model.py

    class GCN(nn.Module):
        def __init__(self, inchannels, channels=21, k=3):
            super(GCN, self).__init__()
    
            self.conv_l1 = Conv2D(in_channels=inchannels, out_channels=channels, kernel_size=(k, 1), padding='same')
            self.conv_l2 = Conv2D(in_channels=channels, out_channels=channels, kernel_size=(1, k), padding='same')
    
            self.conv_r1 = Conv2D(in_channels=inchannels, out_channels=channels, kernel_size=(1, k), padding='same')
            self.conv_r2 = Conv2D(in_channels=channels, out_channels=channels, kernel_size=(k, 1), padding='same')
    
        def forward(self, x):
            x1 = self.conv_l1(x)
            x1 = self.conv_l2(x1)
    
            x2 = self.conv_r1(x)
            x2 = self.conv_r2(x2)
    
            out = x1 + x2
    
            return out
    

    两种实现不同之处在padding的方式,一种是设定值,一种是自动的。不过我发现pytorch0.4.0是不支持对padding关键字参数传入字符串的,另外,我自己写了一个3D版的,不知道对否。

    class GCN(nn.Module):
        def __init__(self, inplanes, planes, ks=7):
            super(GCN, self).__init__()
            self.conv_l1 = nn.Conv3d(inplanes, planes, kernel_size=(ks, 1, 1),
                                     padding=(ks/2, 0, 0))
            self.conv_l2 = nn.Conv3d(planes, planes, kernel_size=(1, ks, 1),
                                     padding=(0, ks/2, 0))
            self.conv_l3 = nn.Conv3d(planes, planes, kernel_size=(1, 1, ks),
                                     padding=(0, 0, ks/2))
    
            self.conv_c1 = nn.Conv3d(inplanes, planes, kernel_size=(1, ks, 1),
                                     padding=(0, ks/2, 0))
            self.conv_c2 = nn.Conv3d(planes, planes, kernel_size=(1, 1, ks),
                                     padding=(0, 0, ks/2))
            self.conv_c3 = nn.Conv3d(planes, planes, kernel_size=(ks, 1, 1),
                                     padding=(ks/2, 0, 0))
    
            self.conv_r1 = nn.Conv3d(inplanes, planes, kernel_size=(1, 1, ks),
                                     padding=(0, 0, ks/2))
            self.conv_r2 = nn.Conv3d(planes, planes, kernel_size=(ks, 1, 1),
                                     padding=(ks/2, 0, 0))
            self.conv_r3 = nn.Conv3d(planes, planes, kernel_size=(1, ks, 1),
                                     padding=(0, ks/2, 0))
    
        def forward(self, x):
            x_l = self.conv_l1(x)
            x_l = self.conv_l2(x_l)
            x_l = self.conv_l3(x_l)
    
            x_c = self.conv_c1(x)
            x_c = self.conv_c2(x_c)
            x_c = self.conv_c3(x_c)
    
            x_r = self.conv_r1(x)
            x_r = self.conv_r2(x_r)
            x_r = self.conv_r3(x_r)
            x = x_l + x_r + x_c
    
            return x
    

      

  • 相关阅读:
    Android-调用优酷SDK上传视频
    新浪微博客户端(16)-获得并显示用户昵称
    新浪微博客户端(15)-保存用户名和密码
    转:Java NIO系列教程(九) Pipe
    新浪微博客户端(14)-截取回调地址中的授权成功的请求标记,换取access_token
    iOS-AFN "Request failed: unacceptable content-type: text/plain"
    新浪微博客户端(13)-使用UIWebView加载OAuth授权界面
    iOS-(kCFStreamErrorDomainSSL, -9802)
    转:Java NIO系列教程(八) DatagramChannel
    转:Java NIO系列教程(七) Socket Channel
  • 原文地址:https://www.cnblogs.com/wzyuan/p/10191929.html
Copyright © 2020-2023  润新知