• tf.estimator.Estimator


    1.定义

    tf.estimator.Estimator(model_fn=model_fn) #model_fn是一个方法

    2.定义model_fn:

        def model_fn_builder(self, bert_config, num_labels, init_checkpoint):
            """
            :param bert_config:
            :param num_labels:
            :param init_checkpoint:
            :param learning_rate:
            :param num_train_steps:
            :param num_warmup_steps:
            :return:
            """
            def model_fn(features, labels, mode, params):
                """
           这4个参数必须这样定义,就算是不用某个参数,也要把它定义出来
                :param features: 是estimator传过来的feature
                :param labels: 数据标签
                :param mode: tf.estimator.TRAIN/tf.estimator.EVAL/tf.estimator.PREDICTION
                :param params:这个暂时没弄懂
                :return:
                """
                input_ids = features['input_ids']
                input_mask = features['input_mask']
                segment_ids = features['segment_ids']
                probabilities = self.creat_model(bert_config, input_ids, input_mask, segment_ids, num_labels) # 这里是重点,这里要定义模型和要取模型的什么值
    
                tvars = tf.trainable_variables()
                (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) # assignment_map是模型所有的变量字典,init_checkpoint为模型文件
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map) # 加载模型
    
                output_spec = tf.estimator.EstimatorSpec(mode=mode, predictions=probabilities) # 应为上面已经从create_model中获取了我们要做什么op,获取什么值,prediction为op或值
                return output_spec
    
            return model_fn

    def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
      """Compute the union of the current variables and checkpoint variables."""
      assignment_map = {}
      initialized_variable_names = {}
    
      name_to_variable = collections.OrderedDict()
      for var in tvars:
        name = var.name
        m = re.match("^(.*):\d+$", name)
        if m is not None:
          name = m.group(1)
        name_to_variable[name] = var
    
      init_vars = tf.train.list_variables(init_checkpoint)
    
      assignment_map = collections.OrderedDict()
      for x in init_vars:
        (name, var) = (x[0], x[1])
        if name not in name_to_variable:
          continue
        assignment_map[name] = name
        initialized_variable_names[name] = 1
        initialized_variable_names[name + ":0"] = 1
    
      return (assignment_map, initialized_variable_names)
        def creat_model(self, bert_config, input_ids, input_mask, segment_ids, num_labels):
            """
    
            :param bert_config:
            :param input_ids:
            :param input_mask:
            :param segment_ids:
            :param num_labels:
            :return:
            """
            model = modeling.BertModel(
                config=bert_config,
                is_training=False,
                input_ids=input_ids,
                input_mask=input_mask,
                token_type_ids=segment_ids,
                use_one_hot_embeddings=False)
    
            output_layer = model.get_pooled_output()
    
            hidden_size = output_layer.shape[-1].value
        
        
        # 获得已经训练好的值   output_weights
    = tf.get_variable( "output_weights", [num_labels, hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02)) output_bias = tf.get_variable( "output_bias", [num_labels], initializer=tf.zeros_initializer()) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) probabilities = tf.nn.softmax(logits, axis=-1) return probabilities

    2.使用estimator.predict

    def predict(self, text_a, text_b):
    """

    :param text_a:
    :param text_b:
    :return:
    """

    def create_int_feature(values):
    f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
    return f

    input_ids, input_mask, segment_ids = self.convert_single_example(text_a, text_b)

    features = collections.OrderedDict()
    features['input_ids'] = create_int_feature(input_ids)
    features['input_mask'] = create_int_feature(input_mask)
    features['segment_ids'] = create_int_feature(segment_ids)

    tf_example = tf.train.Example(features=tf.train.Features(feature=features)) # 将feature转换为example

    self.writer.write(tf_example.SerializeToString())# 序列化example,写入tfrecord文件

    result = self.estimator.predict(input_fn=self.predict_input_fn)
        def file_based_input_fn_builder(self):
            """
    
            :param examples:
            :return:
            """
            name_to_features = {
                "input_ids": tf.FixedLenFeature([MAX_SEQ_LENGTH], tf.int64),
                "input_mask": tf.FixedLenFeature([MAX_SEQ_LENGTH], tf.int64),
                "segment_ids": tf.FixedLenFeature([MAX_SEQ_LENGTH], tf.int64),
            }
    
            def decode_record(_examples, _name_to_feature):
                """
    
                :param _examples:
                :param _name_to_feature:
                :return:
                """
    
                return tf.parse_single_example(_examples, _name_to_feature)
    
            def input_fn():
                """
    
                :param params:
                :return:
                """
                d = tf.data.TFRecordDataset(self.predict_file) # 读取TFRecord文件
                d = d.apply(
                    tf.data.experimental.map_and_batch(
                        lambda record: decode_record(record, name_to_features), # 将序列化的feature映射到字典上
                        batch_size=1,
                        drop_remainder=False))
    
                return d # 这里返回的值会进入到定义estimator时的model_fn中,model_fn中的feature是d.get_next()的结果
    
            return input_fn

    1

  • 相关阅读:
    javaweb毕业设计
    Maven入门----MyEclipse创建maven项目(二)
    Maven入门----介绍及环境搭建(一)
    SpringMvc入门五----文件上传
    SpringMvc入门四----rest风格Url
    SpringMvc入门三----控制器
    SpringMvc入门二----HelloWorld
    SpringMvc入门一----介绍
    分析setup/hold电气特性从D触发器内部结构角度
    33. Search in Rotated Sorted Array
  • 原文地址:https://www.cnblogs.com/callyblog/p/10216058.html
Copyright © 2020-2023  润新知