醉了~~~
package edu.dcy.weka;
import java.io.FileWriter;
import java.util.ArrayList;
import java.util.List;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.misc.InputMappedClassifier;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.SerializationHelper;
import weka.core.Utils;
import weka.core.converters.ConverterUtils.DataSource;
public class PredictModel {
public static void main(String[] args) throws Exception {
if(args.length!=3){
System.err.println("Usage: <model file> <test file > < output dir>");
System.exit(1);
}
// load data
System.out.println("loading data.....");
DataSource dstest = new DataSource(args[1]);
Instances test = dstest.getDataSet();
int attrnums=test.numAttributes();
// lack of the class lable 'Y'
if(attrnums==283){
List<String> strname = new ArrayList<String>();
strname.add("false");
strname.add("true");
Attribute attr = new Attribute("Y", strname);
test.insertAttributeAt(attr, test.numAttributes());
}
test.setClassIndex(test.numAttributes() - 1);
System.out.println("loading done.....");
System.out.println("loading predict model.....");
// load classifier
Object objs[]=SerializationHelper.readAll(args[0]);
Classifier cls=(Classifier) objs[0];
InputMappedClassifier mapper =new InputMappedClassifier();
mapper.setClassifier((Classifier) objs[0]);
mapper.setModelHeader((Instances) objs[1]);
mapper.setTestStructure(test);
cls=mapper;
// output predictions
System.out.println("execute ...");
FileWriter fw=new FileWriter(args[2]);
StringBuffer sb=new StringBuffer();
sb.append("#,actual,predicted,error,p_false,p_true
");
for (int i = 0; i < test.numInstances(); i++) {
Instance ins=test.instance(i);
double pred = cls.classifyInstance(ins);
double[] dist = cls.distributionForInstance(ins);
sb.append(i+1).append(",")
.append(ins.toString(test.classIndex()))
.append(",")
.append(test.classAttribute().value((int) pred))
.append(",");
if (pred != test.instance(i).classValue()) {
sb.append("yes");
} else {
sb.append("no");
}
sb.append(",").append(Utils.arrayToString(dist)).append("
");
// System.out.println(sb.toString());
fw.write(sb.toString());
fw.flush();
sb.delete(0, sb.length());
}
fw.close();
System.out.println("finished,please check the outfile .....");
// print the summary
if(attrnums==284){
Evaluation eval = new Evaluation(test);
eval.evaluateModel(cls, test);
System.out.println(eval.toClassDetailsString());
System.out.println(eval.toSummaryString());
System.out.println(eval.toMatrixString());
}
}
}