Skip to content

Commit

Permalink
Merge branch 'fix/dagster_docs' of github.com:codecentric-oss/niceml …
Browse files Browse the repository at this point in the history
…into fix/dagster_docs
  • Loading branch information
ankeko committed Nov 8, 2023
2 parents cd0b901 + 7e82fbe commit fdc969b
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions niceml/data/datainfolistings/clsdatainfolisting.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ def __init__(
label_suffix: str = ".json",
image_suffixes: Optional[List[str]] = None,
):
"""Init method of LabelClsDataInfoListing"""
self.sub_dir = sub_dir
self.data_location = data_location
self.label_suffix = label_suffix
self.image_suffixes = image_suffixes or [".png", ".jpg", ".jpeg"]

def list(self, data_description: DataDescription) -> List[ClsDataInfo]:
"""Lists all data infos"""
output_data_description: OutputVectorDataDescription = check_instance(
data_description, OutputVectorDataDescription
)
Expand Down Expand Up @@ -73,6 +75,11 @@ def list(self, data_description: DataDescription) -> List[ClsDataInfo]:
return new_data_info_list


def _default_class_extractor(input_str: str) -> str:
"""Default class extractor for DirClsDataInfoListing"""
return splitext(input_str)[0].rsplit("_", maxsplit=1)[-1]


class DirClsDataInfoListing(
DataInfoListing
): # pylint: disable=too-few-public-methods, too-many-arguments
Expand All @@ -85,14 +92,14 @@ def __init__(
class_extractor: Optional[Callable] = None,
image_suffixes: Optional[List[str]] = None,
):
"""Init method of DirClsDataInfoListing"""
self.sub_dir = sub_dir
self.location = location
self.class_extractor = class_extractor or (
lambda x: splitext(x)[0].rsplit("_", maxsplit=1)[-1]
)
self.class_extractor = class_extractor or _default_class_extractor
self.image_suffixes = image_suffixes or [".png", ".jpg", ".jpeg"]

def list(self, data_description: DataDescription) -> List[ClsDataInfo]:
"""Lists all data infos"""
output_data_description: OutputVectorDataDescription = check_instance(
data_description, OutputVectorDataDescription
)
Expand Down

0 comments on commit fdc969b

Please sign in to comment.