Skip to content

Commit

Permalink
chore: remove a last import of SHAPE_TYPE
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlesGaydon committed Oct 4, 2023
1 parent a0e442d commit 2e0d53a
Showing 1 changed file with 26 additions and 11 deletions.
37 changes: 26 additions & 11 deletions myria3d/pctl/dataset/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from myria3d.pctl.dataset.utils import (
LAS_PATHS_BY_SPLIT_DICT_TYPE,
SHAPE_TYPE,
SPLIT_TYPE,
pre_filter_below_n_points,
split_cloud_into_samples,
Expand Down Expand Up @@ -71,7 +70,9 @@ def __init__(
self._samples_hdf5_paths = None

if not las_paths_by_split_dict:
log.warning("No las_paths_by_split_dict given, pre-computed HDF5 dataset is therefore used.")
log.warning(
"No las_paths_by_split_dict given, pre-computed HDF5 dataset is therefore used."
)
return

# Add data for all LAS Files into a single hdf5 file.
Expand Down Expand Up @@ -151,7 +152,9 @@ def testdata(self):

def _get_split_subset(self, split: SPLIT_TYPE):
"""Get a sub-dataset of a specific (train/val/test) split."""
indices = [idx for idx, p in enumerate(self.samples_hdf5_paths) if p.startswith(split)]
indices = [
idx for idx, p in enumerate(self.samples_hdf5_paths) if p.startswith(split)
]
return torch.utils.data.Subset(self, indices)

@property
Expand All @@ -164,7 +167,10 @@ def samples_hdf5_paths(self):
# Load as variable if already indexed in hdf5 file. Need to decode b-string.
with h5py.File(self.hdf5_file_path, "r") as hdf5_file:
if "samples_hdf5_paths" in hdf5_file:
self._samples_hdf5_paths = [sample_path.decode("utf-8") for sample_path in hdf5_file["samples_hdf5_paths"]]
self._samples_hdf5_paths = [
sample_path.decode("utf-8")
for sample_path in hdf5_file["samples_hdf5_paths"]
]
return self._samples_hdf5_paths

# Otherwise, index samples, and add the index as an attribute to the HDF5 file.
Expand All @@ -175,7 +181,9 @@ def samples_hdf5_paths(self):
continue
for basename in hdf5_file[split].keys():
for sample_number in hdf5_file[split][basename].keys():
self._samples_hdf5_paths.append(osp.join(split, basename, sample_number))
self._samples_hdf5_paths.append(
osp.join(split, basename, sample_number)
)

with h5py.File(self.hdf5_file_path, "a") as hdf5_file:
# special type to avoid silent string truncation in hdf5 datasets.
Expand All @@ -198,7 +206,6 @@ def create_hdf5(
subtile_overlap_train: Number = 0,
points_pre_transform: Callable = lidar_hd_pre_transform,
):

"""Create a HDF5 dataset file from las.
Args:
Expand All @@ -221,20 +228,24 @@ def create_hdf5(
if split not in f:
f.create_group(split)
for las_path in tqdm(las_paths, desc=f"Preparing {split} set..."):

basename = os.path.basename(las_path)

# Delete dataset for incomplete LAS entry, to start from scratch.
# Useful in case data preparation was interrupted.
with h5py.File(hdf5_file_path, "a") as hdf5_file:
if basename in hdf5_file[split] and "is_complete" not in hdf5_file[split][basename].attrs:
if (
basename in hdf5_file[split]
and "is_complete" not in hdf5_file[split][basename].attrs
):
del hdf5_file[split][basename]
# Parse and add subtiles to split group.
with h5py.File(hdf5_file_path, "a") as hdf5_file:
if basename in hdf5_file[split]:
continue

subtile_overlap = subtile_overlap_train if split == "train" else 0 # No overlap at eval time.
subtile_overlap = (
subtile_overlap_train if split == "train" else 0
) # No overlap at eval time.
for sample_number, (sample_idx, sample_points) in enumerate(
split_cloud_into_samples(
las_path,
Expand All @@ -249,15 +260,19 @@ def create_hdf5(
if pre_filter is not None and pre_filter(data):
# e.g. pre_filter spots situations where num_nodes is too small.
continue
hdf5_path = os.path.join(split, basename, str(sample_number).zfill(5))
hdf5_path = os.path.join(
split, basename, str(sample_number).zfill(5)
)
hd5f_path_x = os.path.join(hdf5_path, "x")
hdf5_file.create_dataset(
hd5f_path_x,
data.x.shape,
dtype="f",
data=data.x,
)
hdf5_file[hd5f_path_x].attrs["x_features_names"] = copy.deepcopy(data.x_features_names)
hdf5_file[hd5f_path_x].attrs["x_features_names"] = copy.deepcopy(
data.x_features_names
)
hdf5_file.create_dataset(
os.path.join(hdf5_path, "pos"),
data.pos.shape,
Expand Down

0 comments on commit 2e0d53a

Please sign in to comment.