• 【PyTorch】Normalization Vivid


    一图胜千言

      以四维 N x C x H x W (批量大小 x 通道数 x 高 x 宽)为例,更形象一点,有 N 张图片,每张图片有 C 个通道,每个通道的大小为 H x W。

      相同的颜色代表同一个计算均值和方差的区域。

      以 Batch Norm 为例,粉色区域:[1, 1, *, *](即第一张图片的第一个通道的大小为H x W 的区域),[2, 1, *, *],…… ,[N, 1, *, *]。利用这 N 个粉色区域的数去计算均值和方差,然后将计算得到的均值和方差作用到在每个粉色区域的数上,就完成了标准化。

    再理解

      如果上面的图看懂了,那么下面这个常见(但难以理解)的图也就懂了。

    代码

     1 import torch
     2 import math
     3 
     4 
     5 def manual_fun(tensor):
     6     mean = tensor.sum() / tensor.numel()
     7     var = ((tensor - mean) * (tensor - mean)).sum() / tensor.numel()
     8     new_tensor = (tensor - mean) / (math.sqrt(var + 1e-5))
     9     return new_tensor
    10 
    11 
    12 l = torch.tensor(
    13     [[[[11, 2, 3], [4, 57, 6], [7, 8, 9]], [[1, 2, -3], [-4, 5, 6], [7, 8, 9]], [[1, 92, 3], [4, -95, 6], [7, 18, 9]]],
    14      [[[46, 7, 8], [4, 66, 7], [7, 8, 9]], [[100, 2, 3], [4, 5, 6], [7, 8, 9]], [[1, 2, 3], [4, 6, 6], [7, 8, 9]]]],
    15     dtype=torch.float)
    16 print(l.shape, l, sep='\n')
    17 auto_tensor = torch.nn.InstanceNorm2d(3, momentum=1)(l)
    18 for i in range(0, 2):
    19     for j in range(0, 3):
    20         print('manual: ', manual_fun(l[i, j, :, :]))
    21         print('auto: ', auto_tensor[i, j, :, :])

    最后

      关于 Normalization ,思考了很久,网上的资料也乱七八糟(我菜,所以就自己整理下吧。

      把第一张图看懂了,然后结合 PyTorch 官方代码和我的样例,还不会你来扇我(:。

      第一张图片参考于NJU1healer 的博客

  • 相关阅读:
    linux 操作 I/O 端口
    linux I/O 端口分配
    大数问题:求n的阶乘
    POJ 2586 Y2K Accounting Bug
    每天一点儿Java--ComboBox
    android的一些控件
    解决Linux(ubuntu),windows双系统重装后恢复开机选单
    Mysql数据备份与恢复
    log4net 存储到oracle 调试 Could not load type [log4net.Appender.OracleAppender]
    POJ 2533 Longest Ordered Subsequence
  • 原文地址:https://www.cnblogs.com/VividBinGo/p/15972361.html
Copyright © 2020-2023  润新知