• NLP(四十三):sentence_bert+pytorch向量检索,进行语义匹配


    一、项目目录

     二、data_clean生成数据

    from common.root_path import root
    import os
    import pandas as pd
    
    class DataMerge(object):
        def __init__(self):
            self.data_path = os.path.join(root, "data", "raw_data")
            self.out_path = os.path.join(root, "data", "clean_data")
            self.neg_data = os.path.join(self.out_path, "negtive_data.txt")
            self.sim_path = os.path.join(self.out_path, "sim_data.txt")
            self.final = os.path.join(root, "data", "final_data")
            self.train_sim = os.path.join(self.final, "train.txt")
            self.dev_sim = os.path.join(self.final, "dev.txt")
            self.test_sim = os.path.join(self.final, "test.txt")
    
        def data_merge(self, role):
            out_sentence, out_label = [], []
            data_path = os.path.join(self.data_path, role)
            # 训练集数据
            train_path = os.path.join(data_path, "train.txt")
            train_t = pd.read_csv(train_path, sep="\t", header=None, names=["sentence", "label"])
            out_sentence.extend(train_t["sentence"])
            out_label.extend(train_t["label"])
            # 验证集数据
            dev_path = os.path.join(data_path, "dev.txt")
            dev_t = pd.read_csv(dev_path, sep="\t", header=None, names=["sentence", "label"])
            out_sentence.extend(dev_t["sentence"])
            out_label.extend(dev_t["label"])
            # 测试集数据
            test_path = os.path.join(data_path, "test.txt")
            test_t = pd.read_csv(test_path, sep="\t", header=None, names=["sentence", "label"])
            out_sentence.extend(test_t["sentence"])
            out_label.extend(test_t["label"])
            # 去重
            clean_sentence, clean_label = [], []
            for s,l in zip(out_sentence, out_label):
                if s not in clean_sentence:
                    clean_sentence.append(s)
                    clean_label.append(l)
            # 写入文件
            df = pd.DataFrame(
                {
                    "sentence": clean_sentence,
                    "label": clean_label,
                }
            )
            all_data_path = os.path.join(self.out_path, "all_" + role + ".txt")
            df.to_csv(all_data_path, sep="\t", index=None, header=None)
            return df
    
        def get_all(self):
            """坐席客户数据文件合并"""
            df_seats = self.data_merge("seats")
            df_semantic = self.data_merge("semantic")
            df = pd.concat([df_seats, df_semantic])
            all_data = os.path.join(self.out_path, "all_data.txt")
            df.to_csv(all_data, sep="\t", header=None, index=None)
    
        def generator_negtive_data(self):
            """数据随机打乱"""
            all_data = os.path.join(self.out_path, "all_data.txt")
            df_1 = pd.read_csv(all_data, sep="\t", header=None, names=["sentence", "label"])
            df_2 = df_1.sample(frac=1)
            out_s1, out_s2, out_l = [],[],[]
            for s1, l1, s2, l2 in zip(df_1["sentence"], df_1["label"],
                                      df_2["sentence"], df_2["label"]):
                l1 = l1.replace(" ", "").replace("\n", "").replace("\r", "")
    
                if l1 != l2:
                    out_s1.append(s1)
                    out_s2.append(s2)
                    out_l.append("0")
            df = pd.DataFrame({
                "s1":out_s1,
                "s2":out_s2,
                "label": out_l
            })
            df.to_csv(self.neg_data, sep="\t", index=None)
    
        def generator_sim_data(self):
            out_s1, out_s2, out_label = list(), list(), list()
            all_data = os.path.join(self.out_path, "all_data.txt")
            t = pd.read_csv(all_data, sep="\t", header=None, names=["sentence", "label"])
            data_dict = dict()
            for index, row in t.iterrows():
                s = row["sentence"]
                l = row["label"]
                if l not in data_dict.keys():
                    data_dict[l] = list()
                data_dict[l].append(s)
            for l, s_list in data_dict.items():
                s_list_len = len(s_list)
                for index, s in enumerate(s_list):
                    if index > s_list_len -2:
                        break
                    out_s1.append(s)
                    out_s2.append(s_list[index + 1])
                    out_label.append("1")
    
            df = pd.DataFrame({
                "s1": out_s1,
                "s2": out_s2,
                "label": out_label
            })
            df.to_csv(self.sim_path, index=None, sep="\t")
    
        def merge_sim_neg_data(self):
            df1 = pd.read_csv(self.sim_path, sep="\t")
            df2 = pd.read_csv(self.neg_data, sep="\t")
            df_all = pd.concat([df1, df2])
            df = df_all.sample(frac=1.0)
            cut_idx_1 = int(round(0.05 * df.shape[0]))
            cut_idx_2 = int(round(0.1 * df.shape[0]))
            print(cut_idx_1, cut_idx_2)
            df_test, df_dev, df_train = df.iloc[:cut_idx_1], df.iloc[cut_idx_1:cut_idx_2], df.iloc[cut_idx_2:]
            df_test.to_csv(self.test_sim, index=False, sep='\t')
            df_dev.to_csv(self.dev_sim, index=False, sep='\t')
            df_train.to_csv(self.train_sim, index=False, sep='\t')
    
    if __name__ == '__main__':
        DataMerge().merge_sim_neg_data()

    三、root_path

    import os
    __all__ = ["root"]
    _parent_path = os.path.split(os.path.realpath(__file__))[0]
    _root = _parent_path[:_parent_path.find("sentence_bert")]
    root = os.path.join(_root, "sentence_bert")

    四、训练

    from torch.utils.data import DataLoader
    import math
    from sentence_transformers import SentenceTransformer,  LoggingHandler, losses, models, util
    from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
    from sentence_transformers.readers import InputExample
    import logging
    from datetime import datetime
    import os
    from common.root_path import root
    import pandas as pd
    
    class MySentenceBert():
        logging.basicConfig(format='%(asctime)s - %(message)s',
                            datefmt='%Y-%m-%d %H:%M:%S',
                            level=logging.INFO,
                            handlers=[LoggingHandler()])
        def __init__(self):
    
            self.train_batch_size = 16
            self.num_epochs = 4
            data_path = os.path.join(root, "data", "final_data")
            self.train_data = pd.read_csv(os.path.join(data_path, "train.txt"), sep="\t")
            self.val_data = pd.read_csv(os.path.join(data_path, "val.txt"), sep="\t")
            self.test_data = pd.read_csv(os.path.join(data_path, "test.txt"), sep="\t")
            self.model_save_path = os.path.join(root, "chkpt", "sentence_bert_model" +
                                                datetime.now().strftime("_%Y_%m_%d_%H_%M"))
    
        def data_generator(self):
            logging.info("generator dataset")
            train_datas = []
            dev_datas = []
            test_datas = []
            for s1, s2, l in zip(self.train_data["s1"],
                                 self.train_data["s2"],
                                 self.train_data["label"]):
                train_datas.append(InputExample(texts=[s1, s2], label=float(l)))
            for s1, s2, l in zip(self.val_data["s1"],
                                 self.val_data["s2"],
                                 self.val_data["label"]):
                dev_datas.append(InputExample(texts=[s1, s2], label=float(l)))
            for s1, s2, l in zip(self.test_data["s1"],
                                 self.test_data["s2"],
                                 self.test_data["label"]):
                test_datas.append(InputExample(texts=[s1, s2], label=float(l)))
            return train_datas, dev_datas, test_datas
    
        def train(self, train_datas, dev_datas, model):
            train_dataloader = DataLoader(train_datas, shuffle=True, batch_size=self.train_batch_size)
            train_loss = losses.CosineSimilarityLoss(model=model)
            evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_datas, name='sts-dev')
            warmup_steps = math.ceil(len(train_dataloader) * self.num_epochs  * 0.1)
            logging.info("Warmup-steps: {}".format(warmup_steps))
            model.fit(train_objectives=[(train_dataloader, train_loss)],
                      evaluator=evaluator,
                      epochs=self.num_epochs,
                      evaluation_steps=1000,
                      warmup_steps=warmup_steps,
                      output_path=self.model_save_path)
        def test(self, test_samples):
            model = SentenceTransformer(self.model_save_path)
            test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name='sts-test')
            test_evaluator(model, output_path=self.model_save_path)
    
        def main(self):
            train_datas, dev_datas, test_datas = self.data_generator()
    
            model_name = os.path.join(root, "chkpt", "distiluse-base-multilingual-cased")
            word_embedding_model = models.Transformer(model_name)
            pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
                                           pooling_mode_mean_tokens=True,
                                           pooling_mode_cls_token=False,
                                           pooling_mode_max_tokens=False)
            model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
            self.train(train_datas, dev_datas, model)
            self.test(test_datas)
    
    if __name__ == '__main__':
        MySentenceBert().main()

    五、向量检索

    from sentence_transformers import SentenceTransformer, util
    import os
    import csv
    import pickle
    import time
    from root_path import root
    import json
    
    class SemanticSearch():
        def __init__(self):
            model_name = os.path.join(root, "chkpt", "sentence_bert_model_2021_08_05_18_16")
            self.model = SentenceTransformer(model_name)
            embedding_cache_path = 'semantic_search_embedding.pkl'
            dataset_path = os.path.join(root, "data", "bert_data", "index.txt")
            with open(os.path.join(root, "config", "code_to_label.json"), "r", encoding="utf8") as f:
                self.d = json.load(f)
            self.sentences = list()
            self.code = list()
            if not os.path.exists(embedding_cache_path):
                with open(dataset_path, encoding='utf8') as fIn:
                    for read_line in fIn:
                        read_line = read_line.split("\t")
                        self.sentences.append(read_line[0])
                        self.code.append(read_line[1].replace("\n", ""))
                print("Encode the corpus. This might take a while")
                self.embeddings = self.model.encode(self.sentences, show_progress_bar=True, convert_to_tensor=True)
                print("Store file on disc")
                with open(embedding_cache_path, "wb") as fOut:
                    pickle.dump({'sentences': self.sentences, 'embeddings': self.embeddings, "code": self.code}, fOut)
            else:
                print("Load pre-computed embeddings from disc")
                with open(embedding_cache_path, "rb") as fIn:
                    cache_data = pickle.load(fIn)
                    self.sentences = cache_data['sentences']
                    self.embeddings = cache_data['embeddings']
                    self.code = cache_data["code"]
    
        def query(self, query):
            inp_question = query
            question_embedding = self.model.encode(inp_question, convert_to_tensor=True)
            hits = util.semantic_search(question_embedding, self.embeddings)
            hit = hits[0][0]  # Get the hits for the first query
            score = hit['score']
            text = self.sentences[hit['corpus_id']]
            kh_code = self.code[hit['corpus_id']]
            label = self.d[kh_code][1]
            return label,score,text
    
        def main(self):
            self.query("你好")
    
    
    if __name__ == '__main__':
        SemanticSearch().main()

     六、参考

    https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/sts/training_stsbenchmark.py

    https://github.com/UKPLab/sentence-transformers/blob/master/examples/applications/semantic-search/semantic_search_quora_pytorch.py

  • 相关阅读:
    Postgresql HStore 插件试用小结
    postgres-xl 安装与部署 【异常处理】ERROR: could not open file (null)/STDIN_***_0 for write, No such file or directory
    GPDB 5.x PSQL Quick Reference
    postgresql 数据库schema 复制
    hive 打印日志
    gp与 pg 查询进程
    jquery table 发送两次请求 解惑
    python 字符串拼接效率打脸帖
    postgresql 日期类型处理实践
    IBM Rational Rose软件下载以及全破解方法
  • 原文地址:https://www.cnblogs.com/zhangxianrong/p/15908576.html
Copyright © 2020-2023  润新知