• K nearest neighbor cs229


    vectorized code 带来的好处。

     1 import numpy as np
     2 from sklearn.datasets import fetch_mldata
     3 import time
     4 import matplotlib.pyplot as plt
     5 
     6 mnist = fetch_mldata('MNIST original')
     7 
     8 X = mnist.data.astype(float)
     9 Y = mnist.target.astype(float) 
    10 
    11 mask = np.random.permutation(range(np.shape(X)[0]))
    12 
    13 num_train = 10000
    14 num_test = 500
    15 K = 10
    16 
    17 X_train = X[mask[:num_train]]
    18 Y_train = Y[mask[:num_train]]
    19 
    20 X_mean = np.mean(X_train,axis = 0)
    21 
    22 X_train = (X_train-X_mean)/255
    23 
    24 X_test = X[mask[num_train:num_train+num_test]]
    25 
    26 X_test = (X_test - X_mean)/255
    27 
    28 Y_test = Y[mask[num_train:num_train+num_test]]
    29 
    30 
    31 print('X_train',X_train.shape)
    32 print('Y_train',Y_train.shape)
    33 print('X_test',X_test.shape)
    34 print('Y_test',Y_test.shape)
    35 
    36 ex_image = (np.reshape(X_train[10,:]*255 + X_mean, (28, 28))).astype(np.uint8)
    37 plt.imshow(ex_image, interpolation='nearest')
    38 
    39 
    40 # **Computing the distance matrix (num_test x num_train)**
    41 
    42 # Version 1 (Naive implementation using two for loops)
    43 
    44 start = time.time()
    45 dists_1 = np.zeros((num_test,num_train))
    46 for i in xrange(num_test):
    47     for j in xrange(num_train):
    48           dists_1[i,j] = np.sqrt(np.square(np.sum(X_test[i,:]-X_train[j,:])))
    49 
    50 stop = time.time()
    51 time_taken = stop-start
    52 print('Time taken with two for loops: {}s'.format(time_taken))
    53 
    54 
    55 # Version 2(Somewhat better implementation using one for loop)
    56 
    57 start = time.time()
    58 dists_2 = np.zeros((num_test,num_train))
    59 for i in xrange(num_test):
    60           dists_2[i,:] = np.sqrt(np.square(np.sum(X_test[i,:]-X_train,axis = 1)))
    61         
    62 stop = time.time()
    63 time_taken = stop-start
    64 print('Time taken with just one for loop: {}s'.format(time_taken))
    65 
    66 
    67 # Version 3 (Fully vectorized implementation with no for loop)
    68 
    69 start = time.time()
    70 dists_3 = np.zeros((num_test,num_train))
    71 A = np.sum(np.square(X_test),axis = 1)
    72 B = np.sum(np.square(X_train),axis = 1)
    73 C = np.dot(X_test,X_train.T)
    74 
    75 dists_3 = np.sqrt(A[:,np.newaxis]+B[np.newaxis,:]-2*C)
    76         
    77 stop = time.time()
    78 time_taken = stop-start
    79 print('Time taken with no for loops: {}s'.format(time_taken))
    80 
    81 sorted_dist_indices = np.argsort(dists_3,axis = 1)
    82 
    83 closest_k = Y_train[sorted_dist_indices][:,:K].astype(int)
    84 Y_pred = np.zeros_like(Y_test)
    85 
    86 for i in xrange(num_test):
    87       Y_pred[i] = np.argmax(np.bincount(closest_k[i,:]))
    88 
    89 
    90 accuracy = (np.where(Y_test-Y_pred == 0)[0].size)/float(num_test)
    91 print('Prediction accuracy: {}%'.format(accuracy*100))
  • 相关阅读:
    sql server 跟踪各事件的字段项编码及解释
    sql server 有关锁的视图说明 syslockinfo
    SQL Server:查看SQL日志文件大小命令:dbcc sqlperf(logspace)
    [SqlServer]创建链接服务器
    SQL Server 2008 存储过程,带事务的存储过程(创建存储过程,删除存储过程,修改存储过
    sql server 索引分析相关sql
    IO系统性能之一:衡量性能的几个指标
    Writing to a MySQL database from SSIS
    用漫画的形式来讲解为什么MySQL数据库要用B+树存储索引?
    一份 Tomcat 和 JVM 的性能调优经验总结!拿走不谢
  • 原文地址:https://www.cnblogs.com/niuxichuan/p/8728139.html
Copyright © 2020-2023  润新知