Skip to content

Commit

Permalink
Matrix can now create submatrices as different views on a current matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
wojtask committed Aug 12, 2023
1 parent 75ff9d5 commit 5755b82
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 44 deletions.
17 changes: 5 additions & 12 deletions src/book/chapter4/section1.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from typing import Union

from book.data_structures import Matrix
from book.data_structures import Submatrix
from util import range_of


Expand All @@ -23,11 +20,7 @@ def matrix_multiply(A: Matrix, B: Matrix, C: Matrix, n: int) -> None:
C[i, j] += A[i, k] * B[k, j]


def matrix_multiply_recursive(
A: Union[Matrix, Submatrix],
B: Union[Matrix, Submatrix],
C: Union[Matrix, Submatrix],
n: int) -> None:
def matrix_multiply_recursive(A: Matrix, B: Matrix, C: Matrix, n: int) -> None:
"""Recursively multiplies two square matrices and adds the result to the third square matrix.
Implements:
Expand Down Expand Up @@ -60,7 +53,7 @@ def __partition_matrices(A, B, C, n):


def __partition_matrix(M, n):
return Submatrix(M, range_of(1, to=n // 2), range_of(1, to=n // 2)), \
Submatrix(M, range_of(1, to=n // 2), range_of(n // 2 + 1, to=n)), \
Submatrix(M, range_of(n // 2 + 1, to=n), range_of(1, to=n // 2)), \
Submatrix(M, range_of(n // 2 + 1, to=n), range_of(n // 2 + 1, to=n))
return M.submatrix((1, n // 2), (1, n // 2)), \
M.submatrix((1, n // 2), (n // 2 + 1, n)), \
M.submatrix((n // 2 + 1, n), (1, n // 2)), \
M.submatrix((n // 2 + 1, n), (n // 2 + 1, n))
62 changes: 31 additions & 31 deletions src/book/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,46 @@ def __setitem__(self, index: int, value: Any) -> None:


class Matrix:
def __init__(self, rows: int, cols: int) -> None:
assert rows >= 0 and cols >= 0
def __init__(self, rows: Union[int, Tuple[int, int]], cols: Union[int, Tuple[int, int]], elements=None) -> None:
if isinstance(rows, int) and isinstance(cols, int) and elements is None:
assert rows >= 0
assert cols >= 0
self.__create_matrix(rows, cols)
elif isinstance(rows, tuple) and isinstance(cols, tuple) and elements is not None:
self.__create_submatrix(rows, cols, elements)
else:
raise TypeError('Invalid parameters')

def __create_matrix(self, rows, cols):
self.__start_row, self.__end_row = 1, rows
self.__start_col, self.__end_col = 1, cols
self.__elements = [[0] * cols for _ in range(rows)]

def __create_submatrix(self, rows, cols, elements):
self.__start_row, self.__end_row = rows[0], rows[1]
self.__start_col, self.__end_col = cols[0], cols[1]
self.__elements = elements

def submatrix(self, rows: Tuple[int, int], cols: Tuple[int, int]):
assert 1 <= rows[0] <= rows[1] <= len(self.__elements)
assert 1 <= cols[0] <= cols[1] <= len(self.__elements[0])
rows_shifted = rows[0] + self.__start_row - 1, rows[1] + self.__start_row - 1
cols_shifted = cols[0] + self.__start_col - 1, cols[1] + self.__start_col - 1
return Matrix(rows_shifted, cols_shifted, self.__elements)

def __getitem__(self, indices: Tuple[int, int]) -> Union[int, float]:
row = indices[0]
col = indices[1]
assert 1 <= row <= len(self.__elements)
assert 1 <= col <= len(self.__elements[0])
return self.__elements[row - 1][col - 1]
assert 1 <= row <= self.__end_row - self.__start_row + 1
assert 1 <= col <= self.__end_col - self.__start_col + 1
return self.__elements[self.__start_row - 1 + row - 1][self.__start_col - 1 + col - 1]

def __setitem__(self, indices: Tuple[int, int], value: Union[int, float]) -> None:
row = indices[0]
col = indices[1]
assert 1 <= row <= len(self.__elements)
assert 1 <= col <= len(self.__elements[0])
self.__elements[row - 1][col - 1] = value
assert 1 <= row <= self.__end_row - self.__start_row + 1
assert 1 <= col <= self.__end_col - self.__start_col + 1
self.__elements[self.__start_row - 1 + row - 1][self.__start_col - 1 + col - 1] = value

def __eq__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
Expand All @@ -56,26 +79,3 @@ def __eq__(self, other: Any) -> bool:
except AssertionError:
pass
return True


class Submatrix:
def __init__(self, matrix: Matrix, row_range: range, col_range: range) -> None:
self.__matrix = matrix
self.__start_row = row_range.start
self.__end_row = row_range.stop
self.__start_col = col_range.start
self.__end_col = col_range.stop

def __getitem__(self, indices: Tuple[int, int]) -> Union[int, float]:
i = indices[0]
j = indices[1]
assert 1 <= i <= self.__end_row - self.__start_row + 1
assert 1 <= j <= self.__end_col - self.__start_col + 1
return self.__matrix[self.__start_row + i - 1, self.__start_col + j - 1]

def __setitem__(self, indices: Tuple[int, int], value: Union[int, float]) -> None:
i = indices[0]
j = indices[1]
assert 1 <= i <= self.__end_row - self.__start_row + 1
assert 1 <= j <= self.__end_col - self.__start_col + 1
self.__matrix[self.__start_row + i - 1, self.__start_col + j - 1] = value
2 changes: 1 addition & 1 deletion test/test_book/test_chapter4.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class TestChapter4(ClrsTestCase):

@given(st.data())
def test_matrix_multiply(self, data):
n = data.draw(integers(min_value=1, max_value=10), label="Matrices dimension")
n = data.draw(integers(min_value=1, max_value=15), label="Matrices dimension")
elements1 = data.draw(
lists(lists(integers(min_value=-1000, max_value=1000), min_size=n, max_size=n), min_size=n, max_size=n),
label="First matrix elements")
Expand Down

0 comments on commit 5755b82

Please sign in to comment.