From 66ae60495e5d1639e61dbd47f61e3eabcf3a9a60 Mon Sep 17 00:00:00 2001 From: RichieHakim Date: Sun, 4 Feb 2024 13:43:42 -0500 Subject: [PATCH] DELETE VQT and make_VQT_filters methods, UPDATE --- bnpm/spectral.py | 362 ----------------------------------------------- 1 file changed, 362 deletions(-) diff --git a/bnpm/spectral.py b/bnpm/spectral.py index 11398e2..832e63d 100644 --- a/bnpm/spectral.py +++ b/bnpm/spectral.py @@ -351,365 +351,3 @@ def torch_hilbert(x, N=None, dim=0): return torch.fft.ifft(xf * m, dim=dim) - -def make_VQT_filters( - Fs_sample=1000, - Q_lowF=3, - Q_highF=20, - F_min=10, - F_max=400, - n_freq_bins=55, - win_size=501, - symmetry='center', - taper_asymmetric=True, - plot_pref=False -): - """ - Creates a set of filters for use in the VQT algorithm. - - Set Q_lowF and Q_highF to be the same value for a - Constant Q Transform (CQT) filter set. - Varying these values will varying the Q factor - logarithmically across the frequency range. - - RH 2022 - - Args: - Fs_sample (float): - Sampling frequency of the signal. - Q_lowF (float): - Q factor to use for the lowest frequency. - Q_highF (float): - Q factor to use for the highest frequency. - F_min (float): - Lowest frequency to use. - F_max (float): - Highest frequency to use (inclusive). - n_freq_bins (int): - Number of frequency bins to use. - win_size (int): - Size of the window to use, in samples. - symmetry (str): - Whether to use a symmetric window or a single-sided window. - - 'center': Use a symmetric / centered / 'two-sided' window. - - 'left': Use a one-sided, left-half window. Only left half of the - filter will be nonzero. - - 'right': Use a one-sided, right-half window. Only right half of the - filter will be nonzero. - taper_asymmetric (bool): - Only used if symmetry is not 'center'. - Whether to taper the center of the window by multiplying center - sample of window by 0.5. - plot_pref (bool): - Whether to plot the filters. - - Returns: - filters (Torch ndarray): - Array of complex sinusoid filters. - shape: (n_freq_bins, win_size) - freqs (Torch array): - Array of frequencies corresponding to the filters. - wins (Torch ndarray): - Array of window functions (gaussians) - corresponding to each filter. - shape: (n_freq_bins, win_size) - """ - - assert win_size%2==1, "RH Error: win_size should be an odd integer" - - ## Make frequencies. Use a geometric spacing. - freqs = np.geomspace( - start=F_min, - stop=F_max, - num=n_freq_bins, - endpoint=True, - dtype=np.float32, - ) - - periods = 1 / freqs - periods_inSamples = Fs_sample * periods - - ## Make sigmas for gaussian windows. Use a geometric spacing. - sigma_all = np.geomspace( - start=Q_lowF, - stop=Q_highF, - num=n_freq_bins, - endpoint=True, - dtype=np.float32, - ) - sigma_all = sigma_all * periods_inSamples / 4 - - ## Make windows - ### Make windows gaussian - wins = torch.stack([math_functions.gaussian(torch.arange(-win_size//2, win_size//2), 0, sig=sigma) for sigma in sigma_all]) - ### Make windows symmetric or asymmetric - if symmetry=='center': - pass - else: - heaviside = (torch.arange(win_size) <= win_size//2).float() - if symmetry=='left': - pass - elif symmetry=='right': - heaviside = torch.flip(heaviside, dims=[0]) - else: - raise ValueError("symmetry must be 'center', 'left', or 'right'") - wins *= heaviside - ### Taper the center of the window by multiplying center sample of window by 0.5 - if taper_asymmetric: - wins[:, win_size//2] = wins[:, win_size//2] * 0.5 - - - filts = torch.stack([torch.cos(torch.linspace(-np.pi, np.pi, win_size) * freq * (win_size/Fs_sample)) * win for freq, win in zip(freqs, wins)], dim=0) - filts_complex = torch_hilbert(filts.T, dim=0).T - - ## Normalize filters to have unit magnitude - filts_complex = filts_complex / torch.sum(torch.abs(filts_complex), dim=1, keepdims=True) - - ## Plot - if plot_pref: - plt.figure() - plt.plot(freqs) - plt.xlabel('filter num') - plt.ylabel('frequency (Hz)') - - plt.figure() - plt.imshow(wins / torch.max(wins, 1, keepdims=True)[0], aspect='auto') - plt.title('windows (gaussian)') - - plt.figure() - plt.plot(sigma_all) - plt.xlabel('filter num') - plt.ylabel('window width (sigma of gaussian)') - - plt.figure() - plt.imshow(torch.real(filts_complex) / torch.max(torch.real(filts_complex), 1, keepdims=True)[0], aspect='auto', cmap='bwr', vmin=-1, vmax=1) - plt.title('filters (real component)') - - - worN=win_size*4 - filts_freq = np.array([scipy.signal.freqz( - b=filt, - fs=Fs_sample, - worN=worN, - )[1] for filt in filts_complex]) - - filts_freq_xAxis = scipy.signal.freqz( - b=filts_complex[0], - worN=worN, - fs=Fs_sample - )[0] - - plt.figure() - plt.plot(filts_freq_xAxis, np.abs(filts_freq.T)); - plt.xscale('log') - plt.xlabel('frequency (Hz)') - plt.ylabel('magnitude') - - return filts_complex, freqs, wins - -class VQT(): - def __init__( - self, - Fs_sample=1000, - Q_lowF=3, - Q_highF=20, - F_min=10, - F_max=400, - n_freq_bins=55, - win_size=501, - symmetry='center', - taper_asymmetric=True, - downsample_factor=4, - padding='valid', - DEVICE_compute='cpu', - DEVICE_return='cpu', - batch_size=1000, - return_complex=False, - filters=None, - plot_pref=False, - progressBar=True, - ): - """ - Variable Q Transform. - Class for applying the variable Q transform to signals. - - This function works differently than the VQT from - librosa or nnAudio. This one does not use iterative - lowpass filtering. Instead, it uses a fixed set of - filters, and a Hilbert transform to compute the analytic - signal. It can then take the envelope and downsample. - - Uses Pytorch for GPU acceleration, and allows gradients - to pass through. - - Q: quality factor; roughly corresponds to the number - of cycles in a filter. Here, Q is the number of cycles - within 4 sigma (95%) of a gaussian window. - - RH 2022 - - Args: - Fs_sample (float): - Sampling frequency of the signal. - Q_lowF (float): - Q factor to use for the lowest frequency. - Q_highF (float): - Q factor to use for the highest frequency. - F_min (float): - Lowest frequency to use. - F_max (float): - Highest frequency to use. - n_freq_bins (int): - Number of frequency bins to use. - win_size (int): - Size of the window to use, in samples. - symmetry (str): - Whether to use a symmetric window or a single-sided window. - - 'center': Use a symmetric / centered / 'two-sided' window. - - 'left': Use a one-sided, left-half window. Only left half of the - filter will be nonzero. - - 'right': Use a one-sided, right-half window. Only right half of the - filter will be nonzero. - taper_asymmetric (bool): - Only used if symmetry is not 'center'. - Whether to taper the center of the window by multiplying center - sample of window by 0.5. - downsample_factor (int): - Factor to downsample the signal by. - If the length of the input signal is not - divisible by downsample_factor, the signal - will be zero-padded at the end so that it is. - padding (str): - Padding to use for the signal. - 'same' will pad the signal so that the output - signal is the same length as the input signal. - 'valid' will not pad the signal. So the output - signal will be shorter than the input signal. - DEVICE_compute (str): - Device to use for computation. - DEVICE_return (str): - Device to use for returning the results. - batch_size (int): - Number of signals to process at once. - Use a smaller batch size if you run out of memory. - return_complex (bool): - Whether to return the complex version of - the transform. If False, then returns the - absolute value (envelope) of the transform. - downsample_factor must be 1 if this is True. - filters (Torch tensor): - Filters to use. If None, will make new filters. - Should be complex sinusoids. - shape: (n_freq_bins, win_size) - plot_pref (bool): - Whether to plot the filters. - progressBar (bool): - Whether to show a progress bar. - """ - ## Prepare filters - if filters is not None: - ## Use provided filters - self.using_custom_filters = True - self.filters = filters - else: - ## Make new filters - self.using_custom_filters = False - self.filters, self.freqs, self.wins = make_VQT_filters( - Fs_sample=Fs_sample, - Q_lowF=Q_lowF, - Q_highF=Q_highF, - F_min=F_min, - F_max=F_max, - n_freq_bins=n_freq_bins, - win_size=win_size, - symmetry=symmetry, - taper_asymmetric=taper_asymmetric, - plot_pref=plot_pref, - ) - ## Gather parameters from arguments - self.Fs_sample, self.Q_lowF, self.Q_highF, self.F_min, self.F_max, self.n_freq_bins, self.win_size, self.downsample_factor, self.padding, self.DEVICE_compute, \ - self.DEVICE_return, self.batch_size, self.return_complex, self.plot_pref, self.progressBar = \ - Fs_sample, Q_lowF, Q_highF, F_min, F_max, n_freq_bins, win_size, downsample_factor, padding, DEVICE_compute, DEVICE_return, batch_size, return_complex, plot_pref, progressBar - - def _helper_ds(self, X: torch.Tensor, ds_factor: int=4, return_complex: bool=False): - if ds_factor == 1: - return X - elif return_complex == False: - return torch.nn.functional.avg_pool1d(X, kernel_size=[int(ds_factor)], stride=ds_factor, ceil_mode=True) - elif return_complex == True: - ## Unfortunately, torch.nn.functional.avg_pool1d does not support complex numbers. So we have to split it up. - ### Split X, shape: (batch_size, n_freq_bins, n_samples) into real and imaginary parts, shape: (batch_size, n_freq_bins, n_samples, 2) - Y = torch.view_as_real(X) - ### Downsample each part separately, then stack them and make them complex again. - Z = torch.view_as_complex(torch.stack([torch.nn.functional.avg_pool1d(y, kernel_size=[int(ds_factor)], stride=ds_factor, ceil_mode=True) for y in [Y[...,0], Y[...,1]]], dim=-1)) - return Z - - def _helper_conv(self, arr, filters, take_abs, DEVICE): - out = torch.complex( - torch.nn.functional.conv1d(input=arr.to(DEVICE)[:,None,:], weight=torch.real(filters.T).to(DEVICE).T[:,None,:], padding=self.padding), - torch.nn.functional.conv1d(input=arr.to(DEVICE)[:,None,:], weight=-torch.imag(filters.T).to(DEVICE).T[:,None,:], padding=self.padding) - ) - if take_abs: - return torch.abs(out) - else: - return out - - def __call__(self, X): - """ - Forward pass of VQT. - - Args: - X (Torch tensor): - Input signal. - shape: (n_channels, n_samples) - - Returns: - Spectrogram (Torch tensor): - Spectrogram of the input signal. - shape: (n_channels, n_samples_ds, n_freq_bins) - x_axis (Torch tensor): - New x-axis for the spectrogram in units of samples. - Get units of time by dividing by self.Fs_sample. - self.freqs (Torch tensor): - Frequencies of the spectrogram. - """ - if type(X) is not torch.Tensor: - X = torch.as_tensor(X, dtype=torch.float32, device=self.DEVICE_compute) - - if X.ndim==1: - X = X[None,:] - - ## Make iterator for batches - batches = indexing.make_batches(X, batch_size=self.batch_size, length=X.shape[0]) - - ## Make spectrograms - specs = [self._helper_ds( - X=self._helper_conv( - arr=arr, - filters=self.filters, - take_abs=(self.return_complex==False), - DEVICE=self.DEVICE_compute - ), - ds_factor=self.downsample_factor, - return_complex=self.return_complex, - ).to(self.DEVICE_return) for arr in tqdm(batches, disable=(self.progressBar==False), leave=True, total=int(np.ceil(X.shape[0]/self.batch_size)))] - specs = torch.cat(specs, dim=0) - - ## Make x_axis - x_axis = torch.nn.functional.avg_pool1d( - torch.nn.functional.conv1d( - input=torch.arange(0, X.shape[-1], dtype=torch.float32)[None,None,:], - weight=torch.ones(1,1,self.filters.shape[-1], dtype=torch.float32) / self.filters.shape[-1], - padding=self.padding - ), - kernel_size=[int(self.downsample_factor)], - stride=self.downsample_factor, ceil_mode=True, - ).squeeze() - - return specs, x_axis, self.freqs - - def __repr__(self): - if self.using_custom_filters: - return f"VQT with custom filters" - else: - return f"VQT object with parameters: Fs_sample={self.Fs_sample}, Q_lowF={self.Q_lowF}, Q_highF={self.Q_highF}, F_min={self.F_min}, F_max={self.F_max}, n_freq_bins={self.n_freq_bins}, win_size={self.win_size}, downsample_factor={self.downsample_factor}, DEVICE_compute={self.DEVICE_compute}, DEVICE_return={self.DEVICE_return}, batch_size={self.batch_size}, return_complex={self.return_complex}, plot_pref={self.plot_pref}"