• Keras class_weight和sample_weight用法


    搬运: https://stackoverflow.com/questions/57610804/when-is-the-timing-to-use-sample-weights-in-keras

    import tensorflow as tf
    import numpy as np
    
    data_size = 100
    input_size=3
    classes=3
    
    x_train = np.random.rand(data_size ,input_size)
    y_train= np.random.randint(0,classes,data_size )
    #sample_weight_train = np.random.rand(data_size)
    x_val = np.random.rand(data_size ,input_size)
    y_val= np.random.randint(0,classes,data_size )
    #sample_weight_val = np.random.rand(data_size )
    
    inputs = tf.keras.layers.Input(shape=(input_size))
    pred=tf.keras.layers.Dense(classes, activation='softmax')(inputs)
    
    model = tf.keras.models.Model(inputs=inputs, outputs=pred)
    
    loss = tf.keras.losses.sparse_categorical_crossentropy
    metrics = tf.keras.metrics.sparse_categorical_accuracy
    
    model.compile(loss=loss , metrics=[metrics], optimizer='adam')
    
    # Make model static, so we can compare it between different scenarios
    for layer in model.layers:
        layer.trainable = False
    
    # base model no weights (same result as without class_weights)
    # model.fit(x=x_train,y=y_train, validation_data=(x_val,y_val))
    class_weights={0:1.,1:1.,2:1.}
    model.fit(x=x_train,y=y_train, class_weight=class_weights, validation_data=(x_val,y_val))
    # which outputs:
    > loss: 1.1882 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1965 - val_sparse_categorical_accuracy: 0.3100
    
    #changing the class weights to zero, to check which loss and metric that is affected
    class_weights={0:0,1:0,2:0}
    model.fit(x=x_train,y=y_train, class_weight=class_weights, validation_data=(x_val,y_val))
    # which outputs:
    > loss: 0.0000e+00 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1945 - val_sparse_categorical_accuracy: 0.3100
    
    #changing the sample_weights to zero, to check which loss and metric that is affected
    sample_weight_train = np.zeros(100)
    sample_weight_val = np.zeros(100)
    model.fit(x=x_train,y=y_train,sample_weight=sample_weight_train, validation_data=(x_val,y_val,sample_weight_val))
    # which outputs:
    > loss: 0.0000e+00 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1931 - val_sparse_categorical_accuracy: 0.3100
    

    class_weight: output 变量的权重
    sample_weight: data sample 的权重

  • 相关阅读:
    rbd-mirror新功能
    ceph查询rbd的使用容量(快速)
    Ceph Bluestore首测
    让磁盘硬盘灯常闪定位盘
    aa
    地理围栏
    Can't connect to MySQL server on localhost (10061)解决方法
    经典智力题
    MEF程序设计指南
    ESRI.ArcGIS.AnalysisTools.Erase 结果是空?
  • 原文地址:https://www.cnblogs.com/yaos/p/14014218.html
Copyright © 2020-2023  润新知