Skip to content

Commit

Permalink
documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Apr 1, 2024
1 parent 474a13d commit fe9a577
Showing 1 changed file with 118 additions and 104 deletions.
222 changes: 118 additions & 104 deletions vqt/vqt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,90 @@
from . import helpers

class VQT(torch.nn.Module):
"""
Variable Q Transform. Class for applying the variable Q transform to
signals. \n
This function works differently than the VQT from librosa or nnAudio.
This one does not use iterative lowpass filtering. \n
If fft_conv is False, then it uses a fixed set of filters, a Hilbert
transform to compute the analytic signal, and then takes the magnitude. \n
If fft_conv is True, then it uses FFT convolution to compute the transform.
\n
Uses Pytorch for GPU acceleration, and allows gradients to pass through. \n
Q: quality factor; roughly corresponds to the number of cycles in a
filter. Here, Q is similar to the number of cycles within 4 sigma (95%)
of a gaussian window. \n
For running batches on GPU, transferring back to CPU tends to be the slowest
part. \n
RH 2022-2024
Args:
Fs_sample (float):
Sampling frequency of the signal.
Q_lowF (float):
Q factor to use for the lowest frequency.
Q_highF (float):
Q factor to use for the highest frequency.
F_min (float):
Lowest frequency to use.
F_max (float):
Highest frequency to use.
n_freq_bins (int):
Number of frequency bins to use.
win_size (int, None):
Size of the window to use, in samples. \n
If None, will be set to the next odd number after Q_lowF * (Fs_sample / F_min).
window_type (str, np.ndarray, list, tuple):
Window to use for the mother wavelet. \n
* If string: Will be passed to scipy.signal.windows.get_window.
See that documentation for options. Except for 'gaussian',
which you should just pass the string 'gaussian' without any
other arguments.
* If array-like: Will be used as the window directly.
symmetry (str):
Whether to use a symmetric window or a single-sided window. \n
* 'center': Use a symmetric / centered / 'two-sided' window.
\n
* 'left': Use a one-sided, left-half window. Only left half
of the filter will be nonzero. \n * 'right': Use a
one-sided, right-half window. Only right half of the filter
will be nonzero. \n
taper_asymmetric (bool):
Only used if symmetry is not 'center'. Whether to taper the
center of the window by multiplying center sample of window by
0.5.
downsample_factor (int):
Factor to downsample the signal by. If the length of the input
signal is not divisible by downsample_factor, the signal will be
zero-padded at the end so that it is.
padding (str):
Padding mode to use: \n
* If fft_conv==False: ['valid', 'same'] \n
* If fft_conv==True: ['full', 'valid', 'same'] \n
fft_conv (bool):
Whether to use FFT convolution. This is faster, but may be less
accurate. If False, uses torch's conv1d.
fast_length (bool):
Whether to use scipy.fftpack.next_fast_len to
find the next fast length for the FFT.
This may be faster, but uses more memory.
take_abs (bool):
Whether to return the complex version of the transform. If
True, then returns the absolute value (envelope) of the
transform. If False, returns the complex transform.
filters (Torch tensor):
Filters to use. If None, will make new filters. Should be
complex sinusoids. shape: (n_freq_bins, win_size)
verbose (int):
Verbosity. True to print warnings.
plot_pref (bool):
Whether to plot the filters.
"""
def __init__(
self,
Fs_sample: Union[int, float]=1000,
Expand All @@ -30,88 +114,6 @@ def __init__(
verbose: Union[int, bool]=True,
plot_pref: bool=False,
):
"""
Variable Q Transform. Class for applying the variable Q transform to
signals.
This function works differently than the VQT from librosa or nnAudio.
This one does not use iterative lowpass filtering. \n If fft_conv is
False, then it uses a fixed set of filters, a Hilbert transform to
compute the analytic signal, and then takes the magnitude. \n If
fft_conv is True, then it uses FFT convolution to compute the transform.
\n
Uses Pytorch for GPU acceleration, and allows gradients to pass through.
\n
Q: quality factor; roughly corresponds to the number of cycles in a
filter. Here, Q is similar to the number of cycles within 4 sigma (95%)
of a gaussian window. \n
RH 2022-2024
Args:
Fs_sample (float):
Sampling frequency of the signal.
Q_lowF (float):
Q factor to use for the lowest frequency.
Q_highF (float):
Q factor to use for the highest frequency.
F_min (float):
Lowest frequency to use.
F_max (float):
Highest frequency to use.
n_freq_bins (int):
Number of frequency bins to use.
win_size (int, None):
Size of the window to use, in samples. \n
If None, will be set to the next odd number after Q_lowF * (Fs_sample / F_min).
window_type (str, np.ndarray, list, tuple):
Window to use for the mother wavelet. \n
* If string: Will be passed to scipy.signal.windows.get_window.
See that documentation for options. Except for 'gaussian',
which you should just pass the string 'gaussian' without any
other arguments.
* If array-like: Will be used as the window directly.
symmetry (str):
Whether to use a symmetric window or a single-sided window. \n
* 'center': Use a symmetric / centered / 'two-sided' window.
\n
* 'left': Use a one-sided, left-half window. Only left half
of the filter will be nonzero. \n * 'right': Use a
one-sided, right-half window. Only right half of the filter
will be nonzero. \n
taper_asymmetric (bool):
Only used if symmetry is not 'center'. Whether to taper the
center of the window by multiplying center sample of window by
0.5.
downsample_factor (int):
Factor to downsample the signal by. If the length of the input
signal is not divisible by downsample_factor, the signal will be
zero-padded at the end so that it is.
padding (str):
Padding mode to use: \n
* If fft_conv==False: ['valid', 'same'] \n
* If fft_conv==True: ['full', 'valid', 'same'] \n
fft_conv (bool):
Whether to use FFT convolution. This is faster, but may be less
accurate. If False, uses torch's conv1d.
fast_length (bool):
Whether to use scipy.fftpack.next_fast_len to
find the next fast length for the FFT.
This may be faster, but uses more memory.
take_abs (bool):
Whether to return the complex version of the transform. If
True, then returns the absolute value (envelope) of the
transform. If False, returns the complex transform.
filters (Torch tensor):
Filters to use. If None, will make new filters. Should be
complex sinusoids. shape: (n_freq_bins, win_size)
verbose (int):
Verbosity. True to print warnings.
plot_pref (bool):
Whether to plot the filters.
"""
super().__init__()
## Prepare filters
self.using_custom_filters = True if filters is not None else False
Expand Down Expand Up @@ -184,7 +186,7 @@ def __init__(
def forward(
self,
X: torch.Tensor,
):
) -> torch.Tensor:
"""
Forward pass of VQT.
Expand All @@ -197,11 +199,6 @@ def forward(
Spectrogram (Torch tensor):
Spectrogram of the input signal.
shape: (n_channels, n_samples_ds, n_freq_bins)
x_axis (Torch tensor):
New x-axis for the spectrogram in units of samples.
Get units of time by dividing by Fs_sample.
self.freqs (Torch tensor):
Frequencies of the spectrogram.
"""
assert isinstance(X, torch.Tensor), "X should be a torch tensor"
X = X.type(torch.float32)
Expand Down Expand Up @@ -230,11 +227,22 @@ def forward(

return specs

def get_freqs(self):
def get_freqs(self) -> torch.Tensor:
"""
Get the frequencies of the spectrogram.
Args:
None
Returns:
torch.Tensor:
Frequencies of the spectrogram. \n
shape: (n_freq_bins,)
"""
assert hasattr(self, 'freqs'), "freqs not found. This should not happen."
return self.freqs

def get_xAxis(self, n_samples: int):
def get_xAxis(self, n_samples: int) -> torch.Tensor:
"""
Get the x-axis for the spectrogram. \n
RH 2024
Expand All @@ -245,7 +253,8 @@ def get_xAxis(self, n_samples: int):
Returns:
torch.Tensor:
x-axis for the spectrogram in units of samples.
x-axis for the spectrogram in units of samples. \n
shape: (n_samples_ds,)
"""
## Make x_axis
x_axis = make_conv_xAxis(
Expand All @@ -265,13 +274,13 @@ def __repr__(self):
for k, v in self.__dict__.items():
if (k not in ['filters', 'freqs', 'wins']) and (not k.startswith('_')) and (not callable(v)):
attributes_to_print.append(k)
return f"VQT object with parameters: {''.join([f'{k}={getattr(self, k)}, ' for k in attributes_to_print])[:-2]}"
return f"VQT object with parameters: {''.join([f'{k}={getattr(self, k)}, ' for k in attributes_to_print])[:-2]}"


def downsample(
X: torch.Tensor,
ds_factor: int=4,
):
) -> torch.Tensor:
"""
Downsample a signal using average pooling. \n
If X is complex, it will be split into magnitude and phase, downsampled,
Expand Down Expand Up @@ -313,11 +322,11 @@ def downsample(

else:
raise ValueError("X should be a torch tensor of type float or complex")
def _helper_polarReal_to_imag(arr: torch.Tensor):
def _helper_polarReal_to_imag(arr: torch.Tensor) -> torch.Tensor:
return arr[0] * torch.exp(1j * arr[1])
def _helper_imag_to_polarReal(arr: torch.Tensor):
def _helper_imag_to_polarReal(arr: torch.Tensor) -> torch.Tensor:
return torch.stack([torch.abs(arr), torch.angle(arr)], dim=0)
def _helper_ds(arr: torch.Tensor, ds_factor: int):
def _helper_ds(arr: torch.Tensor, ds_factor: int) -> torch.Tensor:
return torch.nn.functional.avg_pool1d(
arr,
kernel_size=[int(ds_factor)],
Expand All @@ -334,7 +343,7 @@ def convolve(
padding: str='same',
fft_conv: bool=False,
fast_length: bool=False,
):
) -> torch.Tensor:
"""
Convolve a signal with a set of kernels. \n
Expand All @@ -361,7 +370,8 @@ def convolve(
Returns:
torch.Tensor:
Result of the convolution.
Result of the convolution. \n
``shape``: (n_channels, n_samples, n_kernels)
"""
assert all(isinstance(arg, torch.Tensor) for arg in [arr, kernels]), "arr and kernels should be torch tensors"

Expand Down Expand Up @@ -408,7 +418,7 @@ def fftconvolve(
y: torch.Tensor,
mode: str='valid',
fast_length: bool=False,
):
) -> torch.Tensor:
"""
Convolution using the FFT method. \n
This is adapted from of torchaudio.functional.fftconvolve that handles
Expand Down Expand Up @@ -436,7 +446,9 @@ def fftconvolve(
Returns:
torch.Tensor:
Result of the convolution.
Result of the convolution. \n
Padding applied to last dimension. \n
``shape``: (..., n_samples)
"""
## Compute the convolution
n_original = x.shape[-1] + y.shape[-1] - 1
Expand All @@ -454,7 +466,7 @@ def fftconvolve(

## For some reason jit is slower here
# @torch.jit.script
def next_fast_len(size: int):
def next_fast_len(size: int) -> int:
"""
Taken from PyTorch Forecasting:
Returns the next largest number ``n >= size`` whose prime factors are all
Expand Down Expand Up @@ -510,7 +522,8 @@ def apply_padding_mode(
Returns:
torch.Tensor:
Result of the convolution with the specified padding mode.
Result of the convolution with the specified padding mode. \n
``shape``: (..., n_samples)
"""
n = x_length + y_length - 1
valid_convolve_modes = ["full", "valid", "same"]
Expand All @@ -534,7 +547,7 @@ def make_conv_xAxis(
padding: str='same',
downsample_factor: int=4,
device: torch.device='cpu',
):
) -> torch.Tensor:
"""
Make the x-axis for the result of a convolution.
This is adapted from torchaudio.functional._make_conv_xAxis.
Expand All @@ -555,7 +568,8 @@ def make_conv_xAxis(
Returns:
torch.Tensor:
x-axis for the result of a convolution.
x-axis for the result of a convolution. \n
``shape``: (n_samples_ds,)
"""

## If n_k is odd, then no offset, if even, then offset by 0.5
Expand Down

0 comments on commit fe9a577

Please sign in to comment.