diff --git a/chirp/audio_utils.py b/chirp/audio_utils.py index 4d98a165..ee294659 100644 --- a/chirp/audio_utils.py +++ b/chirp/audio_utils.py @@ -366,7 +366,8 @@ def ema_conv1d( padded_inp = jnp.concatenate([left_pad, xs], axis=1) kernel = jnp.array( - [(1.0 - gamma) ** k for k in range(conv_width - 1)] + [gamma] + [(1.0 - gamma) ** (conv_width - 1)] + + [gamma * (1.0 - gamma) ** k for k in range(conv_width - 2, -1, -1)] ).astype(xs.dtype) if isinstance(gamma, float) or gamma.ndim == 0: kernel = kernel[jnp.newaxis, jnp.newaxis, :] diff --git a/chirp/tests/audio_utils_test.py b/chirp/tests/audio_utils_test.py index ac53ad03..f5b6663c 100644 --- a/chirp/tests/audio_utils_test.py +++ b/chirp/tests/audio_utils_test.py @@ -15,6 +15,7 @@ """Tests for audio utilities.""" +import functools import os from chirp import audio_utils @@ -119,6 +120,22 @@ def test_pcen(self): np.testing.assert_allclose(out, librosa_out, rtol=5e-2) + def test_ema(self): + rng = np.random.default_rng(seed=0) + inputs = rng.normal(size=(128, 3)) + gamma = 0.9 + outputs, _ = audio_utils.ema(inputs, gamma) + ref = functools.reduce(lambda x, y: (1 - gamma) * x + gamma * y, inputs) + np.testing.assert_allclose(outputs[-1], ref, rtol=1e-6) + + def test_ema_conv1d(self): + rng = np.random.default_rng(seed=0) + inputs = rng.normal(size=(128, 3)) + gamma = 0.9 + outputs = audio_utils.ema_conv1d(inputs[None], gamma, conv_width=-1)[0] + ref = functools.reduce(lambda x, y: (1 - gamma) * x + gamma * y, inputs) + np.testing.assert_allclose(outputs[-1], ref, rtol=1e-6) + @parameterized.product( # NOTE: TF and JAX have different outputs when nperseg is odd. nperseg=(256, 230),