Source code for torch_points3d.core.data_transform
import sys
import numpy as np
import torch_geometric.transforms as T
from .transforms import *
from .grid_transform import *
from .sparse_transforms import *
from .inference_transforms import *
from .feature_augment import *
from .features import *
from .filters import *
from .precollate import *
from .prebatchcollate import *
from omegaconf.dictconfig import DictConfig
from omegaconf.listconfig import ListConfig
from omegaconf import OmegaConf
_custom_transforms = sys.modules[__name__]
_torch_geometric_transforms = sys.modules["torch_geometric.transforms"]
_intersection_names = set(_custom_transforms.__dict__) & set(_torch_geometric_transforms.__dict__)
_intersection_names = set([module for module in _intersection_names if not module.startswith("_")])
L_intersection_names = len(_intersection_names) > 0
_intersection_cls = []
for transform_name in _intersection_names:
transform_cls = getattr(_custom_transforms, transform_name)
if not "torch_geometric.transforms." in str(transform_cls):
_intersection_cls.append(transform_cls)
L_intersection_cls = len(_intersection_cls) > 0
if L_intersection_names:
if L_intersection_cls:
raise Exception(
"It seems that you are overiding a transform from pytorch gemetric, \
this is forbiden, please rename your classes {} from {}".format(
_intersection_names, _intersection_cls
)
)
else:
raise Exception(
"It seems you are importing transforms {} from pytorch geometric within the current code base. \
Please, remove them or add them within a class, function, etc.".format(
_intersection_names
)
)
def instantiate_transform(transform_option, attr="transform"):
""" Creates a transform from an OmegaConf dict such as
transform: GridSampling3D
params:
size: 0.01
"""
tr_name = getattr(transform_option, attr, None)
try:
# tr_params = transform_option.params
tr_params = transform_option.get('params') # Update to OmegaConf 2.0
except KeyError:
tr_params = None
try:
# lparams = transform_option.lparams
lparams = transform_option.get('lparams') # Update to OmegaConf 2.0
except KeyError:
lparams = None
cls = getattr(_custom_transforms, tr_name, None)
if not cls:
cls = getattr(_torch_geometric_transforms, tr_name, None)
if not cls:
raise ValueError("Transform %s is nowhere to be found" % tr_name)
if tr_params and lparams:
return cls(*lparams, **tr_params)
if tr_params:
return cls(**tr_params)
if lparams:
return cls(*lparams)
return cls()
def instantiate_transforms(transform_options):
""" Creates a torch_geometric composite transform from an OmegaConf list such as
- transform: GridSampling3D
params:
size: 0.01
- transform: NormaliseScale
"""
transforms = []
for transform in transform_options:
transforms.append(instantiate_transform(transform))
return T.Compose(transforms)
def instantiate_filters(filter_options):
filters = []
for filt in filter_options:
filters.append(instantiate_transform(filt, "filter"))
return FCompose(filters)
[docs]class LotteryTransform(object):
"""
Transforms which draw a transform randomly among several transforms indicated in transform options
Examples
Parameters
----------
transform_options Omegaconf list which contains the transform
"""
def __init__(self, transform_options):
self.random_transforms = instantiate_transforms(transform_options)
def __call__(self, data):
list_transforms = self.random_transforms.transforms
i = np.random.randint(len(list_transforms))
transform = list_transforms[i]
return transform(data)
def __repr__(self):
rep = "LotteryTransform(["
for trans in self.random_transforms.transforms:
rep = rep + "{}, ".format(trans.__repr__())
rep = rep + "])"
return rep
class ComposeTransform(object):
"""
Transform to compose other transforms with YAML (Compose of torch_geometric does not work).
Example :
.. code-block:: yaml
- transform: ComposeTransform
params:
transform_options:
- transform: GridSampling3D
params:
size: 0.1
- transform: RandomNoise
params:
sigma: 0.05
Parameters:
transform_options: Omegaconf Dict
contains a list of transform
"""
def __init__(self, transform_options):
self.transform = instantiate_transforms(transform_options)
def __call__(self, data):
return self.transform(data)
def __repr__(self):
rep = "ComposeTransform(["
for trans in self.transform.transforms:
rep = rep + "{}, ".format(trans.__repr__())
rep = rep + "])"
return rep
[docs]class RandomParamTransform(object):
"""
create a transform with random parameters
Example (on the yaml)
.. code-block:: yaml
transform: RandomParamTransform
params:
transform_name: GridSampling3D
transform_params:
size:
min: 0.1
max: 0.3
type: "float"
mode:
value: "last"
We can also draw random numbers for two parameters, integer or float
.. code-block:: yaml
transform: RandomParamTransform
params:
transform_name: RandomSphereDropout
transform_params:
radius:
min: 1
max: 2
type: "float"
num_sphere:
min: 1
max: 5
type: "int"
Parameters
----------
transform_name: string:
the name of the transform
transform_options: Omegaconf Dict
contains the name of a variables as a key and min max type as value to specify the range of the parameters and the type of the parameters or it contains the value "value" to specify a variables (see Example above)
"""
def __init__(self, transform_name, transform_params):
self.transform_name = transform_name
self.transform_params = transform_params
self.random_transform = self._instanciate_transform_with_random_params()
def _instanciate_transform_with_random_params(self):
dico = dict()
for p, rang in self.transform_params.items():
if "max" in rang and "min" in rang:
assert rang["max"] - rang["min"] > 0
v = np.random.random() * (rang["max"] - rang["min"]) + rang["min"]
if rang["type"] == "float":
v = float(v)
elif rang["type"] == "int":
v = int(v)
else:
raise NotImplementedError
dico[p] = v
elif "value" in rang:
v = rang["value"]
dico[p] = v
else:
raise NotImplementedError
trans_opt = DictConfig(dict(params=dico, transform=self.transform_name))
random_transform = instantiate_transform(trans_opt, attr="transform")
return random_transform
def __call__(self, data):
self.random_transform = self._instanciate_transform_with_random_params()
return self.random_transform(data)
def __repr__(self):
return "RandomParamTransform({}, params={})".format(self.transform_name, self.transform_params)