Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vtavana committed Oct 25, 2024
1 parent b646cf1 commit 791def5
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 17 deletions.
8 changes: 4 additions & 4 deletions dpnp/linalg/dpnp_iface_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,7 +1073,7 @@ def matrix_rank(A, tol=None, hermitian=False, *, rtol=None):
----------
A : {(M,), (..., M, N)} {dpnp.ndarray, usm_ndarray}
Input vector or stack of matrices.
tol : (...) {float, dpnp.ndarray, usm_ndarray}, optional
tol : (...) {None, float, dpnp.ndarray, usm_ndarray}, optional
Threshold below which SVD values are considered zero. Only `tol` or
`rtol` can be set at a time. If none of them are provided, defaults
to ``S.max() * max(M, N) * eps`` where `S` is an array with singular
Expand All @@ -1083,7 +1083,7 @@ def matrix_rank(A, tol=None, hermitian=False, *, rtol=None):
If ``True``, `A` is assumed to be Hermitian (symmetric if real-valued),
enabling a more efficient method for finding singular values.
Default: ``False``.
rtol : (...) {float, dpnp.ndarray, usm_ndarray}, optional
rtol : (...) {None, float, dpnp.ndarray, usm_ndarray}, optional
Parameter for the relative tolerance component. Only `tol` or `rtol`
can be set at a time. If none of them are provided, defaults to
``max(M, N) * eps`` where `eps` is the epsilon value for datatype
Expand Down Expand Up @@ -1479,7 +1479,7 @@ def pinv(a, rcond=None, hermitian=False, *, rtol=None):
----------
a : (..., M, N) {dpnp.ndarray, usm_ndarray}
Matrix or stack of matrices to be pseudo-inverted.
rcond : (...) {float, dpnp.ndarray, usm_ndarray}, optional
rcond : (...) {None, float, dpnp.ndarray, usm_ndarray}, optional
Cutoff for small singular values.
Singular values less than or equal to ``rcond * largest_singular_value``
are set to zero. Broadcasts against the stack of matrices.
Expand All @@ -1490,7 +1490,7 @@ def pinv(a, rcond=None, hermitian=False, *, rtol=None):
If ``True``, a is assumed to be Hermitian (symmetric if real-valued),
enabling a more efficient method for finding singular values.
Default: ``False``.
rtol : (...) {float, dpnp.ndarray, usm_ndarray}, optional
rtol : (...) {None, float, dpnp.ndarray, usm_ndarray}, optional
Same as `rcond`, but it's an Array API compatible parameter name.
Only `rcond` or `rtol` can be set at a time. If none of them are
provided, defaults to ``max(M, N) * dpnp.finfo(a.dtype).eps``.
Expand Down
5 changes: 4 additions & 1 deletion dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2234,10 +2234,13 @@ def dpnp_matrix_rank(A, tol=None, hermitian=False, rtol=None):
if rtol is None:
rtol = max(A.shape[-2:]) * dpnp.finfo(S.dtype).eps
elif not dpnp.isscalar(rtol):
# Add a new axis to make it broadcastable against S
# needed for S > tol comparison below
rtol = rtol[..., None]
tol = S.max(axis=-1, keepdims=True) * rtol
elif not dpnp.isscalar(tol):
# Add a new axis to match NumPy's output
# Add a new axis to make it broadcastable against S,
# needed for S > tol comparison below
tol = tol[..., None]

return dpnp.count_nonzero(S > tol, axis=-1)
Expand Down
27 changes: 15 additions & 12 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2100,25 +2100,28 @@ def test_matrix_rank_tolerance(self, high_tol, low_tol):
# rtol kwarg was added in numpy 2.0
@testing.with_requires("numpy>=2.0")
@pytest.mark.parametrize(
"rtol",
[0.99e-6, numpy.array(1.01e-6), numpy.array([0.99e-6])],
"tol",
[0.99e-6, numpy.array(1.01e-6), numpy.ones(4) * [0.99e-6]],
ids=["float", "0-D array", "1-D array"],
)
def test_matrix_rank_rtol(self, rtol):
a = numpy.eye(4)
a[-1, -1] = 1e-6
def test_matrix_rank_tol(self, tol):
a = numpy.zeros((4, 3, 2))
a_dp = inp.array(a)

if isinstance(rtol, numpy.ndarray):
dp_rtol = inp.array(
rtol, usm_type=a_dp.usm_type, sycl_queue=a_dp.sycl_queue
if isinstance(tol, numpy.ndarray):
dp_tol = inp.array(
tol, usm_type=a_dp.usm_type, sycl_queue=a_dp.sycl_queue
)
else:
dp_rtol = rtol
dp_tol = tol

expected = numpy.linalg.matrix_rank(a, rtol=rtol)
result = inp.linalg.matrix_rank(a_dp, rtol=dp_rtol)
assert expected == result
expected = numpy.linalg.matrix_rank(a, rtol=tol)
result = inp.linalg.matrix_rank(a_dp, rtol=dp_tol)
assert_dtype_allclose(result, expected)

expected = numpy.linalg.matrix_rank(a, tol=tol)
result = inp.linalg.matrix_rank(a_dp, tol=dp_tol)
assert_dtype_allclose(result, expected)

def test_matrix_rank_errors(self):
a_dp = inp.array([[1, 2], [3, 4]], dtype="float32")
Expand Down

0 comments on commit 791def5

Please sign in to comment.