Skip to content

Commit

Permalink
First version of vectorized mpc.np_argmin/max().
Browse files Browse the repository at this point in the history
  • Loading branch information
lschoe authored Nov 17, 2022
1 parent f3fe903 commit 1dc3fe1
Show file tree
Hide file tree
Showing 12 changed files with 215 additions and 231 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ python:
- 3.8
- 3.9
- 3.10
- 3.11
- pypy3.8-7.3.9
install:
- pip install --upgrade pip
Expand Down
14 changes: 6 additions & 8 deletions demos/lpsolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,13 @@ async def main():
args = parser.parse_args()

settings = [('uvlp', 8, 1, 2),
('wiki', 6, 1, 2),
('wiki', 6, 1, 1),
('tb2x2', 6, 1, 2),
('woody', 8, 1, 3),
('LPExample_R20', 70, 1, 5),
('LPExample_R20', 70, 1, 9),
('sc50b', 104, 10, 55),
('kb2', 536, 100000, 106),
('LPExample', 110, 1, 178)]
('kb2', 560, 100000, 154),
('LPExample', 110, 1, 175)]
name, bit_length, scale, n_iter = settings[args.dataset]
if args.bit_length:
bit_length = args.bit_length
Expand Down Expand Up @@ -200,11 +200,9 @@ async def main():
previous_pivot = secint(1)

iteration = 0
while True:
while await mpc.output((arg_min := argmin_int(T[0][:-1]))[1] < 0):
# find index of pivot column
p_col_index, minimum = argmin_int(T[0][:-1])
if await mpc.output(minimum >= 0):
break # maximum reached
p_col_index = arg_min[0]

# find index of pivot row
p_col = mpc.matrix_prod([p_col_index], T, True)[0]
Expand Down
7 changes: 2 additions & 5 deletions demos/lpsolverfxp.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,9 @@ async def main():
basis = [secfxp(n + i) for i in range(m)]

iteration = 0
while True:
while await mpc.output((arg_min := argmin_int(T[0][:-1]))[1] < 0):
# find index of pivot column
p_col_index, minimum = argmin_int(T[0][:-1])

if await mpc.output(minimum >= 0):
break # maximum reached
p_col_index = arg_min[0]

# find index of pivot row
p_col = mpc.matrix_prod([p_col_index], T, True)[0]
Expand Down
3 changes: 1 addition & 2 deletions demos/np_bnnmnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,7 @@ async def main():
if args.no_legendre:
secint.bit_length = 14
for i in range(batch_size):
prediction = int(await mpc.output(mpc.argmax(L[i].tolist())[0]))

prediction = await mpc.output(np.argmax(L[i]))
error = '******* ERROR *******' if prediction != labels[i] else ''
print(f'Image #{offset+i} with label {labels[i]}: {prediction} predicted. {error}')
print(await mpc.output(L[i]))
Expand Down
2 changes: 1 addition & 1 deletion demos/np_cnnmnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ async def main():
secnum.bit_length = 37

for i in range(batch_size):
prediction = int(await mpc.output(mpc.argmax(x[i].tolist())[0]))
prediction = int(await mpc.output(np.argmax(x[i])))
error = '******* ERROR *******' if prediction != labels[i] else ''
print(f'Image #{offset+i} with label {labels[i]}: {prediction} predicted. {error}')
print(await mpc.output(x[i]))
Expand Down
12 changes: 6 additions & 6 deletions demos/np_id3gini.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
See id3gini.py for background information on decision tree learning and ID3.
"""
# TODO: vectorize mpc.argmax()

import os
import logging
Expand All @@ -20,7 +19,7 @@
@mpc.coroutine
async def id3(T, R) -> asyncio.Future:
sizes = S[C] @ T
i, mx = mpc.argmax(sizes)
i, mx = sizes.argmax(raw=False)
sizeT = sizes.sum()
stop = (sizeT <= int(args.epsilon * len(T))) + (mx == sizeT)
if not (R and await mpc.is_zero_public(stop)):
Expand All @@ -29,7 +28,8 @@ async def id3(T, R) -> asyncio.Future:
tree = i
else:
T_SC = (T * S[C]).T
k = mpc.argmax([GI(S[A] @ T_SC) for A in R], key=SecureFraction)[0]
CT = np.stack(tuple(GI(S[A] @ T_SC) for A in R))
k = CT.argmax(key=SecureFraction, raw=False, raw2=False)
A = list(R)[await mpc.output(k)]
logging.info(f'Attribute node {A}')
T_SA = T * S[A]
Expand All @@ -46,15 +46,15 @@ def GI(x):
y = args.alpha * np.sum(x, axis=1) + 1 # NB: alternatively, use s + (s == 0)
D = mpc.prod(y.tolist())
G = np.sum(np.sum(x * x, axis=1) / y)
return [D * G, D] # numerator, denominator
return mpc.np_fromlist([D * G, D]) # numerator, denominator


class SecureFraction:
def __init__(self, a):
self.n, self.d = a # numerator, denominator
self.a = a # numerator, denominator

def __lt__(self, other): # NB: __lt__() is basic comparison as in Python's list.sort()
return mpc.in_prod([self.n, -self.d], [other.d, other.n]) < 0
return self.a[:, 0] * other.a[:, 1] < self.a[:, 1] * other.a[:, 0]


depth = lambda tree: 0 if isinstance(tree, int) else max(map(depth, tree[1])) + 1
Expand Down
88 changes: 14 additions & 74 deletions demos/np_lpsolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,69 +16,12 @@
from mpyc.runtime import mpc


# TODO: unify approaches for argmin etc. with secure NumPy arrays (also see lpsolverfxp.py)
class SecureFraction:
def __init__(self, a):
self.a = a # numerator, denominator


def argmin_int(x):
secintarr = type(x)
n = len(x)
if n == 1:
return (secintarr(np.array([1])), x[0])

if n == 2:
b = x[0] < x[1]
arg = mpc.np_fromlist([b, 1 - b])
min = b * (x[0] - x[1]) + x[1]
return arg, min

a = x[:n//2] ## split even odd? start at n%2 as in reduce in mpctools
b = x[(n+1)//2:]
c = a < b
m = c * (a - b) + b
if n%2 == 1:
m = np.concatenate((m, x[n//2:(n+1)//2]))
ag, mn = argmin_int(m)
if n%2 == 1:
ag_1 = ag[-1:]
ag = ag[:-1]
arg1 = ag * c
arg2 = ag - arg1
if n%2 == 1:
arg = np.concatenate((arg1, ag_1, arg2))
else:
arg = np.concatenate((arg1, arg2))
return arg, mn


def argmin_rat(nd):
secintarr = type(nd)
N = nd.shape[1]
if N == 1:
return (secintarr(np.array([1])), (nd[0, 0], nd[1, 0]))

if N == 2:
b = mpc.in_prod([nd[0, 0], -nd[1, 0]], [nd[1, 1], nd[0, 1]]) < 0
arg = mpc.np_fromlist([b, 1 - b])
min = b * (nd[:, 0] - nd[:, 1]) + nd[:, 1]
return arg, (min[0], min[1])

a = nd[:, :N//2]
b = nd[:, (N+1)//2:]
c = a[0] * b[1] < b[0] * a[1]
m = c * (a - b) + b
if N%2 == 1:
m = np.concatenate((m, nd[:, N//2:(N+1)//2]), axis=1)
ag, mn = argmin_rat(m)
if N%2 == 1:
ag_1 = ag[-1:]
ag = ag[:-1]
arg1 = ag * c
arg2 = ag - arg1
if N%2 == 1:
arg = np.concatenate((arg1, ag_1, arg2))
else:
arg = np.concatenate((arg1, arg2))
return arg, mn
def __lt__(self, other): # NB: __lt__() is basic comparison as in Python's list.sort()
return self.a[:, 0] * other.a[:, 1] < self.a[:, 1] * other.a[:, 0]


def pow_list(a, x, n):
Expand Down Expand Up @@ -135,13 +78,13 @@ async def main():
args = parser.parse_args()

settings = [('uvlp', 8, 1, 2),
('wiki', 6, 1, 2),
('wiki', 6, 1, 1),
('tb2x2', 6, 1, 2),
('woody', 8, 1, 3),
('LPExample_R20', 70, 1, 5),
('LPExample_R20', 70, 1, 9),
('sc50b', 104, 10, 55),
('kb2', 536, 100000, 106),
('LPExample', 110, 1, 178)]
('kb2', 560, 100000, 154),
('LPExample', 110, 1, 175)]
name, bit_length, scale, n_iter = settings[args.dataset]
if args.bit_length:
bit_length = args.bit_length
Expand Down Expand Up @@ -173,17 +116,15 @@ async def main():
previous_pivot = secint(1)

iteration = 0
while True:
while await mpc.output((arg_min := T[0, :-1].argmin())[1] < 0):
# find index of pivot column
p_col_index, minimum = argmin_int(T[0, :-1])
if await mpc.output(minimum >= 0):
break # maximum reached
p_col_index = arg_min[0]

# find index of pivot row
p_col = T[:, :-1] @ p_col_index
den = p_col[1:]
num = T[1:, -1] + (den <= 0)
p_row_index, (_, pivot) = argmin_rat(np.stack((num, den)))
denominator = p_col[1:]
constraints = np.column_stack((T[1:, -1] + (denominator <= 0), denominator))
p_row_index, (_, pivot) = constraints.argmin(key=SecureFraction)

# reveal progress a bit
iteration += 1
Expand All @@ -198,7 +139,6 @@ async def main():
# update tableau Tij = Tij*Tkl/Tkl' - (Til/Tkl' - bool(i==k)) * (Tkj + bool(j==l)*Tkl')
p_col_index = np.concatenate((p_col_index, np.array([0])))
p_row_index = np.concatenate((np.array([0]), p_row_index))

pp_inv = 1 / previous_pivot
p_col = p_col * pp_inv - p_row_index
p_row = p_row_index @ T + previous_pivot * p_col_index
Expand Down
Loading

0 comments on commit 1dc3fe1

Please sign in to comment.