From a253147402d1a96049094cf6df6e3c6a7a8c745a Mon Sep 17 00:00:00 2001 From: RichieHakim Date: Thu, 8 Feb 2024 00:54:43 -0500 Subject: [PATCH] make VQT a torch nn module and play nice with jit script --- vqt/helpers.py | 8 +- vqt/vqt.py | 252 +++++++++++++++++++++++++------------------------ 2 files changed, 137 insertions(+), 123 deletions(-) diff --git a/vqt/helpers.py b/vqt/helpers.py index 6b63940..dc3140b 100644 --- a/vqt/helpers.py +++ b/vqt/helpers.py @@ -1,3 +1,5 @@ +import math + import torch import numpy as np import scipy.signal @@ -48,7 +50,7 @@ def make_batches( l = length if batch_size is None: - batch_size = np.int64(np.ceil(l / num_batches)) + batch_size = int(math.ceil(l / num_batches)) for start in range(idx_start, l, batch_size): end = min(start + batch_size, l) @@ -332,7 +334,7 @@ def make_VQT_filters( mother_wave = scipy.signal.windows.get_window(window=window_type, Nx=resolution, fftbins=False) wins, xs = make_scaled_wave_basis(mother_wave, lens_waves=scales, lens_windows=win_size) - wins = torch.tensor(np.stack(wins, axis=0), dtype=torch.float32) + wins = torch.as_tensor(np.stack(wins, axis=0), dtype=torch.float32) elif isinstance(window_type, (np.ndarray, list, tuple)): mother_wave = np.array(window_type, dtype=np.float32) @@ -361,6 +363,8 @@ def make_VQT_filters( ## Normalize filters to have unit magnitude filts_complex = filts_complex / torch.linalg.norm(filts_complex, ord=2, dim=1, keepdim=True) + freqs = torch.as_tensor(freqs, dtype=torch.float32) + ## Plot if plot_pref: plt.figure() diff --git a/vqt/vqt.py b/vqt/vqt.py index e126249..07ad8cb 100644 --- a/vqt/vqt.py +++ b/vqt/vqt.py @@ -1,38 +1,34 @@ +from typing import Union, List, Tuple, Optional, Dict, Any, Sequence, Iterable, Type, Callable import math import torch -import numpy as np -import scipy.signal +# import scipy.signal from tqdm import tqdm -import scipy.fftpack +# import scipy.fftpack from . import helpers -class VQT(): +class VQT(torch.nn.Module): def __init__( self, - Fs_sample=1000, - Q_lowF=3, - Q_highF=20, - F_min=10, - F_max=400, - n_freq_bins=55, - win_size=501, - window_type='gaussian', - symmetry='center', - taper_asymmetric=True, - downsample_factor=4, - padding='valid', - fft_conv=True, - fast_length=True, - DEVICE_compute='cpu', - DEVICE_return='cpu', - batch_size=1000, - take_abs=True, - filters=None, - verbose=True, - plot_pref=False, - progressBar=True, + Fs_sample: Union[int, float]=1000, + Q_lowF: Union[int, float]=1, + Q_highF: Union[int, float]=20, + F_min: Union[int, float]=1, + F_max: Union[int, float]=400, + n_freq_bins: int=50, + win_size: int=501, + window_type: Union[str, torch.Tensor]='gaussian', + symmetry: str='center', + taper_asymmetric: bool=True, + downsample_factor: int=4, + padding: str='same', + fft_conv: bool=True, + fast_length: bool=True, + take_abs: bool=True, + filters: Optional[torch.Tensor]=None, + verbose: Union[int, bool]=True, + plot_pref: bool=False, ): """ Variable Q Transform. Class for applying the variable Q transform to @@ -104,13 +100,6 @@ def __init__( Whether to use scipy.fftpack.next_fast_len to find the next fast length for the FFT. This may be faster, but uses more memory. - DEVICE_compute (str): - Device to use for computation. - DEVICE_return (str): - Device to use for returning the results. - batch_size (int): - Number of signals to process at once. Use a smaller batch size - if you run out of memory. take_abs (bool): Whether to return the complex version of the transform. If True, then returns the absolute value (envelope) of the @@ -122,17 +111,16 @@ def __init__( Verbosity. True to print warnings. plot_pref (bool): Whether to plot the filters. - progressBar (bool): - Whether to show a progress bar. """ + super().__init__() ## Prepare filters + self.using_custom_filters = True if filters is not None else False + self.filters = filters ## This line here is just for torch.jit.script to work. Delete it if you want to forget about jit. if filters is not None: ## Use provided filters - self.using_custom_filters = True self.filters = filters else: ## Make new filters - self.using_custom_filters = False self.filters, self.freqs, self.wins = helpers.make_VQT_filters( Fs_sample=Fs_sample, Q_lowF=Q_lowF, @@ -147,6 +135,9 @@ def __init__( plot_pref=plot_pref, ) + ## Make filters the parameters of the model + self.filters = torch.nn.Parameter(self.filters, requires_grad=False) + ## Gather parameters from arguments ( self.Fs_sample, @@ -160,12 +151,8 @@ def __init__( self.padding, self.fft_conv, self.fast_length, - self.DEVICE_compute, - self.DEVICE_return, - self.batch_size, self.take_abs, self.plot_pref, - self.progressBar, ) = ( Fs_sample, Q_lowF, @@ -178,12 +165,8 @@ def __init__( padding, fft_conv, fast_length, - DEVICE_compute, - DEVICE_return, - batch_size, take_abs, plot_pref, - progressBar, ) ## Warnings @@ -195,7 +178,10 @@ def __init__( if win_size > 1024 and fft_conv == False: print(f"Warning: win_size is {win_size}, which is large for conv1d. Consider using fft_conv=True for faster computation.") - def __call__(self, X): + def forward( + self, + X: torch.Tensor, + ): """ Forward pass of VQT. @@ -214,8 +200,11 @@ def __call__(self, X): self.freqs (Torch tensor): Frequencies of the spectrogram. """ - if type(X) is not torch.Tensor: - X = torch.as_tensor(X, dtype=torch.float32, device=self.DEVICE_compute) + assert isinstance(X, torch.Tensor), "X should be a torch tensor" + X = X.type(torch.float32) + + ## Check that X and filters are on the same device + assert X.device == self.filters.device, "X and filters should be on the same device" if X.ndim==1: X = X[None,:] @@ -223,25 +212,18 @@ def __call__(self, X): assert X.ndim==2, "X should be 2D" ## (n_channels, n_samples) assert self.filters.ndim==2, "Filters should be 2D" ## (n_freq_bins, win_size) - ## Make iterator for batches - batches = helpers.make_batches(X, batch_size=self.batch_size, length=X.shape[0]) - - ## Make function to apply to signals - fn_vqt = lambda arr, filters: downsample( + ## Make spectrograms + specs = downsample( X=convolve( - arr=arr, - kernels=filters, + arr=X, + kernels=self.filters, take_abs=self.take_abs, fft_conv=self.fft_conv, padding=self.padding, fast_length=self.fast_length, - DEVICE=self.DEVICE_compute ), ds_factor=self.downsample_factor, - ).to(self.DEVICE_return) - - ## Make spectrograms - specs = torch.cat([fn_vqt(arr=arr, filters=self.filters) for arr in tqdm(batches, disable=(self.progressBar==False), leave=True, total=int(np.ceil(X.shape[0]/self.batch_size)))]) + ) ## Make x_axis x_axis = make_conv_xAxis( @@ -249,8 +231,7 @@ def __call__(self, X): n_k=self.filters.shape[-1], padding=self.padding, downsample_factor=self.downsample_factor, - DEVICE_compute=self.DEVICE_compute, - DEVICE_return=self.DEVICE_return, + device=X.device, ) return specs, x_axis, self.freqs @@ -259,7 +240,14 @@ def __repr__(self): if self.using_custom_filters: return f"VQT with custom filters" else: - return f"VQT object with parameters: {''.join([f'{k}={getattr(self, k)}, ' for k, v in self.__dict__.items() if k not in ['filters', 'freqs', 'wins']])[:-2]}" + # return f"VQT object with parameters: {''.join([f'{k}={getattr(self, k)}, ' for k, v in self.__dict__.items() if k not in ['filters', 'freqs', 'wins']])[:-2]}" + ## Below lines are because torch.jit.script doesn't allow comprehension if statements + attributes_to_print = [] + for k, v in self.__dict__.items(): + if k not in ['filters', 'freqs', 'wins']: + attributes_to_print.append(k) + return f"VQT object with parameters: {''.join([f'{k}={getattr(self, k)}, ' for k in attributes_to_print])[:-2]}" + def downsample( @@ -294,39 +282,40 @@ def downsample( X = X[None,:] assert X.ndim in [2, 3], "X should be 2D or 3D" ## (n_channels, n_samples) - fn_ds = lambda arr: torch.nn.functional.avg_pool1d( - arr, - kernel_size=[int(ds_factor)], - stride=ds_factor, - ceil_mode=True, - # padding=0, - count_include_pad=False, ## Prevents edge effects - ) - ## Check is X is complex if X.is_complex() == False: - return fn_ds(X) + return _helper_ds(X, ds_factor=ds_factor) elif X.is_complex() == True: ## Unfortunately, torch.nn.functional.avg_pool1d does not support complex numbers. So we have to split it up into ## phases and magnitudes (convert imaginary to polar, split, downsample, recombine with polar to complex conversion) - ## Also, avg_pool1d also only supports 2D or 3D input tensors, so to be safe, we just flatten all the dimensions except - ## the new 'magnitude / phase' dimension, run avg_pool1d, and then reshape it back to the original shape. - fn_imag_to_polarReal = lambda arr: torch.stack([torch.abs(arr), torch.angle(arr)], dim=0) - fn_polarReal_to_imag = lambda arr: arr[0] * torch.exp(1j * arr[1]) - shape_original = X.shape - return fn_polarReal_to_imag(fn_ds(fn_imag_to_polarReal(X).reshape(2, -1))).reshape(*shape_original[:-1], -1) + out = _helper_imag_to_polarReal(X).reshape(2, -1) + out = torch.stack([_helper_ds(out[ii], ds_factor=ds_factor) for ii in range(2)], dim=0) + out = _helper_polarReal_to_imag(out) + return out else: raise ValueError("X should be a torch tensor of type float or complex") +def _helper_polarReal_to_imag(arr: torch.Tensor): + return arr[0] * torch.exp(1j * arr[1]) +def _helper_imag_to_polarReal(arr: torch.Tensor): + return torch.stack([torch.abs(arr), torch.angle(arr)], dim=0) +def _helper_ds(arr: torch.Tensor, ds_factor: int): + return torch.nn.functional.avg_pool1d( + arr, + kernel_size=[int(ds_factor)], + stride=ds_factor, + ceil_mode=True, + # padding=0, + count_include_pad=False, ## Prevents edge effects + ) def convolve( - arr, - kernels, - take_abs=False, - padding='same', - fft_conv=False, - fast_length=False, - DEVICE='cpu', + arr: torch.Tensor, + kernels: torch.Tensor, + take_abs: bool=False, + padding: str='same', + fft_conv: bool=False, + fast_length: bool=False, ): """ Convolve a signal with a set of kernels. \n @@ -351,8 +340,6 @@ def convolve( fast_length (bool): Whether to use scipy.fftpack.next_fast_len to find the next fast length for the FFT. - DEVICE (str): - Device to use for computation. Returns: torch.Tensor: @@ -363,8 +350,8 @@ def convolve( arr = arr[None,:] if arr.ndim==1 else arr kernels = kernels[None,:] if kernels.ndim==1 else kernels - arr = arr.to(DEVICE)[:,None,:] ## Shape: (n_channels, 1, n_samples) - kernels = kernels.to(DEVICE) ## Shape: (n_kernels, win_size) + arr = arr[:,None,:] ## Shape: (n_channels, 1, n_samples) + # kernels = kernels ## Shape: (n_kernels, win_size) if fft_conv: out = fftconvolve( @@ -378,18 +365,18 @@ def convolve( kernels = torch.flip(kernels, dims=[-1,])[:,None,:] ## Flip because torch's conv1d uses cross-correlation, not convolution. if flag_kernels_complex: - kernels = [torch.real(kernels), torch.imag(kernels)] + kernels_list = [torch.real(kernels), torch.imag(kernels)] else: - kernels = [kernels,] + kernels_list = [kernels,] out_conv = [torch.nn.functional.conv1d( input=arr, - weight=kernels, + weight=k, padding=padding, - ) for kernels in kernels] + ) for k in kernels_list] if flag_kernels_complex: - out = torch.complex(*out_conv) + out = torch.complex(out_conv[0], out_conv[1]) else: out = out_conv[0] @@ -400,10 +387,10 @@ def convolve( def fftconvolve( - x, - y, - mode='valid', - fast_length=False, + x: torch.Tensor, + y: torch.Tensor, + mode: str='valid', + fast_length: bool=False, ): """ Convolution using the FFT method. \n @@ -428,29 +415,57 @@ def fftconvolve( fast_length (bool): Whether to use scipy.fftpack.next_fast_len to find the next fast length for the FFT. + Set to False if you want to use backpropagation. Returns: torch.Tensor: Result of the convolution. """ - ## only if both are real, then use rfft - if x.is_complex() == False and y.is_complex() == False: - fft, ifft = torch.fft.rfft, torch.fft.irfft - else: - fft, ifft = torch.fft.fft, torch.fft.ifft - ## Compute the convolution n_original = x.shape[-1] + y.shape[-1] - 1 - n = scipy.fftpack.next_fast_len(n_original) if fast_length else n_original + # n = scipy.fftpack.next_fast_len(n_original) if fast_length else n_original + n = next_fast_len(n_original) if fast_length else n_original + n = n_original - f = fft(x, n=n, dim=-1) * fft(y, n=n, dim=-1) + if x.is_complex() == False and y.is_complex() == False: + f = torch.fft.rfft(x, n=n, dim=-1) * torch.fft.fft(y, n=n, dim=-1) + fftconv_xy = torch.fft.irfft(f, n=n, dim=-1) + else: + f = torch.fft.fft(x, n=n, dim=-1) * torch.fft.fft(y, n=n, dim=-1) + fftconv_xy = torch.fft.ifft(f, n=n, dim=-1) return apply_padding_mode( - conv_result=ifft(f, n=n, dim=-1), + conv_result=fftconv_xy, x_length=x.shape[-1], y_length=y.shape[-1], mode=mode, ) +## For some reason jit is slower here +# @torch.jit.script +def next_fast_len(size: int): + """ + Taken from PyTorch Forecasting: + Returns the next largest number ``n >= size`` whose prime factors are all + 2, 3, or 5. These sizes are efficient for fast fourier transforms. + Equivalent to :func:`scipy.fftpack.next_fast_len`. + + Implementation from pyro + + :param int size: A positive number. + :returns: A possibly larger number. + :rtype int: + """ + assert isinstance(size, int) and size > 0 + next_size = size + while True: + remaining = next_size + for n in (2, 3, 5): + while remaining % n == 0: + remaining = remaining // n + if remaining == 1: + return next_size + next_size += 1 + def apply_padding_mode( conv_result: torch.Tensor, @@ -504,10 +519,9 @@ def apply_padding_mode( def make_conv_xAxis( n_s: int, n_k: int, - padding: str, - downsample_factor: int, - DEVICE_compute: str, - DEVICE_return: str, + padding: str='same', + downsample_factor: int=4, + device: torch.device='cpu', ): """ Make the x-axis for the result of a convolution. @@ -524,27 +538,23 @@ def make_conv_xAxis( Padding mode to use. downsample_factor (int): Factor to downsample the signal by. - DEVICE_compute (str): - Device to use for computation. - DEVICE_return (str): - Device to use for returning the results. + device (str): + Device to use. Returns: torch.Tensor: x-axis for the result of a convolution. """ - if downsample_factor == 1: - DEVICE_compute = DEVICE_return ## If n_k is odd, then no offset, if even, then offset by 0.5 ### PyTorch's conv1d and for 'same' pads more to the right, so on the first index of the output the kernel is centered at an offset of 0.5 - offset = 0.5 if n_k % 2 == 0 else 0 + offset = 0.5 if n_k % 2 == 0 else 0.0 x_axis_full = torch.arange( -(n_k-1)//2 + offset, n_s + (n_k-1)//2 + offset, dtype=torch.float32, - device=DEVICE_compute, + device=device, ) ### Then, apply padding mode to it x_axis_padModed = apply_padding_mode( @@ -557,6 +567,6 @@ def make_conv_xAxis( x_axis = downsample( X=x_axis_padModed[None,None,:], ds_factor=downsample_factor, - ).squeeze().to(DEVICE_return) + ).squeeze().to(device) return x_axis \ No newline at end of file