Skip to content

Commit

Permalink
Fix pass through of input / target keys so ImageDataset readers so ar…
Browse files Browse the repository at this point in the history
…gs work with hfds instead of just hfids (iterable)
  • Loading branch information
rwightman committed Jul 17, 2024
1 parent 3196d6b commit 34c9fee
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
4 changes: 3 additions & 1 deletion timm/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ def __init__(
input_img_mode='RGB',
transform=None,
target_transform=None,
**kwargs,
):
if reader is None or isinstance(reader, str):
reader = create_reader(
reader or '',
root=root,
split=split,
class_map=class_map
class_map=class_map,
**kwargs,
)
self.reader = reader
self.load_bytes = load_bytes
Expand Down
6 changes: 3 additions & 3 deletions timm/data/readers/reader_hfds.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
root: Optional[str] = None,
split: str = 'train',
class_map: dict = None,
image_key: str = 'image',
input_key: str = 'image',
target_key: str = 'label',
download: bool = False,
):
Expand All @@ -50,9 +50,9 @@ def __init__(
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
)
# leave decode for caller, plus we want easy access to original path names...
self.dataset = self.dataset.cast_column(image_key, datasets.Image(decode=False))
self.dataset = self.dataset.cast_column(input_key, datasets.Image(decode=False))

self.image_key = image_key
self.image_key = input_key
self.label_key = target_key
self.remap_class = False
if class_map:
Expand Down

0 comments on commit 34c9fee

Please sign in to comment.