• 基于keras的fasttext短文本分类


    ### train_model.py ###

    #!/usr/bin/env python
    # coding=utf-8
    
    import codecs
    import simplejson as json
    import numpy as np
    import pandas as pd
    from keras.models import Sequential, load_model
    from keras.callbacks import EarlyStopping, ModelCheckpoint
    from keras.preprocessing import sequence
    from keras.utils import to_categorical
    from keras.layers import *
    from sklearn.model_selection import train_test_split
    from sklearn.preprocessing import LabelEncoder
    from sklearn.externals import joblib
    import logging
    import re
    import pickle as pkl
    
    logging.basicConfig(level=logging.INFO, format='%(asctime)s %(filename)s: %(message)s', datefmt='%Y-%m-%d %H:%M', filename='log/train_model.log', filemode='a+')
    
    ngram_range = 1
    max_features = 6500
    maxlen = 120
    
    fw = open('error_line_test.txt', 'wb')
    
    DIRTY_LABEL = re.compile('W+')
    # set([u'业务',u'代销',u'施工',u'策划',u'设计',u'销售',u'除外',u'零售',u'食品'])
    STOP_WORDS = pkl.load(open('./data/stopwords.pkl'))
    
    
    def load_data(fname='data/12315_industry_business_train.csv', nrows=None):
        """
        载入训练数据
        """
        data, labels = [], []
        char2idx = json.load(open('data/char2idx.json'))
        used_keys = set(['name', 'business'])
        df = pd.read_csv(fname, encoding='utf-8', nrows=nrows)
        for idx, item in df.iterrows():
            item = item.to_dict()
            line = ''
            for key, value in item.iteritems():
                if key in used_keys:
                    line += key+value
        
            data.append([char2idx[char] for char in line if char in char2idx])
            labels.append(item['label'])
    
        le = LabelEncoder()
        logging.info('%d nb_class: %s' % (len(np.unique(labels)), str(np.unique(labels))))
        onehot_label = to_categorical(le.fit_transform(labels))
        joblib.dump(le, 'model/tgind_labelencoder.h5')
        x_train, x_test, y_train, y_test = train_test_split(data, onehot_label, test_size=0.1)
        return (x_train, y_train), (x_test, y_test)
    
    
    def create_ngram_set(input_list, ngram_value=2):
        return set(zip(*[input_list[i:] for i in range(ngram_value)]))
    
    
    def add_ngram(sequences, token_indice, ngram_range=2):
        """
        Augment the input list of sequences by appending n-grams values
    
        """
        new_sequences = []
        for input_list in sequences:
            new_list = input_list[:]
            for i in range(len(new_list) - ngram_range + 1):
                for ngram_value in range(2, ngram_range+1):
                    ngram = tuple(new_list[i:i+ngram_value])
                    if ngram in token_indice:
                        new_list.append(token_indice[ngram])
            new_sequences.append(new_list)
    
        return new_sequences
    
    (x_train, y_train), (x_test, y_test) = load_data()
    nb_class = y_train.shape[1]
    
    
    logging.info('x_train size: %d' % (len(x_train)))
    logging.info('x_test size: %d' % (len(x_test)))
    logging.info('x_train sent average len: %.2f' % (np.mean(list(map(len, x_train)))))
    print 'x_train sent avg length: %.2f' % (np.mean(list(map(len, x_train))))
    
    if ngram_range>1:
        print 'add {}-gram features'.format(ngram_range)
        ngram_set = set()
        for input_list in x_train:
            for i in range(2, ngram_range+1):
                set_of_ngram = create_ngram_set(input_list, ngram_value=i)
                ngram_set.update(set_of_ngram)
    
        start_index = max_features + 1
        token_indice = {v: k+start_index for k,v in enumerate(ngram_set)}
        indice_token = {token_indice[k]: k for k in token_indice}
    
        max_features = np.max(list(indice_token.keys()))+1
    
        x_train = add_ngram(x_train, token_indice, ngram_range)
        x_test = add_ngram(x_test, token_indice, ngram_range)
    
    
    print 'pad sequences (samples x time)'
    x_train = sequence.pad_sequences(x_train, maxlen=maxlen, padding='post', truncating='post')
    x_test = sequence.pad_sequences(x_test, maxlen=maxlen, padding='post', truncating='post')
    
    logging.info('x_train.shape: %s' % (str(x_train.shape)))
    
    print 'build model...'
    
    def cal_accuracy(x_test, y_test):
        """
        准确率统计
        """
        y_test = np.argmax(y_test, axis=1)
        y_pred = model.predict_classes(x_test)
        correct_cnt = np.sum(y_pred==y_test)
        return float(correct_cnt)/len(y_test)
    
    DEBUG = False
    if DEBUG:
        model = Sequential()
        model.add(Embedding(max_features, 200, input_length=maxlen))
        model.add(GlobalAveragePooling1D())
        model.add(Dropout(0.3))
        model.add(Dense(nb_class, activation='softmax'))
    else:
        model = load_model('./model/tgind_dalei.h5')
    
    #model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    earlystop = EarlyStopping(monitor='val_loss', patience=8)
    checkpoint = ModelCheckpoint(filepath='./model/tgind_dalei.h5', monitor='val_loss', save_best_only=True, save_weights_only=False)
    
    
    model.fit(x_train, y_train, shuffle=True, batch_size=64, epochs=80, validation_split=0.1, callbacks=[checkpoint, earlystop])
    
    loss, acc = model.evaluate(x_test, y_test)
    print '
    
    last model: loss', loss
    print 'acc', acc
    
    
    model = load_model('model/tgind_dalei.h5')
    loss, acc = model.evaluate(x_test, y_test)
    print '
    
     cur best model: loss', loss
    print 'accuracy', acc
    logging.info('loss: %.4f ;accuracy: %.4f' % (loss, acc))
    
    logging.info('
    model acc: %.4f' % acc)
    logging.info('
    model config:
     %s' % model.get_config())

    ### test_model.py ###

    #!/usr/bin/env python
    # coding=utf-8
    
    import matplotlib.pyplot as plt
    from api_tgind import TgIndustry
    import pandas as pd
    import codecs
    import json
    from collections import OrderedDict
    
    ###########  根据阈值计算准确率  ###########
    
    
    def cal_model_acc(model, fname='./data/industry_dalei_test_sample2k.txt', nrows=None):
        """
        载入数据, 并计算前5的准确率
        """
        res = {}
        res['y_pred'] = []
        res['y_true'] = []
        with codecs.open(fname, encoding='utf-8') as fr:
            for idx, line in enumerate(fr):
                tokens = line.strip().split()
                if len(tokens)>3:
                    tokens, label = tokens[:-1], tokens[-1].replace('__labe__', '')
                    tmp = {}
                    tmp['business'] = ''.join(tokens)
                    res['y_pred'].append(model.predict(tmp))
                    res['y_true'].append(label)
                if nrows and idx>nrows:
                    break
        json.dump(res, codecs.open('log/total_acc_output.json', 'wb', encoding='utf-8'))
        return res
    
    def cal_model_acc2(model, fname='data/test_12315_industry_business_sample100.csv', nrows=None):
        """
        直接根据csv预测结果
        """
        res = {}
        res['y_pred'] = []
        res['y_true'] = []
        df = pd.read_csv(fname, encoding='utf-8')
        for idx, item in df.iterrows():
            try:
                res['y_pred'].append(model.predict(item.to_dict()))
            except Exception as e:
                print e
                print idx
                print item['name']
                continue
            res['y_true'].append(item['label'])
    
            if nrows and idx>nrows:
                break
        json.dump(res, codecs.open('log/total_acc_output.json', 'wb', encoding='utf-8'))
        return res
    
    
    
    
    
    def get_model_acc_menlei(res, topk=5, threhold=0.8):
        """
        根据阈值计算模型准确率
        """
        correct_cnt, total_cnt = 0, 0
        for idx, y_pred in enumerate(res['y_pred']):
            y_pred_tuple = sorted(y_pred.iteritems(), key=lambda x:float(x[1]), reverse=True)  # 概率排序
            y_pred = OrderedDict()
            for c, s in y_pred_tuple:
                y_pred[c] = float(s)
    
            if y_pred.values()[0] > threhold:    # 最大类别概率大于阈值threhold 
                if res['y_true'][idx][0] in map(lambda x:x[0], y_pred.keys()[:topk]):
                    correct_cnt += 1
                total_cnt += 1
        acc = float(correct_cnt)/total_cnt
        recall = float(total_cnt)/len(res['y_true'])
        return acc, recall
    
    def get_model_acc_dalei(res, topk=5, threhold=0.8):
        """
        根据阈值计算模型准确率
        """
        correct_cnt, total_cnt = 0, 0
        for idx, y_pred in enumerate(res['y_pred']):
            y_pred_tuple = sorted(y_pred.iteritems(), key=lambda x:float(x[1]), reverse=True)  # 概率排序
            y_pred = OrderedDict()
            for c, s in y_pred_tuple:
                y_pred[c] = float(s)
    
            if y_pred.values()[0] >= threhold:    # 最大类别概率大于阈值threhold 
                if res['y_true'][idx] in y_pred.keys()[:topk]:
                    correct_cnt += 1
                total_cnt += 1
        
        acc = float(correct_cnt)/total_cnt
        recall = float(total_cnt)/len(res['y_true'])
        return acc, recall
    
    
    def plot_accuracy(title, df, number):
        """
        准确率绘图
        """
        for topk in range(1, 5):
            tmpdf = df[df.topk==topk]
            fig = plt.figure()
            ax1 = fig.add_subplot(111)
            plt.subplots_adjust(top=0.85)
            ax1.plot(tmpdf['threhold'], tmpdf['accuracy'], 'ro-', label='accuracy')
    #        ax2 = ax1.twinx()
            ax1.plot(tmpdf['threhold'], tmpdf['recall'], 'g^-', label='recall')
            ax1.set_ylim(0.3, 1.0)
            ax1.legend(loc=3)
            ax1.set_xlabel('threhold')
            plt.grid(True)
            plt.title('%s Industry Classify Result
     topk=%d, number=%d
    ' % (title, topk, number))
            plt.savefig('log/test_%s_acc_topk%d.png' % (title, topk))
            print topk, 'done!'
    
    
    def gen_plot_data(model_acc, ctype='2nd'):
        """
        生成图数据
        """
        res = {}
        res['accuracy'] = []
        res['threhold'] = []
        res['topk'] = []
        res['recall'] = []
        for topk in range(1,5):
            for threhold in range(0, 10):
                threhold = 0.1*threhold
                if ctype == '1st':
                    acc, recall = get_model_acc_menlei(model_acc, topk, threhold)
                else:
                    acc, recall = get_model_acc_dalei(model_acc, topk, threhold)
                res['accuracy'].append(acc)
                res['recall'].append(recall)
                res['threhold'].append(threhold)
                res['topk'].append(topk)
            print ctype, topk, acc
        json.dump(res, open('log/test_model_threshold_%s.log' % ctype, 'wb'))
        df = pd.DataFrame(res)
        df.to_csv('log/test_model_result_%s.csv' % ctype, index=False)
        plot_accuracy(ctype, df, len(model_acc['y_true']))
        return df
    
    if __name__=='__main__':
        
        model = TgIndustry()
        # model_acc = cal_model_acc2(model, fname='data/test_12315_industry_business_sample100.csv')
        model_acc = json.load(codecs.open('log/total_acc_output_12315.json', encoding='utf-8'))
        gen_plot_data(model_acc, '1st')
        gen_plot_data(model_acc, '2nd')

    ### api_tgind.py ###

    #!/usr/bin/env python
    # coding=utf-8
    
    import numpy as np
    import codecs
    import simplejson as json
    from keras.models import load_model
    from keras.preprocessing import sequence
    from sklearn.externals import joblib
    from collections import OrderedDict
    import pickle as pkl
    import re, os
    import jieba
    import time
    
    """
    行业分类调用Api
    
    __author__: jkmiao
    __date__: 2017-07-05
    
    """
    
    
    class TgIndustry(object):
    
        def __init__(self, model_path='model/tgind_dalei_acc76.h5'):
    
            base_path = os.path.dirname(__file__)
            model_path = os.path.join(base_path, model_path)
    
            # 载入预训练好的模型
            self.model = load_model(model_path)
            # 载入labelEncoder
            self.le = joblib.load(os.path.join(base_path, './model/tgind_labelencoder.h5'))
            # 载入字符映射表
            self.char2idx = json.load(open(os.path.join(base_path, 'data/char2idx.json')))
            # 载入停用词表
            # self.stop_words = set([line.strip() for line in codecs.open('./data/stopwords.txt', encoding='utf-8')])
            self.stop_words = pkl.load(open(os.path.join(base_path, './data/stopwords.pkl')))
            # 载入类别最终的编号和名称映射
            self.menlei_label2name = json.load(open(os.path.join(base_path, 'data/menlei_label2name.json')))  # 一级分类
            self.dalei_label2name = json.load(open(os.path.join(base_path, 'data/dalei_label2name.json'))) # 二级分类
    
    
        def predict(self, company_info, topk=2, firstIndustry=False, final_name=False):
            """
            :type company_info: 公司相关信息
            :rtype business: str: 对应 label
            """
            line = ''
            for key, value in company_info.iteritems():
                if key in ['name', 'business']: # 公司信息, 目前取公司名和经营范围
                    line +=  company_info[key]
                
            if not isinstance(line, unicode):
                line = line.decode('utf-8')
                
            # 去除停用词后的句子
            line = ''.join([token for token in jieba.cut(line) if token not in self.stop_words])
            data = [self.char2idx[char] for char in line if char in self.char2idx]
            data = sequence.pad_sequences([data], maxlen=100, padding='post', truncating='post')
            y_pred_proba = self.model.predict(data, verbose=0)
            y_pred_idx_list = [c[-topk:][::-1] for c in np.argsort(y_pred_proba, axis=-1)][0]
            res = OrderedDict()
            for y_pred_idx in y_pred_idx_list:
                y_pred_label = self.le.inverse_transform(y_pred_idx)
                if final_name:
                    y_pred_label = self.dalei_label2name[y_pred_label]
                if firstIndustry:
                    res[y_pred_label[0]] = round(y_pred_proba[0, y_pred_idx], 3) # 概率保留3位小数
                res[y_pred_label] = round(y_pred_proba[0, y_pred_idx], 3) # 概率保留3位小数
            return res
    
    
    if __name__ == '__main__':
    
        DIRTY_LABEL = re.compile('W+')
        test = TgIndustry()
        cnt, total_cnt = 0, 0
        start_time = time.time()
        fw2 = codecs.open('./output/industry_dalei_test_sample2k_error.txt', 'wb', encoding='utf-8')
        with codecs.open('./data/industry_dalei_test_sample2k.txt', encoding='utf-8') as fr:
            for idx, line in enumerate(fr):
                tokens = line.strip().split()
                if len(tokens)>3:
                    tokens, label = tokens[:-1], tokens[-1].replace('__label__', '')
                    if len(label) not in [2, 3] or DIRTY_LABEL.search(label):
                        print 'error line:'
                        print idx, line, label
                        continue
                    tmp = {}
                    tmp['business'] = ''.join(tokens)
                    y_pred = test.predict(tmp, topk=1)
                    if label in y_pred:
                        cnt += 1
                    elif y_pred.values()[0] < 0.3:
                        print 'error: ', ''.join(tokens), y_pred, 'y_true:', label
                        fw2.write(''.join(tokens))
                    total_cnt +=1
                    print label 
                    print json.dumps(y_pred, ensure_ascii=False) 
                    print idx, '=='*20, float(cnt)/total_cnt
                    if idx>200:
                        break
            print 'avg cost time:', float(time.time()-start_time)/idx
  • 相关阅读:
    git的使用
    对大学学习的一些看法
    远程连接mysql失败情况总结
    缓存穿透、缓存击穿、缓存雪崩
    Hello Redis
    Celery的简单使用
    git操作
    码云、github同时配置ssh key,解决冲突问题
    阿里云短信验证码服务
    Vue中img标签的src属性绑定的坑
  • 原文地址:https://www.cnblogs.com/jkmiao/p/7210276.html
Copyright © 2020-2023  润新知