• Tensorflow Lite tflite模型的生成与导入


    假如想要在ARM板上用tensorflow lite,那么意味着必须要把PC上的模型生成tflite文件,然后在ARM上导入这个tflite文件,通过解析这个文件来进行计算。
    根据前面所说,tensorflow的所有计算都会在内部生成一个图,包括变量的初始化,输入定义等,那么即便不是经过训练的神经网络模型,只是简单的三角函数计算,也可以生成一个tflite模型用于在tensorflow lite上导入。所以,这里我就只做了简单的sin()计算来跑一编这个流程。

    生成tflite模型

    这部分主要是调用TFLiteConverter函数,直接生成tflite文件,不再通过pb文件转化。
    先上代码:

    import numpy as np
    import time
    import math
    import tensorflow as tf
    
    SIZE = 1000
    X = np.random.rand(SIZE, 1)
    X = X*(math.pi/2.0)
    
    start = time.time()
    x1 = tf.placeholder(tf.float32, [SIZE, 1], name='x1-input')
    x2 = tf.placeholder(tf.float32, [SIZE, 1], name='x2-input')
    y1 = tf.sin(x1)
    y2 = tf.sin(x2)
    y = y1*y2
    
    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        converter = tf.lite.TFLiteConverter.from_session(sess, [x1, x2], [y])
        tflite_model = converter.convert()
        open("/home/alcht0/share/project/tensorflow-v1.12.0/converted_model.tflite", "wb").write(tflite_model)
    
    end = time.time()
    print("2nd ", str(end - start))
    转化函数
    主要遇到的问题是tensorflow的变化实在太快,这些个转化函数一直在变。位置也一直在变,现在参考官方文档,是按上面代码中调用,否则就会报找不到lite之类的错误。我现在PC上的tensorflow Python版本是1.13,所以lite已经在contrib外面了,如果是以前的版本,要按文档中下面这样调用。
     
    TensorFlow Version Python API
    1.12 tf.contrib.lite.TFLiteConverter
    1.9-1.11 tf.contrib.lite.TocoConverter
    1.7-1.8 tf.contrib.lite.toco_convert

    输入参数shape

    本来在本文件中为了给定的输入数据大小自由,x1,x2shape会写成[None, 1],但是如果这样写,转化成tflite模型后会默认为[1,1],并不能自由接收数据大小,所以在这里要指定大小SIZE

    x1 = tf.placeholder(tf.float32, [SIZE, 1], name='x1-input')

    导入tflite模型

    本来这部分应该是在ARM板子上做的,但是为了验证tflite文件的可用性,我先在PC的Python上试验。先上代码:

    import tensorflow as tf
    import numpy as np
    import math
    import time
    
    SIZE = 1000
    X = np.random.rand(SIZE, 1, ).astype(np.float32)
    X = X*(math.pi/2.0)
    
    start = time.time()
    
    interpreter = tf.lite.Interpreter(model_path="/home/alcht0/share/project/tensorflow-v1.12.0/converted_model.tflite")
    interpreter.allocate_tensors()
    
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    
    interpreter.set_tensor(input_details[0]['index'], X)
    interpreter.set_tensor(input_details[1]['index'], X)
    
    interpreter.invoke()
    
    output_data = interpreter.get_tensor(output_details[0]['index'])
    end = time.time()
    print("1st ", str(end - start))
    首先根据tflite文件生成解析器,然后用allocate_tensors()分配内存。将输入通过set_tensor传入,然后调用invoke()来真正运行。最后得到输出。
    Python跑的时候可以很清楚的看到input_details的数据结构。官方的例子是只传入一个数据,所以只需要取input_details[0],而我传入了2个输入,所以需要设置2个。同时可以看到input_details的2个数据的名字都是我在之前设置的x1-inputx2-input,这样非常好理解。
    输入参数类型
    这里有个坑是输入参数的类型一定要注意。我在生成模型的时候定义的输入参数类型是tf.float32,而在导入的时候如果直接是X = np.random.rand(SIZE, 1, )的话,会报错:
    ValueError: Cannot set tensor: Got tensor of type 0 but expected type 1 for input 3

    这里把通过astype(np.float32)把输入参数指定为float32就OK了。

    • 操作不支持的坑
      可以从前面的代码里看到我写了两个sin(),其实一开始是一个sin()一个cos()的,但是好像默认的tflite模型不支持cos()操作,无法生成,所以我只好暂时先只写sin(),后面再研究怎么把cos()加上。
  • 相关阅读:
    ASP.NET Routing Debugger
    浏览器 CSS & JS Hack 手册
    基于vmWare5.5环境的VxWorks系统安装总结
    TFS 迁移到 Git
    关于websocket
    自定义单一模块Model类
    学习 C++的用途
    Navigation Controllers and Table Views(中)
    Mac环境下svn的使用
    减少.NET应用程序内存占用的一则实践
  • 原文地址:https://www.cnblogs.com/scarecrow-blog/p/11475139.html
Copyright © 2020-2023  润新知