根据训练好的Transformer模型,得到注意力矩阵,并对注意力进行可视化
首先安装:tensorflow 1.13.1 + tensor2tensor 1.13.1
可视化,请在Jupyter notebook中运行。该代码根据tensor2tensor/tensor2tensor/visualization/visualization.py修改得到
# coding=utf-8 # Copyright 2020 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Shared code for visualizing transformer attentions.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np # To register the hparams set from tensor2tensor import models # pylint: disable=unused-import from tensor2tensor import problems from tensor2tensor.utils import registry from tensor2tensor.utils import trainer_lib import tensorflow.compat.v1 as tf from tensor2tensor.utils import usr_dir EOS_ID = 1 class AttentionVisualizer2(object): """Helper object for creating Attention visualizations.""" def __init__( self, hparams_set,hparams,t2t_usr_dir, model_name, data_dir, problem_name, beam_size=1): inputs, targets, samples, att_mats = build_model( hparams_set,hparams, t2t_usr_dir, model_name, data_dir, problem_name, beam_size=beam_size) # Fetch the problem ende_problem = problems.problem(problem_name) encoders = ende_problem.feature_encoders(data_dir) self.inputs = inputs self.targets = targets self.att_mats = att_mats self.samples = samples self.encoders = encoders def encode(self, input_str): """Input str to features dict, ready for inference.""" inputs = self.encoders["inputs"].encode(input_str) + [EOS_ID] batch_inputs = np.reshape(inputs, [1, -1, 1, 1]) # Make it 3D. return batch_inputs def decode(self, integers): """List of ints to str.""" integers = list(np.squeeze(integers)) return self.encoders["targets"].decode(integers) def encode_list(self, integers): """List of ints to list of str.""" integers = list(np.squeeze(integers)) return self.encoders["inputs"].decode_list(integers) def decode_list(self, integers): """List of ints to list of str.""" integers = list(np.squeeze(integers)) return self.encoders["targets"].decode_list(integers) def get_vis_data_from_string(self, sess, input_string): """Constructs the data needed for visualizing attentions. Args: sess: A tf.Session object. input_string: The input sentence to be translated and visualized. Returns: Tuple of ( output_string: The translated sentence. input_list: Tokenized input sentence. output_list: Tokenized translation. att_mats: Tuple of attention matrices; ( enc_atts: Encoder self attention weights. A list of `num_layers` numpy arrays of size (batch_size, num_heads, inp_len, inp_len) dec_atts: Decoder self attention weights. A list of `num_layers` numpy arrays of size (batch_size, num_heads, out_len, out_len) encdec_atts: Encoder-Decoder attention weights. A list of `num_layers` numpy arrays of size (batch_size, num_heads, out_len, inp_len) ) """ encoded_inputs = self.encode(input_string) # Run inference graph to get the translation. out = sess.run(self.samples, { self.inputs: encoded_inputs, }) # Run the decoded translation through the training graph to get the # attention tensors. att_mats = sess.run(self.att_mats, { self.inputs: encoded_inputs, self.targets: np.reshape(out, [1, -1, 1, 1]), }) output_string = self.decode(out) input_list = self.encode_list(encoded_inputs) output_list = self.decode_list(out) return output_string, input_list, output_list, att_mats def build_model(hparams_set, hparams,t2t_usr_dir, model_name, data_dir, problem_name, beam_size=1): """Build the graph required to fetch the attention weights. Args: hparams_set: HParams set to build the model with. model_name: Name of model. data_dir: Path to directory containing training data. problem_name: Name of problem. beam_size: (Optional) Number of beams to use when decoding a translation. If set to 1 (default) then greedy decoding is used. Returns: Tuple of ( inputs: Input placeholder to feed in ids to be translated. targets: Targets placeholder to feed to translation when fetching attention weights. samples: Tensor representing the ids of the translation. att_mats: Tensors representing the attention weights. ) """ print(model_name) usr_dir.import_usr_dir(t2t_usr_dir) hparams = trainer_lib.create_hparams( hparams_set,hparams, data_dir=data_dir, problem_name=problem_name) # print(hparams) translate_model = registry.model(model_name)( hparams, tf.estimator.ModeKeys.EVAL) inputs = tf.placeholder(tf.int32, shape=(1, None, 1, 1), name="inputs") targets = tf.placeholder(tf.int32, shape=(1, None, 1, 1), name="targets") translate_model({ "inputs": inputs, "targets": targets, }) # Must be called after building the training graph, so that the dict will # have been filled with the attention tensors. BUT before creating the # inference graph otherwise the dict will be filled with tensors from # inside a tf.while_loop from decoding and are marked unfetchable. atts = get_att_mats(translate_model,model_name) with tf.variable_scope(tf.get_variable_scope(), reuse=True): samples = translate_model.infer({ "inputs": inputs, }, beam_size=beam_size)["outputs"] return inputs, targets, samples, atts def get_att_mats(translate_model,model_name): """Get's the tensors representing the attentions from a build model. The attentions are stored in a dict on the Transformer object while building the graph. Args: translate_model: Transformer object to fetch the attention weights from. Returns: Tuple of attention matrices; ( enc_atts: Encoder self attention weights. A list of `num_layers` numpy arrays of size (batch_size, num_heads, inp_len, inp_len) dec_atts: Decoder self attetnion weights. A list of `num_layers` numpy arrays of size (batch_size, num_heads, out_len, out_len) encdec_atts: Encoder-Decoder attention weights. A list of `num_layers` numpy arrays of size (batch_size, num_heads, out_len, inp_len) ) """ enc_atts = [] dec_atts = [] encdec_atts = [] prefix = "%s/body/"%(model_name) postfix_self_attention = "/multihead_attention/dot_product_attention" if translate_model.hparams.self_attention_type == "dot_product_relative": postfix_self_attention = ("/multihead_attention/" "dot_product_attention_relative") postfix_encdec = "/multihead_attention/dot_product_attention" for i in range(translate_model.hparams.num_hidden_layers): enc_att = translate_model.attention_weights[ "%sencoder/layer_%i/self_attention%s" % (prefix, i, postfix_self_attention)] dec_att = translate_model.attention_weights[ "%sdecoder/layer_%i/self_attention%s" % (prefix, i, postfix_self_attention)] encdec_att = translate_model.attention_weights[ "%sdecoder/layer_%i/encdec_attention%s" % (prefix, i, postfix_encdec)] enc_atts.append(enc_att) dec_atts.append(dec_att) encdec_atts.append(encdec_att) return enc_atts, dec_atts, encdec_atts from IPython.display import display def call_html(): import IPython display(IPython.core.display.HTML(''' <script src="/static/components/requirejs/require.js"></script> <script> requirejs.config({ paths: { base: '/static/base', "d3": "https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min", jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min', }, }); </script> ''')) import os from tensor2tensor import problems from tensor2tensor.bin import t2t_decoder # To register the hparams set # from tensor2tensor.utils import registry from tensor2tensor.utils import trainer_lib from tensor2tensor.visualization import attention # from src.visualization import visualization os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" # HParams problem_name = 'translate_ende_wmt32k' #数据 data_dir = os.path.expanduser('/home/usrname/collaboration/t2t_data/%s'%(problem_name)) #数据路径 model_name = "collaboration" #模型名称 hparams_set = "collaboration_base" #模型类型 hparams = 'max_length=128,num_hidden_layers=6,usedegray=1.0,reuse_n=0' #自定义参数 (根据自己需求) t2t_usr_dir = './src/' #用户自定义模型model的路径 visualizer = AttentionVisualizer2(hparams_set,hparams, t2t_usr_dir,model_name, data_dir, problem_name, beam_size=1) tf.Variable(0, dtype=tf.int64, trainable=False, name='global_step')
接着继续运行:
saver = tf.train.Saver() with tf.Session() as sess: ckpt = 'averaged.ckpt-0' #checkpoint路径 print(ckpt) saver.restore(sess, ckpt)
#可视化样本 # input_sentence = "It is in this spirit that a majority of American governments have passed new laws since 2009 making the registration or voting process more difficult." input_sentence = "The Law will never be perfect, but its application should be just - this is what we are missing, in my opinion." output_string, inp_text, out_text, att_mats = visualizer.get_vis_data_from_string(sess, input_sentence) print(output_string) call_html() attention.show(inp_text, out_text, *att_mats)
可视化结果: