diff --git a/torchcomp/__init__.py b/torchcomp/__init__.py index ffbf225..105003f 100644 --- a/torchcomp/__init__.py +++ b/torchcomp/__init__.py @@ -1,14 +1,42 @@ import torch import torch.nn.functional as F from typing import Union +from torchaudio.functional import lfilter from .core import compressor_core -__all__ = ["compexp_gain", "limiter_gain", "ms2coef"] +__all__ = ["compexp_gain", "limiter_gain", "ms2coef", "avg"] amp2db = lambda x: 20 * torch.log10(x) db2amp = lambda x: 10 ** (x / 20) ms2coef = lambda ms, sr: (1 - torch.exp(-2200 / ms / sr)) +coef2ms = lambda coef, sr: -2200 / (sr * torch.log(1 - coef)) + + +def avg(rms: torch.Tensor, avg_coef: Union[torch.Tensor, float]): + """Compute the running average of a signal. + + Args: + rms (torch.Tensor): Input signal. + avg_coef (torch.Tensor): Coefficient for the average RMS. + + Shape: + - rms: :math:`(B, T)` where :math:`B` is the batch size and :math:`T` is the number of samples. + - avg_coef: :math:`(B,)` or a scalar. + + """ + + avg_coef = torch.as_tensor( + avg_coef, dtype=rms.dtype, device=rms.device + ).broadcast_to(rms.shape[0]) + assert torch.all(avg_coef > 0) and torch.all(avg_coef <= 1) + + return lfilter( + rms, + torch.stack([torch.ones_like(avg_coef), avg_coef - 1], 1), + torch.stack([avg_coef, torch.zeros_like(avg_coef)], 1), + False, + ) def compexp_gain( @@ -64,7 +92,8 @@ def compexp_gain( log_x_rms = amp2db(x_rms) g = ( torch.minimum( - comp_slope * (comp_thresh - log_x_rms), exp_slope * (exp_thresh - log_x_rms) + comp_slope[:, None] * (comp_thresh[:, None] - log_x_rms), + exp_slope[:, None] * (exp_thresh[:, None] - log_x_rms), ) .neg() .relu()