• yolo-tensorflow复现解析


    看到有人使用tensorflow复现了yoloV3,来此记录下代码阅读。感觉复现的代码写的不是很好,会加一部分其他人用keras复现的代码。

    tensorflow代码地址:https://blog.csdn.net/IronMastiff/article/details/79940118

    源代码分为以下几部分:

    Train.py为主程序train.py部分为训练自己的数据集,eval.py为利用训练好的权重来进行预测。Reader为读取数据标签等,config.yml为训练过程中的一些参数设置,eval_config.yml为预测过程中的一些参数设置。Utils包为其中的一些网络结构,IOU等中间步骤。下面先介绍utils的中的程序

    Net.py: 设置网络结构,提取图片信息。Darknet-53网络结构如下。1x,2x等分别表示该结构重复了1次两次等,Residual表示和前面方框外的结构进行按维度叠加,类似于残差网络。

    feature_extractor 函数该提取三个尺度的信息scale1,scale2,scale3,分别为倒数第1,2,3个方框中网络结构的输出。之后,scales函数通过1x1,3x3卷积分别对三个尺度的特征单元进行特征交互。返回交互后的三个尺度信息。
    最后输出的三个不同尺度分别为 13X13X75,26X26X75,52X52X75具体交互信息可参考代码结构。最后训练时会选取不同scale参与计算loss,select_things就是选取不同scale的特征。
    不太理解如何使用?是需要手动更改配置文件的scale选择?如何选择?

     

    IOU.py 为NMS筛选anchor,用来参与最后的loss计算get_loss.py.先来解释IOU

     1 def IOU_calculator( x, y, width, height, l_x, l_y, l_width, l_height ):
     2     '''
     3     x,y,width,height分别为预测框的中心坐标及宽,高,l_x, l_y, l_width, l_height分别为真实框的中心坐标及宽,高
     4     '''
     5     ##x_min=x-w/2,y_min=y-h/2,x_max=x+w/2,y_max=y+h/2 此段意义为分别求出四个角坐标
     6     x_max = calculate_max( x , width / 2 )
     7     y_max = calculate_max( y, height / 2 )
     8     x_min = calculate_min( x, width / 2 )
     9     y_min = calculate_min( y, height / 2 )
    10 
    11     l_x_max = calculate_max( l_x, width / 2 )
    12     l_y_max = calculate_max( l_y, height / 2 )
    13     l_x_min = calculate_min( l_x, width / 2 )
    14     l_y_min = calculate_min( l_y, height / 2 )
    15 
    16     '''求相交部分的面积'''
    17     xend = tf.minimum( x_max, l_x_max )
    18     xstart = tf.maximum( x_min, l_x_min )
    19 
    20     yend = tf.minimum( y_max, l_y_max )
    21     ystart = tf.maximum( y_min, l_y_min )
    22 
    23     area_width = xend - xstart
    24     area_height = yend - ystart
    25 
    26     '''IOU=A & B/(A+B-A & B)若A与B交集为0,则返回1e-8'''
    27     area = area_width * area_height
    28 
    29     all_area = tf.cond( ( width * height + l_width * l_height - area ) <= 0, lambda : tf.cast( 1e-8, tf.float32 ), lambda : ( width * height + l_width * l_height - area ) )
    30 
    31     IOU = area / all_area
    32 
    33     IOU = tf.cond( area_width < 0, lambda : tf.cast( 1e-8, tf.float32 ), lambda : IOU )
    34     IOU = tf.cond( area_height < 0, lambda : tf.cast( 1e-8, tf.float32 ), lambda : IOU )
    35 
    36     return IOU

    get_loss.py为计算损失函数,他的损失函数计算是按照yolov1来计算的,有点问题。

     1 def objectness_loss( input, switch, l_switch, alpha = 0.5 ):
     2     '''
     3     input为IOU,switch为若预测该框内有object则为1,否则为0,l_switch为实际该框有object则为1,否则为0
     4     '''
     5 
     6     IOU_loss = tf.square( l_switch - input * switch )  ##input * switch类别置信度C
     7     loss_max = tf.square( l_switch * 0.5 - input * switch )
     8 
     9     IOU_loss = tf.cond( IOU_loss < loss_max, lambda : tf.cast( 1e-8, tf.float32 ), lambda : IOU_loss )
    10 
    11     IOU_loss = tf.cond( l_switch < 1, lambda : IOU_loss * alpha, lambda : IOU_loss )
    12 
    13     return IOU_loss
    14 
    15 def location_loss( x, y, width, height, l_x, l_y, l_width, l_height, alpha = 5 ):
    16     point_loss = ( tf.square( l_x - x ) + tf.square( l_y - y ) ) * alpha
    17     size_loss = ( tf.square( tf.sqrt( l_width ) - tf.sqrt( width ) ) + tf.square( tf.sqrt( l_height ) - tf.sqrt( height ) ) ) * alpha
    18 
    19     location_loss = point_loss + size_loss
    20 
    21     return location_loss
    22 
    23 def class_loss( inputs, labels ):
    24     classloss = tf.square( labels - inputs )
    25     loss_sum = tf.reduce_sum( classloss )
    26 
    27     return loss_sum

    接下来是提取训练数据的程序extract_labels.py 可下载pascal voc数据集,对照数据的格式来读数据。比较麻烦,但是这个也是训练程序与预测程序最大的不同点,这份代码最大的亮点也在此,其他部分实现个人感觉并不是很好。数据既可以读取类别标签,也可读取物体框的信息。

    粗略的写了一份程序解读,因为只能找到一个tensorflow代码实现,个人认为不是很好,希望有人有比较好的复现可以说一下。同时学好C++很重要呀,就可以直接读取源码了。



  • 相关阅读:
    [翻译] GCDObjC
    [翻译] ValueTrackingSlider
    [翻译] SWTableViewCell
    使用 NSPropertyListSerialization 持久化字典与数组
    [翻译] AsyncImageView 异步下载图片
    KVC中setValuesForKeysWithDictionary:
    [翻译] Working with NSURLSession: AFNetworking 2.0
    数据库引擎
    什么是数据库引擎
    网站添加百度分享按钮代码实例
  • 原文地址:https://www.cnblogs.com/the-home-of-123/p/9733855.html
Copyright © 2020-2023  润新知