• 使用支持向量机训练mnist数据


      1 # encoding: utf-8
      2 import numpy as np
      3 import matplotlib.pyplot as plt
      4 import cPickle
      5 import gzip
      6 
      7 class SVC(object):
      8     def __init__(self, c=1.0, delta=0.001):  # 初始化
      9         self.N = 0
     10         self.delta = delta
     11         self.X = None
     12         self.y = None
     13         self.w = None
     14         self.wn = 0
     15         self.K = np.zeros((self.N, self.N))
     16         self.a = np.zeros((self.N, 1))
     17         self.b = 0
     18         self.C = c
     19         self.stop=1
     20         self.k=0
     21         self.cls=0
     22         self.train_result=[]
     23 
     24     def kernel_function(self,x1, x2):  # 核函数
     25         return np.dot(x1, x2)
     26 
     27     def kernel_matrix(self, x):  # 核矩阵
     28         for i in range(0, len(x)):
     29             for j in range(i, len(x)):
     30                 self.K[j][i] = self.K[i][j] = self.kernel_function(self.X[i], self.X[j])
     31 
     32     def get_w(self):  # 计算更新w
     33         ay = self.a * self.y
     34         w = np.zeros((1, self.wn))
     35         for i in range(0, self.N):
     36             w += self.X[i] * ay[i]
     37         return w
     38 
     39     def get_b(self, a1, a2, a1_old, a2_old):  # 计算更新B
     40         y1 = self.y[a1]
     41         y2 = self.y[a2]
     42         a1_new = self.a[a1]
     43         a2_new = self.a[a2]
     44         b1_new = -self.E[a1] - y1 * self.K[a1][a1] * (a1_new - a1_old) - y2 * self.K[a2][a1] * (
     45             a2_new - a2_old) + self.b
     46         b2_new = -self.E[a2] - y1 * self.K[a1][a2] * (a1_new - a1_old) - y2 * self.K[a2][a2] * (
     47             a2_new - a2_old) + self.b
     48         if (0 < a1_new) and (a1_new < self.C) and (0 < a2_new) and (a2_new < self.C):
     49             return b1_new[0]
     50         else:
     51             return (b1_new[0] + b2_new[0]) / 2.0
     52 
     53     def gx(self, x):  # 判别函数g(x)
     54         return np.dot(self.w, x) + self.b
     55 
     56     def satisfy_kkt(self, a):  # 判断样本点是否满足kkt条件
     57         index = a[1]
     58         if a[0] == 0 and self.y[index] * self.gx(self.X[index]) > 1:
     59             return 1
     60         elif a[0] < self.C and self.y[index] * self.gx(self.X[index]) == 1:
     61             return 1
     62         elif a[0] == self.C and self.y[index] * self.gx(self.X[index]) < 1:
     63             return 1
     64         return 0
     65 
     66     def clip_func(self, a_new, a1_old, a2_old, y1, y2):  # 拉格朗日乘子的裁剪函数
     67         if (y1 == y2):
     68             L = max(0, a1_old + a2_old - self.C)
     69             H = min(self.C, a1_old + a2_old)
     70         else:
     71             L = max(0, a2_old - a1_old)
     72             H = min(self.C, self.C + a2_old - a1_old)
     73         if a_new < L:
     74             a_new = L
     75         if a_new > H:
     76             a_new = H
     77         return a_new
     78 
     79     def update_a(self, a1, a2):  # 更新a1,a2
     80         partial_a2 = self.K[a1][a1] + self.K[a2][a2] - 2 * self.K[a1][a2]
     81         if partial_a2 <= 1e-9:
     82             print "error:", partial_a2
     83         a2_new_unc = self.a[a2] + (self.y[a2] * ((self.E[a1] - self.E[a2]) / partial_a2))
     84         a2_new = self.clip_func(a2_new_unc, self.a[a1], self.a[a2], self.y[a1], self.y[a2])
     85         a1_new = self.a[a1] + self.y[a1] * self.y[a2] * (self.a[a2] - a2_new)
     86         if abs(a1_new - self.a[a1]) < self.delta:
     87             return 0
     88         self.a[a1] = a1_new
     89         self.a[a2] = a2_new
     90         self.is_update = 1
     91         return 1
     92 
     93     def update(self, first_a):  # 更新拉格朗日乘子
     94         for second_a in range(0, self.N):
     95             if second_a == first_a:
     96                 continue
     97             a1_old = self.a[first_a]
     98             a2_old = self.a[second_a]
     99             if self.update_a(first_a, second_a) == 0:
    100                 return
    101             self.b= self.get_b(first_a, second_a, a1_old, a2_old)
    102             self.w = self.get_w()
    103             self.E = [self.gx(self.X[i]) - self.y[i] for i in range(0, self.N)]
    104             self.stop=0
    105 
    106     def train(self, x, y, max_iternum=100):  # SMO算法
    107         x_len = len(x)
    108         self.X = x
    109         self.N = x_len
    110         self.wn = len(x[0])
    111         self.y = np.array(y).reshape((self.N, 1))
    112         self.K = np.zeros((self.N, self.N))
    113         self.kernel_matrix(self.X)
    114         self.b = 0
    115         self.a = np.zeros((self.N, 1))
    116         self.w = self.get_w()
    117         self.E = [self.gx(self.X[i]) - self.y[i] for i in range(0, self.N)]
    118         self.is_update = 0
    119         for i in range(0, max_iternum):
    120             self.stop=1
    121             data_on_bound = [[x,y] for x,y in zip(self.a, range(0, len(self.a))) if x > 0 and x< self.C]
    122             if len(data_on_bound) == 0:
    123                 data_on_bound = [[x,y] for x,y in zip(self.a, range(0, len(self.a)))]
    124             for data in data_on_bound:
    125                 if self.satisfy_kkt(data) != 1:
    126                     self.update(data[1])
    127             if self.is_update == 0:
    128                 for data in [[x,y] for x,y in zip(self.a, range(0, len(self.a)))]:
    129                     if self.satisfy_kkt(data) != 1:
    130                         self.update(data[1])
    131             if self.stop:
    132                 break
    133         return self.w, self.b
    134 
    135     def fit(self,x, y):  # 训练模型, 一对一法k(k-1)/2个SVM进行多类分类
    136         self.cls, y = np.unique(y, return_inverse=True)
    137         self.k=len(self.cls)
    138         for i in range(self.k):
    139             for j in range(i):
    140                 a,b=self.sub_data(x,y,i,j)
    141                 self.train_result.append([i,j,self.train(a,b)])
    142 
    143     def predict(self,x_new):  # 预测
    144          p=np.zeros(self.k)
    145          for i,j,w in self.train_result:
    146              self.w=w[0]
    147              self.b=w[1]
    148              if self.classfy(x_new)==1:
    149                  p[j]+=1
    150              else:
    151                  p[i]+=1
    152          return self.cls[np.argmax(p)]
    153 
    154     def sub_data(self,x,y,i,j):  # 数据分类
    155         subx=[]
    156         suby=[]
    157         for a,b in zip(x,y):
    158             if b==i:
    159                  subx.append(a)
    160                  suby.append(-1)
    161             elif b==j:
    162                  subx.append(a)
    163                  suby.append(1)
    164         return subx,suby
    165 
    166     def classfy(self,x_new):  # 预测
    167         y_new=self.gx(x_new)
    168         cl = int(np.sign(y_new))
    169         if cl == 0:
    170             cl = 1
    171         return cl
    172 
    173 
    174 def load_data():
    175     f = gzip.open('../data/mnist.pkl.gz', 'rb')
    176     training_data, validation_data, test_data = cPickle.load(f)
    177     f.close()
    178     return (training_data, validation_data, test_data)
    179 
    180 if __name__ == "__main__":
    181     svc = SVC()
    182     np.random.seed(0)
    183     l=1000
    184     training_data, validation_data, test_data = load_data()
    185     svc.fit(training_data[0][:l],training_data[1][:l])
    186     predictions = [svc.predict(a) for a in test_data[0][:l]]
    187     num_correct = sum(int(a == y) for a, y in zip(predictions, test_data[1][:l]))
    188     print "%s of %s values correct." % (num_correct, len(test_data[1][:l]))  #72/100  #808/1000  #8194/10000(较慢)
  • 相关阅读:
    OWNER:Java配置文件解决方案 使用简介
    验证数字最简单正则表达式大全
    使用Spring进行统一日志管理 + 统一异常管理
    SpringMVC 拦截器
    Java排序
    tomcat编码配置
    日常任务
    netty入门代码学习
    redis学习
    AutoLayout And Animation
  • 原文地址:https://www.cnblogs.com/qw12/p/5744302.html
Copyright © 2020-2023  润新知