From 711a8fed949d0462a4b90cece9be94b2a96077d2 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 2 Oct 2023 18:46:14 +0000 Subject: [PATCH] fix and formatting --- .../librispeech_jax/workload.py | 58 +++++++++---------- .../librispeech_pytorch/workload.py | 6 +- 2 files changed, 32 insertions(+), 32 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 4fa2e9395..1ef38de82 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -15,7 +15,7 @@ BaseDeepspeechLibrispeechWorkload -class LibriSpeechDeepSpeechWorkload(LibrispeechWorkload): +class LibriSpeechDeepSpeechWorkload(BaseDeepspeechLibrispeechWorkload): def init_model_fn( self, @@ -53,33 +53,33 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - variables = {'params': params, **model_state} - inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] - is_train_mode = mode == spec.ForwardPassMode.TRAIN - if update_batch_norm or is_train_mode: - (logits, logit_paddings), new_model_state = self._model.apply( - variables, - inputs, - input_paddings, - train=True, - rngs={'dropout' : rng}, - mutable=['batch_stats']) - return (logits, logit_paddings), new_model_state - else: - logits, logit_paddings = self._model.apply( - variables, - inputs, - input_paddings, - train=False, - mutable=False) - return (logits, logit_paddings), model_state + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + variables = {'params': params, **model_state} + inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] + is_train_mode = mode == spec.ForwardPassMode.TRAIN + if update_batch_norm or is_train_mode: + (logits, logit_paddings), new_model_state = self._model.apply( + variables, + inputs, + input_paddings, + train=True, + rngs={'dropout' : rng}, + mutable=['batch_stats']) + return (logits, logit_paddings), new_model_state + else: + logits, logit_paddings = self._model.apply( + variables, + inputs, + input_paddings, + train=False, + mutable=False) + return (logits, logit_paddings), model_state # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. @@ -153,4 +153,4 @@ def _eval_model_on_split(self, computed_metrics = metrics_report.compute() - return computed_metrics \ No newline at end of file + return computed_metrics diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index a28d93312..52663b854 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -66,7 +66,7 @@ def init_model_fn( def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['lin.weight', 'lin.bias'] - + def model_fn( self, params: spec.ParameterContainer, @@ -97,7 +97,7 @@ def model_fn( logits, logits_paddings = model(inputs.to(DEVICE), input_paddings.to(DEVICE)) return (logits, logits_paddings), None - + # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( @@ -137,7 +137,7 @@ def loss_fn( 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), 'per_example': per_example_losses, } - + def _eval_model_on_split(self, split: str, num_examples: int,