Skip to content

Commit

Permalink
Fix black
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool committed Aug 24, 2023
1 parent 0ef9457 commit bebaa2c
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/style_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,4 @@ jobs:
shell: bash
working-directory: ${{github.workspace}}
run: |
black --check --diff .
black -l 80 --check --diff .
2 changes: 0 additions & 2 deletions fast_rnnt/python/fast_rnnt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,3 @@
from .rnnt_loss import rnnt_loss_pruned
from .rnnt_loss import rnnt_loss_simple
from .rnnt_loss import rnnt_loss_smoothed


18 changes: 11 additions & 7 deletions fast_rnnt/python/fast_rnnt/mutual_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ def forward(
if return_grad or px.requires_grad or py.requires_grad:
ans_grad = torch.ones(B, device=px.device, dtype=px.dtype)
(px_grad, py_grad) = _fast_rnnt.mutual_information_backward(
px, py, boundary, p, ans_grad)
px, py, boundary, p, ans_grad
)
ctx.save_for_backward(px_grad, py_grad)
assert len(pxy_grads) == 2
pxy_grads[0] = px_grad
Expand Down Expand Up @@ -290,8 +291,9 @@ def mutual_information_recursion(
px, py = px.contiguous(), py.contiguous()

pxy_grads = [None, None]
scores = MutualInformationRecursionFunction.apply(px, py, pxy_grads,
boundary, return_grad)
scores = MutualInformationRecursionFunction.apply(
px, py, pxy_grads, boundary, return_grad
)
px_grad, py_grad = pxy_grads
return (scores, (px_grad, py_grad)) if return_grad else scores

Expand Down Expand Up @@ -388,16 +390,18 @@ def joint_mutual_information_recursion(
p = torch.empty(B, S + 1, T + 1, device=px_tot.device, dtype=px_tot.dtype)

# note, tot_probs is without grad.
tot_probs = _fast_rnnt.mutual_information_forward(px_tot, py_tot, boundary, p)
tot_probs = _fast_rnnt.mutual_information_forward(
px_tot, py_tot, boundary, p
)

# this is a kind of "fake gradient" that we use, in effect to compute
# occupation probabilities. The backprop will work regardless of the
# actual derivative w.r.t. the total probs.
ans_grad = torch.ones(B, device=px_tot.device, dtype=px_tot.dtype)

(px_grad,
py_grad) = _fast_rnnt.mutual_information_backward(px_tot, py_tot, boundary, p,
ans_grad)
(px_grad, py_grad) = _fast_rnnt.mutual_information_backward(
px_tot, py_tot, boundary, p, ans_grad
)

px_grad = px_grad.reshape(1, B, -1)
py_grad = py_grad.reshape(1, B, -1)
Expand Down
24 changes: 15 additions & 9 deletions fast_rnnt/python/fast_rnnt/rnnt_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def get_rnnt_logprobs(
am.transpose(1, 2), # (B, C, T)
dim=1,
index=symbols.unsqueeze(2).expand(B, S, T),
) # (B, S, T)
) # (B, S, T)

if rnnt_type == "regular":
px_am = torch.cat(
Expand Down Expand Up @@ -291,7 +291,9 @@ def rnnt_loss_simple(
T = T0 if rnnt_type != "regular" else T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device,
(T - 1) / 2,
dtype=px.dtype,
device=px.device,
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
Expand Down Expand Up @@ -495,7 +497,9 @@ def rnnt_loss(
T = T0 if rnnt_type != "regular" else T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device,
(T - 1) / 2,
dtype=px.dtype,
device=px.device,
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
Expand Down Expand Up @@ -770,9 +774,7 @@ def do_rnnt_pruning(
lm_pruning = torch.gather(
lm,
dim=1,
index=ranges.reshape(B, T * s_range, 1).expand(
(B, T * s_range, C)
),
index=ranges.reshape(B, T * s_range, 1).expand((B, T * s_range, C)),
).reshape(B, T, s_range, C)
return am_pruning, lm_pruning

Expand Down Expand Up @@ -1057,7 +1059,9 @@ def rnnt_loss_pruned(
T = T0 if rnnt_type != "regular" else T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device,
(T - 1) / 2,
dtype=px.dtype,
device=px.device,
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
Expand Down Expand Up @@ -1248,7 +1252,7 @@ def get_rnnt_logprobs_smoothed(
am.transpose(1, 2), # (B, C, T)
dim=1,
index=symbols.unsqueeze(2).expand(B, S, T),
) # (B, S, T)
) # (B, S, T)

if rnnt_type == "regular":
px_am = torch.cat(
Expand Down Expand Up @@ -1413,7 +1417,9 @@ def rnnt_loss_smoothed(
T = T0 if rnnt_type != "regular" else T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device,
(T - 1) / 2,
dtype=px.dtype,
device=px.device,
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
Expand Down

0 comments on commit bebaa2c

Please sign in to comment.