Skip to content

Commit

Permalink
Standardize inheritance of DatasetProviderBase. This is mostly trying…
Browse files Browse the repository at this point in the history
… to align the inherited function get_dataset for all dataset providers.

All changes should be non-destructive, and thus not affect current clients.

PiperOrigin-RevId: 507072854
  • Loading branch information
broken authored and copybara-github committed Feb 4, 2023
1 parent f8e59d9 commit 9549a10
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion optformer/data/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import seqio
import t5.data
import tensorflow as tf
import tensorflow_datasets as tfds

Study = converters.Study

Expand Down Expand Up @@ -272,10 +273,14 @@ def supports_arbitrary_sharding(self) -> bool:

def get_dataset(
self,
split: str,
split: str = tfds.Split.TRAIN,
shuffle: bool = True,
seed: Optional[int] = None,
shard_info: Optional[seqio.ShardInfo] = None,
*,
sequence_length: Optional[Mapping[str, int]] = None, # Unused
use_cached: bool = False, # Unused
num_epochs: Optional[int] = 1, # Unused
) -> tf.data.Dataset:
raise NotImplementedError

Expand Down

0 comments on commit 9549a10

Please sign in to comment.