• PyTorch/TensorFlow自定义OP导出ONNX


    PyTorch

    根据PyTorch的官方文档,需要用Function封装一下,为了能够导出ONNX需要加一个symbolic静态方法:

    class relu5_func(Function):
        @staticmethod
        def forward(ctx, input):
            return relu5_cuda.relu5(input)
        @staticmethod
        def symbolic(g, *inputs):
            return g.op("Relu5", inputs[0], myattr_f=1.0) 
            # 这里第一个参数"Relu5"表示ONNX输出命名
            # myattr可以随便取,表示一个属性名,_f表示是一个float类型
    relu5 = relu5_func.apply
    

    定义好后,用以下代码测试

    import torch
    import torch.nn as nn
    import relu5_cuda
    import onnx
    from torch.autograd import Function
    from torch.autograd.function import once_differentiable
    import netron
    
    class TinyNet(nn.Module):
        def __init__(self):
            super(TinyNet, self).__init__()
            self.conv1 = nn.Conv2d(3, 1, kernel_size=3, padding=1)
            self.relu1 = nn.ReLU()
            
        def forward(self, x):
            x = self.conv1(x)
            x = self.relu1(x)
            x = x.view(-1)
            x = relu5(x)
            return x
    
    net = TinyNet().cuda()
    ipt = torch.ones(2,3,12,12).cuda()
    torch.onnx.export(net, (ipt,), 'tinynet.onnx')
    print(onnx.load('tinynet.onnx'))
    netron.start('tinynet.onnx')
    

    TensorFlow

    导出pb文件

    import tensorflow as tf 
    from tensorflow.python.framework import graph_util
    
    conv1_w = tf.Variable(tf.random_normal([3, 3, 2, 3]))
    conv1_b = tf.Variable(tf.random_normal([3]))
    conv2_w = tf.Variable(tf.random_normal([3, 3, 3, 1]))
    conv2_b = tf.Variable(tf.random_normal([1]))
    xs = tf.placeholder(tf.float32, shape=[1, 12, 12, 2], name="input")
    conv1 = tf.nn.conv2d(xs, conv1_w, strides=[1,1,1,1], padding='SAME') + conv1_b
    conv2 = tf.nn.conv2d(conv1, conv2_w, strides=[1,1,1,1], padding='SAME') + conv2_b
    tf.identity(conv2, name='output')
    
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        # sess.run(conv2, feed_dict={xs: x})
        constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['output'])
        with tf.gfile.FastGFile('tfmodel.pb', mode='wb') as f:
            f.write(constant_graph.SerializeToString())
    netron.start('tfmodel.pb')
    

    转化需要

    pip3 install tf2onnx
    

    以下参数中X:0和output:0必须是一个字符串加冒号加数字形式

    python3 -m tf2onnx.convert 
    --input tfmodel.pb 
    --inputs X:0 
    --output tfmodel.onnx 
    --outputs output:0
    

    或者使用Python代码

    import tensorflow as tf 
    import tf2onnx
    from tf2onnx import loader
    
    # graph
    conv1_w = tf.Variable(tf.random_normal([3, 3, 2, 3]))
    conv1_b = tf.Variable(tf.random_normal([3]))
    conv2_w = tf.Variable(tf.random_normal([3, 3, 3, 1]))
    conv2_b = tf.Variable(tf.random_normal([1]))
    xs = tf.placeholder(tf.float32, shape=[1, 12, 12, 2], name="input")
    conv1 = tf.nn.conv2d(xs, conv1_w, strides=[1,1,1,1], padding='SAME') + conv1_b
    conv2 = tf.nn.conv2d(conv1, conv2_w, strides=[1,1,1,1], padding='SAME') + conv2_b
    tf.identity(conv2, name='output')
    # get output_graph_def
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        output_graph_def = loader.freeze_session(sess, output_names=["output:0"])
    # to onnx
    tf.reset_default_graph()
    with tf.Graph().as_default() as tf_graph:
        tf.import_graph_def(output_graph_def, name='')
        onnx_graph = tf2onnx.tfonnx.process_tf_graph(tf_graph, input_names=["input:0"], output_names=["output:0"], opset=11)
        model_proto = onnx_graph.make_model("test")
        with open("tfmodel.onnx", "wb") as f:
            f.write(model_proto.SerializeToString())
    # show
    import onnx 
    import netron
    print(onnx.load('tfmodel.onnx'))
    netron.start('tfmodel.onnx')
    
  • 相关阅读:
    洛谷 P3355 骑士共存问题
    Redis 安装
    Java 集合:(十八) Map接口
    并发编程——Java版【目录】
    Java 集合:(十七) Queue 常用接口:BlockingQueue 子接口
    Java 集合:(十六) Queue 常用接口:Deque 子接口
    Java 集合:(十五) Queue 子接口
    Java 集合:(番外篇一) ArrayList线程不安全性
    第十三章:StringTable
    Java 集合:(十三) Set实现类:LinkedHashSet
  • 原文地址:https://www.cnblogs.com/xytpai/p/13042667.html
Copyright © 2020-2023  润新知