• tensorflow java 调用pb模型预测实例(CVSample 鉴黄检测)


    主要记录以下输入、输出参数处理过程,其他初始化百度资料很多。

    背景

    项目中用到鉴黄识别,从Github上找到了别人训练好的pb模型,项目地址: https://github.com/kingroc711/CVSample/tree/master/TensorFlow/inception_model

    但是项目中只提供了python代码,首先对python不熟悉,并且发现tensorflow提供了对java预测模型的支持,并且项目使用的是java,所以想把tensorflow 集成到项目中,调用pb模型预测。

    但通过tensorboard工具查看模型时发现输入参数为string,虽然可以跑通,但到现在也不理解入参为什么设计成string类型.

    pb文件参数(output_graph.pb)

    在调用模型之前,需要先清楚模型输入、输出参数类型。

    输入名称:DecodeJpeg/contents:0    类型: string,实际传入图片文件原始数据就可以
    输出名称:final_result:0       类型: float

    这个文件的输入、输出参数类型,通过CVSample项目库中python调用代码,找到输入、输出名称

    也可以先用python生成日志,通过tensorboard工具分析日志,拿到模型输入、输出参数

    推荐参考示例(LabelImage):

    tensorflow 官方有一个labelImg的java示例,如果第一次使用tensorflow java api,应该会对你有用: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java

    如果想运行这个示例,下载示例中提到的模型: https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip

    这个示例中,输入、输出参数和鉴黄识别模型参数不太一样,所以也会有一些区别。

    在这个示例中,对图片进行了一些图像预处理。 

    图像是否需要预处理,需要看模型,有些模型需要,有些不需要(比如这个鉴黄模型)。

    精简代码:

    tensorflow: 1.15.0

            <dependency>
                <groupId>org.tensorflow</groupId>
                <artifactId>tensorflow</artifactId>
                <version>1.15.0</version>
            </dependency>

    import org.tensorflow.Graph;
    import org.tensorflow.Output;
    import org.tensorflow.Session;
    import org.tensorflow.Tensor;

    import java.io.File;
    import java.io.IOException;
    import java.nio.file.Files;
    import java.nio.file.Paths;
    import java.util.Arrays;

    public static void main(String[] args) throws IOException { try (Graph g = new Graph()) { //pb 模型文件 byte modelBytes[] = Files.readAllBytes(new File("/opt/work/java_work/tensorflow_demo/inception_model/output_graph.pb").toPath()); g.importGraphDef(modelBytes); try (Session s = new Session(g)) { //生成输入参数,此处生成从 https://github.com/tensorflow/tensorflow/blob/v1.15.0/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java 中找到的方法 Tensor<String> tensor = (Tensor<String>) Tensor.create(Files.readAllBytes(Paths.get("/root/test.png"))); Tensor<Float> result = s.runner() //输入参数 .feed("DecodeJpeg/contents:0", tensor) //输出参数 .fetch("final_result:0") .run() .get(0) .expect(Float.class); //存储结果容器, 输出固定有5条数据,分别是每个分类(0:porn 1:neutral 2:hentai 3:drawings 4:sexy)的分数 float[][] values = new float[1][5]; result.copyTo(values); System.out.println(Arrays.toString(values[0])); //结果[0.027002065, 0.8941082, 0.02338332, 0.044249564, 0.011256761] //porn(色情): 0.027002065, neutral(正常): 0.8941082, hentai: 0.02338332, drawings: 0.044249564, sexy(性感): 0.011256761 } } }

    前前后后为了生成输入参数查了一周,网上资料是真的少,为了有相同问题的人可以快速解决,避免和我类似情况出现,所以此处记录以下。

  • 相关阅读:
    [LeetCode]题解(python):007-Reverse Integer
    [LeetCode]题解(python):006-ZigZag Conversion
    [LeetCode]题解:005-Longest Palindromic Substring优化
    [LeetCode]题解(python):005-Longest Palindromic Substring
    [LeetCode]题解(python):003-Longest Substring Without Repeating Characters
    [LeetCode]题解(python):002-Add Two Numbers
    [LeetCode]题解(python):001-Two-Sum
    【BZOJ1005】【HNOI2008】明明的烦恼
    BZOJ平推计划
    【BZOJ1004】【HNOI20008】cards
  • 原文地址:https://www.cnblogs.com/GengMingYan/p/16024211.html
Copyright © 2020-2023  润新知