项目中有时候需要用到对数据进行关联分析,比如分析一个小商店中顾客购买习惯.
1 package com.data.algorithm; 2 3 import com.google.common.base.Splitter; 4 import com.google.common.collect.Lists; 5 import com.google.common.collect.Maps; 6 import org.slf4j.Logger; 7 import org.slf4j.LoggerFactory; 8 9 import java.io.BufferedReader; 10 import java.io.FileInputStream; 11 import java.io.IOException; 12 import java.io.InputStreamReader; 13 import java.util.*; 14 15 /** 16 * ********************************************************* 17 * <p/> 18 * Author: XiJun.Gong 19 * Date: 2017-01-20 15:06 20 * Version: default 1.0.0 21 * Class description: 22 * <p/> 23 * ********************************************************* 24 */ 25 26 class EOC { 27 28 private static final Logger logger = LoggerFactory.getLogger(EOC.class); 29 private Map<String, Integer> fmap; //forward map 30 private Map<Integer, String> bmap; //backward map 31 private List<Map<String, Integer>> elements = null; 32 33 private Integer maxDimension; 34 35 public EOC(final String pathFile, String separatSeq) { 36 37 BufferedReader bufferedReader = null; 38 try { 39 this.fmap = Maps.newHashMap(); 40 this.bmap = Maps.newHashMap(); 41 this.elements = Lists.newArrayList(); 42 maxDimension = 0; 43 bufferedReader = new BufferedReader( 44 new InputStreamReader( 45 new FileInputStream(pathFile), "UTF-8")); 46 String _line = null; 47 Integer keyValue = null, mapIndex = 0; 48 while ((_line = bufferedReader.readLine()) != null) { 49 Map<String, Integer> lineMap = Maps.newHashMap(); 50 if (_line.trim().length() > 1) { 51 if (separatSeq.trim().length() < 1) { 52 separatSeq = ","; 53 } 54 for (String word : Splitter.on(separatSeq).split(_line)) { 55 word = word.trim(); 56 if (null == (keyValue = fmap.get(word))) { 57 keyValue = mapIndex++; 58 } 59 fmap.put(word, keyValue); 60 bmap.put(keyValue, word); 61 lineMap.put(word, keyValue); 62 } 63 if (maxDimension < lineMap.size()) 64 maxDimension = lineMap.size(); 65 elements.add(lineMap); 66 } 67 } 68 } catch (Exception e) { 69 logger.error("读取文件出错 , 错误原因:{}", e); 70 } finally { 71 if (bufferedReader != null) { 72 try { 73 bufferedReader.close(); 74 } catch (IOException e) { 75 logger.error("bufferedReader , 错误原因:{}", e); 76 } 77 } 78 } 79 } 80 81 public Integer getMaxDimension() { 82 return maxDimension; 83 } 84 85 public float getRateOfSet(Collection<Integer> elementChild) { 86 float rateCnt = 0f; 87 int allSize = 1; 88 for (Map<String, Integer> eMap : elements) { 89 boolean flag = true; 90 for (Integer element : elementChild) { 91 if (null == eMap.get(bmap.get(element))) { 92 flag = false; 93 break; 94 } 95 } 96 if (flag) rateCnt += 1; 97 } 98 return rateCnt / ((allSize = elements.size()) > 1 ? (float) allSize : 1.0f); 99 } 100 101 public Set<Integer> getElements() { 102 103 return new HashSet<Integer>(fmap.values()); 104 } 105 106 public Integer queryByKey(String key) { 107 return fmap.get(key); 108 } 109 110 public String queryByValue(Integer value) { 111 return bmap.get(value); 112 } 113 } 114 115 public class Apriori { 116 private static final Logger logger = LoggerFactory.getLogger(Apriori.class); 117 private EOC eoc = null; 118 private Integer maxDimension; 119 private final float exp = 1e-4f; 120 121 public Apriori(final String pathFile, String separatSeq, Integer maxDimension) { 122 this(pathFile, separatSeq); 123 this.maxDimension = maxDimension; 124 } 125 126 public Apriori(final String pathFile, String separatSeq) { 127 this.eoc = new EOC(pathFile, separatSeq); 128 this.maxDimension = this.eoc.getMaxDimension(); 129 } 130 131 public void work(float confidenceLevel) { 132 List<Set<Integer>> listElement = null; 133 ArrayList<Set<Integer>> middleWareElement = null; 134 Map<Set<Integer>, Float> maps = null; 135 listElement = Lists.newArrayList(); 136 for (Integer element : this.eoc.getElements()) { 137 Set<Integer> set = new HashSet<Integer>(); 138 set.add(element); 139 listElement.add(set); 140 } 141 maps = Maps.newHashMap(); 142 middleWareElement = Lists.newArrayList(); 143 for (int i = 1; i < this.maxDimension; i++) { 144 for (Set<Integer> tmpSet : listElement) { 145 float rate = eoc.getRateOfSet(tmpSet); 146 if (confidenceLevel - exp <= rate) 147 maps.put(tmpSet, rate); 148 } 149 System.out.println("+++++++++++第 " + i + " 维度关联数据+++++++++++"); 150 output(maps); 151 listElement.clear(); 152 middleWareElement.addAll(maps.keySet()); 153 maps.clear(); 154 for (int j = 0; j < middleWareElement.size(); j++) { 155 Set<Integer> tmpSet = middleWareElement.get(j); 156 for (int k = j + 1; k < middleWareElement.size(); k++) { 157 Set<Integer> setChild = middleWareElement.get(k); 158 for (Integer label : setChild) { 159 if (!tmpSet.contains(label)) { 160 Set<Integer> newElement = new HashSet<Integer>(tmpSet); 161 newElement.add(label); 162 if (!listElement.contains(newElement)) { 163 listElement.add(newElement); 164 break; 165 } 166 } 167 } 168 } 169 } 170 middleWareElement.clear(); 171 } 172 } 173 174 public void output(Map<Set<Integer>, Float> maps) { 175 for (Map.Entry<Set<Integer>, Float> iter : maps.entrySet()) { 176 for (Integer integer : iter.getKey()) { 177 System.out.print(eoc.queryByValue(integer) + " "); 178 } 179 System.out.println(iter.getValue()*100+"%"); 180 } 181 } 182 }
1 package com.data.algorithm; 2 3 4 /** 5 * ********************************************************* 6 * <p/> 7 * Author: XiJun.Gong 8 * Date: 2017-01-17 17:57 9 * Version: default 1.0.0 10 * Class description: 11 * <p/> 12 * ********************************************************* 13 */ 14 public class Main { 15 public static void main(String args[]) { 16 Apriori apriori = new Apriori("/home/com/src/main/java/com/qunar/data/algorithm/demo.data", ","); 17 apriori.work(0.5f); 18 } 19 }
1 +++++++++++第 1 维度关联数据+++++++++++ 2 苹果 50.0% 3 西红柿 75.0% 4 香蕉 75.0% 5 矿泉水 75.0% 6 +++++++++++第 2 维度关联数据+++++++++++ 7 苹果 西红柿 50.0% 8 西红柿 香蕉 50.0% 9 西红柿 矿泉水 50.0% 10 香蕉 矿泉水 75.0% 11 +++++++++++第 3 维度关联数据+++++++++++ 12 西红柿 香蕉 矿泉水 50.0%