diff --git a/vqt/vqt.py b/vqt/vqt.py index 7480403..e126249 100644 --- a/vqt/vqt.py +++ b/vqt/vqt.py @@ -18,7 +18,7 @@ def __init__( F_max=400, n_freq_bins=55, win_size=501, - window_type='hann', + window_type='gaussian', symmetry='center', taper_asymmetric=True, downsample_factor=4, @@ -28,7 +28,7 @@ def __init__( DEVICE_compute='cpu', DEVICE_return='cpu', batch_size=1000, - return_complex=False, + take_abs=True, filters=None, verbose=True, plot_pref=False, @@ -111,10 +111,10 @@ def __init__( batch_size (int): Number of signals to process at once. Use a smaller batch size if you run out of memory. - return_complex (bool): + take_abs (bool): Whether to return the complex version of the transform. If - False, then returns the absolute value (envelope) of the - transform. downsample_factor must be 1 if this is True. + True, then returns the absolute value (envelope) of the + transform. If False, returns the complex transform. filters (Torch tensor): Filters to use. If None, will make new filters. Should be complex sinusoids. shape: (n_freq_bins, win_size) @@ -163,7 +163,7 @@ def __init__( self.DEVICE_compute, self.DEVICE_return, self.batch_size, - self.return_complex, + self.take_abs, self.plot_pref, self.progressBar, ) = ( @@ -181,7 +181,7 @@ def __init__( DEVICE_compute, DEVICE_return, batch_size, - return_complex, + take_abs, plot_pref, progressBar, ) @@ -231,14 +231,13 @@ def __call__(self, X): X=convolve( arr=arr, kernels=filters, - take_abs=(self.return_complex==False), + take_abs=self.take_abs, fft_conv=self.fft_conv, padding=self.padding, fast_length=self.fast_length, DEVICE=self.DEVICE_compute ), ds_factor=self.downsample_factor, - return_complex=self.return_complex, ).to(self.DEVICE_return) ## Make spectrograms @@ -266,10 +265,11 @@ def __repr__(self): def downsample( X: torch.Tensor, ds_factor: int=4, - return_complex: bool=False, ): """ Downsample a signal using average pooling. \n + If X is complex, it will be split into magnitude and phase, downsampled, + and then recombined. \n RH 2024 @@ -279,8 +279,6 @@ def downsample( ``shape``: (..., n_samples) ds_factor (int): Factor to downsample the signal by. - return_complex (bool): - Whether to return the complex version of the signal. Returns: torch.Tensor: @@ -289,22 +287,37 @@ def downsample( if ds_factor == 1: return X + ## Assert X is a torch tensor + assert isinstance(X, torch.Tensor), "X should be a torch tensor" + ## Ensure X.ndim in [2, 3] + if X.ndim==1: + X = X[None,:] + assert X.ndim in [2, 3], "X should be 2D or 3D" ## (n_channels, n_samples) + fn_ds = lambda arr: torch.nn.functional.avg_pool1d( arr, kernel_size=[int(ds_factor)], stride=ds_factor, ceil_mode=True, + # padding=0, + count_include_pad=False, ## Prevents edge effects ) - if return_complex == False: + + ## Check is X is complex + if X.is_complex() == False: return fn_ds(X) - elif return_complex == True: - ## Unfortunately, torch.nn.functional.avg_pool1d does not support complex numbers. So we have to split it up. - ### Split X, shape: (..., n_samples) into real and imaginary parts, shape: (..., n_samples, 2), permute, ds, unpermute, and recombine. - dims = np.arange(X.ndim + 1) - dims_to = list(np.roll(dims, 1)) - dims_from = list(np.roll(dims, -1)) - return torch.view_as_complex(fn_ds(torch.view_as_real(X).permute(*dims_to)).permute(*dims_from).contiguous()) + elif X.is_complex() == True: + ## Unfortunately, torch.nn.functional.avg_pool1d does not support complex numbers. So we have to split it up into + ## phases and magnitudes (convert imaginary to polar, split, downsample, recombine with polar to complex conversion) + ## Also, avg_pool1d also only supports 2D or 3D input tensors, so to be safe, we just flatten all the dimensions except + ## the new 'magnitude / phase' dimension, run avg_pool1d, and then reshape it back to the original shape. + fn_imag_to_polarReal = lambda arr: torch.stack([torch.abs(arr), torch.angle(arr)], dim=0) + fn_polarReal_to_imag = lambda arr: arr[0] * torch.exp(1j * arr[1]) + shape_original = X.shape + return fn_polarReal_to_imag(fn_ds(fn_imag_to_polarReal(X).reshape(2, -1))).reshape(*shape_original[:-1], -1) + else: + raise ValueError("X should be a torch tensor of type float or complex") def convolve( arr, @@ -544,7 +557,6 @@ def make_conv_xAxis( x_axis = downsample( X=x_axis_padModed[None,None,:], ds_factor=downsample_factor, - return_complex=False, ).squeeze().to(DEVICE_return) return x_axis \ No newline at end of file