原创声明:本文系作者原创,转载请写明出处。
一、前言
前几天由于科研需要,一直在搞矩阵的稀疏表示的乘法,不过最近虽然把程序写出来了,还是无法处理大规模的矩阵(虽然已经是稀疏了)。原因可能是结果不够稀疏。或者相乘的矩阵本来也不稀疏。
还是把实现的程序放在这里。以供以后研究使用。
二、程序实现功能
首先封装稀疏矩阵为三元组形式。
程序的主要功能有:
稀疏矩阵的转置
稀疏矩阵的乘法
稀疏矩阵的加法
以及相应的导入文本文件(矩阵)等。
三、代码展示
以下程序由eclipse下编写的java
package others; import java.io.BufferedReader; import java.io.File; import java.io.FileOutputStream; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.Map; import java.util.Map.Entry; import weka.clusterers.SimpleKMeans; import weka.core.DistanceFunction; import weka.core.Instances; import weka.core.converters.ArffLoader; import Jama.Matrix; /* * 本类可实现稀疏矩阵三元组表示下的矩阵乘法和矩阵加法,以及矩阵转置等。结果也是三元组存储。 * 但是当数据量非常庞大时,乘积的结果无法存储,会出现内存溢出的现象。 */ public class SMatrix { public Map<ArrayList<Integer>,Integer> Triples;//矩阵的三元组表示 public int rowNum;//矩阵行数 public int colNum;//矩阵列数 public int getRowNum() { return rowNum; } public void setRowNum(int rowNum) { this.rowNum = rowNum; } public int getColNum() { return colNum; } public void setColNum(int colNum) { this.colNum = colNum; } /* * 构造函数1 */ public SMatrix(){ } /* * 构造函数2 */ public SMatrix(Map<ArrayList<Integer>, Integer> triples, int rowNum, int colNum) { Triples = triples; this.rowNum = rowNum; this.colNum = colNum; } /* * 构造函数3 */ public SMatrix(Map<ArrayList<Integer>, Integer> triples) { Triples = triples; } /* * 稀疏矩阵相乘函数 */ public SMatrix Multiply(SMatrix M,SMatrix N){ if(M.colNum != N.rowNum){ System.out.println("矩阵相乘不满足条件"); return null; } Map<ArrayList<Integer>,Integer> triples = new HashMap<ArrayList<Integer>,Integer>(); Iterator<Map.Entry<ArrayList<Integer>, Integer>> it1 = M.Triples.entrySet().iterator(); int iter = 0; while(it1.hasNext()){ iter++; // System.out.println("迭代次数:"+iter); Entry<ArrayList<Integer>, Integer> entry = it1.next(); ArrayList<Integer> position = entry.getKey(); // System.out.println("检查程序:" + position); int value = entry.getValue(); int flag = 0; Iterator<Map.Entry<ArrayList<Integer>, Integer>> it2 = N.Triples.entrySet().iterator(); while(it2.hasNext()){ Entry<ArrayList<Integer>,Integer> entry2 = it2.next(); ArrayList<Integer> position2 = entry2.getKey(); int value2 = entry2.getValue(); if(position.get(1) == position2.get(0)){ flag = 1; ArrayList<Integer> temp = new ArrayList<Integer>(); temp.add(position.get(0)); temp.add(position2.get(1)); int v = value * value2; if(triples.containsKey(temp)){ triples.put(temp, triples.get(temp) + v); System.out.println(temp+ " "+(triples.get(temp) + v)); } else{ triples.put(temp, v); System.out.println(temp + " " + v); } } } } SMatrix s = new SMatrix(triples,M.rowNum,N.colNum); return s; } /* * 稀疏矩阵相加函数 */ public static SMatrix Add(SMatrix M,SMatrix N){ if(M.colNum != N.colNum || M.rowNum != N.rowNum){ System.out.println("矩阵相加不满足条件"); return null; } SMatrix s = new SMatrix(); Map<ArrayList<Integer>,Integer> triples = new HashMap<ArrayList<Integer>,Integer>(); Iterator<Map.Entry<ArrayList<Integer>, Integer>> it1 = M.Triples.entrySet().iterator(); Iterator<Map.Entry<ArrayList<Integer>, Integer>> it2 = N.Triples.entrySet().iterator(); while(it1.hasNext()){ Entry<ArrayList<Integer>, Integer> entry = it1.next(); ArrayList<Integer> position = entry.getKey(); int value = entry.getValue(); if(triples.containsKey(position)){ triples.put(position, triples.get(position) + value); }else{ triples.put(position, value); } } while(it2.hasNext()){ Entry<ArrayList<Integer>,Integer> entry = it2.next(); ArrayList<Integer> position = entry.getKey(); int value = entry.getValue(); if(triples.containsKey(position)){ triples.put(position, triples.get(position) + value); }else{ triples.put(position, value); } } return s; } /* * 稀疏矩阵求转置矩阵函数 */ public SMatrix Transposition(){ Map<ArrayList<Integer>,Integer> triples = new HashMap<ArrayList<Integer>,Integer>(); Iterator<Map.Entry<ArrayList<Integer>, Integer>> it = this.Triples.entrySet().iterator(); while(it.hasNext()){ Entry<ArrayList<Integer>, Integer> entry = it.next(); ArrayList<Integer> position = entry.getKey(); int value = entry.getValue(); ArrayList<Integer> transP = new ArrayList<Integer>(); transP.add(position.get(1)); transP.add(position.get(0)); triples.put(transP, value); } SMatrix s = new SMatrix(triples,this.colNum,this.rowNum); return s; } /* * 加载文本数据为稀疏矩阵三元组形式的函数 */ public SMatrix Load(String file, String delimeter){ Map<ArrayList<Integer>,Integer> triples = new HashMap<ArrayList<Integer>,Integer>(); try{ File f = new File(file); FileReader fr = new FileReader(f); BufferedReader br = new BufferedReader(fr); String line; while((line = br.readLine()) != null){ String[] str = line.trim().split(delimeter); ArrayList<Integer> s = new ArrayList<Integer>(); for(int i = 0;i < str.length - 1; i++){ s.add(Integer.parseInt(str[i])); } triples.put(s, Integer.parseInt(str[str.length - 1])); } br.close(); fr.close(); }catch(IOException e){ e.printStackTrace(); } SMatrix sm = new SMatrix(triples); return sm; } /* * 打印稀疏矩阵(三元组形式) */ public void Print(){ Map<ArrayList<Integer>, Integer> triples = this.Triples; Iterator<Map.Entry<ArrayList<Integer>, Integer>> it = triples.entrySet().iterator(); int num = 0; while(it.hasNext()){ Entry<ArrayList<Integer>, Integer> entry = it.next(); ArrayList<Integer> position = entry.getKey(); num++; System.out.print(num+":"); for(Integer in:position){ System.out.print(in + " "); } System.out.println(entry.getValue()); } } public static void main(String[] args){ /* * 测试程序 String testS = "data/me"; int k = 3; SMatrix te = new SMatrix(); te = te.Load(testS," "); te.rowNum = 4; te.colNum = 6; System.out.println("打印原矩阵"); te.Print(); System.out.println("打印原矩阵的转置矩阵"); te.Transposition().Print(); System.out.println("打印乘积矩阵"); SMatrix A2 = new SMatrix(); A2 = te.Multiply(te, te.Transposition()); A2.Print(); */ long start = System.currentTimeMillis(); String file1 = "data/AT.txt";//author to term 的稀疏矩阵 String file2 = "data/CA.txt";//conference to author 的稀疏矩阵 String delimeter = " "; int k = 11; SMatrix M = new SMatrix(); SMatrix MT = new SMatrix(); SMatrix N = new SMatrix(); SMatrix NT = new SMatrix(); SMatrix P = new SMatrix(); SMatrix Q = new SMatrix(); M = M.Load(file1, delimeter); M.colNum = 9225; M.rowNum = 6456; System.out.println("打印矩阵M"); M.Print(); MT = M.Transposition(); System.out.println("打印矩阵MT"); MT.Print(); System.out.println("计算M和MT的乘积"); System.out.println(M.rowNum); P = M.Multiply(M, MT); System.out.println("打印矩阵M与矩阵M转置的乘积"); P.Print(); N = N.Load(file2, delimeter); N.colNum = 6456; N.rowNum = 20; System.out.println("打印矩阵N"); N.Print(); NT = N.Transposition(); System.out.println("打印矩阵NT:"); NT.Print(); System.out.println("计算NT 和 N的乘积"); System.out.println(NT.colNum); System.out.println(N.rowNum); Q = M.Multiply(NT, N); Q.Print(); SMatrix A = new SMatrix(); A = A.Load("data/AA.txt"," "); SMatrix A1 = new SMatrix(); SMatrix A2 = new SMatrix(); System.out.println("计算矩阵A1=P+Q:"); A1 = SMatrix.Add(Q, P); System.out.println("打印矩阵A1:"); A1.Print(); A2 = SMatrix.Add(A1, A);//得到了比较全面的author to author 矩阵三元组 A2.Print(); double[][] matrix = new double[A2.rowNum][A2.colNum]; for(int i = 0;i < A2.rowNum;i++){ for (int j = 0; j < A2.colNum; j++) { ArrayList<Integer> list = new ArrayList<Integer>(); list.add(i); list.add(j); if (A2.Triples.containsKey(list)) { matrix[i][j] = A2.Triples.get(list); } else{ matrix[i][j] = 0; } } } for(int i = 0;i<A2.rowNum;i++){ for(int j = 0;j < A2.colNum;j++){ System.out.print(matrix[i][j]+" "); } System.out.println(); } Matrix Author = new Matrix(matrix); //第二步:求矩阵的特征值eigValue及其相应的特征向量矩阵,取前K个(最大的) Matrix diagA = Author.eig().getD(); diagA.print(4, 2); int m = diagA.getRowDimension(); int n = diagA.getColumnDimension(); Matrix eigVector = Author.eig().getV(); eigVector.print(eigVector.getRowDimension(),4); //将特征向量输出到文本中。 String outFile = "data/eigenVector.txt"; try{ File f = new File(outFile); FileOutputStream fout = new FileOutputStream(f); fout.write("@RELATION eigenVector ".getBytes()); for(int i = n-k;i<n;i++){ fout.write(("@ATTRIBUTE "+i + " REAL ").getBytes()); } fout.write("@DATA ".getBytes()); if(k <= n){ for(int i = 0;i < m;i++){ for(int j = n-k;j<n;j++){ Double temp = new Double(eigVector.getArray()[i][j]); String tem = temp.toString(); fout.write((tem + " ").getBytes()); } fout.write((" ").getBytes()); } } } catch(IOException e){ e.printStackTrace(); } //第三步:对特征向量矩阵进行kmeans聚类 Instances ins = null; SimpleKMeans KM = null; // 目前没有使用到,但是在3.7.10的版本之中可以指定距离算法 // 默认是欧几里得距离 DistanceFunction disFun = null; try { // 读入样本数据 File file = new File("data/eigenVector.txt"); ArffLoader loader = new ArffLoader(); loader.setFile(file); ins = loader.getDataSet(); // 初始化聚类器 (加载算法) KM = new SimpleKMeans(); KM.setNumClusters(2); //设置聚类要得到的类别数量 KM.setMaxIterations(100); KM.buildClusterer(ins); //开始进行聚类 System.out.println(KM.preserveInstancesOrderTipText()); // 打印聚类结果 System.out.println(KM.toString()); // for(String option : KM.getOptions()) { // System.out.println(option); // } // System.out.println("CentroIds:" + tempIns); } catch(Exception e) { e.printStackTrace(); } System.out.println("程序正常结束"); long end = System.currentTimeMillis(); System.out.println(end - start); } }