diff --git a/tests/test_all.py b/tests/test_all.py index 0d14ecf..26177a3 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -24,14 +24,10 @@ 'taper_asymmetric': True, 'downsample_factor': 4, 'padding': 'valid', - 'DEVICE_compute': 'cpu', - 'DEVICE_return': 'cpu', - 'batch_size': 1000, - 'return_complex': False, + 'take_abs': True, 'filters': None, 'verbose': False, 'plot_pref': False, - 'progressBar': False, } @@ -73,7 +69,7 @@ def test_peak_in_spectrogram_at_sine_wave_frequency( # Apply the VQT to this sine wave spectrogram, _, freqs = v(input_signal) # Convert freqs to a tensor for easier handling - freqs_tensor = torch.tensor(freqs, dtype=torch.float32) + freqs_tensor = torch.as_tensor(freqs, dtype=torch.float32) # Locate the peak in the spectrogram peak_index = torch.argmax(spectrogram[0], dim=0) # Assuming the output shape is (n_channels, n_freq_bins, time_bins) assert torch.all(peak_index == peak_index[0]), "Expected a single peak in the spectrogram" @@ -114,13 +110,8 @@ def test_constant_signal_transformation(): taper_asymmetric=st.booleans(), downsample_factor=st.integers(min_value=1, max_value=100), padding=st.sampled_from(['valid', 'same']), - # DEVICE_compute=st.sampled_from(['cpu', 'cuda']), - # DEVICE_return=st.sampled_from(['cpu', 'cuda']), - batch_size=st.integers(min_value=1, max_value=1000), - # return_complex=st.booleans(), # filters=st.none() | st.sampled_from([None]), # plot_pref=st.booleans(), - # progressBar=st.booleans(), n_channels=st.integers(min_value=1, max_value=100), n_dim=st.integers(min_value=1, max_value=2), @@ -137,13 +128,9 @@ def test_vqt_params( taper_asymmetric, downsample_factor, padding, - # DEVICE_compute, - # DEVICE_return, - batch_size, - # return_complex, + # take_abs, # filters, # plot_pref, - # progressBar, n_channels, n_dim, @@ -160,13 +147,9 @@ def test_vqt_params( 'taper_asymmetric': taper_asymmetric, 'downsample_factor': downsample_factor, 'padding': padding, - # 'DEVICE_compute': DEVICE_compute, - # 'DEVICE_return': DEVICE_return, - 'batch_size': batch_size, - # 'return_complex': return_complex, + # 'take_abs': take_abs, # 'filters': filters, # 'plot_pref': plot_pref, - # 'progressBar': progressBar, } params.update({k: v for k, v in params_vqt.items() if k not in params}) v = vqt.VQT(**params) @@ -182,9 +165,9 @@ def test_vqt_params( assert output is not None, "Output should not be None" # Check output shape assert output.shape[1] == params['n_freq_bins'], "VQT output shape does not match the number of frequency bins" - # Check all output is real if return_complex is False - if not params['return_complex']: - assert torch.all(torch.isreal(output)), "VQT output should be real if return_complex is False" + # Check all output is real if take_abs is True + if params['take_abs']: + assert torch.all(torch.isreal(output)), "VQT output should be real if take_abs is True" def test_vqt_filters(): @@ -202,7 +185,7 @@ def test_vqt_filters(): taper_asymmetric=params['taper_asymmetric'], plot_pref=params['plot_pref'] ) - filters, wins = filters.numpy(), wins.numpy() + filters, freqs, wins = filters.numpy(), freqs.numpy(), wins.numpy() ## Filters, freqs, and wins should all have the following properties: ### They are not all zeros, not all ones, and do not contain NaNs or infinities