• caffe---测试模型分类结果并输出(python )


    当训练好一个model之后,我们通常会根据这个model最终的loss和在验证集上的accuracy来判断它的好坏。但是,对于分类问题,我们如果只是知道整体的分类正确率

    显然还不够,所以只有知道模型对于每一类的分类结果以及正确率这样才能更好的理解这个模型。

    下面就是一个用训练好的模型,来对测试集进行测试,并输出每个样本的分类结果的实现。

    代码如下:

    #coding=utf-8  
          
    import os
    import caffe  
    import numpy as np  
    root='/home/liuyun/caffe/'   #根目录  
    deploy=root + 'examples/DR_grade/deploy.prototxt'    #deploy文件  
    caffe_model=root + 'models/DR/model1/DRnet_iter_40000.caffemodel'  #训练好的 caffemodel  
    
    
    import os
    dir = root+'examples/DR_grade/test_512/'
    filelist=[]
    filenames = os.listdir(dir)
    for fn in filenames:
       fullfilename = os.path.join(dir,fn)
       filelist.append(fullfilename)
    
    
    # img=root+'data/DRIVE/test/60337.jpg'   #随机找的一张待测图片  
    
    def Test(img):
         
        net = caffe.Net(deploy,caffe_model,caffe.TEST)   #加载model和network  
          
        #图片预处理设置  
        transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})  #设定图片的shape格式(1,3,28,28)  
        transformer.set_transpose('data', (2,0,1))    #改变维度的顺序,由原始图片(28,28,3)变为(3,28,28)  
        #transformer.set_mean('data', np.load(mean_file).mean(1).mean(1))    #减去均值,前面训练模型时没有减均值,这儿就不用  
        transformer.set_raw_scale('data', 255)    # 缩放到【0,255】之间  
        transformer.set_channel_swap('data', (2,1,0))   #交换通道,将图片由RGB变为BGR  
          
        im=caffe.io.load_image(img)                   #加载图片  
        net.blobs['data'].data[...] = transformer.preprocess('data',im)      #执行上面设置的图片预处理操作,并将图片载入到blob中  
          
        #执行测试  
        out = net.forward()  
          
        labels = np.loadtxt(labels_filename, str, delimiter='	')   #读取类别名称文件  
        prob= net.blobs['prob'].data[0].flatten() #取出最后一层(prob)属于某个类别的概率值,并打印,'prob'为最后一层的名称 
        print prob  
        order=prob.argsort()[4]  #将概率值排序,取出最大值所在的序号 ,9指的是分为0-9十类  
        #argsort()函数是从小到大排列  
        print 'the class is:',labels[order]   #将该序号转换成对应的类别名称,并打印  
        f=file("/home/liuyun/caffe/examples/DR_grade/label.txt","a+")
        f.writelines(img+' '+labels[order]+'
    ')
    
    labels_filename = root +'examples/DR_grade/DR.txt'    #类别名称文件,将数字标签转换回类别名称  
    
    for i in range(0, len(filelist)):
        img= filelist[i]
        Test(img)
    
  • 相关阅读:
    SQL Server, Timeout expired.all pooled connections were in use and max pool size was reached
    javascript 事件调用顺序
    Best Practices for Speeding Up Your Web Site
    C语言程序设计 使用VC6绿色版
    破解SQL Prompt 3.9的几步操作
    Master page Path (MasterPage 路径)
    几个小型数据库的比较
    CSS+DIV 完美实现垂直居中的方法
    由Response.Redirect引发的"Thread was being aborted. "异常的处理方法
    Adsutil.vbs 在脚本攻击中的妙用
  • 原文地址:https://www.cnblogs.com/Allen-rg/p/5834551.html
Copyright © 2020-2023  润新知