• [学习笔记] Gibbs Sampling


    Gibbs Sampling

    Intro

    Gibbs Sampling 方法是我最近在看概率图模型相关的论文的时候遇见的,采样方法大致为:迭代抽样,最开始从随机样本中抽样,然后将此样本作为条件项,按条件概率抽样,每次只从一个维度考虑,当所有维度均采样完,开始下一轮迭代。

    Random Sampling

    假设我们一直一个随机变量的概率密度函数,我们如何采样得到服从这个分布的样本呢?

    学矩阵论的时候,老师教我们用反函数来生成任意概率分布的随机数,因此,我们也可以用反函数法来生成该分布的样本。即假设 $ xi $ 是 $ [0,1] $ 区间上均匀分布的随机变量,则其反函数$ cdf^{-1}( xi ) $ 服从该概率密度函数为 $ p(x) $ 的分布。

    有一个问题就是,当 $ p(x) $ 复杂到其累积分布函数的反函数无法计算的时候,或者不知道 $ p(x) $ 的精确值的时候,如何采样呢?

    这时候就要用到一些采样的策略,比如拒绝采样、重要性采样、Gibbs采样等等。下面就记一下各种采样策略。

    Rejection Sampling

    拒绝采样的原理是,已知一个提议分布q(往往是简单分布)和原始分布p,从提议分布中采样一个样本(hat{x}),然后计算接受率(a(hat{x}) = frac{p(hat{x}}{kq(hat{x})}),然后从均匀分布中生成一个值z,如果z小于等于a,则接受样本,否则不接受样本,继续采样,知道采样到了足够的样本。

    这个图应该可以说明,上面蓝色的线是提议分布,必须包含原始分布,然后在z0处计算接受率即可。

    然而拒绝采样要求提议分布和原始分布比较接近,这样采样率才会比较高,否则这个采样方法就是低效的,所以往往实际中并不采用这种采样方法。同样的,重要性采样方法也是比较低效的方法。(略去)

    MCMC

    MCMC是马尔可夫蒙特卡罗方法,是一种针对高维变量的采样方法。

    MCMC的核心思想是将采样过程看成一个马尔可夫链,认为第t+1次采样是依赖于第t次抽取样本(x_t)以及状态转移分布(q(x|x_t))。根据马尔可夫性链的收敛特性,我们知道在转移足够多此之后最终的状态将会收敛到一个固定的状态,我们假定收敛时的分布为(p(x)),那么在状态平稳时进行抽样得到的样本就肯定服从与(p(x))分布。

    MCMC一般应用的方法有Metropolis-Hastings算法和Gibbs采样算法。为了快点引入Gibbs Sampling,前者略去。

    Gibbs Sampling

    假设有一随机向量(x = (x_1,x_2,...,x_d)),其中d表示他有d维,每一维是一随机变量,且并不是我们常见的相互独立前提。那么,如果我们已知这个随机向量的概率分布,我们如何从这个分布中进行采样呢?

    显然想要从多元分布的联合概率分布中直接抽样是相当困难的,而Gibbs Sampling就是一种简单而且有效的采样方法。吉布斯采样的大致步骤如下:

    从一个随机的初始化状态(x^{(0)}=[x_1|x_2^{(0)},x_3^{(0)},cdots,x_d^{(0)}])开始,对每个维度单独进行采样,其采样顺序大致如下:

    [x_1^{(1)} hicksim p(x_1|x_2^{(0)},x_3^{(0)},cdots,x_d^{(0)}) \x_2^{(1)} hicksim p(x_2|x_1^{(0)},x_3^{(0)},cdots,x_d^{(0)}) \vdots \x_d^{(1)} hicksim p(x_d|x_1^{(0)},x_2^{(0)},cdots,x_{d-1}^{(0)}) \vdots \x_1^{(t)} hicksim p(x_1|x_2^{(t-1)},x_3^{(t-1)},cdots,x_d^{(t-1)}) \vdots\x_{d}^{(t)} hicksim p(x_d|x_1^{(t-1)},x_2^{(t-1)},cdots,x_{d-1}^{(t-1)}) \ ]

    遵从上面的采样步骤,我们最终能够采样得到所需要的高维分布的样本。需要注意的是,迭代的最开始采样得到的样本并不是完全满足所需要的分布的样本,因为采样之初采样的分布是提议分布,一般是均匀分布,而Gibbs Sampling的过程更像是一个单步迭代的过程,这使我想起了EM算法,都是一样的,一步一步去迭代达到最终结果。

    我在网上找到了一个能够描述这个过程的图片:

    如上图所示,右图是我们需要的分布,左边是迭代的过程,最开始抽样的点0和1都是均匀分布抽样得到的,而越到后面,抽样的点都越满足我们右边的分布,所以这个过程可以说明Gibbs Sampling抽样的过程是可行的。

    还有下面这张图,也差不多:

    Coding

    Gibbs Sampling我是从一篇图像合成的论文中看到并有所了解的,文章基于MRF,使用神经网络去拟合条件分布(p(x_i|x_{-i})),其中(x_{-i})表示除了第i个属性的其他属性。

    具体到图像中来,(x_i)就是第i个位置的像素点的像素值,而(x_{-i})描述的就是除了这个点以外的其他所有点,因此上式的概率分布就是一个条件分布。

    使用神经网络可以拟合出这个分布来,那么如何去生成图片又是一个问题。

    文章给出的解决方案就是Gibbs Sampling,先从随机噪声开始,逐像素进行生成,第一次迭代完成将生成一张图片,那么第二次第三次依次可以使用上一次迭代完前生成的图片进行迭代生成下一次,当迭代次数足够多的时候,即我们认为达到了平稳分布,这个时候生成的图片就是服从该分布的图片了。

    原文参见:

    原文链接

    具体的,我给出下面的代码:

    import numpy as np
    import torch
    import torch.nn.functional as F
    from torch import nn, optim
    from torch.utils import data
    from torchvision import datasets, transforms, utils
    from tqdm import tqdm
    from PIL import Image
    import glob
    import random
    import cv2 as cv
    class MConv(nn.Conv2d):
        '''
        mask_type A or B
        A : the center is zero
        B : the center is not zero
        '''
        def __init__(self,mask_type,*args,**kwargs):
            super(MConv,self).__init__(*args,**kwargs)
            assert mask_type in ["A","B"]
            self.mask_type = mask_type
            self.register_buffer('mask', self.weight.data.clone())
            _,_,h,w = self.weight.size()
            self.mask.fill_(1)
            self.mask[:,:,h//2,w//2 + (mask_type == 'B'):] = 0
            self.mask[:,:,h//2+1:,:] = 0
            
        def forward(self,x):
            self.weight.data *= self.mask
            return super(MaskedConv2d,self).forward(x)
        
        
    class DoublePixelCNN(nn.Module):
        def __init__(self,fm,kernel_size = 7,padding = 3):
            super(DoublePixelCNN, self).__init__()
            self.net1 = nn.Sequential(
                    MConv('A', 1,  64, 17, 1,8, bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
                    MConv('B', 64, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                    MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                    MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                    MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                    MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                    #nn.Conv2d(fm, 256, 1)
            ) 
            self.net2 = nn.Sequential(
                    MConv('A', 1,  64, 17, 1,8, bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
                    MConv('B', 64, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                    MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                    MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                    MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                    MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
                    #nn.Conv2d(fm, 256, 1)
            ) 
            
            self.conv1x1 = nn.Conv2d(fm*2, 256, 1)
        def forward(self,x):
            x1 = self.net1(x)
            x2 = self.net2(x.flip(dims = [-1,-2]))
            x = torch.cat([x1,x2.flip(dims = [-1,-2])],dim = 1)
            x = self.conv1x1(x)
            return x
    
    if __name__ == "__main__":
    	tr =       data.DataLoader(datasets.MNIST(root="/media/xueaoru/Ubuntu/dataset/data",transform=transforms.ToTensor(),),
                         batch_size=64, shuffle=True, num_workers=12, pin_memory=True)
        net = DoublePixelCNN(128)
        net.cuda()
        sample = torch.rand(64,1,k,k).cuda()
        optimizer = optim.Adam(net.parameters(),lr = 0.0001)
        for epoch in range(1000):
            net.train()
            running_loss = 0.
            for input,_ in tqdm(tr):
                #print(input.size())
                input = input.cuda()
                #target = target.cuda()
                target = (input.data[:,:] * 255).long() # (b,3,h,w)
                # net(input) (b,256,3,h,w)
                loss = F.cross_entropy(net(input), target) # 计算的是每个像素的二分类交叉熵
                running_loss += loss.item()
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            print("training loss: {:.8f}".format(running_loss / len(tr)))
            if epoch % 5 == 0:
                torch.save(net.state_dict(),open("./{}.pth".format(epoch),"wb"))
                #sample.fill_(0)
                net.eval()
                with torch.no_grad():
                    for t in tqdm(range(300)):
                        for i in range(k):
                            for j in range(k):
                                out = net(sample) # (b,256)
                                probs = F.softmax(out[:, :, i ,j],dim = 1).data # (b,c) = (16,256)
                                sample[:, :, i, j] = torch.multinomial(probs, 1).float() / 255.
                    
                    utils.save_image(sample, 'sample_{:02d}.png'.format(epoch), nrow=12, padding=0)
        			sample = torch.rand(64,1,k,k).cuda()
    

    由于这个方法采样时间极其缓慢,所以我生成的图片尺度比较小,训练周期也比较短,只是做个demo使用。

  • 相关阅读:
    delphi 鼠标拖动
    Tesseract-ocr 工具使用记录
    在dcef3当中执行js代码并获得返回值
    idhttp提交post带参数并带上cookie
    新制作加热块
    java 调用oracle 分页存储过程 返回游标数据集
    JDBC链接
    ------------浪潮之巅读后感---------------
    价值观作业
    --------关于C语言的问卷调查-----------
  • 原文地址:https://www.cnblogs.com/aoru45/p/12092453.html
Copyright © 2020-2023  润新知