• 3D MinkowskiEngine稀疏模式重建


    3D MinkowskiEngine稀疏模式重建

    本文看一个简单的演示示例,该示例训练一个3D卷积神经网络,该网络用一个热点向量one-hot vector重构3D稀疏模式。这类似于Octree生成网络ICCV'17。输入的one-hot vector一热向量,来自ModelNet40数据集的3D计算机辅助设计(CAD)椅子索引。

    使用MinkowskiEngine.MinkowskiConvolutionTranspose和 MinkowskiEngine.MinkowskiPruning,依次将体素上采样2倍,然后删除一些上采样的体素,以生成目标形状。常规的网络体系结构看起来类似于下图,但是细节可能有所不同。

     

     在继续之前,请先阅读训练和数据加载

    创建稀疏模式重建网络

    要从矢量创建3D网格世界中定义的稀疏张量,需要从 1×1×1分辨率体素。本文使用一个由块MinkowskiEngine.MinkowskiConvolutionTransposeMinkowskiEngine.MinkowskiConvolutionMinkowskiEngine.MinkowskiPruning

    在前进过程forward pass中,为1)主要特征和2)稀疏体素分类创建两条路径,以删除不必要的体素。

    out = upsample_block(z)

    out_cls = classification(out).F

    out = pruning(out, out_cls > 0)

    在输入的稀疏张量达到目标分辨率之前,网络会重复执行一系列的上采样和修剪操作,以去除不必要的体素。在下图上可视化结果。注意,最终的重建非常精确地捕获了目标几何体。还可视化了上采样和修剪的分层重建过程。

     运行示例

    要训​​练网络,请转到Minkowski Engine根目录,然后键入:

    python -m examples.reconstruction --train

    要可视化网络预测或尝试预先训练的模型,请输入:

    python -m examples.reconstruction

     

    该程序将可视化两个3D形状。左边的一个是目标3D形状,右边的一个是重构的网络预测。

    完整的代码可以在example / reconstruction.py找到。

    import os

     

    import sys

     

    import subprocess

     

    import argparse

     

    import logging

     

    import glob

     

    import numpy as np

     

    from time import time

     

    import urllib

     

    # Must be imported before large libs

     

    try:

     

    import open3d as o3d

     

    except ImportError:

     

    raise ImportError('Please install open3d and scipy with `pip install open3d scipy`.')

       
     

    import torch

     

    import torch.nn as nn

     

    import torch.utils.data

     

    import torch.optim as optim

       
     

    import MinkowskiEngine as ME

       
     

    from examples.modelnet40 import InfSampler, resample_mesh

       
     

    M = np.array([[0.80656762, -0.5868724, -0.07091862],

     

    [0.3770505, 0.418344, 0.82632997],

     

    [-0.45528188, -0.6932309, 0.55870326]])

       
     

    assert int(

     

    o3d.__version__.split('.')[1]

     

    ) >= 8, f'Requires open3d version >= 0.8, the current version is {o3d.__version__}'

       
     

    if not os.path.exists('ModelNet40'):

     

    logging.info('Downloading the fixed ModelNet40 dataset...')

     

    subprocess.run(["sh", "./examples/download_modelnet40.sh"])

       
       
     

    ###############################################################################

     

    # Utility functions

     

    ###############################################################################

     

    def PointCloud(points, colors=None):

       
     

    pcd = o3d.geometry.PointCloud()

     

    pcd.points = o3d.utility.Vector3dVector(points)

     

    if colors is not None:

     

    pcd.colors = o3d.utility.Vector3dVector(colors)

     

    return pcd

       
       
     

    def collate_pointcloud_fn(list_data):

     

    coords, feats, labels = list(zip(*list_data))

       
     

    # Concatenate all lists

     

    return {

     

    'coords': coords,

     

    'xyzs': [torch.from_numpy(feat).float() for feat in feats],

     

    'labels': torch.LongTensor(labels),

     

    }

       
       
     

    class ModelNet40Dataset(torch.utils.data.Dataset):

       
     

    def __init__(self, phase, transform=None, config=None):

     

    self.phase = phase

     

    self.files = []

     

    self.cache = {}

     

    self.data_objects = []

     

    self.transform = transform

     

    self.resolution = config.resolution

     

    self.last_cache_percent = 0

       
     

    self.root = './ModelNet40'

     

    fnames = glob.glob(os.path.join(self.root, 'chair/train/*.off'))

     

    fnames = sorted([os.path.relpath(fname, self.root) for fname in fnames])

     

    self.files = fnames

     

    assert len(self.files) > 0, "No file loaded"

     

    logging.info(

     

    f"Loading the subset {phase} from {self.root} with {len(self.files)} files"

     

    )

     

    self.density = 30000

       
     

    # Ignore warnings in obj loader

     

    o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Error)

       
     

    def __len__(self):

     

    return len(self.files)

       
     

    def __getitem__(self, idx):

     

    mesh_file = os.path.join(self.root, self.files[idx])

     

    if idx in self.cache:

     

    xyz = self.cache[idx]

     

    else:

     

    # Load a mesh, over sample, copy, rotate, voxelization

     

    assert os.path.exists(mesh_file)

     

    pcd = o3d.io.read_triangle_mesh(mesh_file)

     

    # Normalize to fit the mesh inside a unit cube while preserving aspect ratio

     

    vertices = np.asarray(pcd.vertices)

     

    vmax = vertices.max(0, keepdims=True)

     

    vmin = vertices.min(0, keepdims=True)

     

    pcd.vertices = o3d.utility.Vector3dVector(

     

    (vertices - vmin) / (vmax - vmin).max())

       
     

    # Oversample points and copy

     

    xyz = resample_mesh(pcd, density=self.density)

     

    self.cache[idx] = xyz

     

    cache_percent = int((len(self.cache) / len(self)) * 100)

     

    if cache_percent > 0 and cache_percent % 10 == 0 and cache_percent != self.last_cache_percent:

     

    logging.info(

     

    f"Cached {self.phase}: {len(self.cache)} / {len(self)}: {cache_percent}%"

     

    )

     

    self.last_cache_percent = cache_percent

       
     

    # Use color or other features if available

     

    feats = np.ones((len(xyz), 1))

       
     

    if len(xyz) < 1000:

     

    logging.info(

     

    f"Skipping {mesh_file}: does not have sufficient CAD sampling density after resampling: {len(xyz)}."

     

    )

     

    return None

       
     

    if self.transform:

     

    xyz, feats = self.transform(xyz, feats)

       
     

    # Get coords

     

    xyz = xyz * self.resolution

     

    coords = np.floor(xyz)

     

    inds = ME.utils.sparse_quantize(coords, return_index=True)

       
     

    return (coords[inds], xyz[inds], idx)

       
       
     

    def make_data_loader(phase, augment_data, batch_size, shuffle, num_workers,

     

    repeat, config):

     

    dset = ModelNet40Dataset(phase, config=config)

       
     

    args = {

     

    'batch_size': batch_size,

     

    'num_workers': num_workers,

     

    'collate_fn': collate_pointcloud_fn,

     

    'pin_memory': False,

     

    'drop_last': False

     

    }

       
     

    if repeat:

     

    args['sampler'] = InfSampler(dset, shuffle)

     

    else:

     

    args['shuffle'] = shuffle

       
     

    loader = torch.utils.data.DataLoader(dset, **args)

       
     

    return loader

       
       
     

    ch = logging.StreamHandler(sys.stdout)

     

    logging.getLogger().setLevel(logging.INFO)

     

    logging.basicConfig(

     

    format=os.uname()[1].split('.')[0] + ' %(asctime)s %(message)s',

     

    datefmt='%m/%d %H:%M:%S',

     

    handlers=[ch])

       
     

    parser = argparse.ArgumentParser()

     

    parser.add_argument('--resolution', type=int, default=128)

     

    parser.add_argument('--max_iter', type=int, default=30000)

     

    parser.add_argument('--val_freq', type=int, default=1000)

     

    parser.add_argument('--batch_size', default=16, type=int)

     

    parser.add_argument('--lr', default=1e-2, type=float)

     

    parser.add_argument('--momentum', type=float, default=0.9)

     

    parser.add_argument('--weight_decay', type=float, default=1e-4)

     

    parser.add_argument('--num_workers', type=int, default=1)

     

    parser.add_argument('--stat_freq', type=int, default=50)

     

    parser.add_argument(

     

    '--weights', type=str, default='modelnet_reconstruction.pth')

     

    parser.add_argument('--load_optimizer', type=str, default='true')

     

    parser.add_argument('--train', action='store_true')

     

    parser.add_argument('--max_visualization', type=int, default=4)

       
     

    ###############################################################################

     

    # End of utility functions

     

    ###############################################################################

       
       
     

    class GenerativeNet(nn.Module):

       
     

    CHANNELS = [1024, 512, 256, 128, 64, 32, 16]

       
     

    def __init__(self, resolution, in_nchannel=512):

     

    nn.Module.__init__(self)

       
     

    self.resolution = resolution

       
     

    # Input sparse tensor must have tensor stride 128.

     

    ch = self.CHANNELS

       
     

    # Block 1

     

    self.block1 = nn.Sequential(

     

    ME.MinkowskiConvolutionTranspose(

     

    in_nchannel,

     

    ch[0],

     

    kernel_size=2,

     

    stride=2,

     

    generate_new_coords=True,

     

    dimension=3),

     

    ME.MinkowskiBatchNorm(ch[0]),

     

    ME.MinkowskiELU(),

     

    ME.MinkowskiConvolution(ch[0], ch[0], kernel_size=3, dimension=3),

     

    ME.MinkowskiBatchNorm(ch[0]),

     

    ME.MinkowskiELU(),

     

    ME.MinkowskiConvolutionTranspose(

     

    ch[0],

     

    ch[1],

     

    kernel_size=2,

     

    stride=2,

     

    generate_new_coords=True,

     

    dimension=3),

     

    ME.MinkowskiBatchNorm(ch[1]),

     

    ME.MinkowskiELU(),

     

    ME.MinkowskiConvolution(ch[1], ch[1], kernel_size=3, dimension=3),

     

    ME.MinkowskiBatchNorm(ch[1]),

     

    ME.MinkowskiELU(),

     

    )

       
     

    self.block1_cls = ME.MinkowskiConvolution(

     

    ch[1], 1, kernel_size=1, has_bias=True, dimension=3)

       
     

    # Block 2

     

    self.block2 = nn.Sequential(

     

    ME.MinkowskiConvolutionTranspose(

     

    ch[1],

     

    ch[2],

     

    kernel_size=2,

     

    stride=2,

     

    generate_new_coords=True,

     

    dimension=3),

     

    ME.MinkowskiBatchNorm(ch[2]),

     

    ME.MinkowskiELU(),

     

    ME.MinkowskiConvolution(ch[2], ch[2], kernel_size=3, dimension=3),

     

    ME.MinkowskiBatchNorm(ch[2]),

     

    ME.MinkowskiELU(),

     

    )

       
     

    self.block2_cls = ME.MinkowskiConvolution(

     

    ch[2], 1, kernel_size=1, has_bias=True, dimension=3)

       
     

    # Block 3

     

    self.block3 = nn.Sequential(

     

    ME.MinkowskiConvolutionTranspose(

     

    ch[2],

     

    ch[3],

     

    kernel_size=2,

     

    stride=2,

     

    generate_new_coords=True,

     

    dimension=3),

     

    ME.MinkowskiBatchNorm(ch[3]),

     

    ME.MinkowskiELU(),

     

    ME.MinkowskiConvolution(ch[3], ch[3], kernel_size=3, dimension=3),

     

    ME.MinkowskiBatchNorm(ch[3]),

     

    ME.MinkowskiELU(),

     

    )

       
     

    self.block3_cls = ME.MinkowskiConvolution(

     

    ch[3], 1, kernel_size=1, has_bias=True, dimension=3)

       
     

    # Block 4

     

    self.block4 = nn.Sequential(

     

    ME.MinkowskiConvolutionTranspose(

     

    ch[3],

     

    ch[4],

     

    kernel_size=2,

     

    stride=2,

     

    generate_new_coords=True,

     

    dimension=3),

     

    ME.MinkowskiBatchNorm(ch[4]),

     

    ME.MinkowskiELU(),

     

    ME.MinkowskiConvolution(ch[4], ch[4], kernel_size=3, dimension=3),

     

    ME.MinkowskiBatchNorm(ch[4]),

     

    ME.MinkowskiELU(),

     

    )

       
     

    self.block4_cls = ME.MinkowskiConvolution(

     

    ch[4], 1, kernel_size=1, has_bias=True, dimension=3)

       
     

    # Block 5

     

    self.block5 = nn.Sequential(

     

    ME.MinkowskiConvolutionTranspose(

     

    ch[4],

     

    ch[5],

     

    kernel_size=2,

     

    stride=2,

     

    generate_new_coords=True,

     

    dimension=3),

     

    ME.MinkowskiBatchNorm(ch[5]),

     

    ME.MinkowskiELU(),

     

    ME.MinkowskiConvolution(ch[5], ch[5], kernel_size=3, dimension=3),

     

    ME.MinkowskiBatchNorm(ch[5]),

     

    ME.MinkowskiELU(),

     

    )

       
     

    self.block5_cls = ME.MinkowskiConvolution(

     

    ch[5], 1, kernel_size=1, has_bias=True, dimension=3)

       
     

    # Block 6

     

    self.block6 = nn.Sequential(

     

    ME.MinkowskiConvolutionTranspose(

     

    ch[5],

     

    ch[6],

     

    kernel_size=2,

     

    stride=2,

     

    generate_new_coords=True,

     

    dimension=3),

     

    ME.MinkowskiBatchNorm(ch[6]),

     

    ME.MinkowskiELU(),

     

    ME.MinkowskiConvolution(ch[6], ch[6], kernel_size=3, dimension=3),

     

    ME.MinkowskiBatchNorm(ch[6]),

     

    ME.MinkowskiELU(),

     

    )

       
     

    self.block6_cls = ME.MinkowskiConvolution(

     

    ch[6], 1, kernel_size=1, has_bias=True, dimension=3)

       
     

    # pruning

     

    self.pruning = ME.MinkowskiPruning()

       
     

    def get_batch_indices(self, out):

     

    return out.coords_man.get_row_indices_per_batch(out.coords_key)

       
     

    def get_target(self, out, target_key, kernel_size=1):

     

    with torch.no_grad():

     

    target = torch.zeros(len(out), dtype=torch.bool)

     

    cm = out.coords_man

     

    strided_target_key = cm.stride(

     

    target_key, out.tensor_stride[0], force_creation=True)

     

    ins, outs = cm.get_kernel_map(

     

    out.coords_key,

     

    strided_target_key,

     

    kernel_size=kernel_size,

     

    region_type=1)

     

    for curr_in in ins:

     

    target[curr_in] = 1

     

    return target

       
     

    def valid_batch_map(self, batch_map):

     

    for b in batch_map:

     

    if len(b) == 0:

     

    return False

     

    return True

       
     

    def forward(self, z, target_key):

     

    out_cls, targets = [], []

       
     

    # Block1

     

    out1 = self.block1(z)

     

    out1_cls = self.block1_cls(out1)

     

    target = self.get_target(out1, target_key)

     

    targets.append(target)

     

    out_cls.append(out1_cls)

     

    keep1 = (out1_cls.F > 0).cpu().squeeze()

       
     

    # If training, force target shape generation, use net.eval() to disable

     

    if self.training:

     

    keep1 += target

       
     

    # Remove voxels 32

     

    out1 = self.pruning(out1, keep1.cpu())

       
     

    # Block 2

     

    out2 = self.block2(out1)

     

    out2_cls = self.block2_cls(out2)

     

    target = self.get_target(out2, target_key)

     

    targets.append(target)

     

    out_cls.append(out2_cls)

     

    keep2 = (out2_cls.F > 0).cpu().squeeze()

       
     

    if self.training:

     

    keep2 += target

       
     

    # Remove voxels 16

     

    out2 = self.pruning(out2, keep2.cpu())

       
     

    # Block 3

     

    out3 = self.block3(out2)

     

    out3_cls = self.block3_cls(out3)

     

    target = self.get_target(out3, target_key)

     

    targets.append(target)

     

    out_cls.append(out3_cls)

     

    keep3 = (out3_cls.F > 0).cpu().squeeze()

       
     

    if self.training:

     

    keep3 += target

       
     

    # Remove voxels 8

     

    out3 = self.pruning(out3, keep3.cpu())

       
     

    # Block 4

     

    out4 = self.block4(out3)

     

    out4_cls = self.block4_cls(out4)

     

    target = self.get_target(out4, target_key)

     

    targets.append(target)

     

    out_cls.append(out4_cls)

     

    keep4 = (out4_cls.F > 0).cpu().squeeze()

       
     

    if self.training:

     

    keep4 += target

       
     

    # Remove voxels 4

     

    out4 = self.pruning(out4, keep4.cpu())

       
     

    # Block 5

     

    out5 = self.block5(out4)

     

    out5_cls = self.block5_cls(out5)

     

    target = self.get_target(out5, target_key)

     

    targets.append(target)

     

    out_cls.append(out5_cls)

     

    keep5 = (out5_cls.F > 0).cpu().squeeze()

       
     

    if self.training:

     

    keep5 += target

       
     

    # Remove voxels 2

     

    out5 = self.pruning(out5, keep5.cpu())

       
     

    # Block 5

     

    out6 = self.block6(out5)

     

    out6_cls = self.block6_cls(out6)

     

    target = self.get_target(out6, target_key)

     

    targets.append(target)

     

    out_cls.append(out6_cls)

     

    keep6 = (out6_cls.F > 0).cpu().squeeze()

       
     

    # Last layer does not require keep

     

    # if self.training:

     

    # keep6 += target

       
     

    # Remove voxels 1

     

    out6 = self.pruning(out6, keep6.cpu())

       
     

    return out_cls, targets, out6

       
       
     

    def train(net, dataloader, device, config):

     

    in_nchannel = len(dataloader.dataset)

       
     

    optimizer = optim.SGD(

     

    net.parameters(),

     

    lr=config.lr,

     

    momentum=config.momentum,

     

    weight_decay=config.weight_decay)

     

    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.95)

       
     

    crit = nn.BCEWithLogitsLoss()

       
     

    net.train()

     

    train_iter = iter(dataloader)

     

    # val_iter = iter(val_dataloader)

     

    logging.info(f'LR: {scheduler.get_lr()}')

     

    for i in range(config.max_iter):

       
     

    s = time()

     

    data_dict = train_iter.next()

     

    d = time() - s

       
     

    optimizer.zero_grad()

     

    init_coords = torch.zeros((config.batch_size, 4), dtype=torch.int)

     

    init_coords[:, 0] = torch.arange(config.batch_size)

       
     

    in_feat = torch.zeros((config.batch_size, in_nchannel))

     

    in_feat[torch.arange(config.batch_size), data_dict['labels']] = 1

       
     

    sin = ME.SparseTensor(

     

    feats=in_feat,

     

    coords=init_coords,

     

    allow_duplicate_coords=True, # for classification, it doesn't matter

     

    tensor_stride=config.resolution,

     

    ).to(device)

       
     

    # Generate target sparse tensor

     

    cm = sin.coords_man

     

    target_key = cm.create_coords_key(

     

    ME.utils.batched_coordinates(data_dict['xyzs']),

     

    force_creation=True,

     

    allow_duplicate_coords=True)

       
     

    # Generate from a dense tensor

     

    out_cls, targets, sout = net(sin, target_key)

     

    num_layers, loss = len(out_cls), 0

     

    losses = []

     

    for out_cl, target in zip(out_cls, targets):

     

    curr_loss = crit(out_cl.F.squeeze(),

     

    target.type(out_cl.F.dtype).to(device))

     

    losses.append(curr_loss.item())

     

    loss += curr_loss / num_layers

       
     

    loss.backward()

     

    optimizer.step()

     

    t = time() - s

       
     

    if i % config.stat_freq == 0:

     

    logging.info(

     

    f'Iter: {i}, Loss: {loss.item():.3e}, Depths: {len(out_cls)} Data Loading Time: {d:.3e}, Tot Time: {t:.3e}'

     

    )

       
     

    if i % config.val_freq == 0 and i > 0:

     

    torch.save(

     

    {

     

    'state_dict': net.state_dict(),

     

    'optimizer': optimizer.state_dict(),

     

    'scheduler': scheduler.state_dict(),

     

    'curr_iter': i,

     

    }, config.weights)

       
     

    scheduler.step()

     

    logging.info(f'LR: {scheduler.get_lr()}')

       
     

    net.train()

       
       
     

    def visualize(net, dataloader, device, config):

     

    in_nchannel = len(dataloader.dataset)

     

    net.eval()

     

    crit = nn.BCEWithLogitsLoss()

     

    n_vis = 0

       
     

    for data_dict in dataloader:

       
     

    init_coords = torch.zeros((config.batch_size, 4), dtype=torch.int)

     

    init_coords[:, 0] = torch.arange(config.batch_size)

       
     

    in_feat = torch.zeros((config.batch_size, in_nchannel))

     

    in_feat[torch.arange(config.batch_size), data_dict['labels']] = 1

       
     

    sin = ME.SparseTensor(

     

    feats=in_feat,

     

    coords=init_coords,

     

    allow_duplicate_coords=True, # for classification, it doesn't matter

     

    tensor_stride=config.resolution,

     

    ).to(device)

       
     

    # Generate target sparse tensor

     

    cm = sin.coords_man

     

    target_key = cm.create_coords_key(

     

    ME.utils.batched_coordinates(data_dict['xyzs']),

     

    force_creation=True,

     

    allow_duplicate_coords=True)

       
     

    # Generate from a dense tensor

     

    out_cls, targets, sout = net(sin, target_key)

     

    num_layers, loss = len(out_cls), 0

     

    for out_cl, target in zip(out_cls, targets):

     

    loss += crit(out_cl.F.squeeze(),

     

    target.type(out_cl.F.dtype).to(device)) / num_layers

       
     

    batch_coords, batch_feats = sout.decomposed_coordinates_and_features

     

    for b, (coords, feats) in enumerate(zip(batch_coords, batch_feats)):

     

    pcd = PointCloud(coords)

     

    pcd.estimate_normals()

     

    pcd.translate([0.6 * config.resolution, 0, 0])

     

    pcd.rotate(M)

     

    opcd = PointCloud(data_dict['xyzs'][b])

     

    opcd.translate([-0.6 * config.resolution, 0, 0])

     

    opcd.estimate_normals()

     

    opcd.rotate(M)

     

    o3d.visualization.draw_geometries([pcd, opcd])

       
     

    n_vis += 1

     

    if n_vis > config.max_visualization:

     

    return

       
       
     

    if __name__ == '__main__':

     

    config = parser.parse_args()

     

    logging.info(config)

     

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

       
     

    dataloader = make_data_loader(

     

    'val',

     

    augment_data=True,

     

    batch_size=config.batch_size,

     

    shuffle=True,

     

    num_workers=config.num_workers,

     

    repeat=True,

     

    config=config)

     

    in_nchannel = len(dataloader.dataset)

       
     

    net = GenerativeNet(config.resolution, in_nchannel=in_nchannel)

     

    net.to(device)

       
     

    logging.info(net)

       
     

    if config.train:

     

    train(net, dataloader, device, config)

     

    else:

     

    if not os.path.exists(config.weights):

     

    logging.info(

     

    f'Downloaing pretrained weights. This might take a while...')

     

    urllib.request.urlretrieve(

     

    "https://bit.ly/36d9m1n", filename=config.weights)

       
     

    logging.info(f'Loading weights from {config.weights}')

     

    checkpoint = torch.load(config.weights)

     

    net.load_state_dict(checkpoint['state_dict'])

       
     

    visualize(net, dataloader, device, config)

    人工智能芯片与自动驾驶
  • 相关阅读:
    bzoj3401[Usaco2009 Mar]Look Up 仰望*
    bzoj2021[Usaco2010 Jan]Cheese Towers*
    bzoj3767A+B Problem加强版
    bzoj3942[Usaco2015 Feb]Censoring*
    bzoj1673[Usaco2005 Dec]Scales 天平*
    bzoj3670[Noi2014]动物园
    stark——pop功能(admin中添加功能)
    stark——快速过滤list_filter
    stark——分页、search、actions
    stark——增删改页面
  • 原文地址:https://www.cnblogs.com/wujianming-110117/p/14227771.html
Copyright © 2020-2023  润新知