Skip to content

Commit

Permalink
enhance test_matrix_power.py/cholesky.py/norm.py (#1038)
Browse files Browse the repository at this point in the history
* enhance test_matrix_power.py

* enhance test_ch

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
XiaLuNV and pre-commit-ci[bot] authored Sep 1, 2023
1 parent 2c13393 commit 8dddd18
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 11 deletions.
3 changes: 1 addition & 2 deletions cunumeric/linalg/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,4 @@
#


class LinAlgError(Exception):
pass
from numpy.linalg.linalg import LinAlgError # noqa: F401
4 changes: 1 addition & 3 deletions cunumeric/linalg/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,6 @@ def matrix_power(a: ndarray, n: int) -> ndarray:
a = empty_like(a)
a[...] = eye(a.shape[-2], dtype=a.dtype)
return a
elif n == 1:
return a.copy()

# Invert if necessary
if n < 0:
Expand All @@ -219,7 +217,7 @@ def matrix_power(a: ndarray, n: int) -> ndarray:

# Fast paths
if n == 1:
return a
return a.copy()
elif n == 2:
return matmul(a, a)
elif n == 3:
Expand Down
9 changes: 9 additions & 0 deletions tests/integration/test_cholesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ def test_array_negative_3dim():
num.linalg.cholesky(arr)


def test_array_negative():
arr = num.random.randint(0, 9, size=(3, 2, 3))
expected_exc = ValueError
with pytest.raises(expected_exc):
num.linalg.cholesky(arr)
with pytest.raises(expected_exc):
np.linalg.cholesky(arr)


def test_diagonal():
a = num.eye(10) * 10.0
b = num.linalg.cholesky(a)
Expand Down
21 changes: 15 additions & 6 deletions tests/integration/test_matrix_power.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,28 +73,37 @@ class TestMatrixPowerErrors:
def test_matrix_ndim_smaller_than_two(self, ndim):
shape = (3,) * ndim
a_num = mk_0to1_array(num, shape)
msg = "Expected at least 2d array"
with pytest.raises(num.linalg.LinAlgError, match=msg):
a_np = mk_0to1_array(np, shape)
expected_exc = num.linalg.LinAlgError
with pytest.raises(expected_exc):
num.linalg.matrix_power(a_num, 1)
with pytest.raises(expected_exc):
np.linalg.matrix_power(a_np, 1)

@pytest.mark.parametrize(
"shape", ((2, 1), (2, 2, 1)), ids=lambda shape: f"(shape={shape})"
)
def test_matrix_not_square(self, shape):
a_num = mk_0to1_array(num, shape)
msg = "Last 2 dimensions of the array must be square"
with pytest.raises(num.linalg.LinAlgError, match=msg):
a_np = mk_0to1_array(np, shape)
expected_exc = num.linalg.LinAlgError
with pytest.raises(expected_exc):
num.linalg.matrix_power(a_num, 1)
with pytest.raises(expected_exc):
np.linalg.matrix_power(a_np, 1)

@pytest.mark.parametrize(
"n", (-1.0, 1.0, [1], None), ids=lambda n: f"(n={n})"
)
def test_n_not_int(self, n):
shape = (2, 2)
a_num = mk_0to1_array(num, shape)
msg = "exponent must be an integer"
with pytest.raises(TypeError, match=msg):
a_np = mk_0to1_array(np, shape)
expected_exc = TypeError
with pytest.raises(expected_exc):
num.linalg.matrix_power(a_num, n)
with pytest.raises(expected_exc):
np.linalg.matrix_power(a_np, n)

def test_n_negative_int(self):
shape = (2, 2)
Expand Down
11 changes: 11 additions & 0 deletions tests/integration/test_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,17 @@ def test_axis_invalid_value(self, axis):
with pytest.raises(expected_exc):
num.linalg.norm(num_arrays[ndim], axis=axis)

def test_axis_out_of_bounds(self):
# raise ValueError("Improper number of dimensions to norm")
expected_exc = ValueError
ndim = 3

with pytest.raises(expected_exc):
np.linalg.norm(np_arrays[ndim], ord=1)

with pytest.raises(expected_exc):
num.linalg.norm(num_arrays[ndim], ord=1)

@pytest.mark.parametrize(
"ndim_axis",
((1, None), (2, 0)),
Expand Down

0 comments on commit 8dddd18

Please sign in to comment.