主要记录以下输入、输出参数处理过程,其他初始化百度资料很多。
背景
项目中用到鉴黄识别,从Github上找到了别人训练好的pb模型,项目地址: https://github.com/kingroc711/CVSample/tree/master/TensorFlow/inception_model
但是项目中只提供了python代码,首先对python不熟悉,并且发现tensorflow提供了对java预测模型的支持,并且项目使用的是java,所以想把tensorflow 集成到项目中,调用pb模型预测。
但通过tensorboard工具查看模型时发现输入参数为string,虽然可以跑通,但到现在也不理解入参为什么设计成string类型.
pb文件参数(output_graph.pb)
在调用模型之前,需要先清楚模型输入、输出参数类型。
输入名称:DecodeJpeg/contents:0 类型: string,实际传入图片文件原始数据就可以
输出名称:final_result:0 类型: float
这个文件的输入、输出参数类型,通过CVSample项目库中python调用代码,找到输入、输出名称
也可以先用python生成日志,通过tensorboard工具分析日志,拿到模型输入、输出参数
推荐参考示例(LabelImage):
tensorflow 官方有一个labelImg的java示例,如果第一次使用tensorflow java api,应该会对你有用: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java
如果想运行这个示例,下载示例中提到的模型: https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
这个示例中,输入、输出参数和鉴黄识别模型参数不太一样,所以也会有一些区别。
在这个示例中,对图片进行了一些图像预处理。
图像是否需要预处理,需要看模型,有些模型需要,有些不需要(比如这个鉴黄模型)。
精简代码:
tensorflow: 1.15.0
<dependency> <groupId>org.tensorflow</groupId> <artifactId>tensorflow</artifactId> <version>1.15.0</version> </dependency>
import org.tensorflow.Graph;
import org.tensorflow.Output;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Arrays;
public static void main(String[] args) throws IOException { try (Graph g = new Graph()) { //pb 模型文件 byte modelBytes[] = Files.readAllBytes(new File("/opt/work/java_work/tensorflow_demo/inception_model/output_graph.pb").toPath()); g.importGraphDef(modelBytes); try (Session s = new Session(g)) { //生成输入参数,此处生成从 https://github.com/tensorflow/tensorflow/blob/v1.15.0/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java 中找到的方法 Tensor<String> tensor = (Tensor<String>) Tensor.create(Files.readAllBytes(Paths.get("/root/test.png"))); Tensor<Float> result = s.runner() //输入参数 .feed("DecodeJpeg/contents:0", tensor) //输出参数 .fetch("final_result:0") .run() .get(0) .expect(Float.class); //存储结果容器, 输出固定有5条数据,分别是每个分类(0:porn 1:neutral 2:hentai 3:drawings 4:sexy)的分数 float[][] values = new float[1][5]; result.copyTo(values); System.out.println(Arrays.toString(values[0])); //结果[0.027002065, 0.8941082, 0.02338332, 0.044249564, 0.011256761] //porn(色情): 0.027002065, neutral(正常): 0.8941082, hentai: 0.02338332, drawings: 0.044249564, sexy(性感): 0.011256761 } } }
前前后后为了生成输入参数查了一周,网上资料是真的少,为了有相同问题的人可以快速解决,避免和我类似情况出现,所以此处记录以下。