• pointNet代码


    介绍

    组成

    1.PointNet classification network分类网络

    1. part segmentation network

    数据集

    1.point clouds sampled from 3D shapes
    2.ShapeNetPart dataset.

    结构

    其主要分成以下三部分:

    • 数据处理
    • model构建
    • 结果选择

    数据处理

    将点云处理成程序可用的格式,具体实现在 provider.py 中,主要包含了数据下载、预处理(shuffle->rotate->jitter)、格式转换(hdf5->txt)

    shuffle

    def shuffle_data(data, labels):
        """ Shuffle data and labels.
            Input:
              data: B,N,... numpy array
              label: B,... numpy array
            Return:
              shuffled data, label and shuffle indices
        """
        idx = np.arange(len(labels))#返回一个列表
        # print('idx=',idx)#idx= [   0    1    2 ... 2045 2046 2047]
        np.random.shuffle(idx)#把idx进行shuffle
        # print('idx=', idx)
        return data[idx, ...], labels[idx], idx

    rotate旋转处理

    def rotate_point_cloud(batch_data):
        # print('batch data shape=',batch_data.shape)#(32, 1024, 3)
        rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
        for k in range(batch_data.shape[0]):
            rotation_angle = np.random.uniform() * 2 * np.pi#生成一个随机数
            cosval = np.cos(rotation_angle)
            sinval = np.sin(rotation_angle)
            rotation_matrix = np.array([[cosval, 0, sinval],
                                        [0, 1, 0],
                                        [-sinval, 0, cosval]])
            shape_pc = batch_data[k, ...]
            rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
            #先让shape_pc的形状变成(?,3),因为旋转矩阵为(3,3)
        return rotated_data

    jitter抖动处理

    def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
        B, N, C = batch_data.shape
        assert(clip > 0)
        jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip)#将数组范围限制在(-1*clip, clip)
        jittered_data += batch_data
        return jittered_data

    model构建

    Feature transform net

    with tf.variable_scope('transform_net1') as sc:#T-net
        transform = input_transform_net(point_cloud, is_training, bn_decay, K=3)
    print('point cloud=',point_cloud)#(32, 1024, 3)
    # print('input transform=',transform)#(32, 3, 3)
    point_cloud_transformed = tf.matmul(point_cloud, transform)
    # print('point_cloud_transformed=',point_cloud_transformed)#(32, 1024, 3)

    mlp(64,128,1024)

    net = tf_util.conv2d(net_transformed, 64, [1,1],
                             padding='VALID', stride=[1,1],
                             bn=True, is_training=is_training,
                             scope='conv3', bn_decay=bn_decay)
    print('net3=',net)#(32, 1024, 1, 64)
    net = tf_util.conv2d(net, 128, [1,1],
                             padding='VALID', stride=[1,1],
                             bn=True, is_training=is_training,
                             scope='conv4', bn_decay=bn_decay)
    print('net4=',net)#(32, 1024, 1, 128)
    net = tf_util.conv2d(net, 1024, [1,1],
                             padding='VALID', stride=[1,1],
                             bn=True, is_training=is_training,
                             scope='conv5', bn_decay=bn_decay)
    print('net5=',net)#(32, 1024, 1, 1024)

    类别投票

    实现方法

    batch_pred_sum.shape=(?,40) # 每个data对40个类的可能性

    pred_val.shape=(?,) # 每个data所属的可能性最大的类

     pred_val = np.argmax(batch_pred_sum, 1)
     #返回沿轴axis最大值的索引,即得到预测值最大的那一类的idx(label)

    评估

    输出(预测label,真实label)

    </dump/pred_label.txt>

    4, 4    
    0, 0
    2, 2
    8, 8
    14, 23
    ...
    <shape_names.txt>
    
    airplane
    bathtub
    bed
    bench
    bookshelf
    bottle
    bowl
    car
    chair
    cone
    cup

    保存预测错误的图片,并可视化

    </dump/xxxx_pred_name.jpg>
    命名=第几个预测错误的图片+真实label+预测label

    例子 /dump/1028_label_bed_pred_sofa.jpg

     三张点云图片,分别是当前点云数据旋转三个不同角度之后的样子

    save code

      for i in range(start_idx, end_idx):
            l = current_label[i]
            total_seen_class[l] += 1
            total_correct_class[l] += (pred_val[i-start_idx] == l)
            fout.write('%d, %d
    ' % (pred_val[i-start_idx], l))
            # print('!!!!!!!!!!','%d, %d
    ' % (pred_val[i-start_idx], l))
            if pred_val[i-start_idx] != l and FLAGS.visu: # ERROR CASE, DUMP!如果预测错了
                img_filename = '%d_label_%s_pred_%s.jpg' % (error_cnt, SHAPE_NAMES[l],
                                                       SHAPE_NAMES[pred_val[i-start_idx]])
                #第几个预测错误的图片+真实label+预测label
                img_filename = os.path.join(DUMP_DIR, img_filename)
                output_img = pc_util.point_cloud_three_views(np.squeeze(current_data[i, :, :]))
                scipy.misc.imsave(img_filename, output_img)
                error_cnt += 1

    画点云图的code

    draw_point_cloud()
    Input:
    points: Nx3 numpy array
    Output:
    gray image

    记录loss,预测精确度

    /dump/log_evaluate.txt

    eval mean loss: 1.816358
    eval accuracy: 0.501216
    eval avg class acc: 0.421297
      airplane: 0.980
       bathtub: 0.440
           bed: 0.940
         bench: 0.450
         ...
  • 相关阅读:
    vue点击元素变色兄弟元素不变色
    获取今天昨天本月的时间段
    java.io.InputStream -- 1.8 初识,应用场景待更新
    java.io.FilterInputStream
    java.io.FileInputStream
    java.io.ByteArrayInputStream -- 1.8
    JavaBeans -- 1.8
    mysql 导出和导入数据
    tp5 数据库迁移工具 migrate&seed
    tp5模型一对一关联hasOne
  • 原文地址:https://www.cnblogs.com/yibeimingyue/p/12005683.html
Copyright © 2020-2023  润新知