本周老师给的任务:
一是将VOT15数据集(世华已传到服务器上)上每个序列的第1,11,21,31,41帧分别运行Faster R-CNN检测器并保存在图片上显示的检测结果;
二是将这5帧的ground truth bounding box作为proposal得到其对应的检测器分类结果(比如网络要检测20类物体,那包括背景就是得到21类对应的检测分数值),并将每个序列的检测结果分别存成一个文本文档。
注意,使用代码的时候,可能会有路径错误,还可能是,我贴上的代码,博客园的网站给在某些语句后加了 <br> ,调错的时候细看!!我在后台竟然看不到<br>,但是浏览的时候却有!!
第一个问题已经解决,现在整理一下思路。
先将py faster rcnn 装好之后,测试运行dome.py能成功展示之后,再进行接下来的工作。
我的想法是,
(1)将vot2015数据集上的所有数据的分类统计出来(就是把vot2015下的子文件夹的名称统计出来,方便之后操作),这里直接用了( http://www.cnblogs.com/flyhigh1860/p/3896111.html )的源码进行修改
#!/usr/bin/python # -*- coding:utf8 -*- import os allFileNum = 0 def printPath(level, path): global allFileNum ''''' 打印一个目录下的所有文件夹和文件 ''' # 所有文件夹,第一个字段是次目录的级别 dirList = [] # 所有文件 fileList = [] # 返回一个列表,其中包含在目录条目的名称(google翻译) files = os.listdir(path) # 先添加目录级别 dirList.append(str(level)) for f in files: if (os.path.isdir(path + '/' + f)): # 排除隐藏文件夹。因为隐藏文件夹过多 if (f[0] == '.'): pass else: # 添加非隐藏文件夹 dirList.append(f) if (os.path.isfile(path + '/' + f)): # 添加文件 fileList.append(f) # 当一个标志使用,文件夹列表第一个级别不打印 i_dl = 0
#得到的文件夹名保存在 save_file.txt 中,使用python的追加操作 ‘a’ save_file = open('/home/user/Downloads/save_file.txt','a') for dl in dirList: if (i_dl == 0): i_dl = i_dl + 1 else: # 打印至控制台,不是第一个的目录 print '-' * (int(dirList[0])), dl
#将文件名写入save_file.txt中 save_file.write(dl) save_file.write(' ') # 打印目录下的所有文件夹和文件,目录级别+1 #printPath((int(dirList[0]) + 1), path + '/' + dl) for fl in fileList: # 打印文件 print '-' * (int(dirList[0])), fl # 随便计算一下有多少个文件 allFileNum = allFileNum + 1 if __name__ == '__main__': printPath(1, '/home/user/Downloads/vot2015') print '总文件数 =', allFileNum
这里再给出save_file.txt 文件内容
soldier butterfly hand car2 sheep birds1 motocross1 marching book road graduate fish3 fernando bag wiper gymnastics2 leaves ball1 birds2 crossing soccer1 godfather nature racing traffic pedestrian2 handball2 ball2 gymnastics1 singer2 singer1 dinosaur gymnastics3 bolt1 gymnastics4 pedestrian1 helicopter singer3 matrix octopus iceskater1 fish4 sphere car1 motocross2 girl fish1 bolt2 basketball blanket bmx shaking tiger handball1 rabbit fish2 tunnel glove iceskater2 soccer2
(2)从save_file.txt 中将分来读取出来,保存再一个list中,之后将这段代码加到 demo.py 中使用(参考了 http://www.cnblogs.com/xuxn/archive/2011/07/27/read-a-file-with-python.html 和 http://www.cnblogs.com/mxh1099/p/5680001.html)
l = [] file = open('/home/user/Downloads/save_file.txt') while 1: line = file.readline() if line != ' ': print line.replace(" ", "")
#在list中 加入去掉换行符的文件名 l.append(line.replace(" ","")) if not line: break print l
(3)需要将文件名和要遍历的每个文件夹下的文件名配合,同样,这段代码之后会用在demo.py 中
lfile = [] file = open('/home/user/Downloads/save_file.txt') while 1: line = file.readline() if line != ' ': lfile.append(line.replace(" ", "")) if not line: break im_names =['00000023.jpg','00000011.jpg','00000001.jpg'] # im_names = ['00000001.jpg', '000000011.jpg', '00000021.jpg', # '00000031.jpg', '00000041.jpg'] for litme in lfile : for im_name in im_names: im_path = str(litme) + '/' + str(im_name) print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~' #print 'Demo for data/demo/{}'.format(im_name) print im_path
(4)可以对文件遍历之后,需要将生成的图片结果保存下来,参考了《演示如何实现Matplotlib绘图并保存图像但不显示图形的方法》(http://blog.csdn.net/rumswell/article/details/7342479) 和Python创建目录文件夹 (http://www.cnblogs.com/monsteryang/p/6574550.html)
最后附上我修改之后的demo.py
#!/usr/bin/env python # -------------------------------------------------------- # Faster R-CNN # Copyright (c) 2015 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ross Girshick # -------------------------------------------------------- """ Demo script showing detections in sample images. See README.md for installation instructions before running. """ import _init_paths from fast_rcnn.config import cfg from fast_rcnn.test import im_detect from fast_rcnn.nms_wrapper import nms from utils.timer import Timer import matplotlib import matplotlib.pyplot as plt import numpy as np import scipy.io as sio import caffe, os, sys, cv2 import argparse #add matplotlib.use('Agg') CLASSES = ('__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor') NETS = {'vgg16': ('VGG16', 'VGG16_faster_rcnn_final.caffemodel'), 'zf': ('ZF', 'ZF_faster_rcnn_final.caffemodel')} #add def mkdir(path): import os path = path.strip() path = path.rstrip("\") isExists = os.path.exists(path) if not isExists: os.makedirs(path) print path + 'ok' return True else: print path + 'failed!' return False def vis_detections(image_name, im, class_name, dets, thresh=0.5): """Draw detected bounding boxes.""" inds = np.where(dets[:, -1] >= thresh)[0] if len(inds) == 0: return im = im[:, :, (2, 1, 0)] fig, ax = plt.subplots(figsize=(12, 12)) ax.imshow(im, aspect='equal') for i in inds: bbox = dets[i, :4] score = dets[i, -1] ax.add_patch( plt.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1], fill=False, edgecolor='red', linewidth=3.5) ) ax.text(bbox[0], bbox[1] - 2, '{:s} {:.3f}'.format(class_name, score), bbox=dict(facecolor='blue', alpha=0.5), fontsize=14, color='white') ax.set_title(('{} detections with ' 'p({} | box) >= {:.1f}').format(class_name, class_name, thresh), fontsize=14) plt.axis('off') plt.tight_layout() plt.draw() #add ll = [] ll = str(image_name).split('/') print ll[0] mkdir('/home/user/tmp/' + str(ll[0])) plt.savefig('/home/user/tmp/' + str(image_name)) def demo(net, image_name): """Detect object classes in an image using pre-computed object proposals.""" # Load the demo image im_file = os.path.join(cfg.DATA_DIR, 'demo','vot2015', image_name) print("%s", im_file) im = cv2.imread(im_file) # Detect all object classes and regress object bounds timer = Timer() timer.tic() #add try except try: scores, boxes = im_detect(net, im) timer.toc() print ('Detection took {:.3f}s for ' '{:d} object proposals').format(timer.total_time, boxes.shape[0]) # Visualize detections for each class CONF_THRESH = 0.8 NMS_THRESH = 0.3 for cls_ind, cls in enumerate(CLASSES[1:]): cls_ind += 1 # because we skipped background cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)] cls_scores = scores[:, cls_ind] dets = np.hstack((cls_boxes, cls_scores[:, np.newaxis])).astype(np.float32) keep = nms(dets, NMS_THRESH) dets = dets[keep, :] vis_detections(image_name,im, cls, dets, thresh=CONF_THRESH) except Exception: print 'Error' def parse_args(): """Parse input arguments.""" parser = argparse.ArgumentParser(description='Faster R-CNN demo') parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]', default=0, type=int) parser.add_argument('--cpu', dest='cpu_mode', help='Use CPU mode (overrides --gpu)', action='store_true') parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16]', choices=NETS.keys(), default='vgg16') args = parser.parse_args() return args if __name__ == '__main__': cfg.TEST.HAS_RPN = True # Use RPN for proposals args = parse_args() prototxt = os.path.join(cfg.MODELS_DIR, NETS[args.demo_net][0], 'faster_rcnn_alt_opt', 'faster_rcnn_test.pt') caffemodel = os.path.join(cfg.DATA_DIR, 'faster_rcnn_models', NETS[args.demo_net][1]) if not os.path.isfile(caffemodel): raise IOError(('{:s} not found. Did you run ./data/script/' 'fetch_faster_rcnn_models.sh?').format(caffemodel)) if args.cpu_mode: caffe.set_mode_cpu() else: caffe.set_mode_gpu() caffe.set_device(args.gpu_id) cfg.GPU_ID = args.gpu_id net = caffe.Net(prototxt, caffemodel, caffe.TEST) print ' Loaded network {:s}'.format(caffemodel) # Warmup on a dummy image im = 128 * np.ones((300, 500, 3), dtype=np.uint8) for i in xrange(2): _, _= im_detect(net, im) # im_names = ['000456.jpg', '000542.jpg', '001150.jpg', # '001763.jpg', '004545.jpg','00000023.jpg','00000011.jpg','00000001.jpg'] # edit lfile = [] file = open('/home/user/Downloads/save_file.txt') while 1: line = file.readline() if line != ' ': lfile.append(line.replace(" ", "")) if not line: break print lfile im_names = ['00000001.jpg', '00000011.jpg', '00000021.jpg', '00000031.jpg', '00000041.jpg'] for litme in lfile : for im_name in im_names: im_path = str(litme) + '/' + str(im_name) print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~' print 'Demo for data/demo/{}'.format(im_name) try: demo(net, im_path) except Exception: print 'ERROR' #plt.show()
第二个问题先看着,没想法
在图片上显示每个IOU大于0.5的proposal对应的最高检测值的类别、分数和回归后的框,在文本文档里则保存每个proposal对应的21个类别的检测分数和回归后的边界框坐标。
对于每个类别,总会生成300个proposals,
所以,在每个proposal,都会有4个坐标
对于每个proposal,都会有一个类别值。
因为要生成每个proposal对应的21个类别的分数,就需要将分数先保存起来,再输出
还要记录回归后的边间框。
对于图片,显示每个IOU大于0.5的proposal对应的最高检测值的类别、分数和回归后的框。
也是先要将最高检测分数对应的类别和回归框记录下来。