Skip to content

Commit

Permalink
Add support for pygrain-based AirIO in t5x train.py.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 571355859
  • Loading branch information
texasmichelle authored and t5-copybara committed Oct 11, 2023
1 parent 7416c8b commit a971c57
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 19 deletions.
14 changes: 9 additions & 5 deletions t5x/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,9 @@ def _verify_matching_vocabs(cfg: utils.DatasetConfig):
partitioner=partitioner,
data_layout=data_layout,
)

input_shapes = jax.tree_map(
lambda x: (data_layout.batch_size, *x.shape[1:]), train_iter.element_spec
lambda x: (data_layout.batch_size, *x.shape[1:]),
train_iter.element_spec,
)
input_types = jax.tree_map(lambda x: x.dtype, train_iter.element_spec)

Expand Down Expand Up @@ -513,9 +513,13 @@ def _run_training_eval(first_run: bool = False):
}
)
logging.info('Computing training evaluation metrics.')
eval_batch_iters = {
task: ds.as_numpy_iterator() for task, ds in train_eval_datasets.items()
}
eval_batch_iters = {}
for task, ds in train_eval_datasets.items():
if isinstance(ds, tf.data.Dataset):
eval_batch_iters[task] = ds.as_numpy_iterator()
else:
eval_batch_iters[task] = ds

eval_summaries = trainer.eval(eval_batch_iters)
trainer.stop_training = run_actions(
trainer_lib.ActionMode.TRAIN_EVAL, # pytype: disable=wrong-arg-types # jax-ndarray
Expand Down
82 changes: 68 additions & 14 deletions t5x/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from absl import flags
from absl import logging
import airio
import clu.data
import flax
from flax import traverse_util
Expand Down Expand Up @@ -519,7 +520,9 @@ def restore(
class DatasetConfig:
"""Configuration for loading a dataset from a SeqIO Task or Mixture."""

mixture_or_task_name: Union[str, seqio.Task, seqio.Mixture]
mixture_or_task_name: Union[
str, seqio.Task, seqio.Mixture, airio.Task, airio.Mixture
]
task_feature_lengths: Mapping[str, int]
split: str
batch_size: int # Number of examples per batch.
Expand Down Expand Up @@ -717,6 +720,8 @@ def prepare_train_iter(
data_layout,
) -> clu.data.dataset_iterator.PeekableDatasetIterator:
"""Prepares the training input iterator."""
if isinstance(train_iter, airio.PyGrainDatasetIteratorWrapper):
return train_iter
if isinstance(train_iter, tf.data.Dataset):
train_iter = clu.data.dataset_iterator.TfDatasetIterator(
train_iter, checkpoint=True
Expand Down Expand Up @@ -789,13 +794,20 @@ def get_zeros_batch_like_spec(


def get_zeros_batch_like_dataset(
dataset: tf.data.Dataset, batch_size=None
dataset: Union[tf.data.Dataset, airio.PyGrainDatasetIteratorWrapper],
batch_size=None,
) -> Mapping[str, jnp.ndarray]:
"""Get zeros batch like the dataset spec."""
reshape = lambda s: (batch_size,) + s[1:] if batch_size else tuple(s)
batch_spec = {
k: jax.ShapeDtypeStruct(reshape(t.shape), t.dtype.as_numpy_dtype)
for k, t in dataset.element_spec.items()
}
batch_spec = {}
for key, val in dataset.element_spec.items(): # pytype: disable=attribute-error
if isinstance(dataset, tf.data.Dataset):
static_attributes = jax.ShapeDtypeStruct(
reshape(val.shape), val.dtype.as_numpy_dtype
)
else:
static_attributes = jax.ShapeDtypeStruct(val.shape, val.dtype)
batch_spec[key] = static_attributes
return get_zeros_batch_like_spec(batch_spec)


Expand Down Expand Up @@ -1819,11 +1831,22 @@ def get_vocabulary(
)
import_module(cfg.module)

if isinstance(cfg.mixture_or_task_name, seqio.DatasetProviderBase):
if isinstance(cfg.mixture_or_task_name, airio.DatasetProviderBase):
mixture_or_task = cfg.mixture_or_task_name
vocab_map = airio.get_vocabularies(mixture_or_task)
if not vocab_map:
raise ValueError(
f'No vocabularies found for AirIO task/mixture {mixture_or_task}'
)
all_vocabularies = list(vocab_map.values())
feature = seqio.Feature(vocabulary=all_vocabularies[0])
features = {'inputs': feature, 'targets': feature}
elif isinstance(cfg.mixture_or_task_name, seqio.DatasetProviderBase):
mixture_or_task = cfg.mixture_or_task_name
features = mixture_or_task.output_features
else:
mixture_or_task = seqio.get_mixture_or_task(cfg.mixture_or_task_name)
features = mixture_or_task.output_features
features = mixture_or_task.output_features

if 'inputs' in features and 'targets' in features:
return (features['inputs'].vocabulary, features['targets'].vocabulary)
Expand Down Expand Up @@ -1900,7 +1923,7 @@ def get_dataset(
feature_converter_cls: Callable[..., seqio.FeatureConverter],
num_epochs: Optional[int] = None,
continue_from_last_checkpoint: bool = False,
) -> tf.data.Dataset:
) -> Union[tf.data.Dataset, airio.PyGrainDatasetIteratorWrapper]:
"""Returns a dataset from SeqIO based on a `DatasetConfig`."""
if continue_from_last_checkpoint:
raise ValueError(
Expand Down Expand Up @@ -1941,7 +1964,10 @@ def get_dataset_inner(
):
"""Internal fn to load a dataset from SeqIO based on a `DatasetConfig`."""
batch_size = cfg.batch_size // shard_info.num_shards
if isinstance(cfg.mixture_or_task_name, seqio.DatasetProviderBase):
if isinstance(
cfg.mixture_or_task_name,
(seqio.DatasetProviderBase, airio.DatasetProviderBase),
):
mixture_or_task = cfg.mixture_or_task_name
else:
mixture_or_task = seqio.get_mixture_or_task(cfg.mixture_or_task_name)
Expand Down Expand Up @@ -1993,7 +2019,7 @@ def __call__(
feature_converter_cls: Callable[..., seqio.FeatureConverter],
num_epochs: Optional[int] = None,
continue_from_last_checkpoint: bool = True,
) -> Union[clu.data.dataset_iterator.DatasetIterator, tf.data.Dataset]:
) -> Union[clu.data.dataset_iterator.DatasetIterator, tf.data.Dataset,]:
...


Expand All @@ -2007,7 +2033,9 @@ def __call__(
num_shards: int,
eval_steps: int,
feature_converter_cls: Callable[..., seqio.FeatureConverter],
) -> Mapping[str, tf.data.Dataset]:
) -> Mapping[
str, Union[tf.data.Dataset, airio.PyGrainDatasetIteratorWrapper]
]:
...


Expand All @@ -2020,9 +2048,19 @@ def get_training_eval_datasets(
deterministic: bool = False,
model_dir: Optional[str] = None,
start_step: int = 0,
) -> Mapping[str, tf.data.Dataset]:
) -> Mapping[
str,
Union[
tf.data.Dataset,
airio.PyGrainDatasetIteratorWrapper,
Iterable[Mapping[str, np.ndarray]],
],
]:
"""Returns a mapping from eval task name to its dataset."""
if isinstance(cfg.mixture_or_task_name, seqio.DatasetProviderBase):
if isinstance(
cfg.mixture_or_task_name,
(seqio.DatasetProviderBase, airio.DatasetProviderBase),
):
mixture_or_task = cfg.mixture_or_task_name
else:
mixture_or_task = seqio.get_mixture_or_task(cfg.mixture_or_task_name)
Expand All @@ -2034,6 +2072,22 @@ def get_training_eval_datasets(
get_deterministic_dataset, model_dir=model_dir, start_step=start_step
)

if isinstance(mixture_or_task, (airio.Task, airio.Mixture)):
data_iter = get_dataset_fn(
dataclasses.replace(cfg, batch_size=1),
shard_id=0,
num_shards=1,
feature_converter_cls=feature_converter_cls,
num_epochs=eval_steps * cfg.batch_size,
continue_from_last_checkpoint=False,
)
# TODO(b/304579895): Cannot use itertools.islice here to limit the number
# of records to eval_steps since peek() is required later. Instead, update
# t5x eval to not depend on the data pipeline and stop after eval_run
# steps.
datasets[mixture_or_task.name] = data_iter
return datasets

if cfg.batch_size % num_shards:
raise ValueError(
f'Batch size ({cfg.batch_size}) must be divisible by number of '
Expand Down

0 comments on commit a971c57

Please sign in to comment.