Skip to content

Commit

Permalink
feat: average filter
Browse files Browse the repository at this point in the history
  • Loading branch information
yoyolicoris committed Mar 19, 2024
1 parent 7157aed commit 899e348
Showing 1 changed file with 31 additions and 2 deletions.
33 changes: 31 additions & 2 deletions torchcomp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,42 @@
import torch
import torch.nn.functional as F
from typing import Union
from torchaudio.functional import lfilter

from .core import compressor_core

__all__ = ["compexp_gain", "limiter_gain", "ms2coef"]
__all__ = ["compexp_gain", "limiter_gain", "ms2coef", "avg"]

amp2db = lambda x: 20 * torch.log10(x)
db2amp = lambda x: 10 ** (x / 20)
ms2coef = lambda ms, sr: (1 - torch.exp(-2200 / ms / sr))
coef2ms = lambda coef, sr: -2200 / (sr * torch.log(1 - coef))


def avg(rms: torch.Tensor, avg_coef: Union[torch.Tensor, float]):
"""Compute the running average of a signal.
Args:
rms (torch.Tensor): Input signal.
avg_coef (torch.Tensor): Coefficient for the average RMS.
Shape:
- rms: :math:`(B, T)` where :math:`B` is the batch size and :math:`T` is the number of samples.
- avg_coef: :math:`(B,)` or a scalar.
"""

avg_coef = torch.as_tensor(
avg_coef, dtype=rms.dtype, device=rms.device
).broadcast_to(rms.shape[0])
assert torch.all(avg_coef > 0) and torch.all(avg_coef <= 1)

return lfilter(
rms,
torch.stack([torch.ones_like(avg_coef), avg_coef - 1], 1),
torch.stack([avg_coef, torch.zeros_like(avg_coef)], 1),
False,
)


def compexp_gain(
Expand Down Expand Up @@ -64,7 +92,8 @@ def compexp_gain(
log_x_rms = amp2db(x_rms)
g = (
torch.minimum(
comp_slope * (comp_thresh - log_x_rms), exp_slope * (exp_thresh - log_x_rms)
comp_slope[:, None] * (comp_thresh[:, None] - log_x_rms),
exp_slope[:, None] * (exp_thresh[:, None] - log_x_rms),
)
.neg()
.relu()
Expand Down

0 comments on commit 899e348

Please sign in to comment.