Skip to content

Commit

Permalink
black reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
durson committed Aug 24, 2023
1 parent 519cfbd commit 8342aab
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 90 deletions.
4 changes: 1 addition & 3 deletions fast_rnnt/python/fast_rnnt/mutual_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,9 +390,7 @@ def joint_mutual_information_recursion(
p = torch.empty(B, S + 1, T + 1, device=px_tot.device, dtype=px_tot.dtype)

# note, tot_probs is without grad.
tot_probs = _fast_rnnt.mutual_information_forward(
px_tot, py_tot, boundary, p
)
tot_probs = _fast_rnnt.mutual_information_forward(px_tot, py_tot, boundary, p)

# this is a kind of "fake gradient" that we use, in effect to compute
# occupation probabilities. The backprop will work regardless of the
Expand Down
77 changes: 25 additions & 52 deletions fast_rnnt/python/fast_rnnt/rnnt_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,7 @@ def get_rnnt_logprobs(
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..

px_lm = torch.gather(
lm[:, :S], dim=2, index=symbols.unsqueeze(-1)
) # [B][S][1]
px_lm = torch.gather(lm[:, :S], dim=2, index=symbols.unsqueeze(-1)) # [B][S][1]

px = px_am + px_lm # [B][S][T+1], last slice with indexes out of
# boundary is -inf
Expand Down Expand Up @@ -314,9 +312,9 @@ def rnnt_loss_simple(
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
penalty = offset.reshape(B, 1, 1) - torch.arange(
T0, device=px.device
).reshape(1, 1, T0)
penalty = offset.reshape(B, 1, 1) - torch.arange(T0, device=px.device).reshape(
1, 1, T0
)
penalty = penalty * delay_penalty
px += penalty.to(px.dtype)

Expand Down Expand Up @@ -424,18 +422,14 @@ def get_rnnt_logprobs_joint(
px = torch.cat(
(
px,
torch.full(
(B, S, 1), float("-inf"), device=px.device, dtype=px.dtype
),
torch.full((B, S, 1), float("-inf"), device=px.device, dtype=px.dtype),
),
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..

px[:, :, :T] -= normalizers[:, :S, :]

py = (
logits[:, :, :, termination_symbol].permute((0, 2, 1)).clone()
) # [B][S+1][T]
py = logits[:, :, :, termination_symbol].permute((0, 2, 1)).clone() # [B][S+1][T]
py -= normalizers

if rnnt_type == "regular":
Expand Down Expand Up @@ -521,9 +515,9 @@ def rnnt_loss(
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
penalty = offset.reshape(B, 1, 1) - torch.arange(
T0, device=px.device
).reshape(1, 1, T0)
penalty = offset.reshape(B, 1, 1) - torch.arange(T0, device=px.device).reshape(
1, 1, T0
)
penalty = penalty * delay_penalty
px += penalty.to(px.dtype)

Expand Down Expand Up @@ -578,9 +572,7 @@ def _monotonic_lower_bound(x: Tensor) -> Tensor:
return x


def _adjust_pruning_lower_bound(
s_begin: Tensor, s_range: int
) -> Tensor:
def _adjust_pruning_lower_bound(s_begin: Tensor, s_range: int) -> Tensor:
"""Adjust s_begin (pruning lower bounds) to make it satisfy the following
constraints
Expand Down Expand Up @@ -617,17 +609,13 @@ def _adjust_pruning_lower_bound(
(B, T) = s_begin.shape
s_begin = _monotonic_lower_bound(s_begin)
# do the magic transformation
s_begin = -(
s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device)
)
s_begin = -(s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device))
# make the transformed tensor to be non-decreasing
s_begin = _monotonic_lower_bound(s_begin)
# make start symbol to be zero.
s_begin = torch.clamp(s_begin, min=0)
# do the magic transformation again to recover s_begin
s_begin = -(
s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device)
)
s_begin = -(s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device))
return s_begin


Expand Down Expand Up @@ -773,9 +761,7 @@ def get_rnnt_prune_ranges(
return ranges


def do_rnnt_pruning(
am: Tensor, lm: Tensor, ranges: Tensor
) -> Tuple[Tensor, Tensor]:
def do_rnnt_pruning(am: Tensor, lm: Tensor, ranges: Tensor) -> Tuple[Tensor, Tensor]:
"""Prune the output of encoder(am) and prediction network(lm) with ranges
generated by `get_rnnt_prune_ranges`.
Expand Down Expand Up @@ -840,10 +826,7 @@ def _roll_by_shifts(src: Tensor, shifts: torch.LongTensor):
assert shifts.shape == (B, T), shifts.shape

index = (
torch.arange(S, device=src.device)
.view((1, S))
.repeat((T, 1))
.repeat((B, 1, 1))
torch.arange(S, device=src.device).view((1, S)).repeat((T, 1)).repeat((B, 1, 1))
)
index = (index - shifts.reshape(B, T, 1)) % S
return torch.gather(src, 2, index)
Expand Down Expand Up @@ -979,9 +962,7 @@ def get_rnnt_logprobs_pruned(
px = torch.cat(
(
px,
torch.full(
(B, S, 1), float("-inf"), device=px.device, dtype=px.dtype
),
torch.full((B, S, 1), float("-inf"), device=px.device, dtype=px.dtype),
),
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..
Expand Down Expand Up @@ -1101,9 +1082,9 @@ def rnnt_loss_pruned(
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
penalty = offset.reshape(B, 1, 1) - torch.arange(
T0, device=px.device
).reshape(1, 1, T0)
penalty = offset.reshape(B, 1, 1) - torch.arange(T0, device=px.device).reshape(
1, 1, T0
)
penalty = penalty * delay_penalty
px += penalty.to(px.dtype)

Expand Down Expand Up @@ -1272,9 +1253,7 @@ def get_rnnt_logprobs_smoothed(
+ torch.finfo(lm_probs.dtype).tiny
) # [1][1][C]
amonly_normalizers = (
torch.mv(am_probs.reshape(-1, C), unigram_lm.reshape(C))
.reshape(B, T, 1)
.log()
torch.mv(am_probs.reshape(-1, C), unigram_lm.reshape(C)).reshape(B, T, 1).log()
+ am_max
) # [B][T][1]
amonly_normalizers = amonly_normalizers.transpose(1, 2) # [B][1][T]
Expand Down Expand Up @@ -1308,9 +1287,7 @@ def get_rnnt_logprobs_smoothed(
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..

px_lm = torch.gather(
lm[:, :S], dim=2, index=symbols.unsqueeze(-1)
) # [B][S][1]
px_lm = torch.gather(lm[:, :S], dim=2, index=symbols.unsqueeze(-1)) # [B][S][1]
px_lm_unigram = torch.gather(
unigram_lm.expand(B, S, C), dim=2, index=symbols.unsqueeze(-1)
) # [B][S][1]
Expand Down Expand Up @@ -1343,14 +1320,10 @@ def get_rnnt_logprobs_smoothed(
am_only_scale = 1.0e-20

px_interp = (
px * combined_scale
+ px_lmonly * lm_only_scale
+ px_amonly * am_only_scale
px * combined_scale + px_lmonly * lm_only_scale + px_amonly * am_only_scale
)
py_interp = (
py * combined_scale
+ py_lmonly * lm_only_scale
+ py_amonly * am_only_scale
py * combined_scale + py_lmonly * lm_only_scale + py_amonly * am_only_scale
)

if rnnt_type == "regular":
Expand Down Expand Up @@ -1463,9 +1436,9 @@ def rnnt_loss_smoothed(
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
penalty = offset.reshape(B, 1, 1) - torch.arange(
T0, device=px.device
).reshape(1, 1, T0)
penalty = offset.reshape(B, 1, 1) - torch.arange(T0, device=px.device).reshape(
1, 1, T0
)
penalty = penalty * delay_penalty
px += penalty.to(px.dtype)

Expand Down
29 changes: 7 additions & 22 deletions fast_rnnt/python/tests/mutual_information_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,8 @@ def test_mutual_information_basic(self):
if random_boundary:

def get_boundary_row():
this_S = random.randint(
0, S
) # allow empty sequence
this_T = random.randint(
this_S if modified else 1, T
)
this_S = random.randint(0, S) # allow empty sequence
this_T = random.randint(this_S if modified else 1, T)
s_begin = random.randint(0, S - this_S)
t_begin = random.randint(0, T - this_T)
s_end = s_begin + this_S
Expand Down Expand Up @@ -118,14 +114,10 @@ def get_boundary_row():
px += 15.0
if random_py:
# log of an odds ratio
py = torch.randn(B, S + 1, T, dtype=dtype).to(
device
)
py = torch.randn(B, S + 1, T, dtype=dtype).to(device)
else:
# log of an odds ratio
py = torch.zeros(B, S + 1, T, dtype=dtype).to(
device
)
py = torch.zeros(B, S + 1, T, dtype=dtype).to(device)
if big_py:
py += 15.0

Expand Down Expand Up @@ -206,14 +198,11 @@ def test_mutual_information_deriv(self):

for dtype in self.dtypes:
for device in self.devices:

if random_boundary:

def get_boundary_row():
this_S = random.randint(1, S)
this_T = random.randint(
this_S if modified else 1, T
)
this_T = random.randint(this_S if modified else 1, T)
s_begin = random.randint(0, S - this_S)
t_begin = random.randint(0, T - this_T)
s_end = s_begin + this_S
Expand Down Expand Up @@ -257,14 +246,10 @@ def get_boundary_row():
px += 15.0
if random_py:
# log of an odds ratio
py = torch.randn(B, S + 1, T, dtype=dtype).to(
device
)
py = torch.randn(B, S + 1, T, dtype=dtype).to(device)
else:
# log of an odds ratio
py = torch.zeros(B, S + 1, T, dtype=dtype).to(
device
)
py = torch.zeros(B, S + 1, T, dtype=dtype).to(device)
if big_py:
py += 15.0
else:
Expand Down
17 changes: 4 additions & 13 deletions fast_rnnt/python/tests/rnnt_loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,7 @@ def test_rnnt_loss_basic(self):
assert px.shape == (B, S, T + 1)
assert py.shape == (B, S + 1, T)
assert symbols.shape == (B, S)
m = fast_rnnt.mutual_information_recursion(
px=px, py=py, boundary=None
)
m = fast_rnnt.mutual_information_recursion(px=px, py=py, boundary=None)

if device == torch.device("cpu"):
expected = -m
Expand Down Expand Up @@ -225,9 +223,7 @@ def test_rnnt_loss_random(self):
rnnt_type=rnnt_type,
)
assert (
px.shape == (B, S, T)
if rnnt_type != "regular"
else (B, S, T + 1)
px.shape == (B, S, T) if rnnt_type != "regular" else (B, S, T + 1)
)
assert py.shape == (B, S + 1, T)
assert symbols.shape == (B, S)
Expand Down Expand Up @@ -343,7 +339,6 @@ def test_rnnt_loss_gradient(self):
boundary_[:, 3] = frames

for device in self.devices:

# lm: [B][S+1][C]
lm = lm_.to(device)
# am: [B][T][C]
Expand Down Expand Up @@ -374,13 +369,9 @@ def test_rnnt_loss_gradient(self):
torch_grad = torch.autograd.grad(torch_loss, logits2)
torch_grad = torch_grad[0]

assert torch.allclose(
fast_loss, torch_loss, atol=1e-2, rtol=1e-2
)
assert torch.allclose(fast_loss, torch_loss, atol=1e-2, rtol=1e-2)

assert torch.allclose(
fast_grad, torch_grad, atol=1e-2, rtol=1e-2
)
assert torch.allclose(fast_grad, torch_grad, atol=1e-2, rtol=1e-2)

def test_rnnt_loss_smoothed(self):
B = 1
Expand Down

0 comments on commit 8342aab

Please sign in to comment.