由于看网上的java有点多,自己写了一份,本人也是初学者,有错误请提出,大家一起学习。
1 import java.io.BufferedReader; 2 import java.io.File; 3 import java.io.FileNotFoundException; 4 import java.io.FileReader; 5 import java.io.IOException; 6 import java.util.*; 7 8 9 public class Index { 10 public static void main(String[] args){ 11 List<List<Double>> Filedatas = new ArrayList<List<Double>>(); 12 List<List<Double>> Testdatas = new ArrayList<List<Double>>(); 13 14 readFile(Filedatas,Testdatas); 15 KNN knn = new KNN(); 16 17 for(int i=0;i<Filedatas.size();i++){ 18 String s = knn.comdistance(3,Filedatas,Testdatas.get(i)); 19 print(s,Testdatas.get(i)); 20 } 21 } 22 //第4步、打印出结果 23 private static void print(String s,List<Double> testdata) { 24 System.out.print("测试数据:"); 25 for(int i=0;i<testdata.size();i++){ 26 System.out.print(testdata.get(i) + " "); 27 } 28 int label = Math.round(Float.parseFloat(s)); 29 System.out.println("所属类别:" + label); 30 } 31 32 //第1.1步、读取文件 33 private static void readFile(List<List<Double>> Filedatas, List<List<Double>> Testdatas) { 34 try { 35 BufferedReader bfd = new BufferedReader(new FileReader(new File("D://a.txt"))); 36 Filedatas = read(bfd,Filedatas); 37 BufferedReader bft = new BufferedReader(new FileReader(new File("D://b.txt"))); 38 Testdatas = read(bft,Testdatas); 39 } catch (FileNotFoundException e) { 40 e.printStackTrace(); 41 } 42 } 43 44 //第1.2步、读取文件 45 private static List<List<Double>> read(BufferedReader bf, List<List<Double>> datas) { 46 try { 47 String str = bf.readLine(); 48 while(str != null){ 49 List<Double> d = new ArrayList<Double>(); 50 String[] string = str.split(" "); 51 for (String s : string) { 52 d.add(Double.parseDouble(s)); 53 } 54 datas.add(d); 55 str = bf.readLine(); 56 } 57 } catch (IOException e) { 58 e.printStackTrace(); 59 } 60 return datas; 61 } 62 63 64 }
1 import java.util.Comparator; 2 import java.util.HashMap; 3 import java.util.List; 4 import java.util.Map; 5 import java.util.PriorityQueue; 6 7 public class KNN { 8 9 public String comdistance(int k, List<List<Double>> filedatas,List<Double> testdata) { 10 //第2.1步、对加入queue队列的项进行距离的排序 11 PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k,new Comparator<KNNNode>() { //优先级队列,按照distance的大小进行排列 12 @Override 13 public int compare(KNNNode o1, KNNNode o2) { 14 if(o1.getDistance() >= o2.getDistance()){ 15 return -1; 16 } 17 else{ 18 return 1; 19 } 20 } 21 }); 22 //第2.2步、计算测试点与训练点的距离,并add进队列,挑出与测试点距离最近的K个点 23 for(int i=0;i<k;i++){ 24 double distance = 0; 25 for(int j=0;j<filedatas.get(i).size()-1;j++){ 26 distance += (filedatas.get(i).get(j) - testdata.get(j)) * (filedatas.get(i).get(j) - testdata.get(j)); 27 } 28 KNNNode node = new KNNNode(filedatas.get(i).get(filedatas.get(i).size()-1).toString(),distance); 29 pq.add(node); 30 } 31 for(int i=k;i<filedatas.size();i++){ 32 double distance = 0; 33 for(int j=0;j<filedatas.get(i).size()-1;j++){ 34 distance += (filedatas.get(i).get(j) - testdata.get(j)) * (filedatas.get(i).get(j) - testdata.get(j)); 35 } 36 KNNNode node = new KNNNode(filedatas.get(i).get(filedatas.get(i).size()-1).toString(),distance); 37 if(pq.peek().getDistance() >= distance ){ 38 pq.remove(); 39 pq.add(node); 40 } 41 } 42 String s = decide(pq); 43 return s; 44 } 45 //第3步、把选择好的最近的K个点的类别进行比较,多的即是测试点的类别 46 private String decide(PriorityQueue<KNNNode> pq) { 47 Map<String,Integer> m = new HashMap<String,Integer>(); 48 for (KNNNode Node : pq) { 49 if(m.containsKey(Node.getC())){ 50 m.put(Node.getC(), m.get(Node.getC()) + 1); 51 } 52 else{ 53 m.put(Node.getC(), 1); 54 } 55 } 56 Object[] o = m.keySet().toArray(); 57 58 if(o.length == 1){ 59 return o[0].toString(); 60 } 61 else{ 62 for(int i=0;i<o.length;i++){ 63 for(int j=i;j<o.length;j++){ 64 if(i != j){ 65 if(m.get(o[i]) > m.get(o[j])){ 66 return o[i].toString(); 67 } 68 else{ 69 return o[j].toString(); 70 } 71 } 72 } 73 } 74 } 75 return null; 76 } 77 }
1 public class KNNNode { 2 3 private String c; 4 private double distance; 5 6 public KNNNode(String c, double distance) { 7 super(); 8 this.c = c; 9 this.distance = distance; 10 } 11 12 public String getC() { 13 return c; 14 } 15 public double getDistance() { 16 return distance; 17 } 18 public void setC(String c) { 19 this.c = c; 20 } 21 public void setDistance(double distance) { 22 this.distance = distance; 23 } 24 }
训练数据:
1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1
1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1
1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1
1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0
1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1
1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 0
测试数据:
1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5
1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8
1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2
1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5
1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5
1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5