diff --git a/tensorflow_io/python/ops/audio_ops.py b/tensorflow_io/python/ops/audio_ops.py index 891b5d243..9997bcf9f 100644 --- a/tensorflow_io/python/ops/audio_ops.py +++ b/tensorflow_io/python/ops/audio_ops.py @@ -17,7 +17,7 @@ import sys import tensorflow as tf - +import math from tensorflow_io.python.ops import core_ops @@ -372,55 +372,113 @@ def fade(input, fade_in, fade_out, mode, name=None): return factor_in * factor_out * input -def resample(input, rate_in, rate_out, name=None): - """Resample audio. +def _get_sinc_resample_kernel(rate_in, rate_out, lowpass_filter_width): + assert lowpass_filter_width > 0 + rate_in=tf.cast(rate_in,tf.float32) + rate_out=tf.cast(rate_out,tf.float32) + base_freq = tf.minimum(rate_in, rate_out) + # This will perform antialiasing filtering by removing the highest frequencies. + # At first I thought I only needed this when downsampling, but when upsampling + # you will get edge artifacts without this, as the edge is equivalent to zero padding, + # which will add high freq artifacts. + base_freq *= 0.99 + + # The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor) + # using the sinc interpolation formula: + # x(t) = sum_i x[i] sinc(pi * rate_in * (i / rate_in - t)) + # We can then sample the function x(t) with a different sample rate: + # y[j] = x(j / rate_out) + # or, + # y[j] = sum_i x[i] sinc(pi * rate_in * (i / rate_in - j / rate_out)) + + # We see here that y[j] is the convolution of x[i] with a specific filter, for which + # we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing. + # But y[j+1] is going to have a different set of weights and so on, until y[j + rate_out]. + # Indeed: + # y[j + rate_out] = sum_i x[i] sinc(pi * rate_in * ((i / rate_in - (j + rate_out) / rate_out)) + # = sum_i x[i] sinc(pi * rate_in * ((i - rate_in) / rate_in - j / rate_out)) + # = sum_i x[i + rate_in] sinc(pi * rate_in * (i / rate_in - j / rate_out)) + # so y[j+rate_out] uses the same filter as y[j], but on a shifted version of x by `rate_in`. + # This will explain the F.conv1d after, with a stride of rate_in. + width = tf.experimental.numpy.ceil(lowpass_filter_width * rate_in / base_freq) + # If rate_in is still big after GCD reduction, most filters will be very unbalanced, i.e., + # they will have a lot of almost zero values to the left or to the right... + # There is probably a way to evaluate those filters more efficiently, but this is kept for + # future work. + idx = tf.range(-width, width + rate_in, dtype=tf.float32) + idx = tf.repeat(tf.expand_dims(idx, axis=-1), tf.cast(rate_out,tf.int32), axis=-1) + aux_i = tf.expand_dims(tf.range(rate_out, dtype=tf.float32), axis=0) + kernels = (-aux_i / rate_out + idx / rate_in) * base_freq + + kernels = tf.clip_by_value(kernels, -lowpass_filter_width, lowpass_filter_width) + kernels *= math.pi + + window = tf.math.cos(kernels / lowpass_filter_width / 2) ** 2 + kernels = tf.where( + kernels == 0, tf.ones_like(kernels), tf.math.sin(kernels) / kernels + ) + kernels *= window + + scale = base_freq / rate_in + return tf.expand_dims(kernels, axis=1) * scale, width + + +def resample(input, rate_in, rate_out, lowpass_filter_width=6): + """Resamples the input at the new frequency. This matches Kaldiā€™s OfflineFeatureTpl ResampleWaveform which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample a signal). LinearResample (LR) means that the output signal is at linearly spaced intervals (i.e the output signal has a frequency of rate_out). It uses sinc/bandlimited interpolation to upsample/downsample the signal. Args: - input: A 1-D (`[samples]`) or 2-D (`[samples, channels]`) or 3-D - (`[batch, samples, channels]`) `Tensor` of type - `int16` or `float`. Audio input. + input: A 1-D (`[samples]`) or 2-D (`[samples, channels]`) or 3-D (`[batch, samples, channels]`) `Tensor` of type `float`. Audio input. rate_in: The rate of the audio input. rate_out: The rate of the audio output. - name: A name for the operation (optional). + lowpass_filter_width: Controls the sharpness of the filter, more == sharper but less efficient. We suggest around 4 to 10 for normal use. (Default: 6) Returns: output: Resampled audio. """ - rank = tf.rank(input) - - def f1(): - return tf.expand_dims(tf.expand_dims(input, -1), 0) - - def f2(): - return tf.expand_dims(input, 0) - - def f3(): - return input - - input = tf.case( - [(tf.math.equal(rank, 1), f1), (tf.math.equal(rank, 2), f2)], default=f3 - ) - - def f(i): - return core_ops.io_audio_resample( - i, rate_in=rate_in, rate_out=rate_out, name=name + waveform = input + + if rate_in == rate_out: + return waveform + rate_in = tf.cast(rate_in,tf.int32) + rate_out = tf.cast(rate_out,tf.int32) + gcd = tf.experimental.numpy.gcd(rate_in, rate_out) + rate_in = rate_in // gcd + rate_out = rate_out // gcd + + kernel, width = _get_sinc_resample_kernel(rate_in, rate_out, lowpass_filter_width) + width=tf.cast(width,tf.int32) + + ori_shape = waveform.shape + ori_shape_len = len(ori_shape) + if ori_shape_len == 1: + waveform = tf.expand_dims(waveform, axis=0) + elif ori_shape_len == 2: + waveform = tf.transpose(waveform, [1, 0]) + elif ori_shape_len == 3: + waveform = tf.transpose(waveform, [0, 2, 1]) + waveform = tf.reshape(waveform, [ori_shape[0] * ori_shape[2], ori_shape[1]]) + + waveform = tf.expand_dims(waveform, axis=-1) + + num_wavs, length, _ = waveform.shape + + waveform = tf.pad(waveform, [[0, 0], [width, width + rate_in], [0, 0]]) + resampled = tf.nn.conv1d(waveform, kernel, stride=tf.reshape(rate_in,[1,]), padding="VALID") + resampled = tf.reshape(resampled, [num_wavs, -1]) + target_length = tf.cast(tf.experimental.numpy.ceil(rate_out * length / rate_in),tf.int32) + if ori_shape_len == 1: + return resampled[0, :target_length] + elif ori_shape_len == 2: + return tf.transpose(resampled[:, :target_length], [1, 0]) + elif ori_shape_len == 3: + return tf.transpose( + tf.reshape( + resampled[:, :target_length], + [ori_shape[0], ori_shape[2], target_length], + ), + [0, 2, 1], ) - value = tf.vectorized_map(f, input) - - def g1(): - return tf.squeeze(value, [0, -1]) - - def g2(): - return tf.squeeze(value, [0]) - - def g3(): - return value - - return tf.case( - [(tf.math.equal(rank, 1), g1), (tf.math.equal(rank, 2), g2)], default=g3 - ) - def decode_wav( input, shape=None, dtype=None, name=None