• fashion_mnist 计算准确率、召回率、F1值


    fashion_mnist 计算准确率、召回率、F1值

    1、定义

    首先需要明确几个概念:

    假设某次预测结果统计为下图:

    image-20201227200400240

    那么各个指标的计算方法为:

    • A类的准确率:TP1/(TP1+FP5+FP9+FP13+FP17) 即预测为A的结果中,真正为A的比例
    • A类的召回率:TP1/(TP1+FP1+FP2+FP3+FP4) 即实际上所有为A的样例中,能预测出来多少个A(的比例)
    • A类的F1值:(准确率*召回率*2)/(准确率+召回率)

    实际上我们在训练出某个模型后,会将测试集中每个测试样例进行一次结果预测,因此只需统计这些结果,经过计算即可得到各类数据的准确率、召回率、F1值

    2、使用fashion_mnist

    需要提前pip安装tensorflow、prettytable、numpy

    from tensorflow import keras
    import numpy as np
    import prettytable
    
    # 下载数据集
    fashion_mnist = keras.datasets.fashion_mnist
    (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
    
    # 制作标签名称
    class_names = ['T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Boot']
    # 图片数据归一化
    train_images = train_images / 255.0
    test_images = test_images / 255.0
    
    # 构建3层DNN模型,使用激活函数softmax
    model = keras.Sequential([
        keras.layers.Flatten(input_shape=(28, 28)),
        keras.layers.Dense(128, activation='relu'),
        keras.layers.Dense(10, activation='softmax')
    ])
    # 定义模型的损失函数,优化器与评估指标
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=0.001),
        loss=keras.losses.sparse_categorical_crossentropy,
        metrics=['accuracy']
    )
    # 训练模型
    model.fit(train_images, train_labels, epochs=5)
    # 评估模型
    predictions = model.predict(test_images)
    train_result = np.zeros((10, 10), dtype=int)
    for i in range(10000):
        train_result[test_labels[i]][np.argmax(predictions[i])] += 1
    
    result_table = prettytable.PrettyTable()
    result_table.field_names = ['Type', 'Accu', 'Recall', 'F1']
    for i in range(10):
        ac = train_result[i][i] / sum(train_result.T[i])
        rc = train_result[i][i] / sum(train_result[i])
        result_table.add_row([class_names[i], round(ac, 3), round(rc, 3), round(ac * rc * 2 / (ac + rc), 3)])
    
    print(result_table)
    
    

    实际效果:

    image-20201227205815716

  • 相关阅读:
    记一次省赛总结
    护网杯一道密码学的感想
    配置phpstudy+phpstorm+xdebug环境
    python 模板注入
    hash扩展攻击本地实验
    kali rolling更新源之gpg和dirmngr问题
    web信息泄露注意事项
    ctf常见php弱类型分析
    文件上传小结
    ctf变量覆盖漏洞
  • 原文地址:https://www.cnblogs.com/soowin/p/14198663.html
Copyright © 2020-2023  润新知