• Keras class_weight和sample_weight用法


    搬运: https://stackoverflow.com/questions/57610804/when-is-the-timing-to-use-sample-weights-in-keras

    import tensorflow as tf
    import numpy as np
    
    data_size = 100
    input_size=3
    classes=3
    
    x_train = np.random.rand(data_size ,input_size)
    y_train= np.random.randint(0,classes,data_size )
    #sample_weight_train = np.random.rand(data_size)
    x_val = np.random.rand(data_size ,input_size)
    y_val= np.random.randint(0,classes,data_size )
    #sample_weight_val = np.random.rand(data_size )
    
    inputs = tf.keras.layers.Input(shape=(input_size))
    pred=tf.keras.layers.Dense(classes, activation='softmax')(inputs)
    
    model = tf.keras.models.Model(inputs=inputs, outputs=pred)
    
    loss = tf.keras.losses.sparse_categorical_crossentropy
    metrics = tf.keras.metrics.sparse_categorical_accuracy
    
    model.compile(loss=loss , metrics=[metrics], optimizer='adam')
    
    # Make model static, so we can compare it between different scenarios
    for layer in model.layers:
        layer.trainable = False
    
    # base model no weights (same result as without class_weights)
    # model.fit(x=x_train,y=y_train, validation_data=(x_val,y_val))
    class_weights={0:1.,1:1.,2:1.}
    model.fit(x=x_train,y=y_train, class_weight=class_weights, validation_data=(x_val,y_val))
    # which outputs:
    > loss: 1.1882 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1965 - val_sparse_categorical_accuracy: 0.3100
    
    #changing the class weights to zero, to check which loss and metric that is affected
    class_weights={0:0,1:0,2:0}
    model.fit(x=x_train,y=y_train, class_weight=class_weights, validation_data=(x_val,y_val))
    # which outputs:
    > loss: 0.0000e+00 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1945 - val_sparse_categorical_accuracy: 0.3100
    
    #changing the sample_weights to zero, to check which loss and metric that is affected
    sample_weight_train = np.zeros(100)
    sample_weight_val = np.zeros(100)
    model.fit(x=x_train,y=y_train,sample_weight=sample_weight_train, validation_data=(x_val,y_val,sample_weight_val))
    # which outputs:
    > loss: 0.0000e+00 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1931 - val_sparse_categorical_accuracy: 0.3100
    

    class_weight: output 变量的权重
    sample_weight: data sample 的权重

  • 相关阅读:
    hdu5608 function
    Codeforces Round #535 (Div. 3) 解题报告
    HDU4746 Mophues
    HDU5663 Hillan and the girl
    AtCoder Beginner Contest 117 解题报告
    GDOI2018D2T1 谈笑风生
    BZOJ4018: 小Q的幻想之乡
    牛客寒假算法基础集训营6 解题报告
    win32拖拽编程
    项目开发中的贝塞尔曲线
  • 原文地址:https://www.cnblogs.com/yaos/p/12069527.html
Copyright © 2020-2023  润新知