Source code for torch_points3d.applications.sparseconv3d

import os
import sys
from omegaconf import DictConfig, OmegaConf
import logging
import torch
from torch_geometric.data import Batch

from torch_points3d.applications.modelfactory import ModelFactory
import torch_points3d.modules.SparseConv3d as sp3d
from torch_points3d.core.base_conv.message_passing import *
from torch_points3d.modules.SparseConv3d.modules import *
from torch_points3d.core.base_conv.partial_dense import *
from torch_points3d.models.base_architectures.unet import UnwrappedUnetBasedModel
from torch_points3d.core.common_modules.base_modules import MLP

from .utils import extract_output_nc


CUR_FILE = os.path.realpath(__file__)
DIR_PATH = os.path.dirname(os.path.realpath(__file__))
PATH_TO_CONFIG = os.path.join(DIR_PATH, "conf/sparseconv3d")

log = logging.getLogger(__name__)


[docs]def SparseConv3d( architecture: str = None, input_nc: int = None, num_layers: int = None, config: DictConfig = None, backend: str = "minkowski", *args, **kwargs ): """Create a Sparse Conv backbone model based on architecture proposed in https://arxiv.org/abs/1904.08755 Two backends are available at the moment: - https://github.com/mit-han-lab/torchsparse - https://github.com/NVIDIA/MinkowskiEngine Parameters ---------- architecture : str, optional Architecture of the model, choose from unet, encoder and decoder input_nc : int, optional Number of channels for the input output_nc : int, optional If specified, then we add a fully connected head at the end of the network to provide the requested dimension num_layers : int, optional Depth of the network config : DictConfig, optional Custom config, overrides the num_layers and architecture parameters block: Type of resnet block, ResBlock by default but can be any of the blocks in modules/SparseConv3d/modules.py backend: torchsparse or minkowski """ if "SPARSE_BACKEND" in os.environ and sp3d.nn.backend_valid(os.environ["SPARSE_BACKEND"]): sp3d.nn.set_backend(os.environ["SPARSE_BACKEND"]) else: sp3d.nn.set_backend(backend) factory = SparseConv3dFactory( architecture=architecture, num_layers=num_layers, input_nc=input_nc, config=config, **kwargs ) return factory.build()
class SparseConv3dFactory(ModelFactory): def _build_unet(self): if self._config: model_config = self._config else: path_to_model = os.path.join(PATH_TO_CONFIG, "unet_{}.yaml".format(self.num_layers)) model_config = OmegaConf.load(path_to_model) ModelFactory.resolve_model(model_config, self.num_features, self._kwargs) modules_lib = sys.modules[__name__] return SparseConv3dUnet(model_config, None, None, modules_lib, **self.kwargs) def _build_encoder(self): if self._config: model_config = self._config else: path_to_model = os.path.join( PATH_TO_CONFIG, "encoder_{}.yaml".format(self.num_layers), ) model_config = OmegaConf.load(path_to_model) ModelFactory.resolve_model(model_config, self.num_features, self._kwargs) modules_lib = sys.modules[__name__] return SparseConv3dEncoder(model_config, None, None, modules_lib, **self.kwargs) class BaseSparseConv3d(UnwrappedUnetBasedModel): CONV_TYPE = "sparse" def __init__(self, model_config, model_type, dataset, modules, *args, **kwargs): super().__init__(model_config, model_type, dataset, modules) self.weight_initialization() default_output_nc = kwargs.get("default_output_nc", None) if not default_output_nc: default_output_nc = extract_output_nc(model_config) self._output_nc = default_output_nc self._has_mlp_head = False if "output_nc" in kwargs: self._has_mlp_head = True self._output_nc = kwargs["output_nc"] self.mlp = MLP([default_output_nc, self.output_nc], activation=torch.nn.ReLU(), bias=False) @property def has_mlp_head(self): return self._has_mlp_head @property def output_nc(self): return self._output_nc def weight_initialization(self): for m in self.modules(): if isinstance(m, sp3d.nn.Conv3d) or isinstance(m, sp3d.nn.Conv3dTranspose): torch.nn.init.kaiming_normal_(m.kernel, mode="fan_out", nonlinearity="relu") if isinstance(m, sp3d.nn.BatchNorm): torch.nn.init.constant_(m.bn.weight, 1) torch.nn.init.constant_(m.bn.bias, 0) def _set_input(self, data): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters ----------- data: a dictionary that contains the data itself and its metadata information. """ self.input = sp3d.nn.SparseTensor(data.x, data.coords, data.batch, self.device) if data.pos is not None: self.xyz = data.pos else: self.xyz = data.coords class SparseConv3dEncoder(BaseSparseConv3d): def forward(self, data, *args, **kwargs): """ Parameters: ----------- data A SparseTensor that contains the data itself and its metadata information. Should contain F -- Features [N, C] coords -- Coords [N, 4] Returns -------- data: - x [1, output_nc] """ self._set_input(data) data = self.input for i in range(len(self.down_modules)): data = self.down_modules[i](data) out = Batch(x=data.F, batch=data.C[:, 0].long().to(data.F.device)) if not isinstance(self.inner_modules[0], Identity): out = self.inner_modules[0](out) if self.has_mlp_head: out.x = self.mlp(out.x) return out class SparseConv3dUnet(BaseSparseConv3d): def forward(self, data, *args, **kwargs): """Run forward pass. Input --- D1 -- D2 -- D3 -- U1 -- U2 -- output | |_________| | |______________________| Parameters ----------- data A SparseTensor that contains the data itself and its metadata information. Should contain F -- Features [N, C] coords -- Coords [N, 4] Returns -------- data: - pos [N, 3] (coords or real pos if xyz is in data) - x [N, output_nc] - batch [N] """ self._set_input(data) data = self.input stack_down = [] for i in range(len(self.down_modules) - 1): data = self.down_modules[i](data) stack_down.append(data) data = self.down_modules[-1](data) stack_down.append(None) # TODO : Manage the inner module for i in range(len(self.up_modules)): data = self.up_modules[i](data, stack_down.pop()) out = Batch(x=data.F, pos=self.xyz).to(self.device) if self.has_mlp_head: out.x = self.mlp(out.x) return out