Skip to content

Commit

Permalink
fix bugs in test_all
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Feb 9, 2024
1 parent 0d275d8 commit 00c9734
Showing 1 changed file with 8 additions and 25 deletions.
33 changes: 8 additions & 25 deletions tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,10 @@
'taper_asymmetric': True,
'downsample_factor': 4,
'padding': 'valid',
'DEVICE_compute': 'cpu',
'DEVICE_return': 'cpu',
'batch_size': 1000,
'return_complex': False,
'take_abs': True,
'filters': None,
'verbose': False,
'plot_pref': False,
'progressBar': False,
}


Expand Down Expand Up @@ -73,7 +69,7 @@ def test_peak_in_spectrogram_at_sine_wave_frequency(
# Apply the VQT to this sine wave
spectrogram, _, freqs = v(input_signal)
# Convert freqs to a tensor for easier handling
freqs_tensor = torch.tensor(freqs, dtype=torch.float32)
freqs_tensor = torch.as_tensor(freqs, dtype=torch.float32)
# Locate the peak in the spectrogram
peak_index = torch.argmax(spectrogram[0], dim=0) # Assuming the output shape is (n_channels, n_freq_bins, time_bins)
assert torch.all(peak_index == peak_index[0]), "Expected a single peak in the spectrogram"
Expand Down Expand Up @@ -114,13 +110,8 @@ def test_constant_signal_transformation():
taper_asymmetric=st.booleans(),
downsample_factor=st.integers(min_value=1, max_value=100),
padding=st.sampled_from(['valid', 'same']),
# DEVICE_compute=st.sampled_from(['cpu', 'cuda']),
# DEVICE_return=st.sampled_from(['cpu', 'cuda']),
batch_size=st.integers(min_value=1, max_value=1000),
# return_complex=st.booleans(),
# filters=st.none() | st.sampled_from([None]),
# plot_pref=st.booleans(),
# progressBar=st.booleans(),
n_channels=st.integers(min_value=1, max_value=100),
n_dim=st.integers(min_value=1, max_value=2),
Expand All @@ -137,13 +128,9 @@ def test_vqt_params(
taper_asymmetric,
downsample_factor,
padding,
# DEVICE_compute,
# DEVICE_return,
batch_size,
# return_complex,
# take_abs,
# filters,
# plot_pref,
# progressBar,

n_channels,
n_dim,
Expand All @@ -160,13 +147,9 @@ def test_vqt_params(
'taper_asymmetric': taper_asymmetric,
'downsample_factor': downsample_factor,
'padding': padding,
# 'DEVICE_compute': DEVICE_compute,
# 'DEVICE_return': DEVICE_return,
'batch_size': batch_size,
# 'return_complex': return_complex,
# 'take_abs': take_abs,
# 'filters': filters,
# 'plot_pref': plot_pref,
# 'progressBar': progressBar,
}
params.update({k: v for k, v in params_vqt.items() if k not in params})
v = vqt.VQT(**params)
Expand All @@ -182,9 +165,9 @@ def test_vqt_params(
assert output is not None, "Output should not be None"
# Check output shape
assert output.shape[1] == params['n_freq_bins'], "VQT output shape does not match the number of frequency bins"
# Check all output is real if return_complex is False
if not params['return_complex']:
assert torch.all(torch.isreal(output)), "VQT output should be real if return_complex is False"
# Check all output is real if take_abs is True
if params['take_abs']:
assert torch.all(torch.isreal(output)), "VQT output should be real if take_abs is True"


def test_vqt_filters():
Expand All @@ -202,7 +185,7 @@ def test_vqt_filters():
taper_asymmetric=params['taper_asymmetric'],
plot_pref=params['plot_pref']
)
filters, wins = filters.numpy(), wins.numpy()
filters, freqs, wins = filters.numpy(), freqs.numpy(), wins.numpy()

## Filters, freqs, and wins should all have the following properties:
### They are not all zeros, not all ones, and do not contain NaNs or infinities
Expand Down

0 comments on commit 00c9734

Please sign in to comment.