贝叶斯决策
- 简单例子引入
- 先验概率
- 后验概率
- 最小错误率决策
- 最小风险贝叶斯决策
简单的例子
正常情况下,我们可以快速的将街上的人分成男和女两类。这里街上的人就是我们观测到的样本,将每一个人分成男、女两类就是我们做决策的过程。上面的问题就是一个分类问题。
分类可以看作是一种决策,即我们根据观测对样本做出应归属哪一类的决策。
假定我手里握着一枚硬币,让你猜是多少钱的硬币,这其实就可以看作一个分类决策的问题:你需要从各种可能的硬币中做出一个决策。硬币假设面值有1角、5角、1块。
如果事先告知这枚硬币只可能是一角或者五角,那么问题就是一个两分类问题。
先验概率
最小错误率
后验概率
决策
最小错误率决策
最小风险贝叶斯决策
最小风险决策
贝叶斯决策理论的分类方法
总结
Bayes.java
package byas; import com.google.common.collect.Lists; import com.google.common.collect.Sets; import org.apache.commons.math3.linear.MatrixUtils; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealVector; import org.lionsoul.jcseg.ASegment; import org.lionsoul.jcseg.core.*; import java.io.IOException; import java.util.ArrayList; import java.util.HashSet; import static org.apache.commons.math3.util.FastMath.log; public class Bayes { //创建JcsegTaskConfig分词任务实例 //即从jcseg.properties配置文件中初始化的配置 public static JcsegTaskConfig config = new JcsegTaskConfig(); public static ADictionary dic = DictionaryFactory .createDefaultDictionary(config); //生成数据 public static Object[] createdata() throws IOException { ArrayList<ArrayList<String>> retList = Lists.newArrayList(); ArrayList<Integer> labels = Lists.newArrayList(); ASegment seg = null; try { seg = (ASegment) SegmentFactory .createJcseg(JcsegTaskConfig.SIMPLE_MODE, new Object[]{config, dic}); } catch (JcsegException e) { e.printStackTrace(); } /*IWord word; while ( (word = seg.next()) != null ) { System.out.println(word.getValue()); } /*String title = article.getTitle(); String content = article.getContent(); List<Term> termList = new ArrayList<Term>(); List<String> wordList = new ArrayList<String>(); Map<String,Set<String>> words = new HashMap<String, Set<String>>(); Queue<String> que = new LinkedList<String>(); try { if(seg!=null){ seg.reset(new StringReader(title + content)); IWord word; while ( (word = seg.next()) != null ) { if(shouldInclude(word.getValue())){ wordList.add(word.getValue()); } } } } catch (IOException e) { e.printStackTrace(); }*/ /*retList.add(Lists.newArrayList("my", "dog", "has", "flea", "problems", "help", "please")); retList.add(Lists.newArrayList("maybe", "not", "take", "him", "to", "dog", "park", "stupid")); retList.add(Lists.newArrayList("my", "dalmation", "is", "so", "cute", "I", "love", "him")); retList.add(Lists.newArrayList("stop", "posting", "stupid", "worthless", "garbage")); retList.add(Lists.newArrayList("mr", "licks", "ate", "my", "steak", "how", "to", "stop", "him")); retList.add(Lists.newArrayList("quit", "buying", "worthless", "dog", "food", "stupid")); ArrayList<Integer> labels = Lists.newArrayList(0,1,0,1,0,1);*/ return new Object[]{retList,labels}; } //获取单词set public static ArrayList<String> createVocabSet(ArrayList<ArrayList<String>> lists){ HashSet<String> retSet = Sets.newHashSet(); for(ArrayList<String> list : lists){ for(String str : list){ retSet.add(str); } } return Lists.newArrayList(retSet); } //计算set中包含的单词数量 public static double[] bagOfWords2VecMN(ArrayList<String> set,ArrayList<String> inputData){ double[] returnVec = new double[set.size()]; for (int i = 0; i < inputData.size(); i++) { if(set.contains(inputData.get(i))){ returnVec[set.indexOf(inputData.get(i))]++; } } return returnVec; } //训练 public static Object[] trainNB(RealMatrix realMatrix,ArrayList<Integer> labels){ int numTrainDocs = realMatrix.getRowDimension(); int numWords = realMatrix.getRow(0).length; int count = 0; for(int l : labels){ count += l; } float pAbusive = (float)count / numTrainDocs; //生成单词矩阵 RealMatrix p0Matrix = MatrixUtils.createRealMatrix(1, numWords); p0Matrix = oneNums(p0Matrix); RealMatrix p1Matrix = MatrixUtils.createRealMatrix(1,numWords); p1Matrix = oneNums(p1Matrix); float p0Denom = 2; float p1Denom = 2; //不同类别单词增加,总单词增加 for (int i = 0; i < labels.size(); i++) { if(labels.get(i)==1){ p1Matrix = p1Matrix.add(realMatrix.getRowMatrix(i)); p1Denom += sumMatrix(realMatrix.getRowMatrix(i)); }else{ p0Matrix = p0Matrix.add(realMatrix.getRowMatrix(i)); p0Denom += sumMatrix(realMatrix.getRowMatrix(i)); } } //单词概率矩阵 RealMatrix p0 = logMatrix(p0Matrix.scalarMultiply(1 / p0Denom)); RealMatrix p1 = logMatrix(p1Matrix.scalarMultiply(1 / p1Denom)); return new Object[]{p0,p1,pAbusive}; } /** * 矩阵填充1 * @param realMatrix * @return */ public static RealMatrix oneNums(RealMatrix realMatrix){ for(int i=0;i<realMatrix.getColumnDimension();i++){ realMatrix.setColumn(i,new double[]{1}); } return realMatrix; } /** * 计算矩阵元素和 * @param realMatrix * @return */ public static float sumMatrix(RealMatrix realMatrix){ float num = 0; double[] rows = realMatrix.getRow(0); for(double row : rows){ num += row; } return num; } /** * 矩阵元素log操作 * @param realMatrix * @return */ public static RealMatrix logMatrix(RealMatrix realMatrix){ double[] rows = realMatrix.getRow(0); double[] newRows = new double[rows.length]; for (int i = 0; i < rows.length; i++) { newRows[i] = log(rows[i]); } realMatrix.setRow(0,newRows); return realMatrix; } /** * 矩阵元素相乘 * @param m1 * @param m2 * @return */ public static RealMatrix multiply(RealMatrix m1,RealMatrix m2){ RealVector r1 = m1.getRowVector(0); RealVector r2 = m2.getRowVector(0); RealMatrix m = MatrixUtils.createRealMatrix(m1.getRowDimension(),m1.getColumnDimension()); m.setRowVector(0,r1.ebeMultiply(r2)); return m; } /** * 验证方法 * @param realMatrix * @param p0M * @param p1M * @return */ public static int classify(RealMatrix realMatrix,Object p0M,Object p1M){ float p0 = (float) (sumMatrix(multiply(realMatrix, (RealMatrix) p0M))+log(1.0-0.5)); float p1 = (float) (sumMatrix(multiply(realMatrix, (RealMatrix) p1M))+log(0.5)); if(p0>p1){ return 0; } return 1; } public static void main(String[] args) throws IOException { /*Object[] retData = createData(); ArrayList<String> set = createVocabSet((ArrayList<ArrayList<String>>) retData[0]); ArrayList<ArrayList<String>> lists = (ArrayList<ArrayList<String>>) retData[0]; RealMatrix m = MatrixUtils.createRealMatrix(lists.size(),set.size()); for (int i = 0; i < lists.size(); i++) { m.setRow(i,bagOfWords2VecMN(set,lists.get(i))); } Object[] retP = trainNB(m, (ArrayList<Integer>) retData[1]); ArrayList<String> test = Lists.newArrayList("love"); RealMatrix m1 = MatrixUtils.createRealMatrix(1,set.size()); m1.setRow(0,bagOfWords2VecMN(set,test)); System.out.println(classify(m1,retP[0],retP[1]));*/ createdata(); } }