Skip to content

Commit

Permalink
FIX address PR comments for tests
Browse files Browse the repository at this point in the history
  • Loading branch information
eickenberg committed Oct 12, 2023
1 parent f53348c commit b7233e3
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 1 deletion.
20 changes: 19 additions & 1 deletion tests/test_1d/test_backward_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def func(points, targets):
dict(modeord=int(not fftshift), isign=isign),
)

assert gradcheck(func, inputs, atol=1.5e-4 * N)
assert gradcheck(func, inputs, atol=5e-3 * N)


@pytest.mark.parametrize("N", Ns)
Expand All @@ -208,3 +208,21 @@ def test_t2_backward_CPU_points(
) -> None:
check_t2_backward(N, modifier, fftshift, isign, "cpu", True)

@pytest.mark.parametrize("N", Ns)
@pytest.mark.parametrize("modifier", length_modifiers)
@pytest.mark.parametrize("fftshift", [False, True])
@pytest.mark.parametrize("isign", [-1, 1])
def test_t2_backward_CPU_values(
N: int, modifier: int, fftshift: bool, isign: int
) -> None:
check_t2_backward(N, modifier, fftshift, isign, "cuda", False)

@pytest.mark.parametrize("N", Ns)
@pytest.mark.parametrize("modifier", length_modifiers)
@pytest.mark.parametrize("fftshift", [False, True])
@pytest.mark.parametrize("isign", [-1, 1])
def test_t2_backward_CPU_points(
N: int, modifier: int, fftshift: bool, isign: int
) -> None:
check_t2_backward(N, modifier, fftshift, isign, "cuda", True)

18 changes: 18 additions & 0 deletions tests/test_2d/test_backward_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,21 @@ def test_t2_backward_CPU_points(
) -> None:
check_t2_backward(N, modifier, fftshift, isign, "cpu", True)

@pytest.mark.parametrize("N", Ns)
@pytest.mark.parametrize("modifier", length_modifiers)
@pytest.mark.parametrize("fftshift", [False, True])
@pytest.mark.parametrize("isign", [-1, 1])
def test_t2_backward_CPU_values(
N: int, modifier: int, fftshift: bool, isign: int
) -> None:
check_t2_backward(N, modifier, fftshift, isign, "cuda", False)

@pytest.mark.parametrize("N", Ns)
@pytest.mark.parametrize("modifier", length_modifiers)
@pytest.mark.parametrize("fftshift", [False, True])
@pytest.mark.parametrize("isign", [-1, 1])
def test_t2_backward_CPU_points(
N: int, modifier: int, fftshift: bool, isign: int
) -> None:
check_t2_backward(N, modifier, fftshift, isign, "cuda", True)

18 changes: 18 additions & 0 deletions tests/test_3d/test_backward_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,21 @@ def test_t2_backward_CPU_points(
) -> None:
check_t2_backward(N, modifier, fftshift, isign, "cpu", True)

@pytest.mark.parametrize("N", Ns)
@pytest.mark.parametrize("modifier", length_modifiers)
@pytest.mark.parametrize("fftshift", [False, True])
@pytest.mark.parametrize("isign", [-1, 1])
def test_t2_backward_CPU_values(
N: int, modifier: int, fftshift: bool, isign: int
) -> None:
check_t2_backward(N, modifier, fftshift, isign, "cuda", False)

@pytest.mark.parametrize("N", Ns)
@pytest.mark.parametrize("modifier", length_modifiers)
@pytest.mark.parametrize("fftshift", [False, True])
@pytest.mark.parametrize("isign", [-1, 1])
def test_t2_backward_CPU_points(
N: int, modifier: int, fftshift: bool, isign: int
) -> None:
check_t2_backward(N, modifier, fftshift, isign, "cuda", True)

0 comments on commit b7233e3

Please sign in to comment.