From 956ec00098e37eb4f870828326e5468875ab5fa9 Mon Sep 17 00:00:00 2001 From: Krzysztof Wojtas Date: Tue, 5 Mar 2024 14:21:13 +0100 Subject: [PATCH] #61 complex_multiply --- src/solutions/chapter4/section2/exercise5.py | 16 ++++++++++ test/test_book/test_chapter4.py | 10 +++---- test/test_solutions/test_chapter4.py | 31 +++++++++++++------- 3 files changed, 40 insertions(+), 17 deletions(-) create mode 100644 src/solutions/chapter4/section2/exercise5.py diff --git a/src/solutions/chapter4/section2/exercise5.py b/src/solutions/chapter4/section2/exercise5.py new file mode 100644 index 0000000..faca433 --- /dev/null +++ b/src/solutions/chapter4/section2/exercise5.py @@ -0,0 +1,16 @@ +def complex_multiply(a: float, b: float, c: float, d: float) -> (float, float): + """Multiplies two complex numbers using only three multiplications of real numbers. + + Args: + a: the real part of the first complex number + b: the imaginary part of the first complex number + c: the real part of the second complex number + d: the imaginary part of the second complex number + + Returns: + The real part and the imaginary part of the product. + """ + alpha = a * c + beta = b * d + gamma = (a + b) * (c + d) + return alpha - beta, gamma - alpha - beta diff --git a/test/test_book/test_chapter4.py b/test/test_book/test_chapter4.py index 8c9b986..fb87e79 100644 --- a/test/test_book/test_chapter4.py +++ b/test/test_book/test_chapter4.py @@ -28,9 +28,8 @@ def test_matrix_multiply(self, data): matrix_multiply(A, B, C, n) - actual_product = create_matrix(numpy.matmul(elements1, elements2)) - - self.assertEqual(C, actual_product) + expected_product = create_matrix(numpy.matmul(elements1, elements2)) + self.assertEqual(C, expected_product) @given(st.data()) def test_matrix_multiply_recursive(self, data): @@ -48,6 +47,5 @@ def test_matrix_multiply_recursive(self, data): matrix_multiply_recursive(A, B, C, n) - actual_product = create_matrix(numpy.matmul(elements1, elements2)) - - self.assertEqual(C, actual_product) + expected_product = create_matrix(numpy.matmul(elements1, elements2)) + self.assertEqual(C, expected_product) diff --git a/test/test_solutions/test_chapter4.py b/test/test_solutions/test_chapter4.py index 5c51e34..b31dd8e 100644 --- a/test/test_solutions/test_chapter4.py +++ b/test/test_solutions/test_chapter4.py @@ -1,6 +1,7 @@ import numpy from hypothesis import given from hypothesis import strategies as st +from hypothesis.strategies import complex_numbers from hypothesis.strategies import integers from hypothesis.strategies import lists @@ -9,6 +10,7 @@ from solutions.chapter4.section1.exercise3 import matrix_multiply_recursive_by_copying from solutions.chapter4.section1.exercise4 import matrix_add_recursive from solutions.chapter4.section2.exercise2 import strassen +from solutions.chapter4.section2.exercise5 import complex_multiply from test_case import ClrsTestCase from test_util import create_matrix @@ -30,9 +32,8 @@ def test_matrix_multiply_recursive_general(self, data): matrix_multiply_recursive_general(A, B, C, n) - actual_product = create_matrix(numpy.matmul(elements1, elements2)) - - self.assertEqual(C, actual_product) + expected_product = create_matrix(numpy.matmul(elements1, elements2)) + self.assertEqual(C, expected_product) @given(st.data()) def test_matrix_multiply_recursive_by_copying(self, data): @@ -50,9 +51,8 @@ def test_matrix_multiply_recursive_by_copying(self, data): matrix_multiply_recursive_by_copying(A, B, C, n) - actual_product = create_matrix(numpy.matmul(elements1, elements2)) - - self.assertEqual(C, actual_product) + expected_product = create_matrix(numpy.matmul(elements1, elements2)) + self.assertEqual(C, expected_product) @given(st.data()) def test_matrix_add_recursive(self, data): @@ -70,9 +70,8 @@ def test_matrix_add_recursive(self, data): matrix_add_recursive(A, B, C, n) - actual_sum = create_matrix(numpy.add(elements1, elements2)) - - self.assertEqual(C, actual_sum) + expected_sum = create_matrix(numpy.add(elements1, elements2)) + self.assertEqual(C, expected_sum) @given(st.data()) def test_strassen(self, data): @@ -90,6 +89,16 @@ def test_strassen(self, data): strassen(A, B, C, n) - actual_product = create_matrix(numpy.matmul(elements1, elements2)) + expected_product = create_matrix(numpy.matmul(elements1, elements2)) + self.assertEqual(C, expected_product) + + @given(st.data()) + def test_complex_multiply(self, data): + z1 = data.draw(complex_numbers(max_magnitude=1000.0), label="First complex number") + z2 = data.draw(complex_numbers(max_magnitude=1000.0), label="Second complex number") + + real, imag = complex_multiply(z1.real, z1.imag, z2.real, z2.imag) - self.assertEqual(C, actual_product) + actual_product = complex(real, imag) + expected_product = z1 * z2 + self.assertAlmostEqual(actual_product, expected_product, places=7)