• 贝叶斯文本分类 java实现


      昨天实现了一个基于贝叶斯定理的的文本分类,贝叶斯定理假设特征属性(在文本中就是词汇)对待分类项的影响都是独立的,道理比较简单,在中文分类系统中,分类的准确性与分词系统的好坏有很大的关系,这段代码也是试验不同分词系统才顺手写的一个。 

      试验数据用的sogou实验室的文本分类样本,一共分为9个类别,每个类别文件夹下大约有2000篇文章。由于文本数据量确实较大,所以得想办法让每次训练的结果都能保存起来,以便于下次直接使用,我这里使用序列化的方式保存在硬盘。 


      训练代码如下:  

      1 /**
      2  * 训练器
      3  * 
      4  * <a href="http://my.oschina.net/arthor" target="_blank" rel="nofollow">@author</a>  duyf
      5  * 
      6  */
      7 class Train implements Serializable {
      8 
      9     /**
     10      * 
     11      */
     12     private static final long serialVersionUID = 1L;
     13 
     14     public final static String SERIALIZABLE_PATH = "D:\\workspace\\Test\\SogouC.mini\\Sample\\Train.ser";
     15     // 训练集的位置
     16     private String trainPath = "D:\\workspace\\Test\\SogouC.mini\\Sample";
     17 
     18     // 类别序号对应的实际名称
     19     private Map<String, String> classMap = new HashMap<String, String>();
     20 
     21     // 类别对应的txt文本数
     22     private Map<String, Integer> classP = new ConcurrentHashMap<String, Integer>();
     23 
     24     // 所有文本数
     25     private AtomicInteger actCount = new AtomicInteger(0);
     26 
     27     
     28 
     29     // 每个类别对应的词典和频数
     30     private Map<String, Map<String, Double>> classWordMap = new ConcurrentHashMap<String, Map<String, Double>>();
     31 
     32     // 分词器
     33     private transient Participle participle;
     34 
     35     private static Train trainInstance = new Train();
     36 
     37     public static Train getInstance() {
     38         trainInstance = new Train();
     39 
     40         // 读取序列化在硬盘的本类对象
     41         FileInputStream fis;
     42         try {
     43             File f = new File(SERIALIZABLE_PATH);
     44             if (f.length() != 0) {
     45                 fis = new FileInputStream(SERIALIZABLE_PATH);
     46                 ObjectInputStream oos = new ObjectInputStream(fis);
     47                 trainInstance = (Train) oos.readObject();
     48                 trainInstance.participle = new IkParticiple();
     49             } else {
     50                 trainInstance = new Train();
     51             }
     52         } catch (Exception e) {
     53             e.printStackTrace();
     54         }
     55 
     56         return trainInstance;
     57     }
     58 
     59     private Train() {
     60         this.participle = new IkParticiple();
     61     }
     62 
     63     public String readtxt(String path) {
     64         BufferedReader br = null;
     65         StringBuilder str = null;
     66         try {
     67             br = new BufferedReader(new FileReader(path));
     68 
     69             str = new StringBuilder();
     70 
     71             String r = br.readLine();
     72 
     73             while (r != null) {
     74                 str.append(r);
     75                 r = br.readLine();
     76 
     77             }
     78 
     79             return str.toString();
     80         } catch (IOException ex) {
     81             ex.printStackTrace();
     82         } finally {
     83             if (br != null) {
     84                 try {
     85                     br.close();
     86                 } catch (IOException e) {
     87                     e.printStackTrace();
     88                 }
     89             }
     90             str = null;
     91             br = null;
     92         }
     93 
     94         return "";
     95     }
     96 
     97     /**
     98      * 训练数据
     99      */
    100     public void realTrain() {
    101         // 初始化
    102         classMap = new HashMap<String, String>();
    103         classP = new HashMap<String, Integer>();
    104         actCount.set(0);
    105         classWordMap = new HashMap<String, Map<String, Double>>();
    106 
    107         // classMap.put("C000007", "汽车");
    108         classMap.put("C000008", "财经");
    109         classMap.put("C000010", "IT");
    110         classMap.put("C000013", "健康");
    111         classMap.put("C000014", "体育");
    112         classMap.put("C000016", "旅游");
    113         classMap.put("C000020", "教育");
    114         classMap.put("C000022", "招聘");
    115         classMap.put("C000023", "文化");
    116         classMap.put("C000024", "军事");
    117 
    118         // 计算各个类别的样本数
    119         Set<String> keySet = classMap.keySet();
    120 
    121         // 所有词汇的集合,是为了计算每个单词在多少篇文章中出现,用于后面计算df
    122         final Set<String> allWords = new HashSet<String>();
    123 
    124         // 存放每个类别的文件词汇内容
    125         final Map<String, List<String[]>> classContentMap = new ConcurrentHashMap<String, List<String[]>>();
    126 
    127         for (String classKey : keySet) {
    128 
    129             Participle participle = new IkParticiple();
    130             Map<String, Double> wordMap = new HashMap<String, Double>();
    131             File f = new File(trainPath + File.separator + classKey);
    132             File[] files = f.listFiles(new FileFilter() {
    133 
    134                 @Override
    135                 public boolean accept(File pathname) {
    136                     if (pathname.getName().endsWith(".txt")) {
    137                         return true;
    138                     }
    139                     return false;
    140                 }
    141 
    142             });
    143 
    144             // 存储每个类别的文件词汇向量
    145             List<String[]> fileContent = new ArrayList<String[]>();
    146             if (files != null) {
    147                 for (File txt : files) {
    148                     String content = readtxt(txt.getAbsolutePath());
    149                     // 分词
    150                     String[] word_arr = participle.participle(content, false);
    151                     fileContent.add(word_arr);
    152                     // 统计每个词出现的个数
    153                     for (String word : word_arr) {
    154                         if (wordMap.containsKey(word)) {
    155                             Double wordCount = wordMap.get(word);
    156                             wordMap.put(word, wordCount + 1);
    157                         } else {
    158                             wordMap.put(word, 1.0);
    159                         }
    160                         
    161                     }
    162                 }
    163             }
    164 
    165             // 每个类别对应的词典和频数
    166             classWordMap.put(classKey, wordMap);
    167 
    168             // 每个类别的文章数目
    169             classP.put(classKey, files.length);
    170             actCount.addAndGet(files.length);
    171             classContentMap.put(classKey, fileContent);
    172 
    173         }
    174 
    175         
    176 
    177         
    178 
    179         // 把训练好的训练器对象序列化到本地 (空间换时间)
    180         FileOutputStream fos;
    181         try {
    182             fos = new FileOutputStream(SERIALIZABLE_PATH);
    183             ObjectOutputStream oos = new ObjectOutputStream(fos);
    184             oos.writeObject(this);
    185         } catch (Exception e) {
    186             e.printStackTrace();
    187         }
    188 
    189     }
    190 
    191     /**
    192      * 分类
    193      * 
    194      * @param text
    195      * <a href="http://my.oschina.net/u/556800" target="_blank" rel="nofollow">@return</a>  返回各个类别的概率大小
    196      */
    197     public Map<String, Double> classify(String text) {
    198         // 分词,并且去重
    199         String[] text_words = participle.participle(text, false);
    200 
    201         Map<String, Double> frequencyOfType = new HashMap<String, Double>();
    202         Set<String> keySet = classMap.keySet();
    203         for (String classKey : keySet) {
    204             double typeOfThis = 1.0;
    205             Map<String, Double> wordMap = classWordMap.get(classKey);
    206             for (String word : text_words) {
    207                 Double wordCount = wordMap.get(word);
    208                 int articleCount = classP.get(classKey);
    209 
    210                 /*
    211                  * Double wordidf = idfMap.get(word); if(wordidf==null){
    212                  * wordidf=0.001; }else{ wordidf = Math.log(actCount / wordidf); }
    213                  */
    214 
    215                 // 假如这个词在类别下的所有文章中木有,那么给定个极小的值 不影响计算
    216                 double term_frequency = (wordCount == null) ? ((double) 1 / (articleCount + 1))
    217                         : (wordCount / articleCount);
    218 
    219                 // 文本在类别的概率 在这里按照特征向量独立统计,即概率=词汇1/文章数 * 词汇2/文章数 。。。
    220                 // 当double无限小的时候会归为0,为了避免 *10
    221 
    222                 typeOfThis = typeOfThis * term_frequency * 10;
    223                 typeOfThis = ((typeOfThis == 0.0) ? Double.MIN_VALUE
    224                         : typeOfThis);
    225                 // System.out.println(typeOfThis+" : "+term_frequency+" :
    226                 // "+actCount);
    227             }
    228 
    229             typeOfThis = ((typeOfThis == 1.0) ? 0.0 : typeOfThis);
    230 
    231             // 此类别文章出现的概率
    232             double classOfAll = classP.get(classKey) / actCount.doubleValue();
    233 
    234             // 根据贝叶斯公式 $(A|B)=S(B|A)*S(A)/S(B),由于$(B)是常数,在这里不做计算,不影响分类结果
    235             frequencyOfType.put(classKey, typeOfThis * classOfAll);
    236         }
    237 
    238         return frequencyOfType;
    239     }
    240 
    241     public void pringAll() {
    242         Set<Entry<String, Map<String, Double>>> classWordEntry = classWordMap
    243                 .entrySet();
    244         for (Entry<String, Map<String, Double>> ent : classWordEntry) {
    245             System.out.println("类别: " + ent.getKey());
    246             Map<String, Double> wordMap = ent.getValue();
    247             Set<Entry<String, Double>> wordMapSet = wordMap.entrySet();
    248             for (Entry<String, Double> wordEnt : wordMapSet) {
    249                 System.out.println(wordEnt.getKey() + ":" + wordEnt.getValue());
    250             }
    251         }
    252     }
    253 
    254     public Map<String, String> getClassMap() {
    255         return classMap;
    256     }
    257 
    258     public void setClassMap(Map<String, String> classMap) {
    259         this.classMap = classMap;
    260     }
    261 
    262 }

      在试验过程中,发觉某篇文章的分类不太准,某篇IT文章分到招聘类别下了,在仔细对比了训练数据后,发觉这是由于招聘类别每篇文章下面都带有“搜狗”的标志,而待分类的这篇IT文章里面充斥这搜狗这类词汇,结果招聘类下的概率比较大。由此想到,在除了做常规的贝叶斯计算时,需要把不同文本中出现次数多的词汇权重降低甚至删除(好比关键词搜索中的tf-idf),通俗点讲就是,在所有训练文本中某词汇(如的,地,得)出现的次数越多,这个词越不重要,比如IT文章中“软件”和“应用”这两个词汇,“应用”应该是很多文章类别下都有的,反而不太重要,但是“软件”这个词汇大多只出现在IT文章里,出现在大量文章的概率并不大。 我这里原本打算计算每个词的idf,然后给定一个阀值来判断是否需要纳入计算,但是由于词汇太多,计算量较大(等待结果时间较长),所以暂时注释掉了。 

    来源:http://my.oschina.net/duyunfei/blog/80283

  • 相关阅读:
    Vue.js 2.x笔记:安装与起步(1)
    EntityFramework Core笔记:保存数据(4)
    EntityFramework Core笔记:查询数据(3)
    EntityFramework Core笔记:表结构及数据基本操作(2)
    EntityFramework Core笔记:入门(1)
    ASP.NET MVC系列:web.config中ConnectionString aspnet_iis加密与AppSettings独立文件
    EntityFramework优化:第一次启动优化
    EntityFramework优化:查询性能
    EntityFramework优化:查询WITH(NOLOCK)
    SpringCloud学习笔记:熔断器Hystrix(5)
  • 原文地址:https://www.cnblogs.com/94julia/p/3103115.html
Copyright © 2020-2023  润新知