• 基于贝叶斯算法实现简单的分类(java)


    参考文章:https://blog.csdn.net/qq_32690999/article/details/78737393

    项目代码目录结构

    模拟训练的数据集

     核心代码

    Bayes.java

    package IsStudent_bys;
     
    import java.util.ArrayList;
    import java.util.HashMap;
    import java.util.Map;
     
    public class Bayes {
     
        //按类别分类
        //输入:训练数据(dataSet)
        //输出:类别到训练数据的一个Map
        public Map<String,ArrayList<ArrayList<String>>> classify(ArrayList<ArrayList<String>> dataSet){
            Map<String,ArrayList<ArrayList<String>>> map = new HashMap<String, ArrayList<ArrayList<String>>>(); //待返回的Map
            int num=dataSet.size();
            for(int i=0;i<num;i++)  //遍历所有数据项
            {
                ArrayList<String> Y = dataSet.get(i);  //将第i个训练样本的信息取出
                String Class = Y.get(Y.size()-1).toString();  //约定将类别信息放在最后一个字符串
                
                if(map.containsKey(Class)){  //判断map中是否已经有这个类了
                    map.get(Class).add(Y);
                }else{  //若没有这个类就新建一个可变长数组记录并加入map
                    ArrayList<ArrayList<String>> nlist = new ArrayList<ArrayList<String>>();
                    nlist.add(Y);
                    map.put(Class,nlist);
                }
            }
            return map;
        }
        
        //计算分类后每个类对应的样本中某个特征出现的概率
        //输入:某一类别对应的数据(classdata) 目标值(value) 相应的列值(index)
        //输出:该类数据中相应列上的值等于目标值得频率
        public double CalPro_yj_c(ArrayList<ArrayList<String>> classdata, String value, int index){
            int sum = 0;  //sum用于记录相同特征出现的频数
            int num = classdata.size();
            for(int i=0;i<num;i++)
            {
                ArrayList<String> Y = classdata.get(i);
                if(Y.get(index).equals(value)) sum++;  //相同则计数
            }
            return (double)sum/num;  //返回频率,以频率带概率
            
        }
        
        //贝叶斯分类器主函数
        //输入:训练集(可变长数组);待分类集
        //输出:概率最大的类别
        public String bys_Main(ArrayList<ArrayList<String>> dataSet, ArrayList<String> testSet){
            Map<String, ArrayList<ArrayList<String>>> doc = this.classify(dataSet);  //用本class中的分类函数构造映射
            
            Object classes[] = doc.keySet().toArray(); //把map中所有的key取出来(即所有类别) ,借鉴学习了object的使用(待深入了解)
            double Max_Value=0.0; //最大的概率
            int Max_Class=-1;     //用于记录最大类的编号
            for(int i=0;i<doc.size();i++)  //对每一个类分别计算,本程序只有两个类
            {
                String c = classes[i].toString();  //将类提取出
                ArrayList<ArrayList<String>> y = doc.get(c);  //提取该类对应的数据列表
                double prob = (double)y.size()/dataSet.size();  //计算比例
                
                System.out.println(c+" : "+prob);  //输出该类的样本占总样本个数的比例!
                
                for(int j=0;j<testSet.size();j++)  //对每个属性计算先验概率
                {
                    double P_yj_c = CalPro_yj_c(y,testSet.get(j),j);
                    //输出中间结果以便测试System.out.println("now in bys_Main!!"+P_yj_c);
                    prob = prob*P_yj_c;
                }
                
                System.out.printf("P(%s | testcase) * P(testcase) = %f
    ",c,prob);  //输出分子的概率大小
                if(prob>Max_Value)  //更新分子最大概率
                {
                    Max_Value=prob;
                    Max_Class=i;
                }
            }
            return classes[Max_Class].toString();
        }
    }

    FetchData.java

    package IsStudent_bys;
     
    import java.io.IOException;
    import java.sql.Connection;
    import java.sql.DriverManager;
    import java.sql.ResultSet;
    import java.sql.SQLException;
    import java.sql.Statement;
    import java.util.ArrayList;
    import java.util.StringTokenizer;
     
    public class FetchData {
     
        //连接数据库,读取训练数据
        //输入:数据库
        //输出:可变长数组
        public ArrayList<ArrayList<String>> fetch_traindata(){
            ArrayList<ArrayList<String>> dataSet = new ArrayList<ArrayList<String>>();  //待返回
            
            Connection conn;    
            String driver = "com.mysql.jdbc.Driver"; 
            String url = "jdbc:mysql://localhost:3306/Bayes";  //指向要访问的数据库!注意后面跟的是数据库名称
            String user = "root";   //navicat for sql配置的用户名
            String password = "root";  //navicat for sql配置的密码
            try{
                Class.forName(driver);  //用class加载动态链接库——驱动程序
                conn = DriverManager.getConnection(url,user,password);  //利用信息链接数据库
                if(!conn.isClosed())
                    System.out.println("Succeeded connecting to the Database!");
                
                Statement statement = conn.createStatement();  //用statement 来执行sql语句
                String sql = "select * from TrainData";   //这是sql语句中的查询某个表,注意后面的emp是表名!!!
                ResultSet rs = statement.executeQuery(sql);  //用于返回结果
                
                String str = null;
                while(rs.next()){  //一直读到最后一条表
                    ArrayList<String> s= new ArrayList<String>();
                    str = rs.getString("Sex");  //分别读取相应栏位的信息加入到可变长数组中
                    s.add(str);
                    str = rs.getString("tatto");
                    s.add(str);
                    str = rs.getString("smoking");
                    s.add(str);
                    str = rs.getString("wearglasses");
                    s.add(str);
                    str = rs.getString("ridebike");
                    s.add(str);
                    str = rs.getString("isStudent");
                    s.add(str);
                    dataSet.add(s);  //加入dataSet
                    //System.out.println(s);  输出中间结果调试
                }
                rs.close();
                conn.close();
            }catch(ClassNotFoundException e){  //catch不同的错误信息,并报错
                System.out.println("Sorry,can`t find the Driver!");
                e.printStackTrace();
            }catch(SQLException e){
                e.printStackTrace();
            }catch (Exception e) {
                e.printStackTrace();
            }finally{
                System.out.println("数据库训练数据读取成功!");
            }
            return dataSet;
        }
        
        
        public ArrayList<String> read_testdata(String str) throws IOException  //将用户输入的一整行字符串分割解析成可变长数组
        {
            ArrayList<String> testdata=new ArrayList<String>();  //待返回
            StringTokenizer tokenizer = new StringTokenizer(str);  
            while (tokenizer.hasMoreTokens()) { 
                testdata.add(tokenizer.nextToken());
            }
            return testdata;
        }
    }

    Main.java

    package IsStudent_bys;
     
    import java.io.BufferedInputStream;
    import java.io.IOException;
    import java.util.ArrayList;
    import java.util.Scanner;
     
    public class Main {
     
        //主函数,读取数据库,并读入待判定数据,输出结果
        public static void main(String[] args) {
            FetchData Fdata = new FetchData();   //java对函数的调用要先声明相应的对象再调用
            Bayes bys = new Bayes();
            ArrayList<ArrayList<String>> dataSet = null; //训练数据列表
            ArrayList<String> testSet = null; //测试数据
            try{
                System.out.println("从数据库读入训练数据:");
                dataSet = Fdata.fetch_traindata();   //读取训练数据集合
                System.out.println("请输入测试数据:"); 
                Scanner cin = new Scanner(new BufferedInputStream(System.in));  //从标准输入输出中读取测试数据
                while(cin.hasNext())  //支持多条测试数据读取
                {
                    String str = cin.nextLine();   //先读入一行
                    testSet = Fdata.read_testdata(str);//将这一行进行字符串分隔解析后返回可变长数组类型
                    //System.out.println(testSet);  //输出中间结果
                    String ans = bys.bys_Main(dataSet, testSet);  //调用贝叶斯分类器
                    if(ans.equals("yes")) System.out.println("Yes!!! 根据已有数据推断极有可能像是一个学生!");  //输出结果
                    else System.out.println("他/她 的特征不像一名学生!");
                }
                cin.close();
            }catch (IOException e) {  //处理异常
                e.printStackTrace();
            } 
        }
     
    }

    运行效果截图:

  • 相关阅读:
    VS 2013 中如何自定义快捷键(图解)
    c# XML读取
    Java与.NET的WebServices相互调用
    .NET 的 WCF 和 WebService 有什么区别?(转载)
    2017年第六届数学中国数学建模国际赛(小美赛)比赛心得
    网络分析法(Analytic Network Process,ANP)
    图的简单应用(C/C++实现)
    【Android开发学习笔记之一】5大布局方式详解
    Android布局属性详解
    Android应用程序使用两个LinearLayout编排5个Button控件
  • 原文地址:https://www.cnblogs.com/zyt-bg/p/10405947.html
Copyright © 2020-2023  润新知