Skip to content

Commit

Permalink
#60 strassen, matrix_add, matrix_subtract
Browse files Browse the repository at this point in the history
  • Loading branch information
wojtask committed Mar 4, 2024
1 parent ab0603c commit 3cb385f
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 0 deletions.
Empty file.
70 changes: 70 additions & 0 deletions src/solutions/chapter4/section2/exercise2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from book.data_structures import Matrix
from util import range_of


def strassen(A: Matrix, B: Matrix, C: Matrix, n: int) -> None:
"""Multiplies two square matrices and adds the result to the third square matrix, using Strassen's algorithm.
Args:
A: the first square matrix to multiply
B: the second square matrix to multiply
C: the matrix to add the result of the matrix multiplication
n: the dimension of matrices A and B
"""
if n == 1:
C[1, 1] += A[1, 1] * B[1, 1]
return
(A11, A12, A21, A22), (B11, B12, B21, B22), (C11, C12, C21, C22) = __partition_matrices(A, B, C, n)
(S1, S2, S3, S4, S5, S6, S7, S8, S9, S10, P1, P2, P3, P4, P5, P6, P7) = __create_intermediate_matrices(n // 2)
matrix_subtract(B12, B22, S1, n // 2)
matrix_add(A11, A12, S2, n // 2)
matrix_add(A21, A22, S3, n // 2)
matrix_subtract(B21, B11, S4, n // 2)
matrix_add(A11, A22, S5, n // 2)
matrix_add(B11, B22, S6, n // 2)
matrix_subtract(A12, A22, S7, n // 2)
matrix_add(B21, B22, S8, n // 2)
matrix_subtract(A11, A21, S9, n // 2)
matrix_add(B11, B12, S10, n // 2)
strassen(A11, S1, P1, n // 2)
strassen(S2, B22, P2, n // 2)
strassen(S3, B11, P3, n // 2)
strassen(A22, S4, P4, n // 2)
strassen(S5, S6, P5, n // 2)
strassen(S7, S8, P6, n // 2)
strassen(S9, S10, P7, n // 2)
matrix_add(P5, P4, C11, n // 2)
matrix_subtract(P6, P2, C11, n // 2)
matrix_add(P1, P2, C12, n // 2)
matrix_add(P3, P4, C21, n // 2)
matrix_subtract(P5, P3, C22, n // 2)
matrix_subtract(P1, P7, C22, n // 2)


def __partition_matrices(A, B, C, n):
return __partition_matrix(A, n), \
__partition_matrix(B, n), \
__partition_matrix(C, n)


def __partition_matrix(M, n):
return M.submatrix((1, n // 2), (1, n // 2)), \
M.submatrix((1, n // 2), (n // 2 + 1, n)), \
M.submatrix((n // 2 + 1, n), (1, n // 2)), \
M.submatrix((n // 2 + 1, n), (n // 2 + 1, n))


def __create_intermediate_matrices(n):
return [Matrix(n, n) for _ in range_of(1, to=17)]


def matrix_add(A, B, C, n):
for i in range_of(1, to=n):
for j in range_of(1, to=n):
C[i, j] += A[i, j] + B[i, j]


def matrix_subtract(A, B, C, n):
for i in range_of(1, to=n):
for j in range_of(1, to=n):
C[i, j] += A[i, j] - B[i, j]
21 changes: 21 additions & 0 deletions test/test_solutions/test_chapter4.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from solutions.chapter4.section1.exercise1 import matrix_multiply_recursive_general
from solutions.chapter4.section1.exercise3 import matrix_multiply_recursive_by_copying
from solutions.chapter4.section1.exercise4 import matrix_add_recursive
from solutions.chapter4.section2.exercise2 import strassen
from test_case import ClrsTestCase
from test_util import create_matrix

Expand Down Expand Up @@ -72,3 +73,23 @@ def test_matrix_add_recursive(self, data):
actual_sum = create_matrix(numpy.add(elements1, elements2))

self.assertEqual(C, actual_sum)

@given(st.data())
def test_strassen(self, data):
k = data.draw(integers(min_value=0, max_value=4), label="Matrices dimension exponent")
n = 2 ** k
elements1 = data.draw(
lists(lists(integers(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
label="First matrix elements")
elements2 = data.draw(
lists(lists(integers(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
label="Second matrix elements")
A = create_matrix(elements1)
B = create_matrix(elements2)
C = Matrix(n, n)

strassen(A, B, C, n)

actual_product = create_matrix(numpy.matmul(elements1, elements2))

self.assertEqual(C, actual_product)

0 comments on commit 3cb385f

Please sign in to comment.