Source code for torch_points3d.datasets.segmentation.s3dis

import os
import os.path as osp
from itertools import repeat, product
import numpy as np
import h5py
import torch
import random
import glob
from plyfile import PlyData, PlyElement
from torch_geometric.data import InMemoryDataset, Data, extract_zip, Dataset
from torch_geometric.data.dataset import files_exist
from torch_geometric.data import DataLoader
from torch_geometric.datasets import S3DIS as S3DIS1x1
import torch_geometric.transforms as T
import logging
from sklearn.neighbors import NearestNeighbors, KDTree
from tqdm.auto import tqdm as tq
import csv
import pandas as pd
import pickle
import gdown
import shutil

from torch_points3d.datasets.samplers import BalancedRandomSampler
import torch_points3d.core.data_transform as cT
from torch_points3d.datasets.base_dataset import BaseDataset

DIR = os.path.dirname(os.path.realpath(__file__))
log = logging.getLogger(__name__)

S3DIS_NUM_CLASSES = 13

INV_OBJECT_LABEL = {
    0: "ceiling",
    1: "floor",
    2: "wall",
    3: "beam",
    4: "column",
    5: "window",
    6: "door",
    7: "chair",
    8: "table",
    9: "bookcase",
    10: "sofa",
    11: "board",
    12: "clutter",
}

OBJECT_COLOR = np.asarray(
    [
        [233, 229, 107],  # 'ceiling' .-> .yellow
        [95, 156, 196],  # 'floor' .-> . blue
        [179, 116, 81],  # 'wall'  ->  brown
        [241, 149, 131],  # 'beam'  ->  salmon
        [81, 163, 148],  # 'column'  ->  bluegreen
        [77, 174, 84],  # 'window'  ->  bright green
        [108, 135, 75],  # 'door'   ->  dark green
        [41, 49, 101],  # 'chair'  ->  darkblue
        [79, 79, 76],  # 'table'  ->  dark grey
        [223, 52, 52],  # 'bookcase'  ->  red
        [89, 47, 95],  # 'sofa'  ->  purple
        [81, 109, 114],  # 'board'   ->  grey
        [233, 233, 229],  # 'clutter'  ->  light grey
        [0, 0, 0],  # unlabelled .->. black
    ]
)

OBJECT_LABEL = {name: i for i, name in INV_OBJECT_LABEL.items()}

ROOM_TYPES = {
    "conferenceRoom": 0,
    "copyRoom": 1,
    "hallway": 2,
    "office": 3,
    "pantry": 4,
    "WC": 5,
    "auditorium": 6,
    "storage": 7,
    "lounge": 8,
    "lobby": 9,
    "openspace": 10,
}

VALIDATION_ROOMS = [
    "hallway_1",
    "hallway_6",
    "hallway_11",
    "office_1",
    "office_6",
    "office_11",
    "office_16",
    "office_21",
    "office_26",
    "office_31",
    "office_36",
    "WC_2",
    "storage_1",
    "storage_5",
    "conferenceRoom_2",
    "auditorium_1",
]

################################### UTILS #######################################


def object_name_to_label(object_class):
    """convert from object name in S3DIS to an int"""
    object_label = OBJECT_LABEL.get(object_class, OBJECT_LABEL["clutter"])
    return object_label


def read_s3dis_format(train_file, room_name, label_out=True, verbose=False, debug=False):
    """extract data from a room folder"""

    room_type = room_name.split("_")[0]
    room_label = ROOM_TYPES[room_type]
    raw_path = osp.join(train_file, "{}.txt".format(room_name))
    if debug:
        reader = pd.read_csv(raw_path, delimiter="\n")
        RECOMMENDED = 6
        for idx, row in enumerate(reader.values):
            row = row[0].split(" ")
            if len(row) != RECOMMENDED:
                log.info("1: {} row {}: {}".format(raw_path, idx, row))

            try:
                for r in row:
                    r = float(r)
            except:
                log.info("2: {} row {}: {}".format(raw_path, idx, row))

        return True
    else:
        room_ver = pd.read_csv(raw_path, sep=" ", header=None).values
        xyz = np.ascontiguousarray(room_ver[:, 0:3], dtype="float32")
        try:
            rgb = np.ascontiguousarray(room_ver[:, 3:6], dtype="uint8")
        except ValueError:
            rgb = np.zeros((room_ver.shape[0], 3), dtype="uint8")
            log.warning("WARN - corrupted rgb data for file %s" % raw_path)
        if not label_out:
            return xyz, rgb
        n_ver = len(room_ver)
        del room_ver
        nn = NearestNeighbors(n_neighbors=1, algorithm="kd_tree").fit(xyz)
        semantic_labels = np.zeros((n_ver,), dtype="int64")
        room_label = np.asarray([room_label])
        instance_labels = np.zeros((n_ver,), dtype="int64")
        objects = glob.glob(osp.join(train_file, "Annotations/*.txt"))
        i_object = 1
        for single_object in objects:
            object_name = os.path.splitext(os.path.basename(single_object))[0]
            if verbose:
                log.debug("adding object " + str(i_object) + " : " + object_name)
            object_class = object_name.split("_")[0]
            object_label = object_name_to_label(object_class)
            obj_ver = pd.read_csv(single_object, sep=" ", header=None).values
            _, obj_ind = nn.kneighbors(obj_ver[:, 0:3])
            semantic_labels[obj_ind] = object_label
            instance_labels[obj_ind] = i_object
            i_object = i_object + 1

        return (
            torch.from_numpy(xyz),
            torch.from_numpy(rgb),
            torch.from_numpy(semantic_labels),
            torch.from_numpy(instance_labels),
            torch.from_numpy(room_label),
        )


def to_ply(pos, label, file):
    assert len(label.shape) == 1
    assert pos.shape[0] == label.shape[0]
    pos = np.asarray(pos)
    colors = OBJECT_COLOR[np.asarray(label)]
    ply_array = np.ones(
        pos.shape[0], dtype=[("x", "f4"), ("y", "f4"), ("z", "f4"), ("red", "u1"), ("green", "u1"), ("blue", "u1")]
    )
    ply_array["x"] = pos[:, 0]
    ply_array["y"] = pos[:, 1]
    ply_array["z"] = pos[:, 2]
    ply_array["red"] = colors[:, 0]
    ply_array["green"] = colors[:, 1]
    ply_array["blue"] = colors[:, 2]
    el = PlyElement.describe(ply_array, "S3DIS")
    PlyData([el], byte_order=">").write(file)


################################### 1m cylinder s3dis ###################################


[docs]class S3DIS1x1Dataset(BaseDataset): def __init__(self, dataset_opt): super().__init__(dataset_opt) pre_transform = self.pre_transform self.train_dataset = S3DIS1x1( self._data_path, test_area=self.dataset_opt.fold, train=True, pre_transform=self.pre_transform, transform=self.train_transform, ) self.test_dataset = S3DIS1x1( self._data_path, test_area=self.dataset_opt.fold, train=False, pre_transform=pre_transform, transform=self.test_transform, ) if dataset_opt.class_weight_method: self.add_weights(class_weight_method=dataset_opt.class_weight_method) def get_tracker(self, wandb_log: bool, tensorboard_log: bool): """Factory method for the tracker Arguments: wandb_log - Log using weight and biases tensorboard_log - Log using tensorboard Returns: [BaseTracker] -- tracker """ from torch_points3d.metrics.segmentation_tracker import SegmentationTracker return SegmentationTracker(self, wandb_log=wandb_log, use_tensorboard=tensorboard_log)
################################### Used for fused s3dis radius sphere ###################################
[docs]class S3DISOriginalFused(InMemoryDataset): """ Original S3DIS dataset. Each area is loaded individually and can be processed using a pre_collate transform. This transform can be used for example to fuse the area into a single space and split it into spheres or smaller regions. If no fusion is applied, each element in the dataset is a single room by default. http://buildingparser.stanford.edu/dataset.html Parameters ---------- root: str path to the directory where the data will be saved test_area: int number between 1 and 6 that denotes the area used for testing split: str can be one of train, trainval, val or test pre_collate_transform: Transforms to be applied before the data is assembled into samples (apply fusing here for example) keep_instance: bool set to True if you wish to keep instance data pre_transform transform pre_filter """ form_url = ( "https://docs.google.com/forms/d/e/1FAIpQLScDimvNMCGhy_rmBA2gHfDu3naktRm6A8BPwAWWDv-Uhm6Shw/viewform?c=0&w=1" ) download_url = "https://drive.google.com/uc?id=0BweDykwS9vIobkVPN0wzRzFwTDg&export=download" zip_name = "Stanford3dDataset_v1.2_Version.zip" path_file = osp.join(DIR, "s3dis.patch") file_name = "Stanford3dDataset_v1.2" folders = ["Area_{}".format(i) for i in range(1, 7)] num_classes = S3DIS_NUM_CLASSES def __init__( self, root, test_area=6, split="train", transform=None, pre_transform=None, pre_collate_transform=None, pre_filter=None, keep_instance=False, verbose=False, debug=False, ): assert test_area >= 1 and test_area <= 6 self.transform = transform self.pre_collate_transform = pre_collate_transform self.test_area = test_area self.keep_instance = keep_instance self.verbose = verbose self.debug = debug self._split = split super(S3DISOriginalFused, self).__init__(root, transform, pre_transform, pre_filter) if split == "train": path = self.processed_paths[0] elif split == "val": path = self.processed_paths[1] elif split == "test": path = self.processed_paths[2] elif split == "trainval": path = self.processed_paths[3] else: raise ValueError((f"Split {split} found, but expected either " "train, val, trainval or test")) self._load_data(path) if split == "test": self.raw_test_data = torch.load(self.raw_areas_paths[test_area - 1]) @property def center_labels(self): if hasattr(self.data, "center_label"): return self.data.center_label else: return None @property def raw_file_names(self): return self.folders @property def pre_processed_path(self): pre_processed_file_names = "preprocessed.pt" return os.path.join(self.processed_dir, pre_processed_file_names) @property def raw_areas_paths(self): return [os.path.join(self.processed_dir, "raw_area_%i.pt" % i) for i in range(6)] @property def processed_file_names(self): test_area = self.test_area return ( ["{}_{}.pt".format(s, test_area) for s in ["train", "val", "test", "trainval"]] + self.raw_areas_paths + [self.pre_processed_path] ) @property def raw_test_data(self): return self._raw_test_data @raw_test_data.setter def raw_test_data(self, value): self._raw_test_data = value def download(self): raw_folders = os.listdir(self.raw_dir) if len(raw_folders) == 0: if not os.path.exists(osp.join(self.root, self.zip_name)): log.info("WARNING: You are downloading S3DIS dataset") log.info("Please, register yourself by filling up the form at {}".format(self.form_url)) log.info("***") log.info( "Press any key to continue, or CTRL-C to exit. By continuing, you confirm filling up the form." ) input("") gdown.download(self.download_url, osp.join(self.root, self.zip_name), quiet=False) extract_zip(os.path.join(self.root, self.zip_name), self.root) shutil.rmtree(self.raw_dir) os.rename(osp.join(self.root, self.file_name), self.raw_dir) shutil.copy(self.path_file, self.raw_dir) cmd = "patch -ruN -p0 -d {} < {}".format(self.raw_dir, osp.join(self.raw_dir, "s3dis.patch")) os.system(cmd) else: intersection = len(set(self.folders).intersection(set(raw_folders))) if intersection != 6: shutil.rmtree(self.raw_dir) os.makedirs(self.raw_dir) self.download() def process(self): if not os.path.exists(self.pre_processed_path): train_areas = [f for f in self.folders if str(self.test_area) not in f] test_areas = [f for f in self.folders if str(self.test_area) in f] train_files = [ (f, room_name, osp.join(self.raw_dir, f, room_name)) for f in train_areas for room_name in os.listdir(osp.join(self.raw_dir, f)) if os.path.isdir(osp.join(self.raw_dir, f, room_name)) ] test_files = [ (f, room_name, osp.join(self.raw_dir, f, room_name)) for f in test_areas for room_name in os.listdir(osp.join(self.raw_dir, f)) if os.path.isdir(osp.join(self.raw_dir, f, room_name)) ] # Gather data per area data_list = [[] for _ in range(6)] if self.debug: areas = np.zeros(7) for (area, room_name, file_path) in tq(train_files + test_files): if self.debug: area_idx = int(area.split("_")[-1]) if areas[area_idx] == 5: continue else: print(area_idx) areas[area_idx] += 1 area_num = int(area[-1]) - 1 if self.debug: read_s3dis_format(file_path, room_name, label_out=True, verbose=self.verbose, debug=self.debug) continue else: xyz, rgb, semantic_labels, instance_labels, room_label = read_s3dis_format( file_path, room_name, label_out=True, verbose=self.verbose, debug=self.debug ) rgb_norm = rgb.float() / 255.0 data = Data(pos=xyz, y=semantic_labels, rgb=rgb_norm) if room_name in VALIDATION_ROOMS: data.validation_set = True else: data.validation_set = False if self.keep_instance: data.instance_labels = instance_labels if self.pre_filter is not None and not self.pre_filter(data): continue data_list[area_num].append(data) raw_areas = cT.PointCloudFusion()(data_list) for i, area in enumerate(raw_areas): torch.save(area, self.raw_areas_paths[i]) for area_datas in data_list: # Apply pre_transform if self.pre_transform is not None: for data in area_datas: data = self.pre_transform(data) torch.save(data_list, self.pre_processed_path) else: data_list = torch.load(self.pre_processed_path) if self.debug: return train_data_list = {} val_data_list = {} trainval_data_list = {} for i in range(6): if i != self.test_area - 1: train_data_list[i] = [] val_data_list[i] = [] for data in data_list[i]: validation_set = data.validation_set del data.validation_set if validation_set: val_data_list[i].append(data) else: train_data_list[i].append(data) trainval_data_list[i] = val_data_list[i] + train_data_list[i] train_data_list = list(train_data_list.values()) val_data_list = list(val_data_list.values()) trainval_data_list = list(trainval_data_list.values()) test_data_list = data_list[self.test_area - 1] if self.pre_collate_transform: log.info("pre_collate_transform ...") log.info(self.pre_collate_transform) train_data_list = self.pre_collate_transform(train_data_list) val_data_list = self.pre_collate_transform(val_data_list) test_data_list = self.pre_collate_transform(test_data_list) trainval_data_list = self.pre_collate_transform(trainval_data_list) self._save_data(train_data_list, val_data_list, test_data_list, trainval_data_list) def _save_data(self, train_data_list, val_data_list, test_data_list, trainval_data_list): torch.save(self.collate(train_data_list), self.processed_paths[0]) torch.save(self.collate(val_data_list), self.processed_paths[1]) torch.save(self.collate(test_data_list), self.processed_paths[2]) torch.save(self.collate(trainval_data_list), self.processed_paths[3]) def _load_data(self, path): self.data, self.slices = torch.load(path)
[docs]class S3DISSphere(S3DISOriginalFused): """ Small variation of S3DISOriginalFused that allows random sampling of spheres within an Area during training and validation. Spheres have a radius of 2m. If sample_per_epoch is not specified, spheres are taken on a 2m grid. http://buildingparser.stanford.edu/dataset.html Parameters ---------- root: str path to the directory where the data will be saved test_area: int number between 1 and 6 that denotes the area used for testing train: bool Is this a train split or not pre_collate_transform: Transforms to be applied before the data is assembled into samples (apply fusing here for example) keep_instance: bool set to True if you wish to keep instance data sample_per_epoch Number of spheres that are randomly sampled at each epoch (-1 for fixed grid) radius radius of each sphere pre_transform transform pre_filter """ def __init__(self, root, sample_per_epoch=100, radius=2, *args, **kwargs): self._sample_per_epoch = sample_per_epoch self._radius = radius self._grid_sphere_sampling = cT.GridSampling3D(size=radius / 10.0) super().__init__(root, *args, **kwargs) def __len__(self): if self._sample_per_epoch > 0: return self._sample_per_epoch else: return len(self._test_spheres) def len(self): return len(self) def get(self, idx): if self._sample_per_epoch > 0: return self._get_random() else: return self._test_spheres[idx].clone() def process(self): # We have to include this method, otherwise the parent class skips processing super().process() def download(self): # We have to include this method, otherwise the parent class skips download super().download() def _get_random(self): # Random spheres biased towards getting more low frequency classes chosen_label = np.random.choice(self._labels, p=self._label_counts) valid_centres = self._centres_for_sampling[self._centres_for_sampling[:, 4] == chosen_label] centre_idx = int(random.random() * (valid_centres.shape[0] - 1)) centre = valid_centres[centre_idx] area_data = self._datas[centre[3].int()] sphere_sampler = cT.SphereSampling(self._radius, centre[:3], align_origin=False) return sphere_sampler(area_data) def _save_data(self, train_data_list, val_data_list, test_data_list, trainval_data_list): torch.save(train_data_list, self.processed_paths[0]) torch.save(val_data_list, self.processed_paths[1]) torch.save(test_data_list, self.processed_paths[2]) torch.save(trainval_data_list, self.processed_paths[3]) def _load_data(self, path): self._datas = torch.load(path) if not isinstance(self._datas, list): self._datas = [self._datas] if self._sample_per_epoch > 0: self._centres_for_sampling = [] for i, data in enumerate(self._datas): assert not hasattr( data, cT.SphereSampling.KDTREE_KEY ) # Just to make we don't have some out of date data in there low_res = self._grid_sphere_sampling(data.clone()) centres = torch.empty((low_res.pos.shape[0], 5), dtype=torch.float) centres[:, :3] = low_res.pos centres[:, 3] = i centres[:, 4] = low_res.y self._centres_for_sampling.append(centres) tree = KDTree(np.asarray(data.pos), leaf_size=10) setattr(data, cT.SphereSampling.KDTREE_KEY, tree) self._centres_for_sampling = torch.cat(self._centres_for_sampling, 0) uni, uni_counts = np.unique(np.asarray(self._centres_for_sampling[:, -1]), return_counts=True) uni_counts = np.sqrt(uni_counts.mean() / uni_counts) self._label_counts = uni_counts / np.sum(uni_counts) self._labels = uni else: grid_sampler = cT.GridSphereSampling(self._radius, self._radius, center=False) self._test_spheres = grid_sampler(self._datas)
class S3DISCylinder(S3DISSphere): def _get_random(self): # Random spheres biased towards getting more low frequency classes chosen_label = np.random.choice(self._labels, p=self._label_counts) valid_centres = self._centres_for_sampling[self._centres_for_sampling[:, 4] == chosen_label] centre_idx = int(random.random() * (valid_centres.shape[0] - 1)) centre = valid_centres[centre_idx] area_data = self._datas[centre[3].int()] cylinder_sampler = cT.CylinderSampling(self._radius, centre[:3], align_origin=False) return cylinder_sampler(area_data) def _load_data(self, path): self._datas = torch.load(path) if not isinstance(self._datas, list): self._datas = [self._datas] if self._sample_per_epoch > 0: self._centres_for_sampling = [] for i, data in enumerate(self._datas): assert not hasattr( data, cT.CylinderSampling.KDTREE_KEY ) # Just to make we don't have some out of date data in there low_res = self._grid_sphere_sampling(data.clone()) centres = torch.empty((low_res.pos.shape[0], 5), dtype=torch.float) centres[:, :3] = low_res.pos centres[:, 3] = i centres[:, 4] = low_res.y self._centres_for_sampling.append(centres) tree = KDTree(np.asarray(data.pos[:, :-1]), leaf_size=10) setattr(data, cT.CylinderSampling.KDTREE_KEY, tree) self._centres_for_sampling = torch.cat(self._centres_for_sampling, 0) uni, uni_counts = np.unique(np.asarray(self._centres_for_sampling[:, -1]), return_counts=True) uni_counts = np.sqrt(uni_counts.mean() / uni_counts) self._label_counts = uni_counts / np.sum(uni_counts) self._labels = uni else: grid_sampler = cT.GridCylinderSampling(self._radius, self._radius, center=False) self._test_spheres = grid_sampler(self._datas)
[docs]class S3DISFusedDataset(BaseDataset): """ Wrapper around S3DISSphere that creates train and test datasets. http://buildingparser.stanford.edu/dataset.html Parameters ---------- dataset_opt: omegaconf.DictConfig Config dictionary that should contain - dataroot - fold: test_area parameter - pre_collate_transform - train_transforms - test_transforms """ INV_OBJECT_LABEL = INV_OBJECT_LABEL def __init__(self, dataset_opt): super().__init__(dataset_opt) sampling_format = dataset_opt.get("sampling_format", "sphere") dataset_cls = S3DISCylinder if sampling_format == "cylinder" else S3DISSphere self.train_dataset = dataset_cls( self._data_path, sample_per_epoch=3000, test_area=self.dataset_opt.fold, split="train", pre_collate_transform=self.pre_collate_transform, transform=self.train_transform, ) self.val_dataset = dataset_cls( self._data_path, sample_per_epoch=-1, test_area=self.dataset_opt.fold, split="val", pre_collate_transform=self.pre_collate_transform, transform=self.val_transform, ) self.test_dataset = dataset_cls( self._data_path, sample_per_epoch=-1, test_area=self.dataset_opt.fold, split="test", pre_collate_transform=self.pre_collate_transform, transform=self.test_transform, ) if dataset_opt.class_weight_method: self.add_weights(class_weight_method=dataset_opt.class_weight_method) @property def test_data(self): return self.test_dataset[0].raw_test_data @staticmethod def to_ply(pos, label, file): """ Allows to save s3dis predictions to disk using s3dis color scheme Parameters ---------- pos : torch.Tensor tensor that contains the positions of the points label : torch.Tensor predicted label file : string Save location """ to_ply(pos, label, file) def get_tracker(self, wandb_log: bool, tensorboard_log: bool): """Factory method for the tracker Arguments: wandb_log - Log using weight and biases tensorboard_log - Log using tensorboard Returns: [BaseTracker] -- tracker """ from torch_points3d.metrics.s3dis_tracker import S3DISTracker return S3DISTracker(self, wandb_log=wandb_log, use_tensorboard=tensorboard_log)