• 机器学习之决策树熵&信息增量求解算法实现


    此文不对理论做相关阐述,仅涉及代码实现:

    1.熵计算公式:

                 P为正例,Q为反例

         Entropy(S)   = -PLog2(P) - QLog2(Q);

    2.信息增量计算:

        Gain(S,Sv) = Entropy(S) - (|Sv|/|S|)ΣEntropy(Sv);

    举例:

    转化数据输入:

     5  14
     Outlook       Sunny  Sunny  Overcast  Rain  Rain    Rain    Overcast  Sunny  Sunny    Rain    Sunny   Overcast   Overcast    Rain
     Temperature   Hot    Hot    Hot       Mild  Cool    Cool        Cool   Mild  Cool     Mild    Mild    Mild       Hot         Mild
     Humidity      High   High   High      High  Normal  Normal  Normal     High  Normal   Normal  Normal  High       Normal      High
     Wind          Weak   Strong Weak      Weak  Weak    Strong  Strong    Weak   Weak     Weak    Strong  Strong     Weak        Strong
     PlayTennis    No     No     Yes       Yes   Yes     No      Yes       No     Yes      Yes     Yes     Yes        Yes         No
     Outlook Temperature Humidity Wind PlayTennis
     1 package com.qunar.data.tree;
     2 
     3 /**
     4  * *********************************************************
     5  * <p/>
     6  * Author:     XiJun.Gong
     7  * Date:       2016-09-02 15:28
     8  * Version:    default 1.0.0
     9  * Class description:
    10  * <p>统计该类型出现的次数</p>
    11  * <p/>
    12  * *********************************************************
    13  */
    14 public class CountMap<T> {
    15 
    16     private T key;     //类型
    17     private int value;   //出现的次数
    18 
    19     public CountMap() {
    20         this(null, 0);
    21     }
    22 
    23     public CountMap(T key, int value) {
    24         this.key = key;
    25         this.value = value;
    26     }
    27 
    28     public T getKey() {
    29         return key;
    30     }
    31 
    32     public void setKey(T key) {
    33         this.key = key;
    34     }
    35 
    36     public int getValue() {
    37         return value;
    38     }
    39 
    40     public void setValue(int value) {
    41         this.value = value;
    42     }
    43 }
    View Code
      1 package com.qunar.data.tree;
      2 
      3 import com.google.common.collect.ArrayListMultimap;
      4 import com.google.common.collect.Maps;
      5 import com.google.common.collect.Multimap;
      6 import com.google.common.collect.Sets;
      7 
      8 import java.util.*;
      9 
     10 /**
     11  * *********************************************************
     12  * <p/>
     13  * Author:     XiJun.Gong
     14  * Date:       2016-09-02 14:24
     15  * Version:    default 1.0.0
     16  * Class description:
     17  * <p>决策树</p>
     18  * <p/>
     19  * *********************************************************
     20  */
     21 
     22 public class DecisionTree<T, K> {
     23 
     24     private static String positiveExampleType = "Yes";
     25     private static String counterExampleType = "No";
     26 
     27 
     28     public double pLog2(final double p) {
     29         if (0 == p) return 0;
     30         return p * (Math.log(p) / Math.log(2));
     31     }
     32 
     33     /**
     34      * 熵计算
     35      *
     36      * @param positiveExample 正例个数
     37      * @param counterExample  反例个数
     38      * @return 熵值
     39      */
     40     public double entropy(final double positiveExample, final double counterExample) {
     41 
     42         double total = positiveExample + counterExample;
     43         double positiveP = positiveExample / total;
     44         double counterP = counterExample / total;
     45         return -1d * (pLog2(positiveP) + pLog2(counterP));
     46     }
     47 
     48     /**
     49      * @param features 特征列表
     50      * @param results  对应结果
     51      * @return 将信息整合成新的格式
     52      */
     53     public Multimap<T, CountMap<K>> merge(final List<T> features, final List<T> results) {
     54         //数据转化
     55         Multimap<T, CountMap<K>> InfoMap = ArrayListMultimap.create();
     56         Iterator result = results.iterator();
     57         for (T feature : features) {
     58             K res = (K) result.next();
     59             boolean tag = false;
     60             Collection<CountMap<K>> countMaps = InfoMap.get(feature);
     61             for (CountMap countMap : countMaps) {
     62                 if (countMap.getKey().equals(res)) {
     63                     /*修改值*/
     64                     int num = countMap.getValue() + 1;
     65                     InfoMap.remove(feature, countMap);
     66                     InfoMap.put(feature, new CountMap<K>(res, num));
     67                     tag = true;
     68                     break;
     69                 }
     70             }
     71             if (!tag)
     72                 InfoMap.put(feature, new CountMap<K>(res, 1));
     73         }
     74 
     75         return InfoMap;
     76     }
     77 
     78     /**
     79      * 信息增益
     80      *
     81      * @param infoMap   因素(Outlook,Temperature,Humidity,Wind)对应的结果
     82      * @param dataTable 输入的数据表
     83      * @param type      因素中的类型(Outlook{Sunny,Overcast,Rain})
     84      * @param entropyS  总的熵值
     85      * @param totalSize 总的样本数
     86      * @return 信息增益
     87      */
     88     public double gain(Multimap<T, CountMap<K>> infoMap,
     89                        Map<K, List<T>> dataTable,
     90                        final String type,
     91                        double entropyS,
     92                        final int totalSize) {
     93         //去重
     94         Set<T> subTypes = Sets.newHashSet();
     95         subTypes.addAll(dataTable.get(type));
     96         /*计算*/
     97         for (T subType : subTypes) {
     98             Collection<CountMap<K>> countMaps = infoMap.get(subType);
     99             double subSize = 0;
    100             double positiveExample = 0;
    101             double counterExample = 0;
    102             for (CountMap<K> countMap : countMaps) {
    103                 subSize += countMap.getValue();
    104                 if (positiveExampleType.equals(countMap.getKey()))
    105                     positiveExample = countMap.getValue();
    106                 else
    107                     counterExample = countMap.getValue();
    108             }
    109             entropyS -= (subSize / totalSize) * entropy(positiveExample, counterExample);
    110         }
    111         return entropyS;
    112     }
    113 
    114     /**
    115      * 计算
    116      *
    117      * @param dataTable  数据表
    118      * @param types      因素列表{Outlook,Temperature,Humidity,Wind}
    119      * @param resultType 结果(PlayTennis)
    120      * @return 返回信息增益集合
    121      */
    122     public Map<String, Double> calculate(Map<K, List<T>> dataTable, List<K> types, K resultType) {
    123 
    124         Map<String, Double> answer = Maps.newHashMap();
    125         List<T> results = dataTable.get(resultType);
    126         int totalSize = results.size();
    127         int positiveExample = 0;
    128         int counterExample = 0;
    129         double entropyS = 0d;
    130         for (T ExampleType : results) {
    131             if (positiveExampleType.equals(ExampleType)) {
    132                 ++positiveExample;
    133                 continue;
    134             }
    135             ++counterExample;
    136         }
    137         /*计算总的熵*/
    138         entropyS = entropy(positiveExample, counterExample);
    139 
    140         Multimap<T, CountMap<K>> infoMap;
    141         for (K type : types) {
    142             infoMap = merge(dataTable.get(type), results);
    143             double _gain = gain(infoMap, dataTable, (String) type, entropyS, totalSize);
    144             answer.put((String) type, _gain);
    145         }
    146         return answer;
    147     }
    148 
    149 }   1package com.qunar.data.tree;
     2 
     3 import com.google.common.collect.Lists;
     4 import com.google.common.collect.Maps;
     5 
     6 import java.util.*;
     7 
     8 /**
     9  * *********************************************************
    10  * <p/>
    11  * Author:     XiJun.Gong
    12  * Date:       2016-09-02 16:43
    13  * Version:    default 1.0.0
    14  * Class description:
    15  * <p/>
    16  * *********************************************************
    17  */
    18 public class Main {
    19 
    20     public static void main(String args[]) {
    21 
    22         Scanner scanner = new Scanner(System.in);
    23         while (scanner.hasNext()) {
    24             DecisionTree<String, String> dt = new DecisionTree();
    25             Map<String, List<String>> dataTable = Maps.newHashMap();
    26             /*Map<String, List<String>> dataTable = Maps.newHashMap();*/
    27             List<String> types = Lists.newArrayList();
    28             String resultType;
    29             int factorSize = scanner.nextInt();
    30             int demoSize = scanner.nextInt();
    31             String type;
    32 
    33             for (int i = 0; i < factorSize; i++) {
    34                 List<String> demos = Lists.newArrayList();
    35                 type = scanner.next();
    36                 for (int j = 0; j < demoSize; j++) {
    37                     demos.add(scanner.next());
    38                 }
    39                 dataTable.put(type, demos);
    40             }
    41             for (int i = 1; i < factorSize; i++) {
    42                 types.add(scanner.next());
    43             }
    44             resultType = scanner.next();
    45             Map<String, Double> ans = dt.calculate(dataTable, types, resultType);
    46             List<Map.Entry<String, Double>> list = new ArrayList<Map.Entry<String, Double>>(ans.entrySet());
    47             Collections.sort(list, new Comparator<Map.Entry<String, Double>>() {
    48 
    49 
    50                 @Override
    51                 public int compare(Map.Entry<String, Double> o1, Map.Entry<String, Double> o2) {
    52                     return (o2.getValue() > o1.getValue() ? 1 : -1);
    53                 }
    54             });
    55 
    56             for (Map.Entry<String, Double> iterator : list) {
    57                 System.out.println(iterator.getKey() + "= " + iterator.getValue());
    58             }
    59         }
    60     }
    61 
    62 }
    63 /**
    64  *使用举例:*
    65  5  14
    66  Outlook       Sunny  Sunny  Overcast  Rain  Rain    Rain    Overcast  Sunny  Sunny    Rain    Sunny   Overcast   Overcast    Rain
    67  Temperature   Hot    Hot    Hot       Mild  Cool    Cool        Cool   Mild  Cool     Mild    Mild    Mild       Hot         Mild
    68  Humidity      High   High   High      High  Normal  Normal  Normal     High  Normal   Normal  Normal  High       Normal      High
    69  Wind          Weak   Strong Weak      Weak  Weak    Strong  Strong    Weak   Weak     Weak    Strong  Strong     Weak        Strong
    70  PlayTennis    No     No     Yes       Yes   Yes     No      Yes       No     Yes      Yes     Yes     Yes        Yes         No
    71  Outlook Temperature Humidity Wind PlayTennis
    72  */

    结果:

    Outlook= 0.2467498197744391
    Humidity= 0.15183550136234136
    Wind= 0.04812703040826927
    Temperature= 0.029222565658954647
  • 相关阅读:
    null和undefined的区别
    "NetworkError: 404 Not Found fontawesome-webfont.woff?v=4.0.3
    php字符串
    php数组
    Oracle 和 MySQL的区别(不完整)
    拦截器和过滤器的区别
    SpringMVC和Struts2的区别
    Redis的介绍
    SpringBoot入门(2)
    SpringBoot入门(1)
  • 原文地址:https://www.cnblogs.com/gongxijun/p/5835589.html
Copyright © 2020-2023  润新知