• [Paddle学习笔记][08][基于YOLOv3的昆虫检测-数据处理]


    说明:

    本例程使用YOLOv3进行昆虫检测。例程分为数据处理、模型设计、损失函数、训练模型、模型预测和测试模型六个部分。本篇为第一部分,实现了昆虫检测训练的数据预处理功能和预测和测试时读取和显示数据功能。

    数据集下载地址:https://aistudio.baidu.com/aistudio/datasetdetail/19748

    实验代码:

    数据增强:

    import matplotlib.pyplot as plt
    import matplotlib.patches as patches
    from PIL import Image
    
    from source.data import get_data_list, get_data_item, augment_image
    
    # 读取数据
    train_set = './dataset/train/'
    data_list = get_data_list(train_set) # 读取数据列表
    
    data = data_list[0] # 读取第一条数据
    image, gtbox, gtcls, image_size = get_data_item(data) # 读取数据项目
    
    # 增强图像
    scale_size = (608, 608) # 缩放宽高
    image, gtbox, gtcls = augment_image(image, gtbox, gtcls, scale_size)
    
    # 显示图像
    object_names = ['Boerner', 'Leconte', 'Linnaeus', 'acuminatus', 'armandi', 'coleoptera', 'linnaeus'] # 物体名称
    color = ['r', 'g', 'b', 'c','m', 'y', 'k'] # 边框颜色
    image = Image.fromarray(image) # 转换为Image格式
    plt.figure(figsize=(10, 10)) # 设置显示图像大小
    currentAxis = plt.gca() # 获取图像当前坐标
    
    gtbox[:, 0] = gtbox[:, 0]*float(scale_size[1]) # 计算边框x真实值
    gtbox[:, 1] = gtbox[:, 1]*float(scale_size[0]) # 计算边框y真实值
    gtbox[:, 2] = gtbox[:, 2]*float(scale_size[1]) # 计算边框w真实值
    gtbox[:, 3] = gtbox[:, 3]*float(scale_size[0]) # 计算边框h真实值
    
    for i in range(len(gtbox)): 
        # 绘制边框
        if gtbox[i, 2] > 1e-3 and gtbox[i, 3] > 1e-3:
            # 获取数据
            x = gtbox[i, 0] - gtbox[i, 2]/2
            y = gtbox[i, 1] - gtbox[i, 3]/2
            w = gtbox[i, 2]
            h = gtbox[i, 3]
            index = int(gtcls[i]) # 类别索引
            names = object_names[index] # 类别名称
            
            # 绘制边框
            rectangle = patches.Rectangle((x, y), w, h, 
                                          linewidth=1, edgecolor=color[index], facecolor=color[index], 
                                          fill=False, linestyle='-')
            currentAxis.add_patch(rectangle)
            
            # 绘制类别
            plt.text(x, y, names, fontsize=12, color=color[index])
    
    plt.imshow(image)
    plt.show()

     结果:

    单线程读取批次数据:

    from source.data import single_thread_reader
    
    # 批次读取训练数据
    train_set = './dataset/train/'
    train_reader = single_thread_reader(train_set, 2, 'train')
    
    train_data = next(train_reader())
    print('train_data - image:{}, gtbox:{}, gtcls:{}, image_size:{}'.format(
        train_data[0].shape, train_data[1].shape, train_data[2].shape, train_data[3].shape))

    结果:train_data - image:(2, 3, 448, 448), gtbox:(2, 50, 4), gtcls:(2, 50), image_size:(2, 2)

    多线程读取批次数据:

    from source.data import multip_thread_reader
    
    # 多线程读取训练数据
    valid_set = './dataset/val/'
    valid_reader = multip_thread_reader(valid_set, 2, 'valid')
    
    valid_data = next(valid_reader())
    print('valid_data - image:{}, gtbox:{}, gtcls:{}, image_size:{}'.format(
        valid_data[0].shape, valid_data[1].shape, valid_data[2].shape, valid_data[3].shape))

    结果:valid_data - image:(2, 3, 512, 512), gtbox:(2, 50, 4), gtcls:(2, 50), image_size:(2, 2)

    data.py文件

    import os
    import random
    
    import numpy as np
    import xml.etree.ElementTree as ET
    
    from PIL import Image, ImageEnhance
    import matplotlib.pyplot as plt
    import matplotlib.patches as patches
    
    import paddle
    
    object_names = ['Boerner', 'Leconte', 'Linnaeus', 'acuminatus', 'armandi', 'coleoptera', 'linnaeus'] # 物体名称
    def get_object_gtcls():
        """
        功能:
            将物体名称映射成物体类别
        输入:
        输出:
            object_gtcls - 物体类别
        """
        object_gtcls = {} # 物体类别字典
        for key, value in enumerate(object_names):
            object_gtcls[value] = key # 将物体名称映射成物体类别
            
        return object_gtcls
    
    def get_data_list(data_path):
        """
        功能:
            读取数据列表
        输入:
            data_path - 数据路径
        输出:
            data_list - 数据列表
        """
        file_list = os.listdir(os.path.join(data_path, 'annotations', 'xmls')) # 文件列表
        data_list = [] # 数据列表
        
        for file_id, file_name in enumerate(file_list):
            # 读取数据文件
            file_path = os.path.join(data_path, 'annotations', 'xmls', file_name) # 文件路径
            xml_tree = ET.parse(file_path) # 解析文件
            
            # 读取图像数据
            image_path = os.path.join(data_path, 'images', file_name.split('.')[0] + '.jpeg') # 图像路径
            image_w = float(xml_tree.find('size').find('width').text) # 图像宽度
            image_h = float(xml_tree.find('size').find('height').text) # 图像高度
            
            # 读取物体数据
            object_list = xml_tree.findall('object') # 物体列表
            object_gtbox = np.zeros((len(object_list), 4), dtype=np.float32) # 物体边框
            object_gtcls = np.zeros((len(object_list),  ), dtype=np.int32) # 物体类别
            
            for object_id, object_item in enumerate(object_list):
                # 读取物体类别
                object_name = object_item.find('name').text # 读取物体名称
                object_gtcls[object_id] = get_object_gtcls()[object_name] # 将物体名称映射成物体类别
                
                # 读取物体边框
                x_min = float(object_item.find('bndbox').find('xmin').text)
                y_min = float(object_item.find('bndbox').find('ymin').text)
                x_max = float(object_item.find('bndbox').find('xmax').text)
                y_max = float(object_item.find('bndbox').find('ymax').text)
                
                x_min = max(0.0, x_min)
                y_min = max(0.0, y_min)
                x_max = min(x_max, image_w - 1.0)
                y_max = min(y_max, image_h - 1.0)
                
                object_gtbox[object_id] = [
                    (x_min + x_max) / 2.0, (y_min + y_max) / 2.0,
                    (x_max - x_min + 1.0), (y_max - y_min + 1.0)] # 计算物体边框: xywh格式
            
            # 保存数据列表
            data = {'image_path': image_path, 'image_width': image_w, 'image_height': image_h,
                    'object_gtbox': object_gtbox, 'object_gtcls': object_gtcls}
            
            if len(data) != 0:
                data_list.append(data)
                
        return data_list
    
    ##############################################################################################################
    
    def get_object_item(object_gtbox, object_gtcls, image_size):
        """
        功能:
            读取物体项目
        输入:
            object_gtbox - 物体边框
            object_gtcls - 物体类别
            image_size   - 图像高宽
        输出:
            gtbox        - 边框列表
            gtcls        - 类别列表
        """
        # 添加物体数据列表
        max_num = 50 # 最大物体数量
        gtbox = np.zeros((max_num, 4)) # 边框列表
        gtcls = np.zeros((max_num,  )) # 类别列表
        
        for i in range(len(object_gtbox)):
            gtbox[i, :] = object_gtbox[i, :]
            gtcls[i] = object_gtcls[i]
            
            if i > max_num: # 是否超过最大物体数量
                break
        
        # 转换为相对真实图像比例的边框
        gtbox[:, 0] = gtbox[:, 0]/float(image_size[1]) # 计算边框x相对值
        gtbox[:, 1] = gtbox[:, 1]/float(image_size[0]) # 计算边框y相对值
        gtbox[:, 2] = gtbox[:, 2]/float(image_size[1]) # 计算边框w相对值
        gtbox[:, 3] = gtbox[:, 3]/float(image_size[0]) # 计算边框h相对值
        
        return gtbox, gtcls
    
    def get_data_item(data):
        """
        功能: 
            读取数据项目
        输入: 
            data       - 数据
        输出:
            image      - 图像数据
            gtbox      - 边框列表
            gtcls      - 类别列表
            image_size - 图像高宽
        """
        # 读取数据项目
        image_path = data['image_path'] # 图像路径
        image_w = data['image_width'] # 图像宽度
        image_h = data['image_height'] # 图像高度
        object_gtbox = data['object_gtbox'] # 物体边框
        object_gtcls = data['object_gtcls'] # 物体类别
        
        # 打开图像文件
        image = Image.open(image_path) # 打开图像
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        image = np.array(image) # 转换为ndarray格式, 读取数据类型为HWC,uint8类型
        
        # 检查图像高宽
        assert image.shape[0] == int(image_h), 
            'image path: {}, image.shape[0]:{} != image_h: {}'.format(
            image_path, image.shape[0], image_h)
        assert image.shape[1] == int(image_w), 
            'image path: {}, image.shape[1]:{} != image_w: {}'.format(
            image_path, image.shape[1], image_w)
        
        # 读取物体项目
        image_size = (image_h, image_w)
        gtbox, gtcls = get_object_item(object_gtbox, object_gtcls, image_size)
        
        return image, gtbox, gtcls, image_size
    
    ##############################################################################################################
    
    def random_distort_image(image):
        """
        功能: 
            随机变换图像
        输入: 
            image - 图像数据
        输出:
            image - 图像数据
        """
        # 随机变换饱和度
        def random_distort_saturation(image):
            if random.random() > 0.5:
                delta = random.uniform(-0.5, 0.5) + 1
                image = ImageEnhance.Color(image).enhance(delta)
            return image
        
        # 随机变换明亮度
        def random_distort_brightness(image):
            if random.random() > 0.5:
                delta = random.uniform(-0.5, 0.5) + 1
                image = ImageEnhance.Brightness(image).enhance(delta)
            return image
        
        # 随机变换对比度
        def random_distort_constract(image):
            if random.random() > 0.5:
                delta = random.uniform(-0.5, 0.5) + 1
                image = ImageEnhance.Contrast(image).enhance(delta)
            return image
        
        # 随机变换色调
        def random_distort_hue(image):
            if random.random() > 0.5:
                delta = random.uniform(-18, 18)
                image = np.array(image.convert('HSV'))
                image[:, :, 0] = image[:, :, 0] + delta # 变换色调
                image = Image.fromarray(image, mode='HSV').convert('RGB')
            return image
        
        # 随机变换图像
        distort = [random_distort_saturation, random_distort_brightness, random_distort_constract, random_distort_hue] # 变换方法列表
        np.random.shuffle(distort) # 打乱变换顺序
        
        image = Image.fromarray(image) # 转换为Image格式
        image = distort[0](image)
        image = distort[1](image)
        image = distort[2](image)
        image = distort[3](image)
        image = np.asarray(image) # 转换为ndarray格式
        
        return image
    
    def random_expand_image(image, gtbox, keep_ratio=True):
        """
        功能: 
            随机填充图像
        输入: 
            image      - 图像数据
            gtbox      - 边框列表
            keep_ratio - 保持宽高比例
        输出:
            image      - 图像数据
            gtbox      - 边框列表
        """
        # 是否填充图像
        if random.random() > 0.5:
            return image, gtbox
        
        # 生成填充比例
        max_ratio = 4
        x_ratio = random.uniform(1, max_ratio) # 随机产生x填充比例
        
        if keep_ratio: # 是否保持宽高比例
            y_ratio = x_ratio # y填充比例等于x填充比例
        else:
            y_ratio = random.uniform(1, max_ratio) # 随机产生y填充比例
            
        # 计算填充宽高
        image_h, image_w, image_channel = image.shape # 获取图像高宽和通道数
        
        expand_w = int(image_w * x_ratio) # 计算填充宽度
        expand_h = int(image_h * y_ratio) # 计算填充高度
        
        x_offset = random.randint(0, expand_w - image_w) # 随机生成原图在填充图中x坐标位置
        y_offset = random.randint(0, expand_h - image_h) # 随机生成原图在填充图中y坐标位置
        
        # 生成填充图像
        expand_image = np.zeros((expand_h, expand_w, image_channel)) # float32类型
        
        mean_value = [0.485, 0.456, 0.406] # COCO数据集通道平均值
        for i in range(image_channel):
            expand_image[:, :, i] = mean_value[i] * 255.0 # 使用均值填充每个通道
        
        # 填充原始图像
        expand_image[y_offset : y_offset + image_h, x_offset : x_offset + image_w, :] = image # 填充原始图像
        image = expand_image.astype('uint8') # 转换为uint8类型
        
        # 计算相对边框
        gtbox[:, 0] = ((gtbox[:, 0] * image_w) + x_offset) / float(expand_w) # 计算边框x相对值
        gtbox[:, 1] = ((gtbox[:, 1] * image_h) + y_offset) / float(expand_h) # 计算边框y相对值
        gtbox[:, 2] = gtbox[:, 2] / x_ratio # 计算边框w相对值
        gtbox[:, 3] = gtbox[:, 3] / y_ratio # 计算边框h相对值
        
        return image, gtbox
    
    def get_boxes_ious_xywh(box1, box2):
        """
        功能:
            计算边框列表的交并比
        输入:
            box1 - 边框列表1
            box2 - 边框列表2
        输出:
            ious - 交并比值列表
        """
        # 判断边框维度
        assert box1.shape[-1] == 4, "Box1 shape[-1] should be 4."
        assert box2.shape[-1] == 4, "Box2 shape[-1] should be 4."
        
        # 计算交集面积
        x1_min = box1[:, 0] - box1[:, 2]/2.0
        y1_min = box1[:, 1] - box1[:, 3]/2.0
        x1_max = box1[:, 0] + box1[:, 2]/2.0
        y1_max = box1[:, 1] + box1[:, 3]/2.0
        
        x2_min = box2[:, 0] - box2[:, 2]/2.0
        y2_min = box2[:, 1] - box2[:, 3]/2.0
        x2_max = box2[:, 0] + box2[:, 2]/2.0
        y2_max = box2[:, 1] + box2[:, 3]/2.0
        
        x_min = np.maximum(x1_min, x2_min)
        y_min = np.maximum(y1_min, y2_min)
        x_max = np.minimum(x1_max, x2_max)
        y_max = np.minimum(y1_max, y2_max)
        
        w = np.maximum(x_max - x_min, 0.0)
        h = np.maximum(y_max - y_min, 0.0)
        
        intersection = w * h # 交集面积
        
        # 计算并集面积
        s1 = box1[:, 2] * box1[:, 3]
        s2 = box2[:, 2] * box2[:, 3]
        
        union = s1 + s2 - intersection # 并集面积
        
        # 计算交并比值
        ious = intersection / union
        
        return ious
    
    def get_cpbox_item(cpbox, gtbox, gtcls, image_size):
        """
        功能:
            计算裁剪边框
        输入:
            cpbox      - 裁剪边框: 真实位置
            gtbox      - 真实边框: 相对位置
            gtcls      - 真实类别
            image_size - 图像尺寸
        输出:
            gtbox      - 裁剪边框
            gtcls      - 裁剪类别
            box_number - 边框数量
        """
        # 拷贝真实边框列表
        gtbox = gtbox.copy() # 防止本次返回失败, 影响下次真实边框值
        gtcls = gtcls.copy() # 防止本次返回失败, 影响下次真实类别值
        
        # 转换真实边框位置:x1, y1, x2, y2
        image_w, image_h = map(float, image_size) # 读取图像宽高
    
        gtbox[:, 0], gtbox[:, 2] = 
            (gtbox[:, 0] - gtbox[:, 2]/2) * image_w, (gtbox[:, 0] + gtbox[:, 2]/2) * image_w # 先计算右边,再赋值左边,避免影响左边原始值
        gtbox[:, 1], gtbox[:, 3] = 
            (gtbox[:, 1] - gtbox[:, 3]/2) * image_h, (gtbox[:, 1] + gtbox[:, 3]/2) * image_h # 原址计算避免内存拷贝,分开计算,需要先拷贝gtbox
        
        # 转换裁剪边框位置
        cpbox_x, cpbox_y, cpbox_w, cpbox_h = map(float, cpbox) # 读取边框位置
        cpbox = np.array([cpbox_x, cpbox_y, cpbox_x + cpbox_w, cpbox_y + cpbox_h]) # 转换为ndarray格式:xyxy格式
        
        # 计算真实边框中心在裁剪边框内的个数的掩码
        gtbox_centers = (gtbox[:, :2] + gtbox[:, 2:]) / 2.0 # 计算所有真实边框的中心位置: x=(x1+x2)/2, y=(y1+y2)/2
        
        mask = np.logical_and(cpbox[:2] <= gtbox_centers, gtbox_centers <= cpbox[2:]).all(axis=1) # 中心位置掩码:真实边框中心是否在裁剪边框内
        
        # 计算裁剪边框位置
        gtbox[:, :2] = np.maximum(gtbox[:, :2], cpbox[:2]) # 计算真实边框与裁剪边框最大x1y1值
        gtbox[:, 2:] = np.minimum(gtbox[:, 2:], cpbox[2:]) # 计算真实边框与裁剪边框最小x2y2值
        gtbox[:, :2] -= cpbox[:2] # 计算真实边框x1y1值, 相对于裁剪边框的左上角位置
        gtbox[:, 2:] -= cpbox[:2] # 计算真实边框x2y2值, 相对于裁剪边框的左上角位置
        
        # 计算裁剪边框中左上角位置小于右下角的掩码
        mask = np.logical_and(mask, (gtbox[:, :2] < gtbox[:, 2:]).all(axis=1) ) # 裁剪边框的右下坐标是否在左上角左边之下
        
        # 计算裁剪边框位置
        gtbox = gtbox * np.expand_dims(mask.astype('float32'), axis=1) # 使用掩码计算合格的裁剪边框
        
        # 转换相对边框位置: x, y, w, h
        gtbox[:, 0], gtbox[:, 2] = 
            (gtbox[:, 0] + gtbox[:, 2])/2 / cpbox_w, (gtbox[:, 2] - gtbox[:, 0]) / cpbox_w # 先计算右边,再赋值左边,避免影响左边原始值
        gtbox[:, 1], gtbox[:, 3] = 
            (gtbox[:, 1] + gtbox[:, 3])/2 / cpbox_h, (gtbox[:, 3] - gtbox[:, 1]) / cpbox_h # 原址计算避免内存拷贝,分开计算,需要先拷贝gtbox
        
        # 计算裁剪边框类别
        gtcls = gtcls * mask.astype('float32') # 使用掩码计算合格的裁剪边框
        
        # 计算裁剪边框数量
        box_number = mask.sum() # 统计掩码为1的数量
        
        return gtbox, gtcls, box_number
    
    def random_crop_image(image, gtbox, gtcls, crop_scope=[0.3, 1.0], max_ratio=2.0, iou_scopes=None, max_trial=50):
        """
        功能: 
            随机裁剪图像
        输入: 
            image      - 图像数据
            gtbox      - 边框列表
            gtcls      - 类别列表
            crop_scope - 裁剪比例范围
            max_ratio  - 最大裁剪比例
            iou_scopes - 交并比值范围
            max_trial  - 最大试验次数
        输出:
            image      - 图像数据
            gtbox      - 边框列表
            gtcls      - 类别列表
        """
        # 是否存在物体
        if random.random() > 0.5:
            return image, gtbox, gtcls
        if len(gtbox) == 0: # 如果物体边框数为零, 则返回原始图像数据项目
            return image, gtbox, gtcls
        
        # 交并比值范围
        if not iou_scopes: # 如果不存在交并比值范围列表, 则使用交默认并比值范围列表
            iou_scopes = [(0.1, 1.0), (0.3, 1.0), (0.5, 1.0), (0.7, 1.0), (0.9, 1.0), (0.0, 1.0)]
        
        # 转换图像格式
        image = Image.fromarray(image) # 转换为Image格式
        image_w, image_h = image.size # 获取图像宽高
        
        # 计算裁剪边框
        cpbox_list = [(0, 0, image_w, image_h)] # 裁剪边框列表
        
        for min_iou, max_iou in iou_scopes:
            for i in range(max_trial):
                # 随机生成裁剪比例
                crop_ratio = random.uniform(crop_scope[0], crop_scope[1]) # 随机生成裁剪宽高比例
                aspect_ratio = random.uniform(
                    max(1 / max_ratio, (crop_ratio * crop_ratio)),
                    min(max_ratio, 1 / (crop_ratio * crop_ratio))) # 随机生成裁剪纵横比例
                
                # 计算裁剪边框位置
                crop_h = int(image_h * crop_ratio / np.sqrt(aspect_ratio)) # 计算随机生成的裁剪高度
                crop_w = int(image_w * crop_ratio * np.sqrt(aspect_ratio)) # 计算随机生成的裁剪宽度
                
                crop_x = random.randint(0, image_w - crop_w) # 随机生成裁剪x坐标
                crop_y = random.randint(0, image_h - crop_h) # 随机生成裁剪y坐标
                
                # 计算边框交并集值
                cpbox = np.array(
                    [[(crop_x + crop_w / 2.0) / float(image_w),
                      (crop_y + crop_h / 2.0) / float(image_h),
                      crop_w / float(image_w),
                      crop_h / float(image_h)]]) # 计算裁剪边框相对位置: xywh格式
                
                ious = get_boxes_ious_xywh(cpbox, gtbox) # 裁剪边框形状为(1, 4), 真实边框形状为(50, 4)
                
                # 添加裁剪边框列表
                if min_iou <= ious.min() and ious.max() <= max_iou: # 如果符合交并比值范围, 则添加一个裁剪边框, 并结束循环
                    cpbox_list.append((crop_x, crop_y, crop_w, crop_h)) # 裁剪边框为真实位置
                    break
                    
        # 随机裁剪图像
        for i in range(len(cpbox_list)):
            # 弹出裁剪边框
            cpbox = cpbox_list.pop(random.randint(0, len(cpbox_list) - 1)) # 随机弹出裁剪边框
            
            # 计算裁剪边框
            crop_boxes, crop_gtcls, box_number = get_cpbox_item(cpbox, gtbox, gtcls, image.size)
            
            # 开始裁剪图像
            if box_number > 0: # 如果裁剪边框数量大于0, 则裁剪图像,并结束循环
                image = image.crop( (cpbox[0], cpbox[1], cpbox[0] + cpbox[2], cpbox[1] + cpbox[3]) ) # 用真实位置裁剪图像
                image = image.resize(image.size, Image.LANCZOS) # 高质量缩放到原来大小
                
                gtbox = crop_boxes # 返回裁剪边框
                gtcls = crop_gtcls # 返回裁剪类别
                
                break
        
        # 转换图像格式
        image = np.asarray(image) # 转换为ndarray格式
        
        return image, gtbox, gtcls
    
    def random_interpolate_image(image, scale_size, interpolation=None):
        """
        功能: 
            随机插值图像
        输入: 
            image         - 图像数据
            scale_size    - 缩放宽高
            interpolation - 插值方法
        输出:
            image         - 图像数据
        """
        # 转换图像格式
        image = Image.fromarray(image) # 转换为Image格式
        
        # 随机缩放图像
        interpolation_method = [Image.NEAREST,Image.BILINEAR ,Image.BICUBIC, Image.LANCZOS] # 插值方法列表
        
        if not interpolation or interpolation not in interpolation_method:
            interpolation = interpolation_method[random.randint(0, len(interpolation_method) - 1)] # 随机选取插值方法
        
        image = image.resize(scale_size, interpolation)
        
        # 转换图像格式
        image = np.asarray(image) # 转换为ndarray格式,数据类型为HWC,uint8类型
        
        return image
    
    def random_flip_image(image, gtbox):
        """
        功能: 
            随机翻转图像
        输入: 
            image - 图像数据
            gtbox - 边框列表
        输出:
            image - 图像数据
            gtbox - 边框列表
        """
        if random.random() > 0.5:
            image = image[:, ::-1, :] # 水平翻转图像, 列全部倒序排列
            gtbox[:, 0] = 1.0 - gtbox[:, 0] # 水平翻转边框, x坐标变为相反数
            
        return image, gtbox
    
    def random_shuffle_gtbox(gtbox, gtcls):
        """
        功能: 
            随机打乱边框
        输入: 
            gtbox - 边框列表
            gtcls - 类别列表
        输出:
            gtbox - 边框列表
            gtcls - 类别列表
        """
        # 连接边框和类别
        data_list = np.concatenate([gtbox, gtcls[:, np.newaxis]], axis=1)
        
        # 打乱列表的顺序
        index = np.arange(data_list.shape[0])
        np.random.shuffle(index)
        
        data_list = data_list[index, :]
        
        # 保存边框和类别
        gtbox = data_list[:, :4]
        gtcls = data_list[:, -1]
        
        return gtbox, gtcls
    
    def augment_image(image, gtbox, gtcls, scale_size):
        """
        功能:
            增强图像
        输入:
            image      - 图像数据
            gtbox      - 边框列表
            gtcls      - 类别列表
            scale_size - 缩放尺寸
        输出:
            image      - 图像数据
            gtbox      - 边框列表
            gtcls      - 类别列表
        """
        # 随机变换图像
        image = random_distort_image(image)
        
        # 随机填充图像
        image, gtbox = random_expand_image(image, gtbox)
        
        # 随机裁剪图像
        image, gtbox, gtcls = random_crop_image(image, gtbox, gtcls)
            
        # 随机插值图像
        image = random_interpolate_image(image, scale_size)
        
        # 随机翻转图像
        image, gtbox = random_flip_image(image, gtbox)
        
        # 随机打乱边框
        gtbox, gtcls = random_shuffle_gtbox(gtbox, gtcls)
        
        return image, gtbox, gtcls
    
    ##############################################################################################################
    
    def get_scale_size(mode):
        """
        功能:
            获取缩放尺寸
        输入:
            mode       - 获取模式
        输出:
            scale_size - 缩放尺寸
        """
        if (mode == 'train') or (mode == 'valid'): # 如果是训练或验证模式, 随机生成缩放宽高
            scale_size = 320 + 32 * random.randint(0, 9) # 随机生成宽高,范围为[320, 608],步进为32
        else:
            scale_size = 608
        
        scale_size = (scale_size, scale_size) # 组合缩放宽高[w,h]
        
        return scale_size
    
    def get_data_array(batch_data):
        """
        功能:
            将数据列表转换为数组构成的元组
        输入:
            batch_data       - 批次数据列表
        输出:
            image_array      - 图像数据数组
            gtbox_array      - 物体边框数组
            gtcls_array      - 图像类别数组
            image_size_array - 图像宽高数组
        """
        image_array = np.array([item[0] for item in batch_data], dtype='float32')
        gtbox_array = np.array([item[1] for item in batch_data], dtype='float32')
        gtcls_array = np.array([item[2] for item in batch_data], dtype='int32')
        image_size_array = np.array([item[3] for item in batch_data], dtype='int32')
        
        return image_array, gtbox_array, gtcls_array, image_size_array
    
    def data_reader(data, scale_size):
        """
        功能:
            读取一条数据
        输入:
            data       - 数据项目
            scale_size - 缩放图像宽高
        输出:
            image      - 图像数据
            gtbox      - 边框列表
            gtcls      - 类别列表
            image_size - 原始图像高宽
        """
        # 读取数据
        image, gtbox, gtcls, image_size = get_data_item(data)
    
        # 增强图像
        image, gtbox, gtcls = augment_image(image, gtbox, gtcls, scale_size)
        
        # 减去均值
        mean = np.array([0.485, 0.456, 0.406]).reshape((1, 1, -1)) # COCO数据集通道平均值
        stdv = np.array([0.229, 0.224, 0.225]).reshape((1, 1, -1)) # COCO数据集通道标准差
        
        image = (image/255.0 - mean) / stdv # 对图像进行归一化
        image = image.astype('float32').transpose((2, 0, 1)) # 转换图片格式:[H,W,C]到[C,H,W]
        
        return image, gtbox, gtcls, image_size
    
    def single_thread_reader(data_path, batch_size=8, mode='train'):
        """
        功能:
            单线程读取批次数据
        输入:
            data_path  - 数据集路径
            batch_size - 每批数据大小
            mode       - 读取数据模式: train或valid
        输出:
            reader     - 数据读取器
        """
        # 读取数据列表
        data_list = get_data_list(data_path)
        
        # 读取数据项目
        def reader():
            # 设置读取模式
            if mode == 'train': # 如果是训练模式, 则打乱数据列表
                np.random.shuffle(data_list)
            scale_size = get_scale_size(mode) # 设置缩放宽高
            
            # 输出批次数据
            batch_data = [] # 批次数据
            
            for item in data_list:
                # 读取数据
                image, gtbox, gtcls, image_size = data_reader(item, scale_size)
                
                # 输出数据
                batch_data.append((image, gtbox, gtcls, image_size))
                if len(batch_data) == batch_size: # 如果压入一批数据, 则弹出数据
                    # 弹出数据
                    yield get_data_array(batch_data)
                    
                    # 重置数据
                    batch_data = []
                    image_size = get_scale_size(mode)
    
            # 输出剩余数据
            if len(batch_data) > 0:
                yield get_data_array(batch_data)
        
        return reader
    
    def multip_thread_reader(data_path, batch_size=8, mode='train'):
        """
        功能:
            多线程读取批次数据
        输入:
            data_path  - 数据集路径
            batch_size - 每批数据大小
            mode       - 读取数据模式: train 或 valid
        输出:
            reader     - 数据读取器
        """
        # 读取数据列表
        data_list = get_data_list(data_path)
        
        # 读取数据项目
        def item_loader():
            # 设置读取模式
            if mode == 'train': # 如果是训练模式, 则打乱数据列表
                np.random.shuffle(data_list)
            scale_size = get_scale_size(mode) # 设置缩放宽高
            
            # 输出批次项目
            batch_item = [] # 批次项目
            
            for item in data_list:
                # 输出项目
                batch_item.append((item, scale_size))
                if len(batch_item) == batch_size: # 如果压入一批项目, 则弹出项目
                    # 弹出数据
                    yield batch_item
                    
                    # 重置项目
                    batch_item = []
                    image_size = get_scale_size(mode)
    
            # 输出剩余数据
            if len(batch_item) > 0:
                yield batch_item
        
        # 读取数据内容
        def data_loader(batch_item):
            batch_data = [] # 批次数据
            
            for item, scale_size in batch_item:
                image, gtbox, label, image_size = data_reader(item, scale_size)
                batch_data.append((image, gtbox, label, image_size))
            
            return get_data_array(batch_data)
        
        # 多线程读取器
        reader = paddle.reader.xmap_readers(data_loader, item_loader, process_num=4, buffer_size=16)
        
        return reader
    
    ##############################################################################################################
    
    def single_test_reader(image_path, scale_size=(608, 608)):
        """
        功能:
            读取一张预测图像
        输入:
            image_path - 图像路径
            scale_size - 缩放高宽
        输出:
            image      - 图像数据
            image_size - 原始高宽
        """
        # 读取图像
        image = Image.open(image_path)                   # 读取图像
        if image.mode != 'RGB':
            image = image.convert('RGB')
        image_size = (image.size[0], image.size[1])      # 读取尺寸
        image = image.resize(scale_size, Image.BILINEAR) # 缩放图像
        
        # 转换格式
        image = np.array(image, dtype='float32')
        image_size = np.array(image_size, dtype='int32')
        
        # 减去均值
        mean = np.array([0.485, 0.456, 0.406]).reshape((1, 1, -1)) # COCO数据集通道平均值
        stdv = np.array([0.229, 0.224, 0.225]).reshape((1, 1, -1)) # COCO数据集通道标准差
        
        image = (image/255.0 - mean) / stdv                  # 对图像进行归一化
        image = image.astype('float32').transpose((2, 0, 1)) # 转换图片格式:[H,W,C]到[C,H,W]
        
        # 增加维度
        image = np.expand_dims(image, axis=0) # 增加维度: (1,3,608,608)
        image_size = np.expand_dims(image_size, axis=0) # 增加维度: (1,2)
        
        return image, image_size
    
    def display_infer(infer, image_path):
        """
        功能:
            显示预测结果
        输入:
            infer       - 预测结果
            image_path  - 图像路径
        输出:
        """
        # 读取图像
        image = Image.open(image_path)   # 读取图像
        if image.mode != 'RGB':
            image = image.convert('RGB') # 转换格式
        
        # 绘制结果
        object_names = ['Boerner','Leconte','Linnaeus','acuminatus','armandi','coleoptera','linnaeus'] # 物体名称
        color = ['r', 'g', 'b', 'c','m', 'y', 'k']                                                     # 边框颜色
        
        plt.figure(figsize=(10, 10)) # 设置显示图像大小
        currentAxis = plt.gca()      # 获取图像当前坐标
        
        for item in infer: # 遍历预测结果
            # 获取结果
            index = int(item[0])        # 类别索引
            names = object_names[index] # 类别名称
            pdbox = item[2:6]           # 边框位置
            
            # 绘制边框
            rectangle = patches.Rectangle(                                       # 设置边框
                (pdbox[0], pdbox[1]), pdbox[2]-pdbox[0]+1, pdbox[3]-pdbox[1]+1, 
                linewidth=1, edgecolor=color[index], facecolor=color[index], 
                fill=False, linestyle='-')
            currentAxis.add_patch(rectangle)                                     # 绘制边框
            plt.text(pdbox[0], pdbox[1], names, fontsize=12, color=color[index]) # 绘制类别
        
        # 显示结果
        plt.imshow(image)
        plt.show()
    
    def get_test_array(batch_data):
        """
        功能:
            将数据列表转换为数组构成的元组
        输入:
            batch_data       - 数据列表
        输出:
            image_name_array - 图像名字数组
            image_array      - 图像数据数组
            image_size_array - 图像宽高数组
        """
        image_name_array = np.array([item[0] for item in batch_data])
        image_array = np.array([item[1] for item in batch_data], dtype='float32')
        image_size_array = np.array([item[2] for item in batch_data], dtype='int32')
        
        return image_name_array, image_array, image_size_array
    
    def multip_test_reader(data_path, batch_size=1, scale_size=(608, 608)):
        """
        功能:
            读取一批测试数据
        输入:
            data_path  - 数据目录
            batch_size - 每批大小
            scale_size - 缩放高宽
        输出:
            image_name - 图像名称
            image      - 图像数据
            image_size - 原始高宽
        """
        # 读取数据列表
        data_list = os.listdir(data_path)
        
        # 读取一批数据
        def reader():
            # 输出批次数据
            batch_data = [] # 批次数据
            
            for image_name in data_list:
                # 读取图像路径
                image_path = os.path.join(data_path, image_name)
                
                # 读取一张图像
                image = Image.open(image_path)                   # 读取图像
                if image.mode != 'RGB':
                    image = image.convert('RGB')
                image_size = (image.size[0], image.size[1])      # 图像大小
                image = image.resize(scale_size, Image.BILINEAR) # 缩放图像
                
                # 转换格式
                image = np.array(image, dtype='float32')
                image_size = np.array(image_size, dtype='int32')
    
                # 减去均值
                mean = np.array([0.485, 0.456, 0.406]).reshape((1, 1, -1)) # COCO数据集通道平均值
                stdv = np.array([0.229, 0.224, 0.225]).reshape((1, 1, -1)) # COCO数据集通道标准差
    
                image = (image/255.0 - mean) / stdv                  # 对图像进行归一化
                image = image.astype('float32').transpose((2, 0, 1)) # 转换图片格式:[H,W,C]到[C,H,W]
                
                # 输出数据
                batch_data.append((image_name.split('.')[0], image, image_size))
                if len(batch_data) == batch_size: # 如果压入一批数据, 则弹出数据
                    # 弹出数据
                    yield get_test_array(batch_data)
                    
                    # 重置数据
                    batch_data = []
    
            # 输出剩余数据
            if len(batch_data) > 0:
                yield get_test_array(batch_data)

    参考资料:

    https://aistudio.baidu.com/aistudio/projectdetail/742781

    https://aistudio.baidu.com/aistudio/projectdetail/672017

    https://aistudio.baidu.com/aistudio/projectdetail/868589

    https://aistudio.baidu.com/aistudio/projectdetail/122277

  • 相关阅读:
    mysql性能调优
    java面试大全
    JVM调优总结
    大数据行业跳槽面试前你需要做什么
    什么是分布式锁?实现分布式锁的方式
    如何保障mysql和redis之间的数据一致性?
    数据倾斜的原因和解决方案
    hive优化
    c# 系统换行符
    12种增强CSS技能并加快开发速度的资源
  • 原文地址:https://www.cnblogs.com/d442130165/p/13683863.html
Copyright © 2020-2023  润新知