Skip to content

Commit

Permalink
add seperate model_fn for deepspeech jax without use_running_average_bn
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Oct 18, 2024
1 parent 087fd5c commit c5c36c2
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def model_fn(
variables = {'params': params, **model_state}
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
is_train_mode = mode == spec.ForwardPassMode.TRAIN
print(type(use_running_average_bn))
if update_batch_norm or is_train_mode:
(logits, logit_paddings), new_model_state = self._model.apply(
variables,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,37 @@ def init_model_fn(
model_state = jax_utils.replicate(model_state)
params = jax_utils.replicate(params)
return params, model_state

def model_fn(
self,
params: spec.ParameterContainer,
augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
use_running_average_bn: Optional[bool] = None
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
variables = {'params': params, **model_state}
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
is_train_mode = mode == spec.ForwardPassMode.TRAIN
if update_batch_norm or is_train_mode:
(logits, logit_paddings), new_model_state = self._model.apply(
variables,
inputs,
input_paddings,
train=True,
rngs={'dropout' : rng},
mutable=['batch_stats'])
return (logits, logit_paddings), new_model_state
else:
logits, logit_paddings = self._model.apply(
variables,
inputs,
input_paddings,
train=False,
mutable=False)
return (logits, logit_paddings), model_state

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Dense_0'
Expand Down
1 change: 1 addition & 0 deletions tests/reference_algorithm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ def _test_submission(workload_name,
workload_path=workload_metadata['workload_path'],
workload_class_name=workload_metadata['workload_class_name'],
return_class=True)
print(f'Workload class for {workload_name} is {workload_class}')

submission_module_path = workloads.convert_filepath_to_module(submission_path)
submission_module = importlib.import_module(submission_module_path)
Expand Down

0 comments on commit c5c36c2

Please sign in to comment.