3D MinkowskiEngine稀疏模式重建
本文看一个简单的演示示例,该示例训练一个3D卷积神经网络,该网络用一个热点向量one-hot vector重构3D稀疏模式。这类似于Octree生成网络ICCV'17。输入的one-hot vector一热向量,来自ModelNet40数据集的3D计算机辅助设计(CAD)椅子索引。
使用MinkowskiEngine.MinkowskiConvolutionTranspose和 MinkowskiEngine.MinkowskiPruning,依次将体素上采样2倍,然后删除一些上采样的体素,以生成目标形状。常规的网络体系结构看起来类似于下图,但是细节可能有所不同。
在继续之前,请先阅读训练和数据加载。
创建稀疏模式重建网络
要从矢量创建3D网格世界中定义的稀疏张量,需要从 1×1×1分辨率体素。本文使用一个由块MinkowskiEngine.MinkowskiConvolutionTranspose,MinkowskiEngine.MinkowskiConvolution和MinkowskiEngine.MinkowskiPruning。
在前进过程forward pass中,为1)主要特征和2)稀疏体素分类创建两条路径,以删除不必要的体素。
out = upsample_block(z)
out_cls = classification(out).F
out = pruning(out, out_cls > 0)
在输入的稀疏张量达到目标分辨率之前,网络会重复执行一系列的上采样和修剪操作,以去除不必要的体素。在下图上可视化结果。注意,最终的重建非常精确地捕获了目标几何体。还可视化了上采样和修剪的分层重建过程。
运行示例
要训练网络,请转到Minkowski Engine根目录,然后键入:
python -m examples.reconstruction --train
要可视化网络预测或尝试预先训练的模型,请输入:
python -m examples.reconstruction
该程序将可视化两个3D形状。左边的一个是目标3D形状,右边的一个是重构的网络预测。
完整的代码可以在example / reconstruction.py找到。
import os |
|
import sys |
|
import subprocess |
|
import argparse |
|
import logging |
|
import glob |
|
import numpy as np |
|
from time import time |
|
import urllib |
|
# Must be imported before large libs |
|
try: |
|
import open3d as o3d |
|
except ImportError: |
|
raise ImportError('Please install open3d and scipy with `pip install open3d scipy`.') |
|
import torch |
|
import torch.nn as nn |
|
import torch.utils.data |
|
import torch.optim as optim |
|
import MinkowskiEngine as ME |
|
from examples.modelnet40 import InfSampler, resample_mesh |
|
M = np.array([[0.80656762, -0.5868724, -0.07091862], |
|
[0.3770505, 0.418344, 0.82632997], |
|
[-0.45528188, -0.6932309, 0.55870326]]) |
|
assert int( |
|
o3d.__version__.split('.')[1] |
|
) >= 8, f'Requires open3d version >= 0.8, the current version is {o3d.__version__}' |
|
if not os.path.exists('ModelNet40'): |
|
logging.info('Downloading the fixed ModelNet40 dataset...') |
|
subprocess.run(["sh", "./examples/download_modelnet40.sh"]) |
|
############################################################################### |
|
# Utility functions |
|
############################################################################### |
|
def PointCloud(points, colors=None): |
|
pcd = o3d.geometry.PointCloud() |
|
pcd.points = o3d.utility.Vector3dVector(points) |
|
if colors is not None: |
|
pcd.colors = o3d.utility.Vector3dVector(colors) |
|
return pcd |
|
def collate_pointcloud_fn(list_data): |
|
coords, feats, labels = list(zip(*list_data)) |
|
# Concatenate all lists |
|
return { |
|
'coords': coords, |
|
'xyzs': [torch.from_numpy(feat).float() for feat in feats], |
|
'labels': torch.LongTensor(labels), |
|
} |
|
class ModelNet40Dataset(torch.utils.data.Dataset): |
|
def __init__(self, phase, transform=None, config=None): |
|
self.phase = phase |
|
self.files = [] |
|
self.cache = {} |
|
self.data_objects = [] |
|
self.transform = transform |
|
self.resolution = config.resolution |
|
self.last_cache_percent = 0 |
|
self.root = './ModelNet40' |
|
fnames = glob.glob(os.path.join(self.root, 'chair/train/*.off')) |
|
fnames = sorted([os.path.relpath(fname, self.root) for fname in fnames]) |
|
self.files = fnames |
|
assert len(self.files) > 0, "No file loaded" |
|
logging.info( |
|
f"Loading the subset {phase} from {self.root} with {len(self.files)} files" |
|
) |
|
self.density = 30000 |
|
# Ignore warnings in obj loader |
|
o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Error) |
|
def __len__(self): |
|
return len(self.files) |
|
def __getitem__(self, idx): |
|
mesh_file = os.path.join(self.root, self.files[idx]) |
|
if idx in self.cache: |
|
xyz = self.cache[idx] |
|
else: |
|
# Load a mesh, over sample, copy, rotate, voxelization |
|
assert os.path.exists(mesh_file) |
|
pcd = o3d.io.read_triangle_mesh(mesh_file) |
|
# Normalize to fit the mesh inside a unit cube while preserving aspect ratio |
|
vertices = np.asarray(pcd.vertices) |
|
vmax = vertices.max(0, keepdims=True) |
|
vmin = vertices.min(0, keepdims=True) |
|
pcd.vertices = o3d.utility.Vector3dVector( |
|
(vertices - vmin) / (vmax - vmin).max()) |
|
# Oversample points and copy |
|
xyz = resample_mesh(pcd, density=self.density) |
|
self.cache[idx] = xyz |
|
cache_percent = int((len(self.cache) / len(self)) * 100) |
|
if cache_percent > 0 and cache_percent % 10 == 0 and cache_percent != self.last_cache_percent: |
|
logging.info( |
|
f"Cached {self.phase}: {len(self.cache)} / {len(self)}: {cache_percent}%" |
|
) |
|
self.last_cache_percent = cache_percent |
|
# Use color or other features if available |
|
feats = np.ones((len(xyz), 1)) |
|
if len(xyz) < 1000: |
|
logging.info( |
|
f"Skipping {mesh_file}: does not have sufficient CAD sampling density after resampling: {len(xyz)}." |
|
) |
|
return None |
|
if self.transform: |
|
xyz, feats = self.transform(xyz, feats) |
|
# Get coords |
|
xyz = xyz * self.resolution |
|
coords = np.floor(xyz) |
|
inds = ME.utils.sparse_quantize(coords, return_index=True) |
|
return (coords[inds], xyz[inds], idx) |
|
def make_data_loader(phase, augment_data, batch_size, shuffle, num_workers, |
|
repeat, config): |
|
dset = ModelNet40Dataset(phase, config=config) |
|
args = { |
|
'batch_size': batch_size, |
|
'num_workers': num_workers, |
|
'collate_fn': collate_pointcloud_fn, |
|
'pin_memory': False, |
|
'drop_last': False |
|
} |
|
if repeat: |
|
args['sampler'] = InfSampler(dset, shuffle) |
|
else: |
|
args['shuffle'] = shuffle |
|
loader = torch.utils.data.DataLoader(dset, **args) |
|
return loader |
|
ch = logging.StreamHandler(sys.stdout) |
|
logging.getLogger().setLevel(logging.INFO) |
|
logging.basicConfig( |
|
format=os.uname()[1].split('.')[0] + ' %(asctime)s %(message)s', |
|
datefmt='%m/%d %H:%M:%S', |
|
handlers=[ch]) |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--resolution', type=int, default=128) |
|
parser.add_argument('--max_iter', type=int, default=30000) |
|
parser.add_argument('--val_freq', type=int, default=1000) |
|
parser.add_argument('--batch_size', default=16, type=int) |
|
parser.add_argument('--lr', default=1e-2, type=float) |
|
parser.add_argument('--momentum', type=float, default=0.9) |
|
parser.add_argument('--weight_decay', type=float, default=1e-4) |
|
parser.add_argument('--num_workers', type=int, default=1) |
|
parser.add_argument('--stat_freq', type=int, default=50) |
|
parser.add_argument( |
|
'--weights', type=str, default='modelnet_reconstruction.pth') |
|
parser.add_argument('--load_optimizer', type=str, default='true') |
|
parser.add_argument('--train', action='store_true') |
|
parser.add_argument('--max_visualization', type=int, default=4) |
|
############################################################################### |
|
# End of utility functions |
|
############################################################################### |
|
class GenerativeNet(nn.Module): |
|
CHANNELS = [1024, 512, 256, 128, 64, 32, 16] |
|
def __init__(self, resolution, in_nchannel=512): |
|
nn.Module.__init__(self) |
|
self.resolution = resolution |
|
# Input sparse tensor must have tensor stride 128. |
|
ch = self.CHANNELS |
|
# Block 1 |
|
self.block1 = nn.Sequential( |
|
ME.MinkowskiConvolutionTranspose( |
|
in_nchannel, |
|
ch[0], |
|
kernel_size=2, |
|
stride=2, |
|
generate_new_coords=True, |
|
dimension=3), |
|
ME.MinkowskiBatchNorm(ch[0]), |
|
ME.MinkowskiELU(), |
|
ME.MinkowskiConvolution(ch[0], ch[0], kernel_size=3, dimension=3), |
|
ME.MinkowskiBatchNorm(ch[0]), |
|
ME.MinkowskiELU(), |
|
ME.MinkowskiConvolutionTranspose( |
|
ch[0], |
|
ch[1], |
|
kernel_size=2, |
|
stride=2, |
|
generate_new_coords=True, |
|
dimension=3), |
|
ME.MinkowskiBatchNorm(ch[1]), |
|
ME.MinkowskiELU(), |
|
ME.MinkowskiConvolution(ch[1], ch[1], kernel_size=3, dimension=3), |
|
ME.MinkowskiBatchNorm(ch[1]), |
|
ME.MinkowskiELU(), |
|
) |
|
self.block1_cls = ME.MinkowskiConvolution( |
|
ch[1], 1, kernel_size=1, has_bias=True, dimension=3) |
|
# Block 2 |
|
self.block2 = nn.Sequential( |
|
ME.MinkowskiConvolutionTranspose( |
|
ch[1], |
|
ch[2], |
|
kernel_size=2, |
|
stride=2, |
|
generate_new_coords=True, |
|
dimension=3), |
|
ME.MinkowskiBatchNorm(ch[2]), |
|
ME.MinkowskiELU(), |
|
ME.MinkowskiConvolution(ch[2], ch[2], kernel_size=3, dimension=3), |
|
ME.MinkowskiBatchNorm(ch[2]), |
|
ME.MinkowskiELU(), |
|
) |
|
self.block2_cls = ME.MinkowskiConvolution( |
|
ch[2], 1, kernel_size=1, has_bias=True, dimension=3) |
|
# Block 3 |
|
self.block3 = nn.Sequential( |
|
ME.MinkowskiConvolutionTranspose( |
|
ch[2], |
|
ch[3], |
|
kernel_size=2, |
|
stride=2, |
|
generate_new_coords=True, |
|
dimension=3), |
|
ME.MinkowskiBatchNorm(ch[3]), |
|
ME.MinkowskiELU(), |
|
ME.MinkowskiConvolution(ch[3], ch[3], kernel_size=3, dimension=3), |
|
ME.MinkowskiBatchNorm(ch[3]), |
|
ME.MinkowskiELU(), |
|
) |
|
self.block3_cls = ME.MinkowskiConvolution( |
|
ch[3], 1, kernel_size=1, has_bias=True, dimension=3) |
|
# Block 4 |
|
self.block4 = nn.Sequential( |
|
ME.MinkowskiConvolutionTranspose( |
|
ch[3], |
|
ch[4], |
|
kernel_size=2, |
|
stride=2, |
|
generate_new_coords=True, |
|
dimension=3), |
|
ME.MinkowskiBatchNorm(ch[4]), |
|
ME.MinkowskiELU(), |
|
ME.MinkowskiConvolution(ch[4], ch[4], kernel_size=3, dimension=3), |
|
ME.MinkowskiBatchNorm(ch[4]), |
|
ME.MinkowskiELU(), |
|
) |
|
self.block4_cls = ME.MinkowskiConvolution( |
|
ch[4], 1, kernel_size=1, has_bias=True, dimension=3) |
|
# Block 5 |
|
self.block5 = nn.Sequential( |
|
ME.MinkowskiConvolutionTranspose( |
|
ch[4], |
|
ch[5], |
|
kernel_size=2, |
|
stride=2, |
|
generate_new_coords=True, |
|
dimension=3), |
|
ME.MinkowskiBatchNorm(ch[5]), |
|
ME.MinkowskiELU(), |
|
ME.MinkowskiConvolution(ch[5], ch[5], kernel_size=3, dimension=3), |
|
ME.MinkowskiBatchNorm(ch[5]), |
|
ME.MinkowskiELU(), |
|
) |
|
self.block5_cls = ME.MinkowskiConvolution( |
|
ch[5], 1, kernel_size=1, has_bias=True, dimension=3) |
|
# Block 6 |
|
self.block6 = nn.Sequential( |
|
ME.MinkowskiConvolutionTranspose( |
|
ch[5], |
|
ch[6], |
|
kernel_size=2, |
|
stride=2, |
|
generate_new_coords=True, |
|
dimension=3), |
|
ME.MinkowskiBatchNorm(ch[6]), |
|
ME.MinkowskiELU(), |
|
ME.MinkowskiConvolution(ch[6], ch[6], kernel_size=3, dimension=3), |
|
ME.MinkowskiBatchNorm(ch[6]), |
|
ME.MinkowskiELU(), |
|
) |
|
self.block6_cls = ME.MinkowskiConvolution( |
|
ch[6], 1, kernel_size=1, has_bias=True, dimension=3) |
|
# pruning |
|
self.pruning = ME.MinkowskiPruning() |
|
def get_batch_indices(self, out): |
|
return out.coords_man.get_row_indices_per_batch(out.coords_key) |
|
def get_target(self, out, target_key, kernel_size=1): |
|
with torch.no_grad(): |
|
target = torch.zeros(len(out), dtype=torch.bool) |
|
cm = out.coords_man |
|
strided_target_key = cm.stride( |
|
target_key, out.tensor_stride[0], force_creation=True) |
|
ins, outs = cm.get_kernel_map( |
|
out.coords_key, |
|
strided_target_key, |
|
kernel_size=kernel_size, |
|
region_type=1) |
|
for curr_in in ins: |
|
target[curr_in] = 1 |
|
return target |
|
def valid_batch_map(self, batch_map): |
|
for b in batch_map: |
|
if len(b) == 0: |
|
return False |
|
return True |
|
def forward(self, z, target_key): |
|
out_cls, targets = [], [] |
|
# Block1 |
|
out1 = self.block1(z) |
|
out1_cls = self.block1_cls(out1) |
|
target = self.get_target(out1, target_key) |
|
targets.append(target) |
|
out_cls.append(out1_cls) |
|
keep1 = (out1_cls.F > 0).cpu().squeeze() |
|
# If training, force target shape generation, use net.eval() to disable |
|
if self.training: |
|
keep1 += target |
|
# Remove voxels 32 |
|
out1 = self.pruning(out1, keep1.cpu()) |
|
# Block 2 |
|
out2 = self.block2(out1) |
|
out2_cls = self.block2_cls(out2) |
|
target = self.get_target(out2, target_key) |
|
targets.append(target) |
|
out_cls.append(out2_cls) |
|
keep2 = (out2_cls.F > 0).cpu().squeeze() |
|
if self.training: |
|
keep2 += target |
|
# Remove voxels 16 |
|
out2 = self.pruning(out2, keep2.cpu()) |
|
# Block 3 |
|
out3 = self.block3(out2) |
|
out3_cls = self.block3_cls(out3) |
|
target = self.get_target(out3, target_key) |
|
targets.append(target) |
|
out_cls.append(out3_cls) |
|
keep3 = (out3_cls.F > 0).cpu().squeeze() |
|
if self.training: |
|
keep3 += target |
|
# Remove voxels 8 |
|
out3 = self.pruning(out3, keep3.cpu()) |
|
# Block 4 |
|
out4 = self.block4(out3) |
|
out4_cls = self.block4_cls(out4) |
|
target = self.get_target(out4, target_key) |
|
targets.append(target) |
|
out_cls.append(out4_cls) |
|
keep4 = (out4_cls.F > 0).cpu().squeeze() |
|
if self.training: |
|
keep4 += target |
|
# Remove voxels 4 |
|
out4 = self.pruning(out4, keep4.cpu()) |
|
# Block 5 |
|
out5 = self.block5(out4) |
|
out5_cls = self.block5_cls(out5) |
|
target = self.get_target(out5, target_key) |
|
targets.append(target) |
|
out_cls.append(out5_cls) |
|
keep5 = (out5_cls.F > 0).cpu().squeeze() |
|
if self.training: |
|
keep5 += target |
|
# Remove voxels 2 |
|
out5 = self.pruning(out5, keep5.cpu()) |
|
# Block 5 |
|
out6 = self.block6(out5) |
|
out6_cls = self.block6_cls(out6) |
|
target = self.get_target(out6, target_key) |
|
targets.append(target) |
|
out_cls.append(out6_cls) |
|
keep6 = (out6_cls.F > 0).cpu().squeeze() |
|
# Last layer does not require keep |
|
# if self.training: |
|
# keep6 += target |
|
# Remove voxels 1 |
|
out6 = self.pruning(out6, keep6.cpu()) |
|
return out_cls, targets, out6 |
|
def train(net, dataloader, device, config): |
|
in_nchannel = len(dataloader.dataset) |
|
optimizer = optim.SGD( |
|
net.parameters(), |
|
lr=config.lr, |
|
momentum=config.momentum, |
|
weight_decay=config.weight_decay) |
|
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.95) |
|
crit = nn.BCEWithLogitsLoss() |
|
net.train() |
|
train_iter = iter(dataloader) |
|
# val_iter = iter(val_dataloader) |
|
logging.info(f'LR: {scheduler.get_lr()}') |
|
for i in range(config.max_iter): |
|
s = time() |
|
data_dict = train_iter.next() |
|
d = time() - s |
|
optimizer.zero_grad() |
|
init_coords = torch.zeros((config.batch_size, 4), dtype=torch.int) |
|
init_coords[:, 0] = torch.arange(config.batch_size) |
|
in_feat = torch.zeros((config.batch_size, in_nchannel)) |
|
in_feat[torch.arange(config.batch_size), data_dict['labels']] = 1 |
|
sin = ME.SparseTensor( |
|
feats=in_feat, |
|
coords=init_coords, |
|
allow_duplicate_coords=True, # for classification, it doesn't matter |
|
tensor_stride=config.resolution, |
|
).to(device) |
|
# Generate target sparse tensor |
|
cm = sin.coords_man |
|
target_key = cm.create_coords_key( |
|
ME.utils.batched_coordinates(data_dict['xyzs']), |
|
force_creation=True, |
|
allow_duplicate_coords=True) |
|
# Generate from a dense tensor |
|
out_cls, targets, sout = net(sin, target_key) |
|
num_layers, loss = len(out_cls), 0 |
|
losses = [] |
|
for out_cl, target in zip(out_cls, targets): |
|
curr_loss = crit(out_cl.F.squeeze(), |
|
target.type(out_cl.F.dtype).to(device)) |
|
losses.append(curr_loss.item()) |
|
loss += curr_loss / num_layers |
|
loss.backward() |
|
optimizer.step() |
|
t = time() - s |
|
if i % config.stat_freq == 0: |
|
logging.info( |
|
f'Iter: {i}, Loss: {loss.item():.3e}, Depths: {len(out_cls)} Data Loading Time: {d:.3e}, Tot Time: {t:.3e}' |
|
) |
|
if i % config.val_freq == 0 and i > 0: |
|
torch.save( |
|
{ |
|
'state_dict': net.state_dict(), |
|
'optimizer': optimizer.state_dict(), |
|
'scheduler': scheduler.state_dict(), |
|
'curr_iter': i, |
|
}, config.weights) |
|
scheduler.step() |
|
logging.info(f'LR: {scheduler.get_lr()}') |
|
net.train() |
|
def visualize(net, dataloader, device, config): |
|
in_nchannel = len(dataloader.dataset) |
|
net.eval() |
|
crit = nn.BCEWithLogitsLoss() |
|
n_vis = 0 |
|
for data_dict in dataloader: |
|
init_coords = torch.zeros((config.batch_size, 4), dtype=torch.int) |
|
init_coords[:, 0] = torch.arange(config.batch_size) |
|
in_feat = torch.zeros((config.batch_size, in_nchannel)) |
|
in_feat[torch.arange(config.batch_size), data_dict['labels']] = 1 |
|
sin = ME.SparseTensor( |
|
feats=in_feat, |
|
coords=init_coords, |
|
allow_duplicate_coords=True, # for classification, it doesn't matter |
|
tensor_stride=config.resolution, |
|
).to(device) |
|
# Generate target sparse tensor |
|
cm = sin.coords_man |
|
target_key = cm.create_coords_key( |
|
ME.utils.batched_coordinates(data_dict['xyzs']), |
|
force_creation=True, |
|
allow_duplicate_coords=True) |
|
# Generate from a dense tensor |
|
out_cls, targets, sout = net(sin, target_key) |
|
num_layers, loss = len(out_cls), 0 |
|
for out_cl, target in zip(out_cls, targets): |
|
loss += crit(out_cl.F.squeeze(), |
|
target.type(out_cl.F.dtype).to(device)) / num_layers |
|
batch_coords, batch_feats = sout.decomposed_coordinates_and_features |
|
for b, (coords, feats) in enumerate(zip(batch_coords, batch_feats)): |
|
pcd = PointCloud(coords) |
|
pcd.estimate_normals() |
|
pcd.translate([0.6 * config.resolution, 0, 0]) |
|
pcd.rotate(M) |
|
opcd = PointCloud(data_dict['xyzs'][b]) |
|
opcd.translate([-0.6 * config.resolution, 0, 0]) |
|
opcd.estimate_normals() |
|
opcd.rotate(M) |
|
o3d.visualization.draw_geometries([pcd, opcd]) |
|
n_vis += 1 |
|
if n_vis > config.max_visualization: |
|
return |
|
if __name__ == '__main__': |
|
config = parser.parse_args() |
|
logging.info(config) |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
dataloader = make_data_loader( |
|
'val', |
|
augment_data=True, |
|
batch_size=config.batch_size, |
|
shuffle=True, |
|
num_workers=config.num_workers, |
|
repeat=True, |
|
config=config) |
|
in_nchannel = len(dataloader.dataset) |
|
net = GenerativeNet(config.resolution, in_nchannel=in_nchannel) |
|
net.to(device) |
|
logging.info(net) |
|
if config.train: |
|
train(net, dataloader, device, config) |
|
else: |
|
if not os.path.exists(config.weights): |
|
logging.info( |
|
f'Downloaing pretrained weights. This might take a while...') |
|
urllib.request.urlretrieve( |
|
"https://bit.ly/36d9m1n", filename=config.weights) |
|
logging.info(f'Loading weights from {config.weights}') |
|
checkpoint = torch.load(config.weights) |
|
net.load_state_dict(checkpoint['state_dict']) |
|
visualize(net, dataloader, device, config) |