• 深度学习模型转换,以pytorch转tensorflow为例


    这里以onnx为中介进行转换。主要用到

    STEP1. 将pytorch 模型转换成onnx模型

    注意这里关键是要构造一个模型的输入输入,这里假设模型接受两个输入。

    pmodel = PytorchModel()
    dummy_input = (np.zeros((1, 30), dtype=np.float32), np.zeros((1, 2), dtype=np.float32))
    torch.onnx.export(pmodel, (torch.as_tensor(dummy_input[0]), torch.as_tensor(dummy_input[1])), "/tmp/xx.onnx",
                      verbose=True, input_names=['input1', 'input2'], output_names=['output1', 'output2'])

    参数 input_names表示模型的输入参数(随便起名字),output_names表示输出名字

    STEP 2. 将onnx模型转成tf

    这里需要借助onnx_tf这个库

    import onnx
    from onnx_tf.backend import prepare
    
    onnx_model = onnx.load("/tmp/xx.onnx")  # load onnx model
    tf_model = prepare(onnx_model)
    tf_model.export_graph("/tmp/xxpb/")  # export the model

    STEP 3 使用tensorflow模型

    import tensorflow as tf
    import io
    import numpy as np
    
    model_path = '/tmp/xxpb/'
    
    sess = tf.compat.v1.Session()
    metagraph = tf.compat.v1.saved_model.loader.load(sess, [tf.compat.v1.saved_model.tag_constants.SERVING], model_path)
    sig = metagraph.signature_def["serving_default"]
    input_dict = dict(sig.inputs)
    output_dict = dict(sig.outputs)
    print(input_dict, output_dict)
    output_stochastic_act_label_0 = output_dict["output_0"].name
    output_stochastic_act_label_1 = output_dict["output_1"].name
    
    input_state_label = None
    initial_state = None
    state = None
    if "state" in input_dict.keys():
        input_state_label = input_dict["state"].name
        strfile = io.StringIO()
        print(input_dict["state"].tensor_shape, file=strfile)
        lines = strfile.getvalue().split("
    ")
        dim_1 = int(lines[1].split(":")[1].strip(" "))
        dim_2 = int(lines[4].split(":")[1].strip(" "))
        initial_state = np.zeros((dim_1, dim_2), dtype=np.float32)
        state = np.zeros((dim_1, dim_2), dtype=np.float32)
    input_obs_label_1 = input_dict["input1"].name
    input_obs_label_0 = input_dict["input2"].name
    input_dict = {input_obs_label_0: np.zeros((1, 2), dtype=np.float32), input_obs_label_1:np.zeros((1, 30), dtype=np.float32)}
    out = sess.run((output_stochastic_act_label_0, output_stochastic_act_label_1), feed_dict=input_dict)
    print(out)

    注意这里的name需要重新设置一遍。





  • 相关阅读:
    TCO 2013 2A
    matlab 中的fmincon参数设定问题
    一步步写自己SqlHelper类库(四):Connection对象
    珠海立方科技实习总结
    Web Services 应用开发学习笔记(三):XML模式定义
    C#笔记(一):类型,泛型,集合
    Web Services 应用开发学习笔记(二):XML文档类型定义
    一步步写自己SqlHelper类库(三):连接字符串
    一步步写自己SqlHelper类库(二):.NET Framework 数据提供程序
    (Joomla)多功能健康模块
  • 原文地址:https://www.cnblogs.com/MrLJC/p/14145763.html
Copyright © 2020-2023  润新知