-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#57 matrix_multiply_recursive_general
- Loading branch information
Showing
4 changed files
with
74 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from book.chapter4.section1 import matrix_multiply_recursive | ||
from book.data_structures import Matrix | ||
from util import range_of | ||
|
||
|
||
def matrix_multiply_recursive_general(A: Matrix, B: Matrix, C: Matrix, n: int) -> None: | ||
"""A generalized version of the recursive algorithm for multiplying two square matrices, supporting arbitrary matrix | ||
dimensions. | ||
Args: | ||
A: the first square matrix to multiply | ||
B: the second square matrix to multiply | ||
C: the matrix to add to the result of the matrix multiplication | ||
n: the dimension of matrices A and B | ||
""" | ||
m = __next_power_of_2(n) | ||
A_ = __extend_matrix(A, n, m) | ||
B_ = __extend_matrix(B, n, m) | ||
C_ = __extend_matrix(C, n, m) | ||
matrix_multiply_recursive(A_, B_, C_, m) | ||
__copy_result(C_, C, n) | ||
|
||
|
||
def __next_power_of_2(n: int) -> int: | ||
res = 1 | ||
while res < n: | ||
res *= 2 | ||
return res | ||
|
||
|
||
def __extend_matrix(source: Matrix, source_size: int, extended_size: int) -> Matrix: | ||
extended = Matrix(extended_size, extended_size) | ||
for i in range_of(1, to=source_size): | ||
for j in range_of(1, to=source_size): | ||
extended[i, j] = source[i, j] | ||
return extended | ||
|
||
|
||
def __copy_result(extended: Matrix, original: Matrix, original_size: int) -> None: | ||
for i in range_of(1, to=original_size): | ||
for j in range_of(1, to=original_size): | ||
original[i, j] = extended[i, j] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import numpy | ||
from hypothesis import given | ||
from hypothesis import strategies as st | ||
from hypothesis.strategies import integers | ||
from hypothesis.strategies import lists | ||
|
||
from book.data_structures import Matrix | ||
from solutions.chapter4.section1.exercise1 import matrix_multiply_recursive_general | ||
from test_case import ClrsTestCase | ||
from test_util import create_matrix | ||
|
||
|
||
class TestChapter4(ClrsTestCase): | ||
|
||
@given(st.data()) | ||
def test_matrix_multiply_recursive_general(self, data): | ||
n = data.draw(integers(min_value=1, max_value=15), label="Matrices dimension") | ||
elements1 = data.draw( | ||
lists(lists(integers(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n), | ||
label="First matrix elements") | ||
elements2 = data.draw( | ||
lists(lists(integers(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n), | ||
label="Second matrix elements") | ||
A = create_matrix(elements1) | ||
B = create_matrix(elements2) | ||
C = Matrix(n, n) | ||
|
||
matrix_multiply_recursive_general(A, B, C, n) | ||
|
||
actual_product = create_matrix(numpy.matmul(elements1, elements2)) | ||
|
||
self.assertEqual(C, actual_product) |