• 深度学习应用系列(四)| 使用 TFLite Android构建自己的图像识别App


            深度学习要想落地实践,一个少不了的路径即是朝着智能终端、嵌入式设备等方向发展。但终端设备没有GPU服务器那样的强大性能,那如何使得终端设备应用上深度学习呢?

    所幸谷歌已经推出了TFMobile,去年又更进一步,推出了TFLite,其应用思路为在GPU服务器上利用迁移学习训练自己的模型,然后将定制化模型移植到TFLite上,

    终端设备仅利用模型做前向推理,预测结果。本文基于以下三篇文章而成:

          相信大家掌握后,也能轻松定制化自己的图像识别应用。

    第一步. 准备数据

      数据下载地址为:http://download.tensorflow.org/example_images/flower_photos.tgz  

      这是一个关于花分类的图片集合,下载解压后,可以看出有5个品种分类:daisy(雏菊)、dandelion(蒲公英)、rose(玫瑰)、sunflower(向日葵)、tulip(郁金香)。

      我们的目的即是通过重新训练预编译模型,得到一个花类识别的模型。

    第二步. 重新训练

      1. 挑选预编译模型

            从上述“谷歌提供的预编译模型”列表中,我们大体可以看出分为两类模型,一种是Float Models(浮点数模型),一种是Quantized Models(量化模型),什么区别呢?

      其实Float Models表示为一种高精度值的模型,该模型意味着模型size较大,识别精度更高、识别时长更长,适合高性能终端设备;而Quantized Models则反之,是低精度值的模型,其精度采取固定的8位大小,故其模型size较小,识别精度低、识别时长较短,适合低性能终端设备,更细的说明可以参见 https://www.tensorflow.org/performance/quantization  。

      我们的手机设备更新换代很快,一般可以使用Float Models。在这个模型下,有不少预编译模型可选,对于本文来说,主要集中为Inception 和Mobilenet两种架构。  

           注意Mobilenet其实也分为很多种类,如Mobilenet_V1_0.50_224,其中第三个参数为模型大小比例值(只能算是近似,不准确),分为0.25/0.50/0.75/1.0四个比例值,第四个参数为图片大小,其值有128/160/192/224四种值。

      有兴趣想观察各模型层次结构的可通过以下代码查看:  

    import tensorflow as tf
    import tensorflow.gfile as gfile
    
    MODEL_PATH = '/home/yourname/Documents/mobilenet_v1_1.0_224/frozen_graph.pb'
    
    def main(unusedArgv):
        with tf.Graph().as_default() as graph:
            with gfile.FastGFile(MODEL_PATH, 'rb') as f:
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(f.read())
                tf.import_graph_def(graph_def, name='')
        for op in graph.get_operations():
            for tensor in op.values():
                print(tensor)
    
    if __name__ == '__main__':
        tf.app.run()

      考虑到测试手机性能还不赖,我们选择mobilenet_v1_1.0_224这个版本作为我们的预编译模型。

      2. 下载训练代码

      需要下载训练模型代码和android相关代码,如下:

    git clone https://github.com/googlecodelabs/tensorflow-for-poets-2
    
    cd tensorflow-for-poets-2

      其中,scripts目录下的retrain.py是我们需要关注的,这个代码目前仅支持Inception_v3和Mobilenet两种预编译模型,默认的训练模型为Inception_v3。

      3. 重新训练模型

      两种模型的训练命令不同,若走默认的Inception_v3模型,可通过如下命令: 

    python -m scripts.retrain 
    --learning_rate=0.01
    --bottleneck_dir=tf_files/bottlenecks --how_many_training_steps=4000 --model_dir=tf_files/models/ --output_graph=tf_files/retrained_graph.pb --output_labels=tf_files/retrained_labels.txt --image_dir=tf_files/flower_photos

      若走Mobilenet模型,可通过如下命令:

    python -m scripts.retrain 
     --learning_rate=0.01  
    --bottleneck_dir=tf_files/bottlenecks --how_many_training_steps=4000 --model_dir=tf_files/models/ --output_graph=tf_files/retrained_graph.pb --output_labels=tf_files/retrained_labels.txt --image_dir=tf_files/flower_photos --architecture=mobilenet_1.0_224

      模型命令解释如下:

     --architecture 为架构类型,支持mobilenet和Inception_v3两种
       --image_dir 为数据地址,假定你在tensorflow-for-poets-2目录下建立了tflite目录,把花图片集放入其中
       --output_labels 最后训练生成模型的标签,由于花图片集合已经按照子目录进行了分类,故retrained_labels.txt最后包含了上述五种花的分类名称
       --output_graph 最后训练生成的模型
       --model_dir 命令启动后,预编译模型的下载地址
       --how_many_training_steps 训练步数,不指定的话默认为4000
       --bottleneck_dir用来把top层的训练数据缓存成文件
       --learning_rate 学习率
       此外,还有些参数可以根据需要进行调整:
       --testing_percentage 把图片按多少比例划分出来当做test数据,默认为10
       --validation_percentage 把图片按多少比例划分出来当做validation数据,默认为10,这两个值设置完后,training数据占比80%
       --eval_step_interval 多少步训练后进行一次评估,默认为10
       --train_batch_size 一次训练的图片数,默认为100
       --validation_batch_size 一次验证的图片数,默认为100
       --random_scale 给定一个比例值,然后随机扩大训练图片的大小,默认为0
       --random_brightness 给定一个比例值,然后随机增强或减弱训练图片的明亮程度,默认为0
       --random_crop 给定一个比例值,然后随机裁剪训练图片的边缘值,默认为0 

        4. 检验训练效果

        我们用Mobilenet_1.0_224进行训练,完成后找一张图片看看是否能正确识别:

    python -m scripts.label_image 
      --graph=tf_files/retrained_graph.pb  
      --image=tf_files/flower_photos/daisy/3475870145_685a19116d.jpg

    结果为:

    Evaluation time (1-image): 1.010s
    
    daisy (score=0.62305)
    tulips (score=0.22490)
    dandelion (score=0.14169)
    roses (score=0.00966)
    sunflowers (score=0.00071)

    还是准确地识别了daisy出来。

        5. 转换模型格式

          pb格式是不能运行在TFLite上的,TFLite吸收了谷歌的protobuffer优点,创造了FlatBuffer格式,具体表现就是后缀名为.tflite的文件。

         上述TOCO的官网已经介绍了如何通过命令行把pb格式转成为tflite文件,或者在代码里也可以转换格式。不仅支持pb格式,也支持HDF5文件格式转换成tflite,实现了与其他框架的模型共享。

         那如何转呢?本例通过命令行方式转换。若训练模型为Inception_v3,命令行方式如下:

    toco 
      --graph_def_file=tf_files/retrained_graph.pb 
      --output_file=tf_files/optimized_graph.lite 
      --input_format=TENSORFLOW_GRAPHDEF 
      --output_format=TFLITE 
      --input_shape=1,299,299,3 
      --input_array=Mul 
      --output_array=final_result 
      --inference_type=FLOAT 
      --input_data_type=FLOAT

         若训练模型为mobilenet,命令行方式则如下:

    toco 
      --graph_def_file=tf_files/retrained_graph.pb 
      --output_file=tf_files/optimized_graph.lite 
      --input_format=TENSORFLOW_GRAPHDEF 
      --output_format=TFLITE 
      --input_shape=1,224,224,3 
      --input_array=input 
      --output_array=final_result 
      --inference_type=FLOAT 
      --input_data_type=FLOAT

      需要说明几点:

           --input_array 参数表示模型图结构的入口tensor op名称,mobilenet的入口名称为input,Inception_v3的入口名称为Mul,为什么这样?可查看scripts/retrain.py代码里内容:

      if architecture == 'inception_v3':
        # pylint: disable=line-too-long
        data_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
        # pylint: enable=line-too-long
        bottleneck_tensor_name = 'pool_3/_reshape:0'
        bottleneck_tensor_size = 2048
        input_width = 299
        input_height = 299
        input_depth = 3
        resized_input_tensor_name = 'Mul:0'
        model_file_name = 'classify_image_graph_def.pb'
        input_mean = 128
        input_std = 128
    elif architecture.startswith('mobilenet_'):
    ... data_url
    = 'http://download.tensorflow.org/models/mobilenet_v1_' data_url += version_string + '_' + size_string + '_frozen.tgz' bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0' bottleneck_tensor_size = 1001 input_width = int(size_string) input_height = int(size_string) input_depth = 3 resized_input_tensor_name = 'input:0'

      其中的resized_input_tensor_name即是新生成模型的入口名称,大家也可以通过上面“1.挑选预编译模型”的代码可视化查看新生成的模型层次结构。所以名称必须正确写对,否则运行该命令会抛出“ValueError: Invalid tensors 'input' were found” 的异常。

           --output_array则是模型的出口名称。为什么是final_result这个名称,因为在scripts/retrain.py里有:

    parser.add_argument(
          '--final_tensor_name',
          type=str,
          default='final_result',
          help="""
          The name of the output classification layer in the retrained graph.
          """
      )

           即出口名称默认为final_result。

           --input_shape 需要注意的是mobilenet的训练图片大小为224,而Inception_v3的训练图片大小为299。

           最后optimized_graph.lite即是我们要移植到android上的模型文件啦。

    第三步. Android TFLite 

           1. 下载Android Studio

             这一步骤不是本文重点,请大家自行在 https://developer.android.com/studio/ 进行下载安装,安装最新的SDK和NDK。

      2. 引入工程

      从android studio上引入 tensorflow-for-poets-2/android/tflite 下的代码,共有四个类,有三个类是跟布局打交道,而我们只需要关注ImageClassifier.java类。

            3. 导入模型

      可通过命令行方式把生成的模型导入上述工程的资源目录下:

    cp tf_files/optimized_graph.lite android/tflite/app/src/main/assets/mobilenet.lite 
    cp tf_files/retrained_labels.txt android/tflite/app/src/main/assets/mobilenet.txt

       4. 修改ImageClassifier.java类

      注意修改四个地方即可:

     /** Name of the model file stored in Assets. */
      private static final String MODEL_PATH = "mobilenet.lite";
    
      /** Name of the label file stored in Assets. */
      private static final String LABEL_PATH = "mobilenet.txt";
    
      static final int DIM_IMG_SIZE_X = 224; //若是inception,改成299
      static final int DIM_IMG_SIZE_Y = 224; //若是inception,改成299

      5. 运行观看效果

      连上手机后,点击“Run”->"Run app"即会部署app到手机上,此时任何被摄像头捕获的图片都会按照标签里的5个分类进行识别排名。

           我们可以通过百度搜一些这五种类别的花进行识别,以看看其识别的正确率。

    后记:根据我的测试结果,在花的图片集上,mobilenet_1.0_244模型生成的新模型识别率较高,而inception_v3模型生成的新模型识别率较低或不准。

           建议大家新的数据集可在两种模型间进行比较,以找到最适合自己的模型。

      

      

           

      

     

      

      

  • 相关阅读:
    Python基础:18类和实例之二
    Python基础:17类和实例之一(类属性和实例属性)
    Python基础:16面向对象概述
    Python基础:15私有化
    Python基础:14生成器
    Python基础:13装饰器
    Python基础:12函数细节
    Python基础:11变量作用域和闭包
    gcc需找头文件路径
    监控系统
  • 原文地址:https://www.cnblogs.com/hutao722/p/9603113.html
Copyright © 2020-2023  润新知