From 75ff9d5fd31e47b414553d0f986bf5ce3408775a Mon Sep 17 00:00:00 2001 From: Krzysztof Wojtas Date: Thu, 10 Aug 2023 00:44:15 +0200 Subject: [PATCH] fix types hints --- src/book/chapter4/section1.py | 8 +++++++- src/book/data_structures.py | 10 ++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/book/chapter4/section1.py b/src/book/chapter4/section1.py index 652c300..2f178ca 100644 --- a/src/book/chapter4/section1.py +++ b/src/book/chapter4/section1.py @@ -1,3 +1,5 @@ +from typing import Union + from book.data_structures import Matrix from book.data_structures import Submatrix from util import range_of @@ -21,7 +23,11 @@ def matrix_multiply(A: Matrix, B: Matrix, C: Matrix, n: int) -> None: C[i, j] += A[i, k] * B[k, j] -def matrix_multiply_recursive(A: Matrix | Submatrix, B: Matrix | Submatrix, C: Matrix | Submatrix, n: int) -> None: +def matrix_multiply_recursive( + A: Union[Matrix, Submatrix], + B: Union[Matrix, Submatrix], + C: Union[Matrix, Submatrix], + n: int) -> None: """Recursively multiplies two square matrices and adds the result to the third square matrix. Implements: diff --git a/src/book/data_structures.py b/src/book/data_structures.py index 0af2748..9b01cc0 100644 --- a/src/book/data_structures.py +++ b/src/book/data_structures.py @@ -1,5 +1,7 @@ from builtins import len from typing import Any +from typing import Tuple +from typing import Union class Array: @@ -22,14 +24,14 @@ def __init__(self, rows: int, cols: int) -> None: assert rows >= 0 and cols >= 0 self.__elements = [[0] * cols for _ in range(rows)] - def __getitem__(self, indices: tuple) -> int | float: + def __getitem__(self, indices: Tuple[int, int]) -> Union[int, float]: row = indices[0] col = indices[1] assert 1 <= row <= len(self.__elements) assert 1 <= col <= len(self.__elements[0]) return self.__elements[row - 1][col - 1] - def __setitem__(self, indices: tuple, value: int | float) -> None: + def __setitem__(self, indices: Tuple[int, int], value: Union[int, float]) -> None: row = indices[0] col = indices[1] assert 1 <= row <= len(self.__elements) @@ -64,14 +66,14 @@ def __init__(self, matrix: Matrix, row_range: range, col_range: range) -> None: self.__start_col = col_range.start self.__end_col = col_range.stop - def __getitem__(self, indices: tuple) -> int | float: + def __getitem__(self, indices: Tuple[int, int]) -> Union[int, float]: i = indices[0] j = indices[1] assert 1 <= i <= self.__end_row - self.__start_row + 1 assert 1 <= j <= self.__end_col - self.__start_col + 1 return self.__matrix[self.__start_row + i - 1, self.__start_col + j - 1] - def __setitem__(self, indices: tuple, value: int | float) -> None: + def __setitem__(self, indices: Tuple[int, int], value: Union[int, float]) -> None: i = indices[0] j = indices[1] assert 1 <= i <= self.__end_row - self.__start_row + 1