From fe9a57741688bda3f95d117f49c42fbfdc9e8a5a Mon Sep 17 00:00:00 2001 From: RichieHakim Date: Mon, 1 Apr 2024 03:14:14 -0400 Subject: [PATCH] documentation --- vqt/vqt.py | 222 ++++++++++++++++++++++++++++------------------------- 1 file changed, 118 insertions(+), 104 deletions(-) diff --git a/vqt/vqt.py b/vqt/vqt.py index e236b37..c5f6dfb 100644 --- a/vqt/vqt.py +++ b/vqt/vqt.py @@ -9,6 +9,90 @@ from . import helpers class VQT(torch.nn.Module): + """ + Variable Q Transform. Class for applying the variable Q transform to + signals. \n + + This function works differently than the VQT from librosa or nnAudio. + This one does not use iterative lowpass filtering. \n + If fft_conv is False, then it uses a fixed set of filters, a Hilbert + transform to compute the analytic signal, and then takes the magnitude. \n + If fft_conv is True, then it uses FFT convolution to compute the transform. + \n + + Uses Pytorch for GPU acceleration, and allows gradients to pass through. \n + + Q: quality factor; roughly corresponds to the number of cycles in a + filter. Here, Q is similar to the number of cycles within 4 sigma (95%) + of a gaussian window. \n + + For running batches on GPU, transferring back to CPU tends to be the slowest + part. \n + + RH 2022-2024 + + Args: + Fs_sample (float): + Sampling frequency of the signal. + Q_lowF (float): + Q factor to use for the lowest frequency. + Q_highF (float): + Q factor to use for the highest frequency. + F_min (float): + Lowest frequency to use. + F_max (float): + Highest frequency to use. + n_freq_bins (int): + Number of frequency bins to use. + win_size (int, None): + Size of the window to use, in samples. \n + If None, will be set to the next odd number after Q_lowF * (Fs_sample / F_min). + window_type (str, np.ndarray, list, tuple): + Window to use for the mother wavelet. \n + * If string: Will be passed to scipy.signal.windows.get_window. + See that documentation for options. Except for 'gaussian', + which you should just pass the string 'gaussian' without any + other arguments. + * If array-like: Will be used as the window directly. + symmetry (str): + Whether to use a symmetric window or a single-sided window. \n + * 'center': Use a symmetric / centered / 'two-sided' window. + \n + * 'left': Use a one-sided, left-half window. Only left half + of the filter will be nonzero. \n * 'right': Use a + one-sided, right-half window. Only right half of the filter + will be nonzero. \n + taper_asymmetric (bool): + Only used if symmetry is not 'center'. Whether to taper the + center of the window by multiplying center sample of window by + 0.5. + downsample_factor (int): + Factor to downsample the signal by. If the length of the input + signal is not divisible by downsample_factor, the signal will be + zero-padded at the end so that it is. + padding (str): + Padding mode to use: \n + * If fft_conv==False: ['valid', 'same'] \n + * If fft_conv==True: ['full', 'valid', 'same'] \n + fft_conv (bool): + Whether to use FFT convolution. This is faster, but may be less + accurate. If False, uses torch's conv1d. + fast_length (bool): + Whether to use scipy.fftpack.next_fast_len to + find the next fast length for the FFT. + This may be faster, but uses more memory. + take_abs (bool): + Whether to return the complex version of the transform. If + True, then returns the absolute value (envelope) of the + transform. If False, returns the complex transform. + filters (Torch tensor): + Filters to use. If None, will make new filters. Should be + complex sinusoids. shape: (n_freq_bins, win_size) + verbose (int): + Verbosity. True to print warnings. + plot_pref (bool): + Whether to plot the filters. + """ def __init__( self, Fs_sample: Union[int, float]=1000, @@ -30,88 +114,6 @@ def __init__( verbose: Union[int, bool]=True, plot_pref: bool=False, ): - """ - Variable Q Transform. Class for applying the variable Q transform to - signals. - - This function works differently than the VQT from librosa or nnAudio. - This one does not use iterative lowpass filtering. \n If fft_conv is - False, then it uses a fixed set of filters, a Hilbert transform to - compute the analytic signal, and then takes the magnitude. \n If - fft_conv is True, then it uses FFT convolution to compute the transform. - \n - - Uses Pytorch for GPU acceleration, and allows gradients to pass through. - \n - - Q: quality factor; roughly corresponds to the number of cycles in a - filter. Here, Q is similar to the number of cycles within 4 sigma (95%) - of a gaussian window. \n - - RH 2022-2024 - - Args: - Fs_sample (float): - Sampling frequency of the signal. - Q_lowF (float): - Q factor to use for the lowest frequency. - Q_highF (float): - Q factor to use for the highest frequency. - F_min (float): - Lowest frequency to use. - F_max (float): - Highest frequency to use. - n_freq_bins (int): - Number of frequency bins to use. - win_size (int, None): - Size of the window to use, in samples. \n - If None, will be set to the next odd number after Q_lowF * (Fs_sample / F_min). - window_type (str, np.ndarray, list, tuple): - Window to use for the mother wavelet. \n - * If string: Will be passed to scipy.signal.windows.get_window. - See that documentation for options. Except for 'gaussian', - which you should just pass the string 'gaussian' without any - other arguments. - * If array-like: Will be used as the window directly. - symmetry (str): - Whether to use a symmetric window or a single-sided window. \n - * 'center': Use a symmetric / centered / 'two-sided' window. - \n - * 'left': Use a one-sided, left-half window. Only left half - of the filter will be nonzero. \n * 'right': Use a - one-sided, right-half window. Only right half of the filter - will be nonzero. \n - taper_asymmetric (bool): - Only used if symmetry is not 'center'. Whether to taper the - center of the window by multiplying center sample of window by - 0.5. - downsample_factor (int): - Factor to downsample the signal by. If the length of the input - signal is not divisible by downsample_factor, the signal will be - zero-padded at the end so that it is. - padding (str): - Padding mode to use: \n - * If fft_conv==False: ['valid', 'same'] \n - * If fft_conv==True: ['full', 'valid', 'same'] \n - fft_conv (bool): - Whether to use FFT convolution. This is faster, but may be less - accurate. If False, uses torch's conv1d. - fast_length (bool): - Whether to use scipy.fftpack.next_fast_len to - find the next fast length for the FFT. - This may be faster, but uses more memory. - take_abs (bool): - Whether to return the complex version of the transform. If - True, then returns the absolute value (envelope) of the - transform. If False, returns the complex transform. - filters (Torch tensor): - Filters to use. If None, will make new filters. Should be - complex sinusoids. shape: (n_freq_bins, win_size) - verbose (int): - Verbosity. True to print warnings. - plot_pref (bool): - Whether to plot the filters. - """ super().__init__() ## Prepare filters self.using_custom_filters = True if filters is not None else False @@ -184,7 +186,7 @@ def __init__( def forward( self, X: torch.Tensor, - ): + ) -> torch.Tensor: """ Forward pass of VQT. @@ -197,11 +199,6 @@ def forward( Spectrogram (Torch tensor): Spectrogram of the input signal. shape: (n_channels, n_samples_ds, n_freq_bins) - x_axis (Torch tensor): - New x-axis for the spectrogram in units of samples. - Get units of time by dividing by Fs_sample. - self.freqs (Torch tensor): - Frequencies of the spectrogram. """ assert isinstance(X, torch.Tensor), "X should be a torch tensor" X = X.type(torch.float32) @@ -230,11 +227,22 @@ def forward( return specs - def get_freqs(self): + def get_freqs(self) -> torch.Tensor: + """ + Get the frequencies of the spectrogram. + + Args: + None + + Returns: + torch.Tensor: + Frequencies of the spectrogram. \n + shape: (n_freq_bins,) + """ assert hasattr(self, 'freqs'), "freqs not found. This should not happen." return self.freqs - def get_xAxis(self, n_samples: int): + def get_xAxis(self, n_samples: int) -> torch.Tensor: """ Get the x-axis for the spectrogram. \n RH 2024 @@ -245,7 +253,8 @@ def get_xAxis(self, n_samples: int): Returns: torch.Tensor: - x-axis for the spectrogram in units of samples. + x-axis for the spectrogram in units of samples. \n + shape: (n_samples_ds,) """ ## Make x_axis x_axis = make_conv_xAxis( @@ -265,13 +274,13 @@ def __repr__(self): for k, v in self.__dict__.items(): if (k not in ['filters', 'freqs', 'wins']) and (not k.startswith('_')) and (not callable(v)): attributes_to_print.append(k) - return f"VQT object with parameters: {''.join([f'{k}={getattr(self, k)}, ' for k in attributes_to_print])[:-2]}" + return f"VQT object with parameters: {''.join([f'{k}={getattr(self, k)}, ' for k in attributes_to_print])[:-2]}" def downsample( X: torch.Tensor, ds_factor: int=4, -): +) -> torch.Tensor: """ Downsample a signal using average pooling. \n If X is complex, it will be split into magnitude and phase, downsampled, @@ -313,11 +322,11 @@ def downsample( else: raise ValueError("X should be a torch tensor of type float or complex") -def _helper_polarReal_to_imag(arr: torch.Tensor): +def _helper_polarReal_to_imag(arr: torch.Tensor) -> torch.Tensor: return arr[0] * torch.exp(1j * arr[1]) -def _helper_imag_to_polarReal(arr: torch.Tensor): +def _helper_imag_to_polarReal(arr: torch.Tensor) -> torch.Tensor: return torch.stack([torch.abs(arr), torch.angle(arr)], dim=0) -def _helper_ds(arr: torch.Tensor, ds_factor: int): +def _helper_ds(arr: torch.Tensor, ds_factor: int) -> torch.Tensor: return torch.nn.functional.avg_pool1d( arr, kernel_size=[int(ds_factor)], @@ -334,7 +343,7 @@ def convolve( padding: str='same', fft_conv: bool=False, fast_length: bool=False, -): +) -> torch.Tensor: """ Convolve a signal with a set of kernels. \n @@ -361,7 +370,8 @@ def convolve( Returns: torch.Tensor: - Result of the convolution. + Result of the convolution. \n + ``shape``: (n_channels, n_samples, n_kernels) """ assert all(isinstance(arg, torch.Tensor) for arg in [arr, kernels]), "arr and kernels should be torch tensors" @@ -408,7 +418,7 @@ def fftconvolve( y: torch.Tensor, mode: str='valid', fast_length: bool=False, -): +) -> torch.Tensor: """ Convolution using the FFT method. \n This is adapted from of torchaudio.functional.fftconvolve that handles @@ -436,7 +446,9 @@ def fftconvolve( Returns: torch.Tensor: - Result of the convolution. + Result of the convolution. \n + Padding applied to last dimension. \n + ``shape``: (..., n_samples) """ ## Compute the convolution n_original = x.shape[-1] + y.shape[-1] - 1 @@ -454,7 +466,7 @@ def fftconvolve( ## For some reason jit is slower here # @torch.jit.script -def next_fast_len(size: int): +def next_fast_len(size: int) -> int: """ Taken from PyTorch Forecasting: Returns the next largest number ``n >= size`` whose prime factors are all @@ -510,7 +522,8 @@ def apply_padding_mode( Returns: torch.Tensor: - Result of the convolution with the specified padding mode. + Result of the convolution with the specified padding mode. \n + ``shape``: (..., n_samples) """ n = x_length + y_length - 1 valid_convolve_modes = ["full", "valid", "same"] @@ -534,7 +547,7 @@ def make_conv_xAxis( padding: str='same', downsample_factor: int=4, device: torch.device='cpu', -): +) -> torch.Tensor: """ Make the x-axis for the result of a convolution. This is adapted from torchaudio.functional._make_conv_xAxis. @@ -555,7 +568,8 @@ def make_conv_xAxis( Returns: torch.Tensor: - x-axis for the result of a convolution. + x-axis for the result of a convolution. \n + ``shape``: (n_samples_ds,) """ ## If n_k is odd, then no offset, if even, then offset by 0.5