• 使用java调用python训练出的pmml模型


    记录下自己的过程,以后可以随时用,如果能帮到大家就更好了。

    从安装软件说起,嫌麻烦的就别看了。

    一、下载工具(俗话说得好,预先善其事必先利其器!哈哈)

    我刚开始安装的是eclipse,但有诸多麻烦不能解决,就用了IDEA,和Pycharm一个公司发行的。

    首先进入官网: http://www.jetbrains.com/products.html#lang=java

    选择IDEA下载:

    由于社区版的功能太少,我下载的是企业版的,后边会告诉破解方法。

    IDEA的安装教程网上都有,正常安装就好。

    企业版的激活码大家可以关注一个公众号,我也是在网上找到的。

    http://idea.medeming.com/

    关注公众号后粘贴就行了。

    二、Java环境安装

    参考教程:https://blog.csdn.net/weixin_38381149/article/details/89668578

    写博客时想找当时看的博客,但发现了这个很全的,jdk,maven,tomcat都有。

    想当初我为了装一个maven花了好久。。。

    三、新建Maven项目

      File ==》New==》Project==》Maven

    四、接下来在IDEA中配置Maven,这是当时参考的博客:https://www.cnblogs.com/jiangzhaowei/p/9534393.html

    五、添加依赖

      由于我只是为了调用模型,没有太多依赖,只添加了这么几个

        <dependencies>
    
            <dependency>
                <groupId>org.jpmml</groupId>
                <artifactId>pmml-evaluator</artifactId>
                <version>1.4.1</version>
            </dependency>
            <dependency>
                <groupId>org.jpmml</groupId>
                <artifactId>pmml-evaluator-extension</artifactId>
                <version>1.4.1</version>
            </dependency>
    
            <dependency>
                <groupId>javax.xml.bind</groupId>
                <artifactId>jaxb-api</artifactId>
                <version>2.3.0</version>
            </dependency>
            <dependency>
                <groupId>com.sun.xml.bind</groupId>
                <artifactId>jaxb-core</artifactId>
                <version>2.3.0</version>
            </dependency>
            <dependency>
                <groupId>com.sun.xml.bind</groupId>
                <artifactId>jaxb-impl</artifactId>
                <version>2.3.0</version>
            </dependency>
    
        </dependencies>

    六、java调用Python训练出的pmml模型的代码

    import org.dmg.pmml.FieldName;
    import org.dmg.pmml.PMML;
    import org.jpmml.evaluator.*;
    import org.jpmml.model.PMMLUtil;
    import org.xml.sax.SAXException;
    
    import javax.xml.bind.JAXBException;
    import java.io.FileInputStream;
    import java.io.FileNotFoundException;
    import java.io.IOException;
    import java.io.InputStream;
    import java.util.ArrayList;
    import java.util.HashMap;
    import java.util.List;
    import java.util.Map;
    
    public class ClassificationModel {
        private Evaluator modelEvaluator;
    
        /**
         * 通过传入 PMML 文件路径来生成机器学习模型
         *
         * @param pmmlFileName pmml 文件路径
         */
        public ClassificationModel(String pmmlFileName) {
            PMML pmml = null;
    
            try {
                if (pmmlFileName != null) {
                    InputStream is = new FileInputStream(pmmlFileName);
                    pmml = PMMLUtil.unmarshal(is);
                    try {
                        is.close();
                    } catch (IOException e) {
                        System.out.println("InputStream close error!");
                    }
    
                    ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
    
                    this.modelEvaluator = (Evaluator) modelEvaluatorFactory.newModelEvaluator(pmml);
                    modelEvaluator.verify();
                    System.out.println("加载模型成功!");
                }
            } catch (SAXException e) {
                e.printStackTrace();
            } catch (JAXBException e) {
                e.printStackTrace();
            } catch (FileNotFoundException e) {
                e.printStackTrace();
            }
    
        }
    
        // 获取模型需要的特征名称
        public List<String> getFeatureNames() {
            List<String> featureNames = new ArrayList<String>();
    
            List<InputField> inputFields = modelEvaluator.getInputFields();
    
            for (InputField inputField : inputFields) {
                featureNames.add(inputField.getName().toString());
            }
            return featureNames;
        }
    
        // 获取目标字段名称
        public String getTargetName() {
            return modelEvaluator.getTargetFields().get(0).getName().toString();
        }
    
        // 使用模型生成概率分布
        private ProbabilityDistribution getProbabilityDistribution(Map<FieldName, ?> arguments) {
            Map<FieldName, ?> evaluateResult = modelEvaluator.evaluate(arguments);
    
            FieldName fieldName = new FieldName(getTargetName());
    
            return (ProbabilityDistribution) evaluateResult.get(fieldName);
    
        }
    
        // 预测不同分类的概率
        public ValueMap<String, Number> predictProba(Map<FieldName, Number> arguments) {
            ProbabilityDistribution probabilityDistribution = getProbabilityDistribution(arguments);
            return probabilityDistribution.getValues();
        }
    
        // 预测结果分类
        public Object predict(Map<FieldName, ?> arguments) {
            ProbabilityDistribution probabilityDistribution = getProbabilityDistribution(arguments);
    
            return probabilityDistribution.getPrediction();
        }
    
        public static void main(String[] args) {
            ClassificationModel clf = new ClassificationModel("D:/JupyterSpace/RandomForestClassifier_Iris.pmml"); //这里模型地址
    
            List<String> featureNames = clf.getFeatureNames();
            System.out.println("feature: " + featureNames);
    
            // 构建待预测数据
            Map<FieldName, Number> waitPreSample = new HashMap<>();
         #这里的key一定要对应python中的列名 waitPreSample.put(
    new FieldName("sepal length (cm)"), 10); waitPreSample.put(new FieldName("sepal width (cm)"), 1); waitPreSample.put(new FieldName("petal length (cm)"), 3); waitPreSample.put(new FieldName("petal width (cm)"), 2); System.out.println("waitPreSample predict result: " + clf.predict(waitPreSample).toString()); System.out.println("waitPreSample predictProba result: " + clf.predictProba(waitPreSample).toString()); } }

    注意事项:

    1、类名和文件名要一致

    2、打开File  ==》Project Structure

    看你的JDK版本和这里是否一致

    运行程序,查看是否报错。

    这是我报的一个错:

    NoClassDefFoundError: javax/activation/DataSource

      解决方法是下载:activation.jar包。

      下载地址:

        链接:https://pan.baidu.com/s/14D8cQWIJp2d7h2iljAPZ2A
        提取码:6f37

    应该没什么问题了。有问题请留言,一定回复。(有问题一定要告诉我,以后还要用呢。。。)

    https://www.cnblogs.com/zhangzhixing/
  • 相关阅读:
    poj 3616 Milking Time
    poj 3176 Cow Bowling
    poj 2229 Sumsets
    poj 2385 Apple Catching
    poj 3280 Cheapest Palindrome
    hdu 1530 Maximum Clique
    hdu 1102 Constructing Roads
    codeforces 592B The Monster and the Squirrel
    CDOJ 1221 Ancient Go
    hdu 1151 Air Raid(二分图最小路径覆盖)
  • 原文地址:https://www.cnblogs.com/zhangzhixing/p/12095815.html
Copyright © 2020-2023  润新知