京东评论情感分类器(基于bag-of-words模型)
近期在本来在研究paraVector模型,想拿bag-of-words来做对照。
数据集是京东的评论,经过人工挑选,选出一批正面和负面的评论。
实验的数据量不大,340条正面,314条负面。我一般拿200正面和200负面做训练,剩下做測试。
做着做着,领悟了一些机器学习的道理。发现,对于不同的数据集,效果是不同的。
对于特定的数据集,随便拿来一套模型可能并不适用。
对于这些评论,我感觉就是bag-of-words模型靠谱点。
由于这些评论的特点是语句简短,关键词重要。
paraVector模型感觉比較擅长长文本的分析,注重上下文。
事实上我还结合了两个模型来做一个新的模型,准确率有点提高,可是不大。可能我数据量太少了。
整理了一下思路,做了个评论情感分类的demo。
特征抽取是bag-of-words模型。
分类器是自己想的一个模型,结合了knn和kmeans的思想。依据对于正负样本的训练集分别求出两个聚类中心,每次新样本进来,跟两个中心做距离比較。
下面是demo的代码:
import java.util.Scanner; public class BowInterTest { public static void main(String[] args) throws Throwable { // TODO Auto-generated method stub BowModel bm = new BowModel("/media/linger/G/sources/comment/test/all");//all=good+bad double[][] good = bm.generateFeature("/media/linger/G/sources/comment/test/good",340); double[][] bad = bm.generateFeature("/media/linger/G/sources/comment/test/bad",314); bm.train(good,0,200,bad,0,200);//指定训练数据 //bm.test(good, 200, 340, bad, 200, 314);//指定測试数据 //交互模式 Scanner sc = new Scanner(System.in); while(sc.hasNext()) { String doc = sc.nextLine(); double[] fea = bm.docFea(doc); Norm.arrayNorm2(fea); double re = bm.predict(fea); if(re<0) { System.out.println("good:"+re); } else { System.out.println("bad:"+re); } } } }
import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.File; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStreamReader; import java.io.OutputStreamWriter; import java.io.UnsupportedEncodingException; import java.util.StringTokenizer; public class BowModel extends KnnCoreModel { Dict dict; DocFeatureFactory dff; public BowModel(String path) throws IOException { dict = new Dict(); dict.loadFromLocalFile(path); dff = new DocFeatureFactory(dict.getWord2Index()); } public double[] docFea(String doc) { return dff.getFeature(doc); } public double[][] generateFeature(String docsFile,int docNum) throws IOException { double[][] featureTable = new double[docNum][]; int docIndex=0; File file = new File(docsFile); BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(file),"utf-8")); while(true) { String line=br.readLine(); if(line == null) break; featureTable[docIndex++] = dff.getFeature(line); } br.close(); return featureTable; } }
import java.io.BufferedReader; import java.io.File; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStreamReader; import java.io.UnsupportedEncodingException; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Hashtable; import java.util.Iterator; import java.util.List; import java.util.StringTokenizer; import java.util.Map.Entry; public class Dict { HashMap<String,Integer> word2Index =null; Hashtable<String,Integer> word2Count = null; void loadFromLocalFile(String path) throws IOException { word2Index = new HashMap<String,Integer>(); word2Count = new Hashtable<String,Integer>(); int index = 0; File file = new File(path); BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(file),"utf-8")); while(true) { String line=br.readLine(); if(line == null) break; StringTokenizer tokenizer=new StringTokenizer(line," "); while(tokenizer.hasMoreElements()) { String term=tokenizer.nextToken(); if(word2Count.containsKey(term)) { int freq=word2Count.get(term)+1; word2Count.put(term, freq); } else { word2Count.put(term, 1); word2Index.put(term, index++); } } } br.close(); } public HashMap<String,Integer> getWord2Index() { return word2Index; } public void print() { Iterator<Entry<String, Integer>> iter=word2Count.entrySet().iterator(); while(iter.hasNext()) { Entry<String,Integer> item=(Entry<String,Integer>)iter.next(); if(item.getValue()>30) System.out.printf("%s,%d ",item.getKey(),item.getValue()); } } public static void main(String[] args) throws IOException { // TODO Auto-generated method stub Dict dict = new Dict(); dict.loadFromLocalFile("/media/linger/G/sources/comment/test/all"); dict.print(); } }
import java.util.HashMap; import java.util.StringTokenizer; public class DocFeatureFactory { HashMap<String,Integer> word2Index; double[] feature; int dim; public DocFeatureFactory(HashMap<String,Integer> w2i) { word2Index = w2i; dim = w2i.size(); } double[] getFeature(String doc) { feature = new double[dim]; int wordNum=0; //while(wordNum<25)//这个作用跟规范化的一样啊 //{ StringTokenizer tokenizer=new StringTokenizer(doc," "); while(tokenizer.hasMoreElements()) { wordNum++; String term =tokenizer.nextToken(); Integer index = word2Index.get(term); if(index==null) continue; feature[index]++; } //} return feature; } public static void main(String[] args) { // TODO Auto-generated method stub } }
public class KnnCoreModel { double[] good_standard ; double[] bad_standard ; public void train(double[][] good,int train_good_start,int train_good_end, double[][] bad,int train_bad_start,int train_bad_end) { //double[][] good = generateFeature("/media/linger/G/sources/comment/test/good",340); //double[][] bad = generateFeature("/media/linger/G/sources/comment/test/bad",314); //double[] arv = new double[good[0].length]; //double[] var = new double[good[0].length]; //2范式归一化 Norm.tableNorm2(good); Norm.tableNorm2(bad); good_standard = new double[good[0].length]; bad_standard = new double[bad[0].length]; for(int i=train_good_start;i<train_good_end;i++) { for(int j=0;j<good[i].length;j++) { good_standard[j]+=good[i][j]; } } //System.out.println(" good core:"); for(int j=0;j<good_standard.length;j++) { good_standard[j]/=(train_good_end-train_good_start); //System.out.printf("%f,",good_standard[j]); } for(int i=train_bad_start;i<train_bad_end;i++) { for(int j=0;j<bad[i].length;j++) { bad_standard[j]+=bad[i][j]; } } //System.out.println(" bad core:"); for(int j=0;j<bad_standard.length;j++) { bad_standard[j]/=(train_bad_end-train_bad_start); //System.out.printf("%f,",bad_standard[j]); } } public void test(double[][] good,int test_good_start,int test_good_end, double[][] bad,int test_bad_start,int test_bad_end) { Norm.tableNorm2(good); Norm.tableNorm2(bad); int error=0; double good_dis; double bad_dis; //test for(int i=test_good_start;i<test_good_end;i++) { good_dis= distance(good[i],good_standard); bad_dis = distance(good[i],bad_standard); //good_dis= allDistance(good[i],good,train_good_start,train_good_end); //bad_dis = allDistance(good[i],bad,train_bad_start,train_bad_end); double dis= good_dis-bad_dis; if(dis>0) { error++; System.out.println("-:"+(dis)); } else { System.out.println("+:"+(dis)); } } for(int i=test_bad_start;i<test_bad_end;i++) { good_dis= distance(bad[i],good_standard); bad_dis = distance(bad[i],bad_standard); //good_dis= allDistance(bad[i],good,train_good_start,train_good_end); //bad_dis = allDistance(bad[i],bad,train_bad_start,train_bad_end); double dis= good_dis-bad_dis; if(dis>0) { System.out.println("+:"+(dis)); } else { error++; System.out.println("-:"+(dis)); } } int count = (test_good_end-test_good_start+test_bad_end-test_bad_start); System.out.println(" error:"+error+",total:"+count); System.out.println("error rate:"+(double)error/count); System.out.println("acc rate:"+(double)(count-error)/count); } public double predict(double[] fea) { double good_dis = distance(fea,good_standard); double bad_dis = distance(fea,bad_standard); return good_dis-bad_dis; } private double distance(double[] src,double[] dst) { double sum=0; if(src.length!=dst.length) { System.out.println("size not right!"); return sum; } for(int i=0;i<src.length;i++) { sum+=(dst[i]-src[i])*(dst[i]-src[i]); } //return Math.sqrt(sum); return sum; } private double allDistance(double[]src,double[][] trainSet,int start,int end) { double sum=0; for(int i=start;i<end && i<trainSet.length;i++) { sum+=distance(src,trainSet[i]); } return sum; } }
public class Norm { public static void arrayNorm2(double[] array) { double sum; sum=0; for(int j=0;j<array.length;j++) { sum +=array[j]*array[j]; } if(sum == 0) return; sum = Math.sqrt(sum); for(int j=0;j<array.length;j++) { array[j]/=sum; } } public static void tableNorm2(double[][] table) { double sum; for(int i=0;i<table.length;i++) { sum=0; for(int j=0;j<table[i].length;j++) { sum +=table[i][j]*table[i][j]; } if(sum == 0) continue; sum = Math.sqrt(sum); for(int j=0;j<table[i].length;j++) { table[i][j]/=sum; } } } }
数据集下载:http://download.csdn.net/detail/linger2012liu/7758939
本文作者:linger
本文链接:http://blog.csdn.net/lingerlanlan/article/details/38418277