• pytorch 的 sum 和 softmax 方法 dim 参数的使用


      在阅读使用 pytorch 实现的代码时,笔者会遇到需要对某一维数据进行求和( sum )或 softmax 的操作。在 pytorch 中,上述两个方法均带有一个指定维度的 dim 参数,这里记录下 dim 参数的用法。

      torch.sum

      在 pytorch 中,提供 torch.sum 的两种形式,一种直接将待求和数据作为参数,则返回参数数据所有维度所有元素的和,另外一种除接收待求和数据作为参数外,还可加入 dim 参数,指定对待求和数据的某一维进行求和。

        out = torch.sum( a )                   #对 a 中所有元素求和
        out = torch.sum( a , dim = 1 )         #对 a 中第 1 维的元素求和

      上述第一种形式比较好理解,但第二种形式,加入 dim 参数后,比较令人疑惑的是到底哪些元素参与了求和?这里通过例子来进行说明。

      1)首先我们生成一个维度为 ( 3, 4, 5, 6 ) 的元素全为 1.0 的 tensor a。

        >>> import torch
        >>> a = torch.ones( 3, 4, 5, 6 )         #生成一个形状为 ( 3, 4, 5, 6 ) 的数据,数据类型默认为 torch.FloatTensor

      2)使用 sum 方法对上述生成的 tensor 进行求和操作。注意 tensor 的维度索引从 0 开始。

        >>> b = torch.sum( a )               #对 a 中所有元素求和, b = 360.0
        >>> c = torch.sum( a, dim = 0 )      #对 a 中 dim = 0 元素求和
        >>> c.shape                          # c 的 shape 为 torch.Size( [ 4, 5, 6 ] ),其中所有元素值为 3.0
        >>> d = torch.sum( a, dim = 3 )      #对 a 中 dim = 3 元素求和
        >>> d.shape                          # d 的 shape 为 torch.Size( [ 3, 4, 5 ] ),其中所有元素值为 6.0

      对上述结果进行解释,b 的结果很好理解,因为 tensor a 的维度为 ( 3, 4, 5, 6 ) 且其中所有元素的值为 1,则对其中所有元素求和的结果为 3 * 4 * 5 * 6 * 1.0 = 360.0 .

      对于 c 和 d 的结果,首先可以观察得到的是, 若在第 i 维进行求和,即 sum( a, dim = i ),则求和结果的每一个元素的值均为该维度的大小。如在 dim = 0 求和,在 dim = 0 上 a 的尺寸为 3,则求和结果 c 的每一个元素值为 3.0 .也就是说每个结果元素值均为是三个求和元素值( 1.0 )相加的结果,求和结果 c 的维度为 ( 4, 5, 6 ),说明待求和数据 a 分为 ( 4, 5, 6 ) 共 4 * 5 * 6 组的元素进行了求和运算。在 dim = 3 上的求和结果 d 现象与 c 保持一致。

      对于输入待求和数据所有数据元素均为 1 时,可以归纳出一个结论,对于维度为 ( s0, s1, s2, s3 ) 的 tensor 的第 i 维进行求和,如第 0 维,则结果的维度为 ( s1, s2, s3 ),其维度为原输入维度去除求和维度。结果的每一个元素值即为 1 * s0 = s0,即为待求和维度的尺寸。

      下面以三维数据即维度为 ( 3, 4, 4 ) 的 tensor a 为例展示 sum 在某一维度的实际计算过程。

                          

      使用 dim = 0 参数计算时,产生的结果维度为 ( 4, 4 ), 对于结果中的每一个位置 ( i, j ) ,由 3 个元素进行计算,实际计算的是 a[ 0 ][ i ][ j ] + a[ 1 ][ i ][ j ] + a[ 2 ][ i ][ j ],当上述三个元素的值均为 1.0 时,计算结果元素即为 3.0 。如上图左侧的图,a[ 0 ][ 3 ][ 3 ] + a[ 1 ][ 3 ][ 3 ] + a[ 2 ][ 3 ][ 3 ] 的结果即为输出 ( 3, 3 ) 位置上的值。上述位置索引 ( i, j ) 的数量由输入的待求和数据的其他维度的尺寸决定。 

      使用 dim = 2 参数计算时,产生的结果维度为 ( 3, 4 ),对于结果中的每一个位置( i, j ) ,由 4 个元素进行计算,实际计算的是 a[ i ][ j ][ 0 ] + a[ i ][ j ][ 1 ] + a[ i ][ j ][ 2 ] + a[ i ][ j ][ 3 ],当上述四个元素的值均为 1.0 时,计算结果元素即为 4.0 。如  a[ 0 ][ 0 ][ 0 ] + a[ 0 ][ 0 ][ 1 ] + a[ 0 ][ 0 ][ 2 ] + a[ 0 ][ 0 ][ 3 ] 即为输出 ( 0, 0 ) 位置上的值。

      对于维度为 ( s0, s1, s2, ... , si, ... , sn ) 的待求和向量,使用 dim = i 调用 sum 方法,则实际产生的结果维度为 ( s0, s1, s2, ... , si-1, si+1, ... , sn ),每个结果元素由 si 个元素元素求和获得。这 si 个元素坐标在其他维度索引保持一致,而在待求和维度索引由 0 至 si 变化。可以看到共有 ( s0, s1, s2, ... , si-1, si+1, ... , sn ) 组这样的求和元素( 索引的数量 ),即为结果的维度。

      torch.nn.softmax / torch.nn.functional.softmax

      softmax 是神经网路中常见的一种计算函数,其可将所有参与计算的对象值映射到 0 到 1 之间,并使得计算对象的和为 1. 在 pytorch 中的 softmax 方法在使用时也需要通过 dim 方法来指定具体进行 softmax 计算的维度。这里以 torch.nn.functional.softmax 为例进行说明。

      softmax 在 pytorch 官方文档中的描述如下:

      It is applied to all slices along dim, and will re-scale them so that the elements lie in the range [0, 1] and sum to 1.  

      可以明确的是, softmax 计算获得的数值在 0 - 1 之间,但是同样比较令人疑惑的是,all slices along the dim 具体指代的是那些数据。这里使用一个维度为 ( 2, 2, 2 ) 的 tensor a 作为示例。

        >>> import torch
        >>> import torch.nn.functional as f
        >>> a = torch.ones( 2, 2, 2 )
        >>> b = f.softmax( a, dim=0 )           #对 a 的第 0 维进行 softmax 计算

      与 sum 方法不同,softmax 方法计算获得的结果的维度与输入的待计算的数据的维度保持一致( sum 方法求和后进行指定求和的那一维不会出现在结果维度中 )。若对第 n 维进行 softmax 操作,且该维尺寸为 s,则 softmax 的结果为其他维索引保持一致,当前维索引由 0 至 s - 1 共 s 个值得和为 1.

      参与 softmax 计算的元素与 sum 方法很相似,对于 tensor a 在 dim = 0 进行 softmax,输出结果 b 实际上是 b[ 0 ][ i ][ j ] + b[ 1 ][ i ][ j ] 的值为 1.即其他维度索引保持一致,而在进行 softmax 维度索引由 0 至 si 变化,如 b[ 0 ][ 0 ][ 1 ] + a[ 1 ][ 0 ][ 1 ] 的值为1.对于 tensor a 在 dim = 2 进行 softmax,输出结果 b 实际上是 b[ i ][ j ][ 0 ] + a[ i ][ j ][ 1 ] 的值为 1.

        >>> c = b[ 0 ][ 0 ][ 1 ] + b[ 1 ][ 0 ][ 1 ]        #c 的值为 1

      参考

      pytorch - tensor-creation-ops

      pytorch - torch.tensor

      pytorch - torch.nn.functional

  • 相关阅读:
    python--txt文件处理
    Tomcat默认工具manager管理页面访问配置
    如何限制只有某些IP才能使用Tomcat Manager
    tomcat manager 禁止外网访问 只容许内网访问
    Tomcat 配置错误界面
    Mysql:Forcing close of thread xxx user: 'root' 的解决方法
    在Tomcat中配置404自定义错误页面详解
    白叔自创放大镜教程
    jQuery实现网页放大镜功能 转载
    转载 jQuery实现放大镜特效
  • 原文地址:https://www.cnblogs.com/yhjoker/p/12701601.html
Copyright © 2020-2023  润新知