我的一位同事告诉我,pdf抽取标题,用机器学习可以完美解决问题,抽取的准确率比较高。于是,我看了一些资料,就动起手来,实践了下。
我主要是根据以往历史块的特征生成一个决策树,然后利用这棵决策树,去判断一个新的块到底是不是标题。理论上,历史块的数量越庞大,那么结果越准确。其实经过实践不是这样的,我觉得影响结果判断的因素越少,而且库的数量达到一定数量后,判断越准确。这个记录块信息的历史库,就是供计算机学习的原料。
首先看下,如何形成一个决策树?
1 private static DecisionTreeID3<string> BuildTree() 2 { 3 //var blockList = Tools.SelectList("/config/Blocks/Block"); 4 5 var blockList = DBHelper.Select<BlockData>(); 6 7 string[,] da = new string[blockList.Count, 6]; 8 9 for (int i = 0; i < blockList.Count; i++) 10 { 11 var index = blockList[i].Index; 12 13 if (index >= 1 && index <= 5) 14 { 15 da[i, 0] = "high"; 16 } 17 else if (index >= 6 && index <= 12) 18 { 19 da[i, 0] = "middle"; 20 } 21 else 22 { 23 da[i, 0] = "low"; 24 } 25 var space = blockList[i].Space.ToString() == "非数字" ? 0 : (int)blockList[i].Space; 26 27 if (space >= 3 && space <= 10 || space >= 17 && space <= 20) 28 { 29 da[i, 1] = "high"; 30 } 31 else if (space >= 11 && space <= 16) 32 { 33 da[i, 1] = "middle"; 34 } 35 else 36 { 37 da[i, 1] = "low"; 38 } 39 40 var xSize = blockList[i].XSize; 41 42 if (xSize >= 11 && xSize <= 19 || xSize >= 400 && xSize <= 440 || xSize >= 250 && xSize <= 260) 43 { 44 da[i, 2] = "high"; 45 } 46 else 47 { 48 da[i, 2] = "low"; 49 } 50 51 var ySize = blockList[i].YSize; 52 53 if (ySize >= 11 && ySize <= 19 || ySize >= 250 && ySize <= 290 || ySize >= 400 && ySize <= 440) 54 { 55 da[i, 3] = "high"; 56 } 57 else 58 { 59 da[i, 3] = "low"; 60 } 61 62 var height = (int)blockList[i].Height; 63 64 if (height >= 6 && height <= 13 || height >= 22 && height <= 24) 65 { 66 da[i, 4] = "high"; 67 } 68 else 69 { 70 da[i, 4] = "low"; 71 } 72 da[i, 5] = blockList[i].IsTitle.ToString(); 73 } 74 75 var names = new string[] { "Index", "Space", "XSize", "YSize", "Height", "IsTitle" }; 76 var tree = new DecisionTreeID3<string>(da, names, new string[] { "True", "False" }); 77 tree.Learn(); 78 return tree; 79 }
把数据库中的块信息,通过转换,变成二维数组,而且每个特征值被转为离散的值,之前的值是几乎连续的值,它有多少个,无法确定,转为离散的值,才能控制决策树的规模。下面,我们看看决策树类 DecisionTreeID3:
1 public class DecisionTreeID3<T> where T : IEquatable<T> 2 { 3 T[,] Data; 4 string[] Names; 5 int Category; 6 T[] CategoryLabels; 7 public DecisionTreeNode<T> Root; 8 public DecisionTreeID3(T[,] data, string[] names, T[] categoryLabels) 9 { 10 Data = data; 11 Names = names; 12 Category = data.GetLength(1) - 1;//类别变量需要放在最后一列 13 CategoryLabels = categoryLabels; 14 } 15 public void Learn() 16 { 17 int nRows = Data.GetLength(0); 18 int nCols = Data.GetLength(1); 19 int[] rows = new int[nRows]; 20 int[] cols = new int[nCols]; 21 for (int i = 0; i < nRows; i++) rows[i] = i; 22 for (int i = 0; i < nCols; i++) cols[i] = i; 23 Root = new DecisionTreeNode<T>(-1, default(T)); 24 Learn(rows, cols, Root); 25 26 DisplayNode(Root); 27 } 28 29 public bool Search(string[] test, DecisionTreeNode<T> Node = null) 30 { 31 bool isResult = false; 32 33 if (Node == null) Node = Root; 34 35 foreach (var item in Node.Children) 36 { 37 var label = item.Label; 38 if (label < test.Length - 1 && test[label] != item.Value.ToString()) continue; 39 else 40 { 41 if (label == test.Length - 1 && item.Value.ToString() == "True") 42 { 43 isResult = true; 44 return isResult; 45 } 46 else 47 { 48 isResult = Search(test, item); 49 } 50 } 51 } 52 return isResult; 53 } 54 55 public StringBuilder sb = new StringBuilder(); 56 57 public void DisplayNode(DecisionTreeNode<T> Node, int depth = 0) 58 { 59 if (Node.Label != -1) 60 { 61 string nodeStr = string.Format("{0} {1}: {2}", new string('-', depth * 3), Names[Node.Label], Node.Value); 62 sb.AppendLine(nodeStr); 63 } 64 foreach (var item in Node.Children) 65 DisplayNode(item, depth + 1); 66 } 67 private void Learn(int[] pnRows, int[] pnCols, DecisionTreeNode<T> Root, int depth = 0) 68 { 69 var categoryValues = GetAttribute(Data, Category, pnRows); 70 var categoryCount = categoryValues.Distinct().Count(); 71 if (categoryCount == 1) 72 { 73 var node = new DecisionTreeNode<T>(Category, categoryValues.First()); 74 Root.Children.Add(node); 75 } 76 else 77 { 78 if (depth > 10) return; 79 80 if (pnRows.Length == 0) return; 81 else if (pnCols.Length == 1) 82 { 83 //投票~ 84 //多数票表决制 85 var Vote = categoryValues.GroupBy(i => i).OrderBy(i => i.Count()).First(); 86 var node = new DecisionTreeNode<T>(Category, Vote.First()); 87 Root.Children.Add(node); 88 } 89 else 90 { 91 //var maxCol = MaxEntropy(pnRows, pnCols); 92 93 //按c4.5算法 94 var maxCol = MaxEntropyRate(pnRows, pnCols); 95 96 var attributes = GetAttribute(Data, maxCol, pnRows).Distinct(); 97 string currentPrefix = Names[maxCol]; 98 foreach (var attr in attributes) 99 { 100 int[] rows = pnRows.Where(irow => Data[irow, maxCol].Equals(attr)).ToArray(); 101 int[] cols = pnCols.Where(i => i != maxCol).ToArray(); 102 var node = new DecisionTreeNode<T>(maxCol, attr); 103 Root.Children.Add(node); 104 Learn(rows, cols, node, depth + 1);//递归生成决策树 105 } 106 } 107 } 108 } 109 public double AttributeInfo(int attrCol, int[] pnRows) 110 { 111 var tuples = AttributeCount(attrCol, pnRows); 112 var sum = (double)pnRows.Length; 113 double Entropy = 0.0; 114 foreach (var tuple in tuples) 115 { 116 int[] count = new int[CategoryLabels.Length]; 117 foreach (var irow in pnRows) 118 if (Data[irow, attrCol].Equals(tuple.Item1)) 119 { 120 int index = Array.IndexOf(CategoryLabels, Data[irow, Category]); 121 count[index]++;//目前仅支持类别变量在最后一列 122 } 123 double k = 0.0; 124 for (int i = 0; i < count.Length; i++) 125 { 126 double frequency = count[i] / (double)tuple.Item2; 127 double t = -frequency * Log2(frequency); 128 k += t; 129 } 130 double freq = tuple.Item2 / sum; 131 Entropy += freq * k; 132 } 133 return Entropy; 134 } 135 136 public double AttributeInfoRate(int attrCol, int[] pnRows) 137 { 138 var tuples = AttributeCount(attrCol, pnRows); 139 var sum = (double)pnRows.Length; 140 double SplitE = 0.0; 141 142 foreach (var tuple in tuples) 143 { 144 double frequency = tuple.Item2 / (double)sum; 145 double t = -frequency * Log2(frequency); 146 SplitE += t; 147 } 148 return SplitE; 149 } 150 151 public double CategoryInfo(int[] pnRows) 152 { 153 var tuples = AttributeCount(Category, pnRows); 154 var sum = (double)pnRows.Length; 155 double Entropy = 0.0; 156 foreach (var tuple in tuples) 157 { 158 double frequency = tuple.Item2 / sum; 159 double t = -frequency * Log2(frequency); 160 Entropy += t; 161 } 162 return Entropy; 163 } 164 private static IEnumerable<T> GetAttribute(T[,] data, int col, int[] pnRows) 165 { 166 foreach (var irow in pnRows) 167 yield return data[irow, col]; 168 } 169 private static double Log2(double x) 170 { 171 return x == 0.0 ? 0.0 : Math.Log(x, 2.0); 172 } 173 /// <summary> 174 /// 计算增益率 175 /// </summary> 176 /// <param name="pnRows"></param> 177 /// <param name="pnCols"></param> 178 /// <returns></returns> 179 public int MaxEntropy(int[] pnRows, int[] pnCols) 180 { 181 double cateEntropy = CategoryInfo(pnRows); 182 int maxAttr = 0; 183 double max = double.MinValue; 184 foreach (var icol in pnCols) 185 if (icol != Category) 186 { 187 double Gain = cateEntropy - AttributeInfo(icol, pnRows); 188 if (max < Gain) 189 { 190 max = Gain; 191 maxAttr = icol; 192 } 193 } 194 return maxAttr; 195 } 196 /// <summary> 197 /// 计算增益率最大的属性 198 /// </summary> 199 /// <param name="pnRows"></param> 200 /// <param name="pnCols"></param> 201 /// <returns></returns> 202 public int MaxEntropyRate(int[] pnRows, int[] pnCols) 203 { 204 double cateEntropy = CategoryInfo(pnRows); 205 int maxAttr = 0; 206 double max = double.MinValue; 207 foreach (var icol in pnCols) 208 if (icol != Category) 209 { 210 double Gain = cateEntropy - AttributeInfo(icol, pnRows); 211 212 double SplitE = AttributeInfoRate(icol, pnRows); 213 214 double GrainRation = Gain / SplitE; 215 216 if (max < GrainRation) 217 { 218 max = GrainRation; 219 maxAttr = icol; 220 } 221 } 222 return maxAttr; 223 } 224 225 public IEnumerable<Tuple<T, int>> AttributeCount(int col, int[] pnRows) 226 { 227 var tuples = from n in GetAttribute(Data, col, pnRows) 228 group n by n into i 229 select Tuple.Create(i.First(), i.Count()); 230 return tuples; 231 } 232 } 233 234 public sealed class DecisionTreeNode<T> 235 { 236 public int Label { get; set; } 237 public T Value { get; set; } 238 public List<DecisionTreeNode<T>> Children { get; set; } 239 public DecisionTreeNode(int label, T value) 240 { 241 Label = label; 242 Value = value; 243 Children = new List<DecisionTreeNode<T>>(); 244 } 245 }
这个类里面包含着两个算法,C4.5和ID3,C4.5是在ID3的基础上进行改进的一种算法。我采取了C4.5的算法,在94行。C4.5 算法,是用信息增益率来选择属性。ID3选择属性用的是子树的信息增益,这里可以用很多方法来定义信息,ID3使用的是熵(entropy, 熵是一种不纯度度量准则),也就是熵的变化值,而C4.5用的是信息增益率。 此处信息量比较大,可以参考 http://shiyanjun.cn/archives/428.html 这篇文章。
决策树建好后,我们开始调用:
1 var tree = BuildTree(); 2 //打印树 3 tree.sb.ToString(); 4 5 //用树来预测 6 var test = new string[] { "True", "False", "True", "False", "False", "" }; 7 8 bool isTitle = tree.Search(test);
第三行,是把树型结构输出来,最后两行是判断一个块信息是否是标题。这个数组当然也是数值转换为离散值后的结果。
有一点必须得明确,就是决策树得剪裁,否则有可能导致内存泄漏。决策类中的78行,如果树的层次结构超过了10层,就停止生长了。其实在规则过滤和决策树预测,我选择了规则过滤,因为用决策树的结果,经测试,准确率并不高,有可能是我才开始用,没有把握精髓,所以我保守选择。