import os
import os.path as osp
from itertools import repeat, product
import numpy as np
import h5py
import torch
import random
import glob
from plyfile import PlyData, PlyElement
from torch_geometric.data import InMemoryDataset, Data, extract_zip, Dataset
from torch_geometric.data.dataset import files_exist
from torch_geometric.data import DataLoader
from torch_geometric.datasets import S3DIS as S3DIS1x1
import torch_geometric.transforms as T
import logging
from sklearn.neighbors import NearestNeighbors, KDTree
from tqdm.auto import tqdm as tq
import csv
import pandas as pd
import pickle
import gdown
import shutil
from torch_points3d.datasets.samplers import BalancedRandomSampler
import torch_points3d.core.data_transform as cT
from torch_points3d.datasets.base_dataset import BaseDataset
DIR = os.path.dirname(os.path.realpath(__file__))
log = logging.getLogger(__name__)
S3DIS_NUM_CLASSES = 13
INV_OBJECT_LABEL = {
0: "ceiling",
1: "floor",
2: "wall",
3: "beam",
4: "column",
5: "window",
6: "door",
7: "chair",
8: "table",
9: "bookcase",
10: "sofa",
11: "board",
12: "clutter",
}
OBJECT_COLOR = np.asarray(
[
[233, 229, 107], # 'ceiling' .-> .yellow
[95, 156, 196], # 'floor' .-> . blue
[179, 116, 81], # 'wall' -> brown
[241, 149, 131], # 'beam' -> salmon
[81, 163, 148], # 'column' -> bluegreen
[77, 174, 84], # 'window' -> bright green
[108, 135, 75], # 'door' -> dark green
[41, 49, 101], # 'chair' -> darkblue
[79, 79, 76], # 'table' -> dark grey
[223, 52, 52], # 'bookcase' -> red
[89, 47, 95], # 'sofa' -> purple
[81, 109, 114], # 'board' -> grey
[233, 233, 229], # 'clutter' -> light grey
[0, 0, 0], # unlabelled .->. black
]
)
OBJECT_LABEL = {name: i for i, name in INV_OBJECT_LABEL.items()}
ROOM_TYPES = {
"conferenceRoom": 0,
"copyRoom": 1,
"hallway": 2,
"office": 3,
"pantry": 4,
"WC": 5,
"auditorium": 6,
"storage": 7,
"lounge": 8,
"lobby": 9,
"openspace": 10,
}
VALIDATION_ROOMS = [
"hallway_1",
"hallway_6",
"hallway_11",
"office_1",
"office_6",
"office_11",
"office_16",
"office_21",
"office_26",
"office_31",
"office_36",
"WC_2",
"storage_1",
"storage_5",
"conferenceRoom_2",
"auditorium_1",
]
################################### UTILS #######################################
def object_name_to_label(object_class):
"""convert from object name in S3DIS to an int"""
object_label = OBJECT_LABEL.get(object_class, OBJECT_LABEL["clutter"])
return object_label
def read_s3dis_format(train_file, room_name, label_out=True, verbose=False, debug=False):
"""extract data from a room folder"""
room_type = room_name.split("_")[0]
room_label = ROOM_TYPES[room_type]
raw_path = osp.join(train_file, "{}.txt".format(room_name))
if debug:
reader = pd.read_csv(raw_path, delimiter="\n")
RECOMMENDED = 6
for idx, row in enumerate(reader.values):
row = row[0].split(" ")
if len(row) != RECOMMENDED:
log.info("1: {} row {}: {}".format(raw_path, idx, row))
try:
for r in row:
r = float(r)
except:
log.info("2: {} row {}: {}".format(raw_path, idx, row))
return True
else:
room_ver = pd.read_csv(raw_path, sep=" ", header=None).values
xyz = np.ascontiguousarray(room_ver[:, 0:3], dtype="float32")
try:
rgb = np.ascontiguousarray(room_ver[:, 3:6], dtype="uint8")
except ValueError:
rgb = np.zeros((room_ver.shape[0], 3), dtype="uint8")
log.warning("WARN - corrupted rgb data for file %s" % raw_path)
if not label_out:
return xyz, rgb
n_ver = len(room_ver)
del room_ver
nn = NearestNeighbors(n_neighbors=1, algorithm="kd_tree").fit(xyz)
semantic_labels = np.zeros((n_ver,), dtype="int64")
room_label = np.asarray([room_label])
instance_labels = np.zeros((n_ver,), dtype="int64")
objects = glob.glob(osp.join(train_file, "Annotations/*.txt"))
i_object = 1
for single_object in objects:
object_name = os.path.splitext(os.path.basename(single_object))[0]
if verbose:
log.debug("adding object " + str(i_object) + " : " + object_name)
object_class = object_name.split("_")[0]
object_label = object_name_to_label(object_class)
obj_ver = pd.read_csv(single_object, sep=" ", header=None).values
_, obj_ind = nn.kneighbors(obj_ver[:, 0:3])
semantic_labels[obj_ind] = object_label
instance_labels[obj_ind] = i_object
i_object = i_object + 1
return (
torch.from_numpy(xyz),
torch.from_numpy(rgb),
torch.from_numpy(semantic_labels),
torch.from_numpy(instance_labels),
torch.from_numpy(room_label),
)
def to_ply(pos, label, file):
assert len(label.shape) == 1
assert pos.shape[0] == label.shape[0]
pos = np.asarray(pos)
colors = OBJECT_COLOR[np.asarray(label)]
ply_array = np.ones(
pos.shape[0], dtype=[("x", "f4"), ("y", "f4"), ("z", "f4"), ("red", "u1"), ("green", "u1"), ("blue", "u1")]
)
ply_array["x"] = pos[:, 0]
ply_array["y"] = pos[:, 1]
ply_array["z"] = pos[:, 2]
ply_array["red"] = colors[:, 0]
ply_array["green"] = colors[:, 1]
ply_array["blue"] = colors[:, 2]
el = PlyElement.describe(ply_array, "S3DIS")
PlyData([el], byte_order=">").write(file)
################################### 1m cylinder s3dis ###################################
[docs]class S3DIS1x1Dataset(BaseDataset):
def __init__(self, dataset_opt):
super().__init__(dataset_opt)
pre_transform = self.pre_transform
self.train_dataset = S3DIS1x1(
self._data_path,
test_area=self.dataset_opt.fold,
train=True,
pre_transform=self.pre_transform,
transform=self.train_transform,
)
self.test_dataset = S3DIS1x1(
self._data_path,
test_area=self.dataset_opt.fold,
train=False,
pre_transform=pre_transform,
transform=self.test_transform,
)
if dataset_opt.class_weight_method:
self.add_weights(class_weight_method=dataset_opt.class_weight_method)
def get_tracker(self, wandb_log: bool, tensorboard_log: bool):
"""Factory method for the tracker
Arguments:
wandb_log - Log using weight and biases
tensorboard_log - Log using tensorboard
Returns:
[BaseTracker] -- tracker
"""
from torch_points3d.metrics.segmentation_tracker import SegmentationTracker
return SegmentationTracker(self, wandb_log=wandb_log, use_tensorboard=tensorboard_log)
################################### Used for fused s3dis radius sphere ###################################
[docs]class S3DISOriginalFused(InMemoryDataset):
""" Original S3DIS dataset. Each area is loaded individually and can be processed using a pre_collate transform.
This transform can be used for example to fuse the area into a single space and split it into
spheres or smaller regions. If no fusion is applied, each element in the dataset is a single room by default.
http://buildingparser.stanford.edu/dataset.html
Parameters
----------
root: str
path to the directory where the data will be saved
test_area: int
number between 1 and 6 that denotes the area used for testing
split: str
can be one of train, trainval, val or test
pre_collate_transform:
Transforms to be applied before the data is assembled into samples (apply fusing here for example)
keep_instance: bool
set to True if you wish to keep instance data
pre_transform
transform
pre_filter
"""
form_url = (
"https://docs.google.com/forms/d/e/1FAIpQLScDimvNMCGhy_rmBA2gHfDu3naktRm6A8BPwAWWDv-Uhm6Shw/viewform?c=0&w=1"
)
download_url = "https://drive.google.com/uc?id=0BweDykwS9vIobkVPN0wzRzFwTDg&export=download"
zip_name = "Stanford3dDataset_v1.2_Version.zip"
path_file = osp.join(DIR, "s3dis.patch")
file_name = "Stanford3dDataset_v1.2"
folders = ["Area_{}".format(i) for i in range(1, 7)]
num_classes = S3DIS_NUM_CLASSES
def __init__(
self,
root,
test_area=6,
split="train",
transform=None,
pre_transform=None,
pre_collate_transform=None,
pre_filter=None,
keep_instance=False,
verbose=False,
debug=False,
):
assert test_area >= 1 and test_area <= 6
self.transform = transform
self.pre_collate_transform = pre_collate_transform
self.test_area = test_area
self.keep_instance = keep_instance
self.verbose = verbose
self.debug = debug
self._split = split
super(S3DISOriginalFused, self).__init__(root, transform, pre_transform, pre_filter)
if split == "train":
path = self.processed_paths[0]
elif split == "val":
path = self.processed_paths[1]
elif split == "test":
path = self.processed_paths[2]
elif split == "trainval":
path = self.processed_paths[3]
else:
raise ValueError((f"Split {split} found, but expected either " "train, val, trainval or test"))
self._load_data(path)
if split == "test":
self.raw_test_data = torch.load(self.raw_areas_paths[test_area - 1])
@property
def center_labels(self):
if hasattr(self.data, "center_label"):
return self.data.center_label
else:
return None
@property
def raw_file_names(self):
return self.folders
@property
def pre_processed_path(self):
pre_processed_file_names = "preprocessed.pt"
return os.path.join(self.processed_dir, pre_processed_file_names)
@property
def raw_areas_paths(self):
return [os.path.join(self.processed_dir, "raw_area_%i.pt" % i) for i in range(6)]
@property
def processed_file_names(self):
test_area = self.test_area
return (
["{}_{}.pt".format(s, test_area) for s in ["train", "val", "test", "trainval"]]
+ self.raw_areas_paths
+ [self.pre_processed_path]
)
@property
def raw_test_data(self):
return self._raw_test_data
@raw_test_data.setter
def raw_test_data(self, value):
self._raw_test_data = value
def download(self):
raw_folders = os.listdir(self.raw_dir)
if len(raw_folders) == 0:
if not os.path.exists(osp.join(self.root, self.zip_name)):
log.info("WARNING: You are downloading S3DIS dataset")
log.info("Please, register yourself by filling up the form at {}".format(self.form_url))
log.info("***")
log.info(
"Press any key to continue, or CTRL-C to exit. By continuing, you confirm filling up the form."
)
input("")
gdown.download(self.download_url, osp.join(self.root, self.zip_name), quiet=False)
extract_zip(os.path.join(self.root, self.zip_name), self.root)
shutil.rmtree(self.raw_dir)
os.rename(osp.join(self.root, self.file_name), self.raw_dir)
shutil.copy(self.path_file, self.raw_dir)
cmd = "patch -ruN -p0 -d {} < {}".format(self.raw_dir, osp.join(self.raw_dir, "s3dis.patch"))
os.system(cmd)
else:
intersection = len(set(self.folders).intersection(set(raw_folders)))
if intersection != 6:
shutil.rmtree(self.raw_dir)
os.makedirs(self.raw_dir)
self.download()
def process(self):
if not os.path.exists(self.pre_processed_path):
train_areas = [f for f in self.folders if str(self.test_area) not in f]
test_areas = [f for f in self.folders if str(self.test_area) in f]
train_files = [
(f, room_name, osp.join(self.raw_dir, f, room_name))
for f in train_areas
for room_name in os.listdir(osp.join(self.raw_dir, f))
if os.path.isdir(osp.join(self.raw_dir, f, room_name))
]
test_files = [
(f, room_name, osp.join(self.raw_dir, f, room_name))
for f in test_areas
for room_name in os.listdir(osp.join(self.raw_dir, f))
if os.path.isdir(osp.join(self.raw_dir, f, room_name))
]
# Gather data per area
data_list = [[] for _ in range(6)]
if self.debug:
areas = np.zeros(7)
for (area, room_name, file_path) in tq(train_files + test_files):
if self.debug:
area_idx = int(area.split("_")[-1])
if areas[area_idx] == 5:
continue
else:
print(area_idx)
areas[area_idx] += 1
area_num = int(area[-1]) - 1
if self.debug:
read_s3dis_format(file_path, room_name, label_out=True, verbose=self.verbose, debug=self.debug)
continue
else:
xyz, rgb, semantic_labels, instance_labels, room_label = read_s3dis_format(
file_path, room_name, label_out=True, verbose=self.verbose, debug=self.debug
)
rgb_norm = rgb.float() / 255.0
data = Data(pos=xyz, y=semantic_labels, rgb=rgb_norm)
if room_name in VALIDATION_ROOMS:
data.validation_set = True
else:
data.validation_set = False
if self.keep_instance:
data.instance_labels = instance_labels
if self.pre_filter is not None and not self.pre_filter(data):
continue
data_list[area_num].append(data)
raw_areas = cT.PointCloudFusion()(data_list)
for i, area in enumerate(raw_areas):
torch.save(area, self.raw_areas_paths[i])
for area_datas in data_list:
# Apply pre_transform
if self.pre_transform is not None:
for data in area_datas:
data = self.pre_transform(data)
torch.save(data_list, self.pre_processed_path)
else:
data_list = torch.load(self.pre_processed_path)
if self.debug:
return
train_data_list = {}
val_data_list = {}
trainval_data_list = {}
for i in range(6):
if i != self.test_area - 1:
train_data_list[i] = []
val_data_list[i] = []
for data in data_list[i]:
validation_set = data.validation_set
del data.validation_set
if validation_set:
val_data_list[i].append(data)
else:
train_data_list[i].append(data)
trainval_data_list[i] = val_data_list[i] + train_data_list[i]
train_data_list = list(train_data_list.values())
val_data_list = list(val_data_list.values())
trainval_data_list = list(trainval_data_list.values())
test_data_list = data_list[self.test_area - 1]
if self.pre_collate_transform:
log.info("pre_collate_transform ...")
log.info(self.pre_collate_transform)
train_data_list = self.pre_collate_transform(train_data_list)
val_data_list = self.pre_collate_transform(val_data_list)
test_data_list = self.pre_collate_transform(test_data_list)
trainval_data_list = self.pre_collate_transform(trainval_data_list)
self._save_data(train_data_list, val_data_list, test_data_list, trainval_data_list)
def _save_data(self, train_data_list, val_data_list, test_data_list, trainval_data_list):
torch.save(self.collate(train_data_list), self.processed_paths[0])
torch.save(self.collate(val_data_list), self.processed_paths[1])
torch.save(self.collate(test_data_list), self.processed_paths[2])
torch.save(self.collate(trainval_data_list), self.processed_paths[3])
def _load_data(self, path):
self.data, self.slices = torch.load(path)
[docs]class S3DISSphere(S3DISOriginalFused):
""" Small variation of S3DISOriginalFused that allows random sampling of spheres
within an Area during training and validation. Spheres have a radius of 2m. If sample_per_epoch is not specified, spheres
are taken on a 2m grid.
http://buildingparser.stanford.edu/dataset.html
Parameters
----------
root: str
path to the directory where the data will be saved
test_area: int
number between 1 and 6 that denotes the area used for testing
train: bool
Is this a train split or not
pre_collate_transform:
Transforms to be applied before the data is assembled into samples (apply fusing here for example)
keep_instance: bool
set to True if you wish to keep instance data
sample_per_epoch
Number of spheres that are randomly sampled at each epoch (-1 for fixed grid)
radius
radius of each sphere
pre_transform
transform
pre_filter
"""
def __init__(self, root, sample_per_epoch=100, radius=2, *args, **kwargs):
self._sample_per_epoch = sample_per_epoch
self._radius = radius
self._grid_sphere_sampling = cT.GridSampling3D(size=radius / 10.0)
super().__init__(root, *args, **kwargs)
def __len__(self):
if self._sample_per_epoch > 0:
return self._sample_per_epoch
else:
return len(self._test_spheres)
def len(self):
return len(self)
def get(self, idx):
if self._sample_per_epoch > 0:
return self._get_random()
else:
return self._test_spheres[idx].clone()
def process(self): # We have to include this method, otherwise the parent class skips processing
super().process()
def download(self): # We have to include this method, otherwise the parent class skips download
super().download()
def _get_random(self):
# Random spheres biased towards getting more low frequency classes
chosen_label = np.random.choice(self._labels, p=self._label_counts)
valid_centres = self._centres_for_sampling[self._centres_for_sampling[:, 4] == chosen_label]
centre_idx = int(random.random() * (valid_centres.shape[0] - 1))
centre = valid_centres[centre_idx]
area_data = self._datas[centre[3].int()]
sphere_sampler = cT.SphereSampling(self._radius, centre[:3], align_origin=False)
return sphere_sampler(area_data)
def _save_data(self, train_data_list, val_data_list, test_data_list, trainval_data_list):
torch.save(train_data_list, self.processed_paths[0])
torch.save(val_data_list, self.processed_paths[1])
torch.save(test_data_list, self.processed_paths[2])
torch.save(trainval_data_list, self.processed_paths[3])
def _load_data(self, path):
self._datas = torch.load(path)
if not isinstance(self._datas, list):
self._datas = [self._datas]
if self._sample_per_epoch > 0:
self._centres_for_sampling = []
for i, data in enumerate(self._datas):
assert not hasattr(
data, cT.SphereSampling.KDTREE_KEY
) # Just to make we don't have some out of date data in there
low_res = self._grid_sphere_sampling(data.clone())
centres = torch.empty((low_res.pos.shape[0], 5), dtype=torch.float)
centres[:, :3] = low_res.pos
centres[:, 3] = i
centres[:, 4] = low_res.y
self._centres_for_sampling.append(centres)
tree = KDTree(np.asarray(data.pos), leaf_size=10)
setattr(data, cT.SphereSampling.KDTREE_KEY, tree)
self._centres_for_sampling = torch.cat(self._centres_for_sampling, 0)
uni, uni_counts = np.unique(np.asarray(self._centres_for_sampling[:, -1]), return_counts=True)
uni_counts = np.sqrt(uni_counts.mean() / uni_counts)
self._label_counts = uni_counts / np.sum(uni_counts)
self._labels = uni
else:
grid_sampler = cT.GridSphereSampling(self._radius, self._radius, center=False)
self._test_spheres = grid_sampler(self._datas)
class S3DISCylinder(S3DISSphere):
def _get_random(self):
# Random spheres biased towards getting more low frequency classes
chosen_label = np.random.choice(self._labels, p=self._label_counts)
valid_centres = self._centres_for_sampling[self._centres_for_sampling[:, 4] == chosen_label]
centre_idx = int(random.random() * (valid_centres.shape[0] - 1))
centre = valid_centres[centre_idx]
area_data = self._datas[centre[3].int()]
cylinder_sampler = cT.CylinderSampling(self._radius, centre[:3], align_origin=False)
return cylinder_sampler(area_data)
def _load_data(self, path):
self._datas = torch.load(path)
if not isinstance(self._datas, list):
self._datas = [self._datas]
if self._sample_per_epoch > 0:
self._centres_for_sampling = []
for i, data in enumerate(self._datas):
assert not hasattr(
data, cT.CylinderSampling.KDTREE_KEY
) # Just to make we don't have some out of date data in there
low_res = self._grid_sphere_sampling(data.clone())
centres = torch.empty((low_res.pos.shape[0], 5), dtype=torch.float)
centres[:, :3] = low_res.pos
centres[:, 3] = i
centres[:, 4] = low_res.y
self._centres_for_sampling.append(centres)
tree = KDTree(np.asarray(data.pos[:, :-1]), leaf_size=10)
setattr(data, cT.CylinderSampling.KDTREE_KEY, tree)
self._centres_for_sampling = torch.cat(self._centres_for_sampling, 0)
uni, uni_counts = np.unique(np.asarray(self._centres_for_sampling[:, -1]), return_counts=True)
uni_counts = np.sqrt(uni_counts.mean() / uni_counts)
self._label_counts = uni_counts / np.sum(uni_counts)
self._labels = uni
else:
grid_sampler = cT.GridCylinderSampling(self._radius, self._radius, center=False)
self._test_spheres = grid_sampler(self._datas)
[docs]class S3DISFusedDataset(BaseDataset):
""" Wrapper around S3DISSphere that creates train and test datasets.
http://buildingparser.stanford.edu/dataset.html
Parameters
----------
dataset_opt: omegaconf.DictConfig
Config dictionary that should contain
- dataroot
- fold: test_area parameter
- pre_collate_transform
- train_transforms
- test_transforms
"""
INV_OBJECT_LABEL = INV_OBJECT_LABEL
def __init__(self, dataset_opt):
super().__init__(dataset_opt)
sampling_format = dataset_opt.get("sampling_format", "sphere")
dataset_cls = S3DISCylinder if sampling_format == "cylinder" else S3DISSphere
self.train_dataset = dataset_cls(
self._data_path,
sample_per_epoch=3000,
test_area=self.dataset_opt.fold,
split="train",
pre_collate_transform=self.pre_collate_transform,
transform=self.train_transform,
)
self.val_dataset = dataset_cls(
self._data_path,
sample_per_epoch=-1,
test_area=self.dataset_opt.fold,
split="val",
pre_collate_transform=self.pre_collate_transform,
transform=self.val_transform,
)
self.test_dataset = dataset_cls(
self._data_path,
sample_per_epoch=-1,
test_area=self.dataset_opt.fold,
split="test",
pre_collate_transform=self.pre_collate_transform,
transform=self.test_transform,
)
if dataset_opt.class_weight_method:
self.add_weights(class_weight_method=dataset_opt.class_weight_method)
@property
def test_data(self):
return self.test_dataset[0].raw_test_data
@staticmethod
def to_ply(pos, label, file):
""" Allows to save s3dis predictions to disk using s3dis color scheme
Parameters
----------
pos : torch.Tensor
tensor that contains the positions of the points
label : torch.Tensor
predicted label
file : string
Save location
"""
to_ply(pos, label, file)
def get_tracker(self, wandb_log: bool, tensorboard_log: bool):
"""Factory method for the tracker
Arguments:
wandb_log - Log using weight and biases
tensorboard_log - Log using tensorboard
Returns:
[BaseTracker] -- tracker
"""
from torch_points3d.metrics.s3dis_tracker import S3DISTracker
return S3DISTracker(self, wandb_log=wandb_log, use_tensorboard=tensorboard_log)