From 92693eda2623083343a932c78361a385bb9129a3 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Wed, 5 Jun 2024 11:58:41 +0200 Subject: [PATCH] More optimization --- src/admmsolver/matrix.py | 24 ++++++++++++++++++++---- test/test_matrix.py | 32 +++++++++++++++++++++++++++++--- 2 files changed, 49 insertions(+), 7 deletions(-) diff --git a/src/admmsolver/matrix.py b/src/admmsolver/matrix.py index f9c62b7..9897f5a 100644 --- a/src/admmsolver/matrix.py +++ b/src/admmsolver/matrix.py @@ -282,7 +282,14 @@ def __matmul__(self, other: Union[MatrixBase, np.ndarray]) -> Union[MatrixBase, (self.shape[0], other.shape[1]) ) elif isinstance(other, PartialDiagonalMatrix): - return DenseMatrix(self._diagonals[:, None] * other.asmatrix()) + diags = self.diagonals.reshape(other.matrix.shape[0], -1) + if np.allclose(diags, diags[:, 0:1]): + return PartialDiagonalMatrix( + diags[:, 0][:, None] * other.matrix.asmatrix(), + other.rest_dims + ) + else: + return DenseMatrix(self._diagonals[:, None] * other.asmatrix()) elif isinstance(other, ScaledIdentityMatrix): return self @ other.to_diagonal_matrix() else: @@ -339,10 +346,13 @@ def __matmul__(self, other) -> Union[np.ndarray, MatrixBase]: assert isinstance(other, MatrixBase) or isinstance(other, np.ndarray) if isinstance(other, np.ndarray): return self.matvec(other) - elif isinstance(other, PartialDiagonalMatrix) and self.rest_dims == other.rest_dims: + + if isinstance(other, PartialDiagonalMatrix) and self.rest_dims == other.rest_dims: return PartialDiagonalMatrix(self.matrix@other.matrix, self.rest_dims) - else: - return DenseMatrix(self.asmatrix() @ other.asmatrix()) + if isinstance(other, ScaledIdentityMatrix) and other.is_diagonal(): + return PartialDiagonalMatrix(other.coeff * self.matrix, self.rest_dims) + + return DenseMatrix(self.asmatrix() @ other.asmatrix()) def __mul__(self, other) -> 'PartialDiagonalMatrix': if type(other) in [float, complex, np.float64, np.complex128]: @@ -459,6 +469,12 @@ def _add_DiagonalMatrix_PartialDiagonalMatrix(a, b): return DenseMatrix(a.asmatrix() + b.asmatrix()) +def _add_PartialDiagonalMatrix_PartialDiagonalMatrix(a, b): + if a.rest_dims == b.rest_dims: + return PartialDiagonalMatrix(a.matrix + b.matrix, a.rest_dims) + return DenseMatrix(a.asmatrix() + b.asmatrix()) + + def _add_DenseMatrix_DenseMatrix(a, b): return DenseMatrix(a.asmatrix() + b.asmatrix()) diff --git a/test/test_matrix.py b/test/test_matrix.py index 51e9f42..c19331b 100644 --- a/test/test_matrix.py +++ b/test/test_matrix.py @@ -8,7 +8,7 @@ def _randn_cmplx(*shape) -> np.ndarray: return np.random.randn(*shape) + 1j * np.random.randn(*shape) -def test_matmal(): +def test_matmul(): np.random.seed(100) # (12, 12) * (12, 4) @@ -33,7 +33,7 @@ def test_matmal(): np.testing.assert_allclose(lr.asmatrix(), l.asmatrix() @ r.asmatrix()) -def test_matmal2(): +def test_matmul(): np.random.seed(100) # (4, 12) * (12, 12) @@ -124,6 +124,32 @@ def test_DiagonalMatrix_PartialDiagonalMatrix(): np.testing.assert_allclose(ab.asmatrix(), a.asmatrix() + b.asmatrix()) +def test_PartialDiagonalMatrix_PartialDiagonalMatrix(): + np.random.seed(100) + n = 3 + a = PartialDiagonalMatrix(_randn_cmplx(n, n), (2, 2)) + b = PartialDiagonalMatrix(_randn_cmplx(n, n), (2, 2)) + ab = a + b + assert isinstance(ab, PartialDiagonalMatrix) + np.testing.assert_allclose(ab.asmatrix(), a.asmatrix() + b.asmatrix()) + + +def test_matmul_DiagonalMatrix_PartialDiagonalMatrix(): + np.random.seed(100) + + n = 3 + diags_ = np.random.randn(n) + diags = np.zeros((n, 4)) + for i in range(4): + diags[:, i] = diags_ + + a = DiagonalMatrix(diags.ravel()) + b = PartialDiagonalMatrix(_randn_cmplx(n, n), (2, 2)) + + ab = a @ b + assert isinstance(ab, PartialDiagonalMatrix) + np.testing.assert_allclose(ab.asmatrix(), a.asmatrix() @ b.asmatrix()) + def test_inv(): np.random.seed(100) @@ -207,7 +233,7 @@ def test_batched_matvec(n, m): np.testing.assert_allclose(mv, m.asmatrix()@vec) -def test_matmal_diagonal(): +def test_matmul_diagonal(): np.random.seed(100) a = DiagonalMatrix(np.random.randn(2), shape=(4,2)) b = DiagonalMatrix(np.random.randn(2), shape=(2,4))