• 手动实现TensorFlow的训练过程:示例


    参考文献:Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow: Concepts, Tools, and Techniques to Build Intelligent Systems

    l2_reg = keras.regularizers.l2(0.05)
    model = keras.models.Sequential([
        keras.layers.Dense(30, activation="elu", kernel_initializer="he_normal",
                           kernel_regularizer=l2_reg),
        keras.layers.Dense(1, kernel_regularizer=l2_reg)
    ])
    
    n_epochs = 5
    batch_size = 32
    n_steps = len(X_train) // batch_size
    optimizer = keras.optimizers.Nadam(lr=0.01)
    loss_fn = keras.losses.mean_squared_error
    mean_loss = keras.metrics.Mean()
    metrics = [keras.metrics.MeanAbsoluteError()]
    
    for epoch in range(1, n_epochs + 1):
        print("Epoch {}/{}".format(epoch, n_epochs))
        for step in range(1, n_steps + 1):
            X_batch, y_batch = random_batch(X_train_scaled, y_train)
            with tf.GradientTape() as tape:
                y_pred = model(X_batch)
                main_loss = tf.reduce_mean(loss_fn(y_batch, y_pred))
                a = main_loss
                b = model.losses
                loss = tf.add_n([main_loss] + model.losses)
            gradients = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))
            for variable in model.variables:
                if variable.constraint is not None:
                    variable.assign(variable.constraint(variable))
            c = loss
            mean_loss(loss)
            for metric in metrics:
                metric(y_batch, y_pred)
            print_status_bar(step * batch_size, len(y_train), mean_loss, metrics)
        print_status_bar(len(y_train), len(y_train), mean_loss, metrics)
        for metric in [mean_loss] + metrics:
            metric.reset_states()
    

    由于模型中存在regularizer,model.losses是每层layer中的regularization loss。总的loss等于loss function + regularization loss。

  • 相关阅读:
    2020牛客暑期多校训练营(第五场)D 思维|最长上升子序列
    codeforces-1343E(贪心+BFS)
    2020牛客暑期多校训练营(第三场)C 计算几何
    codeforces-1385E(拓扑排序)
    2020牛客寒假算法基础训练营2
    2020牛客寒假算法基础训练营1
    codeforces-1295D(欧拉函数)
    codeforces-1283D(多源BFS)
    深入理解JVM之JVM内存区域与内存分配
    属性动画详解一(Property Animation)
  • 原文地址:https://www.cnblogs.com/yaos/p/14014159.html
Copyright © 2020-2023  润新知