• logistic回归梯度上升优化算法


      1 # Author Qian Chenglong
      2 
      3 from numpy import *
      4 from numpy.ma import arange
      5 
      6 
      7 def loadDataSet():
      8     dataMat = []
      9     labelMat = []
     10     fr = open('testSet.txt')
     11     for line in fr.readlines():
     12         lineArr = line.strip().split()
     13         dataMat.append([1.0, float(lineArr[0]), float(lineArr[1])])
     14         labelMat.append(int(lineArr[2]))
     15     return dataMat, labelMat
     16 
     17 #sigmoid归一化函数
     18 #输入:z=w1x1+w2x2+w3x3......
     19 #s输出:归一化结果
     20 def sigmoid(inX):
     21     return 1.0 / (1 + exp(-inX))
     22 
     23 
     24 '''
     25 logistic回归梯度上升优化算法
     26 param dataMatIn: 处理后的数据集
     27 param classLabels: 分类标签
     28 return: 权重值
     29  '''
     30 def gradAscent(dataMatIn, classLabels):
     31     dataMatrix = mat(dataMatIn)  # convert to NumPy matrix(矩阵)
     32     labelMat = mat(classLabels).transpose()  # convert to NumPy matrix
     33     m, n = shape(dataMatrix)          #m行  n列
     34     alpha = 0.001                       #步长
     35     maxCycles = 500
     36     weights = ones((n, 1))              #系数,权重
     37     for k in range(maxCycles):  # heavy on matrix operations
     38         h = sigmoid(dataMatrix * weights)  # matrix mult
     39         error = (labelMat - h)  # vector subtraction
     40         weights = weights + alpha * dataMatrix.transpose() * error  # transpose()矩阵转置
     41     return weights
     42 
     43 '''
     44 画出数据集和logisitic回归最佳拟合直线的函数
     45 param weights:
     46 return:
     47 最后的分割方程是y=(-w0-w1*x)/w2
     48 '''
     49 def plotBestFit(weights):
     50     import matplotlib.pyplot as plt
     51     dataMat, labelMat = loadDataSet()
     52     dataArr = array(dataMat)
     53     n = shape(dataArr)[0]
     54     xcord1 = []
     55     ycord1 = []
     56     xcord2 = []
     57     ycord2 = []
     58     for i in range(n):
     59         if int(labelMat[i]) == 1:
     60             xcord1.append(dataArr[i, 1]);
     61             ycord1.append(dataArr[i, 2])
     62         else:
     63             xcord2.append(dataArr[i, 1]);
     64             ycord2.append(dataArr[i, 2])
     65     fig = plt.figure()
     66     ax = fig.add_subplot(111)
     67     ax.scatter(xcord1, ycord1, s=30, c='red', marker='s')
     68     ax.scatter(xcord2, ycord2, s=30, c='green')
     69     x = arange(-3.0, 3.0, 0.1)
     70     y = (-weights[0] - weights[1] * x) / weights[2]
     71     ax.plot(x, y)
     72     plt.xlabel('X1')
     73     plt.ylabel('X2')
     74     plt.show()
     75 
     76 '''随机梯度上升
     77 param dataMatIn: 处理后的数据集
     78 param classLabels: 分类标签
     79 return: 权重值'''
     80 def stocGradAscent0(dataMatrix, classLabels):
     81     m, n = shape(dataMatrix)
     82     alpha = 0.01
     83     weights = ones(n)  # initialize to all ones
     84     for i in range(m):
     85         h = sigmoid(sum(dataMatrix[i] * weights))
     86         error = classLabels[i] - h
     87         weights = weights + alpha * error * dataMatrix[i]
     88     return weights
     89 
     90 '''改进的随机梯度上升
     91 param dataMatIn: 处理后的数据集
     92 param classLabels: 分类标签
     93 return: 权重值'''
     94 def stocGradAscent1(dataMatrix, classLabels, numIter=150):
     95     m, n = shape(dataMatrix)
     96     weights = ones(n)  # initialize to all ones
     97     for j in range(numIter):
     98         dataIndex = range(m)
     99         for i in range(m):
    100             alpha = 4 / (1.0 + j + i) + 0.0001  # apha decreases with iteration, does not
    101             randIndex = int(random.uniform(0, len(dataIndex)))  # go to 0 because of the constant
    102             h = sigmoid(sum(dataMatrix[randIndex] * weights))
    103             error = classLabels[randIndex] - h
    104             weights = weights + alpha * error * dataMatrix[randIndex]
    105             del (dataIndex[randIndex])
    106     return weights
  • 相关阅读:
    MySQL数据库分页
    Spring MVC
    Spring框架
    Java学习计划(转载)
    开发用户注册模块
    Ajax技术
    Jodd Email 发送邮件
    DOM技术
    MD5加密
    final关键字的使用
  • 原文地址:https://www.cnblogs.com/long5683/p/9383574.html
Copyright © 2020-2023  润新知