• 文本分类(一):使用Pytorch进行文本分类——BiLSTM+Attention


    一、架构图

     二、代码

    class TextBILSTM(nn.Module):
        
        def __init__(self,
                     config:TRNNConfig,
                     char_size = 5000,
                     pinyin_size = 5000):
            super(TextBILSTM, self).__init__()
            self.num_classes = config.num_classes
            self.learning_rate = config.learning_rate
            self.keep_dropout = config.keep_dropout
            self.char_embedding_size = config.char_embedding_size
            self.pinyin_embedding_size = config.pinyin_embedding_size
            self.l2_reg_lambda = config.l2_reg_lambda
            self.hidden_dims = config.hidden_dims
            self.char_size = char_size
            self.pinyin_size = pinyin_size
            self.rnn_layers = config.rnn_layers
    
            self.build_model()
    
    
        def build_model(self):
            # 初始化字向量
            self.char_embeddings = nn.Embedding(self.char_size, self.char_embedding_size)
            # 字向量参与更新
            self.char_embeddings.weight.requires_grad = True
            # 初始化拼音向量
            self.pinyin_embeddings = nn.Embedding(self.pinyin_size, self.pinyin_embedding_size)
            self.pinyin_embeddings.weight.requires_grad = True
            # attention layer
            self.attention_layer = nn.Sequential(
                nn.Linear(self.hidden_dims, self.hidden_dims),
                nn.ReLU(inplace=True)
            )
            # self.attention_weights = self.attention_weights.view(self.hidden_dims, 1)
    
            # 双层lstm
            self.lstm_net = nn.LSTM(self.char_embedding_size, self.hidden_dims,
                                    num_layers=self.rnn_layers, dropout=self.keep_dropout,
                                    bidirectional=True)
            # FC层
            # self.fc_out = nn.Linear(self.hidden_dims, self.num_classes)
            self.fc_out = nn.Sequential(
                nn.Dropout(self.keep_dropout),
                nn.Linear(self.hidden_dims, self.hidden_dims),
                nn.ReLU(inplace=True),
                nn.Dropout(self.keep_dropout),
                nn.Linear(self.hidden_dims, self.num_classes)
            )
    
        def attention_net_with_w(self, lstm_out, lstm_hidden):
            '''
    
            :param lstm_out:    [batch_size, len_seq, n_hidden * 2]
            :param lstm_hidden: [batch_size, num_layers * num_directions, n_hidden]
            :return: [batch_size, n_hidden]
            '''
            lstm_tmp_out = torch.chunk(lstm_out, 2, -1)
            # h [batch_size, time_step, hidden_dims]
            h = lstm_tmp_out[0] + lstm_tmp_out[1]
            # [batch_size, num_layers * num_directions, n_hidden]
            lstm_hidden = torch.sum(lstm_hidden, dim=1)
            # [batch_size, 1, n_hidden]
            lstm_hidden = lstm_hidden.unsqueeze(1)
            # atten_w [batch_size, 1, hidden_dims]
            atten_w = self.attention_layer(lstm_hidden)
            # m [batch_size, time_step, hidden_dims]
            m = nn.Tanh()(h)
            # atten_context [batch_size, 1, time_step]
            atten_context = torch.bmm(atten_w, m.transpose(1, 2))
            # softmax_w [batch_size, 1, time_step]
            softmax_w = F.softmax(atten_context, dim=-1)
            # context [batch_size, 1, hidden_dims]
            context = torch.bmm(softmax_w, h)
            result = context.squeeze(1)
            return result
    
        def forward(self, char_id, pinyin_id):
            # char_id = torch.from_numpy(np.array(input[0])).long()
            # pinyin_id = torch.from_numpy(np.array(input[1])).long()
    
            sen_char_input = self.char_embeddings(char_id)
            sen_pinyin_input = self.pinyin_embeddings(pinyin_id)
    
            sen_input = torch.cat((sen_char_input, sen_pinyin_input), dim=1)
            # input : [len_seq, batch_size, embedding_dim]
            sen_input = sen_input.permute(1, 0, 2)
            output, (final_hidden_state, final_cell_state) = self.lstm_net(sen_input)
            # output : [batch_size, len_seq, n_hidden * 2]
            output = output.permute(1, 0, 2)
            # final_hidden_state : [batch_size, num_layers * num_directions, n_hidden]
            final_hidden_state = final_hidden_state.permute(1, 0, 2)
            # final_hidden_state = torch.mean(final_hidden_state, dim=0, keepdim=True)
            # atten_out = self.attention_net(output, final_hidden_state)
            atten_out = self.attention_net_with_w(output, final_hidden_state)
            return self.fc_out(atten_out)
            

    三、解释

    1、将BILSTM网络输出的结果(shape:[batch_size, time_step, hidden_dims * num_directions(=2)])
    拆成两个大小为[batch_size, time_step, hidden_dims]的Tensor; 2、将第一步拆出的两个Tensor进行相加运算得到h(shape:[batch_size, time_step, hidden_dims]); 3、将BILSTM网络最后一个隐状态(shape:[batch_size, num_layers
    * num_directions, hidden_dims])在第二维度进行求和,
    得到新的lstm_hidden(shape:[batch_size, hidden_dims]); 4、将lstm_hidden的维度从[batch_size, n_hidden]扩展到[batch_size,
    1, hidden_dims]; 5、使用slef.atten_layer(h)获得用于后续计算权重的向量atten_w(shape:[batch_size, 1, hidden_dims]); 6、将h进行tanh激活,得到m(shape:[batch_size, time_step, hidden_dims]); 7、使用torch.bmm(atten_w, m.transpose(1, 2)) 得到atten_context(shape:[batch_size, 1, time_step]); 8、将atten_context使用F.softmax(atten_context, dim=-1)进行归一化,
    得到基于上下文权重的softmax_w(shape:[batch_size, 1, time_step]); 9、使用torch.bmm(softmax_w, h)得到基于权重的BILSTM输出context(shape:[batch_size, 1, hidden_dims]); 10、将context的第二维度消掉,得到result(shape:[batch_size, hidden_dims]) ; 11、返回result;

    四、经验值

    模型效果
    1层BILSTM在训练集准确率:99.8%,测试集准确率:96.5%;
    2层BILSTM在训练集准确率:99.9%,测试集准确率:97.3%;
    调参
    dropout的值要在 0.1 以下(经验之谈,笔者在实践中发现,dropout取0.1时比dropout取0.3时在测试集准确率能提高0.5%)。
    https://blog.csdn.net/dendi_hust/article/details/94435919

  • 相关阅读:
    44-Floyd 算法
    43-Kruskal 算法
    42-MST & Prim 算法
    -垂直(水平)导航栏--转换-过度-动画-多列-瀑布流-事件
    css- 列表-表格table-轮廓-定位-浮动-盒子模型
    JS中[object object]怎么取值
    css- 选择器-背景--文本--字体--链接
    html-时间datetime-获取焦点autofocus-提示required-验证labels-默认值control-可编辑下拉-正则
    ---html链接-表格table-列表ul-布局div-表单form-input属性-多层嵌套的跳转-实体H5-新增属性
    html--块
  • 原文地址:https://www.cnblogs.com/zhangxianrong/p/15118099.html
Copyright © 2020-2023  润新知