Skip to content

Commit

Permalink
support for custom feature encoding/decoding (#7284)
Browse files Browse the repository at this point in the history
* support for custom feature encoding/decoding

* Update src/datasets/features/features.py

---------

Co-authored-by: Quentin Lhoest <[email protected]>
  • Loading branch information
alex-hh and lhoestq authored Nov 21, 2024
1 parent 2049c00 commit 17f17b3
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,7 +1348,7 @@ def encode_nested_example(schema, obj, level=0):
return list(obj)
# Object with special encoding:
# ClassLabel will convert from string to int, TranslationVariableLanguages does some checks
elif isinstance(schema, (Audio, Image, ClassLabel, TranslationVariableLanguages, Value, _ArrayXD, Video)):
elif hasattr(schema, "encode_example"):
return schema.encode_example(obj) if obj is not None else None
# Other object should be directly convertible to a native Arrow type (like Translation and Translation)
return obj
Expand Down Expand Up @@ -1399,10 +1399,9 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni
else:
return decode_nested_example([schema.feature], obj)
# Object with special decoding:
elif isinstance(schema, (Audio, Image, Video)):
elif hasattr(schema, "decode_example") and getattr(schema, "decode", True):
# we pass the token to read and decode files from private repositories in streaming mode
if obj is not None and schema.decode:
return schema.decode_example(obj, token_per_repo_id=token_per_repo_id)
return schema.decode_example(obj, token_per_repo_id=token_per_repo_id) if obj is not None else None
return obj


Expand Down Expand Up @@ -1629,7 +1628,9 @@ def require_decoding(feature: FeatureType, ignore_decode_attribute: bool = False
elif isinstance(feature, Sequence):
return require_decoding(feature.feature)
else:
return hasattr(feature, "decode_example") and (feature.decode if not ignore_decode_attribute else True)
return hasattr(feature, "decode_example") and (
getattr(feature, "decode", True) if not ignore_decode_attribute else True
)


def require_storage_cast(feature: FeatureType) -> bool:
Expand Down

0 comments on commit 17f17b3

Please sign in to comment.