diff --git a/pylops/utils/backend.py b/pylops/utils/backend.py index 77467615..d88ec830 100644 --- a/pylops/utils/backend.py +++ b/pylops/utils/backend.py @@ -610,29 +610,29 @@ def inplace_divide(x: npt.ArrayLike, y: npt.ArrayLike, idx: list) -> NDArray: return y -def randn(n: Union[int, tuple], backend: str = "numpy") -> Callable: +def randn(*n: int, backend: str = "numpy") -> NDArray: """Returns randomly generated number Parameters ---------- - n : :obj:`int` or :obj:`tuple` - Number of samples to generate + *n : :obj:`int` + Number of samples to generate in each dimension backend : :obj:`str`, optional Backend used for dot test computations (``numpy`` or ``cupy``). This parameter will be used to choose how to create the random vectors. Returns ------- - x : :obj:`numpy.ndarray` or :obj:`jax.numpy.ndarray` + x : :obj:`numpy.ndarray` or :obj:`jax.Array` Generated array """ if backend == "numpy": - x = np.random.randn(n) + x = np.random.randn(*n) elif backend == "cupy": - x = cp.random.randn(n) + x = cp.random.randn(*n) elif backend == "jax": - x = jnp.array(np.random.randn(n)) + x = jnp.array(np.random.randn(*n)) else: raise ValueError("backend must be numpy, cupy, or jax") return x diff --git a/pylops/utils/dottest.py b/pylops/utils/dottest.py index 7aefca95..c8a198ca 100644 --- a/pylops/utils/dottest.py +++ b/pylops/utils/dottest.py @@ -93,13 +93,13 @@ def dottest( # make u and v vectors rdtype = np.ones(1, Op.dtype).real.dtype - u = randn(nc, backend).astype(rdtype) + u = randn(nc, backend=backend).astype(rdtype) if complexflag not in (0, 2): - u = u + 1j * randn(nc, backend).astype(rdtype) + u = u + 1j * randn(nc, backend=backend).astype(rdtype) - v = randn(nr, backend).astype(rdtype) + v = randn(nr, backend=backend).astype(rdtype) if complexflag not in (0, 1): - v = v + 1j * randn(nr, backend).astype(rdtype) + v = v + 1j * randn(nr, backend=backend).astype(rdtype) y = Op.matvec(u) # Op * u x = Op.rmatvec(v) # Op'* v