diff --git a/t5x/train.py b/t5x/train.py index d6bfd6ab0..0ab63bdaf 100644 --- a/t5x/train.py +++ b/t5x/train.py @@ -305,9 +305,9 @@ def _verify_matching_vocabs(cfg: utils.DatasetConfig): partitioner=partitioner, data_layout=data_layout, ) - input_shapes = jax.tree_map( - lambda x: (data_layout.batch_size, *x.shape[1:]), train_iter.element_spec + lambda x: (data_layout.batch_size, *x.shape[1:]), + train_iter.element_spec, ) input_types = jax.tree_map(lambda x: x.dtype, train_iter.element_spec) @@ -513,9 +513,13 @@ def _run_training_eval(first_run: bool = False): } ) logging.info('Computing training evaluation metrics.') - eval_batch_iters = { - task: ds.as_numpy_iterator() for task, ds in train_eval_datasets.items() - } + eval_batch_iters = {} + for task, ds in train_eval_datasets.items(): + if isinstance(ds, tf.data.Dataset): + eval_batch_iters[task] = ds.as_numpy_iterator() + else: + eval_batch_iters[task] = ds + eval_summaries = trainer.eval(eval_batch_iters) trainer.stop_training = run_actions( trainer_lib.ActionMode.TRAIN_EVAL, # pytype: disable=wrong-arg-types # jax-ndarray diff --git a/t5x/utils.py b/t5x/utils.py index d56a35274..6db7ac017 100644 --- a/t5x/utils.py +++ b/t5x/utils.py @@ -30,6 +30,7 @@ from absl import flags from absl import logging +import airio import clu.data import flax from flax import traverse_util @@ -519,7 +520,9 @@ def restore( class DatasetConfig: """Configuration for loading a dataset from a SeqIO Task or Mixture.""" - mixture_or_task_name: Union[str, seqio.Task, seqio.Mixture] + mixture_or_task_name: Union[ + str, seqio.Task, seqio.Mixture, airio.Task, airio.Mixture + ] task_feature_lengths: Mapping[str, int] split: str batch_size: int # Number of examples per batch. @@ -717,6 +720,8 @@ def prepare_train_iter( data_layout, ) -> clu.data.dataset_iterator.PeekableDatasetIterator: """Prepares the training input iterator.""" + if isinstance(train_iter, airio.PyGrainDatasetIteratorWrapper): + return train_iter if isinstance(train_iter, tf.data.Dataset): train_iter = clu.data.dataset_iterator.TfDatasetIterator( train_iter, checkpoint=True @@ -789,13 +794,20 @@ def get_zeros_batch_like_spec( def get_zeros_batch_like_dataset( - dataset: tf.data.Dataset, batch_size=None + dataset: Union[tf.data.Dataset, airio.PyGrainDatasetIteratorWrapper], + batch_size=None, ) -> Mapping[str, jnp.ndarray]: + """Get zeros batch like the dataset spec.""" reshape = lambda s: (batch_size,) + s[1:] if batch_size else tuple(s) - batch_spec = { - k: jax.ShapeDtypeStruct(reshape(t.shape), t.dtype.as_numpy_dtype) - for k, t in dataset.element_spec.items() - } + batch_spec = {} + for key, val in dataset.element_spec.items(): # pytype: disable=attribute-error + if isinstance(dataset, tf.data.Dataset): + static_attributes = jax.ShapeDtypeStruct( + reshape(val.shape), val.dtype.as_numpy_dtype + ) + else: + static_attributes = jax.ShapeDtypeStruct(val.shape, val.dtype) + batch_spec[key] = static_attributes return get_zeros_batch_like_spec(batch_spec) @@ -1819,11 +1831,22 @@ def get_vocabulary( ) import_module(cfg.module) - if isinstance(cfg.mixture_or_task_name, seqio.DatasetProviderBase): + if isinstance(cfg.mixture_or_task_name, airio.DatasetProviderBase): + mixture_or_task = cfg.mixture_or_task_name + vocab_map = airio.get_vocabularies(mixture_or_task) + if not vocab_map: + raise ValueError( + f'No vocabularies found for AirIO task/mixture {mixture_or_task}' + ) + all_vocabularies = list(vocab_map.values()) + feature = seqio.Feature(vocabulary=all_vocabularies[0]) + features = {'inputs': feature, 'targets': feature} + elif isinstance(cfg.mixture_or_task_name, seqio.DatasetProviderBase): mixture_or_task = cfg.mixture_or_task_name + features = mixture_or_task.output_features else: mixture_or_task = seqio.get_mixture_or_task(cfg.mixture_or_task_name) - features = mixture_or_task.output_features + features = mixture_or_task.output_features if 'inputs' in features and 'targets' in features: return (features['inputs'].vocabulary, features['targets'].vocabulary) @@ -1900,7 +1923,7 @@ def get_dataset( feature_converter_cls: Callable[..., seqio.FeatureConverter], num_epochs: Optional[int] = None, continue_from_last_checkpoint: bool = False, -) -> tf.data.Dataset: +) -> Union[tf.data.Dataset, airio.PyGrainDatasetIteratorWrapper]: """Returns a dataset from SeqIO based on a `DatasetConfig`.""" if continue_from_last_checkpoint: raise ValueError( @@ -1941,7 +1964,10 @@ def get_dataset_inner( ): """Internal fn to load a dataset from SeqIO based on a `DatasetConfig`.""" batch_size = cfg.batch_size // shard_info.num_shards - if isinstance(cfg.mixture_or_task_name, seqio.DatasetProviderBase): + if isinstance( + cfg.mixture_or_task_name, + (seqio.DatasetProviderBase, airio.DatasetProviderBase), + ): mixture_or_task = cfg.mixture_or_task_name else: mixture_or_task = seqio.get_mixture_or_task(cfg.mixture_or_task_name) @@ -1993,7 +2019,7 @@ def __call__( feature_converter_cls: Callable[..., seqio.FeatureConverter], num_epochs: Optional[int] = None, continue_from_last_checkpoint: bool = True, - ) -> Union[clu.data.dataset_iterator.DatasetIterator, tf.data.Dataset]: + ) -> Union[clu.data.dataset_iterator.DatasetIterator, tf.data.Dataset,]: ... @@ -2007,7 +2033,9 @@ def __call__( num_shards: int, eval_steps: int, feature_converter_cls: Callable[..., seqio.FeatureConverter], - ) -> Mapping[str, tf.data.Dataset]: + ) -> Mapping[ + str, Union[tf.data.Dataset, airio.PyGrainDatasetIteratorWrapper] + ]: ... @@ -2020,9 +2048,19 @@ def get_training_eval_datasets( deterministic: bool = False, model_dir: Optional[str] = None, start_step: int = 0, -) -> Mapping[str, tf.data.Dataset]: +) -> Mapping[ + str, + Union[ + tf.data.Dataset, + airio.PyGrainDatasetIteratorWrapper, + Iterable[Mapping[str, np.ndarray]], + ], +]: """Returns a mapping from eval task name to its dataset.""" - if isinstance(cfg.mixture_or_task_name, seqio.DatasetProviderBase): + if isinstance( + cfg.mixture_or_task_name, + (seqio.DatasetProviderBase, airio.DatasetProviderBase), + ): mixture_or_task = cfg.mixture_or_task_name else: mixture_or_task = seqio.get_mixture_or_task(cfg.mixture_or_task_name) @@ -2034,6 +2072,22 @@ def get_training_eval_datasets( get_deterministic_dataset, model_dir=model_dir, start_step=start_step ) + if isinstance(mixture_or_task, (airio.Task, airio.Mixture)): + data_iter = get_dataset_fn( + dataclasses.replace(cfg, batch_size=1), + shard_id=0, + num_shards=1, + feature_converter_cls=feature_converter_cls, + num_epochs=eval_steps * cfg.batch_size, + continue_from_last_checkpoint=False, + ) + # TODO(b/304579895): Cannot use itertools.islice here to limit the number + # of records to eval_steps since peek() is required later. Instead, update + # t5x eval to not depend on the data pipeline and stop after eval_run + # steps. + datasets[mixture_or_task.name] = data_iter + return datasets + if cfg.batch_size % num_shards: raise ValueError( f'Batch size ({cfg.batch_size}) must be divisible by number of '