Skip to content

Commit

Permalink
fix and formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Oct 2, 2023
1 parent f7789db commit 711a8fe
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
BaseDeepspeechLibrispeechWorkload


class LibriSpeechDeepSpeechWorkload(LibrispeechWorkload):
class LibriSpeechDeepSpeechWorkload(BaseDeepspeechLibrispeechWorkload):

def init_model_fn(
self,
Expand Down Expand Up @@ -53,33 +53,33 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Dense_0'

def model_fn(
self,
params: spec.ParameterContainer,
augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
variables = {'params': params, **model_state}
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
is_train_mode = mode == spec.ForwardPassMode.TRAIN
if update_batch_norm or is_train_mode:
(logits, logit_paddings), new_model_state = self._model.apply(
variables,
inputs,
input_paddings,
train=True,
rngs={'dropout' : rng},
mutable=['batch_stats'])
return (logits, logit_paddings), new_model_state
else:
logits, logit_paddings = self._model.apply(
variables,
inputs,
input_paddings,
train=False,
mutable=False)
return (logits, logit_paddings), model_state
self,
params: spec.ParameterContainer,
augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
variables = {'params': params, **model_state}
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
is_train_mode = mode == spec.ForwardPassMode.TRAIN
if update_batch_norm or is_train_mode:
(logits, logit_paddings), new_model_state = self._model.apply(
variables,
inputs,
input_paddings,
train=True,
rngs={'dropout' : rng},
mutable=['batch_stats'])
return (logits, logit_paddings), new_model_state
else:
logits, logit_paddings = self._model.apply(
variables,
inputs,
input_paddings,
train=False,
mutable=False)
return (logits, logit_paddings), model_state

# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
Expand Down Expand Up @@ -153,4 +153,4 @@ def _eval_model_on_split(self,

computed_metrics = metrics_report.compute()

return computed_metrics
return computed_metrics
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def init_model_fn(

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['lin.weight', 'lin.bias']

def model_fn(
self,
params: spec.ParameterContainer,
Expand Down Expand Up @@ -97,7 +97,7 @@ def model_fn(
logits, logits_paddings = model(inputs.to(DEVICE),
input_paddings.to(DEVICE))
return (logits, logits_paddings), None

# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
Expand Down Expand Up @@ -137,7 +137,7 @@ def loss_fn(
'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE),
'per_example': per_example_losses,
}

def _eval_model_on_split(self,
split: str,
num_examples: int,
Expand Down

0 comments on commit 711a8fe

Please sign in to comment.