• SegNet网络的Pytorch实现


    1.文章原文地址

    SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation

    2.文章摘要

    语义分割具有非常广泛的应用,从场景理解、目标相互关系推断到自动驾驶。早期依赖于低水平视觉线索的方法已经快速的被流行的机器学习算法所取代。特别是最近的深度学习在手写数字识别、语音、图像中的分类和目标检测上取得巨大成功。如今有一个活跃的领域是语义分割(对每个像素进行归类)。然而,最近有一些方法直接采用了为图像分类而设计的网络结构来进行语义分割任务。虽然结果十分鼓舞人心,但还是比较粗糙。这首要的原因是最大池化和下采样减小了特征图的分辨率。我们设计SegNet的动机来自于分割任务需要将低分辨率的特征图映射到输入的分辨率并进行像素级分类,这个映射必须产生对准确边界定位有用的特征。

    3.网络结构

    4.Pytorch实现

      1 import torch.nn as nn
      2 import torch
      3 
      4 class conv2DBatchNormRelu(nn.Module):
      5     def __init__(self,in_channels,out_channels,kernel_size,stride,padding,
      6                  bias=True,dilation=1,is_batchnorm=True):
      7         super(conv2DBatchNormRelu,self).__init__()
      8         if is_batchnorm:
      9             self.cbr_unit=nn.Sequential(
     10                 nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding,
     11                           bias=bias,dilation=dilation),
     12                 nn.BatchNorm2d(out_channels),
     13                 nn.ReLU(inplace=True),
     14             )
     15         else:
     16             self.cbr_unit=nn.Sequential(
     17                 nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
     18                           bias=bias, dilation=dilation),
     19                 nn.ReLU(inplace=True)
     20             )
     21 
     22     def forward(self,inputs):
     23         outputs=self.cbr_unit(inputs)
     24         return outputs
     25 
     26 class segnetDown2(nn.Module):
     27     def __init__(self,in_channels,out_channels):
     28         super(segnetDown2,self).__init__()
     29         self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
     30         self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
     31         self.maxpool_with_argmax=nn.MaxPool2d(kernel_size=2,stride=2,return_indices=True)
     32 
     33     def forward(self,inputs):
     34         outputs=self.conv1(inputs)
     35         outputs=self.conv2(outputs)
     36         unpooled_shape=outputs.size()
     37         outputs,indices=self.maxpool_with_argmax(outputs)
     38         return outputs,indices,unpooled_shape
     39 
     40 class segnetDown3(nn.Module):
     41     def __init__(self,in_channels,out_channels):
     42         super(segnetDown3,self).__init__()
     43         self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
     44         self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
     45         self.conv3=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
     46         self.maxpool_with_argmax=nn.MaxPool2d(kernel_size=2,stride=2,return_indices=True)
     47 
     48     def forward(self,inputs):
     49         outputs=self.conv1(inputs)
     50         outputs=self.conv2(outputs)
     51         outputs=self.conv3(outputs)
     52         unpooled_shape=outputs.size()
     53         outputs,indices=self.maxpool_with_argmax(outputs)
     54         return outputs,indices,unpooled_shape
     55 
     56 
     57 class segnetUp2(nn.Module):
     58     def __init__(self,in_channels,out_channels):
     59         super(segnetUp2,self).__init__()
     60         self.unpool=nn.MaxUnpool2d(2,2)
     61         self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
     62         self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
     63 
     64     def forward(self,inputs,indices,output_shape):
     65         outputs=self.unpool(inputs,indices=indices,output_size=output_shape)
     66         outputs=self.conv1(outputs)
     67         outputs=self.conv2(outputs)
     68         return outputs
     69 
     70 class segnetUp3(nn.Module):
     71     def __init__(self,in_channels,out_channels):
     72         super(segnetUp3,self).__init__()
     73         self.unpool=nn.MaxUnpool2d(2,2)
     74         self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
     75         self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
     76         self.conv3=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
     77 
     78     def forward(self,inputs,indices,output_shape):
     79         outputs=self.unpool(inputs,indices=indices,output_size=output_shape)
     80         outputs=self.conv1(outputs)
     81         outputs=self.conv2(outputs)
     82         outputs=self.conv3(outputs)
     83         return outputs
     84 
     85 class segnet(nn.Module):
     86     def __init__(self,in_channels=3,num_classes=21):
     87         super(segnet,self).__init__()
     88         self.down1=segnetDown2(in_channels=in_channels,out_channels=64)
     89         self.down2=segnetDown2(64,128)
     90         self.down3=segnetDown3(128,256)
     91         self.down4=segnetDown3(256,512)
     92         self.down5=segnetDown3(512,512)
     93 
     94         self.up5=segnetUp3(512,512)
     95         self.up4=segnetUp3(512,256)
     96         self.up3=segnetUp3(256,128)
     97         self.up2=segnetUp2(128,64)
     98         self.up1=segnetUp2(64,64)
     99         self.finconv=conv2DBatchNormRelu(64,num_classes,3,1,1)
    100 
    101     def forward(self,inputs):
    102         down1,indices_1,unpool_shape1=self.down1(inputs)
    103         down2,indices_2,unpool_shape2=self.down2(down1)
    104         down3,indices_3,unpool_shape3=self.down3(down2)
    105         down4,indices_4,unpool_shape4=self.down4(down3)
    106         down5,indices_5,unpool_shape5=self.down5(down4)
    107 
    108         up5=self.up5(down5,indices=indices_5,output_shape=unpool_shape5)
    109         up4=self.up4(up5,indices=indices_4,output_shape=unpool_shape4)
    110         up3=self.up3(up4,indices=indices_3,output_shape=unpool_shape3)
    111         up2=self.up2(up3,indices=indices_2,output_shape=unpool_shape2)
    112         up1=self.up1(up2,indices=indices_1,output_shape=unpool_shape1)
    113         outputs=self.finconv(up1)
    114 
    115         return outputs
    116 
    117 if __name__=="__main__":
    118     inputs=torch.ones(1,3,224,224)
    119     model=segnet()
    120     print(model(inputs).size())
    121     print(model)

    参考

    https://github.com/meetshah1995/pytorch-semseg

  • 相关阅读:
    [dubbo实战] dubbo+zookeeper伪集群搭建 (转)
    [Dubbo实战]dubbo + zookeeper + spring 实战 (转)
    DUBBO本地搭建及小案例 (转)
    【Dubbo实战】 Dubbo+Zookeeper+Spring整合应用篇-Dubbo基于Zookeeper实现分布式服务(转)
    Quartz集成springMVC 的方案二(持久化任务、集群和分布式)
    【Quartz】Quartz的搭建、应用(单独使用Quartz)
    Javascript判断Crontab表达式是否合法
    给Java程序员的几条建议
    使用maven编译Java项目
    使用Docker运行Java Web应用
  • 原文地址:https://www.cnblogs.com/ys99/p/10900870.html
Copyright © 2020-2023  润新知