• [论文理解] Mutual Information Neural Estimation


    Mutual Information Neural Estimation

    互信息定义:

    (I(X;Z) = int_{X imes Z} logfrac{dmathbb{P}(XZ)}{dmathbb{P}(X) otimes mathbb{P}(Z)}dmathbb{P}(XZ))

    CPC文章里用下面这个公式定义要更加容易理解,都是一样的:

    [I(x;z) = sum_{x,z}p(x,z) log frac{p(x,z)}{p(x)p(z)} ]

    互信息越大,表明两个变量依赖关系越强,互信息越小,表示两个随机变量越独立。

    KL散度的对偶问题:

    因此根据KL散度和其对偶问题之间的关系我们可以得到:

    [D_{K L}(mathbb{P} | mathbb{Q}) geq sup _{T in mathcal{F}} mathbb{E}_{mathbb{P}}[T]-log left(mathbb{E}_{mathbb{Q}}left[e^{T} ight] ight) ]

    利用上式优化互信息的下界:

    [I(X ; Z) geq I_{Theta}(X, Z) ]

    [I_{Theta}(X, Z)=sup _{ heta in Theta} mathbb{E}_{mathbb{P}_{X Z}}left[T_{ heta} ight]-log left(mathbb{E}_{mathbb{P}_{X} otimes mathbb{P}_{Z}}left[e^{T_{ heta}} ight] ight) ]

    优化算法:

    一般来说z的分布用高斯分布,x和z的分布(marginal distribution)都好采样;

    对于joint distribution,用一个神经网络来建模,即F(x,z),然后其结果就是joint distribution的采样了。

    代入公式计算即可。

    class Mine(nn.Module):
        def __init__(self, input_size=2, hidden_size=100):
            super().__init__()
            self.fc1 = nn.Linear(input_size, hidden_size)
            self.fc2 = nn.Linear(hidden_size, hidden_size)
            self.fc3 = nn.Linear(hidden_size, 1)
            
        def forward(self, input):
            output = F.elu(self.fc1(input))
            output = F.elu(self.fc2(output))
            output = self.fc3(output)
            return output
    
    def mutual_information(joint, marginal, mine_net):
        t = mine_net(joint)
        et = torch.exp(mine_net(marginal))
        mi_lb = torch.mean(t) - torch.log(torch.mean(et))
        return mi_lb, t, et
    
    
  • 相关阅读:
    JAVA最简单常识
    BREW的资源文件概述及问题
    c语言 512
    c语言510 求矩阵的乘积
    c语言 511
    c语言57
    c语言 59
    c语言55 在应用对象式宏的数组中对数组元素进行倒序排列
    c语言 511
    c语言 510 求4行3列矩阵和3行4列矩阵的乘积。各构成元素的值从键盘输入。
  • 原文地址:https://www.cnblogs.com/aoru45/p/15362515.html
Copyright © 2020-2023  润新知