From 107f5b05b4e1b6b6561ce7b238f8fedbd23709f1 Mon Sep 17 00:00:00 2001 From: Michelle Casbon Date: Wed, 20 Sep 2023 12:04:26 -0700 Subject: [PATCH] Resolve lint issues in interactive_model. Should be a no-op with formatting changes only. PiperOrigin-RevId: 567039673 --- t5x/interactive_model.py | 281 +++++++++++++++++++++++++-------------- 1 file changed, 180 insertions(+), 101 deletions(-) diff --git a/t5x/interactive_model.py b/t5x/interactive_model.py index ea757002a..399b93b8f 100644 --- a/t5x/interactive_model.py +++ b/t5x/interactive_model.py @@ -46,8 +46,9 @@ import tensorflow as tf import tensorflow_datasets as tfds -BatchesType = Union[Sequence[Mapping[str, str]], - Sequence[Sequence[Mapping[str, str]]]] +BatchesType = Union[ + Sequence[Mapping[str, str]], Sequence[Sequence[Mapping[str, str]]] +] class InferenceType(enum.Enum): @@ -140,7 +141,8 @@ def __init__( # -------------------------------------------------------------------------- self._init_random_seed = init_random_seed random_seed = multihost_utils.broadcast_one_to_all( - np.int32(self._init_random_seed)) + np.int32(self._init_random_seed) + ) utils.set_hardware_rng_ops() rng = random.PRNGKey(random_seed) @@ -158,7 +160,8 @@ def __init__( "The number of devices available must be a multiple of the number of", f" partitions. There are {jax.device_count()} devices available, but", f" the number of partitions is set to {num_partitions}. Please", - " provide a different number of partitions.") + " provide a different number of partitions.", + ) self._partitioner = partitioner # -------------------------------------------------------------------------- @@ -171,12 +174,12 @@ def __init__( self._input_types = input_types # Save the model vocabulary as features. output_features = { - "inputs": - seqio.Feature( - vocabulary=self._model.input_vocabulary, add_eos=add_eos), - "targets": - seqio.Feature( - vocabulary=self._model.output_vocabulary, add_eos=add_eos) + "inputs": seqio.Feature( + vocabulary=self._model.input_vocabulary, add_eos=add_eos + ), + "targets": seqio.Feature( + vocabulary=self._model.output_vocabulary, add_eos=add_eos + ), } self._features = dict(sorted(output_features.items())) @@ -195,7 +198,8 @@ def __init__( init_fn=self._model.get_initial_variables, input_shapes=self._input_shapes, input_types=self._input_types, - partitioner=self._partitioner) + partitioner=self._partitioner, + ) # Initialize checkpoint manager. self._checkpoint_manager = utils.LegacyCheckpointManager( @@ -233,40 +237,52 @@ def get_state(rng): path=self._output_dir, mode="latest", dtype=self._save_checkpoint_cfg.dtype - if self._save_checkpoint_cfg else "float32", + if self._save_checkpoint_cfg + else "float32", checkpointer_cls=self._save_checkpoint_cfg.checkpointer_cls - if self._save_checkpoint_cfg else checkpoints.Checkpointer, + if self._save_checkpoint_cfg + else checkpoints.Checkpointer, # Restore dataset state if it is being saved. - restore_dataset=(self._save_checkpoint_cfg and - self._save_checkpoint_cfg.save_dataset), - state_transformation_fns=state_transforms_for_restore)) + restore_dataset=( + self._save_checkpoint_cfg + and self._save_checkpoint_cfg.save_dataset + ), + state_transformation_fns=state_transforms_for_restore, + ) + ) # Restore the model using a checkpoint. valid_restore_cfg, restore_paths = ( utils.get_first_valid_restore_config_and_paths(restore_cfgs) ) self._train_state = self._checkpoint_manager.restore( - restore_paths, valid_restore_cfg, - utils.get_fallback_state(valid_restore_cfg, get_state, self._init_rng)) + restore_paths, + valid_restore_cfg, + utils.get_fallback_state(valid_restore_cfg, get_state, self._init_rng), + ) # 3. If no checkpoint to restore, init from scratch. if self._train_state is None: self._train_state = self._train_state_initializer.from_scratch( - self._init_rng) + self._init_rng + ) self._train_state_axes = self._train_state_initializer.train_state_axes # Log the variable shapes information and write to a file. log_file = os.path.join(self._output_dir, "model-info.txt") - utils.log_model_info(log_file, - self._train_state_initializer.global_train_state_shape, - self._partitioner) + utils.log_model_info( + log_file, + self._train_state_initializer.global_train_state_shape, + self._partitioner, + ) # -------------------------------------------------------------------------- # Trainer # -------------------------------------------------------------------------- if isinstance(self._train_state, Sequence): raise ValueError( - "Expected a single train state, but instead received a Sequence.") + "Expected a single train state, but instead received a Sequence." + ) self._trainer = trainer_lib.Trainer( model=self._model, train_state=self._train_state, @@ -307,7 +323,8 @@ def train_summary(self): def step(self): if isinstance(self._train_state, Sequence): raise ValueError( - "Expected a single train state, but instead received a Sequence.") + "Expected a single train state, but instead received a Sequence." + ) return int(self._train_state.step) def train_step(self, examples: Sequence[Union[str, dict[str, str]]]): @@ -331,11 +348,14 @@ def train_step(self, examples: Sequence[Union[str, dict[str, str]]]): seqio.preprocessors.append_eos, ] self.train_step_with_preprocessors( - examples=examples, preprocessors=preprocessors) + examples=examples, preprocessors=preprocessors + ) def train_step_with_preprocessors( - self, examples: Sequence[Union[str, dict[str, str]]], - preprocessors: Sequence[Callable[..., tf.data.Dataset]]): + self, + examples: Sequence[Union[str, dict[str, str]]], + preprocessors: Sequence[Callable[..., tf.data.Dataset]], + ): """Train function. Args: @@ -360,19 +380,24 @@ def train_step_with_preprocessors( if len(examples) < self._batch_size: raise ValueError( "At least one batch of data must be provided. Please decrease the " - "batch_size or provide more examples.") + "batch_size or provide more examples." + ) train_dataset = get_dataset_from_natural_text_examples( examples, preprocessors=preprocessors, task_feature_lengths=self._task_feature_lengths, - features=self._features) + features=self._features, + ) train_dataset = self._feature_converter( - train_dataset, task_feature_lengths=self._task_feature_lengths) + train_dataset, task_feature_lengths=self._task_feature_lengths + ) train_dataset = train_dataset.padded_batch( - self._batch_size, drop_remainder=True) + self._batch_size, drop_remainder=True + ) train_iter = clu.data.dataset_iterator.TfDatasetIterator( - train_dataset, checkpoint=True) + train_dataset, checkpoint=True + ) # -------------------------------------------------------------------------- # Take 1 train step. @@ -380,7 +405,8 @@ def train_step_with_preprocessors( # `stop_training` is requested, break out the main loop immediately. if self._trainer.stop_training: logging.info( - "Stopping training early since `stop_training` is requested.") + "Stopping training early since `stop_training` is requested." + ) return try: @@ -452,8 +478,10 @@ def infer_with_preprocessors( elif mode == InferenceType.SCORE: infer_step = self._model.score_batch else: - raise ValueError("Mode must be `predict_with_aux`, or `score`," - f" but instead was {mode}.") + raise ValueError( + "Mode must be `predict_with_aux`, or `score`," + f" but instead was {mode}." + ) key_array = seqio.utils.flatten_dict(inference_kwargs) key_array["mode"] = mode infer_fn_key = tuple(key_array.items()) @@ -476,16 +504,20 @@ def infer_with_preprocessors( examples, preprocessors=preprocessors, task_feature_lengths=self._task_feature_lengths, - features=self._features) + features=self._features, + ) model_dataset = self._feature_converter( - dataset, task_feature_lengths=self._task_feature_lengths) + dataset, task_feature_lengths=self._task_feature_lengths + ) # Zip task and model features. infer_dataset = tf.data.Dataset.zip((dataset, model_dataset)) # Create batches and index them. infer_dataset = infer_dataset.padded_batch( - self._batch_size, drop_remainder=False).enumerate() + self._batch_size, drop_remainder=False + ).enumerate() infer_dataset_iter: Iterator[Tuple[int, Any]] = iter( - infer_dataset.prefetch(tf.data.experimental.AUTOTUNE)) + infer_dataset.prefetch(tf.data.experimental.AUTOTUNE) + ) # -------------------------------------------------------------------------- # Run inference @@ -502,15 +534,18 @@ def infer_with_preprocessors( # Unzip chunk dataset in to pretokenized and model datasets. task_dataset = chunk_dataset.map( - lambda p, m: p, num_parallel_calls=tf.data.experimental.AUTOTUNE) + lambda p, m: p, num_parallel_calls=tf.data.experimental.AUTOTUNE + ) model_dataset = chunk_dataset.map( - lambda p, m: m, num_parallel_calls=tf.data.experimental.AUTOTUNE) + lambda p, m: m, num_parallel_calls=tf.data.experimental.AUTOTUNE + ) # Get a chunk-specific RNG key. chunk_rng = jax.random.fold_in(jax.random.PRNGKey(0), chunk) inferences = _extract_tokens_and_aux_values( - infer_fn(model_dataset.enumerate(), rng=chunk_rng)) + infer_fn(model_dataset.enumerate(), rng=chunk_rng) + ) predictions, aux_values = inferences accumulated_inferences = [] @@ -518,8 +553,11 @@ def infer_with_preprocessors( prediction = predictions[idx] # Decode predictions if applicable. if mode == InferenceType.PREDICT_WITH_AUX: - prediction = self._features["targets"].vocabulary.decode_tf( - tf.constant(prediction)).numpy() + prediction = ( + self._features["targets"] + .vocabulary.decode_tf(tf.constant(prediction)) + .numpy() + ) accumulated_inferences.append((inputs, prediction)) all_inferences += accumulated_inferences # Accumulate aux values over batches. @@ -532,7 +570,8 @@ def infer_with_preprocessors( return all_inferences, all_aux_values def predict_with_aux( - self, examples: Sequence[Union[str, dict[str, str]]]) -> _Inferences: + self, examples: Sequence[Union[str, dict[str, str]]] + ) -> _Inferences: """Predict with auxiliary values method.""" # By default, only tokenize and append EOS. preprocessors = [ @@ -542,10 +581,12 @@ def predict_with_aux( return self.infer_with_preprocessors( mode=InferenceType.PREDICT_WITH_AUX, examples=examples, - preprocessors=preprocessors) + preprocessors=preprocessors, + ) - def score(self, examples: Sequence[Union[str, dict[str, - str]]]) -> Sequence[Any]: + def score( + self, examples: Sequence[Union[str, dict[str, str]]] + ) -> Sequence[Any]: """Score method.""" # By default, only tokenize and append EOS. preprocessors = [ @@ -554,18 +595,22 @@ def score(self, examples: Sequence[Union[str, dict[str, ] # Ignore auxiliary values. scores, _ = self.infer_with_preprocessors( - mode=InferenceType.SCORE, - examples=examples, - preprocessors=preprocessors) + mode=InferenceType.SCORE, examples=examples, preprocessors=preprocessors + ) return scores def _compute_metrics( - self, targets: Sequence[Any], predictions: Sequence[Any], - aux_values: Sequence[Any], scores: Sequence[Any], + self, + targets: Sequence[Any], + predictions: Sequence[Any], + aux_values: Sequence[Any], + scores: Sequence[Any], predict_metric_fns: Sequence[seqio.dataset_providers.MetricFnCallable], predict_with_aux_metric_fns: Sequence[ - seqio.dataset_providers.MetricFnCallable], - score_metric_fns: Sequence[seqio.dataset_providers.MetricFnCallable]): + seqio.dataset_providers.MetricFnCallable + ], + score_metric_fns: Sequence[seqio.dataset_providers.MetricFnCallable], + ): """Computes the metrics specified in the metric_fns lists.""" # Only compute metrics once if jax.process_index() != 0: @@ -574,22 +619,31 @@ def _compute_metrics( def compute_metrics_fn(): task_metrics = [] if predict_metric_fns: - task_metrics.extend([ - metric_fn(targets, predictions) for metric_fn in predict_metric_fns - ]) + task_metrics.extend( + [ + metric_fn(targets, predictions) + for metric_fn in predict_metric_fns + ] + ) if predict_with_aux_metric_fns: - task_metrics.extend([ - metric_fn(targets, predictions, aux_values) - for metric_fn in predict_with_aux_metric_fns - ]) + task_metrics.extend( + [ + metric_fn(targets, predictions, aux_values) + for metric_fn in predict_with_aux_metric_fns + ] + ) if score_metric_fns: is_tuple = isinstance(scores, tuple) - if ((not is_tuple and len(targets) != len(scores)) or - (is_tuple and len(targets) != len(scores[0]))): - raise ValueError(f"len(targets)({len(targets)}) != " - f"len(output_scores)({len(scores)})") + if (not is_tuple and len(targets) != len(scores)) or ( + is_tuple and len(targets) != len(scores[0]) + ): + raise ValueError( + f"len(targets)({len(targets)}) != " + f"len(output_scores)({len(scores)})" + ) task_metrics.extend( - [metric_fn(targets, scores) for metric_fn in score_metric_fns]) + [metric_fn(targets, scores) for metric_fn in score_metric_fns] + ) all_metrics = {} for k, v in itertools.chain(*[m.items() for m in task_metrics]): @@ -647,7 +701,8 @@ def evaluate( examples=examples, preprocessors=preprocessors, metric_fns=metric_fns, - postprocessor=None) + postprocessor=None, + ) def evaluate_with_preprocessors( self, @@ -687,8 +742,10 @@ def evaluate_with_preprocessors( score_metric_fns = [] for metric_fn in metric_fns: pos_args = tuple( - key for key, param in inspect.signature(metric_fn).parameters.items() - if param.default == inspect.Parameter.empty) + key + for key, param in inspect.signature(metric_fn).parameters.items() + if param.default == inspect.Parameter.empty + ) if pos_args == ("targets", "scores"): score_metric_fns.append(metric_fn) elif pos_args == ("targets", "predictions"): @@ -700,7 +757,8 @@ def evaluate_with_preprocessors( "Metric functions must have positional arguments matching either " "('targets', 'scores'), ('targets', 'predictions') or " "('targets', 'predictions', 'aux_values'). " - f"Got: {pos_args}") + f"Got: {pos_args}" + ) # ------------------------------------------------------------------------ # Get targets, predictions, and scores @@ -709,7 +767,8 @@ def evaluate_with_preprocessors( examples, preprocessors=preprocessors, task_feature_lengths=self._task_feature_lengths, - features=self._features) + features=self._features, + ) # Get targets. def postprocess_fn(decoded_model_output: Any, **postprocess_kwargs) -> Any: @@ -724,7 +783,9 @@ def postprocess_fn(decoded_model_output: Any, **postprocess_kwargs) -> Any: postprocess_fn( decoded_model_output=ex["targets_pretokenized"], example=ex, - is_target=True)) + is_target=True, + ) + ) # Get predictions. predictions = [] @@ -732,7 +793,8 @@ def postprocess_fn(decoded_model_output: Any, **postprocess_kwargs) -> Any: predictions, aux_values = self.infer_with_preprocessors( mode=InferenceType.PREDICT_WITH_AUX, examples=examples, - preprocessors=preprocessors) + preprocessors=preprocessors, + ) predictions = [ prediction.decode("utf-8") for example, prediction in predictions ] @@ -742,7 +804,8 @@ def postprocess_fn(decoded_model_output: Any, **postprocess_kwargs) -> Any: scores, _ = self.infer_with_preprocessors( mode=InferenceType.SCORE, examples=examples, - preprocessors=preprocessors) + preprocessors=preprocessors, + ) scores = [score for example, score in scores] return self._compute_metrics( @@ -752,7 +815,8 @@ def postprocess_fn(decoded_model_output: Any, **postprocess_kwargs) -> Any: scores, # pytype: disable=wrong-arg-types # mapping-is-not-sequence predict_metric_fns, predict_with_aux_metric_fns, - score_metric_fns) + score_metric_fns, + ) def train_loop( self, @@ -762,8 +826,9 @@ def train_loop( predict_batches: Optional[BatchesType] = None, score_batches: Optional[BatchesType] = None, eval_batches: Optional[BatchesType] = None, - metrics_fns: Optional[Sequence[ - seqio.dataset_providers.MetricFnCallable]] = None, + metrics_fns: Optional[ + Sequence[seqio.dataset_providers.MetricFnCallable] + ] = None, ): """Runs training, inference, and evaluation for `num_steps`. @@ -832,7 +897,8 @@ def get_dataset_from_natural_text_examples( examples: Sequence[Union[str, dict[str, str]]], preprocessors: Sequence[Callable[..., tf.data.Dataset]], task_feature_lengths: Mapping[str, int], - features: Mapping[str, Any]) -> tf.data.Dataset: + features: Mapping[str, Any], +) -> tf.data.Dataset: """Returns a tf.data.Dataset from a list of examples. Args: @@ -903,7 +969,8 @@ def _validate_preprocessing(dataset: tf.data.Dataset) -> tf.data.Dataset: if feat_spec.required: raise ValueError( "Task dataset is missing expected output feature after " - f"preprocessing: {feat}") + f"preprocessing: {feat}" + ) else: # It's ok that this feature does not exist. continue @@ -912,12 +979,14 @@ def _validate_preprocessing(dataset: tf.data.Dataset) -> tf.data.Dataset: raise ValueError( f"Task dataset has incorrect type for feature '{feat}' after " f"preprocessing: Got {actual_spec.dtype.name}, expected " - f"{feat_spec.dtype.name}") + f"{feat_spec.dtype.name}" + ) if feat_spec.rank != actual_spec.shape.rank: raise ValueError( f"Task dataset has incorrect rank for feature '{feat}' after " f"preprocessing: Got {actual_spec.shape.rank}, expected " - f"{feat_spec.rank}") + f"{feat_spec.rank}" + ) return dataset @@ -926,8 +995,9 @@ def _validate_preprocessing(dataset: tf.data.Dataset) -> tf.data.Dataset: return dataset.prefetch(tf.data.experimental.AUTOTUNE) -def _get_equal_length_batches(batches: BatchesType, - length: int) -> Sequence[Any]: +def _get_equal_length_batches( + batches: BatchesType, length: int +) -> Sequence[Any]: """Produces a list of batches that is `length` batches long. Given a single batch, repeat the batch `length` times. @@ -953,7 +1023,8 @@ def _get_equal_length_batches(batches: BatchesType, batches = batches * (length // len(batches)) # If multiple batches are provided, only use the first `length` batches. logging.warning( - "We will only use the first %s batches provided for training.", length) + "We will only use the first %s batches provided for training.", length + ) return batches[:length] @@ -964,7 +1035,8 @@ def get_batches_from_seqio( num_batches: int, get_pretokenized_examples: bool = True, sequence_length: Optional[Mapping[str, int]] = None, - **get_dataset_kwargs) -> Sequence[Sequence[Mapping[str, str]]]: + **get_dataset_kwargs, +) -> Sequence[Sequence[Mapping[str, str]]]: """Returns a batch of examples from a provided SeqIO task. Args: @@ -1028,11 +1100,13 @@ def get_batches_from_seqio( current_batch = [] if total_examples_seen < total_examples_requested: - raise ValueError("Not enough examples in Task/Mixture. User requested " - f"{num_batches} batches of size {batch_size} for a total " - f"of {total_examples_requested} examples. Only " - f"{total_examples_seen} available in " - "Task/Mixture.") + raise ValueError( + "Not enough examples in Task/Mixture. User requested " + f"{num_batches} batches of size {batch_size} for a total " + f"of {total_examples_requested} examples. Only " + f"{total_examples_seen} available in " + "Task/Mixture." + ) return all_batches @@ -1080,10 +1154,12 @@ def dataset_fn(split, shuffle_files): examples, preprocessors=[], task_feature_lengths=interactive_model._task_feature_lengths, # pylint: disable=protected-access - features={}) + features={}, + ) data_source = seqio.FunctionDataSource( - dataset_fn=dataset_fn, splits=["train", "validation"]) + dataset_fn=dataset_fn, splits=["train", "validation"] + ) if add_to_registry: seqio.TaskRegistry.add( @@ -1091,19 +1167,22 @@ def dataset_fn(split, shuffle_files): data_source, preprocessors=preprocessors, output_features=interactive_model._features, # pylint: disable=protected-access - metric_fns=metric_fns) + metric_fns=metric_fns, + ) return seqio.get_mixture_or_task(task_name) # pylint: disable=protected-access -def get_gin_config_from_interactive_model(interactive_model: InteractiveModel, - script_type: T5XScriptType, - task_name: str, - partitioner_config_str: str, - model_config_str: str, - train_steps: int = 1, - imports_str: str = ""): +def get_gin_config_from_interactive_model( + interactive_model: InteractiveModel, + script_type: T5XScriptType, + task_name: str, + partitioner_config_str: str, + model_config_str: str, + train_steps: int = 1, + imports_str: str = "", +): """Converts an InteractiveModel instance into a Gin config string. This function will be used to graduate people to the T5X/SeqIO-based