• torch中的mask:masked_fill, masked_select, masked_scatter


    1. 简介

      pytorch提供mask机制用来提取数据中“感兴趣”的部分。过程如下:左边的矩阵是原数据,中间的mask是遮罩矩阵,标记为1的表明对这个位置的数据“感兴趣”-保留,反之舍弃。整个过程可以视作是在原数据上盖了一层mask,只有感兴趣的部分(值为1)显露出来,而其他部分则背遮住。(matlab中也有mask操作)

      mask为一个和元数据size相匹配的tensor-bool,相匹配: broadcastable-广播机制。如一个2*3*3的原数据可以由一个3*3的mask来提取。

      mask一般是先建立0/1矩阵,然后通过tensor.bool()来转为bool类型的tensor,其他true表示原数据被遮住或者被选中,false表示原数据没有被遮住或者未被选中:这句话在下面的演示中更容易理解。

    2. 程序演示

      这里涉及的是torch中的三个常见mask函数:masked_fill, masked_select, masked_scatter。

      先构造好input和mask矩阵:

    imgs = torch.randint(0, 255, [2, 3, 3], dtype=torch.float32)
    """
    tensor([[[182., 242.,  11.],
             [163.,  92., 183.],
             [222.,  54.,  86.]],
            [[157., 139., 254.],
             [158., 148.,  46.],
             [  1.,  13.,  56.]]])
    """
    mask = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]]).bool()
    """
    tensor([[ True, False, False],
            [False,  True, False],
            [False, False,  True]])
    """

    1)torch.masked_fill(input, mask, value)

      参数:  

    • input:输入的原数据
    • mask:遮罩矩阵
    • value:被“遮住的”部分填充的数据,可以取0、1等值,数据类型不限,int、float均可

      返回值:一个和input相同size的masked-tensor

      使用:

    • output = torch.masked_fill(input, mask, value)
    • output = input.masked_fill(mask, value)
    imgs_masked = torch.masked_fill(input=imgs, mask=~mask, value=0) # 这里mask取反:true表示被“遮住的”
    """
    tensor([[[182.,   0.,   0.],
             [  0.,  92.,   0.],
             [  0.,   0.,  86.]],
            [[157.,   0.,   0.],
             [  0., 148.,   0.],
             [  0.,   0.,  56.]]])
    """

    2)torch.masked_select(input, mask, out)

      参数:  

    • input:输入的原数据
    • mask:遮罩矩阵
    • out:输出的结果,和原tensor不共用内存,一般在左侧接收,而不在形参中赋值

      返回值:一维tensor,数据为“选中”的数据

      使用:

    • torch.masked_select(input, mask, out)
    • output = input.masked_select(mask)
    selected_ele = torch.masked_select(input=imgs, mask=mask)  # true表示selected,false则未选中,所以这里没有取反
    # tensor([182., 92., 86., 157., 148., 56.])

    3)torch.masked_scatter(input, mask, source)

      说明:将从input中mask得到的数据赋值到source-tensor中

      参数:  

    • input:输入的原数据
    • mask:遮罩矩阵
    • source:遮罩矩阵的”样子“(全零还是全一或是其他),true表示遮住了

      返回值:一个和source相同size的masked-tensor

      使用:

    • output = torch.masked_scatter(input, mask, source)
    • output = input.masked_scatter(mask, source)
    source = torch.zeros_like(imgs)
    imgs_masked_copied = torch.masked_scatter(input=imgs, mask=~mask, source=source)
    """
    tensor([[[173.,   0.,   0.],
             [  0.,  77.,   0.],
             [  0.,   0., 159.]],
            [[ 85.,   0.,   0.],
             [  0., 184.,   0.],
             [  0.,   0., 223.]]])
    """

    3. 参考链接

  • 相关阅读:
    linux:安装php7.x
    linux:搭建 WordPress 个人站点
    linux:lnmp环境
    knn初了解
    Pycharm:鼠标滚动控制字体大小
    数据集的获取
    弄懂Java的自增变量
    面试中的volatile关键字
    Java的类锁、对象锁和方法锁
    Error creating bean with name 'entityManagerFactory' defined in class path resource解决方案
  • 原文地址:https://www.cnblogs.com/YuanShiRenY/p/torch_mask.html
Copyright © 2020-2023  润新知