Skip to content

Commit

Permalink
copy jax and pytorch loss_fn, model_fn and _eval_model_on_split to de…
Browse files Browse the repository at this point in the history
…epspeech classes
  • Loading branch information
priyakasimbeg committed Sep 30, 2023
1 parent 0c67481 commit 1e44349
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
BaseDeepspeechLibrispeechWorkload


class LibriSpeechDeepSpeechWorkload(BaseDeepspeechLibrispeechWorkload):
class LibriSpeechDeepSpeechWorkload(LibrispeechWorkload):

def init_model_fn(
self,
Expand Down Expand Up @@ -50,3 +50,106 @@ def init_model_fn(

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Dense_0'

def model_fn(
self,
params: spec.ParameterContainer,
augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
variables = {'params': params, **model_state}
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
is_train_mode = mode == spec.ForwardPassMode.TRAIN
if update_batch_norm or is_train_mode:
(logits, logit_paddings), new_model_state = self._model.apply(
variables,
inputs,
input_paddings,
train=True,
rngs={'dropout' : rng},
mutable=['batch_stats'])
return (logits, logit_paddings), new_model_state
else:
logits, logit_paddings = self._model.apply(
variables,
inputs,
input_paddings,
train=False,
mutable=False)
return (logits, logit_paddings), model_state

# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
self,
label_batch: Tuple[spec.Tensor, spec.Tensor], # (label_batch, padding)
logits_batch: Tuple[spec.Tensor, spec.Tensor], # (logits_batch, padding)
mask_batch: Optional[spec.Tensor] = None,
label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
valid examples in batch, 'per_example': 1-d array of per-example losses}
(not synced across devices).
"""
del label_smoothing
logits, logit_paddings = logits_batch
targets, target_paddings = label_batch
logprobs = nn.log_softmax(logits)
per_example_losses = self.ctc_loss(logprobs,
logit_paddings,
targets,
target_paddings)
# mask_batch is assumed to be shape [batch].
if mask_batch is not None:
per_example_losses *= mask_batch
mask_batch = jnp.logical_and(mask_batch, 1 - target_paddings)
else:
mask_batch = 1 - target_paddings
n_valid_examples = jnp.maximum(mask_batch.sum(), 1)
summed_loss = per_example_losses.sum()
return {
'summed': summed_loss,
'n_valid_examples': n_valid_examples,
'per_example': per_example_losses,
}

def _eval_model_on_split(self,
split: str,
num_examples: int,
global_batch_size: int,
params: spec.ParameterContainer,
model_state: spec.ModelAuxiliaryState,
rng: spec.RandomState,
data_dir: str,
global_step: int = 0) -> Dict[str, float]:
"""Run a full evaluation of the model."""
del global_step
if model_state is not None:
# Sync batch statistics across replicas before evaluating.
model_state = self.sync_batch_stats(model_state)

num_batches = int(math.ceil(num_examples / global_batch_size))
if split not in self._eval_iters:
self._eval_iters[split] = self._build_input_queue(
rng, split, data_dir, global_batch_size, num_batches=num_batches)

metrics_report = None
for _ in range(num_batches):
eval_batch = next(self._eval_iters[split])
computed_metrics = self.eval_step_pmapped(params,
eval_batch,
model_state,
rng).unreplicate()

if metrics_report is None:
metrics_report = computed_metrics
else:
# `merge` aggregates the metrics across batches.
metrics_report = metrics_report.merge(computed_metrics)

computed_metrics = metrics_report.compute()

return computed_metrics
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,143 @@ def init_model_fn(

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['lin.weight', 'lin.bias']

def model_fn(
self,
params: spec.ParameterContainer,
augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng

model = params
if mode == spec.ForwardPassMode.EVAL:
model.eval()
if mode == spec.ForwardPassMode.TRAIN:
model.train()
model.apply(
functools.partial(
pytorch_utils.update_batch_norm_fn,
update_batch_norm=update_batch_norm))

contexts = {
spec.ForwardPassMode.EVAL: torch.no_grad,
spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
}
with contexts[mode]():
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
logits, logits_paddings = model(inputs.to(DEVICE),
input_paddings.to(DEVICE))
return (logits, logits_paddings), None

# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
self,
label_batch: Tuple[spec.Tensor, spec.Tensor], # (label_batch, padding)
logits_batch: Tuple[spec.Tensor, spec.Tensor], # (logits_batch, padding)
mask_batch: Optional[spec.Tensor] = None,
label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
valid examples in batch, 'per_example': 1-d array of per-example losses}
(not synced across devices).
"""
del label_smoothing
targets, target_paddings = label_batch
logits, logit_paddings = logits_batch
logprobs = torch.log_softmax(logits, dim=-1)
input_lengths = torch.einsum('bh->b', 1 - logit_paddings).long()
target_lengths = torch.einsum('bh->b', 1 - target_paddings).long()
per_example_losses = self.ctc_loss(
logprobs.permute(1, 0, 2),
targets.long(),
input_lengths,
target_lengths)
# mask_batch is assumed to be shape [batch].
if mask_batch is not None:
per_example_losses *= mask_batch
mask_batch = torch.logical_and(mask_batch, target_lengths)
else:
mask_batch = target_lengths
n_valid_examples = mask_batch.sum().to(per_example_losses)
summed_loss = per_example_losses.sum()
n_valid_examples = max(n_valid_examples, 1)
return {
'summed': summed_loss,
'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE),
'per_example': per_example_losses,
}

def _eval_model_on_split(self,
split: str,
num_examples: int,
global_batch_size: int,
params: spec.ParameterContainer,
model_state: spec.ModelAuxiliaryState,
rng: spec.RandomState,
data_dir: str,
global_step: int = 0) -> Dict[str, float]:
"""Run a full evaluation of the model."""
del global_step
data_rng, model_rng = prng.split(rng, 2)
if split not in self._eval_iters:
# These iterators repeat indefinitely.
self._eval_iters[split] = (
self._build_input_queue(
data_rng, split, data_dir, global_batch_size=global_batch_size))

total_metrics = {
'loss': torch.tensor(0., device=DEVICE),
'lengths': torch.tensor(0., device=DEVICE),
'word_errors': torch.tensor(0., device=DEVICE),
'num_words': torch.tensor(0., device=DEVICE),
}
num_batches = int(math.ceil(num_examples / global_batch_size))
if self.requires_sync_before_eval:
self.sync_sd(params)
for _ in range(num_batches):
batch = next(self._eval_iters[split])

(logits, logits_padding), _ = self.model_fn(
params,
batch,
model_state,
spec.ForwardPassMode.EVAL,
model_rng,
update_batch_norm=False)
decoded, decoded_paddings = self.greedy_decode(logits, logits_padding)
targets, target_paddings = batch['targets']
word_errors, num_words = metrics.compute_wer(
decoded=decoded.cpu().numpy(),
decoded_paddings=decoded_paddings.cpu().numpy(),
targets=targets.cpu().numpy(),
target_paddings=target_paddings.cpu().numpy(),
tokenizer=self.tokenizer)
loss = self.loss_fn((targets, target_paddings), (logits, logits_padding))
summed_loss = loss['summed']
lengths = loss['n_valid_examples']
batch_metrics = {
'loss': summed_loss,
'lengths': lengths,
'word_errors': word_errors,
'num_words': num_words,
}
total_metrics = {
k: v + batch_metrics[k] for k, v in total_metrics.items()
}
if USE_PYTORCH_DDP:
for metric in total_metrics.values():
dist.all_reduce(metric)
return {
'ctc_loss':
float(total_metrics['loss'].item() /
total_metrics['lengths'].item()),
'wer':
float(total_metrics['word_errors'].item() /
total_metrics['num_words'].item()),
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from algorithmic_efficiency.workloads.librispeech_conformer import workload


class BaseDeepspeechLibrispeechWorkload(workload.BaseLibrispeechWorkload):
class BaseDeepspeechLibrispeechWorkload(workload.LibrispeechConformerWorkload):

@property
def validation_target_value(self) -> float:
Expand Down

0 comments on commit 1e44349

Please sign in to comment.