• (转) Parameter estimation for text analysis 暨LDA学习小结


    Reading Note : Parameter estimation for text analysis 暨LDA学习小结

    伟大的Parameter estimation for text analysis!当把这篇看的差不多的时候,也就到了LDA基础知识终结的时刻了,意味着LDA基础模型的基本了解完成了。所以对该模型的学习告一段落,下一阶段就是了解LDA无穷无尽的变种,不过那些不是很有用了,因为LDA已经被人水遍了各大“论坛”……

    抛开LDA背后复杂深入的数学背景不说,光就LDA的内容,确实不多,虽然变分法还是不懂,不过现在终于还是理解了“LDA is just a simple model”这句话。

    总结一下学习过程:
    1.概率的基本概念:CDF、PDF、Bayes’rule、各种简单的分布Bernoulli,binomial,multinomial、包括对prior、likelihood、postprior的理解(PRML1.2)
    ​3.概率图模型 Probabilistic Graphical Models: PRML Chapter 8 基本概念即可
    4.采样算法:Basic Sampling,Sampling Methods(PRML Chapter 11),马尔科夫蒙特卡洛 MCMC,Gibbs Sampling
    ​6.进阶资料:《Gibbs Sampling for the Uninitiated》、本文
     
    ——————————————– 伟大的分割线 !PETA! ​——————————————–

    一、前面无关部分

    关于ML、MAP、Bayesian inference

    二、模型进一步记忆

    从本图来看,需要记住:

    1.θm是每一个document单独一个θ,所以M个doc共有M个θm,整个θ是一个M*K的矩阵(M个doc,每个doc一个K维topic分布向量)。

    2.φk总共只有K个,对于每一个topic,有一个φk,这些参数是独立于文档的,也就是对于整个corpus只sample一次。不像θm那样每一个都对应一个文档,每个文档都不同,φk对于所有文档都相同,是一个K*V的矩阵(K个topic,每个topic一个V维从topic产生词的概率分布)。

    就这些了。

    三、推导

    公式(39):P(p|α)=Dir(p|α)意思是从参数为α的狄利克雷分布,采样一个多项分布参数p的概率是多少,概率是标准狄利克雷PDF。这里Dirichlet delta function为:

    Δ(α⃗ )=Γ(α1)Γ(α2)Γ(αk)Γ(K1 αk)

    这个function要记住,下面一溜烟全是这个。

    公式(43)是一元语言模型的likelihood,意思是如果提供了语料库W,知道了W里面每个词的个数,那么使用最大似然估计最大化L就可以估计出参数多项分布p。

    公式(44)是考虑了先验的情形,假如已知语料库W和参数α,那么他们产生多项分布参数p的概率是Dir(p|α+n),这个推导我记得在PRML2.1中有解释,抛开复杂的数学证明,只要参考标准狄利克雷分布的归一化项,很容易想出式(46)的归一化项就是Δ(α+n)。这时如果要通过W估计参数p,那么就要使用贝叶斯推断,用这个狄利克雷pdf输出一个p的期望即可。

    最关键的推导(63)-(78):从63-73的目标是要求出整个LDA的联合概率表达式,这样(63)就可以被用在Gibbs Sampler的分子上。首先(63)把联合概率拆成相互独立的两部分p(w|z,β)p(z|α),然后分别对这两部分布求表达式。式(64)、(65)首先不考虑超参数β,而是假设已知参数Φ。这个Φ就是那个K*V维矩阵,表示从每一个topic产生词的概率。然后(66)要把Φ积分掉,这样就可以求出第一部分p(w|z,β)为表达式(68)。从66-68的积分过程一直在套用狄利克雷积分的结果,反正整篇文章套来套去始终就是这么一个狄利克雷积分。n⃗ z是一个V维的向量,对于topic z,代表每一个词在这个topic里面有几个。从69到72的道理其实和64-68一模一样了。n⃗ m是一个K维向量,对于文档m,代表每一个topic在这个文档里有几个词。

    最后(78)求出了Gibbs Sampler所需要的条件概率表达式。这个表达式还是要贴出来的,为了和代码里面对应:

    具体选择下一个新topic的方法是:通过计算每一个topic的新的产生概率p(zi=k|zi,w)也就是代码中的p[k]产生一个新topic。比如有三个topic,算出来产生新的p的概率值为{0.3,0.2,0.4},注意这个条件概率加起来并不一定是一。然后我为了按照这个概率产生一个新topic,我用random函数从uniform distribution产生一个0至0.9的随机数r。如果0<=r<0.3,则新topic赋值为1,如果0.3<=r<0.5,则新topic赋值为2,如果0.5<=r<0.9,那么新topic赋值为3。

    四、代码

    1. /* 
    2.  * (C) Copyright 2005, Gregor Heinrich (gregor :: arbylon : net)  
    3.  * LdaGibbsSampler is free software; you can redistribute it and/or modify it 
    4.  * under the terms of the GNU General Public License as published by the Free 
    5.  * Software Foundation; either version 2 of the License, or (at your option) any 
    6.  * later version. 
    7.  * LdaGibbsSampler is distributed in the hope that it will be useful, but 
    8.  * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 
    9.  * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more 
    10.  * details. 
    11.  * You should have received a copy of the GNU General Public License along with 
    12.  * this program; if not, write to the Free Software Foundation, Inc., 59 Temple 
    13.  * Place, Suite 330, Boston, MA 02111-1307 USA 
    14.  */  
    15. import java.text.DecimalFormat;  
    16. import java.text.NumberFormat;  
    17.   
    18. public class LdaGibbsSampler {  
    19.     /** 
    20.      * document data (term lists) 
    21.      */  
    22.     int[][] documents;  
    23.     /** 
    24.      * vocabulary size 
    25.      */  
    26.     int V;  
    27.     /** 
    28.      * number of topics 
    29.      */  
    30.     int K;  
    31.     /** 
    32.      * Dirichlet parameter (document--topic associations) 
    33.      */  
    34.     double alpha;  
    35.     /** 
    36.      * Dirichlet parameter (topic--term associations) 
    37.      */  
    38.     double beta;  
    39.     /** 
    40.      * topic assignments for each word. 
    41.      * N * M 维,第一维是文档,第二维是word 
    42.      */  
    43.     int z[][];  
    44.     /** 
    45.      * nw[i][j] number of instances of word i (term?) assigned to topic j. 
    46.      */  
    47.     int[][] nw;  
    48.     /** 
    49.      * nd[i][j] number of words in document i assigned to topic j. 
    50.      */  
    51.     int[][] nd;  
    52.     /** 
    53.      * nwsum[j] total number of words assigned to topic j. 
    54.      */  
    55.     int[] nwsum;  
    56.     /** 
    57.      * nasum[i] total number of words in document i. 
    58.      */  
    59.     int[] ndsum;  
    60.     /** 
    61.      * cumulative statistics of theta 
    62.      */  
    63.     double[][] thetasum;  
    64.     /** 
    65.      * cumulative statistics of phi 
    66.      */  
    67.     double[][] phisum;  
    68.     /** 
    69.      * size of statistics 
    70.      */  
    71.     int numstats;  
    72.     /** 
    73.      * sampling lag (?) 
    74.      */  
    75.     private static int THIN_INTERVAL = 20;  
    76.   
    77.     /** 
    78.      * burn-in period 
    79.      */  
    80.     private static int BURN_IN = 100;  
    81.   
    82.     /** 
    83.      * max iterations 
    84.      */  
    85.     private static int ITERATIONS = 1000;  
    86.   
    87.     /** 
    88.      * sample lag (if -1 only one sample taken) 
    89.      */  
    90.     private static int SAMPLE_LAG;  
    91.   
    92.     private static int dispcol = 0;  
    93.   
    94.     /** 
    95.      * Initialise the Gibbs sampler with data. 
    96.      *  
    97.      * @param V 
    98.      *            vocabulary size 
    99.      * @param data 
    100.      */  
    101.     public LdaGibbsSampler(int[][] documents, int V) {  
    102.   
    103.         this.documents = documents;  
    104.         this.V = V;  
    105.     }  
    106.   
    107.     /** 
    108.      * Initialisation: Must start with an assignment of observations to topics ? 
    109.      * Many alternatives are possible, I chose to perform random assignments 
    110.      * with equal probabilities 
    111.      *  
    112.      * @param K 
    113.      *            number of topics 
    114.      * @return z assignment of topics to words 
    115.      */  
    116.     public void initialState(int K) {  
    117.         int i;  
    118.   
    119.         int M = documents.length;  
    120.   
    121.         // initialise count variables.  
    122.         nw = new int[V][K];  
    123.         nd = new int[M][K];  
    124.         nwsum = new int[K];  
    125.         ndsum = new int[M];  
    126.   
    127.         // The z_i are are initialised to values in [1,K] to determine the  
    128.         // initial state of the Markov chain.  
    129.         // 为了方便,他没用从狄利克雷参数采样,而是随机初始化了!  
    130.   
    131.         z = new int[M][];  
    132.         for (int m = 0; m < M; m++) {  
    133.             int N = documents[m].length;  
    134.             z[m] = new int[N];  
    135.             for (int n = 0; n < N; n++) {  
    136.                 //随机初始化!  
    137.                 int topic = (int) (Math.random() * K);  
    138.                 z[m][n] = topic;  
    139.                 // number of instances of word i assigned to topic j  
    140.                 // documents[m][n] 是第m个doc中的第n个词  
    141.                 nw[documents[m][n]][topic]++;  
    142.                 // number of words in document i assigned to topic j.  
    143.                 nd[m][topic]++;  
    144.                 // total number of words assigned to topic j.  
    145.                 nwsum[topic]++;  
    146.             }  
    147.             // total number of words in document i  
    148.             ndsum[m] = N;  
    149.         }  
    150.     }  
    151.   
    152.     /** 
    153.      * Main method: Select initial state ? Repeat a large number of times: 1. 
    154.      * Select an element 2. Update conditional on other elements. If 
    155.      * appropriate, output summary for each run. 
    156.      *  
    157.      * @param K 
    158.      *            number of topics 
    159.      * @param alpha 
    160.      *            symmetric prior parameter on document--topic associations 
    161.      * @param beta 
    162.      *            symmetric prior parameter on topic--term associations 
    163.      */  
    164.     private void gibbs(int K, double alpha, double beta) {  
    165.         this.K = K;  
    166.         this.alpha = alpha;  
    167.         this.beta = beta;  
    168.   
    169.         // init sampler statistics  
    170.         if (SAMPLE_LAG > 0) {  
    171.             thetasum = new double[documents.length][K];  
    172.             phisum = new double[K][V];  
    173.             numstats = 0;  
    174.         }  
    175.   
    176.         // initial state of the Markov chain:  
    177.         //启动马尔科夫链需要一个起始状态  
    178.         initialState(K);  
    179.   
    180.         //每一轮sample  
    181.         for (int i = 0; i < ITERATIONS; i++) {  
    182.   
    183.             // for all z_i  
    184.             for (int m = 0; m < z.length; m++) {  
    185.                 for (int n = 0; n < z[m].length; n++) {  
    186.   
    187.                     // (z_i = z[m][n])  
    188.                     // sample from p(z_i|z_-i, w)  
    189.                     //核心步骤,通过论文中表达式(78)为文档m中的第n个词采样新的topic  
    190.                     int topic = sampleFullConditional(m, n);  
    191.                     z[m][n] = topic;  
    192.                 }  
    193.             }  
    194.   
    195.             // get statistics after burn-in  
    196.             //如果当前迭代轮数已经超过 burn-in的限制,并且正好达到 sample lag间隔  
    197.             //则当前的这个状态是要计入总的输出参数的,否则的话忽略当前状态,继续sample  
    198.             if ((i > BURN_IN) && (SAMPLE_LAG > 0) && (i % SAMPLE_LAG == 0)) {  
    199.                 updateParams();  
    200.             }  
    201.         }  
    202.     }  
    203.   
    204.     /** 
    205.      * Sample a topic z_i from the full conditional distribution: p(z_i = j | 
    206.      * z_-i, w) = (n_-i,j(w_i) + beta)/(n_-i,j(.) + W * beta) * (n_-i,j(d_i) + 
    207.      * alpha)/(n_-i,.(d_i) + K * alpha) 
    208.      *  
    209.      * @param m 
    210.      *            document 
    211.      * @param n 
    212.      *            word 
    213.      */  
    214.     private int sampleFullConditional(int m, int n) {  
    215.   
    216.         // remove z_i from the count variables  
    217.         //这里首先要把原先的topic z(m,n)从当前状态中移除  
    218.         int topic = z[m][n];  
    219.         nw[documents[m][n]][topic]--;  
    220.         nd[m][topic]--;  
    221.         nwsum[topic]--;  
    222.         ndsum[m]--;  
    223.   
    224.         // do multinomial sampling via cumulative method:  
    225.         double[] p = new double[K];  
    226.         for (int k = 0; k < K; k++) {  
    227.             //nw 是第i个word被赋予第j个topic的个数  
    228.             //在下式中,documents[m][n]是word id,k为第k个topic  
    229.             //nd 为第m个文档中被赋予topic k的词的个数  
    230.             p[k] = (nw[documents[m][n]][k] + beta) / (nwsum[k] + V * beta)  
    231.                 * (nd[m][k] + alpha) / (ndsum[m] + K * alpha);  
    232.         }  
    233.         // cumulate multinomial parameters  
    234.         for (int k = 1; k < p.length; k++) {  
    235.             p[k] += p[k - 1];  
    236.         }  
    237.         // scaled sample because of unnormalised p[]  
    238.         double u = Math.random() * p[K - 1];  
    239.         for (topic = 0; topic < p.length; topic++) {  
    240.             if (u < p[topic])  
    241.                 break;  
    242.         }  
    243.   
    244.         // add newly estimated z_i to count variables  
    245.         nw[documents[m][n]][topic]++;  
    246.         nd[m][topic]++;  
    247.         nwsum[topic]++;  
    248.         ndsum[m]++;  
    249.   
    250.         return topic;  
    251.     }  
    252.   
    253.     /** 
    254.      * Add to the statistics the values of theta and phi for the current state. 
    255.      */  
    256.     private void updateParams() {  
    257.         for (int m = 0; m < documents.length; m++) {  
    258.             for (int k = 0; k < K; k++) {  
    259.                 thetasum[m][k] += (nd[m][k] + alpha) / (ndsum[m] + K * alpha);  
    260.             }  
    261.         }  
    262.         for (int k = 0; k < K; k++) {  
    263.             for (int w = 0; w < V; w++) {  
    264.                 phisum[k][w] += (nw[w][k] + beta) / (nwsum[k] + V * beta);  
    265.             }  
    266.         }  
    267.         numstats++;  
    268.     }  
    269.   
    270.     /** 
    271.      * Retrieve estimated document--topic associations. If sample lag > 0 then 
    272.      * the mean value of all sampled statistics for theta[][] is taken. 
    273.      *  
    274.      * @return theta multinomial mixture of document topics (M x K) 
    275.      */  
    276.     public double[][] getTheta() {  
    277.         double[][] theta = new double[documents.length][K];  
    278.   
    279.         if (SAMPLE_LAG > 0) {  
    280.             for (int m = 0; m < documents.length; m++) {  
    281.                 for (int k = 0; k < K; k++) {  
    282.                     theta[m][k] = thetasum[m][k] / numstats;  
    283.                 }  
    284.             }  
    285.   
    286.         } else {  
    287.             for (int m = 0; m < documents.length; m++) {  
    288.                 for (int k = 0; k < K; k++) {  
    289.                     theta[m][k] = (nd[m][k] + alpha) / (ndsum[m] + K * alpha);  
    290.                 }  
    291.             }  
    292.         }  
    293.   
    294.         return theta;  
    295.     }  
    296.   
    297.     /** 
    298.      * Retrieve estimated topic--word associations. If sample lag > 0 then the 
    299.      * mean value of all sampled statistics for phi[][] is taken. 
    300.      *  
    301.      * @return phi multinomial mixture of topic words (K x V) 
    302.      */  
    303.     public double[][] getPhi() {  
    304.         double[][] phi = new double[K][V];  
    305.         if (SAMPLE_LAG > 0) {  
    306.             for (int k = 0; k < K; k++) {  
    307.                 for (int w = 0; w < V; w++) {  
    308.                     phi[k][w] = phisum[k][w] / numstats;  
    309.                 }  
    310.             }  
    311.         } else {  
    312.             for (int k = 0; k < K; k++) {  
    313.                 for (int w = 0; w < V; w++) {  
    314.                     phi[k][w] = (nw[w][k] + beta) / (nwsum[k] + V * beta);  
    315.                 }  
    316.             }  
    317.         }  
    318.         return phi;  
    319.     }  
    320.   
    321.     /** 
    322.      * Configure the gibbs sampler 
    323.      *  
    324.      * @param iterations 
    325.      *            number of total iterations 
    326.      * @param burnIn 
    327.      *            number of burn-in iterations 
    328.      * @param thinInterval 
    329.      *            update statistics interval 
    330.      * @param sampleLag 
    331.      *            sample interval (-1 for just one sample at the end) 
    332.      */  
    333.     public void configure(int iterations, int burnIn, int thinInterval,  
    334.         int sampleLag) {  
    335.         ITERATIONS = iterations;  
    336.         BURN_IN = burnIn;  
    337.         THIN_INTERVAL = thinInterval;  
    338.         SAMPLE_LAG = sampleLag;  
    339.     }  
    340.   
    341.     /** 
    342.      * Driver with example data. 
    343.      *  
    344.      * @param args 
    345.      */  
    346.     public static void main(String[] args) {  
    347.         // words in documents  
    348.         int[][] documents = { {1432314323143236},  
    349.             {224242222422},  
    350.             {1656016560165600},  
    351.             {56623365622656660},  
    352.             {224444155555511110},  
    353.             {542345665432}};  
    354.         // vocabulary  
    355.         int V = 7;  
    356.         int M = documents.length;  
    357.         // # topics  
    358.         int K = 2;  
    359.         // good values alpha = 2, beta = .5  
    360.         double alpha = 2;  
    361.         double beta = .5;  
    362.   
    363.         LdaGibbsSampler lda = new LdaGibbsSampler(documents, V);  
    364.           
    365.         //设定sample参数,采样运行10000轮,burn-in 2000轮,第三个参数没用,是为了显示  
    366.         //第四个参数是sample lag,这个很重要,因为马尔科夫链前后状态conditional dependent,所以要跳过几个采样  
    367.         lda.configure(10000200010010);  
    368.           
    369.         //跑一个!走起!  
    370.         lda.gibbs(K, alpha, beta);  
    371.   
    372.         //输出模型参数,论文中式 (81)与(82)  
    373.         double[][] theta = lda.getTheta();  
    374.         double[][] phi = lda.getPhi();  
    375.     }  
    376. }  
  • 相关阅读:
    微信小程序开发(十)获取手机的经纬度
    微信小程序开发(九)获取手机连接的wifi信息
    微信小程序开发(八)获取手机ip地址
    微信小程序开发(七)获取手机网络类型
    微信小程序开发(六)获取手机信息
    关于JPype报FileNotFoundError: [Errno 2] No such file or directory: '/usr/lib/jvm'错误的解决
    微信小程序开发(五)数据绑定
    微信小程序开发(四)页面跳转
    微信小程序开发(三)点击事件
    微信小程序开发(二)创建一个小程序页面
  • 原文地址:https://www.cnblogs.com/lifegoesonitself/p/3423272.html
Copyright © 2020-2023  润新知