Skip to content

Commit

Permalink
fix: Sorting bug and add pre-sorting for segmentation datamodule (#110)
Browse files Browse the repository at this point in the history
* fix: sorting bug and add pre-sorting

* build: Update version and changelog

* fix: Revert base segmentation experiment
  • Loading branch information
AlessandroPolidori authored Mar 5, 2024
1 parent 539cac5 commit 8a40419
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 3 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@
# Changelog
All notable changes to this project will be documented in this file.

### [2.0.4]

#### Fixed

- Fix segmentation num_data_train sorting

#### Added

- Add default presorting to segmentation samples

### [2.0.3]

#### Fixed
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "quadra"
version = "2.0.3"
version = "2.0.4"
description = "Deep Learning experiment orchestration library"
authors = [
"Federico Belotti <[email protected]>",
Expand Down
2 changes: 1 addition & 1 deletion quadra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.0.3"
__version__ = "2.0.4"


def get_version():
Expand Down
22 changes: 21 additions & 1 deletion quadra/datamodules/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,14 +590,34 @@ def _prepare_data(self) -> None:
masks_train = samples_and_masks_train[:, 0, 1]
masks_val = samples_and_masks_val[:, 0, 1]

# Pre-ordering train and val samples for determinism
# They will be shuffled (with a seed) during training
sorting_indices_train = np.argsort(list(samples_train))
samples_train = [samples_train[i] for i in sorting_indices_train]
targets_train = [targets_train[i] for i in sorting_indices_train]
masks_train = [masks_train[i] for i in sorting_indices_train]

sorting_indices_val = np.argsort(samples_val)
samples_val = [samples_val[i] for i in sorting_indices_val]
targets_val = [targets_val[i] for i in sorting_indices_val]
masks_val = [masks_val[i] for i in sorting_indices_val]

if self.exclude_good:
samples_train = list(np.array(samples_train)[np.array(targets_train)[:, 0] == 0])
masks_train = list(np.array(masks_train)[np.array(targets_train)[:, 0] == 0])
targets_train = list(np.array(targets_train)[np.array(targets_train)[:, 0] == 0])

if self.num_data_train is not None:
# Generate a random permutation
random_permutation = list(range(len(samples_train)))
random.seed(self.seed)
random.shuffle(samples_train)
random.shuffle(random_permutation)

# Shuffle samples_train, targets_train, and masks_train using the same permutation
samples_train = [samples_train[i] for i in random_permutation]
targets_train = [targets_train[i] for i in random_permutation]
masks_train = [masks_train[i] for i in random_permutation]

samples_train = np.array(samples_train)[: self.num_data_train]
targets_train = np.array(targets_train)[: self.num_data_train]
masks_train = np.array(masks_train)[: self.num_data_train]
Expand Down

0 comments on commit 8a40419

Please sign in to comment.