Skip to content

Commit

Permalink
fix Toeplitz convolution2d test
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Jan 30, 2024
1 parent 149e858 commit dd5f4ab
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions bnpm/tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,9 @@ def test_toeplitz_convolution2d():

try:
t = Toeplitz_convolution2d(x_shape=(stt[0][ii], stt[1][ii]), k=k, mode=mode, dtype=None)
out_t2d = t(x, batching=True, mode=mode)
out_t2d_s = t(scipy.sparse.csr_matrix(x), batching=True, mode=mode).toarray()
print(type(out_t2d), type(out_t2d_s))
out_sp = np.stack([scipy.signal.convolve2d(x_i.reshape(stt[0][ii], stt[1][ii]), k, mode=mode) for x_i in x], axis=0)
out_t2d = t(x, batching=True, mode=mode).reshape(3, out_sp.shape[1], out_sp.shape[2])
out_t2d_s = t(scipy.sparse.csr_matrix(x), batching=True, mode=mode).toarray().reshape(3, out_sp.shape[1], out_sp.shape[2])
except Exception as e:
if mode == 'valid' and (stt[0][ii] < stt[2][ii] or stt[1][ii] < stt[3][ii]):
if 'x must be larger than k' in str(e):
Expand All @@ -93,6 +92,12 @@ def test_toeplitz_convolution2d():

else:
print(f'C) test failed with batching==False, shapes: x: {x.shape}, k: {k.shape} and mode: {mode}')
print(f"Failure analysis: \n")
print(f"Shapes: x: {x.shape}, k: {k.shape}, out_t2d: {out_t2d.shape}, out_t2d_s: {out_t2d_s.shape}, out_sp: {out_sp.shape}")
print(f"out_t2d: {out_t2d}")
print(f"out_t2d_s: {out_t2d_s}")
print(f"out_sp: {out_sp}")

success = False
break
print(f'success with all shapes and modes') if success else None
Expand Down

0 comments on commit dd5f4ab

Please sign in to comment.