Skip to content

Commit

Permalink
torch: slightly changed test_autograd_shape
Browse files Browse the repository at this point in the history
  • Loading branch information
D1rk123 committed Jul 20, 2022
1 parent 595aab3 commit 12e95f6
Showing 1 changed file with 5 additions and 11 deletions.
16 changes: 5 additions & 11 deletions tests/test_torch_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,18 +105,12 @@ def test_autograd_shape():
x2 = torch.ones(2, 3, *A.domain_shape, dtype=torch.float32, requires_grad=True)

y1 = A_ag(x1)
assert(y1.size()[0] == 1)
assert(y1.size()[1] == 1)

y2 = A_ag(x2)
assert(y2.size()[0] == 2)
assert(y2.size()[1] == 3)

y1.backward(y1)
assert(x1.grad.size()[0] == 1)
assert(x1.grad.size()[1] == 1)

y2.backward(y2)
assert(x2.grad.size()[0] == 2)
assert(x2.grad.size()[1] == 3)

assert(torch.equal(torch.tensor(y1.size()), torch.tensor([1, 1, *A.range_shape])))
assert(torch.equal(torch.tensor(y2.size()), torch.tensor([2, 3, *A.range_shape])))
assert(torch.equal(torch.tensor(x1.grad.size()), torch.tensor([1, 1, *A.domain_shape])))
assert(torch.equal(torch.tensor(x2.grad.size()), torch.tensor([2, 3, *A.domain_shape])))

0 comments on commit 12e95f6

Please sign in to comment.