• java 朴素贝叶斯


    由于在网上找的bayes的源码都是基于应用的,本人才疏学浅,看不太懂,自己花了2天时间写了个粗糙的代码(基于李航那本书的例子),由于只是初学,若有错误,请指出,大家一起学习!

      1 import java.io.BufferedReader;
      2 import java.io.File;
      3 import java.io.FileNotFoundException;
      4 import java.io.FileReader;
      5 import java.io.IOException;
      6 import java.util.ArrayList;
      7 import java.util.HashMap;
      8 import java.util.List;
      9 import java.util.Map;
     10 
     11 public class Bayes {
     12     public static void main(String[] args){
     13         List<List<String>> filelist = new ArrayList<List<String>>();
     14         Map<String,Double> prioriP = new HashMap<String,Double>();
     15         Map<String,Integer> prioriNo = new HashMap<String,Integer>();
     16         Map<String,Double> result = new HashMap<String,Double>();
     17         String s1 = "D://1.txt";
     18         String s2 = "D://2.txt";
     19         filelist = read(filelist,s1);
     20         prioriP = computepirior(filelist,prioriP,prioriNo);
     21         List<List<String>> testlist = new ArrayList<List<String>>();
     22         testlist = read(testlist,s2);
     23         result = decide(prioriP,filelist,testlist,prioriNo);
     24         print(result,testlist);
     25     }
     26     //第4步、打印结构
     27     private static void print(Map<String, Double> result,List<List<String>> testlist) {
     28         System.out.print("测试数据:" + "   ");
     29         for(int i=0;i<testlist.size();i++){
     30             System.out.print("特征" + (i+1) +" :");
     31             for(int j=0;j<testlist.get(i).size();j++){
     32                 System.out.print(testlist.get(i).get(j) + "   ");
     33             }
     34         }
     35         System.out.print("所属类别:" + result.keySet().iterator().next());
     36     }
     37     //第3.1步、把元数据根据所属类别分开处理
     38     private static Map<String, Double> decide(Map<String, Double> prioriP, List<List<String>> filelist, List<List<String>> testlist, Map<String, Integer> prioriNo) {
     39         List<Map<String,Integer>> map = new ArrayList<Map<String,Integer>>();
     40         List<List<List<String>>> fc = new ArrayList<List<List<String>>>();
     41         
     42         for(Map.Entry<String, Integer> entry : prioriNo.entrySet()){
     43             List<List<String>> filecopy = new ArrayList<List<String>>();
     44             for(int i=0;i<filelist.size();i++){
     45                 List<String> list = new ArrayList<String>();
     46                 for(int j=0;j<filelist.get(i).size();j++){
     47                     if(filelist.get(filelist.size()-1).get(j).equals(entry.getKey())){
     48                         list.add(filelist.get(i).get(j));
     49                     }
     50                 }
     51                 filecopy.add(list);
     52             }
     53             fc.add(filecopy);
     54         }
     55 
     56         //有几组测试数据,本来想实现的是测试数据是多对,自己写不出来,这段代码有待改进
     57         //第3.2步、测试数据在条件下出现的次数
     58         List<Map<String,Integer>> l = new ArrayList<Map<String,Integer>>();
     59         for(int i=0;i<fc.size();i++){
     60             Map<String,Integer> mapdecide = new HashMap<String,Integer>();
     61             for(int k=0;k<fc.get(i).size()-1;k++){
     62                 for(int j=0;j<fc.get(i).get(k).size();j++){                                //需要和元数据比较的次数
     63                     if(testlist.get(k).get(0).equals(fc.get(i).get(k).get(j))){
     64                         if(mapdecide.containsKey(testlist.get(k).get(0))){
     65                             mapdecide.put(testlist.get(k).get(0), mapdecide.get(testlist.get(k).get(0)) + 1);
     66                         }
     67                         else{
     68                             mapdecide.put(testlist.get(k).get(0), 1);
     69                         }
     70                     }
     71                 }
     72             }
     73             l.add(mapdecide);
     74         }
     75         
     76         //第3.3步、求后验概率,并比较哪个类别的概率大即所属类别
     77         Map<String,Double> m = new HashMap<String,Double>();
     78         for(int i=0;i<l.size();i++){
     79             double d = 1.0;
     80             for(Map.Entry<String, Integer> entry : l.get(i).entrySet()){
     81                 d *= (entry.getValue()/(double)fc.get(i).get(fc.get(i).size()-1).size());
     82             }
     83             m.put(fc.get(i).get(fc.get(i).size()-1).get(0), prioriP.get(fc.get(i).get(fc.get(i).size()-1).get(0)) * d);
     84         }
     85         
     86         Double max = 0.0;
     87         for(Map.Entry<String, Double> e : m.entrySet()){
     88             if(max <= e.getValue()){
     89                 max = e.getValue();
     90             }
     91         }
     92         
     93         Map<String,Double> result = new HashMap<String,Double>();
     94         for(Map.Entry<String, Double> e:m.entrySet()){
     95             if(max == e.getValue()){
     96                 result.put(e.getKey(), e.getValue());
     97             }
     98         }
     99         return result;
    100     }
    101     
    102     //第2步、求先验概率
    103     private static Map<String, Double> computepirior(List<List<String>> list, Map<String, Double> prioriP, Map<String, Integer> m) {
    104         
    105         for(int i=0;i<list.get(list.size()-1).size();i++){
    106             if(m.containsKey(list.get(list.size()-1).get(i))){
    107                 m.put(list.get(list.size()-1).get(i),m.get(list.get(list.size()-1).get(i)) + 1);
    108             }
    109             else{
    110                 m.put(list.get(list.size()-1).get(i),1);
    111             }
    112         }
    113         for (Map.Entry<String,Integer> entry : m.entrySet()) {
    114             prioriP.put(entry.getKey(),(entry.getValue()/(double)list.get(list.size()-1).size()));
    115         }
    116         return prioriP;
    117     }
    118     //第1步、读取测试数据和训练数据
    119     private static List<List<String>> read(List<List<String>> list, String sread) {
    120         try {
    121             FileReader fr = new FileReader(new File(sread));
    122             BufferedReader br = new BufferedReader(fr);
    123             String string = br.readLine();
    124             while(string != null){
    125                 List<String> l = new ArrayList<String>();
    126                 String[] str = string.split(" ");
    127                 for (String s : str) {
    128                     l.add(s);
    129                 }
    130                 list.add(l);
    131                 string = br.readLine();
    132             }
    133         } catch (FileNotFoundException e) {
    134             e.printStackTrace();
    135         } catch (IOException e) {
    136             e.printStackTrace();
    137         }
    138         return list;
    139     }
    140 }

     训练数据:

    1 1 1 1 1 2 2 2 2 2 3 3 3 3 3
    S M M S S S M M L L L M M L L
    -1 -1 1 1 -1 -1 -1 1 1 1 1 1 1 1 -1

    测试数据

    2

    S

    实现结果:

    测试数据:   特征1 :2   特征2 :S   所属类别:-1

  • 相关阅读:
    python 数据分析3
    python 数据分析2
    Python 数据分析1
    Python18 Django 基础
    Python 17 web框架&Django
    一只救助犬的最后遗言
    With As 获取 id parentId 递归获取所有
    分布式事物
    div 浮动框
    sql时间比较
  • 原文地址:https://www.cnblogs.com/wn19910213/p/3329590.html
Copyright © 2020-2023  润新知