• TensorFlow Frontend前端


    TensorFlow Frontend前端

    TensorFlow前端有助于将TensorFlow模型导入TVM。

    Supported versions:

    • 1.12 and below

    Tested models:

    • Inception (V1/V2/V3/V4)
    • Resnet (All)
    • Mobilenet (V1/V2 All)
    • Vgg (16/19)
    • BERT (Base/3-layer)

    Preparing a Model for Inference准备推理模型

    Remove Unneeded Nodes删除不需要的节点

    导出过程将删除许多不需要进行推理的节点,但不幸的是会留下一些剩余的节点。应该手动删除的节点:

    Convert None Dimensions to Constants将无尺寸Dimensions转换为常数

    TVM对动态张量形状的支持最少。None应将尺寸替换为常量。例如,模型可以接受带有shape的输入(None,20)。这应转换为的形状(1,20)。应该相应地修改模型,以确保这些形状在整个图形中都匹配。

    Export

    TensorFlow前端需要冻结的protobuf(.pb)或保存的模型作为输入。不支持检查点(.ckpt)。TensorFlow前端所需的graphdef,可以从活动会话中提取,可以使用TFParser帮助器类提取。

    应该导出该模型并进行许多转换,以准备模型进行推理。设置`add_shapes=True`也很重要,因为这会将每个节点的输出形状嵌入到图形中。这是一个给定会话将模型导出为protobuf的函数:

    import tensorflow as tf

    from tensorflow.tools.graph_transforms import TransformGraph

     

    def export_pb(session):

        with tf.gfile.GFile("myexportedmodel.pb", "wb") as f:

            inputs = ["myinput1", "myinput2"] # replace with your input names

            outputs = ["myoutput1"] # replace with your output names

            graph_def = session.graph.as_graph_def(add_shapes=True)

            graph_def = tf.graph.util.convert_variables_to_constants(session, graph_def, outputs)

            graph_def = TransformGraph(

                graph_def,

                inputs,

                outputs,

                [

                    "remove_nodes(op=Identity, op=CheckNumerics, op=StopGradient)",

                    "sort_by_execution_order", # sort by execution order after each transform to ensure correct node ordering

                    "remove_attribute(attribute_name=_XlaSeparateCompiledGradients)",

                    "remove_attribute(attribute_name=_XlaCompile)",

                    "remove_attribute(attribute_name=_XlaScope)",

                    "sort_by_execution_order",

                    "remove_device",

                    "sort_by_execution_order",

                    "fold_batch_norms",

                    "sort_by_execution_order",

                    "fold_old_batch_norms",

                    "sort_by_execution_order"

                ]

            )

            f.write(graph_def.SerializeToString())

    Another method is to export and freeze the graph.

    Import the Model

    Explicit Shape:

    确保可以在整个图形中知道形状,将`shape`参数传递给`from_tensorflow`。该词典将输入名称映射到输入形状。

    Data Layout

    大多数TensorFlow模型以NHWC布局发布。NCHW布局通常提供更好的性能,尤其是在GPU上。该TensorFlow前端可以通过传递参数自动转换模型的数据布局`layout='NCHW'`到`from_tensorflow`。

    Best Practices

    • 使用静态张量形状代替动态形状(删除`None`尺寸)。
    • `TensorArray`目前尚不支持使用静态RNN代替动态RNN。

    Supported Ops

    • Abs
    • Add
    • AddN
    • All
    • Any
    • ArgMax
    • ArgMin
    • AvgPool
    • BatchMatMul
    • BatchMatMulV2
    • BatchNormWithGlobalNormalization
    • BatchToSpaceND
    • BiasAdd
    • BroadcastTo
    • Cast
    • Ceil
    • CheckNumerics
    • ClipByValue
    • Concat
    • ConcatV2
    • Conv2D
    • Cos
    • Tan
    • CropAndResize
    • DecodeJpeg
    • DepthwiseConv2dNative
    • DepthToSpace
    • Dilation2D
    • Equal
    • Elu
    • Enter
    • Erf
    • Exit
    • Exp
    • ExpandDims
    • Fill
    • Floor
    • FloorDiv
    • FloorMod
    • FusedBatchNorm
    • FusedBatchNormV2
    • Gather
    • GatherNd
    • GatherV2
    • Greater
    • GreaterEqual
    • Identity
    • IsFinite
    • IsInf
    • IsNan
    • LeakyRelu
    • LeftShift
    • Less
    • LessEqual
    • Log
    • Log1p
    • LoopCond
    • LogicalAnd
    • LogicalOr
    • LogicalNot
    • LogSoftmax
    • LRN
    • LSTMBlockCell
    • MatMul
    • Max
    • MaxPool
    • Maximum
    • Mean
    • Merge
    • Min
    • Minimum
    • MirrorPad
    • Mod
    • Mul
    • Neg
    • NextIteration
    • NotEqual
    • OneHot
    • Pack
    • Pad
    • PadV2
    • Pow
    • Prod
    • Range
    • Rank
    • RealDiv
    • Relu
    • Relu6
    • Reshape
    • ResizeBilinear
    • ResizeBicubic
    • ResizeNearestNeighbor
    • ReverseV2
    • RightShift
    • Round
    • Rsqrt
    • Select
    • Selu
    • Shape
    • Sigmoid
    • Sign
    • Sin
    • Size
    • Slice
    • Softmax
    • Softplus
    • SpaceToBatchND
    • SpaceToDepth,
    • Split
    • SplitV
    • Sqrt
    • Square
    • SquareDifference
    • Squeeze
    • StridedSlice
    • Sub
    • Sum
    • Switch
    • Tanh
    • TensorArrayV3
    • TensorArrayScatterV3
    • TensorArrayGatherV3
    • TensorArraySizeV3
    • TensorArrayWriteV3
    • TensorArrayReadV3
    • TensorArraySplitV3
    • TensorArrayConcatV3
    • Tile
    • TopKV2
    • Transpose
    • TruncateMod
    • Unpack
    • UnravelIndex
    • Where
    • ZerosLike
    人工智能芯片与自动驾驶
  • 相关阅读:
    数据分析必须掌握的统计学知识!
    数据分析常用指标大全,熟记!
    Java编程基础阶段笔记 day 07 面向对象编程(上)
    Java编程基础阶段笔记 day04 Java基础语法(下)
    Java编程基础阶段笔记 day06 二维数组
    Java编程基础阶段笔记 day05 数组
    Java编程基础阶段笔记 day04 Java基础语法(下)
    Java编程基础阶段笔记 day03 Java基本语法(中)
    啥?虚拟现实技术已经应用到自动化仓库? | 基于unity实现的自动化仓库模拟监控系统
    交互设计书单--西南交大课程推荐
  • 原文地址:https://www.cnblogs.com/wujianming-110117/p/14532264.html
Copyright © 2020-2023  润新知