• reshape、shuffle、save_weights


    #-*- coding: utf-8 -*-
    
    import pandas as pd
    from random import shuffle
    import matplotlib.pyplot as plt #导入Matplotlib
    
    datafile = '../data/model.xls'
    data = pd.read_excel(datafile)
    data = data.as_matrix()
    shuffle(data)
    
    p = 0.8 #设置训练数据比例
    train = data[:int(len(data)*p),:]
    test = data[int(len(data)*p):,:]
    
    #构建LM神经网络模型
    from keras.models import Sequential #导入神经网络初始化函数
    from keras.layers.core import Dense, Activation #导入神经网络层函数、激活函数
    
    netfile = '../tmp/net.model' #构建的神经网络模型存储路径
    
    net = Sequential() #建立神经网络
    net.add(Dense(input_dim = 3, output_dim = 10)) #添加输入层(3节点)到隐藏层(10节点)的连接
    net.add(Activation('relu')) #隐藏层使用relu激活函数
    net.add(Dense(input_dim = 10, output_dim = 1)) #添加隐藏层(10节点)到输出层(1节点)的连接
    net.add(Activation('sigmoid')) #输出层使用sigmoid激活函数
    net.compile(loss = 'binary_crossentropy', optimizer = 'adam', metrics=['accuracy']) #编译模型,使用adam方法求解
    
    net.fit(train[:,:3], train[:,3], nb_epoch=50, batch_size=1) #训练模型,循环1000次
    net.save_weights(netfile) #保存模型
    #print(net.predict_classes(train[:,:3]))
    # [[1]
    #  [1]
    #  [1]
    #  [1]
    #  [1]
    #  [1]
    #  [1]
    #  [1]
    #  [0]
    #  [1]
    #  [0]
    #  [1]
    #  [1]
    #  [0]
    predict_result = net.predict_classes(train[:,:3]).reshape(len(train)) #预测结果变形
    #print(predict_result)
    #[1 1 1 1 1 1 1 1 0 1 0 1 1 0 0]
    '''这里要提醒的是,keras用predict给出预测概率,predict_classes才是给出预测类别,而且两者的预测结果都是n x 1维数组,而不是通常的 1 x n'''
    
    # from cm_plot import * #导入自行编写的混淆矩阵可视化函数
    # cm_plot(train[:,3], predict_result).show() #显示混淆矩阵可视化结果
    
    from sklearn.metrics import roc_curve #导入ROC曲线函数
    
    predict_result = net.predict(test[:,:3]).reshape(len(test))
    fpr, tpr, thresholds = roc_curve(test[:,3], predict_result, pos_label=1)
    plt.plot(fpr, tpr, linewidth=2, label = 'ROC of LM') #作出ROC曲线
    plt.xlabel('False Positive Rate') #坐标轴标签
    plt.ylabel('True Positive Rate') #坐标轴标签
    plt.ylim(0,1.05) #边界范围
    plt.xlim(0,1.05) #边界范围
    plt.legend(loc=4) #图例
    plt.show() #显示作图结果
  • 相关阅读:
    spoj 3273 Treap
    hdu1074 Doing Homework
    hdu1024 Max Sum Plus Plus
    hdu1081 To the Max
    hdu1016 Prime Ring Problem
    hdu4727 The Number Off of FFF
    【判断二分图】C. Catch
    【01染色法判断二分匹配+匈牙利算法求最大匹配】HDU The Accomodation of Students
    【数轴涂色+并查集路径压缩+加速】C. String Reconstruction
    【数轴染色+并查集路径压缩+加速】数轴染色
  • 原文地址:https://www.cnblogs.com/ggzhangxiaochao/p/9115295.html
Copyright © 2020-2023  润新知