• TensorFlow学习_LogicRegression_using_GrandientDescent


     1 #-*- coding:UTF8 -*-
     2 #梯度下降法解决线性回归
     3 import numpy as np
     4 import matplotlib.pyplot as plt
     5 import tensorflow as tf
     6 
     7 #构建数据
     8 point_nums = 100
     9 vector = []
    10 
    11 #用numpy的正态随机分布函数生成100个点
    12 #这些点(x,y)对应线性方程 y = 0.5 * x + 0.05
    13 #权重(Weight)0.5,偏差(Bias)0.05
    14 for i in range(point_nums):
    15     x1 = np.random.normal(0.0, 0.4)
    16     y1 = 0.5 * x1 + 0.05 + np.random.normal(0.0, 0.2)
    17     vector.append([x1,y1])
    18 
    19 x_data = [v[0] for v in vector]
    20 y_data = [v[1] for v in vector]
    21 
    22 
    23 plt.plot(x_data,y_data,'r*',label='original_data')
    24 plt.legend()
    25 plt.show()
    26 
    27 #构建线性模型
    28 W = tf.Variable(tf.random_uniform([1],-1,1)) #random_uniform(shape, minval,maxval,dtype=tf.float32,seed=None,name=None)
    29 b = tf.Variable(tf.zeros([1]))
    30 y = W * x_data + b
    31 
    32 #定义损失函数loss function 或 cost function
    33 loss = tf.reduce_mean(tf.square(y-y_data))
    34 
    35 #梯度下降优化器来进行优化loss function
    36 optimizer = tf.train.GradientDescentOptimizer(0.5) #learning_rate = 0.5
    37 train = optimizer.minimize(loss)
    38 
    39 #创建会话Sessions()
    40 sess = tf.Session()
    41 
    42 #全局初始化
    43 init = tf.global_variables_initializer()
    44 sess.run(init)
    45 
    46 #训练20次
    47 for step in range(20):
    48     #优化每一步
    49     sess.run(train)
    50     print('step:%s, train:%s, weight:%s, bias:%s'%(step, sess.run(train), sess.run(W), sess.run(b)))
    51 
    52 plt.plot(x_data, y_data, 'r*', label = 'original_data')
    53 plt.plot(x_data, sess.run(W) * x_data + sess.run(b), 'b', label = 'fitting_line')
    54 plt.legend()
    55 plt.show()
    56 
    57 sess.close()

    刚开始学习Tensorflow,发现tf确实是一个很强大的框架,内置函数很多,np里面有的tf内大部分也有,并且能用matplotlib包像matlab一样绘制图像。

    并且在对于在对于建模上面google做的越来越好,增加了tensorboard和playground,虽然暂时还未用到实际练习当中 : (

    对于learning rate的设置和优化loss_function的方法,还需进一步学习!

  • 相关阅读:
    忘记自己的密码了!
    MySQL ('root'@'%') does not exist的问题
    用视觉的差异和统一来表现界面信息(转)
    Localhost 本地mysql启动2013错误(windows系统下)
    修改SQL Server2005 sa密码方法
    .net中禁用TextBox和Input框的粘贴功能
    使用Visual Studio的搜索功能时间简单的代码量统计
    visifire3.6.4 以上版本去水印的办法
    网页设计的配色和排版(转)
    小米科技增设电商业务线,大家注意到没
  • 原文地址:https://www.cnblogs.com/AlexHaiY/p/9325570.html
Copyright © 2020-2023  润新知