From 3cb385f23f40bc6c229d544cb1cb3c5ff3ce0840 Mon Sep 17 00:00:00 2001 From: Krzysztof Wojtas Date: Mon, 4 Mar 2024 22:59:50 +0100 Subject: [PATCH] #60 strassen, matrix_add, matrix_subtract --- src/solutions/chapter4/section2/__init__.py | 0 src/solutions/chapter4/section2/exercise2.py | 70 ++++++++++++++++++++ test/test_solutions/test_chapter4.py | 21 ++++++ 3 files changed, 91 insertions(+) create mode 100644 src/solutions/chapter4/section2/__init__.py create mode 100644 src/solutions/chapter4/section2/exercise2.py diff --git a/src/solutions/chapter4/section2/__init__.py b/src/solutions/chapter4/section2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/solutions/chapter4/section2/exercise2.py b/src/solutions/chapter4/section2/exercise2.py new file mode 100644 index 0000000..12d1b4d --- /dev/null +++ b/src/solutions/chapter4/section2/exercise2.py @@ -0,0 +1,70 @@ +from book.data_structures import Matrix +from util import range_of + + +def strassen(A: Matrix, B: Matrix, C: Matrix, n: int) -> None: + """Multiplies two square matrices and adds the result to the third square matrix, using Strassen's algorithm. + + Args: + A: the first square matrix to multiply + B: the second square matrix to multiply + C: the matrix to add the result of the matrix multiplication + n: the dimension of matrices A and B + """ + if n == 1: + C[1, 1] += A[1, 1] * B[1, 1] + return + (A11, A12, A21, A22), (B11, B12, B21, B22), (C11, C12, C21, C22) = __partition_matrices(A, B, C, n) + (S1, S2, S3, S4, S5, S6, S7, S8, S9, S10, P1, P2, P3, P4, P5, P6, P7) = __create_intermediate_matrices(n // 2) + matrix_subtract(B12, B22, S1, n // 2) + matrix_add(A11, A12, S2, n // 2) + matrix_add(A21, A22, S3, n // 2) + matrix_subtract(B21, B11, S4, n // 2) + matrix_add(A11, A22, S5, n // 2) + matrix_add(B11, B22, S6, n // 2) + matrix_subtract(A12, A22, S7, n // 2) + matrix_add(B21, B22, S8, n // 2) + matrix_subtract(A11, A21, S9, n // 2) + matrix_add(B11, B12, S10, n // 2) + strassen(A11, S1, P1, n // 2) + strassen(S2, B22, P2, n // 2) + strassen(S3, B11, P3, n // 2) + strassen(A22, S4, P4, n // 2) + strassen(S5, S6, P5, n // 2) + strassen(S7, S8, P6, n // 2) + strassen(S9, S10, P7, n // 2) + matrix_add(P5, P4, C11, n // 2) + matrix_subtract(P6, P2, C11, n // 2) + matrix_add(P1, P2, C12, n // 2) + matrix_add(P3, P4, C21, n // 2) + matrix_subtract(P5, P3, C22, n // 2) + matrix_subtract(P1, P7, C22, n // 2) + + +def __partition_matrices(A, B, C, n): + return __partition_matrix(A, n), \ + __partition_matrix(B, n), \ + __partition_matrix(C, n) + + +def __partition_matrix(M, n): + return M.submatrix((1, n // 2), (1, n // 2)), \ + M.submatrix((1, n // 2), (n // 2 + 1, n)), \ + M.submatrix((n // 2 + 1, n), (1, n // 2)), \ + M.submatrix((n // 2 + 1, n), (n // 2 + 1, n)) + + +def __create_intermediate_matrices(n): + return [Matrix(n, n) for _ in range_of(1, to=17)] + + +def matrix_add(A, B, C, n): + for i in range_of(1, to=n): + for j in range_of(1, to=n): + C[i, j] += A[i, j] + B[i, j] + + +def matrix_subtract(A, B, C, n): + for i in range_of(1, to=n): + for j in range_of(1, to=n): + C[i, j] += A[i, j] - B[i, j] diff --git a/test/test_solutions/test_chapter4.py b/test/test_solutions/test_chapter4.py index 4d704ec..5c51e34 100644 --- a/test/test_solutions/test_chapter4.py +++ b/test/test_solutions/test_chapter4.py @@ -8,6 +8,7 @@ from solutions.chapter4.section1.exercise1 import matrix_multiply_recursive_general from solutions.chapter4.section1.exercise3 import matrix_multiply_recursive_by_copying from solutions.chapter4.section1.exercise4 import matrix_add_recursive +from solutions.chapter4.section2.exercise2 import strassen from test_case import ClrsTestCase from test_util import create_matrix @@ -72,3 +73,23 @@ def test_matrix_add_recursive(self, data): actual_sum = create_matrix(numpy.add(elements1, elements2)) self.assertEqual(C, actual_sum) + + @given(st.data()) + def test_strassen(self, data): + k = data.draw(integers(min_value=0, max_value=4), label="Matrices dimension exponent") + n = 2 ** k + elements1 = data.draw( + lists(lists(integers(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n), + label="First matrix elements") + elements2 = data.draw( + lists(lists(integers(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n), + label="Second matrix elements") + A = create_matrix(elements1) + B = create_matrix(elements2) + C = Matrix(n, n) + + strassen(A, B, C, n) + + actual_product = create_matrix(numpy.matmul(elements1, elements2)) + + self.assertEqual(C, actual_product)