• tensorflow 模型浮点数计算量和参数量估计


    TensorFlow 模型浮点数计算量和参数量统计
    2018-08-28

    本博文整理了如何对一个 TensorFlow 模型的浮点数计算量(FLOPs)和参数量进行统计。
    stats_graph.py

    import tensorflow as tf
    def stats_graph(graph):
        flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation())
        params = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter())
        print('FLOPs: {};    Trainable params: {}'.format(flops.total_float_ops, params.total_parameters))
    

    利用高斯分布对变量进行初始化会耗费一定的 FLOP

    C[25,9]=A[25,16]B[16,9] FLOPs=(16+15)×(25×9)=6975FLOPs(inTFstyle)=(16+16)×(25×9)=7200total_parameters=25×16+16×9=544

    with tf.Graph().as_default() as graph:
        A = tf.get_variable(initializer=tf.random_normal_initializer(dtype=tf.float32), shape=(25, 16), name='A')
        B = tf.get_variable(initializer=tf.random_normal_initializer(dtype=tf.float32), shape=(16, 9), name='B')
        C = tf.matmul(A, B, name='ouput')
        
        stats_graph(graph)
    

    输出为:
    FLOPs: 8288; Trainable params: 544

    利用常量初始化器对变量进行初始化不会耗费 FLOP

    with tf.Graph().as_default() as graph:
        A = tf.get_variable(initializer=tf.constant_initializer(value=1, dtype=tf.float32), shape=(25, 16), name='A')
        B = tf.get_variable(initializer=tf.zeros_initializer(dtype=tf.float32), shape=(16, 9), name='B')
        C = tf.matmul(A, B, name='ouput')
        
        stats_graph(graph)
    

    输出为:
    FLOPs: 7200; Trainable params: 544

    Frozen graph

    通常我们对耗费在初始化上的 FLOPs 并不感兴趣,因为它是发生在训练过程之前且是一次性的,我们感兴趣的是模型部署之后在生产环境下的 FLOPs。我们可以通过 Freeze 计算图的方式得到除去初始化 FLOPs 的、模型部署后推断过程中耗费的 FLOPs。

    from tensorflow.python.framework import graph_util
    def load_pb(pb):
        with tf.gfile.GFile(pb, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
        with tf.Graph().as_default() as graph:
            tf.import_graph_def(graph_def, name='')
            return graph
    with tf.Graph().as_default() as graph:
        # ***** (1) Create Graph *****
        A = tf.Variable(initial_value=tf.random_normal([25, 16]))
        B = tf.Variable(initial_value=tf.random_normal([16, 9]))
        C = tf.matmul(A, B, name='output')
        
        print('stats before freezing')
        stats_graph(graph)
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            # ***** (2) freeze graph *****
            output_graph = graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), ['output'])
            with tf.gfile.GFile('graph.pb', "wb") as f:
                f.write(output_graph.SerializeToString())
    # ***** (3) Load frozen graph *****
    graph = load_pb('./graph.pb')
    print('stats after freezing')
    stats_graph(graph)
    

    输出为:

    stats before freezing
    FLOPs: 8288; Trainable params: 544
    INFO:tensorflow:Froze 2 variables.
    INFO:tensorflow:Converted 2 variables to const ops.
    stats after freezing
    FLOPs: 7200; Trainable params: 0

    与 Keras 的结合

    from keras import backend as K
    from keras.layers import Dense
    from keras.models import Sequential
    from keras.initializers import Constant
    model = Sequential()
    model.add(Dense(32, input_dim=4, bias_initializer=Constant(value=0), kernel_initializer=Constant(value=1)))
    sess = K.get_session()
    graph = sess.graph
    stats_graph(graph)
    

    输出为:
    FLOPs: 0; Trainable params: 160
    Using TensorFlow backend.
    2 ops no flops stats due to incomplete shapes.
    2 ops no flops stats due to incomplete shapes.
    model.summary()


    Layer (type) Output Shape Param #

    dense_1 (Dense) (None, 32) 160

    Total params: 160
    Trainable params: 160
    Non-trainable params: 0


    DL

    About

    This is Robert Lexis (FengCun Li). To see the world, things dangerous to come to, to see behind walls, to draw closer, to find each other and to feel. That is the purpose of LIFE.
    Recent Posts

    Static variable in inline
    Iterator invalidation rul
    Emplace back
    Perfect forward
  • 相关阅读:
    NHibernate之(12):初探延迟加载机制
    NHibernate之(11):探索多对多关系及其关联查询
    NHibernate之(10):探索父子(一对多)关联查询
    NHibernate之(9):探索父子关系(一对多关系)
    NHibernate之(8):巧用组件之依赖对象
    NHibernate之(7):初探NHibernate中的并发控制
    NHibernate之(5):探索Insert, Update, Delete操作
    NHibernate之(6):探索NHibernate中的事务
    NHibernate之(4):探索查询之条件查询(Criteria Query)
    读写文件
  • 原文地址:https://www.cnblogs.com/o-v-o/p/11042066.html
Copyright © 2020-2023  润新知