Skip to content

Commit

Permalink
chore: simplify testing logic
Browse files Browse the repository at this point in the history
  • Loading branch information
cako committed Aug 5, 2024
1 parent 734776d commit f8e3323
Showing 1 changed file with 39 additions and 35 deletions.
74 changes: 39 additions & 35 deletions pytests/test_torchoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,62 +21,66 @@ def test_TorchOperator(par):
"""
# temporarily, skip tests on mac as torch seems not to recognized
# numpy when v2 is installed
if platform.system() != "Darwin":
Dop = MatrixMult(np.random.normal(0.0, 1.0, (par["ny"], par["nx"])))
Top = TorchOperator(Dop, batch=False)
if platform.system() == "Darwin":
return

x = np.random.normal(0.0, 1.0, par["nx"])
xt = torch.from_numpy(x).view(-1)
xt.requires_grad = True
v = torch.randn(par["ny"])
Dop = MatrixMult(np.random.normal(0.0, 1.0, (par["ny"], par["nx"])))
Top = TorchOperator(Dop, batch=False)

# pylops operator
y = Dop * x
xadj = Dop.H * v
x = np.random.normal(0.0, 1.0, par["nx"])
xt = torch.from_numpy(x).view(-1)
xt.requires_grad = True
v = torch.randn(par["ny"])

# torch operator
yt = Top.apply(xt)
yt.backward(v, retain_graph=True)
# pylops operator
y = Dop * x
xadj = Dop.H * v

assert_array_equal(y, yt.detach().cpu().numpy())
assert_array_equal(xadj, xt.grad.cpu().numpy())
# torch operator
yt = Top.apply(xt)
yt.backward(v, retain_graph=True)

assert_array_equal(y, yt.detach().cpu().numpy())
assert_array_equal(xadj, xt.grad.cpu().numpy())


@pytest.mark.parametrize("par", [(par1)])
def test_TorchOperator_batch(par):
"""Apply forward for input with multiple samples (= batch) and flattened arrays"""
# temporarily, skip tests on mac as torch seems not to recognized
# numpy when v2 is installed
if platform.system() != "Darwin":
Dop = MatrixMult(np.random.normal(0.0, 1.0, (par["ny"], par["nx"])))
Top = TorchOperator(Dop, batch=True)
if platform.system() == "Darwin":
return

Dop = MatrixMult(np.random.normal(0.0, 1.0, (par["ny"], par["nx"])))
Top = TorchOperator(Dop, batch=True)

x = np.random.normal(0.0, 1.0, (4, par["nx"]))
xt = torch.from_numpy(x)
xt.requires_grad = True
x = np.random.normal(0.0, 1.0, (4, par["nx"]))
xt = torch.from_numpy(x)
xt.requires_grad = True

y = Dop.matmat(x.T).T
yt = Top.apply(xt)
y = Dop.matmat(x.T).T
yt = Top.apply(xt)

assert_array_equal(y, yt.detach().cpu().numpy())
assert_array_equal(y, yt.detach().cpu().numpy())


@pytest.mark.parametrize("par", [(par1)])
def test_TorchOperator_batch_nd(par):
"""Apply forward for input with multiple samples (= batch) and nd-arrays"""
# temporarily, skip tests on mac as torch seems not to recognized
# numpy when v2 is installed
if platform.system() != "Darwin":
Dop = MatrixMult(
np.random.normal(0.0, 1.0, (par["ny"], par["nx"])), otherdims=(2,)
)
Top = TorchOperator(Dop, batch=True, flatten=False)
if platform.system() == "Darwin":
return

Dop = MatrixMult(np.random.normal(0.0, 1.0, (par["ny"], par["nx"])), otherdims=(2,))
Top = TorchOperator(Dop, batch=True, flatten=False)

x = np.random.normal(0.0, 1.0, (4, par["nx"], 2))
xt = torch.from_numpy(x)
xt.requires_grad = True
x = np.random.normal(0.0, 1.0, (4, par["nx"], 2))
xt = torch.from_numpy(x)
xt.requires_grad = True

y = (Dop @ x.transpose(1, 2, 0)).transpose(2, 0, 1)
yt = Top.apply(xt)
y = (Dop @ x.transpose(1, 2, 0)).transpose(2, 0, 1)
yt = Top.apply(xt)

assert_array_equal(y, yt.detach().cpu().numpy())
assert_array_equal(y, yt.detach().cpu().numpy())

0 comments on commit f8e3323

Please sign in to comment.