diff --git a/.travis.yml b/.travis.yml index eea95181..133dffbd 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,6 +4,7 @@ python: - 3.8 - 3.9 - 3.10 + - 3.11 - pypy3.8-7.3.9 install: - pip install --upgrade pip diff --git a/demos/lpsolver.py b/demos/lpsolver.py index 1d91f2dd..5a625ff2 100644 --- a/demos/lpsolver.py +++ b/demos/lpsolver.py @@ -153,13 +153,13 @@ async def main(): args = parser.parse_args() settings = [('uvlp', 8, 1, 2), - ('wiki', 6, 1, 2), + ('wiki', 6, 1, 1), ('tb2x2', 6, 1, 2), ('woody', 8, 1, 3), - ('LPExample_R20', 70, 1, 5), + ('LPExample_R20', 70, 1, 9), ('sc50b', 104, 10, 55), - ('kb2', 536, 100000, 106), - ('LPExample', 110, 1, 178)] + ('kb2', 560, 100000, 154), + ('LPExample', 110, 1, 175)] name, bit_length, scale, n_iter = settings[args.dataset] if args.bit_length: bit_length = args.bit_length @@ -200,11 +200,9 @@ async def main(): previous_pivot = secint(1) iteration = 0 - while True: + while await mpc.output((arg_min := argmin_int(T[0][:-1]))[1] < 0): # find index of pivot column - p_col_index, minimum = argmin_int(T[0][:-1]) - if await mpc.output(minimum >= 0): - break # maximum reached + p_col_index = arg_min[0] # find index of pivot row p_col = mpc.matrix_prod([p_col_index], T, True)[0] diff --git a/demos/lpsolverfxp.py b/demos/lpsolverfxp.py index 6ee5af35..01df4c99 100644 --- a/demos/lpsolverfxp.py +++ b/demos/lpsolverfxp.py @@ -87,12 +87,9 @@ async def main(): basis = [secfxp(n + i) for i in range(m)] iteration = 0 - while True: + while await mpc.output((arg_min := argmin_int(T[0][:-1]))[1] < 0): # find index of pivot column - p_col_index, minimum = argmin_int(T[0][:-1]) - - if await mpc.output(minimum >= 0): - break # maximum reached + p_col_index = arg_min[0] # find index of pivot row p_col = mpc.matrix_prod([p_col_index], T, True)[0] diff --git a/demos/np_bnnmnist.py b/demos/np_bnnmnist.py index 7d39cc3c..3294700a 100644 --- a/demos/np_bnnmnist.py +++ b/demos/np_bnnmnist.py @@ -240,8 +240,7 @@ async def main(): if args.no_legendre: secint.bit_length = 14 for i in range(batch_size): - prediction = int(await mpc.output(mpc.argmax(L[i].tolist())[0])) - + prediction = await mpc.output(np.argmax(L[i])) error = '******* ERROR *******' if prediction != labels[i] else '' print(f'Image #{offset+i} with label {labels[i]}: {prediction} predicted. {error}') print(await mpc.output(L[i])) diff --git a/demos/np_cnnmnist.py b/demos/np_cnnmnist.py index 2bd1e39b..36a5d7f3 100644 --- a/demos/np_cnnmnist.py +++ b/demos/np_cnnmnist.py @@ -177,7 +177,7 @@ async def main(): secnum.bit_length = 37 for i in range(batch_size): - prediction = int(await mpc.output(mpc.argmax(x[i].tolist())[0])) + prediction = int(await mpc.output(np.argmax(x[i]))) error = '******* ERROR *******' if prediction != labels[i] else '' print(f'Image #{offset+i} with label {labels[i]}: {prediction} predicted. {error}') print(await mpc.output(x[i])) diff --git a/demos/np_id3gini.py b/demos/np_id3gini.py index 85ce3e4c..57077881 100644 --- a/demos/np_id3gini.py +++ b/demos/np_id3gini.py @@ -6,7 +6,6 @@ See id3gini.py for background information on decision tree learning and ID3. """ -# TODO: vectorize mpc.argmax() import os import logging @@ -20,7 +19,7 @@ @mpc.coroutine async def id3(T, R) -> asyncio.Future: sizes = S[C] @ T - i, mx = mpc.argmax(sizes) + i, mx = sizes.argmax(raw=False) sizeT = sizes.sum() stop = (sizeT <= int(args.epsilon * len(T))) + (mx == sizeT) if not (R and await mpc.is_zero_public(stop)): @@ -29,7 +28,8 @@ async def id3(T, R) -> asyncio.Future: tree = i else: T_SC = (T * S[C]).T - k = mpc.argmax([GI(S[A] @ T_SC) for A in R], key=SecureFraction)[0] + CT = np.stack(tuple(GI(S[A] @ T_SC) for A in R)) + k = CT.argmax(key=SecureFraction, raw=False, raw2=False) A = list(R)[await mpc.output(k)] logging.info(f'Attribute node {A}') T_SA = T * S[A] @@ -46,15 +46,15 @@ def GI(x): y = args.alpha * np.sum(x, axis=1) + 1 # NB: alternatively, use s + (s == 0) D = mpc.prod(y.tolist()) G = np.sum(np.sum(x * x, axis=1) / y) - return [D * G, D] # numerator, denominator + return mpc.np_fromlist([D * G, D]) # numerator, denominator class SecureFraction: def __init__(self, a): - self.n, self.d = a # numerator, denominator + self.a = a # numerator, denominator def __lt__(self, other): # NB: __lt__() is basic comparison as in Python's list.sort() - return mpc.in_prod([self.n, -self.d], [other.d, other.n]) < 0 + return self.a[:, 0] * other.a[:, 1] < self.a[:, 1] * other.a[:, 0] depth = lambda tree: 0 if isinstance(tree, int) else max(map(depth, tree[1])) + 1 diff --git a/demos/np_lpsolver.py b/demos/np_lpsolver.py index f50bfa08..b8994359 100644 --- a/demos/np_lpsolver.py +++ b/demos/np_lpsolver.py @@ -16,69 +16,12 @@ from mpyc.runtime import mpc -# TODO: unify approaches for argmin etc. with secure NumPy arrays (also see lpsolverfxp.py) +class SecureFraction: + def __init__(self, a): + self.a = a # numerator, denominator - -def argmin_int(x): - secintarr = type(x) - n = len(x) - if n == 1: - return (secintarr(np.array([1])), x[0]) - - if n == 2: - b = x[0] < x[1] - arg = mpc.np_fromlist([b, 1 - b]) - min = b * (x[0] - x[1]) + x[1] - return arg, min - - a = x[:n//2] ## split even odd? start at n%2 as in reduce in mpctools - b = x[(n+1)//2:] - c = a < b - m = c * (a - b) + b - if n%2 == 1: - m = np.concatenate((m, x[n//2:(n+1)//2])) - ag, mn = argmin_int(m) - if n%2 == 1: - ag_1 = ag[-1:] - ag = ag[:-1] - arg1 = ag * c - arg2 = ag - arg1 - if n%2 == 1: - arg = np.concatenate((arg1, ag_1, arg2)) - else: - arg = np.concatenate((arg1, arg2)) - return arg, mn - - -def argmin_rat(nd): - secintarr = type(nd) - N = nd.shape[1] - if N == 1: - return (secintarr(np.array([1])), (nd[0, 0], nd[1, 0])) - - if N == 2: - b = mpc.in_prod([nd[0, 0], -nd[1, 0]], [nd[1, 1], nd[0, 1]]) < 0 - arg = mpc.np_fromlist([b, 1 - b]) - min = b * (nd[:, 0] - nd[:, 1]) + nd[:, 1] - return arg, (min[0], min[1]) - - a = nd[:, :N//2] - b = nd[:, (N+1)//2:] - c = a[0] * b[1] < b[0] * a[1] - m = c * (a - b) + b - if N%2 == 1: - m = np.concatenate((m, nd[:, N//2:(N+1)//2]), axis=1) - ag, mn = argmin_rat(m) - if N%2 == 1: - ag_1 = ag[-1:] - ag = ag[:-1] - arg1 = ag * c - arg2 = ag - arg1 - if N%2 == 1: - arg = np.concatenate((arg1, ag_1, arg2)) - else: - arg = np.concatenate((arg1, arg2)) - return arg, mn + def __lt__(self, other): # NB: __lt__() is basic comparison as in Python's list.sort() + return self.a[:, 0] * other.a[:, 1] < self.a[:, 1] * other.a[:, 0] def pow_list(a, x, n): @@ -135,13 +78,13 @@ async def main(): args = parser.parse_args() settings = [('uvlp', 8, 1, 2), - ('wiki', 6, 1, 2), + ('wiki', 6, 1, 1), ('tb2x2', 6, 1, 2), ('woody', 8, 1, 3), - ('LPExample_R20', 70, 1, 5), + ('LPExample_R20', 70, 1, 9), ('sc50b', 104, 10, 55), - ('kb2', 536, 100000, 106), - ('LPExample', 110, 1, 178)] + ('kb2', 560, 100000, 154), + ('LPExample', 110, 1, 175)] name, bit_length, scale, n_iter = settings[args.dataset] if args.bit_length: bit_length = args.bit_length @@ -173,17 +116,15 @@ async def main(): previous_pivot = secint(1) iteration = 0 - while True: + while await mpc.output((arg_min := T[0, :-1].argmin())[1] < 0): # find index of pivot column - p_col_index, minimum = argmin_int(T[0, :-1]) - if await mpc.output(minimum >= 0): - break # maximum reached + p_col_index = arg_min[0] # find index of pivot row p_col = T[:, :-1] @ p_col_index - den = p_col[1:] - num = T[1:, -1] + (den <= 0) - p_row_index, (_, pivot) = argmin_rat(np.stack((num, den))) + denominator = p_col[1:] + constraints = np.column_stack((T[1:, -1] + (denominator <= 0), denominator)) + p_row_index, (_, pivot) = constraints.argmin(key=SecureFraction) # reveal progress a bit iteration += 1 @@ -198,7 +139,6 @@ async def main(): # update tableau Tij = Tij*Tkl/Tkl' - (Til/Tkl' - bool(i==k)) * (Tkj + bool(j==l)*Tkl') p_col_index = np.concatenate((p_col_index, np.array([0]))) p_row_index = np.concatenate((np.array([0]), p_row_index)) - pp_inv = 1 / previous_pivot p_col = p_col * pp_inv - p_row_index p_row = p_row_index @ T + previous_pivot * p_col_index diff --git a/demos/np_lpsolverfxp.py b/demos/np_lpsolverfxp.py index f1d86702..8ea4ef62 100644 --- a/demos/np_lpsolverfxp.py +++ b/demos/np_lpsolverfxp.py @@ -15,106 +15,19 @@ from mpyc.runtime import mpc -# TODO: unify approaches for argmin etc. with secure NumPy arrays (also see lpsolver.py) - -# def argmin_int(xs): - # a, m = mpc.argmin(xs) - # u = mpc.unit_vector(a, len(xs)) - # u = mpc.np_fromlist(u) - # u.integral = True - # return u, m - - -def argmin_int(x): - secarr = type(x) - n = len(x) - if n == 1: - return (secarr(np.array([1])), x[0]) - - if n == 2: - b = x[0] < x[1] - arg = mpc.np_fromlist([b, 1 - b]) - min = b * (x[0] - x[1]) + x[1] - arg.integral = True - return arg, min - - a = x[:n//2] ## split even odd? start at n%2 as in reduce in mpctools - b = x[(n+1)//2:] - c = a < b - m = c * (a - b) + b - if n%2 == 1: - m = np.concatenate((m, x[n//2:(n+1)//2])) - ag, mn = argmin_int(m) - if n%2 == 1: - ag_1 = ag[-1:] - ag = ag[:-1] - arg1 = ag * c - arg2 = ag - arg1 - if n%2 == 1: - arg = np.concatenate((arg1, ag_1, arg2)) - else: - arg = np.concatenate((arg1, arg2)) - arg.integral = True - return arg, mn - - -def argmin_rat(nd, p): - secarr = type(nd) - n = nd.shape[1] - if n == 1: - return secarr(np.array([1])), (nd[0, 0], nd[1, 0]) - - if n == 2: - b = mpc.in_prod([nd[0, 0], -nd[1, 0]], [nd[1, 1], nd[0, 1]]) < 0 - c0 = p[0] -# c0.integral = True - b = mpc.if_else(c0, b, 0) - c1 = p[1] -# c1.integral = True - b = mpc.if_else(c1, b, 1) -# b.integral = True - assert b.integral - arg = mpc.np_fromlist([b, 1 - b]) - min = b * (nd[:, 0] - nd[:, 1]) + nd[:, 1] - return arg, min - - a = nd[:, :n//2] - b = nd[:, (n+1)//2:] - aa = np.stack((-a[1], a[0])) - aa.integral = False - b.integral = False - aa = aa.T.reshape(n//2, 1, 2) - c = aa @ b.T.reshape(n//2, 2, 1) < 0 # c = a[0] * b[1] < b[0] * a[1] - c = c.reshape(len(c)) - assert c.integral - assert p.integral -# c = a[2] * c - a2 = p[:n//2] - c *= a2 - assert c.integral -# c = b[2] * (c - 1) + 1 - b2 = p[(n+1)//2:] - c = b2 * (c - 1) + 1 - assert c.integral - m = c * (a - b) + b - mp = c * (a2 - b2) + b2 - assert mp.integral - if n%2 == 1: - m = np.concatenate((m, nd[:, n//2:(n+1)//2]), axis=1) - mp = np.concatenate((mp, p[n//2:(n+1)//2])) - mp.integral = True - ag, mn = argmin_rat(m, mp) - if n%2 == 1: - ag_1 = ag[-1:] - ag = ag[:-1] - arg1 = ag * c - arg2 = ag - arg1 - if n%2 == 1: - arg = np.concatenate((arg1, ag_1, arg2)) - else: - arg = np.concatenate((arg1, arg2)) - arg.integral = True - return arg, mn +class SecureFraction: + def __init__(self, a): + self.a = a # numerator, denominator, pos + + def __lt__(self, other): # NB: __lt__() is basic comparison as in Python's list.sort() + b = self.a[:, 0] * other.a[:, 1] < other.a[:, 0] * self.a[:, 1] + c0 = self.a[:, 2] + c0.integral = True + b *= c0 # b = b if c0 else 0 + c1 = other.a[:, 2] + c1.integral = True + b = c1 * (b - 1) + 1 # b = b if c1 else 1 + return b async def main(): @@ -154,18 +67,15 @@ async def main(): basis = secfxp.array(n + np.arange(m)) iteration = 0 - while True: + while await mpc.output((arg_min := T[0, :-1].argmin())[1] < 0): # find index of pivot column - p_col_index, minimum = argmin_int(T[0, :-1]) - if await mpc.output(minimum >= 0): - break # maximum reached + p_col_index = arg_min[0] # find index of pivot row - assert p_col_index.integral - p_col = T[:, :-1] @ p_col_index - p_col1 = p_col[1:] - pos = p_col1 > 0.0001 - p_row_index, (_, pivot) = argmin_rat(np.stack((T[1:, -1], p_col1)), pos) + p_col = T[:, :-1] @ p_col_index + denominator = p_col[1:] + constraints = np.column_stack((T[1:, -1], denominator, denominator > 0.0001)) + p_row_index, (_, pivot, _) = constraints.argmin(key=SecureFraction) # reveal progress a bit iteration += 1 @@ -181,8 +91,6 @@ async def main(): # update tableau Tij = Tij - (Til - bool(i==k))/Tkl *outer (Tkj + bool(j==l)) p_col_index = np.concatenate((p_col_index, np.array([0]))) p_row_index = np.concatenate((np.array([0]), p_row_index)) - p_col_index.integral = True - p_row_index.integral = True p_col = (p_col - p_row_index) / pivot p_row = p_row_index @ T + p_col_index T -= np.outer(p_col, p_row) @@ -193,19 +101,15 @@ async def main(): logging.info('Solution x') x = np.sum(np.fromiter((T[i+1, -1] * mpc.np_fromlist(mpc.unit_vector(basis[i], m + n)[:n]) for i in range(m)), 'O')) - cx = c @ x - Ax = A @ x - Ax_bounded_by_b = mpc.all((Ax <= 1.01 * b + 0.0001).tolist()) + Ax_bounded_by_b = mpc.all((A @ x <= 1.01 * b + 0.0001).tolist()) x_nonnegative = mpc.all((x >= 0).tolist()) logging.info('Dual solution y') y = np.sum(np.fromiter((T[0, j] * mpc.np_fromlist(mpc.unit_vector(cobasis[j], m + n)[n:]) for j in range(n)), 'O')) - yb = y @ b - yA = y @ A - yA_bounded_by_c = mpc.all((yA >= np.where(c > 0, 1/1.01, 1.01) * c - 0.0001).tolist()) + yA_bounded_by_c = mpc.all((y @ A >= np.where(c > 0, 1/1.01, 1.01) * c - 0.0001).tolist()) y_nonnegative = mpc.all((y >= 0).tolist()) - cx_eq_yb = abs(cx - yb) <= 0.01 * abs(cx) + cx_eq_yb = abs((cx := c @ x) - y @ b) <= 0.01 * abs(cx) check = mpc.all([cx_eq_yb, Ax_bounded_by_b, x_nonnegative, yA_bounded_by_c, y_nonnegative]) check = bool(await mpc.output(check)) print(f'verification c.x == y.b, A.x <= b, x >= 0, y.A >= c, y >= 0: {check}') diff --git a/mpyc/README.md b/mpyc/README.md index 4f53b5e8..ceffb14e 100644 --- a/mpyc/README.md +++ b/mpyc/README.md @@ -2,8 +2,8 @@ [MPyC](https://lschoe.github.io/mpyc) currently consists of 14 modules (all in pure Python): -1. [gmpy](https://lschoe.github.io/mpyc/mpyc.gmpy.html): some basic number theoretic algorithms (using GMP via Python package gmpy2, if installed) -2. [numpy](https://lschoe.github.io/mpyc/mpyc.numpy.html): stub to avoid dependency on NumPy package (also handling version issues, etc.) +1. [numpy](https://lschoe.github.io/mpyc/mpyc.numpy.html): stub to avoid dependency on NumPy package (also handling version issues, etc.) +2. [gmpy](https://lschoe.github.io/mpyc/mpyc.gmpy.html): some basic number theoretic algorithms (using GMP via Python package gmpy2, if installed) 3. [gfpx](https://lschoe.github.io/mpyc/mpyc.gfpx.html): polynomial arithmetic over arbitrary prime fields 4. [finfields](https://lschoe.github.io/mpyc/mpyc.finfields.html): arbitrary finite fields, including binary fields and prime fields 5. [fingroups](https://lschoe.github.io/mpyc/mpyc.fingroups.html): finite groups, in particular for use in cryptography (elliptic curves, Schnorr groups, etc.) diff --git a/mpyc/__init__.py b/mpyc/__init__.py index a1e94edb..8edc5fb4 100644 --- a/mpyc/__init__.py +++ b/mpyc/__init__.py @@ -26,7 +26,7 @@ and statistics (securely mimicking Python’s statistics module). """ -__version__ = '0.8.10' +__version__ = '0.8.11' __license__ = 'MIT License' import os diff --git a/mpyc/runtime.py b/mpyc/runtime.py index db45f91c..6a5067ab 100644 --- a/mpyc/runtime.py +++ b/mpyc/runtime.py @@ -1301,6 +1301,8 @@ def argmin(self, *x, key=None): """Secure argmin of all given elements in x. See runtime.sorted() for details on key etc. + In case of multiple occurrences of the minimum values, + the index of the first occurrence is returned. """ if len(x) == 1: x = x[0] @@ -1324,15 +1326,17 @@ def _argmin(self, x, key): i0, min0 = self._argmin(x[:n//2], key) i1, min1 = self._argmin(x[n//2:], key) i1 += n//2 - c = key(min0) < key(min1) - a = self.if_else(c, i0, i1) - m = self.if_else(c, min0, min1) # TODO: merge if_else's once integral attr per list element + c = key(min1) < key(min0) + a = self.if_else(c, i1, i0) + m = self.if_else(c, min1, min0) # TODO: merge if_else's once integral attr per list element return a, m def argmax(self, *x, key=None): """Secure argmax of all given elements in x. See runtime.sorted() for details on key etc. + In case of multiple occurrences of the maximum values, + the index of the first occurrence is returned. """ if len(x) == 1: x = x[0] @@ -1827,7 +1831,7 @@ async def np_all(self, x): containing 0s and 1s (Boolean). Runs in log_2 len(x) rounds. """ - # TODO: cover case of SecureArray (incl. case f > 0 + # TODO: cover case of SecureArray (incl. case f > 0) if iter(x) is x: x = list(x) else: @@ -2315,6 +2319,9 @@ async def np_copy(self, a, order='K'): @mpc_coro_no_pc async def np_transpose(self, a, axes=None): + if a.ndim == 1: + return a + stype = type(a) if axes is None: perm = range(a.ndim)[::-1] @@ -2345,7 +2352,6 @@ async def np_concatenate(self, arrays, axis=0): Default axis is 0. """ # TODO: handle array_like input arrays - # TODO: integral attr if axis is None: shape = (sum(a.size for a in arrays),) else: @@ -2356,16 +2362,29 @@ async def np_concatenate(self, arrays, axis=0): i = 0 while not isinstance(a := arrays[i], sectypes.SecureArray): i += 1 - await self.returnType((type(a), shape)) + stype = type(a) + if issubclass(stype, self.SecureFixedPointArray): + integral = all(a.integral if isinstance(a, stype) else stype(a).integral for a in arrays) + await self.returnType((stype, integral, shape)) + else: + await self.returnType((stype, shape)) arrays = await self.gather(arrays) return np.concatenate(arrays, axis=axis) @mpc_coro_no_pc async def np_stack(self, arrays, axis=0): - a = arrays[0] + i = 0 + while not isinstance(a := arrays[i], sectypes.SecureArray): + i += 1 shape = list(a.shape) shape.insert(axis, len(arrays)) - await self.returnType((type(a), tuple(shape))) + shape = tuple(shape) + stype = type(a) + if issubclass(stype, self.SecureFixedPointArray): + integral = all(a.integral if isinstance(a, stype) else stype(a).integral for a in arrays) + await self.returnType((stype, integral, shape)) + else: + await self.returnType((stype, shape)) arrays = await self.gather(arrays) return np.stack(arrays, axis=axis) @@ -2430,13 +2449,21 @@ async def np_hstack(self, tup): except for 1-D arrays where it concatenates along the first axis. Rebuilds arrays divided by hsplit. """ - a = tup[0] + i = 0 + while not isinstance(a := tup[i], sectypes.SecureArray): + i += 1 + stype = type(a) shape = list(a.shape) if a.ndim == 1: shape[0] = sum(a.shape[0] for a in tup) else: shape[1] = sum(a.shape[1] for a in tup) - await self.returnType((type(a), tuple(shape))) + shape = tuple(shape) + if issubclass(stype, self.SecureFixedPointArray): + integral = all(a.integral if isinstance(a, stype) else stype(a).integral for a in tup) + await self.returnType((stype, integral, shape)) + else: + await self.returnType((stype, shape)) tup = await self.gather(tup) return np.hstack(tup) @@ -2464,11 +2491,18 @@ async def np_dstack(self, tup): @mpc_coro_no_pc async def np_column_stack(self, tup): - a = tup[0] + i = 0 + while not isinstance(a := tup[i], sectypes.SecureArray): + i += 1 + stype = type(a) shape_0 = a.shape[0] shape_1 = sum(a.shape[1] if a.shape[1:] else 1 for a in tup) shape = (shape_0, shape_1) - await self.returnType((type(a), shape)) + if issubclass(stype, self.SecureFixedPointArray): + integral = all(a.integral if isinstance(a, stype) else stype(a).integral for a in tup) + await self.returnType((stype, integral, shape)) + else: + await self.returnType((stype, shape)) tup = await self.gather(tup) return np.column_stack(tup) @@ -2685,6 +2719,104 @@ async def np_sgn(self, a, l=None, LT=False, EQ=False): z = z.reshape(a.shape) return z + def np_argmin(self, a, key=None, raw=False, raw2=False): + # TODO: rename raw, raw2 + # TODO: generalize beyond 1D arrays + if key is None: + key = lambda a: a + + u, m = self._np_argmin(a, key) + if not raw: + iv = np.arange(len(u)) + if isinstance(a, self.SecureFixedPointArray): + iv = type(a)(iv) # TODO: remove once @ handles integral attrb for public values + u = u @ iv + if not raw2: + return u + + return u, m + + def _np_argmin(self, a, key): + # return first occurence if multiple hits + n = len(a) + if n == 1: + u = type(a)(np.array([1])) + m = a[0] + elif n == 2: + # Redundant case, except for some small savings. + a1, a2 = a[:1], a[1:] + c = key(a2) < key(a1) + u = self.np_concatenate((1 - c, c)) # save * + m = (c * (a2 - a1) + a1)[0] # save .T (3x) + else: + n0 = n%2 + a1, a2 = a[n0::2], a[n0 + 1::2] + c = key(a2) < key(a1) + a1, a2 = a1.T, a2.T + m = c * (a2 - a1) + a1 + m = m.T + del a1, a2 + if n0: + m = self.np_concatenate((a[:1], m)) + u, m = self._np_argmin(m, key) + if n0: + u0, u = u[:1], u[1:] + u2 = u * c + u1 = u - u2 + u = self.np_column_stack((u1, u2)).flatten() + if n0: + u = self.np_concatenate((u0, u)) + return u, m + + def np_argmax(self, a, key=None, raw=False, raw2=False): + # TODO: rename raw, raw2 + # TODO: generalize beyond 1D arrays + if key is None: + key = lambda a: a + + u, m = self._np_argmax(a, key) + if not raw: + iv = np.arange(len(u)) + if isinstance(a, self.SecureFixedPointArray): + iv = type(a)(iv) # TODO: remove once @ handles integral attrb for public values + u = u @ iv + if not raw2: + return u + + return u, m + + def _np_argmax(self, a, key): + # return first occurence if multiple hits + n = len(a) + if n == 1: + u = type(a)(np.array([1])) + m = a[0] + elif n == 2: + # Redundant case, except for some small savings. + a1, a2 = a[:1], a[1:] + c = key(a1) < key(a2) + u = self.np_concatenate((1 - c, c)) # save * + m = (c * (a2 - a1) + a1)[0] # save .T (3x) + else: + n0 = n%2 + a1, a2 = a[n0::2], a[n0 + 1::2] + c = key(a1) < key(a2) + a1, a2 = a1.T, a2.T + m = c * (a2 - a1) + a1 + m = m.T + del a1, a2 + if n0: + m = self.np_concatenate((a[:1], m)) + u, m = self._np_argmax(m, key) + if n0: + u0, u = u[:1], u[1:] + u2 = u * c + u1 = u - u2 + u = self.np_column_stack((u1, u2)).flatten() + if n0: + u = self.np_concatenate((u0, u)) + return u, m + @mpc_coro async def np_det(self, A): """Secure determinant for nonsingular matrices.""" diff --git a/mpyc/sectypes.py b/mpyc/sectypes.py index a34d63d8..8c5396b9 100644 --- a/mpyc/sectypes.py +++ b/mpyc/sectypes.py @@ -1182,8 +1182,7 @@ def __eq__(self, other): def __ge__(self, other): """Greater-than or equal comparison.""" # self >= other <=> not (self < other) - a = 1 - runtime.np_less(self, other) - return a + return 1 - runtime.np_less(self, other) def __gt__(self, other): """Strictly greater-than comparison.""" @@ -1247,6 +1246,20 @@ def swapaxes(self, axis1, axis2): def sum(self, *args, **kwargs): return runtime.np_sum(self, *args, **kwargs) + def argmin(self, *args, **kwargs): + if 'raw' not in kwargs: + kwargs['raw'] = True + if 'raw2' not in kwargs: + kwargs['raw2'] = True + return runtime.np_argmin(self, *args, **kwargs) + + def argmax(self, *args, **kwargs): + if 'raw' not in kwargs: + kwargs['raw'] = True + if 'raw2' not in kwargs: + kwargs['raw2'] = True + return runtime.np_argmax(self, *args, **kwargs) + class SecureFiniteFieldArray(SecureArray): """Base class for secure (secret-shared) arrays of finite field elements."""