• [阿里DIN] 模型保存,加载和使用


    [阿里DIN] 模型保存,加载和使用

    0x00 摘要

    Deep Interest Network(DIN)是阿里妈妈精准定向检索及基础算法团队在2017年6月提出的。其针对电子商务领域(e-commerce industry)的CTR预估,重点在于充分利用/挖掘用户历史行为数据中的信息。

    本系列文章会解读论文以及源码,顺便梳理一些深度学习相关概念和TensorFlow的实现。

    本文是系列第 12 篇 :介绍DIN模型的保存,加载和使用。

    0x01 TensorFlow模型

    1.1 模型文件

    TensorFlow模型会保存在checkpoint相关文件中。因为TensorFlow会将计算图的结构和图上参数取值分开保存,所以保存后在相关文件夹中会出现3个文件。

    下面就是DIN,DIEN相关生成的文件,可以通过名称来判别。

    checkpoint				
    
    ckpt_noshuffDIN3.data-00000-of-00001
    ckpt_noshuffDIN3.meta
    ckpt_noshuffDIN3.index
    
    ckpt_noshuffDIEN3.data-00000-of-00001	
    ckpt_noshuffDIEN3.index			
    ckpt_noshuffDIEN3.meta
    

    所以我们可以认为和保存的模型直接相关的是以下这四个文件:

    • checkpoint文件保存了一个目录下所有的模型文件列表,这个文件是TensorFlow自动生成且自动维护的。在 checkpoint文件中维护了由一个TensorFlow持久化的所有TensorFlow模型文件的文件名。当某个保存的TensorFlow模型文件被删除时,这个模型所对应的文件名也会从checkpoint文件中删除。checkpoint中内容的格式为CheckpointState Protocol Buffer.
    • .meta文件 保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构。
      TensorFlow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。TensorFlow中元图是由MetaGraphDef Protocol Buffer定义的。MetaGraphDef 中的内容构成了TensorFlow持久化时的第一个文件。保存MetaGraphDef 信息的文件默认以.meta为后缀名。
    • .index文件保存了当前参数名。
    • model.ckpt文件保存了TensorFlow程序中每一个变量的取值,这个文件是通过SSTable格式存储的,可以大致理解为就是一个(key,value)列表。model.ckpt文件中列表的第一行描述了文件的元信息,比如在这个文件中存储的变量列表。列表剩下的每一行保存了一个变量的片段,变量片段的信息是通过SavedSlice Protocol Buffer定义的。SavedSlice类型中保存了变量的名称、当前片段的信息以及变量取值。TensorFlow提供了tf.train.NewCheckpointReader类来查看model.ckpt文件中保存的变量信息。

    1.2 freeze_graph

    正如前文所述,tensorflow在训练过程中,通常不会将权重数据保存的格式文件里,反而是分开保存在一个叫checkpoint的检查点文件里,当初始化时,再通过模型文件里的变量Op节点来从checkoupoint文件读取数据并初始化变量。这种模型和权重数据分开保存的情况,使得发布产品时不是那么方便,所以便有了freeze_graph.py脚本文件用来将这两文件整合合并成一个文件。

    freeze_graph.py是怎么做的呢?

    • 它先加载模型文件
    • 提供checkpoint文件地址后,它从checkpoint文件读取权重数据初始化到模型里的权重变量;
    • 将权重变量转换成权重常量 (因为常量能随模型一起保存在同一个文件里);
    • 再通过指定的输出节点没用于输出推理的Op节点从图中剥离掉;
    • 使用tf.train.writegraph保存图,这个图会提供给freeze_graph使用;
    • 再使用freeze_graph重新保存到指定的文件里;

    0x02 DIN代码

    因为 DIN 源码中没有实现此部分,所以我们需要自行添加。

    2.1 输出结点

    首先,在model.py中,需要声明输出结点。

    def build_fcn_net(self, inp, use_dice = False):
        .....
        # 此处需要给 y_hat 添加一个name
        self.y_hat = tf.nn.softmax(dnn3, name='final_output') + 0.00000001
    

    2.2 保存函数

    其次,需要添加一个保存函数,调用 freeze_graph 来进行保存。

    需要注意几点:

    • write_graph 的 as_text 参数默认是 True,我们这里设置为 False。有的环境如果设置为 True 会有问题;
    • 因为write_graph 的 as_text 参数做了设置,所以freeze_graph的参数也做相应设置: input_binary=True
    • input_checkpoint 参数需要针对DIN或者DIEN做相应调整;

    具体代码如下:

    def din_freeze_graph(sess):
        # 模型持久化,将变量值固定
        output_graph_def = convert_variables_to_constants(
                sess=sess,
                input_graph_def=sess.graph_def, # 等于:sess.graph_def
                output_node_names=['final_output']) # 如果有多个输出节点,以逗号隔开
        tf.train.write_graph(output_graph_def, 'dnn_best_model', 'model.pb', False)
    
        freeze_graph.freeze_graph(
                input_graph='./dnn_best_model/model.pb',
                input_saver='',
                input_binary=True,
                input_checkpoint='./dnn_best_model/ckpt_noshuffDIN3',
                output_node_names='final_output', # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
                restore_op_name='save/restore_all',
                filename_tensor_name='save/Const:0',
                output_graph='./dnn_best_model/frozen_model.pb',
                clear_devices=False,
                initializer_nodes=''
                )
    
    

    2.2 调用保存

    我们在train函数中,存储模型之后,进行调用。

    def train(...):
                    if (iter % save_iter) == 0:
                        print('save model iter: %d' %(iter))
                        model.save(sess, model_path+"--"+str(iter))
                        freeze_graph(sess) # 此处调用
    

    0x03 验证

    3.1 加载

    加载函数如下:

    def load_graph(fz_gh_fn):
        with tf.gfile.GFile(fz_gh_fn, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
    
            with tf.Graph().as_default() as graph:
                tf.import_graph_def(
                    graph_def,
                    input_map=None,
                    return_elements=None,
                    name="prefix"  # 此处可以自己修改
                )
        return graph
    

    调用加载函数如下,我们在加载之后,打印出图中对应节点:

    graph = load_graph('./dnn_best_model/frozen_model.pb')
    for op in graph.get_operations():
        print(op.name, op.values())
    

    从打印结果我们可以看出来,有些op是Inputs相关,final_output节点则是我们之前设定的。

    (u'prefix/Inputs/mid_his_batch_ph', (<tf.Tensor 'prefix/Inputs/mid_his_batch_ph:0' shape=(?, ?) dtype=int32>,))
    (u'prefix/Inputs/cat_his_batch_ph', (<tf.Tensor 'prefix/Inputs/cat_his_batch_ph:0' shape=(?, ?) dtype=int32>,))
    (u'prefix/Inputs/uid_batch_ph', (<tf.Tensor 'prefix/Inputs/uid_batch_ph:0' shape=(?,) dtype=int32>,))
    (u'prefix/Inputs/mid_batch_ph', (<tf.Tensor 'prefix/Inputs/mid_batch_ph:0' shape=(?,) dtype=int32>,))
    (u'prefix/Inputs/cat_batch_ph', (<tf.Tensor 'prefix/Inputs/cat_batch_ph:0' shape=(?,) dtype=int32>,))
    (u'prefix/Inputs/mask', (<tf.Tensor 'prefix/Inputs/mask:0' shape=(?, ?) dtype=float32>,))
    (u'prefix/Inputs/seq_len_ph', (<tf.Tensor 'prefix/Inputs/seq_len_ph:0' shape=(?,) 
                                   
    ......            
                                   
    (u'prefix/final_output', (<tf.Tensor 'prefix/final_output:0' shape=(?, 2) dtype=float32>,))
    

    3.2 验证

    验证数据可以自己炮制,或者就是从测试数据中取出两条即可,我们的验证文件名字为 local_predict_splitByUser

    0	A3BI7R43VUZ1TY	B00JNHU0T2	Literature & Fiction	0989464105B00B01691C14778097321608442845	BooksLiterature & FictionBooksBooks
    
    1	A3BI7R43VUZ1TY	0989464121	Books	0989464105B00B01691C14778097321608442845	BooksLiterature & FictionBooksBooks
    

    验证代码如下,其中feed_dict如何填充,需要根据上节的输出结果来进行相关配置。

    def predict(
            graph,
            predict_file = "local_predict_splitByUser",
            uid_voc = "uid_voc.pkl",
            mid_voc = "mid_voc.pkl",
            cat_voc = "cat_voc.pkl",
            batch_size = 128,
            maxlen = 100):
        gpu_options = tf.GPUOptions(allow_growth=True)
        with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options), graph = graph) as sess:
            predict_data = DataIterator(predict_file, uid_voc, mid_voc, cat_voc, batch_size, maxlen)
            for src, tgt in predict_data:
                uids, mids, cats, mid_his, cat_his, mid_mask, target, sl, noclk_mids, noclk_cats = prepare_data(src, tgt, maxlen, return_neg=True)
                final_output = "prefix/final_output:0"
                feed_dict = {
                    'prefix/Inputs/mid_his_batch_ph:0' : mid_his,
                    'prefix/Inputs/cat_his_batch_ph:0':cat_his,
                    'prefix/Inputs/uid_batch_ph:0':uids,
                    'prefix/Inputs/mid_batch_ph:0':mids,
                    'prefix/Inputs/cat_batch_ph:0':cats,
                    'prefix/Inputs/mask:0':mid_mask,
                    'prefix/Inputs/seq_len_ph:0':sl
                }
                y_hat = sess.run(final_output, feed_dict = feed_dict)
                print(y_hat)
    

    预测结果如下:

    [[0.95820646 0.04179354]
     [0.09431148 0.9056886 ]]
    

    3.3 为什么要在tensor后面加:0

    在上节中,我们可以看到在feed_dict之中,给定的tensor名字后面都带了 :0

    feed_dict = {
        'prefix/Inputs/mid_his_batch_ph:0' : mid_his,
        'prefix/Inputs/cat_his_batch_ph:0':cat_his,
        'prefix/Inputs/uid_batch_ph:0':uids,
        'prefix/Inputs/mid_batch_ph:0':mids,
        'prefix/Inputs/cat_batch_ph:0':cats,
        'prefix/Inputs/mask:0':mid_mask,
        'prefix/Inputs/seq_len_ph:0':sl
    }
    

    这里需要注意,TensorFlow的运算结果不是一个数,而是一个张量结构。张量的命名形式:“node : src_output”,node为节点的名称,src_output 表示当前张量来自来自节点的第几个输出。

    在我们这里,prefix/Inputs/mid_batch_ph 是操作节点,prefix/Inputs/mid_batch_ph:0 才是变量的名字。冒号后面的数字编号表示这个张量是计算节点上的第几个结果

    0xFF 参考

    【TensorFlow】freeze_graph

    [深度学习] TensorFlow中模型的freeze_graph

    TensorFlow模型冷冻以及为什么tensor名字要加:0

    tensorflow实战笔记(19)----使用freeze_graph.py将ckpt转为pb文件

    Tensorflow-GraphDef、MetaGraph、CheckPoint

  • 相关阅读:
    桌面上嵌入窗口(桌面日历)原理探索(将该窗口的Owner设置成桌面的Shell 窗口,可使用SetWindowLong更改窗口的GWL_HWNDPARENT,还要使用SetWindowPos设置Z-Order)
    QQ截图时窗口自动识别的原理(WindowFromPoint, ChildWindowFromPoint, ChildWindowFromPointEx,RealChildWindowFromPoint)
    如何给开源的DUILib支持Accessibility(论述了DUILib的六个缺点,很精彩)
    从点击Button到弹出一个MessageBox, 背后发生了什么(每个UI线程都有一个ThreadInfo结构, 里面包含4个队列和一些标志位)
    Sessions, Window Stations and Desktops(GetDesktopWindow函数得到的桌面句柄, 是Csrss.exe创建的一个窗口)
    skip list
    理解对象模型图(Reading OMDS)
    Javascript与当前项目的思考
    Stub和Mock的理解
    https学习总结
  • 原文地址:https://www.cnblogs.com/rossiXYZ/p/14019176.html
Copyright © 2020-2023  润新知