Skip to content

Commit

Permalink
feature: separate datasets to keras datasets for including other fram…
Browse files Browse the repository at this point in the history
…eworks
  • Loading branch information
dstalzjohn authored and aiakide committed Dec 18, 2023
1 parent f38536f commit 2242721
Show file tree
Hide file tree
Showing 14 changed files with 248 additions and 253 deletions.
2 changes: 1 addition & 1 deletion configs/shared/datasets/dataset_cls_test.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: niceml.data.datasets.genericdataset.GenericDataset
_target_: niceml.dlframeworks.keras.datasets.kerasgenericdataset.KerasGenericDataset
batch_size: 2
datainfo_listing:
_target_: niceml.data.datainfolistings.clsdatainfolisting.DirClsDataInfoListing
Expand Down
2 changes: 1 addition & 1 deletion configs/shared/datasets/dataset_objdet_test.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: niceml.data.datasets.genericdataset.GenericDataset
_target_: niceml.dlframeworks.keras.datasets.kerasgenericdataset.KerasGenericDataset
batch_size: 2
datainfo_listing:
_target_: niceml.data.datainfolistings.objdetdatainfolisting.ObjDetDataInfoListing
Expand Down
2 changes: 1 addition & 1 deletion configs/shared/datasets/dataset_reg_test.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: niceml.data.datasets.dfdataset.DfDataset
_target_: niceml.dlframeworks.keras.datasets.kerasdfdataset.KerasDfDataset
id_key: identifier
batch_size: 64
data_location: ${globals.data_location}
Expand Down
2 changes: 1 addition & 1 deletion configs/shared/datasets/dataset_semseg_test.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: niceml.data.datasets.genericdataset.GenericDataset
_target_: niceml.dlframeworks.keras.datasets.kerasgenericdataset.KerasGenericDataset
batch_size: 2
datainfo_listing:
_target_: niceml.data.datainfolistings.semsegdatainfolisting.SemSegDataInfoListing
Expand Down
151 changes: 0 additions & 151 deletions niceml/dashboard/cam.py

This file was deleted.

12 changes: 9 additions & 3 deletions niceml/data/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@ class Dataset(ABC):
"""Dataset to load, transform, shuffle the data before training"""

@abstractmethod
def get_batch_size(self) -> int:
"""Returns the current batch size"""
def get_item_count(self) -> int:
"""Returns the current count of items in the dataset"""

@abstractmethod
def get_items_per_epoch(self) -> int:
"""Returns the items per epoch"""

@abstractmethod
def get_set_name(self) -> str:
Expand All @@ -31,6 +35,7 @@ def iter_with_info(self) -> Iterable:

@abstractmethod
def __getitem__(self, index: int):
"""Returns the data of the item/batch at index"""
pass

@abstractmethod
Expand All @@ -39,11 +44,12 @@ def get_datainfo(self, batch_index: int) -> List[DataInfo]:

@abstractmethod
def __len__(self):
"""Returns the number of batches/items"""
pass

def get_dataset_stats(self) -> dict:
"""Returns the dataset stats"""
return dict(size=len(self) * self.get_batch_size())
return dict(size=self.get_item_count())

@abstractmethod
def get_data_by_key(self, data_key):
Expand Down
71 changes: 14 additions & 57 deletions niceml/data/datasets/dfdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@

import numpy as np
import pandas as pd
from tensorflow.keras.utils import ( # pylint: disable=import-error,no-name-in-module
Sequence,
)

from niceml.data.datadescriptions.regdatadescription import (
RegDataDescription,
Expand Down Expand Up @@ -74,13 +71,12 @@ def __getattr__(self, item) -> Any:
return self.data[item]


class DfDataset(Dataset, Sequence): # pylint: disable=too-many-instance-attributes
class DfDataset(Dataset): # pylint: disable=too-many-instance-attributes
"""Dataset for dataframes"""

def __init__( # ruff: noqa: PLR0913
self,
id_key: str,
batch_size: int,
subset_name: str,
data_location: Union[dict, LocationConfig],
df_filename: str = ExperimentFilenames.SUBSET_NAME,
Expand All @@ -95,7 +91,6 @@ def __init__( # ruff: noqa: PLR0913
Args:
id_key: Column name of the id column in your dataframe
batch_size: Size of a batch
subset_name: Name of the dataset
data_location: Location of the data used in the data set
df_filename: Specify the file name of the dataframe
Expand All @@ -108,7 +103,6 @@ def __init__( # ruff: noqa: PLR0913
self.dataframe_filters = dataframe_filters or []
self.df_path = df_filename
self.data_location = data_location
self.batch_size = batch_size
self.subset_name = subset_name
self.id_key = id_key
self.index_list = []
Expand Down Expand Up @@ -158,14 +152,13 @@ def initialize(

self.on_epoch_end()

def get_batch_size(self) -> int:
"""
The get_batch_size function returns the batch size of the dataset.
def get_item_count(self) -> int:
"""Get the number of items in the dataset"""
return len(self.data)

Returns:
The batch size
"""
return self.batch_size
def get_items_per_epoch(self) -> int:
"""Get the number of items per epoch"""
return len(self.index_list)

def get_set_name(self) -> str:
"""
Expand Down Expand Up @@ -235,33 +228,26 @@ def extract_data(self, cur_indexes: List[int], cur_input: dict):

def __getitem__(self, index):
"""
The __getitem__ function returns the indexed data batch in the size of `self.batch_size`.
It is called when the DfDataset is accessed, using the notation self[`index`]
(while training a model).
The __getitem__ function returns the indexed data item.
Args:
index: Specify `index` of the batch
index: Specify `index` of the item
Returns:
A batch of input data and target data with the batch size `self.batch_size`
An item of input data and target data
"""
start_idx = index * self.batch_size
end_idx = min(len(self.index_list), (index + 1) * self.batch_size)
input_data, target_data = self.get_data(start_idx, end_idx)
input_data, target_data = self.get_data(index, index + 1)

return input_data, target_data

def __len__(self):
"""
The __len__ function is used to determine the number of batches in an epoch.
The __len__ function is used to determine the number of steps in a dataset.
Returns:
The number of batches in an epoch
The number of items
"""
batch_count, rest = divmod(len(self.index_list), self.batch_size)
if rest > 0:
batch_count += 1
return batch_count
return self.get_items_per_epoch()

def on_epoch_end(self):
"""
Expand All @@ -286,35 +272,6 @@ def iter_with_info(self):
"""
return DataIterator(self)

def get_datainfo(self, batch_index) -> List[RegDataInfo]:
"""
The get_datainfo function is used to get the data information for a given batch.
Args:
batch_index: Determine which batch of data (datainfo) to return
Returns:
A list of `RegDataInfo` objects of the batch with index `batch_index`
"""
start_idx = batch_index * self.batch_size
end_idx = min(len(self.index_list), (batch_index + 1) * self.batch_size)
data_info_list: List[RegDataInfo] = []
input_keys = [input_dict["key"] for input_dict in self.inputs]
target_keys = [target_dict["key"] for target_dict in self.targets]
data_subset = self.data[
[self.id_key] + input_keys + target_keys + self.extra_key_list
]
real_index_list = [self.index_list[idx] for idx in range(start_idx, end_idx)]
data_info_dicts: List[dict] = data_subset.iloc[real_index_list].to_dict(
"records"
)

for data_info_dict in data_info_dicts:
key = data_info_dict[self.id_key]
data_info_dict.pop(self.id_key)
data_info_list.append(RegDataInfo(key, data_info_dict))
return data_info_list

def get_all_data_info(self) -> List[RegDataInfo]:
"""
The get_all_data_info function returns a list of `RegDataInfo` objects for
Expand Down
Loading

0 comments on commit 2242721

Please sign in to comment.