diff --git a/timm/data/dataset.py b/timm/data/dataset.py index 5f5fc9e64b..1c481bffa5 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -30,13 +30,15 @@ def __init__( input_img_mode='RGB', transform=None, target_transform=None, + **kwargs, ): if reader is None or isinstance(reader, str): reader = create_reader( reader or '', root=root, split=split, - class_map=class_map + class_map=class_map, + **kwargs, ) self.reader = reader self.load_bytes = load_bytes diff --git a/timm/data/readers/reader_hfds.py b/timm/data/readers/reader_hfds.py index 3dd6dd5d26..7784660676 100644 --- a/timm/data/readers/reader_hfds.py +++ b/timm/data/readers/reader_hfds.py @@ -35,7 +35,7 @@ def __init__( root: Optional[str] = None, split: str = 'train', class_map: dict = None, - image_key: str = 'image', + input_key: str = 'image', target_key: str = 'label', download: bool = False, ): @@ -50,9 +50,9 @@ def __init__( cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path ) # leave decode for caller, plus we want easy access to original path names... - self.dataset = self.dataset.cast_column(image_key, datasets.Image(decode=False)) + self.dataset = self.dataset.cast_column(input_key, datasets.Image(decode=False)) - self.image_key = image_key + self.image_key = input_key self.label_key = target_key self.remap_class = False if class_map: