• 寒假学习进度12: 线性回归tensorflow2.0实现


    参考博客:https://blog.csdn.net/weixin_45665788/article/details/104919669

    import  matplotlib.pyplot as plt
    import  numpy as np
    import tensorflow as tf
    # 载入随机种子
    np.random.seed(5)
    #生成100个等差序列,每个值在-1 - 1 之间
    x_data = np.linspace(-1,1,1000)
    #y = 2x + 1 + 噪声,噪声的维度和x_Data一致
    y_data = 2 * x_data +1.0 +np.random.randn(*x_data.shape) * 0.4 #*表示把元组拆分为一个个单独的实参
    plt.scatter(x_data,y_data)
    plt.plot(x_data,2*x_data+1,color = 'red' ,linewidth = 3)
    
    #定义模型函数以及线性函数的斜率和截距
    def model(x,w,b):
        return tf.multiply(x,w)+b
    
    #设置损失函数,这里使用均方差作为损失函数
    def loss_fun(x,y,w,b):
        err = model(x,w,b)-y
        squared_err = tf.square(err)
        return tf.reduce_mean(squared_err)
    
    #返回梯度向量
    def grad(x,y,w,b):
        with tf.GradientTape() as tape:
            loss_ = loss_fun(x,y,w,b)
        return tape.gradient(loss_,[w,b])
    
    if __name__ == '__main__':
        #因为模型比较简单,因此超参的迭代次数设置的比较小
        # 构建线性函数的斜率和截距
        w = tf.Variable(np.random.randn(), tf.float32)
        b = tf.Variable(0.0, tf.float32)
        # 设置迭代次数和学习率
        train_epochs = 10
        learning_rate = 0.01
        loss = []
        count = 0
        display_count = 10  # 控制显示粒度的参数,每训练10个样本输出一次损失值
    
        # 开始训练,轮数为epoch,采用SGD随机梯度下降优化方法
        for epoch in range(train_epochs):
            for xs, ys in zip(x_data, y_data):
                # 计算损失,并保存本次损失计算结果
                loss_ = loss_fun(xs, ys, w, b)
                loss.append(loss_)
                # 计算当前[w,b]的梯度
                delta_w, delta_b = grad(xs, ys, w, b)
                change_w = delta_w * learning_rate
                change_b = delta_b * learning_rate
                w.assign_sub(change_w)
                b.assign_sub(change_b)
                # 训练步数加1
                count = count + 1
                if count % display_count == 0:
                    print('train epoch : ', '%02d' % (epoch + 1), 'step:%03d' % (count), 'loss= ', '{:.9f}'.format(loss_))
            # 完成一轮训练后,画图
            plt.plot(x_data, w.numpy() * x_data + b.numpy())
            plt.show()
    

      

  • 相关阅读:
    mysql事物中行锁与表锁
    https的实现原理
    基于http的断点续传和多线程下载
    Cookie与Session
    centos 7 安装python3
    为CentOS下的Docker安装配置python3【转】
    Jmeter如何提取响应头部的JSESSIONID【转】
    centOS7 安装nginx
    centos 7.X关闭防火墙和selinux
    (四)从输入URL到页面加载发生了什么
  • 原文地址:https://www.cnblogs.com/yangqqq/p/14459774.html
Copyright © 2020-2023  润新知