• TensorFlow SSD代码的运行,视频检测 2017-10-18


    主要是修改检测程序:

    1. 原来使用image,改为读取avi
    2. 原来使用visualization.plt_bboxes(img, rclasses, rscores, rbboxes)函数直接画图,修改为visualization.bboxes_draw_on_img(image_np, rclasses, rscores, rbboxes)将框改到image上
    3. 对visualization中做修改,使得之前代码只能类别显示为数字的,变为直接显示为英文类别文字

    主代码修改为:

    # -*- coding: utf-8 -*-
    """
    Created on Tue Oct 17 10:30:37 2017
    ssd_ttest_300
    @author: IRay
    """
    
    
    import os
    import math
    import random
    
    
    import numpy as np
    import tensorflow as tf
    import cv2
    
    slim = tf.contrib.slim
    
    
    #%matplotlib inline
    import matplotlib.pyplot as plt
    import matplotlib.image as mpimg
    
    
    import sys
    sys.path.append('../')
    
    
    from tensorflow.models.SSD_Tensorflow_master.nets import ssd_vgg_300, ssd_common, np_methods
    from tensorflow.models.SSD_Tensorflow_master.preprocessing import ssd_vgg_preprocessing
    from tensorflow.models.SSD_Tensorflow_master.notebooks import visualization
    
    from tensorflow.models.SSD_Tensorflow_master.datasets import pascalvoc_2007
    
    
    # TensorFlow session: grow memory when needed. TF, DO NOT USE ALL MY GPU MEMORY!!!
    gpu_options = tf.GPUOptions(allow_growth=True)
    config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options)
    isess = tf.InteractiveSession(config=config)
    
    
    # Input placeholder.
    net_shape = (300, 300)
    data_format = 'NHWC'
    img_input = tf.placeholder(tf.uint8, shape=(None, None, 3))
    # Evaluation pre-processing: resize to SSD net shape.
    image_pre, labels_pre, bboxes_pre, bbox_img = ssd_vgg_preprocessing.preprocess_for_eval(
        img_input, None, None, net_shape, data_format, resize=ssd_vgg_preprocessing.Resize.WARP_RESIZE)
    image_4d = tf.expand_dims(image_pre, 0)
    
    # Define the SSD model.
    reuse = True if 'ssd_net' in locals() else None
    ssd_net = ssd_vgg_300.SSDNet()
    with slim.arg_scope(ssd_net.arg_scope(data_format=data_format)):
        predictions, localisations, _, _ = ssd_net.net(image_4d, is_training=False, reuse=reuse)
    
    # Restore SSD model.
    ckpt_filename = r'D:softwareanacondaenvs	ensorflowLibsite-packages	ensorflowmodelsSSD_Tensorflow_mastercheckpoints/ssd_300_vgg.ckpt'
    # ckpt_filename = '../checkpoints/VGG_VOC0712_SSD_300x300_ft_iter_120000.ckpt'
    isess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(isess, ckpt_filename)
    
    # SSD default anchor boxes.
    ssd_anchors = ssd_net.anchors(net_shape)
    
    
    #print(ssd_net)
    #ssd_net.anchors(net_shape)
    
    
    
    # Main image processing routine.
    def process_image(img, select_threshold=0.5, nms_threshold=.45, net_shape=(300, 300)):
        # Run SSD network.
        rimg, rpredictions, rlocalisations, rbbox_img = isess.run([image_4d, predictions, localisations, bbox_img],
                                                                  feed_dict={img_input: img})
        
        # Get classes and bboxes from the net outputs.
        rclasses, rscores, rbboxes = np_methods.ssd_bboxes_select(
                rpredictions, rlocalisations, ssd_anchors,
                select_threshold=select_threshold, img_shape=net_shape, num_classes=21, decode=True)
        
        rbboxes = np_methods.bboxes_clip(rbbox_img, rbboxes)
        rclasses, rscores, rbboxes = np_methods.bboxes_sort(rclasses, rscores, rbboxes, top_k=400)
        rclasses, rscores, rbboxes = np_methods.bboxes_nms(rclasses, rscores, rbboxes, nms_threshold=nms_threshold)
        # Resize bboxes to original image shape. Note: useless for Resize.WARP!
        rbboxes = np_methods.bboxes_resize(rbbox_img, rbboxes)
        return rclasses, rscores, rbboxes
    
    
    # Test on some demo image and visualize output.
    #path = r'D:softwareanacondaenvs	ensorflowLibsite-packages	ensorflowmodelsSSD_Tensorflow_master/demo/'
    #image_names = sorted(os.listdir(path))
    #
    #img = mpimg.imread(path + image_names[-2])
    #rclasses, rscores, rbboxes =  process_image(img)
    #
    ## visualization.bboxes_draw_on_img(img, rclasses, rscores, rbboxes, visualization.colors_plasma)
    #visualization.plt_bboxes(img, rclasses, rscores, rbboxes)
    
    
    cap = cv2.VideoCapture(r'D:
    awDatamovieshighway2raw.avi')
    fps = cap.get(cv2.CAP_PROP_FPS) 
    size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))) 
    fourcc = cap.get(cv2.CAP_PROP_FOURCC) 
    #fourcc = cv2.CAP_PROP_FOURCC(*'CVID') 
    print('fps=%d,size=%r,fourcc=%r'%(fps,size,fourcc))
    delay=30/int(fps)
    
    
    while(cap.isOpened()):
          ret,frame = cap.read()
          if ret==True:  
    #          image = Image.open(image_path)
    #          gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
              image = frame
              # the array based representation of the image will be used later in order to prepare the
              # result image with boxes and labels on it.
              image_np = image
    #          image_np = load_image_into_numpy_array(image)
              # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
              image_np_expanded = np.expand_dims(image_np, axis=0)
              # Actual detection.
              rclasses, rscores, rbboxes =  process_image(image_np)
              # Visualization of the results of a detection.
              visualization.bboxes_draw_on_img(image_np, rclasses, rscores, rbboxes)
    #          plt.figure(figsize=IMAGE_SIZE)
    #          plt.imshow(image_np)
              cv2.imshow('frame',image_np)
              cv2.waitKey(np.uint(delay))
              print('Ongoing...')  
          else:
              break
    cap.release()
    cv2.destroyAllWindows()

    visualization文件代码修改为:

    # Copyright 2017 Paul Balanca. All Rights Reserved.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    # ==============================================================================
    import cv2
    import random
    
    import matplotlib.pyplot as plt
    import matplotlib.image as mpimg
    import matplotlib.cm as mpcm
    
    
    #################added
    def num2class(n):
        import tensorflow.models.SSD_Tensorflow_master.datasets.pascalvoc_2007 as pas
        x=pas.pascalvoc_common.VOC_LABELS.items()
        for name,item in x:
            if n in item:
                #print(name)
                return name
    ###########################added
    
    
    
    # =========================================================================== #
    # Some colormaps.
    # =========================================================================== #
    def colors_subselect(colors, num_classes=21):
        dt = len(colors) // num_classes
        sub_colors = []
        for i in range(num_classes):
            color = colors[i*dt]
            if isinstance(color[0], float):
                sub_colors.append([int(c * 255) for c in color])
            else:
                sub_colors.append([c for c in color])
        return sub_colors
    
    colors_plasma = colors_subselect(mpcm.plasma.colors, num_classes=21)
    colors_tableau = [(255, 255, 255), (31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),
                      (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),
                      (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),
                      (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),
                      (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)]
    
    
    # =========================================================================== #
    # OpenCV drawing.
    # =========================================================================== #
    def draw_lines(img, lines, color=[255, 0, 0], thickness=2):
        """Draw a collection of lines on an image.
        """
        for line in lines:
            for x1, y1, x2, y2 in line:
                cv2.line(img, (x1, y1), (x2, y2), color, thickness)
    
    
    def draw_rectangle(img, p1, p2, color=[255, 0, 0], thickness=2):
        cv2.rectangle(img, p1[::-1], p2[::-1], color, thickness)
    
    
    def draw_bbox(img, bbox, shape, label, color=[255, 0, 0], thickness=2):
        p1 = (int(bbox[0] * shape[0]), int(bbox[1] * shape[1]))
        p2 = (int(bbox[2] * shape[0]), int(bbox[3] * shape[1]))
        cv2.rectangle(img, p1[::-1], p2[::-1], color, thickness)
        p1 = (p1[0]+15, p1[1])
        cv2.putText(img, str(label), p1[::-1], cv2.FONT_HERSHEY_DUPLEX, 0.5, color, 1)
    
    
    def bboxes_draw_on_img(img, classes, scores, bboxes, colors=dict(), thickness=2):
        shape = img.shape
        ######################added
        #colors = dict()
        ######################added
        for i in range(bboxes.shape[0]):
            bbox = bboxes[i]
            ######################added
            if classes[i] not in colors:
                    colors[classes[i]] = (random.random(), random.random(), random.random())
            ######################added
            #color = colors[classes[i]]
            # Draw bounding box...
            p1 = (int(bbox[0] * shape[0]), int(bbox[1] * shape[1]))
            p2 = (int(bbox[2] * shape[0]), int(bbox[3] * shape[1]))
            #################modified color
            cv2.rectangle(img, p1[::-1], p2[::-1], colors[classes[i]], thickness)
            # Draw text...
            #s = '%s/%.3f' % (classes[i], scores[i])
            #######################added
            s = '%s/%.3f' % (num2class(classes[i]), scores[i])
            ########################added
            p1 = (p1[0]-5, p1[1])
            #################modified color
            cv2.putText(img, s, p1[::-1], cv2.FONT_HERSHEY_DUPLEX, 0.4, colors[classes[i]], 1)
    
    
    # =========================================================================== #
    # Matplotlib show...
    # modifed by wangjc,2017.10.18
    # =========================================================================== #
    
    
    
    def plt_bboxes(img, classes, scores, bboxes, figsize=(10,10), linewidth=1.5):
        """Visualize bounding boxes. Largely inspired by SSD-MXNET!
        """
    
    
        fig = plt.figure(figsize=figsize)
        plt.imshow(img)
        height = img.shape[0]
        width = img.shape[1]
        colors = dict()
        for i in range(classes.shape[0]):
            cls_id = int(classes[i])
            if cls_id >= 0:
                score = scores[i]
                #score = 0.01
                if cls_id not in colors:
                    colors[cls_id] = (random.random(), random.random(), random.random())
                ymin = int(bboxes[i, 0] * height)
                xmin = int(bboxes[i, 1] * width)
                ymax = int(bboxes[i, 2] * height)
                xmax = int(bboxes[i, 3] * width)
                rect = plt.Rectangle((xmin, ymin), xmax - xmin,
                                     ymax - ymin, fill=False,
                                     edgecolor=colors[cls_id],
                                     linewidth=linewidth)
                plt.gca().add_patch(rect)
                #class_name = str(cls_id)
                ###################added
                #class_name = ['haha','a','ss']
                class_name = num2class(cls_id)
                ##################added
                plt.gca().text(xmin, ymin - 2,
                               '{:s} | {:.3f}'.format(class_name, score),
                               bbox=dict(facecolor=colors[cls_id], alpha=0.5),
                               fontsize=12, color='white')
        plt.show()
  • 相关阅读:
    Android WindowManager和WindowManager.LayoutParams的使用以及实现悬浮窗口的方法
    Android 自定义控件之圆形扩散View(DiffuseView)
    Android线性渐变
    Android Drawable之getIntrinsicWidth()和getIntrinsicHeight()
    Android 用Handler和Message实现计时效果及其中一些疑问
    CentOS6.5下nginx-1.8.1.tar.gz的单节点搭建(图文详解)
    Zeppelin的入门使用系列之创建新的Notebook(一)
    hadoop报错java.io.IOException: Incorrect configuration: namenode address dfs.namenode.servicerpc-address or dfs.namenode.rpc-address is not configured
    ubuntu系统里vi编辑器时,按方向箭头输入是乱码的ABCD字母?(图文详解)
    VirtualBox里如何正确安装增强工具(图文详解)
  • 原文地址:https://www.cnblogs.com/Osler/p/8427865.html
Copyright © 2020-2023  润新知