• 自己定制训练


    1,以类的方式定义一个模型

    class Model(object):
    def __init__(self):
    # Initialize variable to (5.0, 0.0)
    # In practice, these should be initialized to random values.
    self.W = tf.Variable(5.0)
    self.b = tf.Variable(0.0)

    def __call__(self, x):
    return self.W * x + self.b

    model = Model()

    assert model(3.0).numpy() == 15.0

    2,损失函数

    def loss(predicted_y, desired_y):
    return tf.reduce_mean(tf.square(predicted_y - desired_y))

    3,生成数据

    TRUE_W = 3.0
    TRUE_b = 2.0
    NUM_EXAMPLES = 1000

    inputs = tf.random_normal(shape=[NUM_EXAMPLES])
    noise = tf.random_normal(shape=[NUM_EXAMPLES])
    outputs = inputs * TRUE_W + TRUE_b + noise

    4,绘制,训练前

    import matplotlib.pyplot as plt

    plt.scatter(inputs, outputs, c='b',s=1)
    plt.scatter(inputs, model(inputs), c='r',linewidths=0.01)
    plt.show()

    print('Current loss: '),
    print(loss(model(inputs), outputs).numpy())

    5,迭代过程

    def train(model, inputs, outputs, learning_rate):
      with tf.GradientTape() as t:
        current_loss = loss(model(inputs), outputs)
      dW, db = t.gradient(current_loss, [model.W, model.b])
      model.W.assign_sub(learning_rate * dW)
      model.b.assign_sub(learning_rate * db)

    6,训练过程

    model = Model()

    # Collect the history of W-values and b-values to plot later
    Ws, bs = [], []
    epochs = range(10)
    for epoch in epochs:
    Ws.append(model.W.numpy())
    bs.append(model.b.numpy())
    current_loss = loss(model(inputs), outputs)

    train(model, inputs, outputs, learning_rate=0.1)
    print('Epoch %2d: W=%1.2f b=%1.2f, loss=%2.5f' %
    (epoch, Ws[-1], bs[-1], current_loss))

    # Let's plot it all
    plt.plot(epochs, Ws, 'r',
    epochs, bs, 'b')
    plt.plot([TRUE_W] * len(epochs), 'r--',
    [TRUE_b] * len(epochs), 'b--')
    plt.legend(['W', 'b', 'true W', 'true_b'])
    plt.show()

  • 相关阅读:
    BNU校赛
    Latest Common Ancestor
    Codeforces Round #482 (Div. 2)
    Persistent Line Segment Tree
    2018HNCCPC(Onsite)
    2018HNCCPC
    2017 ACM Jordanian Collegiate Programming Contest
    Codeforces Round #480 (Div. 2)
    负载均衡SLB
    windows下的端口监听、程序端口查找命令
  • 原文地址:https://www.cnblogs.com/augustone/p/10507308.html
Copyright © 2020-2023  润新知