Skip to content

Commit

Permalink
#62 matrix_multiply_by_squaring
Browse files Browse the repository at this point in the history
  • Loading branch information
wojtask committed Mar 5, 2024
1 parent 956ec00 commit 4ab534d
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
34 changes: 34 additions & 0 deletions src/solutions/chapter4/section2/exercise6.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Callable

from book.data_structures import Matrix
from solutions.chapter4.section2.exercise2 import matrix_add
from util import range_of


def matrix_multiply_by_squaring(A: Matrix, B: Matrix, C: Matrix, matrix_square: Callable[[Matrix, Matrix, int], None],
n: int) -> None:
"""Multiplies two square matrices and adds the result to the third square matrix, using a function for squaring
matrices.
Args:
A: the first square matrix to multiply
B: the second square matrix to multiply
C: the matrix to add the result of the matrix multiplication
matrix_square: a function for squaring matrices.
The first argument is the square matrix to square, the second argument is the matrix to accumulate the
result, and the third argument is the dimension of the input matrix.
n: the dimension of matrices A and B
"""
D = __create_padded_input_matrix(A, B, n)
E = Matrix(2 * n, 2 * n)
matrix_square(D, E, 2 * n)
matrix_add(C, E.submatrix((n + 1, 2 * n), (1, n)), C, n)


def __create_padded_input_matrix(A, B, n):
M = Matrix(2 * n, 2 * n)
for i in range_of(1, to=n):
for j in range_of(1, to=n):
M[i, j] = B[i, j]
M[n + i, j] = A[i, j]
return M
23 changes: 23 additions & 0 deletions test/test_solutions/test_chapter4.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from hypothesis.strategies import integers
from hypothesis.strategies import lists

from book.chapter4.section1 import matrix_multiply
from book.data_structures import Matrix
from solutions.chapter4.section1.exercise1 import matrix_multiply_recursive_general
from solutions.chapter4.section1.exercise3 import matrix_multiply_recursive_by_copying
from solutions.chapter4.section1.exercise4 import matrix_add_recursive
from solutions.chapter4.section2.exercise2 import strassen
from solutions.chapter4.section2.exercise5 import complex_multiply
from solutions.chapter4.section2.exercise6 import matrix_multiply_by_squaring
from test_case import ClrsTestCase
from test_util import create_matrix

Expand Down Expand Up @@ -102,3 +104,24 @@ def test_complex_multiply(self, data):
actual_product = complex(real, imag)
expected_product = z1 * z2
self.assertAlmostEqual(actual_product, expected_product, places=7)

@given(st.data())
def test_matrix_multiply_by_squaring(self, data):
n = data.draw(integers(min_value=1, max_value=15), label="Matrices dimension")
elements1 = data.draw(
lists(lists(integers(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
label="First matrix elements")
elements2 = data.draw(
lists(lists(integers(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
label="Second matrix elements")
A = create_matrix(elements1)
B = create_matrix(elements2)
C = Matrix(n, n)

def matrix_square_function(D, E, m):
matrix_multiply(D, D, E, m)

matrix_multiply_by_squaring(A, B, C, matrix_square_function, n)

expected_product = create_matrix(numpy.matmul(elements1, elements2))
self.assertEqual(C, expected_product)

0 comments on commit 4ab534d

Please sign in to comment.