• MxNet 模型转Tensorflow pb模型


    用mmdnn实现模型转换

    参考链接:https://www.twblogs.net/a/5ca4cadbbd9eee5b1a0713af

    1. 安装mmdnn
      pip install mmdnn
    2. 准备好mxnet模型的.json文件和.params文件, 以InsightFace MxNet r50为例        https://github.com/deepinsight/insightface
    3. 用mmdnn运行命令行
      python -m mmdnn.conversion._script.convertToIR -f mxnet -n model-symbol.json -w model-0000.params -d resnet50 --inputShape 3,112,112 

       会生成resnet50.json(可视化文件) resnet50.npy(权重参数) resnet50.pb(网络结构)三个文件。

    4. 用mmdnn运行命令行
      python -m mmdnn.conversion._script.IRToCode -f tensorflow --IRModelPath resnet50.pb --IRWeightPath resnet50.npy --dstModelPath tf_resnet50.py 

       生成tf_resnet50.py文件,可以调用tf_resnet50.py中的KitModel函数加载npy权重参数重新生成原网络框架。

    5. 打开tf_resnet.py文件,修改load_weights()中的代码 (tensorflow=1.14.0报错) 

       try:
              weights_dict = np.load(weight_file).item()
          except:
              weights_dict = np.load(weight_file, encoding='bytes').item()

      改为

       try:
              weights_dict = np.load(weight_file, allow_pickle=True).item()
      except:
              weights_dict = np.load(weight_file, allow_pickle=True, encoding='bytes').item()
    6. 基于resnet50.npy和tf_resnet50.py文​​件,固化参数,生成PB文件:

      import tensorflow as tf
      import tf_resnet50 as tf_fun
      def netWork():
          model=tf_fun.KitModel("./resnet50.npy")
          return model
      def freeze_graph(output_graph):
          output_node_names = "output"
          data,fc1=netWork()
          fc1=tf.identity(fc1,name="output")
      
          graph = tf.get_default_graph()  # 獲得默認的圖
          input_graph_def = graph.as_graph_def()  # 返回一個序列化的圖代表當前的圖
          init = tf.global_variables_initializer()
          with tf.Session() as sess:
              sess.run(init)
              output_graph_def = tf.graph_util.convert_variables_to_constants(  # 模型持久化,將變量值固定
                  sess=sess,
                  input_graph_def=input_graph_def,  # 等於:sess.graph_def
                  output_node_names=output_node_names.split(","))  # 如果有多個輸出節點,以逗號隔開
      
              with tf.gfile.GFile(output_graph, "wb") as f:  # 保存模型
                  f.write(output_graph_def.SerializeToString())  # 序列化輸出
      
      if __name__ == '__main__':
          freeze_graph("frozen_insightface_r50.pb")
          print("finish!")
    7. 采用tensorflow的post-train quantization离线量化方法(有一定的精度损失)转换成tflite模型,从而完成端侧的模型部署:
      import tensorflow as tf
      
      convert=tf.lite.TFLiteConverter.from_frozen_graph("frozen_insightface_r50.pb",input_arrays=["data"],output_arrays=["output"],
                                                        input_shapes={"data":[1,112,112,3]})
      convert.post_training_quantize=True
      tflite_model=convert.convert()
      open("quantized_insightface_r50.tflite","wb").write(tflite_model)
      print("finish!")
  • 相关阅读:
    学习笔记之正向代理和反向代理的区别
    PHP程序员的进阶之路
    go语言笔记——切片函数常见操作,增删改查和搜索、排序
    golang的垃圾回收(GC)机制
    堆栈的详细讲解
    springAop必导jar包
    sring框架的jdbc应用
    下载jar包方法
    mysql数据乱码
    Eclipse打包java工程
  • 原文地址:https://www.cnblogs.com/qiangz/p/11134240.html
Copyright © 2020-2023  润新知