Skip to content

Commit

Permalink
minor: fixed some error messages
Browse files Browse the repository at this point in the history
  • Loading branch information
mrava87 committed Jul 4, 2024
1 parent e27faa0 commit eae806f
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 18 deletions.
2 changes: 0 additions & 2 deletions pylops/avo/poststack.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,12 @@ def _PoststackLinearModelling(
D = ncp.diag(0.5 * ncp.ones(nt0 - 1, dtype=dtype), k=1) - ncp.diag(
0.5 * ncp.ones(nt0 - 1, dtype=dtype), -1
)
# D[0] = D[-1] = 0.
D = inplace_set(ncp.array(0.0), D, 0)
D = inplace_set(ncp.array(0.0), D, -1)
else:
D = ncp.diag(ncp.ones(nt0 - 1, dtype=dtype), k=1) - ncp.diag(
ncp.ones(nt0, dtype=dtype), k=0
)
# D[-1] = 0.
D = inplace_set(ncp.array(0.0), D, -1)

# Create wavelet operator
Expand Down
21 changes: 13 additions & 8 deletions pylops/jaxoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,19 @@ def _rmatvecad(self, x: JaxType, y: JaxType) -> JaxType:
Parameters
----------
x : :obj:`jaxlib.xla_extension.ArrayImpl`
Input array
Input array for forward
y : :obj:`jaxlib.xla_extension.ArrayImpl`
Output array (where to store the
Vector-Jacobian product)
Input array for adjoint
Returns
----------
y : :obj:`jaxlib.xla_extension.ArrayImpl`
-------
xadj : :obj:`jaxlib.xla_extension.ArrayImpl`
Output array
"""
_, f_vjp = jax.vjp(self._matvec, x)
return jax.jit(f_vjp)(y)[0]
xadj = jax.jit(f_vjp)(y)[0]
return xadj

def rmatvecad(self, x: JaxType, y: JaxType) -> JaxType:
"""Adjoint matrix-vector multiplication with AD
Expand All @@ -84,7 +84,9 @@ def rmatvecad(self, x: JaxType, y: JaxType) -> JaxType:
M, N = self.shape

if x.shape != (M,) and x.shape != (M, 1):
raise ValueError("dimension mismatch")
raise ValueError(
f"Dimension mismatch. Got {x.shape}, but expected {(M, 1)} or {(M,)}."
)

y = self._rmatvecad(x, y)

Expand All @@ -93,5 +95,8 @@ def rmatvecad(self, x: JaxType, y: JaxType) -> JaxType:
elif x.ndim == 2:
y = y.reshape(N, 1)
else:
raise ValueError("invalid shape returned by user-defined rmatvecad()")
raise ValueError(
f"Invalid shape returned by user-defined rmatvecad(). "
f"Expected 2-d ndarray or matrix, not {x.ndim}-d ndarray"
)
return y
20 changes: 12 additions & 8 deletions pylops/linearoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,9 @@ def matvec(self, x: NDArray) -> NDArray:
M, N = self.shape

if x.shape != (N,) and x.shape != (N, 1):
raise ValueError("dimension mismatch")
raise ValueError(
f"Dimension mismatch. Got {x.shape}, but expected {(M, 1)} or {(M,)}."
)

y = self._matvec(x)

Expand All @@ -517,7 +519,7 @@ def matvec(self, x: NDArray) -> NDArray:
elif x.ndim == 2:
y = y.reshape(M, 1)
else:
raise ValueError("invalid shape returned by user-defined matvec()")
raise ValueError("Invalid shape returned by user-defined matvec()")
return y

@count(forward=False)
Expand All @@ -542,7 +544,9 @@ def rmatvec(self, x: NDArray) -> NDArray:
M, N = self.shape

if x.shape != (M,) and x.shape != (M, 1):
raise ValueError("dimension mismatch")
raise ValueError(
f"Dimension mismatch. Got {x.shape}, but expected {(M, 1)} or {(M,)}."
)

y = self._rmatvec(x)

Expand All @@ -551,7 +555,7 @@ def rmatvec(self, x: NDArray) -> NDArray:
elif x.ndim == 2:
y = y.reshape(N, 1)
else:
raise ValueError("invalid shape returned by user-defined rmatvec()")
raise ValueError("Invalid shape returned by user-defined rmatvec()")
return y

@count(forward=True, matmat=True)
Expand All @@ -574,9 +578,9 @@ def matmat(self, X: NDArray) -> NDArray:
"""
if X.ndim != 2:
raise ValueError("expected 2-d ndarray or matrix, " "not %d-d" % X.ndim)
raise ValueError(f"Expected 2-d ndarray or matrix, not {X.ndim}-d ndarray")
if X.shape[0] != self.shape[1]:
raise ValueError("dimension mismatch: %r, %r" % (self.shape, X.shape))
raise ValueError(f"Dimension mismatch: {self.shape}, {X.shape}")
Y = self._matmat(X)
return Y

Expand All @@ -600,9 +604,9 @@ def rmatmat(self, X: NDArray) -> NDArray:
"""
if X.ndim != 2:
raise ValueError("expected 2-d ndarray or matrix, " "not %d-d" % X.ndim)
raise ValueError(f"Expected 2-d ndarray or matrix, not {X.ndim}-d ndarray")
if X.shape[0] != self.shape[0]:
raise ValueError("dimension mismatch: %r, %r" % (self.shape, X.shape))
raise ValueError(f"Dimension mismatch: {self.shape}, {X.shape}")
Y = self._rmatmat(X)
return Y

Expand Down

0 comments on commit eae806f

Please sign in to comment.