• c#抽取pdf文档标题(4)——机器学习以及决策树


            我的一位同事告诉我,pdf抽取标题,用机器学习可以完美解决问题,抽取的准确率比较高。于是,我看了一些资料,就动起手来,实践了下。

            我主要是根据以往历史块的特征生成一个决策树,然后利用这棵决策树,去判断一个新的块到底是不是标题。理论上,历史块的数量越庞大,那么结果越准确。其实经过实践不是这样的,我觉得影响结果判断的因素越少,而且库的数量达到一定数量后,判断越准确。这个记录块信息的历史库,就是供计算机学习的原料。

           首先看下,如何形成一个决策树?

     1  private static DecisionTreeID3<string> BuildTree()
     2         {
     3             //var blockList = Tools.SelectList("/config/Blocks/Block");
     4 
     5             var blockList = DBHelper.Select<BlockData>();
     6 
     7             string[,] da = new string[blockList.Count, 6];
     8 
     9             for (int i = 0; i < blockList.Count; i++)
    10             {
    11                 var index = blockList[i].Index;
    12 
    13                 if (index >= 1 && index <= 5)
    14                 {
    15                     da[i, 0] = "high";
    16                 }
    17                 else if (index >= 6 && index <= 12)
    18                 {
    19                     da[i, 0] = "middle";
    20                 }
    21                 else
    22                 {
    23                     da[i, 0] = "low";
    24                 }
    25                 var space = blockList[i].Space.ToString() == "非数字" ? 0 : (int)blockList[i].Space;
    26 
    27                 if (space >= 3 && space <= 10 || space >= 17 && space <= 20)
    28                 {
    29                     da[i, 1] = "high";
    30                 }
    31                 else if (space >= 11 && space <= 16)
    32                 {
    33                     da[i, 1] = "middle";
    34                 }
    35                 else
    36                 {
    37                     da[i, 1] = "low";
    38                 }
    39 
    40                 var xSize = blockList[i].XSize;
    41 
    42                 if (xSize >= 11 && xSize <= 19 || xSize >= 400 && xSize <= 440 || xSize >= 250 && xSize <= 260)
    43                 {
    44                     da[i, 2] = "high";
    45                 }
    46                 else
    47                 {
    48                     da[i, 2] = "low";
    49                 }
    50 
    51                 var ySize = blockList[i].YSize;
    52 
    53                 if (ySize >= 11 && ySize <= 19 || ySize >= 250 && ySize <= 290 || ySize >= 400 && ySize <= 440)
    54                 {
    55                     da[i, 3] = "high";
    56                 }
    57                 else
    58                 {
    59                     da[i, 3] = "low";
    60                 }
    61 
    62                 var height = (int)blockList[i].Height;
    63 
    64                 if (height >= 6 && height <= 13 || height >= 22 && height <= 24)
    65                 {
    66                     da[i, 4] = "high";
    67                 }
    68                 else
    69                 {
    70                     da[i, 4] = "low";
    71                 }
    72                 da[i, 5] = blockList[i].IsTitle.ToString();
    73             }
    74 
    75             var names = new string[] { "Index", "Space", "XSize", "YSize", "Height", "IsTitle" };
    76             var tree = new DecisionTreeID3<string>(da, names, new string[] { "True", "False" });
    77             tree.Learn();
    78             return tree;
    79         }

    把数据库中的块信息,通过转换,变成二维数组,而且每个特征值被转为离散的值,之前的值是几乎连续的值,它有多少个,无法确定,转为离散的值,才能控制决策树的规模。下面,我们看看决策树类 DecisionTreeID3:

      1  public class DecisionTreeID3<T> where T : IEquatable<T>
      2     {
      3         T[,] Data;
      4         string[] Names;
      5         int Category;
      6         T[] CategoryLabels;
      7         public DecisionTreeNode<T> Root;
      8         public DecisionTreeID3(T[,] data, string[] names, T[] categoryLabels)
      9         {
     10             Data = data;
     11             Names = names;
     12             Category = data.GetLength(1) - 1;//类别变量需要放在最后一列
     13             CategoryLabels = categoryLabels;
     14         }
     15         public void Learn()
     16         {
     17             int nRows = Data.GetLength(0);
     18             int nCols = Data.GetLength(1);
     19             int[] rows = new int[nRows];
     20             int[] cols = new int[nCols];
     21             for (int i = 0; i < nRows; i++) rows[i] = i;
     22             for (int i = 0; i < nCols; i++) cols[i] = i;
     23             Root = new DecisionTreeNode<T>(-1, default(T));
     24             Learn(rows, cols, Root);
     25 
     26             DisplayNode(Root);
     27         }
     28 
     29         public bool Search(string[] test, DecisionTreeNode<T> Node = null)
     30         {
     31             bool isResult = false;
     32 
     33             if (Node == null) Node = Root;
     34 
     35             foreach (var item in Node.Children)
     36             {
     37                 var label = item.Label;
     38                 if (label < test.Length - 1 && test[label] != item.Value.ToString()) continue;
     39                 else
     40                 {
     41                     if (label == test.Length - 1 && item.Value.ToString() == "True")
     42                     {
     43                         isResult = true;
     44                         return isResult;
     45                     }
     46                     else
     47                     {
     48                         isResult = Search(test, item);
     49                     }
     50                 }
     51             }
     52             return isResult;
     53         }
     54 
     55         public StringBuilder sb = new StringBuilder();
     56 
     57         public void DisplayNode(DecisionTreeNode<T> Node, int depth = 0)
     58         {
     59             if (Node.Label != -1)
     60             {
     61                 string nodeStr = string.Format("{0} {1}: {2}", new string('-', depth * 3), Names[Node.Label], Node.Value);
     62                 sb.AppendLine(nodeStr);
     63             }
     64             foreach (var item in Node.Children)
     65                 DisplayNode(item, depth + 1);
     66         }
     67         private void Learn(int[] pnRows, int[] pnCols, DecisionTreeNode<T> Root, int depth = 0)
     68         {
     69             var categoryValues = GetAttribute(Data, Category, pnRows);
     70             var categoryCount = categoryValues.Distinct().Count();
     71             if (categoryCount == 1)
     72             {
     73                 var node = new DecisionTreeNode<T>(Category, categoryValues.First());
     74                 Root.Children.Add(node);
     75             }
     76             else
     77             {
     78                 if (depth > 10) return;
     79 
     80                 if (pnRows.Length == 0) return;
     81                 else if (pnCols.Length == 1)
     82                 {
     83                     //投票~
     84                     //多数票表决制
     85                     var Vote = categoryValues.GroupBy(i => i).OrderBy(i => i.Count()).First();
     86                     var node = new DecisionTreeNode<T>(Category, Vote.First());
     87                     Root.Children.Add(node);
     88                 }
     89                 else
     90                 {
     91                     //var maxCol = MaxEntropy(pnRows, pnCols);
     92 
     93                     //按c4.5算法
     94                     var maxCol = MaxEntropyRate(pnRows, pnCols);
     95 
     96                     var attributes = GetAttribute(Data, maxCol, pnRows).Distinct();
     97                     string currentPrefix = Names[maxCol];
     98                     foreach (var attr in attributes)
     99                     {
    100                         int[] rows = pnRows.Where(irow => Data[irow, maxCol].Equals(attr)).ToArray();
    101                         int[] cols = pnCols.Where(i => i != maxCol).ToArray();
    102                         var node = new DecisionTreeNode<T>(maxCol, attr);
    103                         Root.Children.Add(node);
    104                         Learn(rows, cols, node, depth + 1);//递归生成决策树
    105                     }
    106                 }
    107             }
    108         }
    109         public double AttributeInfo(int attrCol, int[] pnRows)
    110         {
    111             var tuples = AttributeCount(attrCol, pnRows);
    112             var sum = (double)pnRows.Length;
    113             double Entropy = 0.0;
    114             foreach (var tuple in tuples)
    115             {
    116                 int[] count = new int[CategoryLabels.Length];
    117                 foreach (var irow in pnRows)
    118                     if (Data[irow, attrCol].Equals(tuple.Item1))
    119                     {
    120                         int index = Array.IndexOf(CategoryLabels, Data[irow, Category]);
    121                         count[index]++;//目前仅支持类别变量在最后一列
    122                     }
    123                 double k = 0.0;
    124                 for (int i = 0; i < count.Length; i++)
    125                 {
    126                     double frequency = count[i] / (double)tuple.Item2;
    127                     double t = -frequency * Log2(frequency);
    128                     k += t;
    129                 }
    130                 double freq = tuple.Item2 / sum;
    131                 Entropy += freq * k;
    132             }
    133             return Entropy;
    134         }
    135 
    136         public double AttributeInfoRate(int attrCol, int[] pnRows)
    137         {
    138             var tuples = AttributeCount(attrCol, pnRows);
    139             var sum = (double)pnRows.Length;
    140             double SplitE = 0.0;
    141 
    142             foreach (var tuple in tuples)
    143             {
    144                 double frequency = tuple.Item2 / (double)sum;
    145                 double t = -frequency * Log2(frequency);
    146                 SplitE += t;
    147             }
    148             return SplitE;
    149         }
    150 
    151         public double CategoryInfo(int[] pnRows)
    152         {
    153             var tuples = AttributeCount(Category, pnRows);
    154             var sum = (double)pnRows.Length;
    155             double Entropy = 0.0;
    156             foreach (var tuple in tuples)
    157             {
    158                 double frequency = tuple.Item2 / sum;
    159                 double t = -frequency * Log2(frequency);
    160                 Entropy += t;
    161             }
    162             return Entropy;
    163         }
    164         private static IEnumerable<T> GetAttribute(T[,] data, int col, int[] pnRows)
    165         {
    166             foreach (var irow in pnRows)
    167                 yield return data[irow, col];
    168         }
    169         private static double Log2(double x)
    170         {
    171             return x == 0.0 ? 0.0 : Math.Log(x, 2.0);
    172         }
    173         /// <summary>
    174         /// 计算增益率
    175         /// </summary>
    176         /// <param name="pnRows"></param>
    177         /// <param name="pnCols"></param>
    178         /// <returns></returns>
    179         public int MaxEntropy(int[] pnRows, int[] pnCols)
    180         {
    181             double cateEntropy = CategoryInfo(pnRows);
    182             int maxAttr = 0;
    183             double max = double.MinValue;
    184             foreach (var icol in pnCols)
    185                 if (icol != Category)
    186                 {
    187                     double Gain = cateEntropy - AttributeInfo(icol, pnRows);
    188                     if (max < Gain)
    189                     {
    190                         max = Gain;
    191                         maxAttr = icol;
    192                     }
    193                 }
    194             return maxAttr;
    195         }
    196         /// <summary>
    197         /// 计算增益率最大的属性
    198         /// </summary>
    199         /// <param name="pnRows"></param>
    200         /// <param name="pnCols"></param>
    201         /// <returns></returns>
    202         public int MaxEntropyRate(int[] pnRows, int[] pnCols)
    203         {
    204             double cateEntropy = CategoryInfo(pnRows);
    205             int maxAttr = 0;
    206             double max = double.MinValue;
    207             foreach (var icol in pnCols)
    208                 if (icol != Category)
    209                 {
    210                     double Gain = cateEntropy - AttributeInfo(icol, pnRows);
    211 
    212                     double SplitE = AttributeInfoRate(icol, pnRows);
    213 
    214                     double GrainRation = Gain / SplitE;
    215 
    216                     if (max < GrainRation)
    217                     {
    218                         max = GrainRation;
    219                         maxAttr = icol;
    220                     }
    221                 }
    222             return maxAttr;
    223         }
    224 
    225         public IEnumerable<Tuple<T, int>> AttributeCount(int col, int[] pnRows)
    226         {
    227             var tuples = from n in GetAttribute(Data, col, pnRows)
    228                          group n by n into i
    229                          select Tuple.Create(i.First(), i.Count());
    230             return tuples;
    231         }
    232     }
    233 
    234     public sealed class DecisionTreeNode<T>
    235     {
    236         public int Label { get; set; }
    237         public T Value { get; set; }
    238         public List<DecisionTreeNode<T>> Children { get; set; }
    239         public DecisionTreeNode(int label, T value)
    240         {
    241             Label = label;
    242             Value = value;
    243             Children = new List<DecisionTreeNode<T>>();
    244         }
    245     }

           这个类里面包含着两个算法,C4.5和ID3,C4.5是在ID3的基础上进行改进的一种算法。我采取了C4.5的算法,在94行。C4.5 算法,是用信息增益率来选择属性。ID3选择属性用的是子树的信息增益,这里可以用很多方法来定义信息,ID3使用的是熵(entropy, 熵是一种不纯度度量准则),也就是熵的变化值,而C4.5用的是信息增益率。  此处信息量比较大,可以参考 http://shiyanjun.cn/archives/428.html 这篇文章。

           决策树建好后,我们开始调用:

    1            var tree = BuildTree();
    2            //打印树
    3             tree.sb.ToString();
    4 
    5             //用树来预测
    6             var test = new string[] { "True", "False", "True", "False", "False", "" };
    7           
    8             bool isTitle = tree.Search(test);

    第三行,是把树型结构输出来,最后两行是判断一个块信息是否是标题。这个数组当然也是数值转换为离散值后的结果。

    有一点必须得明确,就是决策树得剪裁,否则有可能导致内存泄漏。决策类中的78行,如果树的层次结构超过了10层,就停止生长了。其实在规则过滤和决策树预测,我选择了规则过滤,因为用决策树的结果,经测试,准确率并不高,有可能是我才开始用,没有把握精髓,所以我保守选择。

  • 相关阅读:
    强大的Resharp插件
    配置SPARK 2.3.0 默认使用 PYTHON3
    python3 数据库操作
    python3 学习中的遇到一些难点
    log4j的一个模板分析
    MYSQL内连接,外连接,左连接,右连接
    rabbitmq实战记录
    领域模型分析
    分布式系统学习笔记
    阿里开发规范 注意事项
  • 原文地址:https://www.cnblogs.com/wangqiang3311/p/7743906.html
Copyright © 2020-2023  润新知