1. 贝叶斯定理:
(1) P(A^B) = P(A|B)P(B) = P(B|A)P(A)
由(1)得
P(A|B) = P(B|A)*P(A)/[p(B)]
贝叶斯在最基本题型:
假定一个场景,在一所高中男女比例为4:6, 留长头发的有男学生有女学生, 我们设定女生都留长发 , 而男生中有10%的留长发,90%留短发.那么如果我们看到远处一个长发背影?请问是一只男学生的概率?
分析:
P(男|长发) = P(长发|男)*P(男)/[p(长发)]
= (1/10)*(4/10)/[(6+4*(1/10))/10]
=1/16 =0.0625
P(女|长发) =P(长发|女)*P(女)/[p(长发)]
=1*(6/10)/[(6+4*(1/10))/10]
=30/32 =15/16
再举一个列子:
某个医院早上收了六个门诊病人,如下表。
症状 职业 疾病
打喷嚏 护士 感冒
打喷嚏 农夫 过敏
头痛 建筑工人 脑震荡
头痛 建筑工人 感冒
打喷嚏 教师 感冒
头痛 教师 脑震荡
现在又来了第七个病人,是一个打喷嚏的建筑工人。请问他患上感冒的概率有多大?(来源: http://www.ruanyifeng.com/blog/2013/12/naive_bayes_classifier.html)
Java代码实现:
1 /** 2 * ********************************************************* 3 * <p/> 4 * Author: XiJun.Gong 5 * Date: 2016-08-31 20:36 6 * Version: default 1.0.0 7 * Class description: 8 * <p>特征库</p> 9 * <p/> 10 * ********************************************************* 11 */ 12 13 public class FeaturePoint { 14 15 private String key; 16 private double p; 17 18 public FeaturePoint(String key) { 19 this(key, 1); 20 } 21 22 public FeaturePoint(String key, double p) { 23 this.key = key; 24 this.p = p; 25 } 26 27 public String getKey() { 28 return key; 29 } 30 31 public void setKey(String key) { 32 this.key = key; 33 } 34 35 public double getP() { 36 return p; 37 } 38 39 public void setP(double p) { 40 this.p = p; 41 } 42 }
1 import com.google.common.collect.ArrayListMultimap; 2 import com.google.common.collect.Multimap; 3 4 import java.util.Collection; 5 import java.util.List; 6 7 /** 8 * ********************************************************* 9 * <p/> 10 * Author: XiJun.Gong 11 * Date: 2016-08-31 15:48 12 * Version: default 1.0.0 13 * Class description: 14 * <p/> 15 * ********************************************************* 16 */ 17 18 public class Bayes { 19 private static Multimap<String, FeaturePoint> map = ArrayListMultimap.create(); 20 21 /*喂数据*/ 22 public void input(List<String> labels) { 23 24 for (String key : labels) { 25 Collection<FeaturePoint> features = map.get(key); 26 for (String value : labels) { 27 if (features == null || features.size() < 1) { 28 map.put(key, new FeaturePoint(value)); 29 continue; 30 } 31 boolean tag = false; 32 for (FeaturePoint feature : features) { 33 if (feature.getKey().equals(value)) { 34 Double num = feature.getP() + 1; 35 map.remove(key, feature); 36 map.put(key, new FeaturePoint(value, num)); 37 tag = true; 38 break; 39 } 40 } 41 if (!tag) 42 map.put(key, new FeaturePoint(value)); 43 } 44 } 45 } 46 47 /*构造模型*/ 48 public void excute(List<String> labels) { 49 // excute(labels, null); 50 } 51 52 /*构造模型*/ 53 public Double excute(final List<String> labels, final String judge, Integer dataSize) { 54 55 Double denominator = 1d; //分母 56 Double numerator = 1d; //分子 57 Double coughNum = 0d; 58 /*选择相关性分子*/ 59 Collection<FeaturePoint> featurePoints = map.get(judge); 60 for (FeaturePoint featurePoint : featurePoints) { 61 if (judge.equals(featurePoint.getKey())) { 62 coughNum = featurePoint.getP(); 63 denominator *= (featurePoint.getP() / dataSize); 64 break; 65 } 66 } 67 68 Integer size = featurePoints.size() - 1; //容量 69 for (String label : labels) { 70 for (FeaturePoint featurePoint : featurePoints) { 71 if (label.equals(featurePoint.getKey())) { 72 denominator *= (featurePoint.getP() / coughNum); 73 for (FeaturePoint feature : map.get(label)) { 74 if (label.equals(feature.getKey())) { 75 numerator *= (feature.getP() / dataSize); 76 } 77 } 78 } 79 } 80 } 81 82 return denominator / numerator; 83 } 84 85 }
1 import com.google.common.collect.Lists; 2 3 import java.util.List; 4 import java.util.Scanner; 5 6 /** 7 * ********************************************************* 8 * <p/> 9 * Author: XiJun.Gong 10 * Date: 2016-09-01 14:58 11 * Version: default 1.0.0 12 * Class description: 13 * <p/> 14 * ********************************************************* 15 */ 16 public class Main { 17 18 public static void main(String args[]) { 19 20 Scanner scanner = new Scanner(System.in); 21 Integer size = scanner.nextInt(); 22 Integer row = scanner.nextInt(); 23 Bayes bayes = new Bayes(); 24 while (scanner.hasNext()) { 25 26 for (int ro = 0; ro < row; ro++) { 27 List<String> list = Lists.newArrayList(); 28 for (int i = 0; i < size; i++) { 29 list.add(scanner.next()); 30 } 31 bayes.input(list); 32 } 33 List<String> list = Lists.newArrayList(); 34 for (int i = 0; i < size - 1; i++) { 35 list.add(scanner.next()); 36 } 37 String judge = scanner.next(); 38 System.out.println(bayes.excute(list, judge,row)); 39 ; 40 } 41 42 } 43 }
pom.xml包
<dependency> <groupId>junit</groupId> <artifactId>junit</artifactId> <version>3.8.1</version> <scope>test</scope> </dependency> <dependency> <groupId>com.google.guava</groupId> <artifactId>guava</artifactId> <version>18.0</version> </dependency>
结果:
1 3 6 2 打喷嚏 护士 感冒 3 打喷嚏 农夫 过敏 4 头痛 建筑工人 脑震荡 5 头痛 建筑工人 感冒 6 打喷嚏 教师 感冒 7 头痛 教师 脑震荡 8 打喷嚏 建筑工人 感冒 9 0.6666666666666666
1 3 6 2 打喷嚏 护士 感冒 3 打喷嚏 农夫 过敏 4 头痛 建筑工人 脑震荡 5 头痛 建筑工人 感冒 6 打喷嚏 教师 感冒 7 头痛 教师 脑震荡 8 打喷嚏 护士 感冒 9 1.3333333333333333
1 2 50 2 男 长发 3 男 短发 4 男 短发 5 男 短发 6 男 短发 7 男 短发 8 男 短发 9 男 短发 10 男 短发 11 男 短发 12 男 短发 13 男 短发 14 男 短发 15 男 短发 16 男 短发 17 男 短发 18 男 短发 19 男 短发 20 男 短发 21 男 长发 22 女 长发 23 女 长发 24 女 长发 25 女 长发 26 女 长发 27 女 长发 28 女 长发 29 女 长发 30 女 长发 31 女 长发 32 女 长发 33 女 长发 34 女 长发 35 女 长发 36 女 长发 37 女 长发 38 女 长发 39 女 长发 40 女 长发 41 女 长发 42 女 长发 43 女 长发 44 女 长发 45 女 长发 46 女 长发 47 女 长发 48 女 长发 49 女 长发 50 女 长发 51 女 长发 52 53 长发 男 54 0.06250000000000001
1 2 50 2 男 长发 3 男 短发 4 男 短发 5 男 短发 6 男 短发 7 男 短发 8 男 短发 9 男 短发 10 男 短发 11 男 短发 12 男 短发 13 男 短发 14 男 短发 15 男 短发 16 男 短发 17 男 短发 18 男 短发 19 男 短发 20 男 短发 21 男 长发 22 女 长发 23 女 长发 24 女 长发 25 女 长发 26 女 长发 27 女 长发 28 女 长发 29 女 长发 30 女 长发 31 女 长发 32 女 长发 33 女 长发 34 女 长发 35 女 长发 36 女 长发 37 女 长发 38 女 长发 39 女 长发 40 女 长发 41 女 长发 42 女 长发 43 女 长发 44 女 长发 45 女 长发 46 女 长发 47 女 长发 48 女 长发 49 女 长发 50 女 长发 51 女 长发 52 长发 女 53 0.9375
利用贝叶斯进行分类?
1 import com.google.common.collect.ArrayListMultimap; 2 import com.google.common.collect.Lists; 3 import com.google.common.collect.Multimap; 4 5 import java.util.Collection; 6 import java.util.List; 7 8 /** 9 * ********************************************************* 10 * <p/> 11 * Author: XiJun.Gong 12 * Date: 2016-08-31 15:48 13 * Version: default 1.0.0 14 * Class description: 15 * <p/> 16 * ********************************************************* 17 */ 18 19 public class Bayes { 20 private Multimap<String, FeaturePoint> map = null; 21 private List<String> featurePool = null; 22 23 public Bayes() { 24 map = ArrayListMultimap.create(); 25 featurePool = Lists.newArrayList(); 26 } 27 28 public void add(String label) { 29 featurePool.add(label); 30 } 31 32 /*喂数据*/ 33 public void input(List<String> labels) { 34 35 for (String key : labels) { 36 Collection<FeaturePoint> features = map.get(key); 37 for (String value : labels) { 38 if (features == null || features.size() < 1) { 39 map.put(key, new FeaturePoint(value)); 40 continue; 41 } 42 boolean tag = false; 43 for (FeaturePoint feature : features) { 44 if (feature.getKey().equals(value)) { 45 Double num = feature.getP() + 1; 46 map.remove(key, feature); 47 map.put(key, new FeaturePoint(value, num)); 48 tag = true; 49 break; 50 } 51 } 52 if (!tag) 53 map.put(key, new FeaturePoint(value)); 54 } 55 } 56 } 57 58 /*最符合那个分类*/ 59 public String excute(List<String> labels, Integer dataSize) { 60 61 Double max = -999999999d; 62 String max_obj = null; 63 List<Double> ans = Lists.newArrayList(); 64 for (String label : featurePool) { 65 Double p = excute(labels, label, dataSize); 66 ans.add(p); 67 if (max < p) { 68 max_obj = label; 69 max = p; 70 } 71 } 72 return max_obj; 73 } 74 75 /*构造模型*/ 76 public Double excute(final List<String> labels, final String judge, Integer dataSize) { 77 78 Double denominator = 1d; //分母 79 Double numerator = 1d; //分子 80 Double coughNum = 0d; 81 /*选择相关性分子*/ 82 Collection<FeaturePoint> featurePoints = map.get(judge); 83 for (FeaturePoint featurePoint : featurePoints) { 84 if (judge.equals(featurePoint.getKey())) { 85 coughNum = featurePoint.getP(); 86 denominator *= (featurePoint.getP() / dataSize); 87 break; 88 } 89 } 90 /*O(n^3)*/ 91 Integer size = featurePoints.size() - 1; //容量 92 for (String label : labels) { 93 for (FeaturePoint featurePoint : featurePoints) { 94 if (label.equals(featurePoint.getKey())) { 95 denominator *= (featurePoint.getP() / coughNum); 96 for (FeaturePoint feature : map.get(label)) { 97 if (label.equals(feature.getKey())) { 98 numerator *= (feature.getP() / dataSize); 99 } 100 } 101 } 102 } 103 } 104 105 return denominator / numerator; 106 } 107 108 }
1 import com.google.common.collect.Lists; 2 3 import java.util.List; 4 import java.util.Scanner; 5 6 /** 7 * ********************************************************* 8 * <p/> 9 * Author: XiJun.Gong 10 * Date: 2016-09-01 14:58 11 * Version: default 1.0.0 12 * Class description: 13 * <p/> 14 * ********************************************************* 15 */ 16 public class Main { 17 18 public static void main(String args[]) { 19 20 Scanner scanner = new Scanner(System.in); 21 Integer size = scanner.nextInt(); 22 Integer row = scanner.nextInt(); 23 Integer category = scanner.nextInt(); 24 while (scanner.hasNext()) { 25 Bayes bayes = new Bayes(); 26 for (int ro = 0; ro < row; ro++) { 27 List<String> list = Lists.newArrayList(); 28 for (int i = 0; i < size; i++) { 29 list.add(scanner.next()); 30 } 31 bayes.input(list); 32 } 33 List<String> list = Lists.newArrayList(); 34 for (int i = 0; i < size - 1; i++) { 35 list.add(scanner.next()); 36 } 37 for (int i = 0; i < category; i++) { 38 bayes.add(scanner.next()); 39 } 40 System.out.println(bayes.excute(list, row)); 41 } 42 43 } 44 }
结果:
1 2 50 2 2 男 长发 3 男 短发 4 男 短发 5 男 短发 6 男 短发 7 男 短发 8 男 短发 9 男 短发 10 男 短发 11 男 短发 12 男 短发 13 男 短发 14 男 短发 15 男 短发 16 男 短发 17 男 短发 18 男 短发 19 男 短发 20 男 短发 21 男 长发 22 女 长发 23 女 长发 24 女 长发 25 女 长发 26 女 长发 27 女 长发 28 女 长发 29 女 长发 30 女 长发 31 女 长发 32 女 长发 33 女 长发 34 女 长发 35 女 长发 36 女 长发 37 女 长发 38 女 长发 39 女 长发 40 女 长发 41 女 长发 42 女 长发 43 女 长发 44 女 长发 45 女 长发 46 女 长发 47 女 长发 48 女 长发 49 女 长发 50 女 长发 51 女 长发 52 长发 53 男 女 54 女
1 2 50 2 2 男 长发 3 男 短发 4 男 短发 5 男 短发 6 男 短发 7 男 短发 8 男 短发 9 男 短发 10 男 短发 11 男 短发 12 男 短发 13 男 短发 14 男 短发 15 男 短发 16 男 短发 17 男 短发 18 男 短发 19 男 短发 20 男 短发 21 男 长发 22 女 长发 23 女 长发 24 女 长发 25 女 长发 26 女 长发 27 女 长发 28 女 长发 29 女 长发 30 女 长发 31 女 长发 32 女 长发 33 女 长发 34 女 长发 35 女 长发 36 女 长发 37 女 长发 38 女 长发 39 女 长发 40 女 长发 41 女 长发 42 女 长发 43 女 长发 44 女 长发 45 女 长发 46 女 长发 47 女 长发 48 女 长发 49 女 长发 50 女 长发 51 女 长发 52 短发 53 男 女 54 男