• TensorFlow 图像分类模型 inception_resnet_v2 模型导出、冻结与使用


    1. 背景

    作为一名深度学习萌新,项目突然需要使用图像分类模型去作分类,因此找到了TensorFlow的模型库,使用它的框架进行训练和后续的操作,项目地址:https://github.com/tensorflow/models/tree/master/research/slim

    在使用真正的数据集之前,我首先使用的是它提供的flowers的数据集,用的模型是inception_resnet_v2,因为top-5 Accuracy比较高嘛。

    然后我安装flowers的目录结构,将我的数据按照类似的结构进行组织;

    仿照download_and_convert_flowers.py增加了自己的数据处理文件convert_normal_data.py;

    仿照数据集读取文件flowers.py增加了自己的文件normal.py;

    然后使用项目的教程,一步步的进行fine-tuning,直到准确率到了百分之九十以上,停止训练。

    但是这个时候在导出模型的时候遇到了坑。

    2. 导出Inference Graph

    实际上教程写得很简单,就是先导出模型的框架:

    Saves out a GraphDef containing the architecture of the model.

    然后再往框架里把训练好的checkpoints写到graph中:

    If you then want to use the resulting model with your own or pretrained checkpoints as part of a mobile model, you can run freeze_graph to get a graph def with the variables inlined

    它放出来的教程是这样的:

    $ python export_inference_graph.py 
      --alsologtostderr 
      --model_name=inception_v3 
      --output_file=/tmp/inception_v3_inf_graph.pb

    我安装这个格式去把模型改成inception_resnet_v2,然后把checkpoint导进去,总是会报:

    tensorflow.python.framework.errors_impl.InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [1001] rhs shape= [2]
    [[{{node save/Assign_916}}]]

    找了个群问了一下,说是模型最后一层输出的数目没有改变,于是重新理了思路,去看了export_inference_graph.py的源码,发现里面有个num_classes的参数,是用来决定最后输出层的数量的,于是最后增加了一下导出参数,最后的命令为:

    python export_inference_graph.py 
      --alsologtostderr 
      --model_name=${MODEL_NAME} 
      --dataset_name=normal 
      --dataset_dir=${DATASET_DIR} 
      --output_file=/you/path/to/sava/${MODEL_NAME}_inf_graph.pb

    最后获得我的graph.pb。

    3. 冻结Graph

    冻结是个大坑,为什么呢,因为官方给出的教程是使用bazel先编译freeze_graph,然后再使用它进行模型冻结。麻烦来了,首先Ubuntu 18.04无法使用apt进行安装,所以一番折腾,使用它放出的install脚本进行了安装。

    然后是需要git clone TensorFlow的源码进行编译,这个编译期间又报了很多错,而且我编译失败后,conda环境的TensorFlow GPU版本还不能用了。。。

    最后发现,如果你已经使用conda或者git安装了TensorFlow,直接使用

    find / -name freeze_graph.py

    找出这个python文件的位置就行了,最后使用命令:

    python tensorflow/python/tools/freeze_graph.py 
      --input_graph=/you/path/to/sava/${MODEL_NAME}_inf_graph.pb 
      --input_checkpoint=/you/trained/checkpoints/model.ckpt-10000 
      --input_binary=true 
      --output_node_names=InceptionResnetV2/Logits/Predictions 
      --output_graph=/your/path/to/save/frozen_graph.pb

    最后终于导出了模型。

    4. 使用模型进行预测

    主要参考了博文【深度学习-模型eval+模型导出】使用Tensorflow Slim对训练的模型进行评估+导出模型,进行微调:

    # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    # ==============================================================================
     
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
     
    import argparse
    import os.path
    import re
    import sys
    import tarfile
     
    import numpy as np
    from six.moves import urllib
    import tensorflow as tf
     
    FLAGS = None
     
    class NodeLookup(object):
      def __init__(self, label_lookup_path=None):
        self.node_lookup = self.load(label_lookup_path)
     
      def load(self, label_lookup_path):
        node_id_to_name = {}
        with open(label_lookup_path) as f:
          for line in f:
            line_list = line.strip().split(":")
            node_id_to_name[int(line_list[0])] = line_list[1]
        return node_id_to_name
     
      def id_to_string(self, node_id):
        if node_id not in self.node_lookup:
          return ''
        return self.node_lookup[node_id]
     
     
    def create_graph():
      """Creates a graph from saved GraphDef file and returns a saver."""
      # Creates graph from saved graph_def.pb.
      with tf.gfile.FastGFile(FLAGS.model_path, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        _ = tf.import_graph_def(graph_def, name='')
     
    def preprocess_for_eval(image, height, width,
                            central_fraction=0.875, scope=None):
      with tf.name_scope(scope, 'eval_image', [image, height, width]):
        if image.dtype != tf.float32:
          image = tf.image.convert_image_dtype(image, dtype=tf.float32)
        # Crop the central region of the image with an area containing 87.5% of
        # the original image.
        if central_fraction:
          image = tf.image.central_crop(image, central_fraction=central_fraction)
     
        if height and 
          # Resize the image to the specified height and width.
          image = tf.expand_dims(image, 0)
          image = tf.image.resize_bilinear(image, [height, width],
                                           align_corners=False)
          image = tf.squeeze(image, [0])
        image = tf.subtract(image, 0.5)
        image = tf.multiply(image, 2.0)
        return image
     
    def run_inference_on_image(image):
      """Runs inference on an image.
      Args:
        image: Image file name.
      Returns:
        Nothing
      """
      with tf.Graph().as_default():
        image_data = tf.gfile.FastGFile(image, 'rb').read()
        image_data = tf.image.decode_jpeg(image_data)
        image_data = preprocess_for_eval(image_data, 299, 299)
        image_data = tf.expand_dims(image_data, 0)
        with tf.Session() as sess:
          image_data = sess.run(image_data)
     
      # Creates graph from saved GraphDef.
      create_graph()
     
      with tf.Session() as sess:
        softmax_tensor = sess.graph.get_tensor_by_name('InceptionResnetV2/Logits/Predictions:0')
        predictions = sess.run(softmax_tensor,
                               {'input:0': image_data})
        predictions = np.squeeze(predictions)
     
        # Creates node ID --> English string lookup.
        node_lookup = NodeLookup(FLAGS.label_path)
     
        top_k = predictions.argsort()[-FLAGS.num_top_predictions:][::-1]
        for node_id in top_k:
          human_string = node_lookup.id_to_string(node_id)
          score = predictions[node_id]
          print('%s (score = %.5f)' % (human_string, score))
     
     
    def main(_):
      image = FLAGS.image_file
      run_inference_on_image(image)
     
     
    if __name__ == '__main__':
      parser = argparse.ArgumentParser()
      parser.add_argument(
          '--model_path',
          type=str,
      )
      parser.add_argument(
          '--label_path',
          type=str,
      )
      parser.add_argument(
          '--image_file',
          type=str,
          default='',
          help='Absolute path to image file.'
      )
      parser.add_argument(
          '--num_top_predictions',
          type=int,
          default=5,
          help='Display this many predictions.'
      )
      FLAGS, unparsed = parser.parse_known_args()
      tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

    最后使用一张图片进行测试:

    python classify_image_inception_resnet_v2.py 
      --model_path /your/saved/path/frozen_graph.pb 
      --label_path /your/path/labels.txt 
      --image_file /your/path/test.jpg

    最后输出:

    unsuited (score = 0.94713)
    suited (score = 0.05287)

    虽然有点高兴,但是蓦然回首,还是很心累,然后现在conda的TensorFlow GPU版本跪了,需要修复。

    5. 参考

    (1) 【深度学习-模型eval+模型导出】使用Tensorflow Slim对训练的模型进行评估+导出模型

    (2) 【Tensorflow系列】使用Inception_resnet_v2训练自己的数据集并用Tensorboard监控

    (完)

  • 相关阅读:
    npm ci命令比npm installer命令快2至10倍
    Liferay 7.1发布啦
    2016/07/05 zend optimizer
    2016/06/16 phpexcel
    2016/06/13 phpexcel 未完待续
    2016/06/10 日历插件 Datepicker
    2016/06/09 ThinkPHP3.2.3使用分页
    2016/06/02 网摘记录 svn 服务器端 客户端 安装使用
    2016/05/27 php上传文件常见问题总结
    2016/05/25 抽象类与API(接口)差别
  • 原文地址:https://www.cnblogs.com/harrymore/p/12149756.html
Copyright © 2020-2023  润新知