diff --git a/configs/shared/datasets/dataset_cls_test.yaml b/configs/shared/datasets/dataset_cls_test.yaml index 21295f18..1010574e 100644 --- a/configs/shared/datasets/dataset_cls_test.yaml +++ b/configs/shared/datasets/dataset_cls_test.yaml @@ -1,4 +1,4 @@ -_target_: niceml.data.datasets.genericdataset.GenericDataset +_target_: niceml.dlframeworks.keras.datasets.kerasgenericdataset.KerasGenericDataset batch_size: 2 datainfo_listing: _target_: niceml.data.datainfolistings.clsdatainfolisting.DirClsDataInfoListing diff --git a/configs/shared/datasets/dataset_objdet_test.yaml b/configs/shared/datasets/dataset_objdet_test.yaml index 5b0ac65d..08537d5b 100644 --- a/configs/shared/datasets/dataset_objdet_test.yaml +++ b/configs/shared/datasets/dataset_objdet_test.yaml @@ -1,4 +1,4 @@ -_target_: niceml.data.datasets.genericdataset.GenericDataset +_target_: niceml.dlframeworks.keras.datasets.kerasgenericdataset.KerasGenericDataset batch_size: 2 datainfo_listing: _target_: niceml.data.datainfolistings.objdetdatainfolisting.ObjDetDataInfoListing diff --git a/configs/shared/datasets/dataset_reg_test.yaml b/configs/shared/datasets/dataset_reg_test.yaml index 268d547b..78c1b92d 100644 --- a/configs/shared/datasets/dataset_reg_test.yaml +++ b/configs/shared/datasets/dataset_reg_test.yaml @@ -1,4 +1,4 @@ -_target_: niceml.data.datasets.dfdataset.DfDataset +_target_: niceml.dlframeworks.keras.datasets.kerasdfdataset.KerasDfDataset id_key: identifier batch_size: 64 data_location: ${globals.data_location} diff --git a/configs/shared/datasets/dataset_semseg_test.yaml b/configs/shared/datasets/dataset_semseg_test.yaml index d13ad64e..c11f347e 100644 --- a/configs/shared/datasets/dataset_semseg_test.yaml +++ b/configs/shared/datasets/dataset_semseg_test.yaml @@ -1,4 +1,4 @@ -_target_: niceml.data.datasets.genericdataset.GenericDataset +_target_: niceml.dlframeworks.keras.datasets.kerasgenericdataset.KerasGenericDataset batch_size: 2 datainfo_listing: _target_: niceml.data.datainfolistings.semsegdatainfolisting.SemSegDataInfoListing diff --git a/niceml/dashboard/cam.py b/niceml/dashboard/cam.py deleted file mode 100644 index a161e5b9..00000000 --- a/niceml/dashboard/cam.py +++ /dev/null @@ -1,151 +0,0 @@ -"""Module to with functions and utilities to calculate CAM on an image""" - -from pathlib import Path - -import numpy as np -import tensorflow as tf -from cv2 import cv2 -from tensorflow.keras import Model # pylint:disable= import-error -from tensorflow.keras.models import load_model # pylint:disable= import-error -from tensorflow.keras.utils import ( # pylint:disable= import-error - img_to_array, - load_img, -) - -from niceml.experiments.expdatalocalstorageloader import ( - create_expdata_from_local_storage, -) -from niceml.experiments.experimentdata import ExperimentData - - -# pylint:disable=c-extension-no-member -def run_cam_on_img(img_file: str, experiment_path: str): - """ - 1. Load image from file - 2. Load model from experiment output - 3. Extract Conv Layer from model - 4. Calc CAM - 5. Create Heatmap image - 6. Save image to experiment directory - Args: - experiment_path: path to experiment output - img_file: original image path - - """ - - exp: ExperimentData = create_expdata_from_local_storage(experiment_path) - - model = _load_model(exp) - preprocessed_img = load_preprocessed_img( - img_file=img_file, shape=model.input_shape[1:] - ) - heatmap = create_heatmap(model=model, img=preprocessed_img, class_idx=0) - cam_img = make_overlay(preprocessed_img, heatmap) - - output_image = cv2.addWeighted( - cv2.cvtColor(preprocessed_img.astype("uint8"), cv2.COLOR_RGB2BGR), - 0.5, - cam_img, - 1, - 0, - ) - - output_path = Path(experiment_path).joinpath(f"{Path(img_file).stem}_cam.png") - cv2.imwrite(str(output_path), output_image) - - -def _load_model(exp: ExperimentData): - """ - Load best model from experiment - - Args: - exp: experiment (ExperimentData) - - Returns: - keras model - """ - - model = load_model(exp.get_model_path()) - return model - - -def load_preprocessed_img(img_file: str, shape: tuple): - """ - - Args: - img_file: path to image file - shape: input shape of model - - Returns: - image as numpy array - """ - - img = load_img(img_file, target_size=shape) - return img_to_array(img) - - -# pylint:disable = too-many-locals,c-extension-no-member -def create_heatmap( - model: Model, img: np.ndarray, class_idx: int, layer_name: str = "conv" -) -> np.ndarray: - """ - generates heatmap using last conv layer "conv" specified in create_model section - - Args: - model: keras Model - img: image as numpy array - class_idx: class index of output tensor - layer_name: last conv layer name - DEFAULT: "conv" - - Returns: - heatmap as numpy array - """ - - grad_model = Model( - [model.input], [model.get_layer(name=layer_name).output, model.output] - ) - - with tf.GradientTape() as tape: - conv_outputs, predictions = grad_model(np.array([img])) - loss = predictions[:, class_idx] - - output = conv_outputs[0] - grads = tape.gradient(loss, conv_outputs)[0] - - tf.cast(output > 0, "float32") - tf.cast(grads > 0, "float32") - guided_grads = ( - tf.cast(output > 0, "float32") * tf.cast(grads > 0, "float32") * grads - ) - - weights = tf.reduce_mean(guided_grads, axis=(0, 1)) - - cam = np.ones(output.shape[0:2], dtype=np.float32) - - for index, weight in enumerate(weights): - cam += weight * output[:, :, index] - - cam = cv2.resize(cam.numpy(), (img.shape[0], img.shape[1])) - cam = np.maximum(cam, 0) - heatmap = (cam - cam.min()) // (cam.max() - cam.min()) - - return heatmap - - -# pylint:disable=c-extension-no-member -def make_overlay(img: np.ndarray, heatmap: np.ndarray) -> np.ndarray: - """ - generates a weighted image based on cam heatmap - - Args: - img: original image as numpy array - heatmap: heatmap as numpy array - - Returns: - cam image as numpy array - """ - - cam = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_JET) - return cv2.addWeighted( - cv2.cvtColor(img.astype("uint8"), cv2.COLOR_RGB2BGR), 0.5, cam, 1, 0 - ) diff --git a/niceml/data/datasets/dataset.py b/niceml/data/datasets/dataset.py index ff07eca9..b7ec4c1f 100644 --- a/niceml/data/datasets/dataset.py +++ b/niceml/data/datasets/dataset.py @@ -12,8 +12,12 @@ class Dataset(ABC): """Dataset to load, transform, shuffle the data before training""" @abstractmethod - def get_batch_size(self) -> int: - """Returns the current batch size""" + def get_item_count(self) -> int: + """Returns the current count of items in the dataset""" + + @abstractmethod + def get_items_per_epoch(self) -> int: + """Returns the items per epoch""" @abstractmethod def get_set_name(self) -> str: @@ -31,6 +35,7 @@ def iter_with_info(self) -> Iterable: @abstractmethod def __getitem__(self, index: int): + """Returns the data of the item/batch at index""" pass @abstractmethod @@ -39,11 +44,12 @@ def get_datainfo(self, batch_index: int) -> List[DataInfo]: @abstractmethod def __len__(self): + """Returns the number of batches/items""" pass def get_dataset_stats(self) -> dict: """Returns the dataset stats""" - return dict(size=len(self) * self.get_batch_size()) + return dict(size=self.get_item_count()) @abstractmethod def get_data_by_key(self, data_key): diff --git a/niceml/data/datasets/dfdataset.py b/niceml/data/datasets/dfdataset.py index 3966dedf..d4559e62 100644 --- a/niceml/data/datasets/dfdataset.py +++ b/niceml/data/datasets/dfdataset.py @@ -4,9 +4,6 @@ import numpy as np import pandas as pd -from tensorflow.keras.utils import ( # pylint: disable=import-error,no-name-in-module - Sequence, -) from niceml.data.datadescriptions.regdatadescription import ( RegDataDescription, @@ -74,13 +71,12 @@ def __getattr__(self, item) -> Any: return self.data[item] -class DfDataset(Dataset, Sequence): # pylint: disable=too-many-instance-attributes +class DfDataset(Dataset): # pylint: disable=too-many-instance-attributes """Dataset for dataframes""" def __init__( # ruff: noqa: PLR0913 self, id_key: str, - batch_size: int, subset_name: str, data_location: Union[dict, LocationConfig], df_filename: str = ExperimentFilenames.SUBSET_NAME, @@ -95,7 +91,6 @@ def __init__( # ruff: noqa: PLR0913 Args: id_key: Column name of the id column in your dataframe - batch_size: Size of a batch subset_name: Name of the dataset data_location: Location of the data used in the data set df_filename: Specify the file name of the dataframe @@ -108,7 +103,6 @@ def __init__( # ruff: noqa: PLR0913 self.dataframe_filters = dataframe_filters or [] self.df_path = df_filename self.data_location = data_location - self.batch_size = batch_size self.subset_name = subset_name self.id_key = id_key self.index_list = [] @@ -158,14 +152,13 @@ def initialize( self.on_epoch_end() - def get_batch_size(self) -> int: - """ - The get_batch_size function returns the batch size of the dataset. + def get_item_count(self) -> int: + """Get the number of items in the dataset""" + return len(self.data) - Returns: - The batch size - """ - return self.batch_size + def get_items_per_epoch(self) -> int: + """Get the number of items per epoch""" + return len(self.index_list) def get_set_name(self) -> str: """ @@ -235,33 +228,26 @@ def extract_data(self, cur_indexes: List[int], cur_input: dict): def __getitem__(self, index): """ - The __getitem__ function returns the indexed data batch in the size of `self.batch_size`. - It is called when the DfDataset is accessed, using the notation self[`index`] - (while training a model). + The __getitem__ function returns the indexed data item. Args: - index: Specify `index` of the batch + index: Specify `index` of the item Returns: - A batch of input data and target data with the batch size `self.batch_size` + An item of input data and target data """ - start_idx = index * self.batch_size - end_idx = min(len(self.index_list), (index + 1) * self.batch_size) - input_data, target_data = self.get_data(start_idx, end_idx) + input_data, target_data = self.get_data(index, index + 1) return input_data, target_data def __len__(self): """ - The __len__ function is used to determine the number of batches in an epoch. + The __len__ function is used to determine the number of steps in a dataset. Returns: - The number of batches in an epoch + The number of items """ - batch_count, rest = divmod(len(self.index_list), self.batch_size) - if rest > 0: - batch_count += 1 - return batch_count + return self.get_items_per_epoch() def on_epoch_end(self): """ @@ -286,35 +272,6 @@ def iter_with_info(self): """ return DataIterator(self) - def get_datainfo(self, batch_index) -> List[RegDataInfo]: - """ - The get_datainfo function is used to get the data information for a given batch. - - Args: - batch_index: Determine which batch of data (datainfo) to return - - Returns: - A list of `RegDataInfo` objects of the batch with index `batch_index` - """ - start_idx = batch_index * self.batch_size - end_idx = min(len(self.index_list), (batch_index + 1) * self.batch_size) - data_info_list: List[RegDataInfo] = [] - input_keys = [input_dict["key"] for input_dict in self.inputs] - target_keys = [target_dict["key"] for target_dict in self.targets] - data_subset = self.data[ - [self.id_key] + input_keys + target_keys + self.extra_key_list - ] - real_index_list = [self.index_list[idx] for idx in range(start_idx, end_idx)] - data_info_dicts: List[dict] = data_subset.iloc[real_index_list].to_dict( - "records" - ) - - for data_info_dict in data_info_dicts: - key = data_info_dict[self.id_key] - data_info_dict.pop(self.id_key) - data_info_list.append(RegDataInfo(key, data_info_dict)) - return data_info_list - def get_all_data_info(self) -> List[RegDataInfo]: """ The get_all_data_info function returns a list of `RegDataInfo` objects for diff --git a/niceml/data/datasets/genericdataset.py b/niceml/data/datasets/genericdataset.py index f02d865e..8ab5c9cc 100644 --- a/niceml/data/datasets/genericdataset.py +++ b/niceml/data/datasets/genericdataset.py @@ -1,6 +1,6 @@ +"""module for generic dataset implementation""" from typing import Dict, List, Optional -from tensorflow.keras.utils import Sequence from niceml.data.augmentation.augmentation import AugmentationProcessor from niceml.data.datadescriptions.datadescription import DataDescription @@ -20,10 +20,14 @@ ) -class GenericDataset(Sequence, Dataset): - def __init__( +class GenericDataset(Dataset): + """Generic dataset implementation. This is a flexible dataset for multiple + use cases. It can be used for classification, segmentation, object detection, etc. + For specific frameworks, there are subclasses of this class, e.g. KerasGenericDataset + """ + + def __init__( # noqa: PLR0913 self, - batch_size: int, set_name: str, datainfo_listing: DataInfoListing, data_loader: DataLoader, @@ -35,11 +39,24 @@ def __init__( augmentator: Optional[AugmentationProcessor] = None, net_data_logger: Optional[NetDataLogger] = None, ): + """ + Constructor of the GenericDataset + Args: + set_name: Name of the subset e.g. train + datainfo_listing: How to list the data + data_loader: How to load the data + target_transformer: How to transform the + target of the model (e.g. one-hot encoding) + input_transformer: How to transform the input of the model + shuffle: bool if the data should be shuffled + data_shuffler: A way of shuffling the data (e.g. random, sampled) + stats_generator: Write dataset stats + augmentator: Augment the data on the fly + net_data_logger: Stores the in the way it is presented to the model + """ super().__init__() self.net_data_logger = net_data_logger self.set_name = set_name - self.batch_size = batch_size - self.batch_count = None self.datainfo_listing: DataInfoListing = datainfo_listing self.data_loader: DataLoader = data_loader self.shuffle = shuffle @@ -55,6 +72,7 @@ def __init__( def initialize( self, data_description: DataDescription, exp_context: ExperimentContext ): + """Initializes the dataset with the data description and context""" self.data_description = data_description self.data_loader.initialize(data_description) @@ -76,56 +94,51 @@ def initialize( self.on_epoch_end() - def __getitem__(self, batch_index: int): - cur_data_infos = self.get_datainfo(batch_index) - dc_list: list = [self.data_loader.load_data(x) for x in cur_data_infos] + def get_item_count(self) -> int: + """Returns the current count of items in the dataset""" + return len(self.data_info_list) + + def get_items_per_epoch(self) -> int: + """Returns the items per epoch""" + return len(self.index_list) + + def __getitem__(self, item_index: int): + """Returns the data of the item at index""" + real_index = self.index_list[item_index] + data_info = self.data_info_list[real_index] + data_item = self.data_loader.load_data(data_info) if self.augmentator is not None: - dc_list = [self.augmentator(x) for x in dc_list] - net_inputs = self.input_transformer.get_net_inputs(dc_list) - net_targets = self.target_transformer.get_net_targets(dc_list) + data_item = self.augmentator(data_item) + net_inputs = self.input_transformer.get_net_inputs([data_item]) + net_targets = self.target_transformer.get_net_targets([data_item]) if self.net_data_logger is not None: self.net_data_logger.log_data( net_inputs=net_inputs, net_targets=net_targets, - data_info_list=cur_data_infos, + data_info_list=[data_info], ) return net_inputs, net_targets - def get_batch_size(self) -> int: - return self.batch_size - def get_set_name(self) -> str: + """Returns the name of the set e.g. train""" return self.set_name def __len__(self): - batch_count, rest = divmod(len(self.index_list), self.batch_size) - if rest > 0: - batch_count += 1 - if self.batch_count is not None: - batch_count = min(self.batch_count, batch_count) - return batch_count - - def get_datainfo(self, batch_index: int) -> List[DataInfo]: - start_idx = batch_index * self.batch_size - end_idx = min(len(self.index_list), (batch_index + 1) * self.batch_size) - data_info_list: List[DataInfo] = [] - for cur_idx in range(start_idx, end_idx): - real_index = self.index_list[cur_idx] - image_info = self.data_info_list[real_index] - data_info_list.append(image_info) - return data_info_list + """Returns the number of batches""" + return self.get_items_per_epoch() def get_data_by_key(self, data_key): + """Returns the data by the key (identifier of the data)""" data_info: DataInfo = self.data_info_dict[data_key] return self.data_loader.load_data(data_info) def get_dataset_stats(self) -> dict: + """Returns the dataset stats""" return self.data_stats_generator.generate_stats( self.data_info_list, self.index_list ) def on_epoch_end(self): + """Shuffles the data if required""" if self.shuffle: - self.index_list = self.data_shuffler.shuffle( - self.data_info_list, batch_size=self.batch_size - ) + self.index_list = self.data_shuffler.shuffle(self.data_info_list) diff --git a/niceml/dlframeworks/keras/datasets/__init__.py b/niceml/dlframeworks/keras/datasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/niceml/dlframeworks/keras/datasets/kerasdfdataset.py b/niceml/dlframeworks/keras/datasets/kerasdfdataset.py new file mode 100644 index 00000000..11caf80d --- /dev/null +++ b/niceml/dlframeworks/keras/datasets/kerasdfdataset.py @@ -0,0 +1,97 @@ +"""module for the KerasDfDataset class""" +from typing import List + +from niceml.data.datasets.dfdataset import DfDataset, RegDataInfo + +from keras.utils import Sequence + + +class KerasDfDataset(DfDataset, Sequence): + """Keras implementation of the DfDataset""" + + def __init__(self, batch_size: int, *args, **kwargs): + """ + Constructor of the KerasdfDataset + Args: + batch_size: Batch size + **kwargs: All arguments of the DfDataset + """ + super().__init__(*args, **kwargs) + self.batch_size = batch_size + + def __len__(self): + """ + The __len__ function is used to determine the number of batches in an epoch. + + Returns: + The number of batches in an epoch + """ + batch_count, rest = divmod(self.get_items_per_epoch(), self.batch_size) + if rest > 0: + batch_count += 1 + return batch_count + + def __getitem__(self, index): + """ + The __getitem__ function returns the indexed data batch in the size of `self.batch_size`. + It is called when the DfDataset is accessed, using the notation self[`index`] + (while training a model). + + Args: + index: Specify `index` of the batch + + Returns: + A batch of input data and target data with the batch size `self.batch_size` + """ + start_idx = index * self.batch_size + end_idx = min(len(self.index_list), (index + 1) * self.batch_size) + input_data, target_data = self.get_data(start_idx, end_idx) + + return input_data, target_data + + def on_epoch_end(self): + """ + Execute logic to be performed at the end of an epoch (e.g. shuffling the data) + """ + if self.shuffle: + self.index_list = self.data_shuffler.shuffle( + data_infos=self.get_all_data_info(), batch_size=self.batch_size + ) + + def get_datainfo(self, batch_index) -> List[RegDataInfo]: + """ + The get_datainfo function is used to get the data information for a given batch. + + Args: + batch_index: Determine which batch of data (datainfo) to return + + Returns: + A list of `RegDataInfo` objects of the batch with index `batch_index` + """ + start_idx = batch_index * self.batch_size + end_idx = min(len(self.index_list), (batch_index + 1) * self.batch_size) + data_info_list: List[RegDataInfo] = [] + input_keys = [input_dict["key"] for input_dict in self.inputs] + target_keys = [target_dict["key"] for target_dict in self.targets] + data_subset = self.data[ + [self.id_key] + input_keys + target_keys + self.extra_key_list + ] + real_index_list = [self.index_list[idx] for idx in range(start_idx, end_idx)] + data_info_dicts: List[dict] = data_subset.iloc[real_index_list].to_dict( + "records" + ) + + for data_info_dict in data_info_dicts: + key = data_info_dict[self.id_key] + data_info_dict.pop(self.id_key) + data_info_list.append(RegDataInfo(key, data_info_dict)) + return data_info_list + + def get_batch_size(self) -> int: + """ + The get_batch_size function returns the batch size of the dataset. + + Returns: + The batch size + """ + return self.batch_size diff --git a/niceml/dlframeworks/keras/datasets/kerasgenericdataset.py b/niceml/dlframeworks/keras/datasets/kerasgenericdataset.py new file mode 100644 index 00000000..05d12c20 --- /dev/null +++ b/niceml/dlframeworks/keras/datasets/kerasgenericdataset.py @@ -0,0 +1,73 @@ +"""module for the KerasGenericDataset class""" +from typing import List + +from keras.utils import Sequence + +from niceml.data.datainfos.datainfo import DataInfo +from niceml.data.datasets.genericdataset import GenericDataset + + +class KerasGenericDataset(GenericDataset, Sequence): + """Keras implementation of the GenericDataset""" + + def __init__(self, batch_size: int, **kwargs): + """ + Constructor of the KerasGenericDataset + Args: + batch_size: Batch size + **kwargs: All arguments of the GenericDataset + """ + super().__init__(**kwargs) + self.batch_size = batch_size + + def __len__(self): + """ + The __len__ function is used to determine the number of batches in an epoch. + Contrary to the __len__ function of the GenericDataset, this function + returns the number of items per epoch. + """ + batch_count, rest = divmod(self.get_items_per_epoch(), self.batch_size) + if rest > 0: + batch_count += 1 + return batch_count + + def get_datainfo(self, batch_index: int) -> List[DataInfo]: + """ + Returns the datainfo for the batch at index + Args: + batch_index: index of the batch + + Returns: + List of DataInfo with regard to shuffling + """ + start_idx = batch_index * self.batch_size + end_idx = min(len(self.index_list), (batch_index + 1) * self.batch_size) + data_info_list: List[DataInfo] = [] + for cur_idx in range(start_idx, end_idx): + real_index = self.index_list[cur_idx] + image_info = self.data_info_list[real_index] + data_info_list.append(image_info) + return data_info_list + + def __getitem__(self, batch_index: int): + """Returns the data of the batch at index""" + cur_data_infos = self.get_datainfo(batch_index) + dc_list: list = [self.data_loader.load_data(x) for x in cur_data_infos] + if self.augmentator is not None: + dc_list = [self.augmentator(x) for x in dc_list] + net_inputs = self.input_transformer.get_net_inputs(dc_list) + net_targets = self.target_transformer.get_net_targets(dc_list) + if self.net_data_logger is not None: + self.net_data_logger.log_data( + net_inputs=net_inputs, + net_targets=net_targets, + data_info_list=cur_data_infos, + ) + return net_inputs, net_targets + + def on_epoch_end(self): + """Shuffles the data if shuffle is True""" + if self.shuffle: + self.index_list = self.data_shuffler.shuffle( + self.data_info_list, batch_size=self.batch_size + ) diff --git a/template/configs/shared/datasets/dataset_cls_test.yaml b/template/configs/shared/datasets/dataset_cls_test.yaml index 21295f18..1010574e 100644 --- a/template/configs/shared/datasets/dataset_cls_test.yaml +++ b/template/configs/shared/datasets/dataset_cls_test.yaml @@ -1,4 +1,4 @@ -_target_: niceml.data.datasets.genericdataset.GenericDataset +_target_: niceml.dlframeworks.keras.datasets.kerasgenericdataset.KerasGenericDataset batch_size: 2 datainfo_listing: _target_: niceml.data.datainfolistings.clsdatainfolisting.DirClsDataInfoListing diff --git a/template/configs/shared/datasets/dataset_objdet_test.yaml b/template/configs/shared/datasets/dataset_objdet_test.yaml index 5b0ac65d..08537d5b 100644 --- a/template/configs/shared/datasets/dataset_objdet_test.yaml +++ b/template/configs/shared/datasets/dataset_objdet_test.yaml @@ -1,4 +1,4 @@ -_target_: niceml.data.datasets.genericdataset.GenericDataset +_target_: niceml.dlframeworks.keras.datasets.kerasgenericdataset.KerasGenericDataset batch_size: 2 datainfo_listing: _target_: niceml.data.datainfolistings.objdetdatainfolisting.ObjDetDataInfoListing diff --git a/template/configs/shared/datasets/dataset_semseg_test.yaml b/template/configs/shared/datasets/dataset_semseg_test.yaml index d13ad64e..c11f347e 100644 --- a/template/configs/shared/datasets/dataset_semseg_test.yaml +++ b/template/configs/shared/datasets/dataset_semseg_test.yaml @@ -1,4 +1,4 @@ -_target_: niceml.data.datasets.genericdataset.GenericDataset +_target_: niceml.dlframeworks.keras.datasets.kerasgenericdataset.KerasGenericDataset batch_size: 2 datainfo_listing: _target_: niceml.data.datainfolistings.semsegdatainfolisting.SemSegDataInfoListing