Skip to content

Commit

Permalink
refactor: add parameterized test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
anujsinha3 committed Oct 25, 2023
1 parent 0cb9d69 commit 0858325
Showing 1 changed file with 23 additions and 13 deletions.
36 changes: 23 additions & 13 deletions tests/test_whiten.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
import numpy as np
import pytest
import scipy
from scipy.fftpack import next_fast_len

Expand Down Expand Up @@ -99,16 +100,16 @@ def whiten_original(data, fft_para: ConfigParameters):
# it is not expected that the smoothed version returns the same, so currently no test for that
# (would be good to add one based on some expected outcome)

fft_para = ConfigParameters()
fft_para.samp_freq = 1.0
fft_para.freqmin = 0.01
fft_para.freqmax = 0.2
fft_para.smooth_N = 1
fft_para.freq_norm = FreqNorm.PHASE_ONLY


def whiten1d():
def whiten1d(freq_norm: FreqNorm):
# 1 D case
fft_para = ConfigParameters()
fft_para.samp_freq = 1.0
fft_para.freqmin = 0.01
fft_para.freqmax = 0.2
fft_para.smooth_N = 1
fft_para.freq_norm = freq_norm

data = np.random.random(1000)
white_original = whiten_original(data, fft_para)
white_new = whiten(data, fft_para)
Expand All @@ -119,8 +120,15 @@ def whiten1d():
return white_original, white_new


def whiten2d():
def whiten2d(freq_norm: FreqNorm):
# 2 D case
fft_para = ConfigParameters()
fft_para.samp_freq = 1.0
fft_para.freqmin = 0.01
fft_para.freqmax = 0.2
fft_para.smooth_N = 1
fft_para.freq_norm = freq_norm

data = np.random.random((5, 1000))
white_original = whiten_original(data, fft_para)
white_new = whiten(data, fft_para)
Expand Down Expand Up @@ -158,12 +166,14 @@ def plot_2d(white_original, white_new):


# Use wrappers since test functions are not supposed to return values
def test_whiten1d():
_, _ = whiten1d()
@pytest.mark.parametrize("freqNorm", [FreqNorm.PHASE_ONLY, FreqNorm.RMA])
def test_whiten1d(freqNorm: FreqNorm):
_, _ = whiten1d(freqNorm)


def test_whiten2d():
_, _ = whiten2d()
@pytest.mark.parametrize("freqNorm", [FreqNorm.PHASE_ONLY, FreqNorm.RMA])
def test_whiten2d(freqNorm: FreqNorm):
_, _ = whiten2d(freqNorm)


if __name__ == "__main__":
Expand Down

0 comments on commit 0858325

Please sign in to comment.