Skip to content

Commit

Permalink
Dev/main (#203)
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink authored Mar 21, 2024
2 parents 53a31f7 + e223a28 commit b0df46e
Show file tree
Hide file tree
Showing 20 changed files with 499 additions and 11 deletions.
2 changes: 1 addition & 1 deletion dacapo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .options import Options # noqa
from . import experiments # noqa
from . import experiments, utils # noqa
from .apply import apply # noqa
from .train import train # noqa
from .validate import validate # noqa
Expand Down
18 changes: 12 additions & 6 deletions dacapo/experiments/datasplits/datasplit_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

logger = logging.getLogger(__name__)

__SEPARATOR_CHARACTER = "&"


def is_zarr_group(file_name: str, dataset: str):
zarr_file = zarr.open(str(file_name))
Expand Down Expand Up @@ -187,6 +185,7 @@ def __init__(
min_training_volume_size=8_000, # 20**3
raw_min=0,
raw_max=255,
classes_separator_caracter="&",
):
self.name = name
self.datasets = datasets
Expand All @@ -208,6 +207,7 @@ def __init__(
self.min_training_volume_size = min_training_volume_size
self.raw_min = raw_min
self.raw_max = raw_max
self.classes_separator_caracter = classes_separator_caracter

def __str__(self) -> str:
return f"DataSplitGenerator:{self.name}_{self.segmentation_type}_{self.class_name}_{self.output_resolution[0]}nm"
Expand All @@ -226,7 +226,9 @@ def class_name(self, class_name):
self._class_name = class_name

def check_class_name(self, class_name):
datasets, classes = format_class_name(class_name)
datasets, classes = format_class_name(
class_name, self.classes_separator_caracter
)
if self.class_name is None:
self.class_name = classes
if self.targets is None:
Expand Down Expand Up @@ -268,8 +270,12 @@ def __generate_semantic_seg_datasplit(self):
gt_config=gt_config,
)
)
if type(self.class_name) == list:
classes = self.classes_separator_caracter.join(self.class_name)
else:
classes = self.class_name
return TrainValidateDataSplitConfig(
name=f"{self.name}_{self.segmentation_type}_{self.class_name}_{self.output_resolution[0]}nm",
name=f"{self.name}_{self.segmentation_type}_{classes}_{self.output_resolution[0]}nm",
train_configs=train_dataset_configs,
validate_configs=validation_dataset_configs,
)
Expand Down Expand Up @@ -383,11 +389,11 @@ def generate_from_csv(
)


def format_class_name(class_name):
def format_class_name(class_name, separator_character="&"):
if "[" in class_name:
if "]" not in class_name:
raise ValueError(f"Invalid class name {class_name} missing ']'")
classes = class_name.split("[")[1].split("]")[0].split(__SEPARATOR_CHARACTER)
classes = class_name.split("[")[1].split("]")[0].split(separator_character)
base_class_name = class_name.split("[")[0]
return [f"{base_class_name}{c}" for c in classes], classes
else:
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit b0df46e

Please sign in to comment.