Skip to content

Commit

Permalink
fix: randn arguments
Browse files Browse the repository at this point in the history
* Uses positional arguments instead of `n` as int | tuple, which is the correct usage with `np.random.randn`
* Corrects input/output types
  • Loading branch information
cako committed Aug 5, 2024
1 parent 694e2ff commit 1890d16
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
14 changes: 7 additions & 7 deletions pylops/utils/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,29 +610,29 @@ def inplace_divide(x: npt.ArrayLike, y: npt.ArrayLike, idx: list) -> NDArray:
return y


def randn(n: Union[int, tuple], backend: str = "numpy") -> Callable:
def randn(*n: int, backend: str = "numpy") -> NDArray:
"""Returns randomly generated number
Parameters
----------
n : :obj:`int` or :obj:`tuple`
Number of samples to generate
*n : :obj:`int`
Number of samples to generate in each dimension
backend : :obj:`str`, optional
Backend used for dot test computations (``numpy`` or ``cupy``). This
parameter will be used to choose how to create the random vectors.
Returns
-------
x : :obj:`numpy.ndarray` or :obj:`jax.numpy.ndarray`
x : :obj:`numpy.ndarray` or :obj:`jax.Array`
Generated array
"""
if backend == "numpy":
x = np.random.randn(n)
x = np.random.randn(*n)
elif backend == "cupy":
x = cp.random.randn(n)
x = cp.random.randn(*n)
elif backend == "jax":
x = jnp.array(np.random.randn(n))
x = jnp.array(np.random.randn(*n))
else:
raise ValueError("backend must be numpy, cupy, or jax")
return x
8 changes: 4 additions & 4 deletions pylops/utils/dottest.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,13 @@ def dottest(
# make u and v vectors
rdtype = np.ones(1, Op.dtype).real.dtype

u = randn(nc, backend).astype(rdtype)
u = randn(nc, backend=backend).astype(rdtype)
if complexflag not in (0, 2):
u = u + 1j * randn(nc, backend).astype(rdtype)
u = u + 1j * randn(nc, backend=backend).astype(rdtype)

v = randn(nr, backend).astype(rdtype)
v = randn(nr, backend=backend).astype(rdtype)
if complexflag not in (0, 1):
v = v + 1j * randn(nr, backend).astype(rdtype)
v = v + 1j * randn(nr, backend=backend).astype(rdtype)

y = Op.matvec(u) # Op * u
x = Op.rmatvec(v) # Op'* v
Expand Down

0 comments on commit 1890d16

Please sign in to comment.