TensorFlow线性回归
环境:TensorFlow2.0
前置知识
Tensor
TensorFlow使用tensor(张量)来表示数据,名字挺高大上,其实可以理解成多维数组。
import tensorflow as tf # 每次使用TensorFlow的第一件事
A = tf.constant([1, 2, 3]) # constant是表示A是一个常量
B = tf.constant([[1,2],[3,4]])
这里我们定义了两个张量,其中A是一维张量,B是二维张量。我们打印看看
>>> A # 形状是(3,)是一个3有个元素的向量
<tf.Tensor: id=0, shape=(3,), dtype=int32, numpy=array([1, 2, 3])>
>>> B # 形状是(2, 2)一个2*2矩阵
<tf.Tensor: id=1, shape=(2, 2), dtype=int32, numpy=
array([[1, 2],
[3, 4]])> # 两者都是int32类型,可以通过numpy()方法来得到它们的值
Tensor运算
tf.add(tensor_A, tensor_B) # 矩阵元素相加
tf.subtract(tensor_A, tensor_B) # 矩阵元素相减
tf.multiply(tensor_A, tensor_B) # 矩阵元素相乘
tf.divide(tensor_A, tensor_B) # 矩阵元素相除
tf.matmul(tensor_A, tensor_B) # 矩阵乘法
tf.pow(tensor_A, num) # 矩阵元素幂运算
这上面除了矩阵乘法,也就其他对矩阵元素进行操作的都可以用数学符号的+-*/和**
来代替。
Tensor的运算会有一个广播机制,后面遇到再讲。
自动求导
身为机器(深度)学习框架,自动求导机制少不了,我们看一下代码
x = tf.Variable(initial_value=1.) # tf.Variable表示这是一个变量,里面的initial参数将其初始值设为1
with tf.GradientTape() as tape: # GradientTape:梯度带,在with中的所有过程将会被记录
y = x**2+7*x+1 # 也就是你的函数放在这里就行,甚至可以分好几步写
dy_dx = tape.gradient(y,x) # 求with过程中y关于x的梯度
print(dy_dx) # tf.Tensor(9.0, shape=(), dtype=float32)
线性回归
首先生成我们的数据,数据标签通过真实函数加上高斯噪声得到。
然后为了进行梯度计算,X、y需要转换成tf的格式。
定义了变量w和偏移量b,初值都设为0.
# 构造数据,方程y=2x+1
x_data = np.linspace(0, 1, 200).reshape((-1,1))
y_data = 2*x_data+1 + np.random.normal(0,0.02, x_data.shape)
X = tf.constant(x_data, dtype=tf.float32)
y = tf.constant(y_data, dtype=tf.float32)
w = tf.Variable(0.)
b = tf.Variable(0.)
使用梯度带自动计算梯度,然后用优化器自动更新模型的参数w和b
epoches = 1000
# 优化器,设置学习率为0.001
optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
for _ in range(epoches):
with tf.GradientTape() as tape:
y_pred = X*w+b
loss = 0.5*tf.reduce_sum((y_pred-y)**2)
grads = tape.gradient(loss, [w,b])
# 通过apply_gradients来最小化损失函数,参数是梯度,变量对(grad, variable)
optimizer.apply_gradients(grads_and_vars=zip(grads, [w, b]))
打印我们的参数,很接近真实参数
>>> print(w.numpy(), b.numpy())
1.9974449 1.003515
最后
我们的线性回归拟合完成,如果想更直观的看,可以使用plot将离散点和拟合的函数画出来。
import matplotlib.pyplot as plt
y_pred = X*w+b
plt.figure()
plt.scatter(x_data, y_data)
plt.plot(x_data, y_pred, "r", lw=2)
plt.show()