from typing import *
import numpy as np
import numpy
import random
import scipy
import re
import torch
import logging
import torch.nn.functional as F
from torch_scatter import scatter_mean, scatter_add
from torch_geometric.nn.pool.consecutive import consecutive_cluster
from torch_geometric.nn import voxel_grid
from torch_geometric.data import Data
from torch_cluster import grid_cluster
log = logging.getLogger(__name__)
# Label will be the majority label in each voxel
_INTEGER_LABEL_KEYS = ["y", "instance_labels"]
def shuffle_data(data):
num_points = data.pos.shape[0]
shuffle_idx = torch.randperm(num_points)
for key in set(data.keys):
item = data[key]
if torch.is_tensor(item) and num_points == item.shape[0]:
data[key] = item[shuffle_idx]
return data
def group_data(data, cluster=None, unique_pos_indices=None, mode="last", skip_keys=[]):
""" Group data based on indices in cluster.
The option ``mode`` controls how data gets agregated within each cluster.
Parameters
----------
data : Data
[description]
cluster : torch.Tensor
Tensor of the same size as the number of points in data. Each element is the cluster index of that point.
unique_pos_indices : torch.tensor
Tensor containing one index per cluster, this index will be used to select features and labels
mode : str
Option to select how the features and labels for each voxel is computed. Can be ``last`` or ``mean``.
``last`` selects the last point falling in a voxel as the representent, ``mean`` takes the average.
skip_keys: list
Keys of attributes to skip in the grouping
"""
assert mode in ["mean", "last"]
if mode == "mean" and cluster is None:
raise ValueError("In mean mode the cluster argument needs to be specified")
if mode == "last" and unique_pos_indices is None:
raise ValueError("In last mode the unique_pos_indices argument needs to be specified")
num_nodes = data.num_nodes
for key, item in data:
if bool(re.search("edge", key)):
raise ValueError("Edges not supported. Wrong data type.")
if key in skip_keys:
continue
if torch.is_tensor(item) and item.size(0) == num_nodes:
if mode == "last" or key == "batch" or key == SaveOriginalPosId.KEY:
data[key] = item[unique_pos_indices]
elif mode == "mean":
is_item_bool = item.dtype == torch.bool
if is_item_bool:
item = item.int()
if key in _INTEGER_LABEL_KEYS:
item_min = item.min()
item = F.one_hot(item - item_min)
item = scatter_add(item, cluster, dim=0)
data[key] = item.argmax(dim=-1) + item_min
else:
data[key] = scatter_mean(item, cluster, dim=0)
if is_item_bool:
data[key] = data[key].bool()
return data
[docs]class GridSampling3D:
""" Clusters points into voxels with size :attr:`size`.
Parameters
----------
size: float
Size of a voxel (in each dimension).
quantize_coords: bool
If True, it will convert the points into their associated sparse coordinates within the grid and store
the value into a new `coords` attribute
mode: string:
The mode can be either `last` or `mean`.
If mode is `mean`, all the points and their features within a cell will be averaged
If mode is `last`, one random points per cell will be selected with its associated features
"""
def __init__(self, size, quantize_coords=False, mode="mean", verbose=False):
self._grid_size = size
self._quantize_coords = quantize_coords
self._mode = mode
if verbose:
log.warning(
"If you need to keep track of the position of your points, use SaveOriginalPosId transform before using GridSampling3D"
)
if self._mode == "last":
log.warning(
"The tensors within data will be shuffled each time this transform is applied. Be careful that if an attribute doesn't have the size of num_points, it won't be shuffled"
)
def _process(self, data):
if self._mode == "last":
data = shuffle_data(data)
coords = torch.round((data.pos) / self._grid_size)
if "batch" not in data:
cluster = grid_cluster(coords, torch.tensor([1, 1, 1]))
else:
cluster = voxel_grid(coords, data.batch, 1)
cluster, unique_pos_indices = consecutive_cluster(cluster)
data = group_data(data, cluster, unique_pos_indices, mode=self._mode)
if self._quantize_coords:
data.coords = coords[unique_pos_indices].int()
data.grid_size = torch.tensor([self._grid_size])
return data
def __call__(self, data):
if isinstance(data, list):
data = [self._process(d) for d in data]
else:
data = self._process(data)
return data
def __repr__(self):
return "{}(grid_size={}, quantize_coords={}, mode={})".format(
self.__class__.__name__, self._grid_size, self._quantize_coords, self._mode
)
class SaveOriginalPosId:
""" Transform that adds the index of the point to the data object
This allows us to track this point from the output back to the input data object
"""
KEY = "origin_id"
def _process(self, data):
if hasattr(data, self.KEY):
return data
setattr(data, self.KEY, torch.arange(0, data.pos.shape[0]))
return data
def __call__(self, data):
if isinstance(data, list):
data = [self._process(d) for d in data]
else:
data = self._process(data)
return data
def __repr__(self):
return self.__class__.__name__
[docs]class ElasticDistortion:
"""Apply elastic distortion on sparse coordinate space. First projects the position onto a
voxel grid and then apply the distortion to the voxel grid.
Parameters
----------
granularity: List[float]
Granularity of the noise in meters
magnitude:List[float]
Noise multiplier in meters
Returns
-------
data: Data
Returns the same data object with distorted grid
"""
def __init__(
self, apply_distorsion: bool = True, granularity: List = [0.2, 0.8], magnitude=[0.4, 1.6],
):
assert len(magnitude) == len(granularity)
self._apply_distorsion = apply_distorsion
self._granularity = granularity
self._magnitude = magnitude
@staticmethod
def elastic_distortion(coords, granularity, magnitude):
coords = coords.numpy()
blurx = np.ones((3, 1, 1, 1)).astype("float32") / 3
blury = np.ones((1, 3, 1, 1)).astype("float32") / 3
blurz = np.ones((1, 1, 3, 1)).astype("float32") / 3
coords_min = coords.min(0)
# Create Gaussian noise tensor of the size given by granularity.
noise_dim = ((coords - coords_min).max(0) // granularity).astype(int) + 3
noise = np.random.randn(*noise_dim, 3).astype(np.float32)
# Smoothing.
for _ in range(2):
noise = scipy.ndimage.filters.convolve(noise, blurx, mode="constant", cval=0)
noise = scipy.ndimage.filters.convolve(noise, blury, mode="constant", cval=0)
noise = scipy.ndimage.filters.convolve(noise, blurz, mode="constant", cval=0)
# Trilinear interpolate noise filters for each spatial dimensions.
ax = [
np.linspace(d_min, d_max, d)
for d_min, d_max, d in zip(coords_min - granularity, coords_min + granularity * (noise_dim - 2), noise_dim)
]
interp = scipy.interpolate.RegularGridInterpolator(ax, noise, bounds_error=0, fill_value=0)
coords = coords + interp(coords) * magnitude
return torch.tensor(coords).float()
def __call__(self, data):
# coords = data.pos / self._spatial_resolution
if self._apply_distorsion:
if random.random() < 0.95:
for i in range(len(self._granularity)):
data.pos = ElasticDistortion.elastic_distortion(data.pos, self._granularity[i], self._magnitude[i],)
return data
def __repr__(self):
return "{}(apply_distorsion={}, granularity={}, magnitude={})".format(
self.__class__.__name__, self._apply_distorsion, self._granularity, self._magnitude,
)