diff --git a/t5x/train.py b/t5x/train.py index 831e6e64b..61682edc8 100644 --- a/t5x/train.py +++ b/t5x/train.py @@ -61,9 +61,11 @@ def run_actions( - mode: trainer_lib.ActionMode, actions: trainer_lib.ActionMapType, + mode: trainer_lib.ActionMode, + actions: trainer_lib.ActionMapType, train_state: train_state_lib.TrainState, - metrics_by_task: Mapping[str, trainer_lib.MetricValueMapType]) -> bool: + metrics_by_task: Mapping[str, trainer_lib.MetricValueMapType], +) -> bool: """Invokes all actions on the given mode on host 0, then broadcasts to all. Args: @@ -214,7 +216,7 @@ def train( # checkpoint period or the full training. # We compute here to ensure that the eval period and checkpoint period are # divisible by this number, otherwise we fail. - eval_enabled = (train_eval_dataset_cfg or infer_eval_dataset_cfg) + eval_enabled = train_eval_dataset_cfg or infer_eval_dataset_cfg eval_period = eval_period if eval_enabled else 0 checkpoint_period = checkpoint_cfg.save.period if checkpoint_cfg.save else 0 checkpoint_steps = ( @@ -222,14 +224,20 @@ def train( ) if eval_period or checkpoint_period or gc_period: - steps_per_epoch = min(eval_period or np.inf, checkpoint_period or np.inf, - gc_period or np.inf) + steps_per_epoch = min( + eval_period or np.inf, checkpoint_period or np.inf, gc_period or np.inf + ) else: steps_per_epoch = total_steps stats_period = stats_period or steps_per_epoch - if (eval_period and eval_period % steps_per_epoch or - checkpoint_period and checkpoint_period % steps_per_epoch or - gc_period and gc_period % steps_per_epoch): + if ( + eval_period + and eval_period % steps_per_epoch + or checkpoint_period + and checkpoint_period % steps_per_epoch + or gc_period + and gc_period % steps_per_epoch + ): raise ValueError( f'Checkpoint period ({checkpoint_period}), eval ' f'period ({eval_period}), and GC period ({gc_period}) must all be ' @@ -238,7 +246,8 @@ def train( if use_hardware_rng or random_seed is None: logging.info( - 'Using fast RngBitGenerator PRNG for initialization and dropout.') + 'Using fast RngBitGenerator PRNG for initialization and dropout.' + ) if random_seed is None: random_seed = multihost_utils.broadcast_one_to_all(np.int32(time.time())) @@ -247,12 +256,14 @@ def train( logging.warning( 'When using hardware RNG with a fixed seed, repeatability is only ' 'guaranteed for fixed hardware and partitioning schemes and for a ' - 'fixed version of this code and its dependencies.') + 'fixed version of this code and its dependencies.' + ) utils.set_hardware_rng_ops() rng = random.PRNGKey(random_seed) else: - logging.info('Using seed for initialization and dropout RNG: %d', - random_seed) + logging.info( + 'Using seed for initialization and dropout RNG: %d', random_seed + ) rng = random.PRNGKey(random_seed) init_rng, trainer_rng = random.split(rng, 2) @@ -261,13 +272,15 @@ def train( # Initialize datasets # --------------------------------------------------------------------------- - if (train_dataset_cfg.seed and - not (checkpoint_cfg.save and checkpoint_cfg.save.save_dataset)): + if train_dataset_cfg.seed and not ( + checkpoint_cfg.save and checkpoint_cfg.save.save_dataset + ): logging.warning( 'Providing a random seed for the train dataset with ' '`checkpoint_train_ds=False` is dangerous since each ' 'preemption/restart will cause the dataset to deterministically replay ' - 'from the beginning.') + 'from the beginning.' + ) data_layout = partitioner.get_data_layout(train_dataset_cfg.batch_size) ds_shard_id = data_layout.shard_id @@ -281,27 +294,36 @@ def _verify_matching_vocabs(cfg: utils.DatasetConfig): _verify_matching_vocabs(train_dataset_cfg) - train_iter = get_dataset_fn(train_dataset_cfg, ds_shard_id, num_ds_shards, - model.FEATURE_CONVERTER_CLS) + train_iter = get_dataset_fn( + train_dataset_cfg, ds_shard_id, num_ds_shards, model.FEATURE_CONVERTER_CLS + ) train_iter = utils.prepare_train_iter( train_iter, checkpoint_cfg=checkpoint_cfg, partitioner=partitioner, - data_layout=data_layout) + data_layout=data_layout, + ) - input_shapes = jax.tree_map(lambda x: (data_layout.batch_size, *x.shape[1:]), - train_iter.element_spec) + input_shapes = jax.tree_map( + lambda x: (data_layout.batch_size, *x.shape[1:]), train_iter.element_spec + ) input_types = jax.tree_map(lambda x: x.dtype, train_iter.element_spec) if train_eval_dataset_cfg: _verify_matching_vocabs(train_eval_dataset_cfg) train_eval_datasets = train_eval_get_dataset_fn( - train_eval_dataset_cfg, ds_shard_id, num_ds_shards, eval_steps, - model.FEATURE_CONVERTER_CLS) # type: Mapping[str, tf.data.Dataset] + train_eval_dataset_cfg, + ds_shard_id, + num_ds_shards, + eval_steps, + model.FEATURE_CONVERTER_CLS, + ) # type: Mapping[str, tf.data.Dataset] if not train_eval_datasets: logging.warning( 'No train_eval datasets loaded from config `train_eval_dataset_cfg`: ' - '%s', train_eval_dataset_cfg) + '%s', + train_eval_dataset_cfg, + ) else: train_eval_datasets = {} @@ -325,17 +347,21 @@ def _verify_matching_vocabs(cfg: utils.DatasetConfig): mode='latest', dtype=checkpoint_cfg.save.dtype if checkpoint_cfg.save else 'float32', checkpointer_cls=checkpoint_cfg.save.checkpointer_cls - if checkpoint_cfg.save else checkpoints.Checkpointer, + if checkpoint_cfg.save + else checkpoints.Checkpointer, # Restore dataset state if it is being saved. - restore_dataset=(checkpoint_cfg.save and - checkpoint_cfg.save.save_dataset), - state_transformation_fns=state_transforms_for_restore) + restore_dataset=( + checkpoint_cfg.save and checkpoint_cfg.save.save_dataset + ), + state_transformation_fns=state_transforms_for_restore, + ) ] # 2. From a checkpoint specified by `checkpoint_cfg.restore.path`, if set. if checkpoint_cfg.restore: if checkpoint_cfg.restore.mode == 'all': raise ValueError( - "Restore checkpoint mode 'all' is not supported in training.") + "Restore checkpoint mode 'all' is not supported in training." + ) # TODO(dhgarrette): Split "restore" behavior into separate configurations # for the initial restoration for a new run, vs resuming a stopped run. @@ -346,7 +372,8 @@ def _verify_matching_vocabs(cfg: utils.DatasetConfig): pass else: raise ValueError( - 'Restore checkpoint config may only have a single path in training.') + 'Restore checkpoint config may only have a single path in training.' + ) init_or_restore_tick = time.time() train_state_initializer = train_state_initializer_cls( @@ -354,7 +381,8 @@ def _verify_matching_vocabs(cfg: utils.DatasetConfig): init_fn=model.get_initial_variables, input_shapes=input_shapes, input_types=input_types, - partitioner=partitioner) + partitioner=partitioner, + ) # May be None, empty valid_restore_cfg, restore_paths = ( @@ -388,14 +416,15 @@ def _verify_matching_vocabs(cfg: utils.DatasetConfig): train_state = train_state or train_state_initializer.from_scratch(init_rng) train_state_axes = train_state_initializer.train_state_axes init_or_restore_secs = time.time() - init_or_restore_tick - logging.info('Initialize/restore complete (%.2f seconds).', - init_or_restore_secs) + logging.info( + 'Initialize/restore complete (%.2f seconds).', init_or_restore_secs + ) # Log the variable shapes information and write to a file. log_file = os.path.join(model_dir, 'model-info.txt') - utils.log_model_info(log_file, - train_state_initializer.global_train_state_shape, - partitioner) + utils.log_model_info( + log_file, train_state_initializer.global_train_state_shape, partitioner + ) # Restore step from last checkpoint or set to 0 if training from scratch. host_step = int(utils.get_local_data(train_state.step)) # pytype: disable=attribute-error @@ -411,14 +440,16 @@ def _verify_matching_vocabs(cfg: utils.DatasetConfig): train_state_axes=train_state_axes, eval_names=train_eval_datasets.keys(), summary_dir=model_dir, - rng=trainer_rng) + rng=trainer_rng, + ) del train_state train_metrics = trainer.train_metrics_manager summarize_config_fn(model_dir, train_metrics.summary_writer, host_step) - train_metrics.write_scalar('timing/init_or_restore_seconds', - init_or_restore_secs, host_step) + train_metrics.write_scalar( + 'timing/init_or_restore_seconds', init_or_restore_secs, host_step + ) # ---------------------------------------------------------------------------- # SeqIO (inference-based) evaluation setup @@ -432,7 +463,8 @@ def _verify_matching_vocabs(cfg: utils.DatasetConfig): model=model, partitioner=partitioner, log_dir=model_dir, - verify_matching_vocabs_fn=verify_matching_vocabs_fn) + verify_matching_vocabs_fn=verify_matching_vocabs_fn, + ) if not evaluator.eval_tasks: # Skip evaluation. evaluator = None @@ -441,20 +473,30 @@ def _verify_matching_vocabs(cfg: utils.DatasetConfig): actions = {} if set(actions.keys()).difference(_ACTION_KEYS): - raise ValueError(f'actions keys must be one of {_ACTION_KEYS}, but got : ' - f'{actions.keys()}') + raise ValueError( + f'actions keys must be one of {_ACTION_KEYS}, but got : ' + f'{actions.keys()}' + ) # Transform the string key into proper ActionMode enum. actions = {trainer_lib.ActionMode[k]: v for k, v in actions.items()} - if concurrent_metrics and actions.get(trainer_lib.ActionMode.INFER_EVAL, - None) is not None: - logging.warning('Actions for INFER_EVAL will not be triggered when async ' - 'metrics computation is enabled') - if concurrent_metrics and actions.get(trainer_lib.ActionMode.TRAIN, - None) is not None: - logging.warning('Actions for TRAIN will not be triggered when async ' - 'metrics computation is enabled') + if ( + concurrent_metrics + and actions.get(trainer_lib.ActionMode.INFER_EVAL, None) is not None + ): + logging.warning( + 'Actions for INFER_EVAL will not be triggered when async ' + 'metrics computation is enabled' + ) + if ( + concurrent_metrics + and actions.get(trainer_lib.ActionMode.TRAIN, None) is not None + ): + logging.warning( + 'Actions for TRAIN will not be triggered when async ' + 'metrics computation is enabled' + ) # ---------------------------------------------------------------------------- # Setup Eval Utility Functions @@ -462,19 +504,23 @@ def _verify_matching_vocabs(cfg: utils.DatasetConfig): def _run_training_eval(first_run: bool = False): if first_run: logging.info('Compiling training eval loop.') - trainer.compile_eval({ # pytype: disable=wrong-arg-types # jax-ndarray - task: utils.get_zeros_batch_like_dataset(ds) - for task, ds in train_eval_datasets.items() - }) + trainer.compile_eval( + { # pytype: disable=wrong-arg-types # jax-ndarray + task: utils.get_zeros_batch_like_dataset(ds) + for task, ds in train_eval_datasets.items() + } + ) logging.info('Computing training evaluation metrics.') eval_batch_iters = { - task: ds.as_numpy_iterator() - for task, ds in train_eval_datasets.items() + task: ds.as_numpy_iterator() for task, ds in train_eval_datasets.items() } eval_summaries = trainer.eval(eval_batch_iters) - trainer.stop_training = run_actions(trainer_lib.ActionMode.TRAIN_EVAL, # pytype: disable=wrong-arg-types # jax-ndarray - actions, trainer.train_state, - eval_summaries) + trainer.stop_training = run_actions( + trainer_lib.ActionMode.TRAIN_EVAL, # pytype: disable=wrong-arg-types # jax-ndarray + actions, + trainer.train_state, + eval_summaries, + ) def _run_inference_eval(): """Run prediction based inference eval.""" @@ -486,11 +532,15 @@ def _run_inference_eval(): if not concurrent_metrics: # Ensure metrics are finished being computed. all_metrics_done = all_metrics.result() or {} - trainer.stop_training = run_actions(trainer_lib.ActionMode.INFER_EVAL, - actions, trainer.train_state, - all_metrics_done) - train_metrics.write_scalar('timing/evaluate_seconds', - time.time() - evaluate_tick, host_step) + trainer.stop_training = run_actions( + trainer_lib.ActionMode.INFER_EVAL, + actions, + trainer.train_state, + all_metrics_done, + ) + train_metrics.write_scalar( + 'timing/evaluate_seconds', time.time() - evaluate_tick, host_step + ) # Optionally run teacher-forcing training eval and SeqIO inference-base eval # before training. Useful for testing how much a model knows before any @@ -543,12 +593,15 @@ def _cleanup() -> None: if total_steps < first_step: raise ValueError( f'Unexpected total_steps ({total_steps}) < checkpoint step ' - f' ({first_step}).') + f' ({first_step}).' + ) elif total_steps == first_step: logging.warning( 'Total training steps and checkpoint step were both %d, so no training ' 'will be done. If you are only doing evaluation, this is expected. ' - 'Stopping now.', total_steps) + 'Stopping now.', + total_steps, + ) _cleanup() return host_step, trainer.train_state @@ -557,9 +610,11 @@ def _cleanup() -> None: steps_per_epoch = min(steps_per_epoch, total_steps) first_epoch = first_step // steps_per_epoch num_epochs = first_epoch + math.ceil( - (total_steps - first_step) / steps_per_epoch) - logging.info('Training with artificial "epochs" of %d steps.', - steps_per_epoch) + (total_steps - first_step) / steps_per_epoch + ) + logging.info( + 'Training with artificial "epochs" of %d steps.', steps_per_epoch + ) logging.info('Compiling train loop.') logging.flush() @@ -577,8 +632,10 @@ def _as_gda(spec): # Construct dummy batch for compiling the model. dummy_batch = jax.tree_map(_as_gda, train_iter.element_spec) if not isinstance(dummy_batch, Mapping): - raise ValueError('Training loop expects batches to have type ' - f'Mapping[str, np.ndarray] but got {type(dummy_batch)}.') + raise ValueError( + 'Training loop expects batches to have type ' + f'Mapping[str, np.ndarray] but got {type(dummy_batch)}.' + ) assert isinstance(dummy_batch, Mapping) trainer.compile_train(dummy_batch) @@ -592,13 +649,15 @@ def _as_gda(spec): # model compilation above. We just measure the additional time needed. first_batch_ready.result() train_iter_warmup_tock = time.time() - train_metrics.write_scalar('timing/train_iter_warmup', - train_iter_warmup_tock - train_iter_warmup_tick, - host_step) + train_metrics.write_scalar( + 'timing/train_iter_warmup', + train_iter_warmup_tock - train_iter_warmup_tick, + host_step, + ) jax.monitoring.record_event_duration_secs( - '/jax/t5x/train/time_before_first_step_secs', - time.time() - _IMPORT_TIME) + '/jax/t5x/train/time_before_first_step_secs', time.time() - _IMPORT_TIME + ) # Current index within checkpoint_steps list for faster lookup runtime and # for creating a checkpoint if needed between stats_period iterations. @@ -627,8 +686,9 @@ def _as_gda(spec): trainer.train_state, checkpoint_cfg.save.state_transformation_fns, # pytype: disable=attribute-error ) - logging.info('Stopping training loop early since `stop_training` is ' - 'requested.') + logging.info( + 'Stopping training loop early since `stop_training` is requested.' + ) break inner_num_steps = min(epoch_end_step - host_step, stats_period) @@ -710,11 +770,15 @@ def _as_gda(spec): checkpoint_cfg.save.state_transformation_fns, # pytype: disable=attribute-error ) checkpoint_tock = time.time() - train_metrics.write_scalar('timing/checkpoint_seconds', - checkpoint_tock - checkpoint_tick, host_step) + train_metrics.write_scalar( + 'timing/checkpoint_seconds', + checkpoint_tock - checkpoint_tick, + host_step, + ) - is_eval_epoch = eval_period and (final_epoch or - step_offset % eval_period == 0) + is_eval_epoch = eval_period and ( + final_epoch or step_offset % eval_period == 0 + ) # Training Evaluation (i.e., with teacher forcing). if is_eval_epoch and train_eval_datasets: @@ -752,45 +816,61 @@ def _as_gda(spec): flags.DEFINE_multi_string( 'gin_file', default=None, - help='Path to gin configuration file. Multiple paths may be passed and ' - 'will be imported in the given order, with later configurations ' - 'overriding earlier ones.') + help=( + 'Path to gin configuration file. Multiple paths may be passed and ' + 'will be imported in the given order, with later configurations ' + 'overriding earlier ones.' + ), + ) flags.DEFINE_multi_string( - 'gin_bindings', default=[], help='Individual gin bindings.') + 'gin_bindings', default=[], help='Individual gin bindings.' + ) flags.DEFINE_list( 'gin_search_paths', default=['.'], - help='Comma-separated list of gin config path prefixes to be prepended ' - 'to suffixes given via `--gin_file`. If a file appears in. Only the ' - 'first prefix that produces a valid path for each suffix will be ' - 'used.') + help=( + 'Comma-separated list of gin config path prefixes to be prepended ' + 'to suffixes given via `--gin_file`. If a file appears in. Only the ' + 'first prefix that produces a valid path for each suffix will be ' + 'used.' + ), + ) flags.DEFINE_string( - 'tfds_data_dir', None, + 'tfds_data_dir', + None, 'If set, this directory will be used to store datasets prepared by ' 'TensorFlow Datasets that are not available in the public TFDS GCS ' 'bucket. Note that this flag overrides the `tfds_data_dir` attribute of ' - 'all `Task`s.') + 'all `Task`s.', + ) flags.DEFINE_list( - 'seqio_additional_cache_dirs', [], - 'Directories to search for cached Tasks in addition to defaults.') + 'seqio_additional_cache_dirs', + [], + 'Directories to search for cached Tasks in addition to defaults.', + ) flags.DEFINE_boolean( 'multiprocess_gpu', False, - help='Initialize JAX distributed system for multi-host GPU, using ' - '`coordinator_address`, `process_count`, and `process_index`.') + help=( + 'Initialize JAX distributed system for multi-host GPU, using ' + '`coordinator_address`, `process_count`, and `process_index`.' + ), + ) flags.DEFINE_string( 'coordinator_address', None, - help='IP address:port for multi-host GPU coordinator.') + help='IP address:port for multi-host GPU coordinator.', + ) flags.DEFINE_integer( - 'process_count', None, help='Number of processes for multi-host GPU.') + 'process_count', None, help='Number of processes for multi-host GPU.' + ) flags.DEFINE_integer('process_index', None, help='Index of this process.') flags.DEFINE_integer(