• BoTNet:Bottleneck Transformers for Visual Recognition


    【GiantPandaCV导语】基于Transformer的骨干网络,同时使用卷积与自注意力机制来保持全局性和局部性。模型在ResNet最后三个BottleNeck中使用了MHSA替换3x3卷积。属于早期的结合CNN+Transformer的工作。简单来讲Non-Local+Self Attention+BottleNeck = BoTNet

    引言

    本文的发展脉络如下图所示:

    实际上沿着Transformer Block改进的方向进行的,与CNN架构也是兼容的。具体结构如下图所示:

    两者都遵循了BottleNeck的设计原则,可以有效降低计算量。

    设计Transformer中self attention存在几个挑战:

    • 图片尺寸比较大,比如目标检测中分辨率在1024x1024
    • 内存和计算量的占用高,导致训练开销比较大。

    本文设计如下:

    • 使用卷积识别底层特征的抽象信息。
    • 使用self attention处理通过卷积层得到的高层信息。

    这样可以有效处理大分辨率图像。

    方法

    BoTNet中MHSA模块如下图所示:

    上边的这个MHSA Block是核心创新点,其与Transformer中的MHSA有所不同:

    • 由于处理对象不是一维的,而是类似CNN模型,所以有非常多特性与此相关。
    • 归一化这里并没有使用Layer Norm而是采用的Batch Norm,与CNN一致。
    • 非线性激活,BoTNet使用了三个非线性激活
    • 左侧content-position模块引入了二维的位置编码,这是与Transformer中最大区别。

    由于该模块是处理BxCHW的形式,所以难免让人想起来Non Local 操作,这里列出笔者以前绘制的一幅图:

    可以看出主要区别就是在于Content-postion模块引入的位置信息。

    BoTNet细节设计:

    整体的设计和ResNet50几乎一样,唯一不同在于最后一个阶段中三个BottleNeck使用了MHSA模块。具体这样做的原因是Self attention需要消耗巨大的计算量,在模型最后加入时候feature map的size比较小,相对而言计算量比较小。

    实验

    在目标检测和分割领域性能对比

    分辨率改变对BoTNet帮助更大

    消融实验-相对位置编码

    BoTNet对ResNet系列模型的提升

    BoTNet与更大的图片适配

    BoTNet与Non-Local Net的比较

    与ImageNet上结果比较

    模型放缩的影响

    显卡香气飘来,又是谷歌的骚操作,将EfficientNet方法放在BoTNet上:

    可以看出与期望相符合,Transformer架构带来的性能上限要高于CNN,虽然模型大小比较小的时候性能比较弱,但是模型量变大以后其性能就有了保证。

    代码

    核心模块:MHSA (由第三方进行实现)

    class MHSA(nn.Module):
        def __init__(self, n_dims, width=14, height=14, heads=4):
            super(MHSA, self).__init__()
            self.heads = heads
    
            self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1)
            self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1)
            self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1)
    
            self.rel_h = nn.Parameter(torch.randn([1, heads, n_dims // heads, 1, height]), requires_grad=True)
            self.rel_w = nn.Parameter(torch.randn([1, heads, n_dims // heads, width, 1]), requires_grad=True)
    
            self.softmax = nn.Softmax(dim=-1)
    
        def forward(self, x):
            n_batch, C, width, height = x.size()
            q = self.query(x).view(n_batch, self.heads, C // self.heads, -1)
            k = self.key(x).view(n_batch, self.heads, C // self.heads, -1)
            v = self.value(x).view(n_batch, self.heads, C // self.heads, -1)
    
            content_content = torch.matmul(q.permute(0, 1, 3, 2), k)
    
            content_position = (self.rel_h + self.rel_w).view(1, self.heads, C // self.heads, -1).permute(0, 1, 3, 2)
            content_position = torch.matmul(content_position, q)
    
            energy = content_content + content_position
            attention = self.softmax(energy)
    
            out = torch.matmul(v, attention.permute(0, 1, 3, 2))
            out = out.view(n_batch, C, width, height)
    
            return out
    

    参考

    https://arxiv.org/abs/2101.11605

    https://zhuanlan.zhihu.com/p/347849929

    https://github.com/leaderj1001/BottleneckTransformers/blob/main/model.py

    跑不动ImageNet,想试试Vision Transformer的同学可以看看这个仓库,

    https://github.com/pprp/pytorch-cifar-model-zoo

    在CIFAR10上测试:

    python train.py --model 'botnet' --name "fast_training" --sched 'cosine' --epochs 100 --cutout True --lr 0.1 --bs 128 --nw 4
    

    目前可以在100个epoch内达到验证集91.1%的准确率。

    代码改变世界
  • 相关阅读:
    用户模板和用户场景
    移动端疫情展示
    数据爬取
    全国疫情统计可视化地图-第二、三阶段
    学习进度条-第三周
    学习进度条-第二周
    软件工程第二周开课博客
    返回一个整数数组中最大子数组的和
    JavaWeb选课系统(2)
    JavaWeb选课系统
  • 原文地址:https://www.cnblogs.com/pprp/p/15736209.html
Copyright © 2020-2023  润新知