• 机器学习之路: tensorflow 一个最简单的神经网络


    git: https://github.com/linyi0604/MachineLearning/tree/master/07_tensorflow/

     1 import tensorflow as tf
     2 # 利用numpy生成数据模拟数据集
     3 from numpy.random import RandomState
     4 
     5 
     6 # 定义一个训练数据batch的大小
     7 batch_size = 8
     8 
     9 # 定义神经网络的参数
    10 w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1))
    11 w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1))
    12 
    13 # 行数指定为None  会根据传入的行数动态变化
    14 x = tf.placeholder(tf.float32, shape=(None, 2), name="x_input")
    15 y_ = tf.placeholder(tf.float32, shape=(None, 1), name="y_input")
    16 
    17 # 定义神经网络前向传播
    18 a = tf.matmul(x, w1)
    19 y = tf.matmul(a, w2)
    20 
    21 # 定义损失函数和反向传播的算法
    22 cross_entropy = -tf.reduce_mean(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
    23 train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)
    24 
    25 # 利用随机数生成数据集
    26 rdm = RandomState(1)
    27 dataset_size = 128
    28 X = rdm.rand(dataset_size, 2)
    29 Y = [[int(x1+x2 < 1)] for (x1, x2) in X]
    30 
    31 # 开启会话
    32 with tf.Session() as sess:
    33     init_op = tf.initialize_all_variables()
    34     # 初始化变量
    35     sess.run(init_op)
    36     print("训练之前的权重w1和2:")
    37     print(sess.run(w1))
    38     print(sess.run(w2))
    39 
    40     # 设定迭代次数
    41     STEPS = 5000
    42     for i in range(STEPS):
    43         # 每次选batch个样本进行训练
    44         start = (i * batch_size) % dataset_size
    45         end = min(start + batch_size, dataset_size)
    46         # 选取样本训练神经网络
    47         sess.run(train_step, feed_dict={x: X[start: end], y_: Y[start: end]})
    48         # 每隔一段时间计算所有数据上的交叉熵并输出
    49         if i % 1000 == 0:
    50             total_cross_entropy = sess.run(cross_entropy, feed_dict={x: X, y_: Y})
    51             print("第%d次迭代:交叉熵为%s" % (i, total_cross_entropy))
    52     print("训练之后的权重w1和w2分别是:")
    53     print(sess.run(w1))
    54     print(sess.run(w2))
    55 
    56 
    57 '''
    58 练之前的权重w1和2:
    59 [[-0.8113182   1.4845988   0.06532937]
    60  [-2.4427042   0.0992484   0.5912243 ]]
    61 [[-0.8113182 ]
    62  [ 1.4845988 ]
    63  [ 0.06532937]]
    64  
    65 第0次迭代:交叉熵为0.067492485
    66 第1000次迭代:交叉熵为0.016338505
    67 第2000次迭代:交叉熵为0.009075474
    68 第3000次迭代:交叉熵为0.007144361
    69 第4000次迭代:交叉熵为0.005784708
    70 
    71 训练之后的权重w1和w2分别是:
    72 [[-1.9618274  2.582354   1.6820377]
    73  [-3.4681718  1.0698233  2.11789  ]]
    74 [[-1.8247149]
    75  [ 2.6854665]
    76  [ 1.418195 ]]
    77 '''
  • 相关阅读:
    Android平台架构及特性
    MySQL 数据库性能优化之索引优化
    排序自己总结
    存储过程中“ 警告: 聚合或其他 SET 操作消除了 Null 值” 导致错误的解决
    存储过程的output及return的区别
    如果ssh端口转发时候g没有效果解决方案
    sql语句显示复选内容, indication 为复选框的累计value(整数),显示所有的
    kprfakesu.c Linux su密码欺骗 源码
    Unix/Linux上的后门技术和防范
    iis7应用程序池经常自动停止如何解决?
  • 原文地址:https://www.cnblogs.com/Lin-Yi/p/9145186.html
Copyright © 2020-2023  润新知