• Gibbs LDA java实现


    1.偏文、偏理的故事


        某学校高一年级有6个班级,每个班级各有一定数量的学生,3班有几个同学数学成绩很好,拿过省奥赛奖。现在教育局要来该校听数学课,学校应该安排听课老师听哪个班的课?显然是3班,因为3班有几个数学特别厉害的同学,所以3班数学强一点,至少看起来数学强一点.这里,我们把"偏理"称为3班的特点。同样,2班和4班有很多同学的语文成绩很好,他们的作文都曾被文学报刊发表过,我们可以说”偏文“是2班和4班的特点。又假如5班和6班的同学在校篮球赛上进了决赛,我们可以说5班和6班”偏体育“。如果教育局来该校听某种课程,我们就可以安排他们去有该课程"特点"的班级里听。
        在这里,原来的班级结构是只有两层,即学生层,和班级层,每个学生都有指定的班级。我们为了区分每个班级的特点,在学生和班级之间又加了一层,特点层,即”偏文“,”偏理“,”偏体育“。这个特点层就是对LDA最直观的理解。接着上面偏文偏理的故事,3班除了几个同学数学好,另外还有一部分同学思想品德很好,因多次扶老奶奶过马路而上新闻,我们同样也可以说3班同学思想品德很好。这样3班的特点就不只1个了,这里我们提出分布的概念,即每个班级可能有多个特点,只是有的特点对应的学生多,有的特点对应的学生少,我们选对应学生多的特点作为这个班级的主要特点。每个特点同样也对应多个同学也是一种分布,比如,”偏理“包含拿过奥赛奖的同学,也包含期末考试数学考满分的同学。到这里,班级包含多个特点,每个特点又包含多个学生,LDA的主要结构就是这样。
        对于一片文档,我们怎么区分这篇文档是属于那个类别?参照偏文偏理的例子, 我们可以把文档想象成班级,word想象成学生。例如某篇文档的单词中,银行,汇率,股票,下跌等次大量重复出现,那么该篇文档很有可能就是写经济的,我们可以把这篇文档归为经济类。如果某篇文档里面含有,詹姆斯,科比,扣篮,犯规等词,那么这篇文章很有可能是体育类。当然这种分类不一定是单一的,有可能一个文章有多个主题。
    三层结构如下:

    doct: |           doc1                     doc2                     doc3 ...
    
    topic:|      t1     t2      t3...      t1   t2    t3            t1  t2  t3...
    
    word: |w1 w5 w8... w6 w2 w3....      w3,w5... 
    

    最终要求的,doc下面的topics分布和topic下的words分布.LDA原理见原论文,不赘述.
    输出文件中有各种分布:
    topic ~ words
    doc ~ topics
    topic ~ docs
    详见JGibbLDA的输出文件

    2.Gibbs LDA代码结构


        第一次读代码时把lda分成了两部分,即训练部分和推测部分,训练部分训练出来模型,即topic下面的words分布等,推测部分是用训练出的模型推测新的文章。后来发现推测部分也是一种训练,只是参考了已训练好的结果再训练.如果推测的文件数据量大于参考的数据量,那么这个推测集推测出来的结果,可以当成新的模型,更为准确。训练过程和推测过程的结果类型是完全相同的,包含各个完整的分布,详见JGibbLDA的输出文件

    代码除了读入,保存之类的,核心代码不到200行.LDA 结构与代码如下:

    • 1.预处理:
      去停词表,去noise词,低频词等等.
    • 2.Estimate:推测过程
    package jgibblda;
    
    import java.io.File;
    import java.util.Vector;
    
    public class Estimator {
    	
    	// output model
    	protected Model trnModel;
    	LDACmdOption option;
    	
    	public boolean init(LDACmdOption option){
    		this.option = option;
    		trnModel = new Model();
    		
    		if (option.est){
    			if (!trnModel.initNewModel(option))
    				return false;
    			trnModel.data.localDict.writeWordMap(option.dir + File.separator + option.wordMapFileName);
    		}
    		else if (option.estc){
    			if (!trnModel.initEstimatedModel(option))
    				return false;
    		}
    		
    		return true;
    	}
    	
    	public void estimate(){
    		System.out.println("Sampling " + trnModel.niters + " iteration!");
    		
    		int lastIter = trnModel.liter;
    		for (trnModel.liter = lastIter + 1; trnModel.liter < trnModel.niters + lastIter; trnModel.liter++){
    			System.out.println("Iteration " + trnModel.liter + " ...");
    			
    			// for all z_i
    			for (int m = 0; m < trnModel.M; m++){				
    				for (int n = 0; n < trnModel.data.docs[m].length; n++){
    					// z_i = z[m][n]
    					// sample from p(z_i|z_-i, w)
    					int topic = sampling(m, n);
    					trnModel.z[m].set(n, topic);
    				}// end for each word
    			}// end for each document
    			
    			if (option.savestep > 0){
    				if (trnModel.liter % option.savestep == 0){
    					System.out.println("Saving the model at iteration " + trnModel.liter + " ...");
    					computeTheta();
    					computePhi();
    					trnModel.saveModel("model-" + Conversion.ZeroPad(trnModel.liter, 5));
    				}
    			}
    		}// end iterations		
    		
    		System.out.println("Gibbs sampling completed!
    ");
    		System.out.println("Saving the final model!
    ");
    		computeTheta();
    		computePhi();
    		trnModel.liter--;
    		trnModel.saveModel("model-final");
    	}
    	
    	/**
    	 * Do sampling
    	 * @param m document number
    	 * @param n word number
    	 * @return topic id
    	 */
    	public int sampling(int m, int n){
    		// remove z_i from the count variable
    		int topic = trnModel.z[m].get(n);
    		int w = trnModel.data.docs[m].words[n];
    		
    		trnModel.nw[w][topic] -= 1;
    		trnModel.nd[m][topic] -= 1;
    		trnModel.nwsum[topic] -= 1;
    		trnModel.ndsum[m] -= 1;
    		
    		double Vbeta = trnModel.V * trnModel.beta;
    		double Kalpha = trnModel.K * trnModel.alpha;
    		
    		//do multinominal sampling via cumulative method
    		for (int k = 0; k < trnModel.K; k++){
    			trnModel.p[k] = (trnModel.nw[w][k] + trnModel.beta)/(trnModel.nwsum[k] + Vbeta) *
    					(trnModel.nd[m][k] + trnModel.alpha)/(trnModel.ndsum[m] + Kalpha);
    		}
    		
    		// cumulate multinomial parameters
    		for (int k = 1; k < trnModel.K; k++){
    			trnModel.p[k] += trnModel.p[k - 1];
    		}
    		
    		// scaled sample because of unnormalized p[]
    		double u = Math.random() * trnModel.p[trnModel.K - 1];              // 这一段没懂
    		
    		for (topic = 0; topic < trnModel.K; topic++){
    			if (trnModel.p[topic] > u) //sample topic w.r.t distribution p
    				break;
    		}
    		
    		// add newly estimated z_i to count variables
    		
    		trnModel.nw[w][topic] += 1;
    		trnModel.nd[m][topic] += 1;
    		trnModel.nwsum[topic] += 1;
    		trnModel.ndsum[m] += 1;
     		return topic;
    	}
    	
    	public void computeTheta(){
    		for (int m = 0; m < trnModel.M; m++){
    			for (int k = 0; k < trnModel.K; k++){
    				trnModel.theta[m][k] = (trnModel.nd[m][k] + trnModel.alpha) / (trnModel.ndsum[m] + trnModel.K * trnModel.alpha);
    			}
    		}
    	}
    	
    	public void computePhi(){
    		for (int k = 0; k < trnModel.K; k++){
    			for (int w = 0; w < trnModel.V; w++){
    				trnModel.phi[k][w] = (trnModel.nw[w][k] + trnModel.beta) / (trnModel.nwsum[k] + trnModel.V * trnModel.beta);
    			}
    		}
    	}
    }
    
    

    Sampling()部分里面,以下代码没懂。每个word所属的topic初始化时是随机分配的,中间迭代的时候,为什么还是随机的?
    p[k]在这是所有topic分布之和,然后随机一个数乘以这个和,得到u。这里u可以理解成word可以取到topic的范围。
    然后返回第一个比u大的p[k]的下标k,这里k代表第k个topic,还是前k个topics?
    最终要求的不是word只对应某个topic,而是word下的topic分布,和topic下的分布,下一遍看代码要参考分布理解这一段。

    // cumulate multinomial parameters
    		for (int k = 1; k < trnModel.K; k++){
    			trnModel.p[k] += trnModel.p[k - 1];
    		}
    		
    		// scaled sample because of unnormalized p[]
    		double u = Math.random() * trnModel.p[trnModel.K - 1];              // 这一段没懂
    		
    		for (topic = 0; topic < trnModel.K; topic++){
    			if (trnModel.p[topic] > u) //sample topic w.r.t distribution p
    				break;
    		}
    

    3.Inference: 推测过程

    
    package jgibblda;
    
    import java.io.BufferedReader;
    import java.io.File;
    import java.io.FileInputStream;
    import java.io.InputStreamReader;
    import java.util.StringTokenizer;
    import java.util.Vector;
    
    public class Inferencer {	
    	// Train model
    	public Model trnModel;
    	public Dictionary globalDict;
    	private LDACmdOption option;
    	
    	private Model newModel;
    	public int niters = 100;
    	
    	//-----------------------------------------------------
    	// Init method
    	//-----------------------------------------------------
    	public boolean init(LDACmdOption option){
    		this.option = option;
    		trnModel = new Model();
    		
    		if (!trnModel.initEstimatedModel(option))
    			return false;		
    		
    		globalDict = trnModel.data.localDict;
    		computeTrnTheta();
    		computeTrnPhi();
    		
    		return true;
    	}
    	
    	//inference new model ~ getting data from a specified dataset
    	public Model inference( LDADataset newData){
    		System.out.println("init new model");
    		Model newModel = new Model();		
    		
    		newModel.initNewModel(option, newData, trnModel);		
    		this.newModel = newModel;		
    		
    		System.out.println("Sampling " + niters + " iteration for inference!");		
    		for (newModel.liter = 1; newModel.liter <= niters; newModel.liter++){
    			//System.out.println("Iteration " + newModel.liter + " ...");
    			
    			// for all newz_i
    			for (int m = 0; m < newModel.M; ++m){
    				for (int n = 0; n < newModel.data.docs[m].length; n++){
    					// (newz_i = newz[m][n]
    					// sample from p(z_i|z_-1,w)
    					int topic = infSampling(m, n);
    					newModel.z[m].set(n, topic);
    				}
    			}//end foreach new doc
    			
    		}// end iterations
    		
    		System.out.println("Gibbs sampling for inference completed!");
    		
    		computeNewTheta();
    		computeNewPhi();
    		newModel.liter--;
    		return this.newModel;
    	}
    	
    	public Model inference(String [] strs){
    		//System.out.println("inference");
    		Model newModel = new Model();
    		
    		//System.out.println("read dataset");
    		LDADataset dataset = LDADataset.readDataSet(strs, globalDict);
    		
    		return inference(dataset);
    	}
    	
    	//inference new model ~ getting dataset from file specified in option
    	public Model inference(){	
    		//System.out.println("inference");
    		
    		newModel = new Model();
    		if (!newModel.initNewModel(option, trnModel)) return null;
    		
    		System.out.println("Sampling " + niters + " iteration for inference!");
    		
    		for (newModel.liter = 1; newModel.liter <= niters; newModel.liter++){
    			//System.out.println("Iteration " + newModel.liter + " ...");
    			
    			// for all newz_i
    			for (int m = 0; m < newModel.M; ++m){
    				for (int n = 0; n < newModel.data.docs[m].length; n++){
    					// (newz_i = newz[m][n]
    					// sample from p(z_i|z_-1,w)
    					int topic = infSampling(m, n);
    					newModel.z[m].set(n, topic);
    				}
    			}//end foreach new doc
    			
    		}// end iterations
    		
    		System.out.println("Gibbs sampling for inference completed!");		
    		System.out.println("Saving the inference outputs!");
    		
    		computeNewTheta();
    		computeNewPhi();
    		newModel.liter--;
    		newModel.saveModel(newModel.dfile + "." + newModel.modelName);		
    		
    		return newModel;
    	}
    	
    	/**
    	 * do sampling for inference
    	 * m: document number
    	 * n: word number?
    	 */
    	protected int infSampling(int m, int n){
    		// remove z_i from the count variables
    		int topic = newModel.z[m].get(n);
    		int _w = newModel.data.docs[m].words[n];
    		int w = newModel.data.lid2gid.get(_w);
    		newModel.nw[_w][topic] -= 1;
    		newModel.nd[m][topic] -= 1;
    		newModel.nwsum[topic] -= 1;
    		newModel.ndsum[m] -= 1;
    		
    		double Vbeta = trnModel.V * newModel.beta;
    		double Kalpha = trnModel.K * newModel.alpha;
    		
    		// do multinomial sampling via cummulative method		
    		for (int k = 0; k < newModel.K; k++){			
    			newModel.p[k] = (trnModel.nw[w][k] + newModel.nw[_w][k] + newModel.beta)/(trnModel.nwsum[k] +  newModel.nwsum[k] + Vbeta) *
    					(newModel.nd[m][k] + newModel.alpha)/(newModel.ndsum[m] + Kalpha);
    		}
    		
    		// cummulate multinomial parameters
    		for (int k = 1; k < newModel.K; k++){
    			newModel.p[k] += newModel.p[k - 1];
    		}
    		
    		// scaled sample because of unnormalized p[]
    		double u = Math.random() * newModel.p[newModel.K - 1];     
    		
    		for (topic = 0; topic < newModel.K; topic++){
    			if (newModel.p[topic] > u)
    				break;
    		}
    		
    		// add newly estimated z_i to count variables
    		newModel.nw[_w][topic] += 1;
    		newModel.nd[m][topic] += 1;
    		newModel.nwsum[topic] += 1;
    		newModel.ndsum[m] += 1;
    		
    		return topic;
    	}
    	
    	protected void computeNewTheta(){
    		for (int m = 0; m < newModel.M; m++){
    			for (int k = 0; k < newModel.K; k++){
    				newModel.theta[m][k] = (newModel.nd[m][k] + newModel.alpha) / (newModel.ndsum[m] + newModel.K * newModel.alpha);
    			}//end foreach topic
    		}//end foreach new document
    	}
    	
    	protected void computeNewPhi(){
    		for (int k = 0; k < newModel.K; k++){
    			for (int _w = 0; _w < newModel.V; _w++){
    				Integer id = newModel.data.lid2gid.get(_w);
    				
    				if (id != null){
    					newModel.phi[k][_w] = (trnModel.nw[id][k] + newModel.nw[_w][k] + newModel.beta) / (newModel.nwsum[k] + newModel.nwsum[k] + trnModel.V * newModel.beta);
    				}
    			}//end foreach word
    		}// end foreach topic
    	}
    	
    	protected void computeTrnTheta(){
    		for (int m = 0; m < trnModel.M; m++){
    			for (int k = 0; k < trnModel.K; k++){
    				trnModel.theta[m][k] = (trnModel.nd[m][k] + trnModel.alpha) / (trnModel.ndsum[m] + trnModel.K * trnModel.alpha);
    			}
    		}
    	}
    	
    	protected void computeTrnPhi(){
    		for (int k = 0; k < trnModel.K; k++){
    			for (int w = 0; w < trnModel.V; w++){
    				trnModel.phi[k][w] = (trnModel.nw[w][k] + trnModel.beta) / (trnModel.nwsum[k] + trnModel.V * trnModel.beta);
    			}
    		}
    	}
    }
    
    

    4.数据可视化和输出
    完整代码可参考原版JGibblda

  • 相关阅读:
    python基础学习-无参装饰器
    python基础学习-day16==课后作业练习(函数对象和闭包)
    python基础学习-函数闭包
    python基础学习-函数对象与函数嵌套
    4.15作业
    反射、元类
    Mixins、多态、绑定方法与非绑定方法
    propety装饰器、继承
    封装
    4.8作业
  • 原文地址:https://www.cnblogs.com/cyno/p/4451804.html
Copyright © 2020-2023  润新知