• 贝叶斯决策


    贝叶斯决策

    • 简单例子引入
    • 先验概率
    • 后验概率
    • 最小错误率决策
    • 最小风险贝叶斯决策

    简单的例子

      正常情况下,我们可以快速的将街上的人分成男和女两类。这里街上的人就是我们观测到的样本,将每一个人分成男、女两类就是我们做决策的过程。上面的问题就是一个分类问题。

      分类可以看作是一种决策,即我们根据观测对样本做出应归属哪一类的决策。

      假定我手里握着一枚硬币,让你猜是多少钱的硬币,这其实就可以看作一个分类决策的问题:你需要从各种可能的硬币中做出一个决策。硬币假设面值有1角、5角、1块。

      如果事先告知这枚硬币只可能是一角或者五角,那么问题就是一个两分类问题。

    先验概率

          


    最小错误率

            

    后验概率

    决策

     

     最小错误率决策

     

    最小风险贝叶斯决策

    最小风险决策

     

     

     

     

     贝叶斯决策理论的分类方法

    总结

                  

    Bayes.java

    package byas;
    
    import com.google.common.collect.Lists;
    import com.google.common.collect.Sets;
    import org.apache.commons.math3.linear.MatrixUtils;
    import org.apache.commons.math3.linear.RealMatrix;
    import org.apache.commons.math3.linear.RealVector;
    import org.lionsoul.jcseg.ASegment;
    import org.lionsoul.jcseg.core.*;
    
    import java.io.IOException;
    import java.util.ArrayList;
    import java.util.HashSet;
    
    import static org.apache.commons.math3.util.FastMath.log;
    
    
    public class Bayes {
        //创建JcsegTaskConfig分词任务实例
        //即从jcseg.properties配置文件中初始化的配置
        public static JcsegTaskConfig config = new JcsegTaskConfig();
        public static ADictionary dic = DictionaryFactory
                .createDefaultDictionary(config);
    
        //生成数据
        public static Object[] createdata() throws IOException {
            ArrayList<ArrayList<String>> retList = Lists.newArrayList();
            ArrayList<Integer> labels = Lists.newArrayList();
            ASegment seg = null;
            try {
                seg = (ASegment) SegmentFactory
                        .createJcseg(JcsegTaskConfig.SIMPLE_MODE,
                                new Object[]{config, dic});
            } catch (JcsegException e) {
                e.printStackTrace();
            }
    
            /*IWord word;
            while ( (word = seg.next()) != null ) {
                System.out.println(word.getValue());
            }
            /*String title = article.getTitle();
            String content = article.getContent();
    
            List<Term> termList = new ArrayList<Term>();
            List<String> wordList = new ArrayList<String>();
            Map<String,Set<String>> words = new HashMap<String, Set<String>>();
            Queue<String> que = new LinkedList<String>();
            try {
                if(seg!=null){
                    seg.reset(new StringReader(title + content));
                    IWord word;
                    while ( (word = seg.next()) != null ) {
                        if(shouldInclude(word.getValue())){
                            wordList.add(word.getValue());
                        }
                    }
                }
    
            } catch (IOException e) {
                e.printStackTrace();
            }*/
    
            /*retList.add(Lists.newArrayList("my", "dog", "has", "flea", "problems", "help", "please"));
            retList.add(Lists.newArrayList("maybe", "not", "take", "him", "to", "dog", "park", "stupid"));
            retList.add(Lists.newArrayList("my", "dalmation", "is", "so", "cute", "I", "love", "him"));
            retList.add(Lists.newArrayList("stop", "posting", "stupid", "worthless", "garbage"));
            retList.add(Lists.newArrayList("mr", "licks", "ate", "my", "steak", "how", "to", "stop", "him"));
            retList.add(Lists.newArrayList("quit", "buying", "worthless", "dog", "food", "stupid"));
            ArrayList<Integer> labels = Lists.newArrayList(0,1,0,1,0,1);*/
            return new Object[]{retList,labels};
        }
    
    
        //获取单词set
        public static ArrayList<String> createVocabSet(ArrayList<ArrayList<String>> lists){
            HashSet<String> retSet = Sets.newHashSet();
            for(ArrayList<String> list : lists){
                for(String str : list){
                    retSet.add(str);
                }
            }
    
            return Lists.newArrayList(retSet);
        }
    
        //计算set中包含的单词数量
        public static double[] bagOfWords2VecMN(ArrayList<String> set,ArrayList<String> inputData){
            double[] returnVec = new double[set.size()];
            for (int i = 0; i < inputData.size(); i++) {
                if(set.contains(inputData.get(i))){
                    returnVec[set.indexOf(inputData.get(i))]++;
                }
            }
            return returnVec;
        }
    
        //训练
        public static Object[] trainNB(RealMatrix realMatrix,ArrayList<Integer> labels){
    
            int numTrainDocs = realMatrix.getRowDimension();
            int numWords = realMatrix.getRow(0).length;
            int count = 0;
            for(int l : labels){
                count += l;
            }
    
            float pAbusive = (float)count / numTrainDocs;
            //生成单词矩阵
            RealMatrix p0Matrix = MatrixUtils.createRealMatrix(1, numWords);
            p0Matrix = oneNums(p0Matrix);
            RealMatrix p1Matrix = MatrixUtils.createRealMatrix(1,numWords);
            p1Matrix = oneNums(p1Matrix);
    
            float p0Denom = 2;
            float p1Denom = 2;
    
            //不同类别单词增加,总单词增加
            for (int i = 0; i < labels.size(); i++) {
                if(labels.get(i)==1){
                    p1Matrix = p1Matrix.add(realMatrix.getRowMatrix(i));
                    p1Denom += sumMatrix(realMatrix.getRowMatrix(i));
                }else{
                    p0Matrix = p0Matrix.add(realMatrix.getRowMatrix(i));
                    p0Denom += sumMatrix(realMatrix.getRowMatrix(i));
                }
            }
    
            //单词概率矩阵
            RealMatrix p0 = logMatrix(p0Matrix.scalarMultiply(1 / p0Denom));
            RealMatrix p1 = logMatrix(p1Matrix.scalarMultiply(1 / p1Denom));
            return new Object[]{p0,p1,pAbusive};
        }
    
    
        /**
         * 矩阵填充1
         * @param realMatrix
         * @return
         */
        public static RealMatrix oneNums(RealMatrix realMatrix){
            for(int i=0;i<realMatrix.getColumnDimension();i++){
                realMatrix.setColumn(i,new double[]{1});
            }
            return realMatrix;
        }
    
        /**
         * 计算矩阵元素和
         * @param realMatrix
         * @return
         */
        public static float sumMatrix(RealMatrix realMatrix){
            float num = 0;
            double[] rows = realMatrix.getRow(0);
            for(double row : rows){
                num += row;
            }
            return num;
        }
    
        /**
         * 矩阵元素log操作
         * @param realMatrix
         * @return
         */
        public static RealMatrix logMatrix(RealMatrix realMatrix){
            double[] rows = realMatrix.getRow(0);
            double[] newRows = new double[rows.length];
            for (int i = 0; i < rows.length; i++) {
                newRows[i] = log(rows[i]);
            }
            realMatrix.setRow(0,newRows);
            return realMatrix;
        }
    
        /**
         * 矩阵元素相乘
         * @param m1
         * @param m2
         * @return
         */
        public static RealMatrix multiply(RealMatrix m1,RealMatrix m2){
            RealVector r1 = m1.getRowVector(0);
            RealVector r2 = m2.getRowVector(0);
            RealMatrix m = MatrixUtils.createRealMatrix(m1.getRowDimension(),m1.getColumnDimension());
            m.setRowVector(0,r1.ebeMultiply(r2));
    
            return m;
        }
    
    
        /**
         * 验证方法
         * @param realMatrix
         * @param p0M
         * @param p1M
         * @return
         */
        public static int classify(RealMatrix realMatrix,Object p0M,Object p1M){
    
            float p0 = (float) (sumMatrix(multiply(realMatrix, (RealMatrix) p0M))+log(1.0-0.5));
            float p1 = (float) (sumMatrix(multiply(realMatrix, (RealMatrix) p1M))+log(0.5));
            if(p0>p1){
                return 0;
            }
            return 1;
        }
    
    
    
        public static void main(String[] args) throws IOException {
            /*Object[] retData = createData();
            ArrayList<String> set = createVocabSet((ArrayList<ArrayList<String>>) retData[0]);
            ArrayList<ArrayList<String>> lists = (ArrayList<ArrayList<String>>) retData[0];
            RealMatrix m = MatrixUtils.createRealMatrix(lists.size(),set.size());
            for (int i = 0; i < lists.size(); i++) {
                m.setRow(i,bagOfWords2VecMN(set,lists.get(i)));
            }
    
            Object[] retP = trainNB(m, (ArrayList<Integer>) retData[1]);
    
            ArrayList<String> test = Lists.newArrayList("love");
            RealMatrix m1 = MatrixUtils.createRealMatrix(1,set.size());
            m1.setRow(0,bagOfWords2VecMN(set,test));
            System.out.println(classify(m1,retP[0],retP[1]));*/
            createdata();
    
    
        }
    }
  • 相关阅读:
    C++中的 . 和 >
    JVM内存学习
    JAVA内部类
    2013年3月25日
    Dalvik虚拟机
    多线程访问数据库
    深入学习Android笔记(一)
    结对项目四则运算 “软件”之升级版
    分布式版本控制系统Git的安装与使用
    第一次作业准备
  • 原文地址:https://www.cnblogs.com/zlslch/p/6789129.html
Copyright © 2020-2023  润新知