• 使用CNN做电影评论的负面检测——本质上感觉和ngram或者LSTM同,因为CNN里图像检测卷积一般是3x3,而文本分类的话是直接是一维的3、4、5


    代码如下:

    from __future__ import division, print_function, absolute_import
    
    import tensorflow as tf
    import tflearn
    from tflearn.layers.core import input_data, dropout, fully_connected
    from tflearn.layers.conv import conv_1d, global_max_pool
    from tflearn.layers.merge_ops import merge
    from tflearn.layers.estimator import regression
    from tflearn.data_utils import to_categorical, pad_sequences
    from tflearn.datasets import imdb
    import os
    from tensorflow.contrib.learn.python import learn
    from sklearn import metrics
    from sklearn.model_selection import train_test_split
    import numpy as np
    
    MAX_DOCUMENT_LENGTH = 200
    EMBEDDING_SIZE = 50
    
    n_words=0
    
    
    def load_one_file(filename):
        x=""
        with open(filename) as f:
            for line in f:
                x+=line
        return x
    
    def load_files(rootdir,label):
        list = os.listdir(rootdir)
        x=[]
        y=[]
        for i in range(0, len(list)):
            path = os.path.join(rootdir, list[i])
            if os.path.isfile(path):
                #print "Load file %s" % path
                y.append(label)
                x.append(load_one_file(path))
    
        return x,y 
    
    
    def load_data():
        x=[]
        y=[]
        x1,y1=load_files("../data/movie-review-data/review_polarity/txt_sentoken/pos/",0)
        x2,y2=load_files("../data/movie-review-data/review_polarity/txt_sentoken/neg/", 1)
        x=x1+x2
        y=y1+y2
        return x,y 
    def  do_cnn(trainX, trainY,testX, testY):
        global n_words
        # Data preprocessing
        # Sequence padding
        trainX = pad_sequences(trainX, maxlen=MAX_DOCUMENT_LENGTH, value=0.)
        testX = pad_sequences(testX, maxlen=MAX_DOCUMENT_LENGTH, value=0.)
        # Converting labels to binary vectors
        trainY = to_categorical(trainY, nb_classes=2)
        testY = to_categorical(testY, nb_classes=2)
        # Building convolutional network
        network = input_data(shape=[None, MAX_DOCUMENT_LENGTH], name='input')
        network = tflearn.embedding(network, input_dim=n_words+1, output_dim=128)
        branch1 = conv_1d(network, 128, 3, padding='valid', activation='relu', regularizer="L2")
        branch2 = conv_1d(network, 128, 4, padding='valid', activation='relu', regularizer="L2")
        branch3 = conv_1d(network, 128, 5, padding='valid', activation='relu', regularizer="L2")
        network = merge([branch1, branch2, branch3], mode='concat', axis=1)
        network = tf.expand_dims(network, 2)
        network = global_max_pool(network)
        network = dropout(network, 0.5)
        network = fully_connected(network, 2, activation='softmax')
        network = regression(network, optimizer='adam', learning_rate=0.001,
                             loss='categorical_crossentropy', name='target')
        # Training
        model = tflearn.DNN(network, tensorboard_verbose=0)
        model.fit(trainX, trainY, n_epoch = 20, shuffle=True, validation_set=(testX, testY), show_metric=True, batch_size=32)
    
    if __name__ == '__main__':
        # IMDB Dataset loading
        global n_words
    
        x,y=load_data()
    
        x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.4, random_state=0)
    
        vp = learn.preprocessing.VocabularyProcessor(max_document_length=MAX_DOCUMENT_LENGTH, min_frequency=1)
        vp.fit(x)
        x_train = np.array(list(vp.transform(x_train)))
        x_test = np.array(list(vp.transform(x_test)))
        n_words=len(vp.vocabulary_)
        print('Total words: %d' % n_words)
    
        do_cnn(x_train, y_train,x_test, y_test)
                                                          

    准确率是100%

  • 相关阅读:
    可空类型转换为不可空的普通类型
    如何使用AspNetPager分页控件和ObjectDataSource控件进行分页
    TFS映射后丢失引用的问题
    (很好用)JS时间控件实现日期的多选
    取两个日期之间的非工作日的天数(指的是周六、周日)
    在日期格式化的时候提示错误:Tostring没有采用一个参数的重载
    Linq返回的集合类型不是已有的表格类型时的写法(谨记:列表的时候用)
    系统缓存全解析6:数据库缓存依赖
    实现文本框动态限制字数的实现(好方法)
    实现GridView内容循环滚动
  • 原文地址:https://www.cnblogs.com/bonelee/p/7908346.html
Copyright © 2020-2023  润新知