Skip to content

Commit

Permalink
Merge pull request #7 from DiffAPF:fix-avg_coef-shape
Browse files Browse the repository at this point in the history
fix: correct tensor broadcasting in avg function
  • Loading branch information
yoyolicoris authored Nov 27, 2024
2 parents 70b6590 + 7ab8ed1 commit 03a6ecc
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchcomp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def avg(rms: torch.Tensor, avg_coef: Union[torch.Tensor, float]):
assert torch.all(avg_coef > 0) and torch.all(avg_coef <= 1)

return sample_wise_lpc(
rms * avg_coef,
rms * avg_coef.unsqueeze(1),
avg_coef[:, None, None].broadcast_to(rms.shape + (1,)) - 1,
)

Expand Down

0 comments on commit 03a6ecc

Please sign in to comment.