Skip to content

Commit

Permalink
Merge pull request #1441 from vincent-ehrmanntraut/simplified_array_m…
Browse files Browse the repository at this point in the history
…atrix_mult

Add to_row_matrix and to_column_matrix to Array
  • Loading branch information
mkskeller authored Jul 3, 2024
2 parents 41999a3 + 70d65d3 commit b0dc2b3
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 11 deletions.
22 changes: 22 additions & 0 deletions Compiler/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6322,6 +6322,28 @@ def sort(self, n_threads=None, batcher=False, n_bits=None):
from . import sorting
sorting.radix_sort(self, self, n_bits=n_bits)

def to_row_matrix(self):
"""
Returns the array as 1xN matrix.
Warning: This operation is in-place (without copying data), i.e., all changes to the values of the matrix will also affect the original array.
:return: Matrix
"""
assert self.value_type.n_elements() == 1 and \
self.value_type.mem_size() == 1
return Matrix(1, self.length, self.value_type, address=self.address)

def to_column_matrix(self):
"""
Returns the array as Nx1 matrix.
Warning: This operation is in-place (without copying data), i.e., all changes to the values of the matrix will also affect the original array.
:return: Matrix
"""
assert self.value_type.n_elements() == 1 and \
self.value_type.mem_size() == 1
return Matrix(self.length, 1, self.value_type, address=self.address)

def Array(self, size):
# compatibility with registers
return Array(size, self.value_type)
Expand Down
19 changes: 8 additions & 11 deletions Programs/Source/test_dot.mpc
Original file line number Diff line number Diff line change
Expand Up @@ -49,35 +49,32 @@ def test_matrix(expected, actual):

crash()

break_point()
def hacky_array_dot_matrix(arr, mat):
# Arrays sadly do not have a dot function, therefore the array is converted into a 1 times n Matrix by copying memory addresses.
tmp = sint.Matrix(rows=1, columns=len(arr), address=arr.address)
result = tmp.dot(mat)
return sint.Array(mat.shape[1], result.address)

start_timer(3)

e3 = hacky_array_dot_matrix(a, c)
e3 = a.to_row_matrix().dot(c).to_array()
# b[0] = e3[0]
f3 = hacky_array_dot_matrix(b, d)
f3 = b.to_row_matrix().dot(d).to_array()
g3 = c.dot(b.to_column_matrix()).to_array()

stop_timer(3)

e3 = e3.reveal()
f3 = f3.reveal()
g3 = g3.reveal()

e3.print_reveal_nested()
f3.print_reveal_nested()
g3.print_reveal_nested()

test_array([70, 80, 90], e3)
test_array([56, 50, 44, 38], f3)
test_array([10, 28, 46, 64], g3)

start_timer(4)

e4 = hacky_array_dot_matrix(a, c)
e4 = a.to_row_matrix().dot(c).to_array()
b[-1] = e4[0]
f4 = hacky_array_dot_matrix(b, d)
f4 = b.to_row_matrix().dot(d).to_array()

stop_timer(4)

Expand Down

0 comments on commit b0dc2b3

Please sign in to comment.