Source code for torch_points3d.datasets.segmentation.shapenet

import os
import os.path as osp
import shutil
import json
from tqdm.auto import tqdm as tq
from itertools import repeat, product
import numpy as np
import torch

from torch_geometric.data import Data, InMemoryDataset, extract_zip
from torch_geometric.io import read_txt_array
import torch_geometric.transforms as T
from torch_points3d.core.data_transform import SaveOriginalPosId
from torch_points3d.metrics.shapenet_part_tracker import ShapenetPartTracker
from torch_points3d.datasets.base_dataset import BaseDataset, save_used_properties
from torch_points3d.utils.download import download_url



[docs]class ShapeNet(InMemoryDataset): r"""The ShapeNet part level segmentation dataset from the `"A Scalable Active Framework for Region Annotation in 3D Shape Collections" <http://web.stanford.edu/~ericyi/papers/part_annotation_16_small.pdf>`_ paper, containing about 17,000 3D shape point clouds from 16 shape categories. Each category is annotated with 2 to 6 parts. Args: root (string): Root directory where the dataset should be saved. categories (string or [string], optional): The category of the CAD models (one or a combination of :obj:`"Airplane"`, :obj:`"Bag"`, :obj:`"Cap"`, :obj:`"Car"`, :obj:`"Chair"`, :obj:`"Earphone"`, :obj:`"Guitar"`, :obj:`"Knife"`, :obj:`"Lamp"`, :obj:`"Laptop"`, :obj:`"Motorbike"`, :obj:`"Mug"`, :obj:`"Pistol"`, :obj:`"Rocket"`, :obj:`"Skateboard"`, :obj:`"Table"`). Can be explicitly set to :obj:`None` to load all categories. (default: :obj:`None`) include_normals (bool, optional): If set to :obj:`False`, will not include normal vectors as input features. (default: :obj:`True`) split (string, optional): If :obj:`"train"`, loads the training dataset. If :obj:`"val"`, loads the validation dataset. If :obj:`"trainval"`, loads the training and validation dataset. If :obj:`"test"`, loads the test dataset. (default: :obj:`"trainval"`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) """ url = "https://shapenet.cs.stanford.edu/media/" "shapenetcore_partanno_segmentation_benchmark_v0_normal.zip" category_ids = { "Airplane": "02691156", "Bag": "02773838", "Cap": "02954340", "Car": "02958343", "Chair": "03001627", "Earphone": "03261776", "Guitar": "03467517", "Knife": "03624134", "Lamp": "03636649", "Laptop": "03642806", "Motorbike": "03790512", "Mug": "03797390", "Pistol": "03948459", "Rocket": "04099429", "Skateboard": "04225987", "Table": "04379243", } seg_classes = { "Airplane": [0, 1, 2, 3], "Bag": [4, 5], "Cap": [6, 7], "Car": [8, 9, 10, 11], "Chair": [12, 13, 14, 15], "Earphone": [16, 17, 18], "Guitar": [19, 20, 21], "Knife": [22, 23], "Lamp": [24, 25, 26, 27], "Laptop": [28, 29], "Motorbike": [30, 31, 32, 33, 34, 35], "Mug": [36, 37], "Pistol": [38, 39, 40], "Rocket": [41, 42, 43], "Skateboard": [44, 45, 46], "Table": [47, 48, 49], } def __init__( self, root, categories=None, include_normals=True, split="trainval", transform=None, pre_transform=None, pre_filter=None, is_test=False, ): if categories is None: categories = list(self.category_ids.keys()) if isinstance(categories, str): categories = [categories] assert all(category in self.category_ids for category in categories) self.categories = categories self.is_test = is_test super(ShapeNet, self).__init__( root, transform, pre_transform, pre_filter) if split == "train": path = self.processed_paths[0] raw_path = self.processed_raw_paths[0] elif split == "val": path = self.processed_paths[1] raw_path = self.processed_raw_paths[1] elif split == "test": path = self.processed_paths[2] raw_path = self.processed_raw_paths[2] elif split == "trainval": path = self.processed_paths[3] raw_path = self.processed_raw_paths[3] else: raise ValueError( (f"Split {split} found, but expected either " "train, val, trainval or test")) self.data, self.slices, self.y_mask = self.load_data( path, include_normals) # We have perform a slighly optimzation on memory space of no pre-transform was used. # c.f self._process_filenames if os.path.exists(raw_path): self.raw_data, self.raw_slices, _ = self.load_data( raw_path, include_normals) else: self.get_raw_data = self.get def load_data(self, path, include_normals): '''This function is used twice to load data for both raw and pre_transformed ''' data, slices = torch.load(path) data.x = data.x if include_normals else None y_mask = torch.zeros( (len(self.seg_classes.keys()), 50), dtype=torch.bool) for i, labels in enumerate(self.seg_classes.values()): y_mask[i, labels] = 1 return data, slices, y_mask @property def raw_file_names(self): return list(self.category_ids.values()) + ["train_test_split"] @property def processed_raw_paths(self): cats = "_".join([cat[:3].lower() for cat in self.categories]) processed_raw_paths = [os.path.join(self.processed_dir, "raw_{}_{}".format( cats, s)) for s in ["train", "val", "test", "trainval"]] return processed_raw_paths @property def processed_file_names(self): cats = "_".join([cat[:3].lower() for cat in self.categories]) return [os.path.join("{}_{}.pt".format(cats, split)) for split in ["train", "val", "test", "trainval"]] def download(self): if self.is_test: return path = download_url(self.url, self.root) extract_zip(path, self.root) os.unlink(path) shutil.rmtree(self.raw_dir) name = self.url.split("/")[-1].split(".")[0] os.rename(osp.join(self.root, name), self.raw_dir) def get_raw_data(self, idx, **kwargs): data = self.raw_data.__class__() if hasattr(self.raw_data, '__num_nodes__'): data.num_nodes = self.raw_data.__num_nodes__[idx] for key in self.raw_data.keys: item, slices = self.raw_data[key], self.raw_slices[key] start, end = slices[idx].item(), slices[idx + 1].item() # print(slices[idx], slices[idx + 1]) if torch.is_tensor(item): s = list(repeat(slice(None), item.dim())) s[self.raw_data.__cat_dim__(key, item)] = slice(start, end) elif start + 1 == end: s = slices[start] else: s = slice(start, end) data[key] = item[s] return data def _process_filenames(self, filenames): data_raw_list = [] data_list = [] categories_ids = [self.category_ids[cat] for cat in self.categories] cat_idx = {categories_ids[i]: i for i in range(len(categories_ids))} has_pre_transform = self.pre_transform is not None id_scan = -1 for name in tq(filenames): cat = name.split(osp.sep)[0] if cat not in categories_ids: continue id_scan += 1 data = read_txt_array(osp.join(self.raw_dir, name)) pos = data[:, :3] x = data[:, 3:6] y = data[:, -1].type(torch.long) category = torch.ones(x.shape[0], dtype=torch.long) * cat_idx[cat] id_scan_tensor = torch.from_numpy(np.asarray([id_scan])).clone() data = Data(pos=pos, x=x, y=y, category=category, id_scan=id_scan_tensor) data = SaveOriginalPosId()(data) if self.pre_filter is not None and not self.pre_filter(data): continue data_raw_list.append(data.clone() if has_pre_transform else data) if has_pre_transform: data = self.pre_transform(data) data_list.append(data) if not has_pre_transform: return [], data_raw_list return data_raw_list, data_list def _save_data_list(self, datas, path_to_datas, save_bool=True): if save_bool: torch.save(self.collate(datas), path_to_datas) def _re_index_trainval(self, trainval): if len(trainval) == 0: return trainval train, val = trainval for v in val: v.id_scan += len(train) assert (train[-1].id_scan + 1 == val[0].id_scan).item(), (train[-1].id_scan, val[0].id_scan) return train + val def process(self): if self.is_test: return raw_trainval = [] trainval = [] for i, split in enumerate(["train", "val", "test"]): path = osp.join(self.raw_dir, "train_test_split", f"shuffled_{split}_file_list.json") with open(path, "r") as f: filenames = [ osp.sep.join(name.split('/')[1:]) + ".txt" for name in json.load(f) ] # Removing first directory. data_raw_list, data_list = self._process_filenames( sorted(filenames)) if split == "train" or split == "val": if len(data_raw_list) > 0: raw_trainval.append(data_raw_list) trainval.append(data_list) self._save_data_list(data_list, self.processed_paths[i]) self._save_data_list( data_raw_list, self.processed_raw_paths[i], save_bool=len(data_raw_list) > 0) self._save_data_list(self._re_index_trainval( trainval), self.processed_paths[3]) self._save_data_list(self._re_index_trainval( raw_trainval), self.processed_raw_paths[3], save_bool=len(raw_trainval) > 0) def __repr__(self): return "{}({}, categories={})".format(self.__class__.__name__, len(self), self.categories)
[docs]class ShapeNetDataset(BaseDataset): """ Wrapper around ShapeNet that creates train and test datasets. Parameters ---------- dataset_opt: omegaconf.DictConfig Config dictionary that should contain - dataroot - category: List of categories or All - normal: bool, include normals or not - pre_transforms - train_transforms - test_transforms - val_transforms """ FORWARD_CLASS = "forward.shapenet.ForwardShapenetDataset" def __init__(self, dataset_opt): super().__init__(dataset_opt) try: self._category = dataset_opt.category is_test = dataset_opt.get("is_test", False) except KeyError: self._category = None self.train_dataset = ShapeNet( self._data_path, self._category, include_normals=dataset_opt.normal, split="train", pre_transform=self.pre_transform, transform=self.train_transform, is_test=is_test, ) self.val_dataset = ShapeNet( self._data_path, self._category, include_normals=dataset_opt.normal, split="val", pre_transform=self.pre_transform, transform=self.val_transform, is_test=is_test, ) self.test_dataset = ShapeNet( self._data_path, self._category, include_normals=dataset_opt.normal, split="test", transform=self.test_transform, pre_transform=self.pre_transform, is_test=is_test, ) self._categories = self.train_dataset.categories @property # type: ignore @save_used_properties def class_to_segments(self): classes_to_segment = {} for key in self._categories: classes_to_segment[key] = ShapeNet.seg_classes[key] return classes_to_segment @property def is_hierarchical(self): return len(self._categories) > 1 def get_tracker(self, wandb_log: bool, tensorboard_log: bool): """Factory method for the tracker Arguments: wandb_log - Log using weight and biases tensorboard_log - Log using tensorboard Returns: [BaseTracker] -- tracker """ return ShapenetPartTracker(self, wandb_log=wandb_log, use_tensorboard=tensorboard_log)