from typing import List
import itertools
import numpy as np
import math
import re
import torch
import random
from tqdm.auto import tqdm as tq
from sklearn.neighbors import KDTree
from functools import partial
from torch.nn import functional as F
from torch_geometric.nn.pool.pool import pool_pos, pool_batch
from torch_geometric.data import Data, Batch
from torch_scatter import scatter_add, scatter_mean
from torch_geometric.transforms import FixedPoints as FP
from torch_points_kernels.points_cpu import ball_query
import numba
from torch_points3d.datasets.multiscale_data import MultiScaleData
from torch_points3d.datasets.registration.pair import Pair
from torch_points3d.utils.transform_utils import SamplingStrategy
from torch_points3d.utils.config import is_list
from torch_points3d.utils import is_iterable
from .grid_transform import group_data, GridSampling3D, shuffle_data
from .features import Random3AxisRotation
KDTREE_KEY = "kd_tree"
[docs]class RemoveAttributes(object):
"""This transform allows to remove unnecessary attributes from data for optimization purposes
Parameters
----------
attr_names: list
Remove the attributes from data using the provided `attr_name` within attr_names
strict: bool=False
Wether True, it will raise an execption if the provided attr_name isn t within data keys.
"""
def __init__(self, attr_names=[], strict=False):
self._attr_names = attr_names
self._strict = strict
def __call__(self, data):
keys = set(data.keys)
for attr_name in self._attr_names:
if attr_name not in keys and self._strict:
raise Exception("attr_name: {} isn t within keys: {}".format(attr_name, keys))
for attr_name in self._attr_names:
delattr(data, attr_name)
return data
def __repr__(self):
return "{}(attr_names={}, strict={})".format(self.__class__.__name__, self._attr_names, self._strict)
[docs]class PointCloudFusion(object):
"""This transform is responsible to perform a point cloud fusion from a list of data
- If a list of data is provided -> Create one Batch object with all data
- If a list of list of data is provided -> Create a list of fused point cloud
"""
def _process(self, data_list):
if len(data_list) == 0:
return Data()
data = Batch.from_data_list(data_list)
delattr(data, "batch")
delattr(data, "ptr")
return data
def __call__(self, data_list: List[Data]):
if len(data_list) == 0:
raise Exception("A list of data should be provided")
elif len(data_list) == 1:
return data_list[0]
else:
if isinstance(data_list[0], list):
data = [self._process(d) for d in data_list]
else:
data = self._process(data_list)
return data
def __repr__(self):
return "{}()".format(self.__class__.__name__)
[docs]class GridSphereSampling(object):
"""Fits the point cloud to a grid and for each point in this grid,
create a sphere with a radius r
Parameters
----------
radius: float
Radius of the sphere to be sampled.
grid_size: float, optional
Grid_size to be used with GridSampling3D to select spheres center. If None, radius will be used
delattr_kd_tree: bool, optional
If True, KDTREE_KEY should be deleted as an attribute if it exists
center: bool, optional
If True, a centre transform is apply on each sphere.
"""
KDTREE_KEY = KDTREE_KEY
def __init__(self, radius, grid_size=None, delattr_kd_tree=True, center=True):
self._radius = eval(radius) if isinstance(radius, str) else float(radius)
grid_size = eval(grid_size) if isinstance(grid_size, str) else float(grid_size)
self._grid_sampling = GridSampling3D(size=grid_size if grid_size else self._radius)
self._delattr_kd_tree = delattr_kd_tree
self._center = center
def _process(self, data):
if not hasattr(data, self.KDTREE_KEY):
tree = KDTree(np.asarray(data.pos), leaf_size=50)
else:
tree = getattr(data, self.KDTREE_KEY)
# The kdtree has bee attached to data for optimization reason.
# However, it won't be used for down the transform pipeline and should be removed before any collate func call.
if hasattr(data, self.KDTREE_KEY) and self._delattr_kd_tree:
delattr(data, self.KDTREE_KEY)
# apply grid sampling
grid_data = self._grid_sampling(data.clone())
datas = []
for grid_center in np.asarray(grid_data.pos):
pts = np.asarray(grid_center)[np.newaxis]
# Find closest point within the original data
ind = torch.LongTensor(tree.query(pts, k=1)[1][0])
grid_label = data.y[ind]
# Find neighbours within the original data
ind = torch.LongTensor(tree.query_radius(pts, r=self._radius)[0])
sampler = SphereSampling(self._radius, grid_center, align_origin=self._center)
new_data = sampler(data)
new_data.center_label = grid_label
datas.append(new_data)
return datas
def __call__(self, data):
if isinstance(data, list):
data = [self._process(d) for d in tq(data)]
data = list(itertools.chain(*data)) # 2d list needs to be flatten
else:
data = self._process(data)
return data
def __repr__(self):
return "{}(radius={}, center={})".format(self.__class__.__name__, self._radius, self._center)
class GridCylinderSampling(object):
"""Fits the point cloud to a grid and for each point in this grid,
create a cylinder with a radius r
Parameters
----------
radius: float
Radius of the cylinder to be sampled.
grid_size: float, optional
Grid_size to be used with GridSampling3D to select cylinders center. If None, radius will be used
delattr_kd_tree: bool, optional
If True, KDTREE_KEY should be deleted as an attribute if it exists
center: bool, optional
If True, a centre transform is apply on each cylinder.
"""
KDTREE_KEY = KDTREE_KEY
def __init__(self, radius, grid_size=None, delattr_kd_tree=True, center=True):
self._radius = eval(radius) if isinstance(radius, str) else float(radius)
grid_size = eval(grid_size) if isinstance(grid_size, str) else float(grid_size)
self._grid_sampling = GridSampling3D(size=grid_size if grid_size else self._radius)
self._delattr_kd_tree = delattr_kd_tree
self._center = center
def _process(self, data):
if not hasattr(data, self.KDTREE_KEY):
tree = KDTree(np.asarray(data.pos[:, :-1]), leaf_size=50)
else:
tree = getattr(data, self.KDTREE_KEY)
# The kdtree has bee attached to data for optimization reason.
# However, it won't be used for down the transform pipeline and should be removed before any collate func call.
if hasattr(data, self.KDTREE_KEY) and self._delattr_kd_tree:
delattr(data, self.KDTREE_KEY)
# apply grid sampling
grid_data = self._grid_sampling(data.clone())
datas = []
for grid_center in np.unique(grid_data.pos[:, :-1], axis=0):
pts = np.asarray(grid_center)[np.newaxis]
# Find closest point within the original data
ind = torch.LongTensor(tree.query(pts, k=1)[1][0])
grid_label = data.y[ind]
# Find neighbours within the original data
ind = torch.LongTensor(tree.query_radius(pts, r=self._radius)[0])
sampler = CylinderSampling(self._radius, grid_center, align_origin=self._center)
new_data = sampler(data)
new_data.center_label = grid_label
datas.append(new_data)
return datas
def __call__(self, data):
if isinstance(data, list):
data = [self._process(d) for d in tq(data)]
data = list(itertools.chain(*data)) # 2d list needs to be flatten
else:
data = self._process(data)
return data
def __repr__(self):
return "{}(radius={}, center={})".format(self.__class__.__name__, self._radius, self._center)
class ComputeKDTree(object):
"""Calculate the KDTree and saves it within data
Parameters
-----------
leaf_size:int
Size of the leaf node.
"""
def __init__(self, leaf_size):
self._leaf_size = leaf_size
def _process(self, data):
data.kd_tree = KDTree(np.asarray(data.pos), leaf_size=self._leaf_size)
return data
def __call__(self, data):
if isinstance(data, list):
data = [self._process(d) for d in data]
else:
data = self._process(data)
return data
def __repr__(self):
return "{}(leaf_size={})".format(self.__class__.__name__, self._leaf_size)
[docs]class RandomSphere(object):
"""Select points within a sphere of a given radius. The centre is chosen randomly within the point cloud.
Parameters
----------
radius: float
Radius of the sphere to be sampled.
strategy: str
choose between `random` and `freq_class_based`. The `freq_class_based` \
favors points with low frequency class. This can be used to balance unbalanced datasets
center: bool
if True then the sphere will be moved to the origin
"""
def __init__(self, radius, strategy="random", class_weight_method="sqrt", center=True):
self._radius = eval(radius) if isinstance(radius, str) else float(radius)
self._sampling_strategy = SamplingStrategy(strategy=strategy, class_weight_method=class_weight_method)
self._center = center
def _process(self, data):
# apply sampling strategy
random_center = self._sampling_strategy(data)
random_center = np.asarray(data.pos[random_center])[np.newaxis]
sphere_sampling = SphereSampling(self._radius, random_center, align_origin=self._center)
return sphere_sampling(data)
def __call__(self, data):
if isinstance(data, list):
data = [self._process(d) for d in data]
else:
data = self._process(data)
return data
def __repr__(self):
return "{}(radius={}, center={}, sampling_strategy={})".format(
self.__class__.__name__, self._radius, self._center, self._sampling_strategy
)
class SphereSampling:
""" Samples points within a sphere
Parameters
----------
radius : float
Radius of the sphere
sphere_centre : torch.Tensor or np.array
Centre of the sphere (1D array that contains (x,y,z))
align_origin : bool, optional
move resulting point cloud to origin
"""
KDTREE_KEY = KDTREE_KEY
def __init__(self, radius, sphere_centre, align_origin=True):
self._radius = radius
self._centre = np.asarray(sphere_centre)
if len(self._centre.shape) == 1:
self._centre = np.expand_dims(self._centre, 0)
self._align_origin = align_origin
def __call__(self, data):
num_points = data.pos.shape[0]
if not hasattr(data, self.KDTREE_KEY):
tree = KDTree(np.asarray(data.pos), leaf_size=50)
setattr(data, self.KDTREE_KEY, tree)
else:
tree = getattr(data, self.KDTREE_KEY)
t_center = torch.FloatTensor(self._centre)
ind = torch.LongTensor(tree.query_radius(self._centre, r=self._radius)[0])
new_data = Data()
for key in set(data.keys):
if key == self.KDTREE_KEY:
continue
item = data[key]
if torch.is_tensor(item) and num_points == item.shape[0]:
item = item[ind]
if self._align_origin and key == "pos": # Center the sphere.
item -= t_center
elif torch.is_tensor(item):
item = item.clone()
setattr(new_data, key, item)
return new_data
def __repr__(self):
return "{}(radius={}, center={}, align_origin={})".format(
self.__class__.__name__, self._radius, self._centre, self._align_origin
)
class CylinderSampling:
""" Samples points within a cylinder
Parameters
----------
radius : float
Radius of the cylinder
cylinder_centre : torch.Tensor or np.array
Centre of the cylinder (1D array that contains (x,y,z) or (x,y))
align_origin : bool, optional
move resulting point cloud to origin
"""
KDTREE_KEY = KDTREE_KEY
def __init__(self, radius, cylinder_centre, align_origin=True):
self._radius = radius
if cylinder_centre.shape[0] == 3:
cylinder_centre = cylinder_centre[:-1]
self._centre = np.asarray(cylinder_centre)
if len(self._centre.shape) == 1:
self._centre = np.expand_dims(self._centre, 0)
self._align_origin = align_origin
def __call__(self, data):
num_points = data.pos.shape[0]
if not hasattr(data, self.KDTREE_KEY):
tree = KDTree(np.asarray(data.pos[:, :-1]), leaf_size=50)
setattr(data, self.KDTREE_KEY, tree)
else:
tree = getattr(data, self.KDTREE_KEY)
t_center = torch.FloatTensor(self._centre)
ind = torch.LongTensor(tree.query_radius(self._centre, r=self._radius)[0])
new_data = Data()
for key in set(data.keys):
if key == self.KDTREE_KEY:
continue
item = data[key]
if torch.is_tensor(item) and num_points == item.shape[0]:
item = item[ind]
if self._align_origin and key == "pos": # Center the cylinder.
item[:, :-1] -= t_center
elif torch.is_tensor(item):
item = item.clone()
setattr(new_data, key, item)
return new_data
def __repr__(self):
return "{}(radius={}, center={}, align_origin={})".format(
self.__class__.__name__, self._radius, self._centre, self._align_origin
)
[docs]class Select:
""" Selects given points from a data object
Parameters
----------
indices : torch.Tensor
indeices of the points to keep. Can also be a boolean mask
"""
def __init__(self, indices=None):
self._indices = indices
def __call__(self, data):
num_points = data.pos.shape[0]
new_data = Data()
for key in data.keys:
if key == KDTREE_KEY:
continue
item = data[key]
if torch.is_tensor(item) and num_points == item.shape[0]:
item = item[self._indices].clone()
elif torch.is_tensor(item):
item = item.clone()
setattr(new_data, key, item)
return new_data
class CylinderNormalizeScale(object):
""" Normalize points within a cylinder
"""
def __init__(self, normalize_z=True):
self._normalize_z = normalize_z
def _process(self, data):
data.pos -= data.pos.mean(dim=0, keepdim=True)
scale = (1 / data.pos[:, :-1].abs().max()) * 0.999999
data.pos[:, :-1] = data.pos[:, :-1] * scale
if self._normalize_z:
scale = (1 / data.pos[:, -1].abs().max()) * 0.999999
data.pos[:, -1] = data.pos[:, -1] * scale
return data
def __call__(self, data):
if isinstance(data, list):
data = [self._process(d) for d in data]
else:
data = self._process(data)
return data
def __repr__(self):
return "{}(normalize_z={})".format(self.__class__.__name__, self._normalize_z)
[docs]class RandomSymmetry(object):
""" Apply a random symmetry transformation on the data
Parameters
----------
axis: Tuple[bool,bool,bool], optional
axis along which the symmetry is applied
"""
def __init__(self, axis=[False, False, False]):
self.axis = axis
def __call__(self, data):
for i, ax in enumerate(self.axis):
if ax:
if torch.rand(1) < 0.5:
c_max = torch.max(data.pos[:, i])
data.pos[:, i] = c_max - data.pos[:, i]
return data
def __repr__(self):
return "Random symmetry of axes: x={}, y={}, z={}".format(*self.axis)
[docs]class RandomNoise(object):
""" Simple isotropic additive gaussian noise (Jitter)
Parameters
----------
sigma:
Variance of the noise
clip:
Maximum amplitude of the noise
"""
def __init__(self, sigma=0.01, clip=0.05):
self.sigma = sigma
self.clip = clip
def __call__(self, data):
noise = self.sigma * torch.randn(data.pos.shape)
noise = noise.clamp(-self.clip, self.clip)
data.pos = data.pos + noise
return data
def __repr__(self):
return "{}(sigma={}, clip={})".format(self.__class__.__name__, self.sigma, self.clip)
[docs]class ScalePos:
def __init__(self, scale=None):
self.scale = scale
def __call__(self, data):
data.pos *= self.scale
return data
def __repr__(self):
return "{}(scale={})".format(self.__class__.__name__, self.scale)
[docs]class RandomScaleAnisotropic:
r""" Scales node positions by a randomly sampled factor ``s1, s2, s3`` within a
given interval, *e.g.*, resulting in the transformation matrix
.. math::
\left[
\begin{array}{ccc}
s1 & 0 & 0 \\
0 & s2 & 0 \\
0 & 0 & s3 \\
\end{array}
\right]
for three-dimensional positions.
Parameters
-----------
scales:
scaling factor interval, e.g. ``(a, b)``, then scale \
is randomly sampled from the range \
``a <= b``. \
"""
def __init__(self, scales=None, anisotropic=True):
assert is_iterable(scales) and len(scales) == 2
assert scales[0] <= scales[1]
self.scales = scales
def __call__(self, data):
scale = self.scales[0] + torch.rand((3,)) * (self.scales[1] - self.scales[0])
data.pos = data.pos * scale
if getattr(data, "norm", None) is not None:
data.norm = data.norm / scale
data.norm = torch.nn.functional.normalize(data.norm, dim=1)
return data
def __repr__(self):
return "{}({})".format(self.__class__.__name__, self.scales)
class MeshToNormal(object):
""" Computes mesh normals (IN PROGRESS)
"""
def __init__(self):
pass
def __call__(self, data):
if hasattr(data, "face"):
pos = data.pos
face = data.face
vertices = [pos[f] for f in face]
normals = torch.cross(vertices[0] - vertices[1], vertices[0] - vertices[2], dim=1)
normals = F.normalize(normals)
data.normals = normals
return data
def __repr__(self):
return "{}".format(self.__class__.__name__)
[docs]class ShuffleData(object):
""" This transform allow to shuffle feature, pos and label tensors within data
"""
def _process(self, data):
return shuffle_data(data)
def __call__(self, data):
if isinstance(data, list):
data = [self._process(d) for d in tq(data)]
data = list(itertools.chain(*data)) # 2d list needs to be flatten
else:
data = self._process(data)
return data
class PairTransform(object):
def __init__(self, transform):
"""
apply the transform for a pair of data
(as defined in torch_points3d/datasets/registration/pair.py)
"""
self.transform = transform
def __call__(self, data):
data_source, data_target = data.to_data()
data_source = self.transform(data_source)
data_target = self.transform(data_target)
return Pair.make_pair(data_source, data_target)
def __repr__(self):
return "{}()".format(self.__class__.__name__)
[docs]class ShiftVoxels:
""" Trick to make Sparse conv invariant to even and odds coordinates
https://github.com/chrischoy/SpatioTemporalSegmentation/blob/master/lib/train.py#L78
Parameters
-----------
apply_shift: bool:
Whether to apply the shift on indices
"""
def __init__(self, apply_shift=True):
self._apply_shift = apply_shift
def __call__(self, data):
if self._apply_shift:
if not hasattr(data, "coords"):
raise Exception("should quantize first using GridSampling3D")
if not isinstance(data.coords, torch.IntTensor):
raise Exception("The pos are expected to be coordinates, so torch.IntTensor")
data.coords[:, :3] += (torch.rand(3) * 100).type_as(data.coords)
return data
def __repr__(self):
return "{}(apply_shift={})".format(self.__class__.__name__, self._apply_shift)
[docs]class RandomDropout:
""" Randomly drop points from the input data
Parameters
----------
dropout_ratio : float, optional
Ratio that gets dropped
dropout_application_ratio : float, optional
chances of the dropout to be applied
"""
def __init__(self, dropout_ratio: float = 0.2, dropout_application_ratio: float = 0.5):
self.dropout_ratio = dropout_ratio
self.dropout_application_ratio = dropout_application_ratio
def __call__(self, data):
if random.random() < self.dropout_application_ratio:
N = len(data.pos)
data = FP(int(N * (1 - self.dropout_ratio)))(data)
return data
def __repr__(self):
return "{}(dropout_ratio={}, dropout_application_ratio={})".format(
self.__class__.__name__, self.dropout_ratio, self.dropout_application_ratio
)
def apply_mask(data, mask, skip_keys=[]):
size_pos = len(data.pos)
for k in data.keys:
if size_pos == len(data[k]) and k not in skip_keys:
data[k] = data[k][mask]
return data
@numba.jit(nopython=True, cache=True)
def rw_mask(pos, ind, dist, mask_vertices, random_ratio=0.04, num_iter=5000):
rand_ind = np.random.randint(0, len(pos))
for _ in range(num_iter):
mask_vertices[rand_ind] = False
if np.random.rand() < random_ratio:
rand_ind = np.random.randint(0, len(pos))
else:
neighbors = ind[rand_ind][dist[rand_ind] > 0]
if len(neighbors) == 0:
rand_ind = np.random.randint(0, len(pos))
else:
n_i = np.random.randint(0, len(neighbors))
rand_ind = neighbors[n_i]
return mask_vertices
[docs]class RandomWalkDropout(object):
"""
randomly drop points from input data using random walk
Parameters
----------
dropout_ratio: float, optional
Ratio that gets dropped
num_iter: int, optional
number of iterations
radius: float, optional
radius of the neighborhood search to create the graph
max_num: int optional
max number of neighbors
skip_keys: List optional
skip_keys where we don't apply the mask
"""
def __init__(
self,
dropout_ratio: float = 0.05,
num_iter: int = 5000,
radius: float = 0.5,
max_num: int = -1,
skip_keys: List = [],
):
self.dropout_ratio = dropout_ratio
self.num_iter = num_iter
self.radius = radius
self.max_num = max_num
self.skip_keys = skip_keys
def __call__(self, data):
pos = data.pos.detach().cpu().numpy()
ind, dist = ball_query(data.pos, data.pos, radius=self.radius, max_num=self.max_num, mode=0)
mask = np.ones(len(pos), dtype=bool)
mask = rw_mask(
pos=pos,
ind=ind.detach().cpu().numpy(),
dist=dist.detach().cpu().numpy(),
mask_vertices=mask,
num_iter=self.num_iter,
random_ratio=self.dropout_ratio,
)
data = apply_mask(data, mask, self.skip_keys)
return data
def __repr__(self):
return "{}(dropout_ratio={}, num_iter={}, radius={}, max_num={}, skip_keys={})".format(
self.__class__.__name__, self.dropout_ratio, self.num_iter, self.radius, self.max_num, self.skip_keys
)
[docs]class RandomSphereDropout(object):
"""
drop out of points on random spheres of fixed radius.
This function takes n random balls of fixed radius r and drop
out points inside these balls.
Parameters
----------
num_sphere: int, optional
number of random spheres
radius: float, optional
radius of the spheres
"""
def __init__(self, num_sphere: int = 10, radius: float = 5, grid_size_center: float = 0.01):
self.num_sphere = num_sphere
self.radius = radius
self.grid_sampling = GridSampling3D(grid_size_center, mode="last")
def __call__(self, data):
data_c = self.grid_sampling(data.clone())
list_ind = torch.randint(0, len(data_c.pos), (self.num_sphere,))
center = data_c.pos[list_ind]
pos = data.pos
# list_ind = torch.randint(0, len(pos), (self.num_sphere,))
ind, dist = ball_query(data.pos, center, radius=self.radius, max_num=-1, mode=1)
ind = ind[dist[:, 0] >= 0]
mask = torch.ones(len(pos), dtype=torch.bool)
mask[ind[:, 0]] = False
data = apply_mask(data, mask)
return data
def __repr__(self):
return "{}(num_sphere={}, radius={})".format(self.__class__.__name__, self.num_sphere, self.radius)
class FixedSphereDropout(object):
"""
drop out of points on spheres of fixed centers fixed radius.
This function takes n random balls of fixed radius r and drop
out points inside these balls.
Parameters
----------
center: list of list of float, optional
centers of the spheres
radius: float, optional
radius of the spheres
"""
def __init__(self, centers: List[List[float]] = [[0, 0, 0]], name_ind=None, radius: float = 1):
self.centers = torch.tensor(centers)
self.radius = radius
self.name_ind = name_ind
def __call__(self, data):
if self.name_ind is None:
ind, dist = ball_query(data.pos, self.centers, radius=self.radius, max_num=-1, mode=1)
else:
center = data.pos[data[self.name_ind].long()]
ind, dist = ball_query(data.pos, center, radius=self.radius, max_num=-1, mode=1)
ind = ind[dist[:, 0] > 0]
mask = torch.ones(len(data.pos), dtype=torch.bool)
mask[ind[:, 0]] = False
data = apply_mask(data, mask)
return data
def __repr__(self):
return "{}(centers={}, radius={})".format(self.__class__.__name__, self.centers, self.radius)
[docs]class SphereCrop(object):
"""
crop the point cloud on a sphere. this function.
takes a ball of radius radius centered on a random point and points
outside the ball are rejected.
Parameters
----------
radius: float, optional
radius of the sphere
"""
def __init__(self, radius: float = 50):
self.radius = radius
def __call__(self, data):
i = torch.randint(0, len(data.pos), (1,))
ind, dist = ball_query(data.pos, data.pos[i].view(1, 3), radius=self.radius, max_num=-1, mode=1)
ind = ind[dist[:, 0] > 0]
size_pos = len(data.pos)
for k in data.keys:
if size_pos == len(data[k]):
data[k] = data[k][ind[:, 0]]
return data
def __repr__(self):
return "{}(radius={})".format(self.__class__.__name__, self.radius)
[docs]class CubeCrop(object):
"""
Crop cubically the point cloud. This function take a cube of size c
centered on a random point, then points outside the cube are rejected.
Parameters
----------
c: float, optional
half size of the cube
rot_x: float_otional
rotation of the cube around x axis
rot_y: float_otional
rotation of the cube around x axis
rot_z: float_otional
rotation of the cube around x axis
"""
def __init__(
self, c: float = 1, rot_x: float = 180, rot_y: float = 180, rot_z: float = 180, grid_size_center: float = 0.01
):
self.c = c
self.random_rotation = Random3AxisRotation(rot_x=rot_x, rot_y=rot_y, rot_z=rot_z)
self.grid_sampling = GridSampling3D(grid_size_center, mode="last")
def __call__(self, data):
data_c = self.grid_sampling(data.clone())
data_temp = data.clone()
i = torch.randint(0, len(data_c.pos), (1,))
center = data_c.pos[i]
min_square = center - self.c
max_square = center + self.c
data_temp.pos = data_temp.pos - center
data_temp = self.random_rotation(data_temp)
data_temp.pos = data_temp.pos + center
mask = torch.prod((data_temp.pos - min_square) > 0, dim=1) * torch.prod((max_square - data_temp.pos) > 0, dim=1)
mask = mask.to(torch.bool)
data = apply_mask(data, mask)
return data
def __repr__(self):
return "{}(c={}, rotation={})".format(self.__class__.__name__, self.c, self.random_rotation)
class EllipsoidCrop(object):
"""
"""
def __init__(
self, a: float = 1, b: float = 1, c: float = 1, rot_x: float = 180, rot_y: float = 180, rot_z: float = 180
):
"""
Crop with respect to an ellipsoid.
the function of an ellipse is defined as:
Parameters
----------
a: float, optional
half size of the cube
b: float_otional
rotation of the cube around x axis
c: float_otional
rotation of the cube around x axis
"""
self._a2 = a ** 2
self._b2 = b ** 2
self._c2 = c ** 2
self.random_rotation = Random3AxisRotation(rot_x=rot_x, rot_y=rot_y, rot_z=rot_z)
def _compute_mask(self, pos: torch.Tensor):
mask = (pos[:, 0] ** 2 / self._a2 + pos[:, 1] ** 2 / self._b2 + pos[:, 2] ** 2 / self._c2) < 1
return mask
def __call__(self, data):
data_temp = data.clone()
i = torch.randint(0, len(data.pos), (1,))
data_temp = self.random_rotation(data_temp)
center = data_temp.pos[i]
data_temp.pos = data_temp.pos - center
mask = self._compute_mask(data_temp.pos)
data = apply_mask(data, mask)
return data
def __repr__(self):
return "{}(a={}, b={}, c={}, rotation={})".format(
self.__class__.__name__, np.sqrt(self._a2), np.sqrt(self._b2), np.sqrt(self._c2), self.random_rotation
)
class DensityFilter(object):
"""
Remove points with a low density(compute the density with a radius search and remove points with)
a low number of neighbors
Parameters
----------
radius_nn: float, optional
radius for the neighbors search
min_num: int, otional
minimum number of neighbors to be dense
skip_keys: int, otional
list of attributes of data to skip when we apply the mask
"""
def __init__(self, radius_nn: float = 0.04, min_num: int = 6, skip_keys: List = []):
self.radius_nn = radius_nn
self.min_num = min_num
self.skip_keys = skip_keys
def __call__(self, data):
ind, dist = ball_query(data.pos, data.pos, radius=self.radius_nn, max_num=-1, mode=0)
mask = (dist > 0).sum(1) > self.min_num
data = apply_mask(data, mask, self.skip_keys)
return data
def __repr__(self):
return "{}(radius_nn={}, min_num={}, skip_keys={})".format(
self.__class__.__name__, self.radius_nn, self.min_num, self.skip_keys
)
class IrregularSampling(object):
"""
a sort of soft crop. the more we are far from the center, the more it is unlikely to choose the point
"""
def __init__(self, d_half=2.5, p=2, grid_size_center=0.1, skip_keys=[]):
self.d_half = d_half
self.p = p
self.skip_keys = skip_keys
self.grid_sampling = GridSampling3D(grid_size_center, mode="last")
def __call__(self, data):
data_temp = self.grid_sampling(data.clone())
i = torch.randint(0, len(data_temp.pos), (1,))
center = data_temp.pos[i]
d_p = (torch.abs(data.pos - center) ** self.p).sum(1)
sigma_2 = (self.d_half ** self.p) / (2 * np.log(2))
thresh = torch.exp(-d_p / (2 * sigma_2))
mask = torch.rand(len(data.pos)) < thresh
data = apply_mask(data, mask, self.skip_keys)
return data
def __repr__(self):
return "{}(d_half={}, p={}, skip_keys={})".format(self.__class__.__name__, self.d_half, self.p, self.skip_keys)
class PeriodicSampling(object):
"""
sample point at a periodic distance
"""
def __init__(self, period=0.1, prop=0.1, box_multiplier=1, skip_keys=[]):
self.pulse = 2 * np.pi / period
self.thresh = np.cos(self.pulse * prop * period * 0.5)
self.box_multiplier = box_multiplier
self.skip_keys = skip_keys
def __call__(self, data):
data_temp = data.clone()
max_p = data_temp.pos.max(0)[0]
min_p = data_temp.pos.min(0)[0]
center = self.box_multiplier * torch.rand(3) * (max_p - min_p) + min_p
d_p = torch.norm(data.pos - center, dim=1)
mask = torch.cos(self.pulse * d_p) > self.thresh
data = apply_mask(data, mask, self.skip_keys)
return data
def __repr__(self):
return "{}(pulse={}, thresh={}, box_mullti={}, skip_keys={})".format(
self.__class__.__name__, self.pulse, self.thresh, self.box_multiplier, self.skip_keys
)