• [论文理解] Connectionist Text Proposal Network


    Connectionist Text Proposal Network

    简介

    CTPN是通过VGG16后在特征图上采用3*3窗口进行滑窗,采用与RPN类似的anchor机制,固定width而只预测anchor的y坐标和高度,达到比较精准的text proposal效果。同时,文章的亮点在于引入了RNN,使用BLSTM使得预测更加精准。CTPN在自然场景下文本提取的效果很不错,不同于传统的bottom-up方法,传统方法通过检测单个字符然后再去连接文本线,其准确性主要依赖于单个字符的识别,而且错误会累积,其使用的仅仅是low-level的feature;而本文采用的方法提取的是深度的特征,采用anchor机制做的精准预测,然后用循环神经网络对anchor识别的区域进行连接,精度要高很多。

    结构:

    Detecting Text in Fine-scale Proposals

    detection过程很简单,直接在vgg-16后面用3*3的滑窗去滑feature map的最后一个卷积层,固定感受野大小为228pixels,total stride为16pixels,这样每个anchor对应在原图中的间隔就是16pixels。total stride和感受野的大小都是由网络结构决定的,也就是说,在网络结构确定的情况下,我们可以人为地去设置感受野的大小和total stride,由于total stride = s *2 *2 *2 *2,由于设置的total stride =16 ,所以可以确定3*3的stride是1,也就是后面每个anchor的水平距离在原图中对应的是16pixels。

    之后,作者修改了原始的rpn,去预测长度固定为16pixels的区域,与rpn不同的是,本文只预测区域的y轴坐标和高度,此外,还输出anchor是或不是文字区域的二分类结果。由于上面确定了每次anchor移动的距离恰好是total stride,所以这里对应上了。然后对每个特征点设计了10种vertical anchor,这些anchor的宽度都为16pixels,高度从11 到 273pixels(每次除以0.7),让这10个anchor独立地预测中心点坐标(vc)和高度(vh),定义如下:

    对每个预测而言,水平坐标和k个anchor的位置是固定的,这些都是可以预先在图像进来之后计算出来的,而分类器输出的结果是text/non-text的得分和预测的k个anchor的y轴坐标(v)。而识别出来的text proposals 是从那些text/non-text的得分大于0.7,然后再经过MNS得到的。这样只预测纵坐标的做法比rpn的准确率提升了很多,因为其提供了更多的监督信息。

    Recurrent Connectionist Text Proposals

    本文的亮点就在于使用了循环神经网络来连接text proposals,为了提升定位的准确率,作者把文本线看成是一连串的text proposals,然后去单独预测,但是这样做发现很容易错将非文字区域识别为文字区域。由于RNN对处理上下文很好,而文字有着很强的上下文关联,所以作者顺理成章的引入RNN,将conv5层的feature的每个window扫描后的结果作为RNN的输入,然后循环更新这个隐状态定义如下:

    作者使用的是双向LSTM作为RNN的结构,因此每个window都具有他之前的window的上下文信息,每个window的卷积特征作为256D的 双向lstm的输入,然后将每个隐状态全连接到输出层,预测第t个proposal。

    使用RNN后,明显减少了错误的识别,将很多之前没识别到的地方也识别到了,说明上下文信息对预测确实很有帮助。

    Side-refinement

    由于预测的text proposal 可能与ground truth在最左和最右两边不一定重叠度高,所以可能被弃掉,因此提出了边框修正,来修正这一点,如果不修正,那么预测到的proposal的文字区域可能在两边有缺失。

    结果如下

    Outputs And Loss Functions

    模型一共有三个输出,分别是text/non-text scores、竖直坐标v(包括anchor在原图中对应的竖直坐标和高度)以及修正系数o。对于每个特征点k个anchor,分别输出2k,2k,k个参数,而文章也是采用了多任务学习来进行优化模型参数,模型的loss functions定义如下:

    分类误差用的是softmax计算的,回归误差用的是smooth L1函数计算的,两个λ是为了调整loss的权重。

    论文原文

    简单写的model:

    import torch
    
    import torch.nn as nn
    from torchvision.models import vgg16
    from torchsummary import summary
    
    
    class Backbone(nn.Module):
        def __init__(self):
            super(Backbone,self).__init__()
            self.feature_extractor = vgg16(pretrained = False).features
        def forward(self,x):
            return self.feature_extractor(x)
    class BasicConv(nn.Module):
        def __init__(self,in_size,out_size,kernel_size):
            super(BasicConv,self).__init__()
            self.basic = nn.Sequential(
                nn.Conv2d(in_size,out_size,kernel_size = kernel_size),
                nn.BatchNorm2d(out_size),
                nn.ReLU()
            )
        def forward(self,x):
            return self.basic(x)
    
    class CTPN(nn.Module):
        def __init__(self):
            super(CTPN,self).__init__()
            self.backbone = Backbone()
            self.brnn = nn.GRU(512,128,2,bidirectional = True)
            self.fc = BasicConv(256,512,1)
            self.coordinates = BasicConv(512,20,1)
            self.scores = BasicConv(512,20,1)
            self.sides = BasicConv(512,10,1)
        def forward(self,x):
            x = self.backbone(x) # (b,c,h,w)
            s = x.permute(0,2,3,1).contiguous().view(-1,x.size(3),512) # (b,h,w,c) -> (bh,w,c)
            s,_ = self.brnn(s) # (bh,w,2c)
            s = s.view(-1, x.size(2),x.size(3),256).permute(0,3,1,2).contiguous()# (b,2c,h,w)
            output = self.fc(s) # (b,512,h,w)
            coordinates = self.coordinates(output).permute(0,2,3,1).contiguous().view(-1,10 * output.size(2) * output.size(3),2)
            scores = self.scores(output).permute(0,2,3,1).contiguous().view(-1,10 * output.size(2) * output.size(3),2)
            sides = self.sides(output).permute(0,2,3,1).contiguous().view(-1,10 * output.size(2) * output.size(3),1)
            return [coordinates,scores,sides]
    if __name__ == "__main__":
        net = CTPN()
        summary(net,(3,224,224),device = "cpu") # -1 512 7 7 
        #for name,module in net.named_children():
        #    print(name,module)
    
    
    
    ----------------------------------------------------------------
            Layer (type)               Output Shape         Param #
    ================================================================
                Conv2d-1         [-1, 64, 224, 224]           1,792
                  ReLU-2         [-1, 64, 224, 224]               0
                Conv2d-3         [-1, 64, 224, 224]          36,928
                  ReLU-4         [-1, 64, 224, 224]               0
             MaxPool2d-5         [-1, 64, 112, 112]               0
                Conv2d-6        [-1, 128, 112, 112]          73,856
                  ReLU-7        [-1, 128, 112, 112]               0
                Conv2d-8        [-1, 128, 112, 112]         147,584
                  ReLU-9        [-1, 128, 112, 112]               0
            MaxPool2d-10          [-1, 128, 56, 56]               0
               Conv2d-11          [-1, 256, 56, 56]         295,168
                 ReLU-12          [-1, 256, 56, 56]               0
               Conv2d-13          [-1, 256, 56, 56]         590,080
                 ReLU-14          [-1, 256, 56, 56]               0
               Conv2d-15          [-1, 256, 56, 56]         590,080
                 ReLU-16          [-1, 256, 56, 56]               0
            MaxPool2d-17          [-1, 256, 28, 28]               0
               Conv2d-18          [-1, 512, 28, 28]       1,180,160
                 ReLU-19          [-1, 512, 28, 28]               0
               Conv2d-20          [-1, 512, 28, 28]       2,359,808
                 ReLU-21          [-1, 512, 28, 28]               0
               Conv2d-22          [-1, 512, 28, 28]       2,359,808
                 ReLU-23          [-1, 512, 28, 28]               0
            MaxPool2d-24          [-1, 512, 14, 14]               0
               Conv2d-25          [-1, 512, 14, 14]       2,359,808
                 ReLU-26          [-1, 512, 14, 14]               0
               Conv2d-27          [-1, 512, 14, 14]       2,359,808
                 ReLU-28          [-1, 512, 14, 14]               0
               Conv2d-29          [-1, 512, 14, 14]       2,359,808
                 ReLU-30          [-1, 512, 14, 14]               0
            MaxPool2d-31            [-1, 512, 7, 7]               0
             Backbone-32            [-1, 512, 7, 7]               0
                  GRU-33  [[-1, 7, 256], [-1, 7, 128]]               0
               Conv2d-34            [-1, 512, 7, 7]         131,584
          BatchNorm2d-35            [-1, 512, 7, 7]           1,024
                 ReLU-36            [-1, 512, 7, 7]               0
            BasicConv-37            [-1, 512, 7, 7]               0
               Conv2d-38             [-1, 20, 7, 7]          10,260
          BatchNorm2d-39             [-1, 20, 7, 7]              40
                 ReLU-40             [-1, 20, 7, 7]               0
            BasicConv-41             [-1, 20, 7, 7]               0
               Conv2d-42             [-1, 20, 7, 7]          10,260
          BatchNorm2d-43             [-1, 20, 7, 7]              40
                 ReLU-44             [-1, 20, 7, 7]               0
            BasicConv-45             [-1, 20, 7, 7]               0
               Conv2d-46             [-1, 10, 7, 7]           5,130
          BatchNorm2d-47             [-1, 10, 7, 7]              20
                 ReLU-48             [-1, 10, 7, 7]               0
            BasicConv-49             [-1, 10, 7, 7]               0
    ================================================================
    Total params: 14,873,046
    Trainable params: 14,873,046
    Non-trainable params: 0
    ----------------------------------------------------------------
    Input size (MB): 0.57
    Forward/backward pass size (MB): 207.18
    Params size (MB): 56.74
    Estimated Total Size (MB): 264.49
    
    
  • 相关阅读:
    MySQL学习——操作表
    MySQL学习——数据类型
    MySQL学习——操作数据库
    MySQL学习——存储引擎
    Linux网络——配置防火墙的相关命令
    查询各分类中最大自增ID
    CentOS7下Rsync+sersync实现数据实时同步
    mysql的join连接查询优化经历
    搭建nginx代理支持前端页面跨域调用接口
    Centos查看系统CPU个数、核心数、线程数
  • 原文地址:https://www.cnblogs.com/aoru45/p/10498444.html
Copyright © 2020-2023  润新知