Skip to content

Commit

Permalink
#61 complex_multiply
Browse files Browse the repository at this point in the history
  • Loading branch information
wojtask committed Mar 5, 2024
1 parent 5b19414 commit 956ec00
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 17 deletions.
16 changes: 16 additions & 0 deletions src/solutions/chapter4/section2/exercise5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
def complex_multiply(a: float, b: float, c: float, d: float) -> (float, float):
"""Multiplies two complex numbers using only three multiplications of real numbers.
Args:
a: the real part of the first complex number
b: the imaginary part of the first complex number
c: the real part of the second complex number
d: the imaginary part of the second complex number
Returns:
The real part and the imaginary part of the product.
"""
alpha = a * c
beta = b * d
gamma = (a + b) * (c + d)
return alpha - beta, gamma - alpha - beta
10 changes: 4 additions & 6 deletions test/test_book/test_chapter4.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ def test_matrix_multiply(self, data):

matrix_multiply(A, B, C, n)

actual_product = create_matrix(numpy.matmul(elements1, elements2))

self.assertEqual(C, actual_product)
expected_product = create_matrix(numpy.matmul(elements1, elements2))
self.assertEqual(C, expected_product)

@given(st.data())
def test_matrix_multiply_recursive(self, data):
Expand All @@ -48,6 +47,5 @@ def test_matrix_multiply_recursive(self, data):

matrix_multiply_recursive(A, B, C, n)

actual_product = create_matrix(numpy.matmul(elements1, elements2))

self.assertEqual(C, actual_product)
expected_product = create_matrix(numpy.matmul(elements1, elements2))
self.assertEqual(C, expected_product)
31 changes: 20 additions & 11 deletions test/test_solutions/test_chapter4.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy
from hypothesis import given
from hypothesis import strategies as st
from hypothesis.strategies import complex_numbers
from hypothesis.strategies import integers
from hypothesis.strategies import lists

Expand All @@ -9,6 +10,7 @@
from solutions.chapter4.section1.exercise3 import matrix_multiply_recursive_by_copying
from solutions.chapter4.section1.exercise4 import matrix_add_recursive
from solutions.chapter4.section2.exercise2 import strassen
from solutions.chapter4.section2.exercise5 import complex_multiply
from test_case import ClrsTestCase
from test_util import create_matrix

Expand All @@ -30,9 +32,8 @@ def test_matrix_multiply_recursive_general(self, data):

matrix_multiply_recursive_general(A, B, C, n)

actual_product = create_matrix(numpy.matmul(elements1, elements2))

self.assertEqual(C, actual_product)
expected_product = create_matrix(numpy.matmul(elements1, elements2))
self.assertEqual(C, expected_product)

@given(st.data())
def test_matrix_multiply_recursive_by_copying(self, data):
Expand All @@ -50,9 +51,8 @@ def test_matrix_multiply_recursive_by_copying(self, data):

matrix_multiply_recursive_by_copying(A, B, C, n)

actual_product = create_matrix(numpy.matmul(elements1, elements2))

self.assertEqual(C, actual_product)
expected_product = create_matrix(numpy.matmul(elements1, elements2))
self.assertEqual(C, expected_product)

@given(st.data())
def test_matrix_add_recursive(self, data):
Expand All @@ -70,9 +70,8 @@ def test_matrix_add_recursive(self, data):

matrix_add_recursive(A, B, C, n)

actual_sum = create_matrix(numpy.add(elements1, elements2))

self.assertEqual(C, actual_sum)
expected_sum = create_matrix(numpy.add(elements1, elements2))
self.assertEqual(C, expected_sum)

@given(st.data())
def test_strassen(self, data):
Expand All @@ -90,6 +89,16 @@ def test_strassen(self, data):

strassen(A, B, C, n)

actual_product = create_matrix(numpy.matmul(elements1, elements2))
expected_product = create_matrix(numpy.matmul(elements1, elements2))
self.assertEqual(C, expected_product)

@given(st.data())
def test_complex_multiply(self, data):
z1 = data.draw(complex_numbers(max_magnitude=1000.0), label="First complex number")
z2 = data.draw(complex_numbers(max_magnitude=1000.0), label="Second complex number")

real, imag = complex_multiply(z1.real, z1.imag, z2.real, z2.imag)

self.assertEqual(C, actual_product)
actual_product = complex(real, imag)
expected_product = z1 * z2
self.assertAlmostEqual(actual_product, expected_product, places=7)

0 comments on commit 956ec00

Please sign in to comment.