• tensorflow 使用预训练好的模型的一部分参数


       

    vars = tf.global_variables()

    net_var = [var for var in vars if 'bi-lstm_secondLayer' not in var.name and 'word_embedding1s' not in var.name

    and 'proj_secondLayer' not in var.name

    ]

       

    saver_pre = tf.train.Saver(net_var)

       

    saver_pre.restore(self.sess, tf.train.latest_checkpoint(self.config.dir_model_storepath_pre))

       

    '''

    with tf.variable_scope('bi-lstm',reuse=True):

    fwk=tf.get_variable('bidirectional_rnn/fw/lstm_cell/kernel')

    fwb=tf.get_variable('bidirectional_rnn/fw/lstm_cell/bias')

    bwk = tf.get_variable('bidirectional_rnn/bw/lstm_cell/kernel')

    bwb = tf.get_variable('bidirectional_rnn/bw/lstm_cell/bias')

       

    saver_pre= tf.train.Saver({'words/_word_embeddings':self._word_embeddings,

    'bi-lstm/bidirectional_rnn/fw/lstm_cell/kernel':fwk,

    'bi-lstm/bidirectional_rnn/fw/lstm_cell/bias':fwb,

    'bi-lstm/bidirectional_rnn/bw/lstm_cell/kernel':bwk,

    'bi-lstm/bidirectional_rnn/bw/lstm_cell/bias':bwb})

    for x in tf.trainable_variables():

    print(x.name)

       

    #mysaver = tf.train.import_meta_graph(self.config.dir_model_storepath_pre_graph)

       

    saver_pre.restore(self.sess, tf.train.latest_checkpoint(self.config.dir_model_storepath_pre))

    '''

  • 相关阅读:
    《财富自由之路》读后感及读书笔记
    echarts3.x 入门
    Ubuntu 16.04 硬盘安装
    语义化版本控制的规范(转载)
    appcan IDE 无法 请求数据
    jQuery extend 函数
    63342 接口 奇遇 IDEA
    C++调用Java的Jar包
    无法打开 源 文件“stdafx.h”的解决方法
    CString的头文件
  • 原文地址:https://www.cnblogs.com/wuxiangli/p/10330907.html
Copyright © 2020-2023  润新知