Source code for torch_points3d.applications.pointnet2

import os
import sys
from omegaconf import DictConfig, OmegaConf
import logging

from torch_points3d.applications.modelfactory import ModelFactory
from torch_points3d.modules.pointnet2 import *
from torch_points3d.core.base_conv.dense import DenseFPModule
from torch_points3d.models.base_architectures.unet import UnwrappedUnetBasedModel
from torch_points3d.datasets.multiscale_data import MultiScaleBatch
from torch_points3d.core.common_modules.dense_modules import Conv1D
from torch_points3d.core.common_modules.base_modules import Seq
from .utils import extract_output_nc

CUR_FILE = os.path.realpath(__file__)
DIR_PATH = os.path.dirname(os.path.realpath(__file__))
PATH_TO_CONFIG = os.path.join(DIR_PATH, "conf/pointnet2")

log = logging.getLogger(__name__)


[docs]def PointNet2( architecture: str = None, input_nc: int = None, num_layers: int = None, config: DictConfig = None, multiscale=False, *args, **kwargs ): """ Create a PointNet2 backbone model based on the architecture proposed in https://arxiv.org/abs/1706.02413 Parameters ---------- architecture : str, optional Architecture of the model, choose from unet, encoder and decoder input_nc : int, optional Number of channels for the input output_nc : int, optional If specified, then we add a fully connected head at the end of the network to provide the requested dimension num_layers : int, optional Depth of the network config : DictConfig, optional Custom config, overrides the num_layers and architecture parameters """ factory = PointNet2Factory( architecture=architecture, num_layers=num_layers, input_nc=input_nc, multiscale=multiscale, config=config, **kwargs ) return factory.build()
class PointNet2Factory(ModelFactory): def _build_unet(self): if self._config: model_config = self._config else: path_to_model = os.path.join( PATH_TO_CONFIG, "unet_{}_{}.yaml".format(self.num_layers, "ms" if self.kwargs["multiscale"] else "ss") ) model_config = OmegaConf.load(path_to_model) ModelFactory.resolve_model(model_config, self.num_features, self._kwargs) modules_lib = sys.modules[__name__] return PointNet2Unet(model_config, None, None, modules_lib, **self.kwargs) def _build_encoder(self): if self._config: model_config = self._config else: path_to_model = os.path.join( PATH_TO_CONFIG, "encoder_{}_{}.yaml".format(self.num_layers, "ms" if self.kwargs["multiscale"] else "ss"), ) model_config = OmegaConf.load(path_to_model) ModelFactory.resolve_model(model_config, self.num_features, self._kwargs) modules_lib = sys.modules[__name__] return PointNet2Encoder(model_config, None, None, modules_lib, **self.kwargs) class BasePointnet2(UnwrappedUnetBasedModel): CONV_TYPE = "dense" def __init__(self, model_config, model_type, dataset, modules, *args, **kwargs): super(BasePointnet2, self).__init__(model_config, model_type, dataset, modules) try: default_output_nc = extract_output_nc(model_config) except: default_output_nc = -1 log.warning("Could not resolve number of output channels") self._has_mlp_head = False self._output_nc = default_output_nc if "output_nc" in kwargs: self._has_mlp_head = True self._output_nc = kwargs["output_nc"] self.mlp = Seq() self.mlp.append(Conv1D(default_output_nc, self._output_nc, bn=True, bias=False)) @property def has_mlp_head(self): return self._has_mlp_head @property def output_nc(self): return self._output_nc def _set_input(self, data): """Unpack input data from the dataloader and perform necessary pre-processing steps. """ assert len(data.pos.shape) == 3 data = data.to(self.device) if data.x is not None: data.x = data.x.transpose(1, 2).contiguous() else: data.x = None self.input = data class PointNet2Encoder(BasePointnet2): def forward(self, data, *args, **kwargs): """ Parameters: ----------- data A dictionary that contains the data itself and its metadata information. Should contain x -- Features [B, N, C] pos -- Points [B, N, 3] """ self._set_input(data) data = self.input stack_down = [data] for i in range(len(self.down_modules) - 1): data = self.down_modules[i](data) stack_down.append(data) data = self.down_modules[-1](data) if not isinstance(self.inner_modules[0], Identity): stack_down.append(data) data = self.inner_modules[0](data) if self.has_mlp_head: data.x = self.mlp(data.x) return data class PointNet2Unet(BasePointnet2): def forward(self, data, *args, **kwargs): """ This method does a forward on the Unet assuming symmetrical skip connections Input --- D1 -- D2 -- I -- U1 -- U2 -- U3 -- output | | |________| | | | |______________________| | |___________________________________| Parameters: ----------- data A dictionary that contains the data itself and its metadata information. Should contain x -- Features [B, N, C] pos -- Points [B, N, 3] """ self._set_input(data) data = self.input stack_down = [data] for i in range(len(self.down_modules) - 1): data = self.down_modules[i](data) stack_down.append(data) data = self.down_modules[-1](data) if not isinstance(self.inner_modules[0], Identity): stack_down.append(data) data = self.inner_modules[0](data) sampling_ids = self._collect_sampling_ids(stack_down) for i in range(len(self.up_modules)): data = self.up_modules[i]((data, stack_down.pop())) for key, value in sampling_ids.items(): setattr(data, key, value) if self.has_mlp_head: data.x = self.mlp(data.x) return data