import os
import os.path as osp
import shutil
import json
import torch
from glob import glob
import sys
import csv
import logging
import numpy as np
from plyfile import PlyData, PlyElement
from torch_geometric.data import Data, InMemoryDataset, download_url, extract_zip
import torch_geometric.transforms as T
import multiprocessing
import pandas as pd
import tempfile
import urllib
from urllib.request import urlopen
from torch_points3d.datasets.base_dataset import BaseDataset
import torch_points3d.core.data_transform as cT
from . import IGNORE_LABEL
log = logging.getLogger(__name__)
# Ref: https://github.com/xjwang-cs/TSDF_utils/blob/master/download-scannet.py
########################################################################################
# #
# Download script #
# #
########################################################################################
BASE_URL = "http://kaldir.vc.in.tum.de/scannet/"
TOS_URL = BASE_URL + "ScanNet_TOS.pdf"
FILETYPES = [
".aggregation.json",
".sens",
".txt",
"_vh_clean.ply",
"_vh_clean_2.0.010000.segs.json",
"_vh_clean_2.ply",
"_vh_clean.segs.json",
"_vh_clean.aggregation.json",
"_vh_clean_2.labels.ply",
"_2d-instance.zip",
"_2d-instance-filt.zip",
"_2d-label.zip",
"_2d-label-filt.zip",
]
FILETYPES_TEST = [".sens", ".txt", "_vh_clean.ply", "_vh_clean_2.ply"]
PREPROCESSED_FRAMES_FILE = ["scannet_frames_25k.zip", "5.6GB"]
TEST_FRAMES_FILE = ["scannet_frames_test.zip", "610MB"]
LABEL_MAP_FILES = ["scannetv2-labels.combined.tsv", "scannet-labels.combined.tsv"]
RELEASES = ["v2/scans", "v1/scans"]
RELEASES_TASKS = ["v2/tasks", "v1/tasks"]
RELEASES_NAMES = ["v2", "v1"]
RELEASE = RELEASES[0]
RELEASE_TASKS = RELEASES_TASKS[0]
RELEASE_NAME = RELEASES_NAMES[0]
LABEL_MAP_FILE = LABEL_MAP_FILES[0]
RELEASE_SIZE = "1.2TB"
V1_IDX = 1
NUM_CLASSES = 41
CLASS_LABELS = (
"wall",
"floor",
"cabinet",
"bed",
"chair",
"sofa",
"table",
"door",
"window",
"bookshelf",
"picture",
"counter",
"desk",
"curtain",
"refrigerator",
"shower curtain",
"toilet",
"sink",
"bathtub",
"otherfurniture",
)
URLS_METADATA = [
"https://raw.githubusercontent.com/facebookresearch/votenet/master/scannet/meta_data/scannetv2-labels.combined.tsv",
"https://raw.githubusercontent.com/facebookresearch/votenet/master/scannet/meta_data/scannetv2_train.txt",
"https://raw.githubusercontent.com/facebookresearch/votenet/master/scannet/meta_data/scannetv2_test.txt",
"https://raw.githubusercontent.com/facebookresearch/votenet/master/scannet/meta_data/scannetv2_val.txt",
]
VALID_CLASS_IDS = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39]
SCANNET_COLOR_MAP = {
0: (0.0, 0.0, 0.0),
1: (174.0, 199.0, 232.0),
2: (152.0, 223.0, 138.0),
3: (31.0, 119.0, 180.0),
4: (255.0, 187.0, 120.0),
5: (188.0, 189.0, 34.0),
6: (140.0, 86.0, 75.0),
7: (255.0, 152.0, 150.0),
8: (214.0, 39.0, 40.0),
9: (197.0, 176.0, 213.0),
10: (148.0, 103.0, 189.0),
11: (196.0, 156.0, 148.0),
12: (23.0, 190.0, 207.0),
14: (247.0, 182.0, 210.0),
15: (66.0, 188.0, 102.0),
16: (219.0, 219.0, 141.0),
17: (140.0, 57.0, 197.0),
18: (202.0, 185.0, 52.0),
19: (51.0, 176.0, 203.0),
20: (200.0, 54.0, 131.0),
21: (92.0, 193.0, 61.0),
22: (78.0, 71.0, 183.0),
23: (172.0, 114.0, 82.0),
24: (255.0, 127.0, 14.0),
25: (91.0, 163.0, 138.0),
26: (153.0, 98.0, 156.0),
27: (140.0, 153.0, 101.0),
28: (158.0, 218.0, 229.0),
29: (100.0, 125.0, 154.0),
30: (178.0, 127.0, 135.0),
32: (146.0, 111.0, 194.0),
33: (44.0, 160.0, 44.0),
34: (112.0, 128.0, 144.0),
35: (96.0, 207.0, 209.0),
36: (227.0, 119.0, 194.0),
37: (213.0, 92.0, 176.0),
38: (94.0, 106.0, 211.0),
39: (82.0, 84.0, 163.0),
40: (100.0, 85.0, 144.0),
}
SPLITS = ["train", "val", "test"]
MAX_NUM_POINTS = 1200000
def get_release_scans(release_file):
scan_lines = urlopen(release_file)
scans = []
for scan_line in scan_lines:
scan_id = scan_line.decode("utf8").rstrip("\n")
scans.append(scan_id)
return scans
def download_release(release_scans, out_dir, file_types, use_v1_sens):
if len(release_scans) == 0:
return
log.info("Downloading ScanNet " + RELEASE_NAME + " release to " + out_dir + "...")
failed = []
for scan_id in release_scans:
scan_out_dir = os.path.join(out_dir, scan_id)
try:
download_scan(scan_id, scan_out_dir, file_types, use_v1_sens)
except:
failed.append(scan_id)
log.info("Downloaded ScanNet " + RELEASE_NAME + " release.")
if len(failed):
log.warning("Failed downloads: {}".format(failed))
def download_file(url, out_file):
out_dir = os.path.dirname(out_file)
if not os.path.isdir(out_dir):
os.makedirs(out_dir)
if not os.path.isfile(out_file):
log.info("\t" + url + " > " + out_file)
fh, out_file_tmp = tempfile.mkstemp(dir=out_dir)
f = os.fdopen(fh, "w")
f.close()
urllib.request.urlretrieve(url, out_file_tmp)
# urllib.urlretrieve(url, out_file_tmp)
os.rename(out_file_tmp, out_file)
else:
pass
# log.warning("WARNING Skipping download of existing file " + out_file)
def download_scan(scan_id, out_dir, file_types, use_v1_sens):
# log.info("Downloading ScanNet " + RELEASE_NAME + " scan " + scan_id + " ...")
if not os.path.isdir(out_dir):
os.makedirs(out_dir)
for ft in file_types:
v1_sens = use_v1_sens and ft == ".sens"
url = (
BASE_URL + RELEASE + "/" + scan_id + "/" + scan_id + ft
if not v1_sens
else BASE_URL + RELEASES[V1_IDX] + "/" + scan_id + "/" + scan_id + ft
)
out_file = out_dir + "/" + scan_id + ft
download_file(url, out_file)
# log.info("Downloaded scan " + scan_id)
def download_label_map(out_dir):
log.info("Downloading ScanNet " + RELEASE_NAME + " label mapping file...")
files = [LABEL_MAP_FILE]
for file in files:
url = BASE_URL + RELEASE_TASKS + "/" + file
localpath = os.path.join(out_dir, file)
localdir = os.path.dirname(localpath)
if not os.path.isdir(localdir):
os.makedirs(localdir)
download_file(url, localpath)
log.info("Downloaded ScanNet " + RELEASE_NAME + " label mapping file.")
# REFERENCE TO https://github.com/facebookresearch/votenet/blob/master/scannet/load_scannet_data.py
########################################################################################
# #
# UTILS #
# #
########################################################################################
def represents_int(s):
""" if string s represents an int. """
try:
int(s)
return True
except ValueError:
return False
def read_label_mapping(filename, label_from="raw_category", label_to="nyu40id"):
assert os.path.isfile(filename)
mapping = dict()
with open(filename) as csvfile:
reader = csv.DictReader(csvfile, delimiter="\t")
for row in reader:
mapping[row[label_from]] = int(row[label_to])
if represents_int(list(mapping.keys())[0]):
mapping = {int(k): v for k, v in mapping.items()}
return mapping
def read_mesh_vertices(filename):
"""read XYZ for each vertex."""
assert os.path.isfile(filename)
with open(filename, "rb") as f:
plydata = PlyData.read(f)
num_verts = plydata["vertex"].count
vertices = np.zeros(shape=[num_verts, 3], dtype=np.float32)
vertices[:, 0] = plydata["vertex"].data["x"]
vertices[:, 1] = plydata["vertex"].data["y"]
vertices[:, 2] = plydata["vertex"].data["z"]
return vertices
def read_mesh_vertices_rgb(filename):
"""read XYZ RGB for each vertex.
Note: RGB values are in 0-255
"""
assert os.path.isfile(filename)
with open(filename, "rb") as f:
plydata = PlyData.read(f)
num_verts = plydata["vertex"].count
vertices = np.zeros(shape=[num_verts, 6], dtype=np.float32)
vertices[:, 0] = plydata["vertex"].data["x"]
vertices[:, 1] = plydata["vertex"].data["y"]
vertices[:, 2] = plydata["vertex"].data["z"]
vertices[:, 3] = plydata["vertex"].data["red"]
vertices[:, 4] = plydata["vertex"].data["green"]
vertices[:, 5] = plydata["vertex"].data["blue"]
return vertices
def read_aggregation(filename):
assert os.path.isfile(filename)
object_id_to_segs = {}
label_to_segs = {}
with open(filename) as f:
data = json.load(f)
num_objects = len(data["segGroups"])
for i in range(num_objects):
object_id = data["segGroups"][i]["objectId"] + 1 # instance ids should be 1-indexed
label = data["segGroups"][i]["label"]
segs = data["segGroups"][i]["segments"]
object_id_to_segs[object_id] = segs
if label in label_to_segs:
label_to_segs[label].extend(segs)
else:
label_to_segs[label] = segs
return object_id_to_segs, label_to_segs
def read_segmentation(filename):
assert os.path.isfile(filename)
seg_to_verts = {}
with open(filename) as f:
data = json.load(f)
num_verts = len(data["segIndices"])
for i in range(num_verts):
seg_id = data["segIndices"][i]
if seg_id in seg_to_verts:
seg_to_verts[seg_id].append(i)
else:
seg_to_verts[seg_id] = [i]
return seg_to_verts, num_verts
def export(mesh_file, agg_file, seg_file, meta_file, label_map_file, output_file=None):
"""points are XYZ RGB (RGB in 0-255),
semantic label as nyu40 ids,
instance label as 1-#instance,
box as (cx,cy,cz,dx,dy,dz,semantic_label)
"""
label_map = read_label_mapping(label_map_file, label_from="raw_category", label_to="nyu40id")
mesh_vertices = read_mesh_vertices_rgb(mesh_file)
# Load scene axis alignment matrix
lines = open(meta_file).readlines()
for line in lines:
if "axisAlignment" in line:
axis_align_matrix = [float(x) for x in line.rstrip().strip("axisAlignment = ").split(" ")]
break
axis_align_matrix = np.array(axis_align_matrix).reshape((4, 4))
pts = np.ones((mesh_vertices.shape[0], 4))
pts[:, 0:3] = mesh_vertices[:, 0:3]
pts = np.dot(pts, axis_align_matrix.transpose()) # Nx4
mesh_vertices[:, 0:3] = pts[:, 0:3]
# Load semantic and instance labels
object_id_to_segs, label_to_segs = read_aggregation(agg_file)
seg_to_verts, num_verts = read_segmentation(seg_file)
label_ids = np.zeros(shape=(num_verts), dtype=np.uint32) # 0: unannotated
object_id_to_label_id = {}
for label, segs in label_to_segs.items():
label_id = label_map[label]
for seg in segs:
verts = seg_to_verts[seg]
label_ids[verts] = label_id
instance_ids = np.zeros(shape=(num_verts), dtype=np.uint32) # 0: unannotated
num_instances = len(np.unique(list(object_id_to_segs.keys())))
for object_id, segs in object_id_to_segs.items():
for seg in segs:
verts = seg_to_verts[seg]
instance_ids[verts] = object_id
if object_id not in object_id_to_label_id:
object_id_to_label_id[object_id] = label_ids[verts][0]
instance_bboxes = np.zeros((num_instances, 7))
for obj_id in object_id_to_segs:
label_id = object_id_to_label_id[obj_id]
obj_pc = mesh_vertices[instance_ids == obj_id, 0:3]
if len(obj_pc) == 0:
continue
# Compute axis aligned box
# An axis aligned bounding box is parameterized by
# (cx,cy,cz) and (dx,dy,dz) and label id
# where (cx,cy,cz) is the center point of the box,
# dx is the x-axis length of the box.
xmin = np.min(obj_pc[:, 0])
ymin = np.min(obj_pc[:, 1])
zmin = np.min(obj_pc[:, 2])
xmax = np.max(obj_pc[:, 0])
ymax = np.max(obj_pc[:, 1])
zmax = np.max(obj_pc[:, 2])
bbox = np.array(
[
(xmin + xmax) / 2.0,
(ymin + ymax) / 2.0,
(zmin + zmax) / 2.0,
xmax - xmin,
ymax - ymin,
zmax - zmin,
label_id,
]
)
# NOTE: this assumes obj_id is in 1,2,3,.,,,.NUM_INSTANCES
instance_bboxes[obj_id - 1, :] = bbox
return (
mesh_vertices.astype(np.float32),
label_ids.astype(np.int),
instance_ids.astype(np.int),
instance_bboxes.astype(np.float32),
object_id_to_label_id,
)
########################################################################################
# #
# SCANNET InMemoryDataset DATASET #
# #
########################################################################################
[docs]class Scannet(InMemoryDataset):
"""Scannet dataset, you will have to agree to terms and conditions by hitting enter
so that it downloads the dataset.
http://www.scan-net.org/
Parameters
----------
root : str
Path to the data
split : str, optional
Split used (train, val or test)
transform (callable, optional):
A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
pre_transform (callable, optional):
A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before being saved to disk.
pre_filter (callable, optional):
A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean
value, indicating whether the data object should be included in the final dataset.
version : str, optional
version of scannet, by default "v2"
use_instance_labels : bool, optional
Wether we use instance labels or not, by default False
use_instance_bboxes : bool, optional
Wether we use bounding box labels or not, by default False
donotcare_class_ids : list, optional
Class ids to be discarded
max_num_point : [type], optional
Max number of points to keep during the pre processing step
use_multiprocessing : bool, optional
Wether we use multiprocessing or not
process_workers : int, optional
Number of process workers
normalize_rgb : bool, optional
Normalise rgb values, by default True
"""
CLASS_LABELS = CLASS_LABELS
URLS_METADATA = URLS_METADATA
VALID_CLASS_IDS = VALID_CLASS_IDS
SCANNET_COLOR_MAP = SCANNET_COLOR_MAP
SPLITS = SPLITS
def __init__(
self,
root,
split="train",
transform=None,
pre_transform=None,
pre_filter=None,
version="v2",
use_instance_labels=False,
use_instance_bboxes=False,
donotcare_class_ids=[],
max_num_point=None,
process_workers=4,
types=[".txt", "_vh_clean_2.ply", "_vh_clean_2.0.010000.segs.json", ".aggregation.json"],
normalize_rgb=True,
is_test=False,
):
assert self.SPLITS == ["train", "val", "test"]
if not isinstance(donotcare_class_ids, list):
raise Exception("donotcare_class_ids should be list with indices of class to ignore")
self.donotcare_class_ids = donotcare_class_ids
self.valid_class_idx = [idx for idx in self.VALID_CLASS_IDS if idx not in donotcare_class_ids]
assert version in ["v2", "v1"], "The version should be either v1 or v2"
self.version = version
self.max_num_point = max_num_point
self.use_instance_labels = use_instance_labels
self.use_instance_bboxes = use_instance_bboxes
self.use_multiprocessing = process_workers > 1
self.process_workers = process_workers
self.types = types
self.normalize_rgb = normalize_rgb
self.is_test = is_test
super().__init__(root, transform, pre_transform, pre_filter)
if split == "train":
path = self.processed_paths[0]
elif split == "val":
path = self.processed_paths[1]
elif split == "test":
path = self.processed_paths[2]
else:
raise ValueError((f"Split {split} found, but expected either " "train, val, or test"))
self.data, self.slices = torch.load(path)
if split != "test":
if not use_instance_bboxes:
delattr(self.data, "instance_bboxes")
if not use_instance_labels:
delattr(self.data, "instance_labels")
self.data.y = self._remap_labels(self.data.y)
self.has_labels = True
else:
self.has_labels = False
self.read_from_metadata()
def get_raw_data(self, id_scan, remap_labels=True) -> Data:
"""Grabs the raw data associated with a scan index
Parameters
----------
id_scan : int or torch.Tensor
id of the scan
remap_labels : bool, optional
If True then labels are mapped to the range [IGNORE_LABELS:number of labels].
If using this method to compare ground truth against prediction then set remap_labels to True
"""
stage = self.name
if torch.is_tensor(id_scan):
id_scan = int(id_scan.item())
assert stage in self.SPLITS
mapping_idx_to_scan_names = getattr(self, "MAPPING_IDX_TO_SCAN_{}_NAMES".format(stage.upper()))
scan_name = mapping_idx_to_scan_names[id_scan]
path_to_raw_scan = os.path.join(
self.processed_raw_paths[self.SPLITS.index(stage.lower())], "{}.pt".format(scan_name)
)
data = torch.load(path_to_raw_scan)
data.scan_name = scan_name
data.path_to_raw_scan = path_to_raw_scan
if self.has_labels and remap_labels:
data.y = self._remap_labels(data.y)
return data
@property
def raw_file_names(self):
return ["metadata", "scans", "scannetv2-labels.combined.tsv"]
@property
def processed_file_names(self):
return [
"{}.pt".format(
s,
)
for s in Scannet.SPLITS
]
@property
def processed_raw_paths(self):
processed_raw_paths = [os.path.join(self.processed_dir, "raw_{}".format(s)) for s in Scannet.SPLITS]
for p in processed_raw_paths:
if not os.path.exists(p):
os.makedirs(p)
return processed_raw_paths
@property
def path_to_submission(self):
root = os.getcwd()
path_to_submission = os.path.join(root, "submission_labels")
if not os.path.exists(path_to_submission):
os.makedirs(path_to_submission)
return path_to_submission
@property
def num_classes(self):
return len(Scannet.VALID_CLASS_IDS)
def download_scans(self):
release_file = BASE_URL + RELEASE + ".txt"
release_scans = get_release_scans(release_file)
# release_scans = ["scene0191_00","scene0191_01", "scene0568_00", "scene0568_01"]
file_types = FILETYPES
release_test_file = BASE_URL + RELEASE + "_test.txt"
release_test_scans = get_release_scans(release_test_file)
file_types_test = FILETYPES_TEST
out_dir_scans = os.path.join(self.raw_dir, "scans")
out_dir_test_scans = os.path.join(self.raw_dir, "scans_test")
if self.types: # download file type
file_types = self.types
for file_type in file_types:
if file_type not in FILETYPES:
log.error("ERROR: Invalid file type: " + file_type)
return
file_types_test = []
for file_type in file_types:
if file_type in FILETYPES_TEST:
file_types_test.append(file_type)
download_label_map(self.raw_dir)
log.info("WARNING: You are downloading all ScanNet " + RELEASE_NAME + " scans of type " + file_types[0])
log.info(
"Note that existing scan directories will be skipped. Delete partially downloaded directories to re-download."
)
log.info("***")
log.info("Press any key to continue, or CTRL-C to exit.")
input("")
if self.version == "v2" and ".sens" in file_types:
log.info(
"Note: ScanNet v2 uses the same .sens files as ScanNet v1: Press 'n' to exclude downloading .sens files for each scan"
)
key = input("")
if key.strip().lower() == "n":
file_types.remove(".sens")
download_release(release_scans, out_dir_scans, file_types, use_v1_sens=True)
if self.version == "v2":
download_label_map(self.raw_dir)
download_release(release_test_scans, out_dir_test_scans, file_types_test, use_v1_sens=True)
# download_file(os.path.join(BASE_URL, RELEASE_TASKS, TEST_FRAMES_FILE[0]), os.path.join(out_dir_tasks, TEST_FRAMES_FILE[0]))
def download(self):
if self.is_test:
return
log.info(
"By pressing any key to continue you confirm that you have agreed to the ScanNet terms of use as described at:"
)
log.info(TOS_URL)
log.info("***")
log.info("Press any key to continue, or CTRL-C to exit.")
input("")
self.download_scans()
metadata_path = osp.join(self.raw_dir, "metadata")
if not os.path.exists(metadata_path):
os.makedirs(metadata_path)
for url in self.URLS_METADATA:
_ = download_url(url, metadata_path)
@staticmethod
def read_one_test_scan(scannet_dir, scan_name, normalize_rgb):
mesh_file = osp.join(scannet_dir, scan_name, scan_name + "_vh_clean_2.ply")
mesh_vertices = read_mesh_vertices_rgb(mesh_file)
data = {}
data["pos"] = torch.from_numpy(mesh_vertices[:, :3])
data["rgb"] = torch.from_numpy(mesh_vertices[:, 3:])
if normalize_rgb:
data["rgb"] /= 255.0
return Data(**data)
@staticmethod
def read_one_scan(
scannet_dir,
scan_name,
label_map_file,
donotcare_class_ids,
max_num_point,
obj_class_ids,
normalize_rgb,
):
mesh_file = osp.join(scannet_dir, scan_name, scan_name + "_vh_clean_2.ply")
agg_file = osp.join(scannet_dir, scan_name, scan_name + ".aggregation.json")
seg_file = osp.join(scannet_dir, scan_name, scan_name + "_vh_clean_2.0.010000.segs.json")
meta_file = osp.join(
scannet_dir, scan_name, scan_name + ".txt"
) # includes axisAlignment info for the train set scans.
mesh_vertices, semantic_labels, instance_labels, instance_bboxes, instance2semantic = export(
mesh_file, agg_file, seg_file, meta_file, label_map_file, None
)
# Discard unwanted classes
mask = np.logical_not(np.in1d(semantic_labels, donotcare_class_ids))
mesh_vertices = mesh_vertices[mask, :]
semantic_labels = semantic_labels[mask]
instance_labels = instance_labels[mask]
bbox_mask = np.in1d(instance_bboxes[:, -1], obj_class_ids)
instance_bboxes = instance_bboxes[bbox_mask, :]
# Subsample
N = mesh_vertices.shape[0]
if max_num_point:
if N > max_num_point:
choices = np.random.choice(N, max_num_point, replace=False)
mesh_vertices = mesh_vertices[choices, :]
semantic_labels = semantic_labels[choices]
instance_labels = instance_labels[choices]
# Build data container
data = {}
data["pos"] = torch.from_numpy(mesh_vertices[:, :3])
data["rgb"] = torch.from_numpy(mesh_vertices[:, 3:])
if normalize_rgb:
data["rgb"] /= 255.0
data["y"] = torch.from_numpy(semantic_labels)
data["x"] = None
data["instance_labels"] = torch.from_numpy(instance_labels)
data["instance_bboxes"] = torch.from_numpy(instance_bboxes)
return Data(**data)
def read_from_metadata(self):
metadata_path = osp.join(self.raw_dir, "metadata")
self.label_map_file = osp.join(metadata_path, LABEL_MAP_FILE)
split_files = ["scannetv2_{}.txt".format(s) for s in Scannet.SPLITS]
self.scan_names = []
for sf in split_files:
f = open(osp.join(metadata_path, sf))
self.scan_names.append(sorted([line.rstrip() for line in f]))
f.close()
for idx_split, split in enumerate(Scannet.SPLITS):
idx_mapping = {idx: scan_name for idx, scan_name in enumerate(self.scan_names[idx_split])}
setattr(self, "MAPPING_IDX_TO_SCAN_{}_NAMES".format(split.upper()), idx_mapping)
@staticmethod
def process_func(
id_scan,
total,
scannet_dir,
scan_name,
label_map_file,
donotcare_class_ids,
max_num_point,
obj_class_ids,
normalize_rgb,
split,
):
if split == "test":
data = Scannet.read_one_test_scan(scannet_dir, scan_name, normalize_rgb)
else:
data = Scannet.read_one_scan(
scannet_dir,
scan_name,
label_map_file,
donotcare_class_ids,
max_num_point,
obj_class_ids,
normalize_rgb,
)
log.info("{}/{}| scan_name: {}, data: {}".format(id_scan, total, scan_name, data))
data["id_scan"] = torch.tensor([id_scan])
return cT.SaveOriginalPosId()(data)
def process(self):
if self.is_test:
return
self.read_from_metadata()
scannet_dir = osp.join(self.raw_dir, "scans")
for i, (scan_names, split) in enumerate(zip(self.scan_names, self.SPLITS)):
if not os.path.exists(self.processed_paths[i]):
mapping_idx_to_scan_names = getattr(self, "MAPPING_IDX_TO_SCAN_{}_NAMES".format(split.upper()))
scannet_dir = osp.join(self.raw_dir, "scans" if split in ["train", "val"] else "scans_test")
total = len(scan_names)
args = [
(
id,
total,
scannet_dir,
scan_name,
self.label_map_file,
self.donotcare_class_ids,
self.max_num_point,
self.VALID_CLASS_IDS,
self.normalize_rgb,
split,
)
for id, scan_name in enumerate(scan_names)
]
if self.use_multiprocessing:
with multiprocessing.get_context("spawn").Pool(processes=self.process_workers) as pool:
datas = pool.starmap(Scannet.process_func, args)
else:
datas = []
for arg in args:
data = Scannet.process_func(*arg)
datas.append(data)
for data in datas:
id_scan = int(data.id_scan.item())
scan_name = mapping_idx_to_scan_names[id_scan]
path_to_raw_scan = os.path.join(self.processed_raw_paths[i], "{}.pt".format(scan_name))
torch.save(data, path_to_raw_scan)
if self.pre_transform:
datas = [self.pre_transform(data) for data in datas]
log.info("SAVING TO {}".format(self.processed_paths[i]))
torch.save(self.collate(datas), self.processed_paths[i])
def _remap_labels(self, semantic_label):
"""Remaps labels to [0 ; num_labels -1]. Can be overriden."""
new_labels = semantic_label.clone()
mapping_dict = {indice: idx for idx, indice in enumerate(self.valid_class_idx)}
for idx in range(NUM_CLASSES):
if idx not in mapping_dict:
mapping_dict[idx] = IGNORE_LABEL
for idx in self.donotcare_class_ids:
mapping_dict[idx] = IGNORE_LABEL
for source, target in mapping_dict.items():
mask = semantic_label == source
new_labels[mask] = target
broken_labels = new_labels >= len(self.valid_class_idx)
new_labels[broken_labels] = IGNORE_LABEL
return new_labels
def __repr__(self):
return "{}({})".format(self.__class__.__name__, len(self))
[docs]class ScannetDataset(BaseDataset):
"""Wrapper around Scannet that creates train and test datasets.
Parameters
----------
dataset_opt: omegaconf.DictConfig
Config dictionary that should contain
- dataroot
- version
- max_num_point (optional)
- use_instance_labels (optional)
- use_instance_bboxes (optional)
- donotcare_class_ids (optional)
- pre_transforms (optional)
- train_transforms (optional)
- val_transforms (optional)
"""
SPLITS = SPLITS
def __init__(self, dataset_opt):
super().__init__(dataset_opt)
use_instance_labels: bool = dataset_opt.use_instance_labels
use_instance_bboxes: bool = dataset_opt.use_instance_bboxes
donotcare_class_ids: [] = list(dataset_opt.get('donotcare_class_ids', []))
max_num_point: int = dataset_opt.get('max_num_point', None)
process_workers: int = dataset_opt.process_workers if hasattr(dataset_opt,'process_workers') else 0
is_test: bool = dataset_opt.get('is_test', False)
self.train_dataset = Scannet(
self._data_path,
split="train",
pre_transform=self.pre_transform,
transform=self.train_transform,
version=dataset_opt.version,
use_instance_labels=use_instance_labels,
use_instance_bboxes=use_instance_bboxes,
donotcare_class_ids=donotcare_class_ids,
max_num_point=max_num_point,
process_workers=process_workers,
is_test=is_test,
)
self.val_dataset = Scannet(
self._data_path,
split="val",
transform=self.val_transform,
pre_transform=self.pre_transform,
version=dataset_opt.version,
use_instance_labels=use_instance_labels,
use_instance_bboxes=use_instance_bboxes,
donotcare_class_ids=donotcare_class_ids,
max_num_point=max_num_point,
process_workers=process_workers,
is_test=is_test,
)
self.test_dataset = Scannet(
self._data_path,
split="test",
transform=self.val_transform,
pre_transform=self.pre_transform,
version=dataset_opt.version,
use_instance_labels=use_instance_labels,
use_instance_bboxes=use_instance_bboxes,
donotcare_class_ids=donotcare_class_ids,
max_num_point=max_num_point,
process_workers=process_workers,
is_test=is_test,
)
@property
def path_to_submission(self):
return self.train_dataset.path_to_submission
def get_tracker(self, wandb_log: bool, tensorboard_log: bool):
"""Factory method for the tracker
Arguments:
dataset {[type]}
wandb_log - Log using weight and biases
Returns:
[BaseTracker] -- tracker
"""
from torch_points3d.metrics.scannet_segmentation_tracker import ScannetSegmentationTracker
return ScannetSegmentationTracker(
self, wandb_log=wandb_log, use_tensorboard=tensorboard_log, ignore_label=IGNORE_LABEL
)