From 98766105a1567a59e35fab4234a59fc0a407118c Mon Sep 17 00:00:00 2001 From: lschoe Date: Fri, 18 Nov 2022 18:10:05 +0100 Subject: [PATCH] Extend mpc.random.shuffle() to lists of lists. Consistent with mpc.sorted(), mpc.if_else(), mpc.if_swap(), mpc.min(), mpc.argmax() etc., which also work for lists of (all same length) lists. --- mpyc/random.py | 30 ++++++++++++++++++++++++------ tests/test_random.py | 5 ++++- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/mpyc/random.py b/mpyc/random.py index f410407b..6295611c 100644 --- a/mpyc/random.py +++ b/mpyc/random.py @@ -192,17 +192,35 @@ def shuffle(sectype, x): """Shuffle list x secretly in-place, and return None. Given list x may contain public or secret elements. + Elements of x are all numbers or all lists (of the same length) of numbers. """ n = len(x) - if not isinstance(x[0], sectype): # assume same type for all elts of x - for i in range(len(x)): - x[i] = sectype(x[i]) + # assume same type for all elts of x + x_i_is_list = isinstance(x[0], list) + if not x_i_is_list: + # elements of x are numbers + if not isinstance(x[0], sectype): + for i in range(n): + x[i] = sectype(x[i]) + for i in range(n-1): + u = random_unit_vector(sectype, n - i) + x_u = runtime.in_prod(x[i:], u) + d = runtime.scalar_mul(x[i] - x_u, u) + x[i] = x_u + x[i:] = runtime.vector_add(x[i:], d) + return + + # elements of x are lists of numbers + for j in range(len(x[0])): + if not isinstance(x[0][j], sectype): + for i in range(n): + x[i][j] = sectype(x[i][j]) for i in range(n-1): u = random_unit_vector(sectype, n - i) - x_u = runtime.in_prod(x[i:], u) - d = runtime.scalar_mul(x[i] - x_u, u) + x_u = runtime.matrix_prod([u], x[i:])[0] + d = runtime.matrix_prod([[a] for a in u], [runtime.vector_sub(x[i], x_u)]) x[i] = x_u - x[i:] = runtime.vector_add(x[i:], d) + x[i:] = runtime.matrix_add(x[i:], d) def random_permutation(sectype, x): diff --git a/tests/test_random.py b/tests/test_random.py index 598f8f45..9d410ffe 100644 --- a/tests/test_random.py +++ b/tests/test_random.py @@ -38,9 +38,12 @@ def test_secint(self): x = list(range(8)) shuffle(secint, x) - shuffle(secint, x) x = mpc.run(mpc.output(x)) self.assertSetEqual(set(x), set(range(8))) + x = list(map(list, zip(range(8), range(0, -8, -1)))) + shuffle(secint, x) + a = mpc.run(mpc.output(x[0])) + self.assertEqual(a[1], -a[0]) x = mpc.run(mpc.output(random_permutation(secint, 8))) self.assertSetEqual(set(x), set(range(8))) x = mpc.run(mpc.output(random_derangement(secint, 2)))