Skip to content

Commit

Permalink
added return_grad for all types of rnnt loss (#29)
Browse files Browse the repository at this point in the history
* added return_grad for all types of rnnt loss

* lifted T >= S for regular case

* black reformat

* black -l80 reformat

* fixed s_range adjustment rule
  • Loading branch information
durson authored Aug 25, 2023
1 parent 6a4b834 commit 83e4637
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 40 deletions.
124 changes: 86 additions & 38 deletions fast_rnnt/python/fast_rnnt/rnnt_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,26 @@
from .mutual_information import mutual_information_recursion


def validate_st_lengths(
S: int,
T: int,
is_rnnt_type_regular: bool,
boundary: Optional[Tensor] = None,
):
if boundary is None:
assert S >= 1, S
assert (
is_rnnt_type_regular or T >= S
), f"Modified transducer requires T >= S, but got T={T} and S={S}"
else:
Ss = boundary[:, 2]
Ts = boundary[:, 3]
assert (Ss >= 1).all(), Ss
assert (
is_rnnt_type_regular or (Ts >= Ss).all()
), f"Modified transducer requires T >= S, but got T={Ts} and S={Ss}"


def fix_for_boundary(px: Tensor, boundary: Optional[Tensor] = None) -> Tensor:
"""
Insert -inf's into `px` in appropriate places if `boundary` is not
Expand Down Expand Up @@ -145,8 +165,8 @@ def get_rnnt_logprobs(
(B, T, C) = am.shape
S = lm.shape[1] - 1
assert symbols.shape == (B, S), symbols.shape
assert S >= 1, S
assert T >= S, (T, S)

validate_st_lengths(S, T, rnnt_type == "regular", boundary)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type

# subtracting am_max and lm_max is to ensure the probs are in a good range
Expand Down Expand Up @@ -391,8 +411,8 @@ def get_rnnt_logprobs_joint(
(B, T, S1, C) = logits.shape
S = S1 - 1
assert symbols.shape == (B, S), symbols.shape
assert S >= 1, S
assert T >= S, (T, S)

validate_st_lengths(S, T, rnnt_type == "regular", boundary)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type

normalizers = torch.logsumexp(logits, dim=3)
Expand Down Expand Up @@ -437,6 +457,7 @@ def rnnt_loss(
rnnt_type: str = "regular",
delay_penalty: float = 0.0,
reduction: Optional[str] = "mean",
return_grad: bool = False,
) -> Tensor:
"""A normal RNN-T loss, which uses a 'joiner' network output as input,
i.e. a 4 dimensions tensor.
Expand Down Expand Up @@ -509,20 +530,24 @@ def rnnt_loss(
penalty = penalty * delay_penalty
px += penalty.to(px.dtype)

negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary)
scores_and_grads = mutual_information_recursion(
px=px, py=py, boundary=boundary, return_grad=return_grad
)
negated_loss = scores_and_grads[0] if return_grad else scores_and_grads
if reduction == "none":
return -negated_loss
loss = -negated_loss
elif reduction == "mean":
return -torch.mean(negated_loss)
loss = -torch.mean(negated_loss)
elif reduction == "sum":
return -torch.sum(negated_loss)
loss = -torch.sum(negated_loss)
else:
raise ValueError(
f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
)
return (loss, scores_and_grads[1]) if return_grad else loss


def _monotonic_lower_bound(x: torch.Tensor) -> torch.Tensor:
def _monotonic_lower_bound(x: Tensor) -> Tensor:
"""Compute a monotonically increasing lower bound of the tensor `x` on the
last dimension. The basic idea is: we traverse the tensor in reverse order,
and update current element with the following statement,
Expand Down Expand Up @@ -556,9 +581,7 @@ def _monotonic_lower_bound(x: torch.Tensor) -> torch.Tensor:
return x


def _adjust_pruning_lower_bound(
s_begin: torch.Tensor, s_range: int
) -> torch.Tensor:
def _adjust_pruning_lower_bound(s_begin: Tensor, s_range: int) -> Tensor:
"""Adjust s_begin (pruning lower bounds) to make it satisfy the following
constraints
Expand Down Expand Up @@ -613,11 +636,11 @@ def _adjust_pruning_lower_bound(
# chapter 3.2 (Pruning bounds) of our Pruned RNN-T paper
# (https://arxiv.org/pdf/2206.13236.pdf)
def get_rnnt_prune_ranges(
px_grad: torch.Tensor,
py_grad: torch.Tensor,
boundary: torch.Tensor,
px_grad: Tensor,
py_grad: Tensor,
boundary: Tensor,
s_range: int,
) -> torch.Tensor:
) -> Tensor:
"""Get the pruning ranges of normal rnnt loss according to the grads
of px and py returned by mutual_information_recursion.
Expand Down Expand Up @@ -661,28 +684,44 @@ def get_rnnt_prune_ranges(
"""
(B, S, T1) = px_grad.shape
T = py_grad.shape[-1]

is_regular = T1 != T

assert T1 in [T, T + 1], T1
S1 = S + 1
assert py_grad.shape == (B, S + 1, T), py_grad.shape
assert boundary.shape == (B, 4), boundary.shape

assert S >= 1, S
assert T >= S, (T, S)
validate_st_lengths(S, T, is_regular, boundary)

# in regular case s_range should be no less than
# a minimum integer satisfying `(s_range - 1) * t + 1 >= s + 1`
if is_regular:
Ss = boundary[:, 2]
Ts = boundary[:, 3]
s_range_min = (
Ss.sub(1).div(Ts, rounding_mode="trunc").add(2).max().item()
)
if s_range < s_range_min:
print(
f"Warning: get_rnnt_prune_ranges - got s_range={s_range} "
f"for boundaries S={Ss}, T={Ts}. Adjusting to {s_range_min}"
)
s_range = s_range_min

# s_range > S means we won't prune out any symbols. To make indexing with
# ranges run normally, s_range should be equal to or less than ``S + 1``.
if s_range > S:
s_range = S + 1

if T1 == T:
assert (
s_range >= 1
), "Pruning range for modified RNN-T should be equal to or greater than 1, or no valid paths could survive pruning."

else:
if is_regular:
assert (
s_range >= 2
), "Pruning range for standard RNN-T should be equal to or greater than 2, or no valid paths could survive pruning."
else:
assert (
s_range >= 1
), "Pruning range for modified RNN-T should be equal to or greater than 1, or no valid paths could survive pruning."

(B_stride, S_stride, T_stride) = py_grad.stride()
blk_grad = torch.as_strided(
Expand Down Expand Up @@ -739,8 +778,8 @@ def get_rnnt_prune_ranges(


def do_rnnt_pruning(
am: torch.Tensor, lm: torch.Tensor, ranges: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
am: Tensor, lm: Tensor, ranges: Tensor
) -> Tuple[Tensor, Tensor]:
"""Prune the output of encoder(am) and prediction network(lm) with ranges
generated by `get_rnnt_prune_ranges`.
Expand Down Expand Up @@ -779,7 +818,7 @@ def do_rnnt_pruning(
return am_pruning, lm_pruning


def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor):
def _roll_by_shifts(src: Tensor, shifts: torch.LongTensor):
"""Roll tensor with different shifts for each row.
Note:
Expand Down Expand Up @@ -819,7 +858,7 @@ def get_rnnt_logprobs_pruned(
symbols: Tensor,
ranges: Tensor,
termination_symbol: int,
boundary: Tensor,
boundary: Optional[Tensor] = None,
rnnt_type: str = "regular",
) -> Tuple[Tensor, Tensor]:
"""Construct px, py for mutual_information_recursion with pruned output.
Expand Down Expand Up @@ -888,10 +927,14 @@ def get_rnnt_logprobs_pruned(
# ranges (B, T, s_range)
assert logits.ndim == 4, logits.ndim
(B, T, s_range, C) = logits.shape
assert ranges.shape == (B, T, s_range), ranges.shape
assert ranges.shape == (
B,
T,
s_range,
), f"{ranges.shape} == ({B}, {T}, {s_range})"
(B, S) = symbols.shape
assert S >= 1, S
assert T >= S, (T, S)

validate_st_lengths(S, T, rnnt_type == "regular", boundary)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type

normalizers = torch.logsumexp(logits, dim=3)
Expand Down Expand Up @@ -986,10 +1029,11 @@ def rnnt_loss_pruned(
symbols: Tensor,
ranges: Tensor,
termination_symbol: int,
boundary: Tensor = None,
boundary: Optional[Tensor] = None,
rnnt_type: str = "regular",
delay_penalty: float = 0.0,
reduction: Optional[str] = "mean",
return_grad: bool = False,
) -> Tensor:
"""A RNN-T loss with pruning, which uses the output of a pruned 'joiner'
network as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C),
Expand Down Expand Up @@ -1071,17 +1115,21 @@ def rnnt_loss_pruned(
penalty = penalty * delay_penalty
px += penalty.to(px.dtype)

negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary)
scores_and_grads = mutual_information_recursion(
px=px, py=py, boundary=boundary, return_grad=return_grad
)
negated_loss = scores_and_grads[0] if return_grad else scores_and_grads
if reduction == "none":
return -negated_loss
loss = -negated_loss
elif reduction == "mean":
return -torch.mean(negated_loss)
loss = -torch.mean(negated_loss)
elif reduction == "sum":
return -torch.sum(negated_loss)
loss = -torch.sum(negated_loss)
else:
raise ValueError(
f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
)
return (loss, scores_and_grads[1]) if return_grad else loss


def get_rnnt_logprobs_smoothed(
Expand Down Expand Up @@ -1202,8 +1250,8 @@ def get_rnnt_logprobs_smoothed(
(B, T, C) = am.shape
S = lm.shape[1] - 1
assert symbols.shape == (B, S), symbols.shape
assert S >= 1, S
assert T >= S, (T, S)

validate_st_lengths(S, T, rnnt_type == "regular", boundary)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type

# Caution: some parts of this code are a little less clear than they could
Expand Down
1 change: 0 additions & 1 deletion fast_rnnt/python/tests/mutual_information_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ def test_mutual_information_deriv(self):

for dtype in self.dtypes:
for device in self.devices:

if random_boundary:

def get_boundary_row():
Expand Down
97 changes: 96 additions & 1 deletion fast_rnnt/python/tests/rnnt_loss_test.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,6 @@ def test_rnnt_loss_gradient(self):
boundary_[:, 3] = frames

for device in self.devices:

# lm: [B][S+1][C]
lm = lm_.to(device)
# am: [B][T][C]
Expand Down Expand Up @@ -609,6 +608,102 @@ def test_rnnt_loss_pruned_small_symbols_number(self):
)
print(f"Pruned loss with range {r} : {pruned_loss}")

# Test low s_range values with large S and small T,
# at this circumstance, the s_range would not be enough
# to cover the whole sequence length (in regular rnnt mode)
# and would result in inf loss
def test_rnnt_loss_pruned_small_s_range(self):
B = 2
T = 2
S = 10
C = 10

frames = torch.randint(1, T, (B,))
seq_lengths = torch.randint(1, S, (B,))
T = torch.max(frames)
S = torch.max(seq_lengths)

am_ = torch.randn((B, T, C), dtype=torch.float64)
lm_ = torch.randn((B, S + 1, C), dtype=torch.float64)
symbols_ = torch.randint(0, C, (B, S))
terminal_symbol = C - 1

boundary_ = torch.zeros((B, 4), dtype=torch.int64)
boundary_[:, 2] = seq_lengths
boundary_[:, 3] = frames

print(f"B = {B}, T = {T}, S = {S}, C = {C}")

for rnnt_type in ["regular"]:
for device in self.devices:
# normal rnnt
am = am_.to(device)
lm = lm_.to(device)
symbols = symbols_.to(device)
boundary = boundary_.to(device)

logits = am.unsqueeze(2) + lm.unsqueeze(1)
logits = logits.float()

# nonlinear transform
logits = torch.sigmoid(logits)

loss = fast_rnnt.rnnt_loss(
logits=logits,
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
rnnt_type=rnnt_type,
reduction="none",
)

print(f"Unpruned rnnt loss with {rnnt_type} rnnt : {loss}")

# pruning
simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
rnnt_type=rnnt_type,
return_grad=True,
reduction="none",
)

S0 = 2

for r in range(S0, S + 2):
ranges = fast_rnnt.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=r,
)
# (B, T, r, C)
pruned_am, pruned_lm = fast_rnnt.do_rnnt_pruning(
am=am, lm=lm, ranges=ranges
)

logits = pruned_am + pruned_lm

# nonlinear transform
logits = torch.sigmoid(logits)

pruned_loss = fast_rnnt.rnnt_loss_pruned(
logits=logits,
symbols=symbols,
ranges=ranges,
termination_symbol=terminal_symbol,
boundary=boundary,
rnnt_type=rnnt_type,
reduction="none",
)
assert (
not pruned_loss.isinf().any()
), f"Pruned loss is inf for r={r}, S={S}, T={T}: {pruned_loss}"
print(f"Pruned loss with range {r} : {pruned_loss}")


if __name__ == "__main__":
unittest.main()

0 comments on commit 83e4637

Please sign in to comment.