diff --git a/src/book/chapter4/section1.py b/src/book/chapter4/section1.py index 2f178ca..8a21846 100644 --- a/src/book/chapter4/section1.py +++ b/src/book/chapter4/section1.py @@ -1,7 +1,4 @@ -from typing import Union - from book.data_structures import Matrix -from book.data_structures import Submatrix from util import range_of @@ -23,11 +20,7 @@ def matrix_multiply(A: Matrix, B: Matrix, C: Matrix, n: int) -> None: C[i, j] += A[i, k] * B[k, j] -def matrix_multiply_recursive( - A: Union[Matrix, Submatrix], - B: Union[Matrix, Submatrix], - C: Union[Matrix, Submatrix], - n: int) -> None: +def matrix_multiply_recursive(A: Matrix, B: Matrix, C: Matrix, n: int) -> None: """Recursively multiplies two square matrices and adds the result to the third square matrix. Implements: @@ -60,7 +53,7 @@ def __partition_matrices(A, B, C, n): def __partition_matrix(M, n): - return Submatrix(M, range_of(1, to=n // 2), range_of(1, to=n // 2)), \ - Submatrix(M, range_of(1, to=n // 2), range_of(n // 2 + 1, to=n)), \ - Submatrix(M, range_of(n // 2 + 1, to=n), range_of(1, to=n // 2)), \ - Submatrix(M, range_of(n // 2 + 1, to=n), range_of(n // 2 + 1, to=n)) + return M.submatrix((1, n // 2), (1, n // 2)), \ + M.submatrix((1, n // 2), (n // 2 + 1, n)), \ + M.submatrix((n // 2 + 1, n), (1, n // 2)), \ + M.submatrix((n // 2 + 1, n), (n // 2 + 1, n)) diff --git a/src/book/data_structures.py b/src/book/data_structures.py index 9b01cc0..6f889b1 100644 --- a/src/book/data_structures.py +++ b/src/book/data_structures.py @@ -20,23 +20,46 @@ def __setitem__(self, index: int, value: Any) -> None: class Matrix: - def __init__(self, rows: int, cols: int) -> None: - assert rows >= 0 and cols >= 0 + def __init__(self, rows: Union[int, Tuple[int, int]], cols: Union[int, Tuple[int, int]], elements=None) -> None: + if isinstance(rows, int) and isinstance(cols, int) and elements is None: + assert rows >= 0 + assert cols >= 0 + self.__create_matrix(rows, cols) + elif isinstance(rows, tuple) and isinstance(cols, tuple) and elements is not None: + self.__create_submatrix(rows, cols, elements) + else: + raise TypeError('Invalid parameters') + + def __create_matrix(self, rows, cols): + self.__start_row, self.__end_row = 1, rows + self.__start_col, self.__end_col = 1, cols self.__elements = [[0] * cols for _ in range(rows)] + def __create_submatrix(self, rows, cols, elements): + self.__start_row, self.__end_row = rows[0], rows[1] + self.__start_col, self.__end_col = cols[0], cols[1] + self.__elements = elements + + def submatrix(self, rows: Tuple[int, int], cols: Tuple[int, int]): + assert 1 <= rows[0] <= rows[1] <= len(self.__elements) + assert 1 <= cols[0] <= cols[1] <= len(self.__elements[0]) + rows_shifted = rows[0] + self.__start_row - 1, rows[1] + self.__start_row - 1 + cols_shifted = cols[0] + self.__start_col - 1, cols[1] + self.__start_col - 1 + return Matrix(rows_shifted, cols_shifted, self.__elements) + def __getitem__(self, indices: Tuple[int, int]) -> Union[int, float]: row = indices[0] col = indices[1] - assert 1 <= row <= len(self.__elements) - assert 1 <= col <= len(self.__elements[0]) - return self.__elements[row - 1][col - 1] + assert 1 <= row <= self.__end_row - self.__start_row + 1 + assert 1 <= col <= self.__end_col - self.__start_col + 1 + return self.__elements[self.__start_row - 1 + row - 1][self.__start_col - 1 + col - 1] def __setitem__(self, indices: Tuple[int, int], value: Union[int, float]) -> None: row = indices[0] col = indices[1] - assert 1 <= row <= len(self.__elements) - assert 1 <= col <= len(self.__elements[0]) - self.__elements[row - 1][col - 1] = value + assert 1 <= row <= self.__end_row - self.__start_row + 1 + assert 1 <= col <= self.__end_col - self.__start_col + 1 + self.__elements[self.__start_row - 1 + row - 1][self.__start_col - 1 + col - 1] = value def __eq__(self, other: Any) -> bool: if not isinstance(other, self.__class__): @@ -56,26 +79,3 @@ def __eq__(self, other: Any) -> bool: except AssertionError: pass return True - - -class Submatrix: - def __init__(self, matrix: Matrix, row_range: range, col_range: range) -> None: - self.__matrix = matrix - self.__start_row = row_range.start - self.__end_row = row_range.stop - self.__start_col = col_range.start - self.__end_col = col_range.stop - - def __getitem__(self, indices: Tuple[int, int]) -> Union[int, float]: - i = indices[0] - j = indices[1] - assert 1 <= i <= self.__end_row - self.__start_row + 1 - assert 1 <= j <= self.__end_col - self.__start_col + 1 - return self.__matrix[self.__start_row + i - 1, self.__start_col + j - 1] - - def __setitem__(self, indices: Tuple[int, int], value: Union[int, float]) -> None: - i = indices[0] - j = indices[1] - assert 1 <= i <= self.__end_row - self.__start_row + 1 - assert 1 <= j <= self.__end_col - self.__start_col + 1 - self.__matrix[self.__start_row + i - 1, self.__start_col + j - 1] = value diff --git a/test/test_book/test_chapter4.py b/test/test_book/test_chapter4.py index 8bd8452..8c9b986 100644 --- a/test/test_book/test_chapter4.py +++ b/test/test_book/test_chapter4.py @@ -15,7 +15,7 @@ class TestChapter4(ClrsTestCase): @given(st.data()) def test_matrix_multiply(self, data): - n = data.draw(integers(min_value=1, max_value=10), label="Matrices dimension") + n = data.draw(integers(min_value=1, max_value=15), label="Matrices dimension") elements1 = data.draw( lists(lists(integers(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n), label="First matrix elements")