Skip to content

Commit

Permalink
#57 matrix_multiply_recursive_general
Browse files Browse the repository at this point in the history
  • Loading branch information
wojtask committed Nov 1, 2023
1 parent 5755b82 commit 9f3bedb
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 0 deletions.
Empty file.
Empty file.
42 changes: 42 additions & 0 deletions src/solutions/chapter4/section1/exercise1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from book.chapter4.section1 import matrix_multiply_recursive
from book.data_structures import Matrix
from util import range_of


def matrix_multiply_recursive_general(A: Matrix, B: Matrix, C: Matrix, n: int) -> None:
"""A generalized version of the recursive algorithm for multiplying two square matrices, supporting arbitrary matrix
dimensions.
Args:
A: the first square matrix to multiply
B: the second square matrix to multiply
C: the matrix to add to the result of the matrix multiplication
n: the dimension of matrices A and B
"""
m = __next_power_of_2(n)
A_ = __extend_matrix(A, n, m)
B_ = __extend_matrix(B, n, m)
C_ = __extend_matrix(C, n, m)
matrix_multiply_recursive(A_, B_, C_, m)
__copy_result(C_, C, n)


def __next_power_of_2(n: int) -> int:
res = 1
while res < n:
res *= 2
return res


def __extend_matrix(source: Matrix, source_size: int, extended_size: int) -> Matrix:
extended = Matrix(extended_size, extended_size)
for i in range_of(1, to=source_size):
for j in range_of(1, to=source_size):
extended[i, j] = source[i, j]
return extended


def __copy_result(extended: Matrix, original: Matrix, original_size: int) -> None:
for i in range_of(1, to=original_size):
for j in range_of(1, to=original_size):
original[i, j] = extended[i, j]
32 changes: 32 additions & 0 deletions test/test_solutions/test_chapter4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import numpy
from hypothesis import given
from hypothesis import strategies as st
from hypothesis.strategies import integers
from hypothesis.strategies import lists

from book.data_structures import Matrix
from solutions.chapter4.section1.exercise1 import matrix_multiply_recursive_general
from test_case import ClrsTestCase
from test_util import create_matrix


class TestChapter4(ClrsTestCase):

@given(st.data())
def test_matrix_multiply_recursive_general(self, data):
n = data.draw(integers(min_value=1, max_value=15), label="Matrices dimension")
elements1 = data.draw(
lists(lists(integers(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
label="First matrix elements")
elements2 = data.draw(
lists(lists(integers(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
label="Second matrix elements")
A = create_matrix(elements1)
B = create_matrix(elements2)
C = Matrix(n, n)

matrix_multiply_recursive_general(A, B, C, n)

actual_product = create_matrix(numpy.matmul(elements1, elements2))

self.assertEqual(C, actual_product)

0 comments on commit 9f3bedb

Please sign in to comment.