• Keras猫狗大战三:加载模型,预测目录中图片,画混淆矩阵


    版权声明:本文为博主原创文章,欢迎转载,并请注明出处。联系方式:460356155@qq.com

     一、加载模型,预测测试集

    %matplotlib inline
    import matplotlib.pyplot as plt
    
    import os
    import itertools
    import cv2
    
    import numpy as np
    from sklearn.metrics import confusion_matrix
    
    from keras.preprocessing.image import ImageDataGenerator
    from keras.models import load_model
    
    dst_path = r'D:BaiduNetdiskDownloadsmall'
    model_file = r"D:fastaiprojectscats_and_dogs_small_1.h5"
    test_dir = os.path.join(dst_path, 'test')
    
    batch_size = 20
    
    model = load_model(model_file)
    
    test_datagen = ImageDataGenerator(rescale=1. / 255)
    
    test_generator = test_datagen.flow_from_directory(
        test_dir,
        target_size=(150, 150),
        batch_size=batch_size,
        class_mode='binary')
    
    test_loss, test_acc = model.evaluate_generator(test_generator, steps=test_generator.samples / batch_size)
    print('test acc: %.3f%%' % test_acc)
    Found 400 images belonging to 2 classes.
    test acc: 0.747%

    二、预测测试集,画混淆矩阵
    def get_input_xy(src=[]):
        pre_x = []
        true_y = []
    
        class_indices = {'cat': 0, 'dog': 1}
    
        for s in src:
            input = cv2.imread(s)
            input = cv2.resize(input, (150, 150))
            input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
            pre_x.append(input)
    
            _, fn = os.path.split(s)
            y = class_indices.get(fn[:3])
            true_y.append(y)
    
        pre_x = np.array(pre_x) / 255.0
    
        return pre_x, true_y
    
    
    def plot_sonfusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
        plt.imshow(cm, interpolation='nearest', cmap=cmap)
        plt.title(title)
        plt.colorbar()
        tick_marks = np.arange(len(classes))
        plt.xticks(tick_marks, classes, rotation=45)
        plt.yticks(tick_marks, classes)
    
        if normalize:
            cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
        thresh = cm.max() / 2.0
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            plt.text(j, i, cm[i, j], horizontalalignment='center', color='white' if cm[i, j] > thresh else 'black')
    
        plt.tight_layout()
        plt.ylabel('True label')
        plt.xlabel('Predict label')
    
    
    test = os.listdir(test_dir)
    
    images = []
    
    # 获取每张图片的地址,并保存在列表images中
    for testpath in test:
        for fn in os.listdir(os.path.join(test_dir, testpath)):
            if fn.endswith('jpg'):
                fd = os.path.join(test_dir, testpath, fn)
                images.append(fd)
    
    # 得到规范化图片及true label
    pre_x, true_y = get_input_xy(images)
    
    # 预测
    pred_y = model.predict_classes(pre_x)
    
    # 画混淆矩阵
    confusion_mat = confusion_matrix(true_y, pred_y)
    plot_sonfusion_matrix(confusion_mat, classes=range(2))

  • 相关阅读:
    mybatis 缓存
    mybatis 动态sql
    新手必读:游戏编程入门指南
    22条常用JavaScript开发小技巧
    Unity即将全面升级 实时3D技术及大场景编辑未来可期!
    10分钟学会Python基础知识
    如何用UE4制作非写实草浪
    最适合设计师的前端学习路径有哪些?
    超全面的C++游戏开发面试问题总结
    如何学习大型项目的源码?虚幻引擎源码学习思路分享
  • 原文地址:https://www.cnblogs.com/zhengbiqing/p/11070050.html
Copyright © 2020-2023  润新知