学习进度笔记09
TensorFlow K近邻算法
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
mnist =input_data.read_data_sets("/home/yxcx/tf_data/MNIST_data",one_hot=True)
Xtr,Ytr=mnist.train.next_batch(5000)
Xte,Yte=mnist.test.next_batch(200)
#tf Graph Input
xtr=tf.placeholder("float",[None,784])
xte=tf.placeholder("float",[784])
distance =tf.reduce_sum(tf.abs(tf.add(xtr,tf.negative(xte))),reduction_indices=1)
pred=tf.argmin(distance,0)
accuracy=0
init=tf.global_variables_initializer()
#Start training
with tf.Session() as sess:
sess.run(init)
for i in range(len(Xte)):
#Get nearest nerighbor
nn_index=sess.run(pred,feed_dict={xtr:Xtr,xte:Xte[i,:]})
print("Test",i ,"Prediction:",np.argmax(Ytr[nn_index]),"True Class:",np.argmax(Yte[i]))
if np.argmax(Ytr[nn_index])==np.argmax(Yte[i]):
accuracy+=1./len(Xte)
print("Done!")
print("accuacy:" ,accuracy)