• 【目标检测】使用YOLOv3对视频进行物体检测并输出新的视频


    此项目使用YOLOv3对视频进行物体检测并输出新的视频。

    参考YOLO官网:https://pjreddie.com/darknet/yolo/,安装darknet,配置opencv和gpu。

    主要步骤如下:

    1) 下载 darknet 并编译:
      git clone https://github.com/pjreddie/darknet
      cd darknet
      make
    2) 下载 yolov3 权重:
      wget https://pjreddie.com/media/files/yolov3.weights
    3) 将编译生成的 libdarknet.so 复制到 python 文件夹下,此动态库中包含了检测需要的所有参数;
    4) 修改 darknet.py 以实现摄像头和视频检测;
    5) 将检测后的每一帧保存,然后调用 opencv 将所有帧合成视频。

    官网主要使用c接口对图片、视频或是进行实时检测,Python代码可实现的功能较少,这里修改Python代码来实现以上功能。

    对Python文件夹中darknet.py的修改主要如下:

    1)为 80 个类别的检测框分别分配一种颜色
    2) 为检测到的物体画出边界框(以及置信度)
    3) 保存检测后的每一帧图像
    4) 将所有帧合成一个 30fps 的视频

    先贴几张图,最后合成的视频地址:https://www.bilibili.com/video/av94163474,代码在下方。

    此代码也可实现对批量图片的检测以及实时检测。

     

     

     修改后的darknet.py文件如下:

      1 from ctypes import *
      2 import random
      3 import cv2
      4 import numpy as np
      5 import os
      6 import time
      7 
      8 
      9 def sample(probs):
     10     s = sum(probs)
     11     probs = [a / s for a in probs]
     12     r = random.uniform(0, 1)
     13     for i in range(len(probs)):
     14         r = r - probs[i]
     15         if r <= 0:
     16             return i
     17     return len(probs) - 1
     18 
     19 
     20 def c_array(ctype, values):
     21     arr = (ctype * len(values))()
     22     arr[:] = values
     23     return arr
     24 
     25 
     26 class BOX(Structure):
     27     _fields_ = [("x", c_float),
     28                 ("y", c_float),
     29                 ("w", c_float),
     30                 ("h", c_float)]
     31 
     32 
     33 class DETECTION(Structure):
     34     _fields_ = [("bbox", BOX),
     35                 ("classes", c_int),
     36                 ("prob", POINTER(c_float)),
     37                 ("mask", POINTER(c_float)),
     38                 ("objectness", c_float),
     39                 ("sort_class", c_int)]
     40 
     41 
     42 class IMAGE(Structure):
     43     _fields_ = [("w", c_int),
     44                 ("h", c_int),
     45                 ("c", c_int),
     46                 ("data", POINTER(c_float))]
     47 
     48 
     49 class METADATA(Structure):
     50     _fields_ = [("classes", c_int),
     51                 ("names", POINTER(c_char_p))]
     52 
     53 
     54 lib = CDLL("./libdarknet.so", RTLD_GLOBAL)
     55 lib.network_width.argtypes = [c_void_p]
     56 lib.network_width.restype = c_int
     57 lib.network_height.argtypes = [c_void_p]
     58 lib.network_height.restype = c_int
     59 
     60 predict = lib.network_predict
     61 predict.argtypes = [c_void_p, POINTER(c_float)]
     62 predict.restype = POINTER(c_float)
     63 
     64 set_gpu = lib.cuda_set_device
     65 set_gpu.argtypes = [c_int]
     66 
     67 make_image = lib.make_image
     68 make_image.argtypes = [c_int, c_int, c_int]
     69 make_image.restype = IMAGE
     70 
     71 get_network_boxes = lib.get_network_boxes
     72 get_network_boxes.argtypes = [c_void_p, c_int, c_int, c_float, c_float, POINTER(c_int), c_int, POINTER(c_int)]
     73 get_network_boxes.restype = POINTER(DETECTION)
     74 
     75 make_network_boxes = lib.make_network_boxes
     76 make_network_boxes.argtypes = [c_void_p]
     77 make_network_boxes.restype = POINTER(DETECTION)
     78 
     79 free_detections = lib.free_detections
     80 free_detections.argtypes = [POINTER(DETECTION), c_int]
     81 
     82 free_ptrs = lib.free_ptrs
     83 free_ptrs.argtypes = [POINTER(c_void_p), c_int]
     84 
     85 network_predict = lib.network_predict
     86 network_predict.argtypes = [c_void_p, POINTER(c_float)]
     87 
     88 reset_rnn = lib.reset_rnn
     89 reset_rnn.argtypes = [c_void_p]
     90 
     91 load_net = lib.load_network
     92 load_net.argtypes = [c_char_p, c_char_p, c_int]
     93 load_net.restype = c_void_p
     94 
     95 do_nms_obj = lib.do_nms_obj
     96 do_nms_obj.argtypes = [POINTER(DETECTION), c_int, c_int, c_float]
     97 
     98 do_nms_sort = lib.do_nms_sort
     99 do_nms_sort.argtypes = [POINTER(DETECTION), c_int, c_int, c_float]
    100 
    101 free_image = lib.free_image
    102 free_image.argtypes = [IMAGE]
    103 
    104 letterbox_image = lib.letterbox_image
    105 letterbox_image.argtypes = [IMAGE, c_int, c_int]
    106 letterbox_image.restype = IMAGE
    107 
    108 load_meta = lib.get_metadata
    109 lib.get_metadata.argtypes = [c_char_p]
    110 lib.get_metadata.restype = METADATA
    111 
    112 load_image = lib.load_image_color
    113 load_image.argtypes = [c_char_p, c_int, c_int]
    114 load_image.restype = IMAGE
    115 
    116 rgbgr_image = lib.rgbgr_image
    117 rgbgr_image.argtypes = [IMAGE]
    118 
    119 predict_image = lib.network_predict_image
    120 predict_image.argtypes = [c_void_p, IMAGE]
    121 predict_image.restype = POINTER(c_float)
    122 
    123 
    124 def convertBack(x, y, w, h):
    125     xmin = int(round(x - (w / 2)))
    126     xmax = int(round(x + (w / 2)))
    127     ymin = int(round(y - (h / 2)))
    128     ymax = int(round(y + (h / 2)))
    129     return xmin, ymin, xmax, ymax
    130 
    131 
    132 def array_to_image(arr):
    133     # need to return old values to avoid python freeing memory
    134     # arr = np.asarray(arr, dtype='float64') # add by dengjie
    135     arr = arr.transpose(2, 0, 1)
    136     c, h, w = arr.shape[0:3]
    137     arr = np.ascontiguousarray(arr.flat, dtype=np.float32) / 255.0
    138     data = arr.ctypes.data_as(POINTER(c_float))
    139     im = IMAGE(w, h, c, data)
    140     return im, arr
    141 
    142 
    143 def detect(net, meta, image, thresh=.5, hier_thresh=.5, nms=.45):
    144     im, image = array_to_image(image)
    145     rgbgr_image(im)
    146     num = c_int(0)
    147 
    148     pnum = pointer(num)
    149     predict_image(net, im)
    150     dets = get_network_boxes(net, im.w, im.h, thresh,
    151                              hier_thresh, None, 0, pnum)
    152     num = pnum[0]
    153     if nms: do_nms_obj(dets, num, meta.classes, nms)
    154 
    155     res = []
    156     for j in range(num):
    157         a = dets[j].prob[0:meta.classes]
    158         if any(a):
    159             ai = np.array(a).nonzero()[0]
    160             for i in ai:
    161                 b = dets[j].bbox
    162                 res.append((meta.names[i], dets[j].prob[i],
    163                             (b.x, b.y, b.w, b.h)))
    164 
    165     res = sorted(res, key=lambda x: -x[1])
    166     if isinstance(image, bytes): free_image(im)
    167     free_detections(dets, num)
    168     return res
    169 
    170 
    171 def mode_select(state):
    172     if state not in {'picture', 'video', 'real_time'}:
    173         raise ValueError('{} is not a valid argument!'.format(state))
    174     if state == 'video' or state == 'real_time':
    175         if state == 'real_time':
    176             # video = "http://admin:admin@192.168.0.13:8081"
    177             video = 0
    178         elif state == 'video':
    179             video = '../test/test_video/video7.mp4'
    180         cap = cv2.VideoCapture(video)
    181     else:
    182         cap = 1
    183     return cap
    184 
    185 
    186 def find_object_in_picture(ret, img):
    187     for i in ret:
    188         # index = LABELS.index(str(i[0])[2:-1])
    189         index = LABELS.index(i[0].decode())
    190         color = COLORS[index].tolist()
    191         x, y, w, h = i[2][0], i[2][1], i[2][2], i[2][3]
    192         xmin, ymin, xmax, ymax = convertBack(float(x), float(y), float(w), float(h))
    193         pt1 = (xmin, ymin)
    194         pt2 = (xmax, ymax)
    195         cv2.rectangle(img, pt1, pt2, color, 3)
    196         if state == 'video':
    197             text = i[0].decode()
    198         else:
    199             text = i[0].decode() + " [" + str(round(i[1] * 100, 2)) + "]"
    200         (text_w, text_h), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
    201         cv2.rectangle(img, (pt1[0], pt1[1] - text_h - baseline), (pt1[0] + text_w, pt1[1]), color, -1)
    202         cv2.putText(img, text, (pt1[0], pt1[1] - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
    203     return img
    204 
    205 
    206 def save_video(state, out_video):
    207     if state == 'video':
    208         if out_video:
    209             img = cv2.imread('../result/result_frame/result_frame_0.jpg', 1)
    210             isColor = 1
    211             FPS = 20.0
    212             frameWidth = img.shape[1]
    213             frameHeight = img.shape[0]
    214             fourcc = cv2.VideoWriter_fourcc(*'XVID')
    215             out = cv2.VideoWriter('../result/result_video/result_video.avi', fourcc, FPS,
    216                                   (frameWidth, frameHeight), isColor)
    217             list = os.listdir(frame_path)
    218             print('the number of video frames is', len(list))
    219             for i in range(len(list)):
    220                 frame = cv2.imread(
    221                     '../result/result_frame/result_frame_%d.jpg' % i, 1)
    222                 out.write(frame)
    223                 if cv2.waitKey(25) & 0xFF == ord('q'):
    224                     break
    225             out.release()
    226             print('video has already saved.')
    227             return 1
    228         else:
    229             return 0
    230     else:
    231         return 0
    232 
    233 
    234 def load_model():
    235     net1 = load_net(b"/home/dengjie/dengjie/project/detection/from_darknet/cfg/yolov3.cfg",
    236                     b"/home/dengjie/dengjie/project/detection/from_darknet/cfg/yolov3.weights",
    237                     0)
    238     meta1 = load_meta("/home/dengjie/dengjie/project/detection/from_darknet/cfg/coco.data".encode('utf-8'))
    239     label_path = '../data/coco.names'
    240     LABELS1 = open(label_path).read().strip().split("
    ")
    241     num_class = len(LABELS1)
    242     return net1, meta1, LABELS1, num_class
    243 
    244 
    245 def random_color(num):
    246     """
    247     colorArr = ['1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F']
    248     color = ""
    249     for i in range(6):
    250         color += colorArr[random.randint(0, 14)]
    251     return "#" + color
    252     color = np.random.randint(0, 256, size=[1, 3])
    253     color = color.tolist()[0]
    254     """
    255     # 为每个类别的边界框随机匹配相应颜色
    256     np.random.seed(80)
    257     COLORS = np.random.randint(0, 256, size=(num, 3), dtype='uint8')  #
    258     return COLORS
    259 
    260 
    261 if __name__ == "__main__":
    262     k = 0
    263     path = '../test/test_pic'
    264     frame_path = '../result/result_frame'
    265 
    266     state = 'video'  # 检测模式选择,state = 'video','picture','real_time'
    267 
    268     net, meta, LABELS, class_num = load_model()
    269     cap = mode_select(state)
    270     COLORS = random_color(class_num)
    271     print('start detect')
    272 
    273     if cap == 1:
    274         test_list = os.listdir(path)
    275         test_list.sort()
    276         k = 0
    277         sum_t = 0
    278         print('test_list', test_list[1:])
    279         for j in test_list:
    280             time_p = time.time()
    281             img = cv2.imread(os.path.join(path, j), 1)
    282             r = detect(net, meta, img)
    283             # print(r)
    284             # [(b'person', 0.6372514963150024,
    285             # (414.55322265625, 279.70245361328125, 483.99005126953125, 394.2349853515625))]
    286             # 类别,识别概率,识别物体的X坐标,识别物体的Y坐标,识别物体的长度,识别物体的高度
    287             image = find_object_in_picture(r, img)
    288             t = time.time() - time_p
    289             if j != test_list[0]:
    290                 sum_t += t
    291                 print('process ' + j + ' spend %.5fs' % t)
    292                 cv2.imshow("img", img)
    293                 cv2.imwrite('../result/result_pic/result_%d.jpg' % k, image)
    294                 k += 1
    295                 cv2.waitKey()
    296                 cv2.destroyAllWindows()
    297         print('Have processed %d pictures.' % k)
    298         print('Total picture-processing time is %.5fs' % sum_t)
    299         print('Average processing time is %.5fs' % (sum_t / k))
    300         print('Have Done!')
    301     else:
    302         sum_v = 0
    303         sum_fps = 0
    304         i = 0  # 帧数记录
    305         while True:
    306             time_v = time.time()
    307             ret, img = cap.read()
    308             # fps = cap.get(cv2.CAP_PROP_FPS)
    309             # print('fps', fps)
    310             if ret:
    311                 i += 1
    312                 r = detect(net, meta, img)
    313                 image = find_object_in_picture(r, img)
    314                 cv2.imshow("window", image)
    315                 t_v = time.time() - time_v
    316                 fps = 1 / t_v
    317                 if i > 1:
    318                     print('FPS %.3f' % fps)
    319                     sum_fps += fps
    320                 sum_v += t_v
    321                 if state == 'video':
    322                     cv2.imwrite('../result/result_frame/result_frame_%d.jpg' % k, image)
    323                     k += 1
    324             else:  # 视频播放结束
    325                 print('Total processing time is %.5fs' % sum_v)
    326                 print('Detected frames : %d ' % i)
    327                 print('Average fps is %.3f' % (sum_fps / (i - 1)))
    328                 cap.release()
    329                 cv2.destroyAllWindows()
    330                 break
    331             if cv2.waitKey(1) & 0xFF == ord('q'):
    332                 # cv2.waitKey(1) 1为参数,单位毫秒,表示间隔时间,ord(' ')将字符转化为对应的整数(ASCII码);
    333                 # cv2.waitKey()和(0)是等待输入
    334                 print('Detected time is %.5fs' % sum_v)
    335                 print('Average fps is %.3f' % (sum_fps / (i - 1)))
    336                 print('Detected frames : %d ' % i)
    337                 cap.release()
    338                 cv2.destroyAllWindows()
    339                 break
    340         val = save_video(state, True)
    341         if val == 1:
    342             print('Have Done!')
    343         else:
    344             print('Detection has finished.')
  • 相关阅读:
    c/c++字节序转换(转)
    O_DIRECT与O_SYNC区别(转)
    TCMalloc小记(转)
    内存优化总结:ptmalloc、tcmalloc和jemalloc(转)
    不依赖三方库从图像数据中获取宽高-gif、bmp、png、jepg
    Linux查看物理CPU个数、核数、逻辑CPU个数
    unix环境高级编程-3.10-文件共享(转)
    TF_Server gRPC failed, call return code:8:Received message larger than max (45129801 vs. 4194304)
    google doc上的云转换
    telnet 退出命令
  • 原文地址:https://www.cnblogs.com/DJames23/p/12456011.html
Copyright © 2020-2023  润新知