• Feedforward Networks Training Speed Enhancement by Optimal Initialization of the Synaptic Coefficients


    Yam J. Y. F. and Chow T. W. S. Feedforward networks training speed enhancement by optimal initialization of the synaptic coefficients.

    here 超级像的一个工作, 都希望让输出的区域落在激活函数的非饱和区域.

    主要内容

    先考虑单层:

    \[h(x) = \sum_{i=1}^N w_i x_i + w_0. \]

    假设该结点所使用的激活函数为

    \[\sigma(z) = \frac{1}{1 + \exp(-z)} \in (0, 1), \]

    定义其非饱和区域为(此区域内关于\(z\)的导数大于最大导数的\(1/20\)):

    \[z \in [-4.36, 4.36]. \]

    在输入空间中, 定义超平面

    \[P(a) = \sum_{i=1}^H w_i x_i + w_0 - a, \]

    则非饱和区域以\(P(-4.36), P(4.36)\)为边界. 该区域的宽度为

    \[d = \frac{8.72}{\sqrt{\sum_{i=1}^N w_i^2}}. \]

    又该结点的定义域为:

    \[x_i \in [x_{i}^{min}, x_{i}^{max}], \]

    可知定义域的'宽度'(实际上是对角边)为:

    \[D = \sqrt{\sum_{i=1}^N [x_i^{max} - x_i^{min}]^2}. \]

    显然, 希望有

    \[d \ge D \]

    成立, 文中更精确地让 \(d = D\), 即

    \[\sqrt{\sum_{i=1}^N w_i^2} = \frac{8.72}{D}. \]

    我们希望让 \(w_i\) 采样自

    \[w_i \sim \mathcal{U}[-w_{max}, w_{max}], \]

    \[\mathbb{E} \sum_{i=1}^N w_i^2 = N \frac{w_{max}^2}{3}. \]

    故不妨令

    \[\frac{8.72}{D} = N \frac{w_{max}^2}{3}, \]

    \[w_{max} = \frac{8.72}{D} \sqrt{\frac{3}{N}}. \]

    最后, 对于偏置\(w_0\), 我们希望两个区域的中心是一致的. 其中定义域的中心为:

    \[C = (\frac{x_1^{min} + x_1^{max}}{2}, \frac{x_2^{min} + x_2^{max}}{2}, \cdots, \frac{x_N^{min} + x_N^{max}}{2})^T. \]

    故需要满足:

    \[w_0 + \sum_{i=1}^N w_i C_i = 0 \Rightarrow w_0 = - \sum_{i=1}^N w_i C_i. \]

    当再往下考虑的时候, 因为激活函数的值域为\((0, 1)\), 所以下一层权重的初始化为:

    \[D = \sqrt{H}, \\ v_{max} = \frac{15.1}{H}, \\ v_0 = - \sum_{i=1}^N 0.5 v_i. \]

    代码

    
    import torch.nn as nn
    import torch.nn.functional as F
    
    import math
    
    
    def active_init(weight, bias, low: float = 0., high: float = 1.):
        assert high > low, "high should be greater than low"
        out_channels, in_channels = weight.size()
        D = math.sqrt(in_channels * (high - low))
        w = 8.72 * math.sqrt(3 / in_channels) / D
        nn.init.uniform_(weight, -w, w)
        C = -torch.ones(in_channels) * ((high - low) / 2)
        nn.init.constant_(bias, weight.data[0] @ C)
    
    
    
  • 相关阅读:
    part17 一些知识总结
    part16 php面向对象
    part15 php函数
    part14 php foreach循环
    part13 数组排序
    part12 php数组
    part11 php条件语句
    part10 php运算符
    part09 php字符串变量
    part08 php常量
  • 原文地址:https://www.cnblogs.com/MTandHJ/p/16096682.html
Copyright © 2020-2023  润新知