tensorflow学习
-
鸢尾花分类
步骤
1 · 准备数据,包括数据集读入、数据集乱序,把训练集和测试集中的数据配成输入特征和标签对,生成 train 和 test 即永不相见的训练集和测试集;
2 · 搭建网络,定义神经网络中的所有可训练参数;
3 · 优化这些可训练的参数,利用嵌套循环在 with 结构中求得损失函数 loss对每个可训练参数的偏导数,更改这些可训练参数,为了查看效果,程序中可以加入每遍历一次数据集显示当前准确率,还 可以画出准确率 acc 和损失函数 loss的变化曲线图。
代码实现
from sklearn import datasets
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
#读入数据并进行数据分割处理
x_data =datasets.load_iris().data
y_data =datasets.load_iris().target
np.random.seed(116)
np.random.shuffle(x_data)
np.random.seed(116)
np.random.shuffle(y_data)
tf.random.set_seed(116)
x_train = x_data[:-30]
y_train = y_data[:-30]
x_test = x_data[-30:]
y_test = y_data[-30:]
x_train = tf.cast(x_train,tf.float32)
x_test = tf.cast(x_test,tf.float32)
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
# 设置参数
w1 = tf.Variable(tf.random.truncated_normal([4,3],stddev =0.1,seed=1))
b1 = tf.Variable(tf.random.truncated_normal([3],stddev=0.1,seed=1))
lr =0.1#学习率
train_loss_results=[]
test_acc=[]
epoch = 500#循环次数
loss_all = 0
# 训练部分
for epoch in range(epoch):
for step,(x_train,y_train) in enumerate(train_db):
with tf.GradientTape() as tape:
y = tf.matmul(x_train, w1) + b1
y = tf.nn.softmax(y)
y_ = tf.one_hot(y_train, depth=3)
loss = tf.reduce_mean(tf.square(y_ - y))
loss_all += loss.numpy()
grads = tape.gradient(loss,[w1,b1])
# 更新参数
w1.assign_sub(lr * grads[0])
b1.assign_sub(lr * grads[1])
print("Epoch {}, loss: {}".format(epoch, loss_all/4))
train_loss_results.append(loss_all/4)
loss_all=0
total_correct, total_number = 0,0
for x_test , y_test in test_db:
y = tf.matmul(x_test,w1)+b1
y =tf.nn.softmax(y)
pred = tf.argmax(y,axis=1)
pred = tf.cast(pred,dtype=y_test.dtype)
correct = tf.cast(tf.equal(pred, y_test), dtype=tf.int32)
correct = tf.reduce_sum(correct) # 将每个 batch 的 correct 数加起来
total_correct += int(correct) # 将所有 batch 中的 correct 数加起来
total_number += x_test.shape[0]
acc = total_correct / total_number
test_acc.append(acc)
print("test_acc:", acc)
print("--------------------------------")
# 画图
plt.title('Loss Function Curve') # 图片标题
plt.xlabel('Epoch') # x 轴名称
plt.ylabel('Loss') # y 轴名称
plt.plot(train_loss_results, label="$Loss$") #
plt.legend()
plt.show()
plt.title('Acc Curve') # 图片标题
plt.xlabel('Epoch') # x 轴名称
plt.ylabel('Acc') # y 轴名称
plt.plot(test_acc, label="$Accuracy$") # 逐点画出 test_acc 值并连线
plt.legend()
plt.show()
结果