diff --git a/BMR/RealGarbleWire.hpp b/BMR/RealGarbleWire.hpp index a62b1e705..28ba50130 100644 --- a/BMR/RealGarbleWire.hpp +++ b/BMR/RealGarbleWire.hpp @@ -58,7 +58,7 @@ void GarbleJob::middle_round(RealProgramParty& party, Protocol& second_pro { second_protocol.prepare_mul(party.shared_delta(j), lambda_uv + lambda_v * alpha + lambda_u * beta - + T(alpha * beta, me, party.MC->get_alphai()) + + T::constant(alpha * beta, me, party.MC->get_alphai()) + lambda_w); } } @@ -131,7 +131,7 @@ void RealGarbleWire::input(party_id_t from, char input) assert(party.MC != 0); auto& protocol = party.shared_proc->protocol; protocol.init_mul(party.shared_proc); - protocol.prepare_mul(mask, T(1, party.P->my_num(), party.mac_key) - mask); + protocol.prepare_mul(mask, T::constant(1, party.P->my_num(), party.mac_key) - mask); protocol.exchange(); if (party.MC->open(protocol.finalize_mul(), *party.P) != 0) throw runtime_error("input mask not a bit"); diff --git a/BMR/Register.h b/BMR/Register.h index 87a3af8d8..e648f5c73 100644 --- a/BMR/Register.h +++ b/BMR/Register.h @@ -195,6 +195,7 @@ class BlackHole template BlackHole& operator<<(T) { return *this; } BlackHole& operator<<(BlackHole& (*__pf)(BlackHole&)) { (void)__pf; return *this; } + void activate(bool) {} }; inline BlackHole& endl(BlackHole& b) { return b; } inline BlackHole& flush(BlackHole& b) { return b; } @@ -205,7 +206,7 @@ class Phase typedef NoMemory DynamicMemory; typedef BlackHole out_type; - static const BlackHole out; + static BlackHole out; static void check(const int128& value, word share, int128 mac) { (void)value; (void)share; (void)mac; } diff --git a/CHANGELOG.md b/CHANGELOG.md index 30f217019..91efc1445 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.1.4 (Dec 23, 2019) + +- Mixed circuit computation with secret sharing +- Binary computation for dishonest majority using secret sharing as in [FKOS15](https://eprint.iacr.org/2015/901) +- Fixed security bug: insufficient OT correlation check in SPDZ2k +- This version breaks bytecode compatibilty. + ## 0.1.3 (Nov 21, 2019) - Python 3 diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index b7eed45c4..d3ed17c53 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -28,13 +28,32 @@ class ClearBitsAF(base.RegisterArgFormat): LDBITS = 0x20a, ANDS = 0x20b, TRANS = 0x20c, - XORCI = 0x210, + BITB = 0x20d, + ANDM = 0x20e, + LDMSB = 0x240, + STMSB = 0x241, + LDMSBI = 0x242, + STMSBI = 0x243, + MOVSB = 0x244, + INPUTB = 0x246, + XORCBI = 0x210, BITDECC = 0x211, CONVCINT = 0x213, REVEAL = 0x214, STMSDCI = 0x215, - INPUTB = 0x216, + LDMCB = 0x217, + STMCB = 0x218, + XORCB = 0x219, + ADDCB = 0x21a, + ADDCBI = 0x21b, + MULCBI = 0x21c, + SHRCBI = 0x21d, + SHLCBI = 0x21e, PRINTREGSIGNED = 0x220, + PRINTREGB = 0x221, + PRINTREGPLAINB = 0x222, + PRINTFLOATPLAINB = 0x223, + CONDPRINTSTRB = 0x224, CONVCBIT = 0x230, ) @@ -46,12 +65,12 @@ class xorm(base.Instruction): code = opcodes['XORM'] arg_format = ['int','sbw','sb','cb'] -class xorc(base.Instruction): - code = base.opcodes['XORC'] +class xorcb(base.Instruction): + code = opcodes['XORCB'] arg_format = ['cbw','cb','cb'] -class xorci(base.Instruction): - code = opcodes['XORCI'] +class xorcbi(base.Instruction): + code = opcodes['XORCBI'] arg_format = ['cbw','cb','int'] class andrs(base.Instruction): @@ -62,16 +81,20 @@ class ands(base.Instruction): code = opcodes['ANDS'] arg_format = tools.cycle(['int','sbw','sb','sb']) -class addc(base.Instruction): - code = base.opcodes['ADDC'] +class andm(base.Instruction): + code = opcodes['ANDM'] + arg_format = ['int','sbw','sb','cb'] + +class addcb(base.Instruction): + code = opcodes['ADDCB'] arg_format = ['cbw','cb','cb'] -class addci(base.Instruction): - code = base.opcodes['ADDCI'] +class addcbi(base.Instruction): + code = opcodes['ADDCBI'] arg_format = ['cbw','cb','int'] -class mulci(base.Instruction): - code = base.opcodes['MULCI'] +class mulcbi(base.Instruction): + code = opcodes['MULCBI'] arg_format = ['cbw','cb','int'] class bitdecs(base.VarArgsInstruction): @@ -86,44 +109,44 @@ class bitdecc(base.VarArgsInstruction): code = opcodes['BITDECC'] arg_format = tools.chain(['cb'], itertools.repeat('cbw')) -class shrci(base.Instruction): - code = base.opcodes['SHRCI'] +class shrcbi(base.Instruction): + code = opcodes['SHRCBI'] arg_format = ['cbw','cb','int'] -class shlci(base.Instruction): - code = base.opcodes['SHLCI'] +class shlcbi(base.Instruction): + code = opcodes['SHLCBI'] arg_format = ['cbw','cb','int'] class ldbits(base.Instruction): code = opcodes['LDBITS'] arg_format = ['sbw','i','i'] -class ldms(base.DirectMemoryInstruction, base.ReadMemoryInstruction): - code = base.opcodes['LDMS'] +class ldmsb(base.DirectMemoryInstruction, base.ReadMemoryInstruction): + code = opcodes['LDMSB'] arg_format = ['sbw','int'] -class stms(base.DirectMemoryWriteInstruction): - code = base.opcodes['STMS'] +class stmsb(base.DirectMemoryWriteInstruction): + code = opcodes['STMSB'] arg_format = ['sb','int'] # def __init__(self, *args, **kwargs): # super(type(self), self).__init__(*args, **kwargs) # import inspect # self.caller = [frame[1:] for frame in inspect.stack()[1:]] -class ldmc(base.DirectMemoryInstruction, base.ReadMemoryInstruction): - code = base.opcodes['LDMC'] +class ldmcb(base.DirectMemoryInstruction, base.ReadMemoryInstruction): + code = opcodes['LDMCB'] arg_format = ['cbw','int'] -class stmc(base.DirectMemoryWriteInstruction): - code = base.opcodes['STMC'] +class stmcb(base.DirectMemoryWriteInstruction): + code = opcodes['STMCB'] arg_format = ['cb','int'] -class ldmsi(base.ReadMemoryInstruction): - code = base.opcodes['LDMSI'] +class ldmsbi(base.ReadMemoryInstruction): + code = opcodes['LDMSBI'] arg_format = ['sbw','ci'] -class stmsi(base.WriteMemoryInstruction): - code = base.opcodes['STMSI'] +class stmsbi(base.WriteMemoryInstruction): + code = opcodes['STMSBI'] arg_format = ['sb','ci'] class ldmsdi(base.ReadMemoryInstruction): @@ -158,8 +181,8 @@ class convcbit(base.Instruction): code = opcodes['CONVCBIT'] arg_format = ['ciw','cb'] -class movs(base.Instruction): - code = base.opcodes['MOVS'] +class movsb(base.Instruction): + code = opcodes['MOVSB'] arg_format = ['sbw','sb'] class trans(base.VarArgsInstruction): @@ -169,8 +192,8 @@ def __init__(self, *args): ['sb'] * (len(args) - 1 - args[0]) super(trans, self).__init__(*args) -class bit(base.Instruction): - code = base.opcodes['BIT'] +class bitb(base.Instruction): + code = opcodes['BITB'] arg_format = ['sbw'] class reveal(base.Instruction): @@ -182,28 +205,28 @@ class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction): code = opcodes['INPUTB'] arg_format = tools.cycle(['p','int','int','sbw']) -class print_reg(base.IOInstruction): - code = base.opcodes['PRINTREG'] +class print_regb(base.IOInstruction): + code = opcodes['PRINTREGB'] arg_format = ['cb','i'] def __init__(self, reg, comment=''): - super(print_reg, self).__init__(reg, self.str_to_int(comment)) + super(print_regb, self).__init__(reg, self.str_to_int(comment)) -class print_reg_plain(base.IOInstruction): - code = base.opcodes['PRINTREGPLAIN'] +class print_reg_plainb(base.IOInstruction): + code = opcodes['PRINTREGPLAINB'] arg_format = ['cb'] class print_reg_signed(base.IOInstruction): code = opcodes['PRINTREGSIGNED'] arg_format = ['int','cb'] -class print_float_plain(base.IOInstruction): +class print_float_plainb(base.IOInstruction): __slots__ = [] - code = base.opcodes['PRINTFLOATPLAIN'] + code = opcodes['PRINTFLOATPLAINB'] arg_format = ['cb', 'cb', 'cb', 'cb'] -class cond_print_str(base.IOInstruction): +class cond_print_strb(base.IOInstruction): r""" Print a 4 character string. """ - code = base.opcodes['CONDPRINTSTR'] + code = opcodes['CONDPRINTSTRB'] arg_format = ['cb', 'int'] def __init__(self, cond, val): diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 988f0a300..2f95a7743 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -1,5 +1,5 @@ -from Compiler.types import MemValue, read_mem_value, regint, Array -from Compiler.types import _bitint, _number, _fix, _structure +from Compiler.types import MemValue, read_mem_value, regint, Array, cint +from Compiler.types import _bitint, _number, _fix, _structure, _bit from Compiler.program import Tape, Program from Compiler.exceptions import * from Compiler import util, oram, floatingpoint, library @@ -7,7 +7,7 @@ import operator from functools import reduce -class bits(Tape.Register, _structure): +class bits(Tape.Register, _structure, _bit): n = 40 size = 1 PreOp = staticmethod(floatingpoint.PreOpN) @@ -97,10 +97,15 @@ def set_length(self, n): raise Exception('too long: %d' % n) self.n = n def load_other(self, other): + if isinstance(other, cint): + size = other.size + other = sum(x << i for i, x in enumerate(other)) + other = other.to_regint(size) if isinstance(other, int): self.set_length(self.n or util.int_len(other)) self.load_int(other) elif isinstance(other, regint): + assert(other.size == 1) self.conv_regint(self.n, self, other) elif isinstance(self, type(other)) or isinstance(other, type(self)): self.mov(self, other) @@ -122,8 +127,8 @@ class cbits(bits): max_length = 64 reg_type = 'cb' is_clear = True - load_inst = (None, inst.ldmc) - store_inst = (None, inst.stmc) + load_inst = (None, inst.ldmcb) + store_inst = (None, inst.stmcb) bitdec = inst.bitdecc conv_regint = staticmethod(lambda n, x, y: inst.convcint(x, y)) types = {} @@ -146,9 +151,9 @@ def clear_op(self, other, c_inst, ci_inst, op): else: return op(self, cbits(other)) __add__ = lambda self, other: \ - self.clear_op(other, inst.addc, inst.addci, operator.add) + self.clear_op(other, inst.addcb, inst.addcbi, operator.add) __xor__ = lambda self, other: \ - self.clear_op(other, inst.xorc, inst.xorci, operator.xor) + self.clear_op(other, inst.xorcb, inst.xorcbi, operator.xor) __radd__ = __add__ __rxor__ = __xor__ def __mul__(self, other): @@ -157,25 +162,25 @@ def __mul__(self, other): else: try: res = cbits(n=min(self.max_length, self.n+util.int_len(other))) - inst.mulci(res, self, other) + inst.mulcbi(res, self, other) return res except TypeError: return NotImplemented def __rshift__(self, other): res = cbits(n=self.n-other) - inst.shrci(res, self, other) + inst.shrcbi(res, self, other) return res def __lshift__(self, other): res = cbits(n=self.n+other) - inst.shlci(res, self, other) + inst.shlcbi(res, self, other) return res def print_reg(self, desc=''): - inst.print_reg(self, desc) + inst.print_regb(self, desc) def print_reg_plain(self): inst.print_reg_signed(self.n, self) output = print_reg_plain def print_if(self, string): - inst.cond_print_str(self, string) + inst.cond_print_strb(self, string) def reveal(self): return self def to_regint(self, dest): @@ -189,12 +194,12 @@ class sbits(bits): is_clear = False clear_type = cbits default_type = cbits - load_inst = (inst.ldmsi, inst.ldms) - store_inst = (inst.stmsi, inst.stms) + load_inst = (inst.ldmsbi, inst.ldmsb) + store_inst = (inst.stmsbi, inst.stmsb) bitdec = inst.bitdecs bitcom = inst.bitcoms conv_regint = inst.convsint - mov = inst.movs + mov = inst.movsb types = {} def __init__(self, *args, **kwargs): bits.__init__(self, *args, **kwargs) @@ -207,7 +212,7 @@ def new(value=None, n=None): @staticmethod def get_random_bit(): res = sbit() - inst.bit(res) + inst.bitb(res) return res @classmethod def get_input_from(cls, player, n_bits=None): @@ -238,7 +243,7 @@ def load_int(self, value): if self.n <= 32: inst.ldbits(self, self.n, value) elif self.n <= 64: - self.load_other(regint(value)) + self.load_other(regint(value, size=1)) elif self.n <= 128: lower = sbits.get_type(64)(value % 2**64) upper = sbits.get_type(self.n - 64)(value >> 64) @@ -251,7 +256,7 @@ def __add__(self, other): return self.xor_int(other) else: if not isinstance(other, sbits): - other = sbits(other) + other = self.conv(other) n = min(self.n, other.n) res = self.new(n=n) inst.xors(n, res, self, other) @@ -300,13 +305,20 @@ def __and__(self, other): return 0 elif util.is_all_ones(other, self.n): return self - assert(self.n == other.n) res = self.new(n=self.n) + if not isinstance(other, sbits): + other = cbits.get_type(self.n).conv(other) + inst.andm(self.n, res, self, other) + return res + other = self.conv(other) + assert(self.n == other.n) inst.ands(self.n, res, self, other) return res def xor_int(self, other): if other == 0: return self + elif other == self.long_one(): + return ~self self_bits = self.bit_decompose() other_bits = util.bit_decompose(other, max(self.n, util.int_len(other))) extra_bits = [self.new(b, n=1) for b in other_bits[self.n:]] @@ -332,9 +344,8 @@ def __invert__(self): # res = type(self)(n=self.n) # inst.nots(res, self) # return res - one = self.new(value=1, n=1) - bits = [one + bit for bit in self.bit_decompose()] - return self.bit_compose(bits) + one = self.new(value=self.long_one(), n=self.n) + return self + one def __neg__(self): return self def reveal(self): @@ -381,6 +392,9 @@ def trans(cls, rows): def if_else(self, x, y): # vectorized if/else return result_conv(x, y)(self & (x ^ y) ^ y) + @staticmethod + def bit_adder(*args, **kwargs): + return sbitint.bit_adder(*args, **kwargs) class sbitvec(object): @classmethod @@ -634,7 +648,7 @@ def output(self): bits = self.v.bit_decompose(self.k) sign = bits[-1] v = self.v + (sign << (self.k)) * -1 - inst.print_float_plain(v, cbits(-self.f, n=32), cbits(0), cbits(0)) + inst.print_float_plainb(v, cbits(-self.f, n=32), cbits(0), cbits(0)) class sbitfix(_fix): float_type = type(None) diff --git a/Compiler/allocator.py b/Compiler/allocator.py index f6d1f0e0f..df0ef7be2 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -41,16 +41,16 @@ def alloc_reg(self, reg, free): raise RegisterOverflowError() self.alloc[base] = res - if base.vector: - for i,r in enumerate(base.vector): - r.i = self.alloc[base] + i base.i = self.alloc[base] def dealloc_reg(self, reg, inst, free): - self.dealloc.add(reg) + if reg.vector: + self.dealloc |= reg.vector + else: + self.dealloc.add(reg) base = reg.vectorbase - if base.vector and not inst.is_vec(): + if base.vector: for i in base.vector: if i not in self.dealloc: # not all vector elements ready for deallocation diff --git a/Compiler/comparison.py b/Compiler/comparison.py index e9cb21dac..62834887d 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -139,8 +139,12 @@ def TruncLeakyInRing(a, k, m, signed): from .types import sint, intbitint, cint, cgf2n n_bits = k - m n_shift = int(program.options.ring) - n_bits - r_bits = [sint.get_random_bit() for i in range(n_bits)] - r = sint.bit_compose(r_bits) + if program.use_dabit and n_bits > 1: + r, r_bits = zip(*(sint.get_dabit() for i in range(n_bits))) + r = sint.bit_compose(r) + else: + r_bits = [sint.get_random_bit() for i in range(n_bits)] + r = sint.bit_compose(r_bits) if signed: a += (1 << (k - 1)) shifted = ((a << (n_shift - m)) + (r << n_shift)).reveal() @@ -195,21 +199,26 @@ def Mod2mRing(a_prime, a, k, m, signed): assert(int(program.options.ring) >= k) from Compiler.types import sint, intbitint, cint shift = int(program.options.ring) - m - r = [sint.get_random_bit() for i in range(m)] + if program.use_dabit: + r, r_bin = zip(*(sint.get_dabit() for i in range(m))) + else: + r = [sint.get_random_bit() for i in range(m)] + r_bin = r r_prime = sint.bit_compose(r) tmp = a + r_prime c_prime = (tmp << shift).reveal() >> shift u = sint() - BitLTL(u, c_prime, r, 0) + BitLTL(u, c_prime, r_bin, 0) res = (u << m) + c_prime - r_prime if a_prime is not None: movs(a_prime, res) return res def Mod2mField(a_prime, a, k, m, kappa, signed): + from .types import sint r_dprime = program.curr_block.new_reg('s') r_prime = program.curr_block.new_reg('s') - r = [program.curr_block.new_reg('s') for i in range(m)] + r = [sint() for i in range(m)] c = program.curr_block.new_reg('c') c_prime = program.curr_block.new_reg('c') v = program.curr_block.new_reg('s') @@ -238,7 +247,7 @@ def Mod2mField(a_prime, a, k, m, kappa, signed): adds(a_prime, t[5], t[4]) return r_dprime, r_prime, c, c_prime, u, t, c2k1 -def PRandM(r_dprime, r_prime, b, k, m, kappa): +def PRandM(r_dprime, r_prime, b, k, m, kappa, use_dabit=True): """ r_dprime = random secret integer in range [0, 2^(k + kappa - m) - 1] r_prime = random secret integer in range [0, 2^m - 1] @@ -249,6 +258,12 @@ def PRandM(r_dprime, r_prime, b, k, m, kappa): PRandInt(r_dprime, k + kappa - m) # r_dprime is always multiplied by 2^m program.curr_tape.require_bit_length(k + kappa) + if use_dabit and program.use_dabit and m > 1: + from .types import sint + r, b[:] = zip(*(sint.get_dabit() for i in range(m))) + r = sint.bit_compose(r) + movs(r_prime, r) + return bit(b[-1]) for i in range(1,m): adds(t[i][0], t[i-1][1], t[i-1][1]) @@ -345,14 +360,13 @@ def carry(b, a, compute_p): return a t = [program.curr_block.new_reg('s') for i in range(3)] if compute_p: - muls(t[0], a[0], b[0]) - muls(t[1], a[0], b[1]) - adds(t[2], a[1], t[1]) + t[0] = a[0].bit_and(b[0]) + t[2] = a[0].bit_and(b[1]) + a[1] return t[0], t[2] # from WP9 report # length of a is even -def CarryOutAux(d, a, kappa): +def CarryOutAux(a, kappa): k = len(a) if k > 1 and k % 2 == 1: a.append(None) @@ -362,9 +376,9 @@ def CarryOutAux(d, a, kappa): if k > 1: for i in range(k//2): u[i] = carry(a[2*i+1], a[2*i], i != k//2-1) - CarryOutAux(d, u[:k//2][::-1], kappa) + return CarryOutAux(u[:k//2][::-1], kappa) else: - movs(d, a[0][1]) + return a[0][1] # carry out with carry-in bit c def CarryOut(res, a, b, c=0, kappa=None): @@ -378,19 +392,14 @@ def CarryOut(res, a, b, c=0, kappa=None): k = len(a) from . import types d = [program.curr_block.new_reg('s') for i in range(k)] - t = [[types.sint() for i in range(k)] for i in range(4)] s = [program.curr_block.new_reg('s') for i in range(3)] for i in range(k): - mulm(t[0][i], b[i], a[i]) - mulsi(t[1][i], t[0][i], 2) - addm(t[2][i], b[i], a[i]) - subs(t[3][i], t[2][i], t[1][i]) - d[i] = [t[3][i], t[0][i]] + d[i] = list(b[i].half_adder(a[i])) s[0] = d[-1][0] * c s[1] = d[-1][1] + s[0] d[-1][1] = s[1] - CarryOutAux(res, d[::-1], kappa) + movs(res, types.sint.conv(CarryOutAux(d[::-1], kappa))) def CarryOutLE(a, b, c=0): """ Little-endian version """ @@ -412,7 +421,7 @@ def BitLTL(res, a, b, kappa): s = [[program.curr_block.new_reg('s') for i in range(k)] for j in range(2)] t = [program.curr_block.new_reg('s') for i in range(1)] for i in range(len(b)): - subsfi(s[0][i], b[i], 1) + s[0][i] = b[0].long_one() - b[i] CarryOut(t[0], a_bits[::-1], s[0][::-1], 1, kappa) subsfi(res, t[0], 1) return a_bits, s[0] diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index 2a51ebd46..e49fc456f 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -32,8 +32,12 @@ def shift_two(n, pos): def maskRing(a, k): shift = int(program.Program.prog.options.ring) - k - r = [types.sint.get_random_bit() for i in range(k)] - r_prime = types.sint.bit_compose(r) + if program.Program.prog.use_dabit: + rr, r = zip(*(types.sint.get_dabit() for i in range(k))) + r_prime = types.sint.bit_compose(rr) + else: + r = [types.sint.get_random_bit() for i in range(k)] + r_prime = types.sint.bit_compose(r) c = ((a + r_prime) << shift).reveal() >> shift return c, r @@ -53,8 +57,8 @@ def EQZ(a, k, kappa): c, r = maskField(a, k, kappa) d = [None]*k for i,b in enumerate(bits(c, k)): - d[i] = b + r[i] - 2*b*r[i] - return 1 - KOR(d, kappa) + d[i] = r[i].bit_xor(b) + return 1 - types.sint.conv(KOR(d, kappa)) def bits(a,m): """ Get the bits of an int """ @@ -82,10 +86,10 @@ def carry(b, a, compute_p=True): (p,g) = (p_2, g_2)o(p_1, g_1) -> (p_1 & p_2, g_2 | (p_2 & g_1)) """ if compute_p: - t1 = a[0]*b[0] + t1 = a[0].bit_and(b[0]) else: t1 = None - t2 = a[1] + a[0]*b[1] + t2 = a[1] + a[0].bit_and(b[1]) return (t1, t2) def or_op(a, b, void=None): @@ -197,7 +201,7 @@ def KORL(a, kappa): else: t1 = KORL(a[:k//2], kappa) t2 = KORL(a[k//2:], kappa) - return t1 + t2 - t1*t2 + return t1 + t2 - t1.bit_and(t2) def KORC(a, kappa): return PreORC(a, kappa, 1)[0] @@ -295,11 +299,16 @@ def BitDec(a, k, m, kappa, bits_to_compute=None): def BitDecRing(a, k, m): n_shift = int(program.Program.prog.options.ring) - m - r_bits = [types.sint.get_random_bit() for i in range(m)] - r = types.sint.bit_compose(r_bits) + if program.Program.prog.use_dabit: + r, r_bits = zip(*(types.sint.get_dabit() for i in range(m))) + r = types.sint.bit_compose(r) + else: + r_bits = [types.sint.get_random_bit() for i in range(m)] + r = types.sint.bit_compose(r_bits) shifted = ((a - r) << n_shift).reveal() masked = shifted >> n_shift - return types.intbitint.bit_adder(r_bits, masked.bit_decompose(m)) + bits = r_bits[0].bit_adder(r_bits, masked.bit_decompose(m)) + return [types.sint.conv(bit) for bit in bits] def BitDecField(a, k, m, kappa, bits_to_compute=None): r_dprime = types.sint() @@ -319,7 +328,8 @@ def BitDecField(a, k, m, kappa, bits_to_compute=None): print('BitDec assertion failed') print('a =', a.value) print('a mod 2^%d =' % k, (a.value % 2**k)) - return types.intbitint.bit_adder(list(bits(c,m)), r) + res = r[0].bit_adder(r, list(bits(c,m))) + return [types.sint.conv(bit) for bit in res] def Pow2(a, l, kappa): @@ -345,17 +355,21 @@ def B2U_from_Pow2(pow2a, l, kappa): r = [types.sint() for i in range(l)] t = types.sint() c = types.cint() - for i in range(l): - bit(r[i]) + if program.Program.prog.use_dabit: + r, r_bits = zip(*(types.sint.get_dabit() for i in range(l))) + else: + for i in range(l): + bit(r[i]) + r_bits = r comparison.PRandInt(t, kappa) asm_open(c, pow2a + two_power(l) * t + sum(two_power(i)*r[i] for i in range(l))) comparison.program.curr_tape.require_bit_length(l + kappa) c = list(bits(c, l)) - x = [c[i] + r[i] - 2*c[i]*r[i] for i in range(l)] + x = [r_bits[i].bit_xor(c[i]) for i in range(l)] #print ' '.join(str(b.value) for b in x) y = PreOR(x, kappa) #print ' '.join(str(b.value) for b in y) - return [1 - y[i] for i in range(l)] + return [types.sint.conv(1 - y[i]) for i in range(l)] def Trunc(a, l, m, kappa, compute_modulo=False, signed=False): """ Oblivious truncation by secret m """ @@ -532,7 +546,7 @@ def TruncPrField(a, k, m, kappa=None): b = two_power(k-1) + a r_prime, r_dprime = types.sint(), types.sint() comparison.PRandM(r_dprime, r_prime, [types.sint() for i in range(m)], - k, m, kappa) + k, m, kappa, use_dabit=False) two_to_m = two_power(m) r = two_to_m * r_dprime + r_prime c = (b + r).reveal() diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 2209cfac7..23db3fc14 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -801,6 +801,15 @@ class bit(base.DataInstruction): def execute(self): self.args[0].value = randint(0,1) +@base.vectorize +class dabit(base.DataInstruction): + """ daBit """ + __slots__ = [] + code = base.opcodes['DABIT'] + arg_format = ['sw', 'sbw'] + field_type = 'modp' + data_type = 'bit' + @base.gf2n @base.vectorize class square(base.DataInstruction): diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index 8683c305b..89dff9ca6 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -102,6 +102,7 @@ GBITGF2NTRIPLE = 0x155, INPUTMASK = 0x56, PREP = 0x57, + DABIT = 0x58, # Input INPUT = 0x60, INPUTFIX = 0xF0, @@ -237,11 +238,11 @@ def __init__(self, size, *args, **kwargs): if issubclass(ArgFormats[f], RegisterArgFormat): arg.set_size(size) def get_code(self): - return (self.size << 9) + self.code + return (self.size << 10) + self.code def get_pre_arg(self): return "%d, " % self.size def is_vec(self): - return self.size > 1 + return True def get_size(self): return self.size def expand(self): @@ -547,8 +548,8 @@ def check_args(self): try: ArgFormats[f].check(arg) except ArgumentError as e: - raise CompilerError('Invalid argument "%s" to instruction: %s' - % (e.arg, self) + '\n' + e.msg) + raise CompilerError('Invalid argument %d "%s" to instruction: %s' + % (n, e.arg, self) + '\n' + e.msg) except KeyError as e: raise CompilerError('Unknown argument %s for instruction %s' % (f, self)) diff --git a/Compiler/library.py b/Compiler/library.py index 5314ee94c..99345bb1f 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -909,6 +909,7 @@ def exit_elimination(block): for block in blocks[-n_to_merge + 1:]: merged.instructions += block.instructions exit_elimination(block) + block.purge() del blocks[-n_to_merge + 1:] del get_tape().req_node.children[-1] merged.children = [] @@ -956,8 +957,6 @@ def multithread(n_threads, n_items): Distribute the computation of n_items to n_threads threads, but leave the in-thread repetition up to the user """ - if n_threads == 1 or n_items == 1: - return lambda loop_body: loop_body(0, n_items) return map_reduce(n_threads, None, n_items, initializer=lambda: [], reducer=None, looping=False) @@ -977,13 +976,6 @@ def new_body(i): return new_body new_dec = map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, thread_mem_req) return lambda loop_body: new_dec(decorator(loop_body)) - if n_threads == 1 or n_loops == 1: - dec = map_reduce_single(n_parallel, n_loops, initializer, reducer) - if thread_mem_req: - thread_mem = Array(thread_mem_req[regint], regint) - return lambda loop_body: dec(lambda i: loop_body(i, thread_mem)) - else: - return dec def decorator(loop_body): thread_rounds = n_loops // n_threads remainder = n_loops % n_threads diff --git a/Compiler/oram.py b/Compiler/oram.py index 93e34c844..4d61b2d86 100644 --- a/Compiler/oram.py +++ b/Compiler/oram.py @@ -1404,7 +1404,7 @@ def __init__(self, size, entry_size=None, value_type=sint, init_rounds=-1, \ self.value_type = value_type for demux_bits in range(max_demux_bits + 1): self.log_entries_per_element = min(log2(size), \ - int(math.floor(math.log(float(get_value_size(value_type)) // \ + int(math.floor(math.log(float(get_value_size(value_type)) / \ sum(self.entry_size), 2)))) self.log_elements_per_block = \ max(0, min(demux_bits, log2(size) - \ diff --git a/Compiler/program.py b/Compiler/program.py index eebc411a3..7b2537884 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -94,6 +94,7 @@ def __init__(self, args, options, param=-1, assemblymode=False): self.to_merge += [gc.ldmsdi, gc.stmsdi, gc.ldmsd, gc.stmsd, \ gc.stmsdci, gc.xors, gc.andrs, gc.ands, gc.inputb] self.use_trunc_pr = False + self.use_dabit = options.mixed Program.prog = self self.reset_values() @@ -462,6 +463,7 @@ def __init__(self, parent, name, scope, exit_condition=None): self.purged = False self.n_rounds = 0 self.n_to_merge = 0 + self.defined_registers = None def __len__(self): return len(self.instructions) @@ -502,9 +504,14 @@ def adjust_jump(self): #print 'Basic block %d jumps to %d (%d)' % (next_block_index, jump_index, offset) def purge(self): - relevant = lambda inst: inst.add_usage is not \ - Compiler.instructions_base.Instruction.add_usage + def relevant(inst): + req_node = Tape.ReqNode('') + req_node.num = Tape.ReqNum() + inst.add_usage(req_node) + return req_node.num != {} self.usage_instructions = list(filter(relevant, self.instructions)) + if len(self.usage_instructions) > 1000: + print('Retaining %d instructions' % len(self.usage_instructions)) del self.instructions del self.defined_registers self.purged = True @@ -899,7 +906,7 @@ class Register(object): The 'value' property is for emulation. """ - __slots__ = ["reg_type", "program", "i", "_is_active", \ + __slots__ = ["reg_type", "program", "absolute_i", "relative_i", \ "size", "vector", "vectorbase", "caller", \ "can_eliminate"] @@ -916,16 +923,16 @@ def __init__(self, reg_type, program, value=None, size=None, i=None): if size is None: size = Compiler.instructions_base.get_global_vector_size() self.size = size + self.vectorbase = self + self.relative_i = 0 if i is not None: self.i = i else: self.i = program.reg_counter[reg_type] program.reg_counter[reg_type] += size self.vector = [] - self.vectorbase = self if value is not None: self.value = value - self._is_active = False self.can_eliminate = True if Program.prog.DEBUG: self.caller = [frame[1:] for frame in inspect.stack()[1:]] @@ -934,9 +941,19 @@ def __init__(self, reg_type, program, value=None, size=None, i=None): if self.i % 1000000 == 0 and self.i > 0: print("Initialized %d registers at" % self.i, time.asctime()) + @property + def i(self): + return self.vectorbase.absolute_i + self.relative_i + + @i.setter + def i(self, value): + self.vectorbase.absolute_i = value - self.relative_i + def set_size(self, size): if self.size == size: return + elif not self.program.options.assemblymode: + raise CompilerError('Mismatch of instruction and register size') elif self.size == 1 and self.vectorbase is self: if '%s%d' % (self.reg_type, self.i) in compilerLib.VARS: # create vector register in assembly mode @@ -955,10 +972,18 @@ def set_vectorbase(self, vectorbase): if self.vectorbase is not self: raise CompilerError('Cannot assign one register' \ 'to several vectors') + self.relative_i = self.i - vectorbase.i self.vectorbase = vectorbase - def _new_by_number(self, i): - return Tape.Register(self.reg_type, self.program, size=1, i=i) + def _new_by_number(self, i, size=1): + return Tape.Register(self.reg_type, self.program, size=size, i=i) + + def get_vector(self, base, size): + res = self._new_by_number(self.i + base, size=size) + res.set_vectorbase(self) + self.create_vector_elements() + res.vector = self.vector[base:base+size] + return res def create_vector_elements(self): if self.vector: @@ -983,14 +1008,6 @@ def __getitem__(self, index): def __len__(self): return self.size - def activate(self): - """ Activating a register signals that it will at some point be used - in the program. - - Inactive registers are reserved for temporaries for CISC instructions. """ - if not self._is_active: - self._is_active = True - @property def value(self): return self.program.reg_values[self.reg_type][self.i] @@ -1000,10 +1017,6 @@ def value(self, val): while (len(self.program.reg_values[self.reg_type]) <= self.i): self.program.reg_values[self.reg_type] += [0] * INIT_REG_MAX self.program.reg_values[self.reg_type][self.i] = val - - @property - def is_active(self): - return self._is_active @property def is_gf2n(self): diff --git a/Compiler/types.py b/Compiler/types.py index f72690a7b..4ac5282f2 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -175,6 +175,10 @@ def max(self, other): return (self < other).if_else(other, self) class _int(object): + @staticmethod + def bit_adder(*args, **kwargs): + return intbitint.bit_adder(*args, **kwargs) + def if_else(self, a, b): if hasattr(a, 'for_mux'): f, a, b = a.for_mux(b) @@ -189,7 +193,24 @@ def cond_swap(self, a, b): def bit_xor(self, other): return self + other - 2 * self * other -class _gf2n(object): + def bit_and(self, other): + return self * other + + def half_adder(self, other): + carry = self * other + return self + other - 2 * carry, carry + +class _bit(object): + def bit_xor(self, other): + return self ^ other + + def bit_and(self, other): + return self & other + + def half_adder(self, other): + return self ^ other, self & other + +class _gf2n(_bit): def if_else(self, a, b): return b ^ self * self.hard_conv(a ^ b) @@ -310,6 +331,12 @@ def __init__(self, reg_type, val, size): elif val is not None: self.load_other(val) + def _new_by_number(self, i, size=1): + res = type(self)(size=size) + res.i = i + res.program = self.program + return res + def sizeof(self): return self.size @@ -509,6 +536,7 @@ def load_int(self, val): elif chunk: sum += sign * chunk + @vectorize def to_regint(self, n_bits=None, dest=None): dest = regint() if dest is None else dest convmodp(dest, self, bitlength=n_bits) @@ -599,6 +627,7 @@ def right_shift(self, other, bit_length=None): def greater_than(self, other, bit_length=None): return self > other + @vectorize def bit_decompose(self, bit_length=None): if bit_length == 0: return [] @@ -1065,19 +1094,20 @@ def load_clear(self, val): @read_mem_value @vectorize def load_other(self, val): + from Compiler.GC.types import sbits if isinstance(val, self.clear_type): self.load_clear(val) elif isinstance(val, type(self)): movs(self, val) + elif isinstance(val, sbits): + assert(val.n == self.size) + r = self.get_dabit() + v = regint() + bitdecint_class(regint((r[1] ^ val).reveal()), *v) + movs(self, r[0].bit_xor(v)) else: self.load_clear(self.clear_type(val)) - def _new_by_number(self, i): - res = type(self)(size=1) - res.i = i - res.program = self.program - return res - @set_instruction_type @read_mem_value @vectorize @@ -1179,6 +1209,18 @@ def get_input_from(cls, player): inputmixed('int', res, player) return res + @vectorized_classmethod + def get_dabit(cls): + """ Bit in arithmetic and binary circuit according to security model """ + from Compiler.GC.types import sbits + res = cls(), sbits.get_type(get_global_vector_size())() + dabit(*res) + return res + + @staticmethod + def long_one(): + return 1 + @classmethod def get_raw_input_from(cls, player): res = cls() @@ -1354,6 +1396,7 @@ def __rlshift__(self, other): def __rrshift__(self, other): return floatingpoint.Trunc(other, program.bit_length, self, program.security) + @vectorize def bit_decompose(self, bit_length=None, security=None): if bit_length == 0: return [] @@ -1519,6 +1562,10 @@ class _bitint(object): log_rounds = False linear_rounds = False + @staticmethod + def half_adder(a, b): + return a.half_adder(b) + @classmethod def bit_adder(cls, a, b, carry_in=0, get_carry=False): a, b = list(a), list(b) @@ -1618,10 +1665,6 @@ def full_adder(a, b, carry): s = a + b return s + carry, util.if_else(s, carry, a) - @staticmethod - def half_adder(a, b): - return a + b, a & b - @staticmethod def bit_comparator(a, b): long_one = util.long_one(a + b) @@ -1814,11 +1857,6 @@ def full_adder(a, b, carry): s = a.bit_xor(b) return s.bit_xor(carry), util.if_else(s, carry, a) - @staticmethod - def half_adder(a, b): - carry = a * b - return a + b - 2 * carry, carry - @staticmethod def sum_from_carries(a, b, carries): return [a[i] + b[i] + carries[i] - 2 * carries[i + 1] \ @@ -3284,6 +3322,7 @@ def assign_all(self, value, use_threads=True, conv=True): if conv: value = self.value_type.conv(value) mem_value = MemValue(value) + self.address = MemValue(self.address) n_threads = 8 if use_threads and len(self) > 2**20 else 1 @library.for_range_multithread(n_threads, 1024, len(self)) def f(i): diff --git a/Compiler/util.py b/Compiler/util.py index 8b7ea214a..1583b86e8 100644 --- a/Compiler/util.py +++ b/Compiler/util.py @@ -197,6 +197,11 @@ def __iter__(self): def add(self, value): self.content[id(value)] = value + def __ior__(self, values): + for value in values: + self.add(value) + return self + class dict_by_id(object): def __init__(self): self.content = {} diff --git a/ECDSA/hm-ecdsa-party.hpp b/ECDSA/hm-ecdsa-party.hpp index 3ab383a80..3863ff16b 100644 --- a/ECDSA/hm-ecdsa-party.hpp +++ b/ECDSA/hm-ecdsa-party.hpp @@ -21,6 +21,9 @@ #include "Processor/Input.hpp" #include "Processor/Processor.hpp" #include "Processor/Data_Files.hpp" +#include "GC/ShareSecret.hpp" +#include "GC/RepPrep.hpp" +#include "GC/ThreadMaster.hpp" #include diff --git a/ECDSA/mascot-ecdsa-party.cpp b/ECDSA/mascot-ecdsa-party.cpp index 6b2fa5650..0cc65edf2 100644 --- a/ECDSA/mascot-ecdsa-party.cpp +++ b/ECDSA/mascot-ecdsa-party.cpp @@ -3,8 +3,13 @@ * */ +#include "GC/TinierSecret.h" +#include "GC/TinyMC.h" + #include "Protocols/Share.hpp" #include "Protocols/MAC_Check.hpp" +#include "GC/Secret.hpp" +#include "GC/TinierSharePrep.hpp" #include "ot-ecdsa-party.hpp" #include diff --git a/ECDSA/semi-ecdsa-party.cpp b/ECDSA/semi-ecdsa-party.cpp index 14b0a2925..6bdcec286 100644 --- a/ECDSA/semi-ecdsa-party.cpp +++ b/ECDSA/semi-ecdsa-party.cpp @@ -3,6 +3,9 @@ * */ +#include "GC/SemiSecret.h" +#include "GC/SemiPrep.h" + #include "Protocols/SemiMC.hpp" #include "Protocols/SemiPrep.hpp" #include "Protocols/SemiInput.hpp" diff --git a/Exceptions/Exceptions.h b/Exceptions/Exceptions.h index 27efea6c9..6eb713c0c 100644 --- a/Exceptions/Exceptions.h +++ b/Exceptions/Exceptions.h @@ -212,4 +212,13 @@ class closed_connection } }; +class no_singleton : runtime_error +{ +public: + no_singleton(string msg) : + runtime_error(msg) + { + } +}; + #endif diff --git a/GC/FakeSecret.h b/GC/FakeSecret.h index f6a6d7b98..b5f54e7b6 100644 --- a/GC/FakeSecret.h +++ b/GC/FakeSecret.h @@ -34,12 +34,6 @@ class FakeSecret typedef FakeSecret DynamicType; typedef Memory DynamicMemory; - // dummy - typedef DummyMC MC; - typedef DummyProtocol Protocol; - - static MC* new_mc(Machine& _) { (void) _; return new MC; } - static string type_string() { return "fake secret"; } static string phase_name() { return "Faking"; } diff --git a/GC/Instruction.h b/GC/Instruction.h index b8e3d7958..e77c50e57 100644 --- a/GC/Instruction.h +++ b/GC/Instruction.h @@ -15,50 +15,29 @@ using namespace std; namespace GC { -// Register types -enum RegType { - SBIT, - CBIT, - INT, - DYN_SBIT, - MAX_REG_TYPE, - NONE -}; - template class Processor; -template class Instruction : public ::BaseInstruction { - bool (*code)(const Instruction& instruction, Processor& processor); public: Instruction(); - int get_r(int i) const { return r[i]; } - unsigned int get_n() const { return n; } - const vector& get_start() const { return start; } - int get_opcode() const { return opcode; } - // Reads a single instruction from the istream void parse(istream& s, int pos); // Return whether usage is known bool get_offline_data_usage(int& usage); - int get_reg_type() const; - - // Returns the maximal register used - unsigned get_max_reg(int reg_type) const; - // Returns the memory size used if applicable and known unsigned get_mem(RegType reg_type) const; // Execute this instruction - bool exe(Processor& processor) const { return code(*this, processor); } - template + template bool execute(Processor& processor, U& dynamic_memory) const; }; +} /* namespace GC */ + enum { // GC specific @@ -77,20 +56,37 @@ enum LDBITS = 0x20a, ANDS = 0x20b, TRANS = 0x20c, + BITB = 0x20d, + ANDM = 0x20e, + LDMSB = 0x240, + STMSB = 0x241, + LDMSBI = 0x242, + STMSBI = 0x243, + MOVSB = 0x244, + INPUTB = 0x246, // write to clear CLEAR_WRITE = 0x210, - XORCI = 0x210, + XORCBI = 0x210, BITDECC = 0x211, CONVCINT = 0x213, REVEAL = 0x214, STMSDCI = 0x215, - INPUTB = 0x216, + LDMCB = 0x217, + STMCB = 0x218, + XORCB = 0x219, + ADDCB = 0x21a, + ADDCBI = 0x21b, + MULCBI = 0x21c, + SHRCBI = 0x21d, + SHLCBI = 0x21e, // don't write PRINTREGSIGNED = 0x220, + PRINTREGB = 0x221, + PRINTREGPLAINB = 0x222, + PRINTFLOATPLAINB = 0x223, + CONDPRINTSTRB = 0x224, // write to regint CONVCBIT = 0x230, }; -} /* namespace GC */ - #endif /* PROCESSOR_GC_INSTRUCTION_H_ */ diff --git a/GC/Instruction.hpp b/GC/Instruction.hpp index add5c7b78..578e7d5b4 100644 --- a/GC/Instruction.hpp +++ b/GC/Instruction.hpp @@ -17,16 +17,15 @@ namespace GC { -template -Instruction::Instruction() : +inline +Instruction::Instruction() : BaseInstruction() { - code = fallback_code; size = 1; } -template -bool Instruction::get_offline_data_usage(int& usage) +inline +bool Instruction::get_offline_data_usage(int& usage) { switch (opcode) { @@ -38,79 +37,8 @@ bool Instruction::get_offline_data_usage(int& usage) } } -template -int Instruction::get_reg_type() const -{ - switch (opcode & 0x2F0) - { - case SECRET_WRITE: - return SBIT; - case CLEAR_WRITE: - return CBIT; - default: - switch (::BaseInstruction::get_reg_type()) - { - case ::INT: - return INT; - case ::MODP: - switch (opcode) - { - case LDMC: - case STMC: - case XORC: - case ADDC: - case ADDCI: - case MULCI: - case SHRCI: - case SHLCI: - return CBIT; - } - return SBIT; - } - return NONE; - } -} - -template -unsigned GC::Instruction::get_max_reg(int reg_type) const -{ - int skip; - int offset = 0; - switch (opcode) - { - case LDMSD: - case LDMSDI: - skip = 3; - break; - case STMSD: - case STMSDI: - skip = 2; - break; - case ANDRS: - case XORS: - case ANDS: - skip = 4; - offset = 1; - break; - case INPUTB: - skip = 4; - offset = 3; - break; - case CONVCBIT: - return BaseInstruction::get_max_reg(INT); - default: - return BaseInstruction::get_max_reg(reg_type); - } - - unsigned m = 0; - if (reg_type == SBIT) - for (size_t i = offset; i < start.size(); i += skip) - m = max(m, (unsigned)start[i] + 1); - return m; -} - -template -unsigned Instruction::get_mem(RegType reg_type) const +inline +unsigned Instruction::get_mem(RegType reg_type) const { unsigned m = n + 1; switch (opcode) @@ -133,27 +61,15 @@ unsigned Instruction::get_mem(RegType reg_type) const return m; } break; - case LDMS: - case STMS: - if (reg_type == SBIT) - return m; - break; - case LDMC: - case STMC: - if (reg_type == CBIT) - return m; - break; - case LDMINT: - case STMINT: - if (reg_type == INT) - return m; - break; + default: + return BaseInstruction::get_mem(reg_type, MAX_SECRECY_TYPE); } + return 0; } -template -void Instruction::parse(istream& s, int pos) +inline +void Instruction::parse(istream& s, int pos) { n = 0; start.resize(0); @@ -162,67 +78,7 @@ void Instruction::parse(istream& s, int pos) int file_pos = s.tellg(); opcode = ::get_int(s); - try { - parse_operands(s, pos); - } - catch (Invalid_Instruction& e) - { - int m; - switch (opcode) - { - case XORM: - n = get_int(s); - get_ints(r, s, 3); - break; - case XORCI: - case MULCI: - case LDBITS: - get_ints(r, s, 2); - n = get_int(s); - break; - case BITDECS: - case BITCOMS: - case BITDECC: - m = get_int(s) - 1; - get_ints(r, s, 1); - get_vector(m, start, s); - break; - case CONVCINT: - case CONVCBIT: - get_ints(r, s, 2); - break; - case REVEAL: - case CONVSINT: - n = get_int(s); - get_ints(r, s, 2); - break; - case LDMSDI: - case STMSDI: - case LDMSD: - case STMSD: - case STMSDCI: - case XORS: - case ANDRS: - case ANDS: - case INPUTB: - get_vector(get_int(s), start, s); - break; - case PRINTREGSIGNED: - n = get_int(s); - get_ints(r, s, 1); - break; - case TRANS: - m = get_int(s) - 1; - n = get_int(s); - get_vector(m, start, s); - break; - default: - ostringstream os; - os << "Invalid instruction " << showbase << hex << opcode - << " at " << dec << pos << "/" << hex << file_pos << dec; - throw Invalid_Instruction(os.str()); - } - } + parse_operands(s, pos, file_pos); switch(opcode) { diff --git a/GC/Instruction_inline.h b/GC/Instruction_inline.h index 1ab2735b5..2eddb3775 100644 --- a/GC/Instruction_inline.h +++ b/GC/Instruction_inline.h @@ -19,7 +19,7 @@ namespace GC { #include "instructions.h" template -inline bool fallback_code(const Instruction& instruction, Processor& processor) +inline bool fallback_code(const Instruction& instruction, Processor& processor) { (void)processor; cout << "Undefined instruction " << showbase << hex @@ -27,9 +27,8 @@ inline bool fallback_code(const Instruction& instruction, Processor& proce return true; } -template -template -MAYBE_INLINE bool Instruction::execute(Processor& processor, +template +MAYBE_INLINE bool Instruction::execute(Processor& processor, U& dynamic_memory) const { #ifdef DEBUG_OPS diff --git a/GC/Machine.h b/GC/Machine.h index 88c62c54b..c3d8987e1 100644 --- a/GC/Machine.h +++ b/GC/Machine.h @@ -21,11 +21,22 @@ namespace GC template class Program; template -class Machine : public ::BaseMachine +class Memories { public: Memory MS; Memory MC; + + template + void reset(const U& program); + + void write_memory(int my_num); +}; + +template +class Machine : public ::BaseMachine, public Memories +{ +public: Memory MI; vector > progs; @@ -50,8 +61,6 @@ class Machine : public ::BaseMachine void run_tape(int thread_number, int tape_number, int arg); void join_tape(int thread_numer); - - void write_memory(int my_num); }; } /* namespace GC */ diff --git a/GC/Machine.hpp b/GC/Machine.hpp index 2c5b18752..09ac9fd26 100644 --- a/GC/Machine.hpp +++ b/GC/Machine.hpp @@ -3,6 +3,9 @@ * */ +#ifndef GC_MACHINE_HPP_ +#define GC_MACHINE_HPP_ + #include #include "GC/Program.h" @@ -52,13 +55,20 @@ void Machine::load_schedule(string progname) print_compiler(); } +template +template +void Memories::reset(const U& program) +{ + MS.resize_min(*program.direct_mem(SBIT), "memory"); + MC.resize_min(*program.direct_mem(CBIT), "memory"); +} + template template void Machine::reset(const U& program) { - MS.resize_min(program.direct_mem(SBIT), "memory"); - MC.resize_min(program.direct_mem(CBIT), "memory"); - MI.resize_min(program.direct_mem(INT), "memory"); + Memories::reset(program); + MI.resize_min(*program.direct_mem(INT), "memory"); } template @@ -66,7 +76,7 @@ template void Machine::reset(const U& program, V& MD) { reset(program); - MD.resize_min(program.direct_mem(DYN_SBIT), "dynamic memory"); + MD.resize_min(*program.direct_mem(DYN_SBIT), "dynamic memory"); #ifdef DEBUG_MEMORY cerr << "reset dynamic mem to " << program.direct_mem(DYN_SBIT) << endl; #endif @@ -85,12 +95,14 @@ void Machine::join_tape(int thread_number) } template -void GC::Machine::write_memory(int my_num) +void Memories::write_memory(int my_num) { - ofstream outf(memory_filename("B", my_num)); + ofstream outf(BaseMachine::memory_filename("B", my_num)); outf << 0 << endl; outf << MC.size() << endl << MC; outf << 0 << endl << 0 << endl << 0 << endl << 0 << endl; } } /* namespace GC */ + +#endif diff --git a/GC/MaliciousRepSecret.h b/GC/MaliciousRepSecret.h index 5b4dd0b3a..5acbfcb69 100644 --- a/GC/MaliciousRepSecret.h +++ b/GC/MaliciousRepSecret.h @@ -8,6 +8,7 @@ #include "ShareSecret.h" #include "Machine.h" +#include "ThreadMaster.h" #include "Protocols/Beaver.h" #include "Protocols/MaliciousRepMC.h" #include "Processor/DummyProtocol.h" @@ -34,12 +35,19 @@ class MaliciousRepSecret : public ReplicatedSecret typedef ReplicatedInput Input; typedef RepPrep LivePrep; - static MC* new_mc(Machine& machine) + typedef MaliciousRepSecret part_type; + + static MC* new_mc(mac_key_type) { - if (machine.more_comm_less_comp) - return new CommMaliciousRepMC; - else - return new HashMaliciousRepMC; + try + { + if (ThreadMaster::s().machine.more_comm_less_comp) + return new CommMaliciousRepMC; + } + catch(no_singleton& e) + { + } + return new HashMaliciousRepMC; } static MaliciousRepSecret constant(const BitVec& other, int my_num, const BitVec& alphai) diff --git a/GC/NoShare.h b/GC/NoShare.h new file mode 100644 index 000000000..d32bc2905 --- /dev/null +++ b/GC/NoShare.h @@ -0,0 +1,122 @@ +/* + * NoShare.h + * + */ + +#ifndef GC_NOSHARE_H_ +#define GC_NOSHARE_H_ + +#include "BMR/Register.h" +#include "Processor/DummyProtocol.h" + +namespace GC +{ + +class NoValue +{ +public: + const static int n_bits = 0; + + static bool allows(Dtype) + { + return false; + } + + static int size() + { + return 0; + } + + static void fail() + { + throw runtime_error("VM does not support binary circuits"); + } + + void assign(const char*) { fail(); } + + int get() const { fail(); return 0; } +}; + +class NoShare : public Phase +{ +public: + typedef DummyMC MC; + typedef DummyProtocol Protocol; + typedef NotImplementedInput Input; + typedef DummyLivePrep LivePrep; + typedef DummyMC MAC_Check; + + typedef NoValue open_type; + typedef NoValue clear; + typedef NoValue mac_key_type; + + typedef NoShare bit_type; + typedef NoShare part_type; + + static const bool needs_ot = false; + + static MC* new_mc(mac_key_type) + { + return new MC; + } + + template + static void generate_mac_key(mac_key_type, T) + { + } + + static DataFieldType field_type() + { + throw not_implemented(); + } + + static string type_short() + { + return ""; + } + + static string type_string() + { + return "no"; + } + + static int size() + { + return 0; + } + + static void fail() + { + NoValue::fail(); + } + + static void inputb(Processor&, const vector&) { fail(); } + + static void input(Processor&, InputArgs&) { fail(); } + static void trans(Processor&, Integer, const vector&) { fail(); } + + NoShare() {} + + NoShare(int) { fail(); } + + void load_clear(Integer, Integer) { fail(); } + void random_bit() { fail(); } + void and_(int, NoShare&, NoShare&, bool) { fail(); } + void xor_(int, NoShare&, NoShare&) { fail(); } + void bitdec(vector&, const vector&) const { fail(); } + void bitcom(vector&, const vector&) const { fail(); } + void reveal(Integer, Integer) { fail(); } + + void assign(const char*) { fail(); } + + NoShare operator&(const Clear&) const { fail(); return {}; } + + NoShare operator<<(int) const { fail(); return {}; } + void operator^=(NoShare) { fail(); } + + NoShare operator+(const NoShare&) const { fail(); return {}; } +}; + +} /* namespace GC */ + +#endif /* GC_NOSHARE_H_ */ diff --git a/GC/Processor.h b/GC/Processor.h index d2f8082c4..3304d182a 100644 --- a/GC/Processor.h +++ b/GC/Processor.h @@ -39,7 +39,8 @@ class Processor : public ::ProcessorBase static void check_input(bigint in, int n_bits); - Machine& machine; + Machine* machine; + Memories& memories; unsigned int PC; unsigned int time; @@ -54,6 +55,7 @@ class Processor : public ::ProcessorBase ExecutionStats stats; Processor(Machine& machine); + Processor(Memories& memories, Machine* machine = 0); ~Processor(); template diff --git a/GC/Processor.hpp b/GC/Processor.hpp index 9ee5c2b92..5757718b9 100644 --- a/GC/Processor.hpp +++ b/GC/Processor.hpp @@ -3,6 +3,9 @@ * */ +#ifndef GC_PROCESSOR_HPP_ +#define GC_PROCESSOR_HPP_ + #include #include @@ -13,6 +16,7 @@ using namespace std; #include "Access.h" #include "Processor/FixInput.h" +#include "GC/Machine.hpp" #include "Processor/ProcessorBase.hpp" namespace GC @@ -20,7 +24,13 @@ namespace GC template Processor::Processor(Machine& machine) : - machine(machine), PC(0), time(0), + Processor(machine, &machine) +{ +} + +template +Processor::Processor(Memories& memories, Machine* machine) : + machine(machine), memories(memories), PC(0), time(0), complexity(0) { } @@ -49,7 +59,9 @@ template void Processor::reset(const U& program) { reset(program, 0); - machine.reset(program); + if (machine) + machine->reset(program); + memories.reset(program); } template @@ -248,3 +260,5 @@ void Processor::print_float_prec(int n) } } /* namespace GC */ + +#endif diff --git a/GC/Program.h b/GC/Program.h index 53b0e3a08..c91dec955 100644 --- a/GC/Program.h +++ b/GC/Program.h @@ -7,6 +7,7 @@ #define GC_PROGRAM_H_ #include "GC/Instruction.h" +#include "Processor/Program.h" #include using namespace std; @@ -26,7 +27,7 @@ template class Processor; template class Program { - vector< Instruction > p; + vector p; int offline_data_used; // Maximal register used @@ -57,8 +58,8 @@ class Program unsigned num_reg(RegType reg_type) const { return max_reg[reg_type]; } - unsigned direct_mem(RegType reg_type) const - { return max_mem[reg_type]; } + const unsigned* direct_mem(RegType reg_type) const + { return &max_mem[reg_type]; } template BreakType execute(Processor& Proc, U& dynamic_memory, int PC = -1) const; diff --git a/GC/Program.hpp b/GC/Program.hpp index 24ceee76f..b90694719 100644 --- a/GC/Program.hpp +++ b/GC/Program.hpp @@ -61,7 +61,7 @@ template void Program::parse(istream& s) { p.resize(0); - Instruction instr; + Instruction instr; s.peek(); int pos = 0; CALLGRIND_STOP_INSTRUMENTATION; diff --git a/GC/RepPrep.h b/GC/RepPrep.h index 5f1e03828..90464b192 100644 --- a/GC/RepPrep.h +++ b/GC/RepPrep.h @@ -13,13 +13,16 @@ namespace GC { +template class ShareThread; + template class RepPrep : public BufferPrep, ShiftableTripleBuffer { ReplicatedBase* protocol; public: - RepPrep(DataPositions& usage, Thread& thread); + RepPrep(DataPositions& usage, ShareThread& thread); + RepPrep(DataPositions& usage); ~RepPrep(); void set_protocol(typename T::Protocol& protocol); diff --git a/GC/RepPrep.hpp b/GC/RepPrep.hpp index 2c35b8e91..2b6fe091f 100644 --- a/GC/RepPrep.hpp +++ b/GC/RepPrep.hpp @@ -16,12 +16,18 @@ namespace GC { template -RepPrep::RepPrep(DataPositions& usage, Thread& thread) : - BufferPrep(usage), protocol(0) +RepPrep::RepPrep(DataPositions& usage, ShareThread& thread) : + RepPrep(usage) { (void) thread; } +template +RepPrep::RepPrep(DataPositions& usage) : + BufferPrep(usage), protocol(0) +{ +} + template RepPrep::~RepPrep() { @@ -39,7 +45,7 @@ template void RepPrep::buffer_triples() { assert(protocol != 0); - auto MC = ShareThread::s().new_mc(); + auto MC = ShareThread::s().new_mc({}); shuffle_triple_generation(this->triples, protocol->P, *MC, 64); delete MC; } diff --git a/GC/Secret.h b/GC/Secret.h index 02def12df..9ed15fa71 100644 --- a/GC/Secret.h +++ b/GC/Secret.h @@ -70,11 +70,7 @@ class Secret public: typedef typename T::DynamicMemory DynamicMemory; - // dummy - typedef DummyMC MC; - typedef DummyProtocol Protocol; - - static MC* new_mc(Machine& _) { (void) _; return new MC; } + typedef NoShare bit_type; static string type_string() { return "evaluation secret"; } static string phase_name() { return T::name(); } diff --git a/GC/Secret.hpp b/GC/Secret.hpp index 90ba1ddbe..d71adede7 100644 --- a/GC/Secret.hpp +++ b/GC/Secret.hpp @@ -3,6 +3,9 @@ * */ +#ifndef GC_SECRET_HPP_ +#define GC_SECRET_HPP_ + #include "Secret.h" #include "Secret_inline.h" @@ -336,3 +339,5 @@ void Secret::reveal(size_t n_bits, U& x) } } /* namespace GC */ + +#endif diff --git a/GC/SemiHonestRepPrep.h b/GC/SemiHonestRepPrep.h index 678a436e5..2470eb2bc 100644 --- a/GC/SemiHonestRepPrep.h +++ b/GC/SemiHonestRepPrep.h @@ -15,8 +15,12 @@ namespace GC class SemiHonestRepPrep : public RepPrep { public: - SemiHonestRepPrep(DataPositions& usage, Thread& thread) : - RepPrep(usage, thread) + SemiHonestRepPrep(DataPositions& usage, ShareThread&) : + RepPrep(usage) + { + } + SemiHonestRepPrep(DataPositions& usage) : + RepPrep(usage) { } diff --git a/GC/SemiPrep.cpp b/GC/SemiPrep.cpp index 37c1e4115..949154bfc 100644 --- a/GC/SemiPrep.cpp +++ b/GC/SemiPrep.cpp @@ -15,8 +15,13 @@ namespace GC { -SemiPrep::SemiPrep(DataPositions& usage, Thread& thread) : - BufferPrep(usage), thread(thread), triple_generator(0) +SemiPrep::SemiPrep(DataPositions& usage, ShareThread&) : + SemiPrep(usage) +{ +} + +SemiPrep::SemiPrep(DataPositions& usage) : + BufferPrep(usage), triple_generator(0) { } @@ -25,9 +30,9 @@ void SemiPrep::set_protocol(Beaver& protocol) (void) protocol; params.set_passive(); triple_generator = new SemiSecret::TripleGenerator( - thread.processor.machine.ot_setups.at(thread.thread_num).get_fresh(), - thread.master.N, thread.thread_num, thread.master.opts.batch_size, - 1, params, {}, thread.P); + BaseMachine::s().fresh_ot_setup(), + protocol.P.N, -1, OnlineOptions::singleton.batch_size, + 1, params, {}, &protocol.P); triple_generator->multi_threaded = false; } @@ -50,6 +55,7 @@ SemiPrep::~SemiPrep() void SemiPrep::buffer_bits() { + auto& thread = Thread::s(); word r = thread.secure_prng.get_word(); for (size_t i = 0; i < sizeof(word) * 8; i++) this->bits.push_back((r >> i) & 1); diff --git a/GC/SemiPrep.h b/GC/SemiPrep.h index 166e97e95..04633da88 100644 --- a/GC/SemiPrep.h +++ b/GC/SemiPrep.h @@ -16,15 +16,16 @@ template class Beaver; namespace GC { +template class ShareThread; + class SemiPrep : public BufferPrep, ShiftableTripleBuffer { - Thread& thread; - SemiSecret::TripleGenerator* triple_generator; MascotParams params; public: - SemiPrep(DataPositions& usage, Thread& thread); + SemiPrep(DataPositions& usage, ShareThread& thread); + SemiPrep(DataPositions& usage); ~SemiPrep(); void set_protocol(Beaver& protocol); diff --git a/GC/SemiSecret.cpp b/GC/SemiSecret.cpp index 56dc82cdb..26805b1c1 100644 --- a/GC/SemiSecret.cpp +++ b/GC/SemiSecret.cpp @@ -26,7 +26,7 @@ void SemiSecret::trans(Processor& processor, int n_outputs, void SemiSecret::load_clear(int n, const Integer& x) { check_length(n, x); - *this = constant(x, Thread::s().P->my_num()); + *this = constant(x, ShareThread::s().P->my_num()); } void SemiSecret::bitcom(Memory& S, const vector& regs) @@ -45,7 +45,7 @@ void SemiSecret::bitdec(Memory& S, void SemiSecret::reveal(size_t n_bits, Clear& x) { - auto& thread = Thread::s(); + auto& thread = ShareThread::s(); x = thread.MC->POpen(*this, *thread.P).mask(n_bits); } diff --git a/GC/SemiSecret.h b/GC/SemiSecret.h index a0f9e0805..53dfb181c 100644 --- a/GC/SemiSecret.h +++ b/GC/SemiSecret.h @@ -29,12 +29,19 @@ class SemiSecret : public SemiShare, public ShareSecret typedef SemiPrep LivePrep; typedef SemiInput Input; + typedef SemiSecret part_type; + static const int default_length = sizeof(BitVec) * 8; static string type_string() { return "binary secret"; } static string phase_name() { return "Binary computation"; } - static MC* new_mc(Machine& _) { (void) _; return new MC; } + static MC* new_mc(mac_key_type) { return new MC; } + + template + static void generate_mac_key(mac_key_type, T) + { + } static void trans(Processor& processor, int n_outputs, const vector& args); @@ -50,6 +57,11 @@ class SemiSecret : public SemiShare, public ShareSecret SemiShare(other) { } + template + SemiSecret(const Z2& other) : + SemiShare(other) + { + } void load_clear(int n, const Integer& x); diff --git a/GC/ShareParty.hpp b/GC/ShareParty.hpp index d8a04964b..8c461594a 100644 --- a/GC/ShareParty.hpp +++ b/GC/ShareParty.hpp @@ -69,6 +69,7 @@ ShareParty::ShareParty(int argc, const char** argv, int default_batch_size) : "--communication" // Flag token. ); online_opts.finalize(opt, argc, argv); + OnlineOptions::singleton = online_opts; this->progname = online_opts.progname; int my_num = online_opts.playerno; @@ -122,7 +123,7 @@ ShareParty::ShareParty(int argc, const char** argv, int default_batch_size) : template Thread* ShareParty::new_thread(int i) { - return new ShareThread(i, *this); + return new StandaloneShareThread(i, *this); } template @@ -130,7 +131,7 @@ void ShareParty::post_run() { DataPositions usage; for (auto thread : this->threads) - usage.increase(dynamic_cast*>(thread)->usage); + usage.increase(dynamic_cast*>(thread)->usage); usage.print_cost(); } diff --git a/GC/ShareSecret.h b/GC/ShareSecret.h index dfa9ef7b5..e2135ac2c 100644 --- a/GC/ShareSecret.h +++ b/GC/ShareSecret.h @@ -13,6 +13,7 @@ using namespace std; #include "GC/Clear.h" #include "GC/Access.h" #include "GC/ArgTuples.h" +#include "GC/NoShare.h" #include "Math/FixedVec.h" #include "Math/BitVec.h" #include "Tools/SwitchableOutput.h" @@ -73,6 +74,8 @@ class ReplicatedSecret : public FixedVec, public ShareSecret typedef ReplicatedBase Protocol; + typedef NoShare bit_type; + static const int N_BITS = clear::N_BITS; static const bool dishonest_majority = false; @@ -83,9 +86,19 @@ class ReplicatedSecret : public FixedVec, public ShareSecret static int default_length; + static int threshold(int) + { + return 1; + } + static void trans(Processor& processor, int n_outputs, const vector& args); + template + static void generate_mac_key(mac_key_type, T) + { + } + ReplicatedSecret() {} template ReplicatedSecret(const T& other) : super(other) {} @@ -101,6 +114,9 @@ class ReplicatedSecret : public FixedVec, public ShareSecret { *this = x ^ y; (void)n; } void reveal(size_t n_bits, Clear& x); + + ReplicatedSecret operator&(const Clear& other) + { return super::operator&(BitVec(other)); } }; class SemiHonestRepPrep; @@ -118,7 +134,9 @@ class SemiHonestRepSecret : public ReplicatedSecret typedef SemiHonestRepPrep LivePrep; typedef ReplicatedInput Input; - static MC* new_mc(Machine& _) { (void) _; return new MC; } + typedef SemiHonestRepSecret part_type; + + static MC* new_mc(mac_key_type) { return new MC; } SemiHonestRepSecret() {} template diff --git a/GC/ShareSecret.hpp b/GC/ShareSecret.hpp index 1e18d7e1e..4976e9dbe 100644 --- a/GC/ShareSecret.hpp +++ b/GC/ShareSecret.hpp @@ -3,6 +3,9 @@ * */ +#ifndef GC_SHARESECRET_HPP +#define GC_SHARESECRET_HPP + #include "ShareSecret.h" #include "MaliciousRepSecret.h" @@ -17,6 +20,7 @@ #include "Protocols/Beaver.hpp" #include "ShareParty.h" #include "ShareThread.hpp" +#include "Thread.hpp" namespace GC { @@ -30,8 +34,11 @@ SwitchableOutput ShareSecret::out; template void ShareSecret::check_length(int n, const Integer& x) { - if ((size_t)n < 8 * sizeof(x) and abs(x.get()) >= (1LL << n)) - throw out_of_range("public value too long"); + if ((size_t) n < 8 * sizeof(x) + and (unsigned long long) abs(x.get()) >= (1ULL << n)) + throw out_of_range( + "public value too long for " + to_string(n) + " bits: " + + to_string(x.get()) + "/" + to_string(1ULL << n)); } template @@ -89,7 +96,7 @@ void ShareSecret::inputb(Processor& processor, input.reset_all(*party.P); InputArgList a(args); - bool interactive = party.n_interactive_inputs_from_me(a) > 0; + bool interactive = Thread::s().n_interactive_inputs_from_me(a) > 0; for (auto x : a) { @@ -153,7 +160,7 @@ void ReplicatedSecret::reveal(size_t n_bits, Clear& x) vector opened; auto& party = ShareThread::s(); party.MC->POpen(opened, {share}, *party.P); - x = IntBase(opened[0]); + x = BitVec::super(opened[0]); } template @@ -165,3 +172,5 @@ void ShareSecret::random_bit() } } + +#endif diff --git a/GC/ShareThread.h b/GC/ShareThread.h index b0eb12a2d..e53c76588 100644 --- a/GC/ShareThread.h +++ b/GC/ShareThread.h @@ -19,25 +19,42 @@ namespace GC { template -class ShareThread : public Thread +class ShareThread { static thread_local ShareThread* singleton; public: static ShareThread& s(); + Player* P; + typename T::MC* MC; + typename T::Protocol* protocol; + DataPositions usage; Preprocessing& DataF; - ShareThread(int i, ThreadMaster& master); + ShareThread(const Names& N, OnlineOptions& opts); virtual ~ShareThread(); - void pre_run(); + virtual typename T::MC* new_mc(typename T::mac_key_type mac_key) + { return T::new_mc(mac_key); } + + void pre_run(Player& P, typename T::mac_key_type mac_key); void post_run(); void and_(Processor& processor, const vector& args, bool repeat); }; +template +class StandaloneShareThread : public ShareThread, public Thread +{ +public: + StandaloneShareThread(int i, ThreadMaster& master); + + void pre_run(); + void post_run() { ShareThread::post_run(); } +}; + template thread_local ShareThread* ShareThread::singleton = 0; diff --git a/GC/ShareThread.hpp b/GC/ShareThread.hpp index fa5c39159..e1c1fd968 100644 --- a/GC/ShareThread.hpp +++ b/GC/ShareThread.hpp @@ -7,6 +7,7 @@ #define GC_SHARETHREAD_HPP_ #include +#include "GC/ShareParty.h" #include "Protocols/MaliciousRepMC.h" #include "Math/Setup.h" @@ -16,15 +17,19 @@ namespace GC { template -ShareThread::ShareThread(int i, - ThreadMaster& master) : - Thread(i, master), usage(master.N.num_players()), DataF( - master.opts.live_prep ? - *(Preprocessing*) new typename T::LivePrep(usage, - *this) : - *(Preprocessing*) new Sub_Data_Files(master.N, - get_prep_dir(master.N.num_players(), 128, 128), - usage)) +StandaloneShareThread::StandaloneShareThread(int i, ThreadMaster& master) : + ShareThread(master.N, master.opts), Thread(i, master) +{ +} + +template +ShareThread::ShareThread(const Names& N, OnlineOptions& opts) : + P(0), MC(0), protocol(0), usage(N.num_players()), DataF( + opts.live_prep ? + *static_cast*>(new typename T::LivePrep( + usage, *this)) : + *static_cast*>(new Sub_Data_Files(N, + get_prep_dir(N.num_players(), 128, 128), usage))) { } @@ -32,21 +37,34 @@ template ShareThread::~ShareThread() { delete &DataF; + if (MC) + delete MC; + if (protocol) + delete protocol; } template -void ShareThread::pre_run() +void ShareThread::pre_run(Player& P, typename T::mac_key_type mac_key) { + this->P = &P; if (singleton) throw runtime_error("there can only be one"); singleton = this; - assert(this->protocol != 0); + protocol = new typename T::Protocol(*this->P); + MC = this->new_mc(mac_key); DataF.set_protocol(*this->protocol); } +template +void StandaloneShareThread::pre_run() +{ + ShareThread::pre_run(*Thread::P, ShareParty::s().mac_key); +} + template void ShareThread::post_run() { + MC->Check(*this->P); #ifndef INSECURE cerr << "Removing used pre-processed data" << endl; DataF.prune(); diff --git a/GC/Thread.h b/GC/Thread.h index 5546907f1..790b988dd 100644 --- a/GC/Thread.h +++ b/GC/Thread.h @@ -34,8 +34,6 @@ class Thread ThreadMaster& master; Machine& machine; Processor processor; - typename T::MC* MC; - typename T::Protocol* protocol; Names& N; Player* P; PRNG secure_prng; @@ -50,8 +48,6 @@ class Thread Thread(int thread_num, ThreadMaster& master); virtual ~Thread(); - virtual typename T::MC* new_mc() { return T::new_mc(machine); } - void run(); virtual void pre_run() {} virtual void run(Program& program); @@ -72,7 +68,8 @@ Thread& Thread::s() if (singleton) return *singleton; else - throw runtime_error("no singleton"); + throw runtime_error( + "no singleton / not implemented with arithmetic VMs"); } } /* namespace GC */ diff --git a/GC/Thread.hpp b/GC/Thread.hpp index 7c1d0ef58..4ab43ad3f 100644 --- a/GC/Thread.hpp +++ b/GC/Thread.hpp @@ -3,6 +3,9 @@ * */ +#ifndef GC_THREAD_HPP_ +#define GC_THREAD_HPP_ + #include "Thread.h" #include "Program.h" @@ -24,7 +27,7 @@ void* Thread::run_thread(void* thread) template Thread::Thread(int thread_num, ThreadMaster& master) : master(master), machine(master.machine), processor(machine), - protocol(0), N(master.N), P(0), + N(master.N), P(0), thread_num(thread_num) { pthread_create(&thread, 0, run_thread, this); @@ -33,12 +36,8 @@ Thread::Thread(int thread_num, ThreadMaster& master) : template Thread::~Thread() { - if (MC) - delete MC; if (P) delete P; - if (protocol) - delete protocol; } template @@ -47,13 +46,12 @@ void Thread::run() if (singleton) throw runtime_error("there can only be one"); singleton = this; + BaseMachine::s().thread_num = thread_num; secure_prng.ReSeed(); if (machine.use_encryption) P = new CryptoPlayer(N, thread_num << 16); else P = new PlainPlayer(N, thread_num << 16); - protocol = new typename T::Protocol(*P); - MC = this->new_mc(); processor.open_input_file(N.my_num(), thread_num); done.push(0); pre_run(); @@ -67,7 +65,6 @@ void Thread::run() } post_run(); - MC->Check(*P); } template @@ -104,3 +101,5 @@ int GC::Thread::n_interactive_inputs_from_me(InputArgList& args) } } /* namespace GC */ + +#endif diff --git a/GC/ThreadMaster.hpp b/GC/ThreadMaster.hpp index 0a533b2ea..750189f3a 100644 --- a/GC/ThreadMaster.hpp +++ b/GC/ThreadMaster.hpp @@ -3,6 +3,9 @@ * */ +#ifndef GC_THREADMASTER_HPP_ +#define GC_THREADMASTER_HPP_ + #include "ThreadMaster.h" #include "Program.h" @@ -20,7 +23,7 @@ ThreadMaster& ThreadMaster::s() if (singleton) return *singleton; else - throw runtime_error("no singleton, maybe threads not supported"); + throw no_singleton("no singleton, maybe threads not supported"); } template @@ -107,3 +110,5 @@ void ThreadMaster::run() } } /* namespace GC */ + +#endif diff --git a/GC/TinierPrep.h b/GC/TinierPrep.h new file mode 100644 index 000000000..4849d7203 --- /dev/null +++ b/GC/TinierPrep.h @@ -0,0 +1,31 @@ +/* + * TinierPrep.h + * + */ + +#ifndef GC_TINIERPREP_H_ +#define GC_TINIERPREP_H_ + +#include "TinyPrep.h" + +namespace GC +{ + +template +class TinierPrep : public TinyPrep +{ +public: + TinierPrep(DataPositions& usage, ShareThread& thread) : + TinyPrep(usage, thread) + { + } + + void buffer_inputs(int player) + { + this->buffer_inputs_(player, this->triple_generator); + } +}; + +} + +#endif /* GC_TINIERPREP_H_ */ diff --git a/GC/TinierSecret.h b/GC/TinierSecret.h new file mode 100644 index 000000000..0e98b9555 --- /dev/null +++ b/GC/TinierSecret.h @@ -0,0 +1,95 @@ +/* + * TinierSecret.h + * + */ + +#ifndef GC_TINIERSECRET_H_ +#define GC_TINIERSECRET_H_ + +#include "TinySecret.h" +#include "TinierShare.h" + +template class TinierMultiplier; + +namespace GC +{ + +template class TinierPrep; + +template +class TinierSecret : public VectorSecret> +{ + typedef VectorSecret> super; + typedef TinierSecret This; + +public: + typedef TinyMC MC; + typedef MC MAC_Check; + typedef Beaver Protocol; + typedef ::Input Input; + typedef TinierPrep LivePrep; + typedef Memory DynamicMemory; + + typedef NPartyTripleGenerator TripleGenerator; + typedef NPartyTripleGenerator InputGenerator; + typedef TinierMultiplier Multiplier; + + typedef typename super::part_type check_type; + typedef Share input_check_type; + typedef check_type input_type; + + static string type_short() + { + return "TT"; + } + + static MC* new_mc(typename super::mac_key_type mac_key) + { + return new MC(mac_key); + } + + template + static void generate_mac_key(typename super::mac_key_type& dest, const U&) + { + SeededPRNG G; + dest.randomize(G); + } + + static void store_clear_in_dynamic(Memory& mem, + const vector& accesses) + { + auto& party = ShareThread::s(); + for (auto access : accesses) + mem[access.address] = super::constant(access.value, + party.P->my_num(), {}); + } + + + TinierSecret() + { + } + TinierSecret(const super& other) : + super(other) + { + } + TinierSecret(const typename super::super& other) : + super(other) + { + } + TinierSecret(const typename super::part_type& other) + { + this->get_regs().push_back(other); + } + + void reveal(size_t n_bits, Clear& x) + { + auto& to_open = *this; + to_open.resize_regs(n_bits); + auto& party = ShareThread::s(); + x = party.MC->POpen(to_open, *party.P); + } +}; + +} /* namespace GC */ + +#endif /* GC_TINIERSECRET_H_ */ diff --git a/GC/TinierShare.h b/GC/TinierShare.h new file mode 100644 index 000000000..c0d8a6831 --- /dev/null +++ b/GC/TinierShare.h @@ -0,0 +1,107 @@ +/* + * TinierShare.h + * + */ + +#ifndef GC_TINIERSHARE_H_ +#define GC_TINIERSHARE_H_ + +#include "Processor/DummyProtocol.h" +#include "Protocols/Share.h" +#include "Math/Bit.h" +#include "TinierSharePrep.h" + +namespace GC +{ + +template class TinierSecret; + +template +class TinierShare: public Share_, SemiShare>, + public ShareSecret> +{ + typedef TinierShare This; + +public: + typedef Share_, SemiShare> super; + + typedef T mac_key_type; + typedef T mac_type; + typedef T sacri_type; + typedef Share input_check_type; + + typedef MAC_Check_ MAC_Check; + typedef TinierSharePrep LivePrep; + typedef ::Input Input; + typedef Beaver Protocol; + typedef NPartyTripleGenerator> TripleGenerator; + + typedef void DynamicMemory; + typedef SwitchableOutput out_type; + + static string name() + { + return "tinier share"; + } + + static string type_string() + { + return "Tinier"; + } + + static ShareThread>& get_party() + { + return ShareThread>::s(); + } + + static MAC_Check* new_mc(mac_key_type mac_key) + { + return new MAC_Check(mac_key); + } + + static This new_reg() + { + return {}; + } + + TinierShare() + { + } + TinierShare(const super& other) : + super(other) + { + } + TinierShare(const typename super::share_type& share, const typename super::mac_type& mac) : + super(share, mac) + { + } + + void XOR(const This& a, const This& b) + { + *this = a + b; + } + + This& operator^=(const This& other) + { + *this += other; + return *this; + } + + void public_input(bool input) + { + auto& party = get_party(); + *this = super::constant(input, party.P->my_num(), + party.MC->get_alphai()); + } + + void random() + { + TinierSecret tmp; + get_party().DataF.get_one(DATA_BIT, tmp); + *this = tmp.get_reg(0); + } +}; + +} /* namespace GC */ + +#endif /* GC_TINIERSHARE_H_ */ diff --git a/GC/TinierSharePrep.h b/GC/TinierSharePrep.h new file mode 100644 index 000000000..955bee2f7 --- /dev/null +++ b/GC/TinierSharePrep.h @@ -0,0 +1,40 @@ +/* + * TinierSharePrep.h + * + */ + +#ifndef GC_TINIERSHAREPREP_H_ +#define GC_TINIERSHAREPREP_H_ + +#include "Protocols/ReplicatedPrep.h" +#include "OT/NPartyTripleGenerator.h" +#include "ShareThread.h" + +namespace GC +{ + +template +class TinierSharePrep : public BufferPrep +{ + typename T::TripleGenerator* triple_generator; + MascotParams params; + + void buffer_triples() { throw not_implemented(); } + void buffer_squares() { throw not_implemented(); } + void buffer_bits() { throw not_implemented(); } + void buffer_inverses() { throw not_implemented(); } + + void buffer_inputs(int player); + +public: + TinierSharePrep(DataPositions& usage); + ~TinierSharePrep(); + + void set_protocol(typename T::Protocol& protocol); + + size_t data_sent(); +}; + +} + +#endif /* GC_TINIERSHAREPREP_H_ */ diff --git a/GC/TinierSharePrep.hpp b/GC/TinierSharePrep.hpp new file mode 100644 index 000000000..86e775742 --- /dev/null +++ b/GC/TinierSharePrep.hpp @@ -0,0 +1,59 @@ +/* + * TinierSharePrep.cpp + * + */ + +#include "TinierSharePrep.h" + +namespace GC +{ + +template +TinierSharePrep::TinierSharePrep(DataPositions& usage) : + BufferPrep(usage), triple_generator(0) +{ +} + +template +TinierSharePrep::~TinierSharePrep() +{ + if (triple_generator) + delete triple_generator; +} + +template +void TinierSharePrep::set_protocol(typename T::Protocol& protocol) +{ + params.generateMACs = true; + params.amplify = false; + params.check = false; + auto& thread = ShareThread>::s(); + triple_generator = new typename T::TripleGenerator( + BaseMachine::s().fresh_ot_setup(), protocol.P.N, -1, + OnlineOptions::singleton.batch_size + * TinierSecret::default_length, 1, + params, thread.MC->get_alphai(), &protocol.P); + triple_generator->multi_threaded = false; + this->inputs.resize(thread.P->num_players()); +} + +template +void TinierSharePrep::buffer_inputs(int player) +{ + auto& inputs = this->inputs; + assert(triple_generator); + triple_generator->generateInputs(player); + for (auto& x : triple_generator->inputs) + inputs.at(player).push_back(x); +} + +template +size_t TinierSharePrep::data_sent() +{ + if (triple_generator) + return triple_generator->data_sent(); + else + return 0; +} + +} diff --git a/GC/TinyMC.h b/GC/TinyMC.h index d3d45d304..9abb5ed47 100644 --- a/GC/TinyMC.h +++ b/GC/TinyMC.h @@ -14,9 +14,9 @@ namespace GC template class TinyMC : public MAC_Check_Base { - typename T::part_type::MAC_Check part_MC; + typename T::check_type::MAC_Check part_MC; vector part_values; - vector part_shares; + vector part_shares; public: TinyMC(typename T::mac_key_type mac_key) : diff --git a/GC/TinyPrep.h b/GC/TinyPrep.h index 31e3eccac..57b900201 100644 --- a/GC/TinyPrep.h +++ b/GC/TinyPrep.h @@ -18,18 +18,16 @@ namespace GC template class TinyPrep : public BufferPrep, public RandomPrep { - typedef Share> res_type; - - Thread& thread; +protected: + ShareThread& thread; typename T::TripleGenerator* triple_generator; - typename T::part_type::TripleGenerator* input_generator; MascotParams params; vector> triple_buffer; public: - TinyPrep(DataPositions& usage, Thread& thread); + TinyPrep(DataPositions& usage, ShareThread& thread); ~TinyPrep(); void set_protocol(Beaver& protocol); @@ -37,14 +35,35 @@ class TinyPrep : public BufferPrep, public RandomPrep get_triple(int n_bits); + + size_t data_sent(); +}; + +template +class TinyOnlyPrep : public TinyPrep +{ + typename T::part_type::TripleGenerator* input_generator; + +public: + TinyOnlyPrep(DataPositions& usage, ShareThread& thread); + ~TinyOnlyPrep(); + + void set_protocol(Beaver& protocol); + + void buffer_inputs(int player) + { + this->buffer_inputs_(player, input_generator); + } + + size_t data_sent(); }; } /* namespace GC */ diff --git a/GC/TinyPrep.hpp b/GC/TinyPrep.hpp index e332be638..894ea92f5 100644 --- a/GC/TinyPrep.hpp +++ b/GC/TinyPrep.hpp @@ -5,13 +5,21 @@ #include "TinyPrep.h" +#include "Protocols/MascotPrep.hpp" + namespace GC { template -TinyPrep::TinyPrep(DataPositions& usage, Thread& thread) : - BufferPrep(usage), thread(thread), triple_generator(0), - input_generator(0) +TinyPrep::TinyPrep(DataPositions& usage, ShareThread& thread) : + BufferPrep(usage), thread(thread), triple_generator(0) +{ + +} + +template +TinyOnlyPrep::TinyOnlyPrep(DataPositions& usage, ShareThread& thread) : + TinyPrep(usage, thread), input_generator(0) { } @@ -20,6 +28,11 @@ TinyPrep::~TinyPrep() { if (triple_generator) delete triple_generator; +} + +template +TinyOnlyPrep::~TinyOnlyPrep() +{ if (input_generator) delete input_generator; } @@ -31,19 +44,24 @@ void TinyPrep::set_protocol(Beaver& protocol) params.generateMACs = true; params.amplify = false; params.check = false; + auto& thread = ShareThread::s(); triple_generator = new typename T::TripleGenerator( - thread.processor.machine.ot_setups.at(thread.thread_num).get_fresh(), - thread.master.N, thread.thread_num, - thread.master.opts.batch_size, - 1, params, thread.MC->get_alphai(), thread.P); + BaseMachine::s().fresh_ot_setup(), protocol.P.N, -1, + OnlineOptions::singleton.batch_size, 1, params, + thread.MC->get_alphai(), &protocol.P); triple_generator->multi_threaded = false; +} + +template +void TinyOnlyPrep::set_protocol(Beaver& protocol) +{ + TinyPrep::set_protocol(protocol); input_generator = new typename T::part_type::TripleGenerator( - thread.processor.machine.ot_setups.at(thread.thread_num).get_fresh(), - thread.master.N, thread.thread_num, - thread.master.opts.batch_size, - 1, params, thread.MC->get_alphai(), thread.P); + BaseMachine::s().fresh_ot_setup(), protocol.P.N, -1, + OnlineOptions::singleton.batch_size, 1, this->params, + this->thread.MC->get_alphai(), &protocol.P); input_generator->multi_threaded = false; - thread.MC->get_part_MC().set_prep(*this); + this->thread.MC->get_part_MC().set_prep(*this); } template @@ -51,9 +69,9 @@ void TinyPrep::buffer_triples() { auto& triple_generator = this->triple_generator; params.generateBits = false; - vector> triples; - ShuffleSacrifice sacrifice; - while (int(triples.size()) < sacrifice.minimum_n_inputs()) + vector> triples; + ShuffleSacrifice sacrifice; + while (int(triples.size()) < sacrifice.minimum_n_inputs_with_combining()) { triple_generator->generatePlainTriples(); triple_generator->unlock(); @@ -84,6 +102,8 @@ void TinyPrep::buffer_triples() } sacrifice.triple_sacrifice(triples, triples, *thread.P, thread.MC->get_part_MC()); + sacrifice.triple_combine(triples, triples, *thread.P, + thread.MC->get_part_MC()); for (size_t i = 0; i < triples.size() / T::default_length; i++) { this->triples.push_back({}); @@ -112,21 +132,22 @@ void TinyPrep::buffer_bits() } template -void TinyPrep::buffer_inputs(int player) +void TinyPrep::buffer_inputs_(int player, typename T::InputGenerator* input_generator) { auto& inputs = this->inputs; - inputs.resize(thread.P->num_players()); - assert(this->input_generator); - this->input_generator->generateInputs(player); - for (size_t i = 0; i < this->input_generator->inputs.size() / T::default_length; i++) + inputs.resize(this->thread.P->num_players()); + assert(input_generator); + input_generator->generateInputs(player); + assert(input_generator->inputs.size() >= T::default_length); + for (size_t i = 0; i < input_generator->inputs.size() / T::default_length; i++) { inputs[player].push_back({}); inputs[player].back().share.resize_regs(T::default_length); for (int j = 0; j < T::default_length; j++) { - auto& source_input = this->input_generator->inputs[j + auto& source_input = input_generator->inputs[j + i * T::default_length]; - inputs[player].back().share.get_reg(j) = res_type(source_input.share); + inputs[player].back().share.get_reg(j) = source_input.share; inputs[player].back().value ^= typename T::open_type( source_input.value.get_bit(0)) << j; } @@ -170,4 +191,22 @@ array TinyPrep::get_triple(int n_bits) return res; } +template +size_t TinyPrep::data_sent() +{ + size_t res = 0; + if (triple_generator) + res += triple_generator->data_sent(); + return res; +} + +template +size_t TinyOnlyPrep::data_sent() +{ + auto res = TinyPrep::data_sent(); + if (input_generator) + res += input_generator->data_sent(); + return res; +} + } /* namespace GC */ diff --git a/GC/TinySecret.h b/GC/TinySecret.h index 54fa4ca22..5b732457b 100644 --- a/GC/TinySecret.h +++ b/GC/TinySecret.h @@ -19,16 +19,16 @@ template class TinyMultiplier; namespace GC { -template class TinyPrep; +template class TinyOnlyPrep; template class TinyMC; -template -class TinySecret : public Secret> +template +class VectorSecret : public Secret { - typedef TinySecret This; + typedef VectorSecret This; public: - typedef TinyShare part_type; + typedef T part_type; typedef Secret super; typedef typename part_type::mac_key_type mac_key_type; @@ -36,29 +36,17 @@ class TinySecret : public Secret> typedef BitVec open_type; typedef BitVec clear; - typedef TinyMC MC; - typedef MC MAC_Check; - typedef Beaver Protocol; - typedef ::Input Input; - typedef TinyPrep LivePrep; - typedef Memory DynamicMemory; - - typedef OTTripleGenerator TripleGenerator; - typedef TinyMultiplier Multiplier; typedef typename part_type::sacri_type sacri_type; typedef typename part_type::mac_type mac_type; typedef BitDiagonal Rectangle; + typedef typename T::super check_type; + static const bool dishonest_majority = true; static const bool needs_ot = true; static const int default_length = 64; - static string type_short() - { - return "T"; - } - static DataFieldType field_type() { return BitVec::field_type(); @@ -69,19 +57,9 @@ class TinySecret : public Secret> return part_type::size() * default_length; } - static MC* new_mc(Machine& machine) + static void generate_mac_key(mac_key_type& dest, const mac_key_type& source) { - (void) machine; - return new MC(ShareParty::s().mac_key); - } - - static void store_clear_in_dynamic(Memory& mem, - const vector& accesses) - { - auto& party = ShareThread::s(); - for (auto access : accesses) - mem[access.address] = constant(access.value, party.P->my_num(), - {}); + dest = source; } static This constant(BitVec other, int my_num, mac_key_type alphai) @@ -93,13 +71,17 @@ class TinySecret : public Secret> return res; } - TinySecret() + VectorSecret() { } - TinySecret(const super& other) : + VectorSecret(const super& other) : super(other) { } + VectorSecret(const part_type& other) + { + this->get_regs().push_back(other); + } void assign(const char* buffer) { @@ -113,6 +95,11 @@ class TinySecret : public Secret> return *this + other; } + This& operator^=(const This& other) + { + return *this = *this + other; + } + This operator*(const BitVec& other) const { This res = *this; @@ -122,6 +109,11 @@ class TinySecret : public Secret> return res; } + This operator&(const BitVec::super& other) const + { + return *this * BitVec(other); + } + This extend_bit() const { This res; @@ -136,14 +128,6 @@ class TinySecret : public Secret> return res; } - void reveal(size_t n_bits, Clear& x) - { - auto& to_open = *this; - to_open.resize_regs(n_bits); - auto& party = ShareThread::s(); - x = party.MC->POpen(to_open, *party.P); - } - void output(ostream& s, bool human = true) const { assert(this->get_regs().size() == default_length); @@ -153,7 +137,66 @@ class TinySecret : public Secret> }; template -inline TinySecret operator*(const BitVec& clear, const TinySecret& share) +class TinySecret : public VectorSecret> +{ + typedef VectorSecret> super; + typedef TinySecret This; + +public: + typedef TinyMC MC; + typedef MC MAC_Check; + typedef Beaver Protocol; + typedef ::Input Input; + typedef TinyOnlyPrep LivePrep; + typedef Memory DynamicMemory; + + typedef OTTripleGenerator TripleGenerator; + typedef typename super::part_type::TripleGenerator InputGenerator; + + typedef TinyMultiplier Multiplier; + + static string type_short() + { + return "T"; + } + + static MC* new_mc(typename super::mac_key_type mac_key) + { + return new MC(mac_key); + } + + static void store_clear_in_dynamic(Memory& mem, + const vector& accesses) + { + auto& party = ShareThread::s(); + for (auto access : accesses) + mem[access.address] = super::constant(access.value, + party.P->my_num(), {}); + } + + TinySecret() + { + } + TinySecret(const super& other) : + super(other) + { + } + TinySecret(const typename super::part_type& other) : + super(other) + { + } + + void reveal(size_t n_bits, Clear& x) + { + auto& to_open = *this; + to_open.resize_regs(n_bits); + auto& party = ShareThread::s(); + x = party.MC->POpen(to_open, *party.P); + } +}; + +template +inline VectorSecret operator*(const BitVec& clear, const VectorSecret& share) { return share * clear; } diff --git a/GC/TinyShare.h b/GC/TinyShare.h index 51d724e62..419921fec 100644 --- a/GC/TinyShare.h +++ b/GC/TinyShare.h @@ -50,7 +50,11 @@ class TinyShare : public Spdz2kShare<1, S>, public ShareSecret> TinyShare() { } - TinyShare(const typename super::super& other) : + TinyShare(const typename super::super::super& other) : + super(other) + { + } + TinyShare(const super& other) : super(other) { } @@ -64,7 +68,7 @@ class TinyShare : public Spdz2kShare<1, S>, public ShareSecret> { auto& party = get_party(); *this = super::constant(input, party.P->my_num(), - ShareParty < TinySecret < S >> ::s().mac_key); + party.MC->get_alphai()); } void random() diff --git a/GC/instructions.h b/GC/instructions.h index 18ee9edda..8f0f7505f 100644 --- a/GC/instructions.h +++ b/GC/instructions.h @@ -14,87 +14,96 @@ #define MD dynamic_memory #define R0 instruction.get_r(0) -#define R1 instruction.get_r(1) -#define R2 instruction.get_r(2) +#define REG1 instruction.get_r(1) #define S0 processor.S[instruction.get_r(0)] #define PS1 processor.S[instruction.get_r(1)] #define PS2 processor.S[instruction.get_r(2)] #define C0 processor.C[instruction.get_r(0)] -#define C1 processor.C[instruction.get_r(1)] -#define C2 processor.C[instruction.get_r(2)] +#define PC1 processor.C[instruction.get_r(1)] +#define PC2 processor.C[instruction.get_r(2)] #define I0 processor.I[instruction.get_r(0)] -#define I1 processor.I[instruction.get_r(1)] -#define I2 processor.I[instruction.get_r(2)] +#define PI1 processor.I[instruction.get_r(1)] +#define PI2 processor.I[instruction.get_r(2)] #define IMM instruction.get_n() #define EXTRA instruction.get_start() -#define MSD MACH.MS[IMM] -#define MMC MACH.MC[IMM] -#define MID MACH.MI[IMM] +#define MSD processor.memories.MS[IMM] +#define MMC processor.memories.MC[IMM] +#define MID MACH->MI[IMM] -#define MSI MACH.MS[I1.get()] -#define MII MACH.MI[I1.get()] +#define MSI processor.memories.MS[PI1.get()] +#define MII MACH->MI[PI1.get()] -#define INSTRUCTIONS \ +#define BIT_INSTRUCTIONS \ X(XORS, PROC.xors(EXTRA)) \ - X(XORC, C0.xor_(C1, C2)) \ - X(XORCI, C0.xor_(C1, IMM)) \ + X(XORCB, C0.xor_(PC1, PC2)) \ + X(XORCBI, C0.xor_(PC1, IMM)) \ X(ANDRS, T::andrs(PROC, EXTRA)) \ X(ANDS, T::ands(PROC, EXTRA)) \ X(INPUTB, T::inputb(PROC, EXTRA)) \ - X(ADDC, C0 = C1 + C2) \ - X(ADDCI, C0 = C1 + IMM) \ - X(MULCI, C0 = C1 * IMM) \ + X(ADDCB, C0 = PC1 + PC2) \ + X(ADDCBI, C0 = PC1 + IMM) \ + X(MULCBI, C0 = PC1 * IMM) \ X(BITDECS, PROC.bitdecs(EXTRA, S0)) \ X(BITCOMS, PROC.bitcoms(S0, EXTRA)) \ X(BITDECC, PROC.bitdecc(EXTRA, C0)) \ - X(BITDECINT, PROC.bitdecint(EXTRA, I0)) \ - X(SHRCI, C0 = C1 >> IMM) \ - X(SHLCI, C0 = C1 << IMM) \ - X(LDBITS, S0.load_clear(R1, IMM)) \ - X(LDMS, S0 = MSD) \ - X(STMS, MSD = S0) \ - X(LDMSI, S0 = MSI) \ - X(STMSI, MSI = S0) \ - X(LDMC, C0 = MMC) \ - X(STMC, MMC = C0) \ + X(SHRCBI, C0 = PC1 >> IMM) \ + X(SHLCBI, C0 = PC1 << IMM) \ + X(LDBITS, S0.load_clear(REG1, IMM)) \ + X(LDMSB, S0 = MSD) \ + X(STMSB, MSD = S0) \ + X(LDMCB, C0 = MMC) \ + X(STMCB, MMC = C0) \ + X(MOVSB, S0 = PS1) \ + X(TRANS, T::trans(PROC, IMM, EXTRA)) \ + X(BITB, PROC.random_bit(S0)) \ + X(REVEAL, PS1.reveal(IMM, C0)) \ + X(PRINTREGSIGNED, PROC.print_reg_signed(IMM, C0)) \ + X(PRINTREGB, PROC.print_reg(R0, IMM)) \ + X(PRINTREGPLAINB, PROC.print_reg_plain(C0)) \ + X(PRINTFLOATPLAINB, PROC.print_float(EXTRA)) \ + X(CONDPRINTSTRB, if(C0.get()) PROC.print_str(IMM)) \ + +#define COMBI_INSTRUCTIONS BIT_INSTRUCTIONS \ + X(ANDM, S0 = PS1 & PC2) \ + X(LDMSBI, S0 = processor.memories.MS[Proc.read_Ci(REG1)]) \ + X(STMSBI, processor.memories.MS[Proc.read_Ci(REG1)] = S0) \ + X(CONVSINT, S0.load_clear(IMM, Proc.read_Ci(REG1))) \ + X(CONVCINT, C0 = Proc.read_Ci(REG1)) \ + X(CONVCBIT, Proc.write_Ci(R0, PC1.get())) \ + X(DABIT, Proc.dabit(INST)) \ + +#define GC_INSTRUCTIONS \ + X(LDMSBI, S0 = MSI) \ + X(STMSBI, MSI = S0) \ X(LDMSD, PROC.load_dynamic_direct(EXTRA, MD)) \ X(STMSD, PROC.store_dynamic_direct(EXTRA, MD)) \ X(LDMSDI, PROC.load_dynamic_indirect(EXTRA, MD)) \ X(STMSDI, PROC.store_dynamic_indirect(EXTRA, MD)) \ X(STMSDCI, PROC.store_clear_in_dynamic(EXTRA, MD)) \ - X(CONVSINT, S0.load_clear(IMM, I1)) \ - X(CONVCINT, C0 = I1) \ - X(CONVCBIT, T::convcbit(I0, C1)) \ - X(MOVS, S0 = PS1) \ - X(TRANS, T::trans(PROC, IMM, EXTRA)) \ - X(BIT, PROC.random_bit(S0)) \ - X(REVEAL, PS1.reveal(IMM, C0)) \ - X(PRINTREG, PROC.print_reg(R0, IMM)) \ - X(PRINTREGPLAIN, PROC.print_reg_plain(C0)) \ - X(PRINTREGSIGNED, PROC.print_reg_signed(IMM, C0)) \ + X(CONVSINT, S0.load_clear(IMM, PI1)) \ + X(CONVCINT, C0 = PI1) \ + X(CONVCBIT, T::convcbit(I0, PC1)) \ X(PRINTCHR, PROC.print_chr(IMM)) \ X(PRINTSTR, PROC.print_str(IMM)) \ - X(PRINTFLOATPLAIN, PROC.print_float(EXTRA)) \ X(PRINTFLOATPREC, PROC.print_float_prec(IMM)) \ - X(CONDPRINTSTR, if(C0.get()) PROC.print_str(IMM)) \ X(LDINT, I0 = int(IMM)) \ - X(ADDINT, I0 = I1 + I2) \ - X(SUBINT, I0 = I1 - I2) \ - X(MULINT, I0 = I1 * I2) \ - X(DIVINT, I0 = I1 / I2) \ + X(ADDINT, I0 = PI1 + PI2) \ + X(SUBINT, I0 = PI1 - PI2) \ + X(MULINT, I0 = PI1 * PI2) \ + X(DIVINT, I0 = PI1 / PI2) \ X(JMP, PROC.PC += IMM) \ X(JMPNZ, if (I0 != 0) PROC.PC += IMM) \ X(JMPEQZ, if (I0 == 0) PROC.PC += IMM) \ - X(EQZC, I0 = I1 == 0) \ - X(LTZC, I0 = I1 < 0) \ - X(LTC, I0 = I1 < I2) \ - X(GTC, I0 = I1 > I2) \ - X(EQC, I0 = I1 == I2) \ + X(EQZC, I0 = PI1 == 0) \ + X(LTZC, I0 = PI1 < 0) \ + X(LTC, I0 = PI1 < PI2) \ + X(GTC, I0 = PI1 > PI2) \ + X(EQC, I0 = PI1 == PI2) \ X(JMPI, PROC.PC += I0) \ X(LDMINT, I0 = MID) \ X(STMINT, MID = I0) \ @@ -102,18 +111,23 @@ X(STMINTI, MII = I0) \ X(PUSHINT, PROC.pushi(I0.get())) \ X(POPINT, long x; PROC.popi(x); I0 = x) \ - X(MOVINT, I0 = I1) \ + X(MOVINT, I0 = PI1) \ + X(BITDECINT, PROC.bitdecint(EXTRA, I0)) \ X(LDARG, I0 = PROC.get_arg()) \ X(STARG, PROC.set_arg(I0.get())) \ - X(TIME, MACH.time()) \ - X(START, MACH.start(IMM)) \ - X(STOP, MACH.stop(IMM)) \ + X(TIME, MACH->time()) \ + X(START, MACH->start(IMM)) \ + X(STOP, MACH->stop(IMM)) \ X(GLDMS, ) \ X(GLDMC, ) \ + X(LDMS, ) \ + X(LDMC, ) \ X(PRINTINT, S0.out << I0) \ X(STARTGRIND, CALLGRIND_START_INSTRUMENTATION) \ X(STOPGRIND, CALLGRIND_STOP_INSTRUMENTATION) \ - X(RUN_TAPE, MACH.run_tape(R0, IMM, R1)) \ - X(JOIN_TAPE, MACH.join_tape(R0)) \ + X(RUN_TAPE, MACH->run_tape(R0, IMM, REG1)) \ + X(JOIN_TAPE, MACH->join_tape(R0)) \ + +#define INSTRUCTIONS BIT_INSTRUCTIONS GC_INSTRUCTIONS #endif /* GC_INSTRUCTIONS_H_ */ diff --git a/Machines/Player-Online.cpp b/Machines/Player-Online.cpp index b47379e2a..9c5491f12 100644 --- a/Machines/Player-Online.cpp +++ b/Machines/Player-Online.cpp @@ -5,6 +5,7 @@ #include "Processor/config.h" #include "Protocols/Share.h" +#include "GC/TinierSecret.h" #include "Player-Online.hpp" diff --git a/Machines/Player-Online.hpp b/Machines/Player-Online.hpp index 08fc64d0c..9b3ff77be 100644 --- a/Machines/Player-Online.hpp +++ b/Machines/Player-Online.hpp @@ -18,7 +18,8 @@ int spdz_main(int argc, const char** argv, ez::ezOptionParser& opt, bool live_pr OnlineOptions& online_opts = OnlineOptions::singleton; online_opts = {opt, argc, argv, 1000, live_prep_default}; - opt.example = "./Player-Online.x -lgp 64 -lg2 128 -m new 0 sample-prog\n./Player-Online.x -pn 13000 -h localhost 1 sample-prog\n"; + opt.example = string() + argv[0] + " -p 0 -N 2 sample-prog\n" + argv[0] + + " -h localhost -p 1 -N 2 sample-prog\n"; opt.add( to_string(U::clear::default_degree()).c_str(), // Default. diff --git a/Machines/Rep.hpp b/Machines/Rep.hpp index e1b6031b6..1338e96ad 100644 --- a/Machines/Rep.hpp +++ b/Machines/Rep.hpp @@ -3,9 +3,6 @@ * */ -#include "Protocols/MaliciousRep3Share.h" -#include "Protocols/MalRepRingShare.h" -#include "Protocols/BrainShare.h" #include "Protocols/BrainPrep.h" #include "Protocols/MalRepRingPrep.h" @@ -15,12 +12,14 @@ #include "Protocols/BrainPrep.hpp" #include "Protocols/MalRepRingPrep.hpp" #include "Protocols/MaliciousRepPrep.hpp" -#include "Protocols/Spdz2kPrep.hpp" #include "Protocols/MAC_Check_Base.hpp" #include "Protocols/fake-stuff.hpp" #include "Protocols/MaliciousRepMC.hpp" #include "Protocols/Beaver.hpp" #include "Math/Z2k.hpp" +#include "GC/ShareSecret.hpp" +#include "GC/RepPrep.hpp" +#include "GC/ThreadMaster.hpp" template<> Preprocessing>* Preprocessing>::get_live_prep( diff --git a/Machines/RepRing.hpp b/Machines/RepRing.hpp new file mode 100644 index 000000000..cec263e58 --- /dev/null +++ b/Machines/RepRing.hpp @@ -0,0 +1,2 @@ +#include "Rep.hpp" +#include "Protocols/Spdz2kPrep.hpp" diff --git a/Machines/SPDZ.cpp b/Machines/SPDZ.cpp index 452dbfd42..aea1dbf07 100644 --- a/Machines/SPDZ.cpp +++ b/Machines/SPDZ.cpp @@ -8,4 +8,14 @@ #include "Protocols/MascotPrep.hpp" +#include "GC/TinierSecret.h" +#include "GC/TinyMC.h" +#include "GC/TinierPrep.h" + +#include "GC/ShareParty.hpp" +#include "GC/Secret.hpp" +#include "GC/TinyPrep.hpp" +#include "GC/ShareSecret.hpp" +#include "GC/TinierSharePrep.hpp" + template class Machine, Share>; diff --git a/Machines/SPDZ2k.hpp b/Machines/SPDZ2k.hpp index 2fe84a949..d52bf5dc0 100644 --- a/Machines/SPDZ2k.hpp +++ b/Machines/SPDZ2k.hpp @@ -4,6 +4,7 @@ */ #include "Protocols/Spdz2kShare.h" +#include "Protocols/Spdz2kPrep.h" #include "Processor/Data_Files.hpp" #include "Processor/Instruction.hpp" diff --git a/Machines/brain-party.cpp b/Machines/brain-party.cpp index a3fa5b560..f776c1092 100644 --- a/Machines/brain-party.cpp +++ b/Machines/brain-party.cpp @@ -8,7 +8,7 @@ #include "Processor/RingOptions.h" #include "Protocols/ReplicatedMachine.hpp" -#include "Machines/Rep.hpp" +#include "Machines/RepRing.hpp" int main(int argc, const char** argv) { diff --git a/Machines/cowgear-party.cpp b/Machines/cowgear-party.cpp index 065d918f5..0b9a12487 100644 --- a/Machines/cowgear-party.cpp +++ b/Machines/cowgear-party.cpp @@ -11,6 +11,10 @@ #include "FHE/FFT_Data.h" #include "FHE/NTL-Subs.h" +#include "GC/TinierSecret.h" +#include "GC/TinierPrep.h" +#include "GC/TinyMC.h" + #include "Processor/Data_Files.hpp" #include "Processor/Instruction.hpp" #include "Processor/Machine.hpp" @@ -18,6 +22,11 @@ #include "Protocols/fake-stuff.hpp" #include "Protocols/Beaver.hpp" #include "Protocols/Share.hpp" +#include "GC/ShareSecret.hpp" +#include "GC/TinyPrep.hpp" +#include "GC/Secret.hpp" +#include "GC/TinierSharePrep.hpp" +#include "OT/NPartyTripleGenerator.hpp" #include "Player-Online.hpp" diff --git a/Machines/hemi-party.cpp b/Machines/hemi-party.cpp index 2ec4618bc..457c096be 100644 --- a/Machines/hemi-party.cpp +++ b/Machines/hemi-party.cpp @@ -8,6 +8,8 @@ #include "Math/gf2n.h" #include "FHE/P2Data.h" #include "Tools/ezOptionParser.h" +#include "GC/SemiSecret.h" +#include "GC/SemiPrep.h" #include "Player-Online.hpp" #include "Protocols/HemiPrep.hpp" @@ -21,6 +23,8 @@ #include "Protocols/fake-stuff.hpp" #include "Protocols/SemiMC.hpp" #include "Protocols/Beaver.hpp" +#include "GC/ShareSecret.hpp" +#include "GC/SemiHonestRepPrep.h" int main(int argc, const char** argv) { diff --git a/Machines/mal-rep-bmr-party.cpp b/Machines/mal-rep-bmr-party.cpp index 6e1fe264d..3fda28fab 100644 --- a/Machines/mal-rep-bmr-party.cpp +++ b/Machines/mal-rep-bmr-party.cpp @@ -3,6 +3,8 @@ * */ +#include "Protocols/MaliciousRep3Share.h" + #include "Machines/Rep.hpp" #include "BMR/RealProgramParty.hpp" diff --git a/Machines/malicious-rep-ring-party.cpp b/Machines/malicious-rep-ring-party.cpp index ab5bb0813..3d5238a68 100644 --- a/Machines/malicious-rep-ring-party.cpp +++ b/Machines/malicious-rep-ring-party.cpp @@ -7,7 +7,7 @@ #include "Protocols/MalRepRingOptions.h" #include "Protocols/ReplicatedMachine.hpp" #include "Processor/RingOptions.h" -#include "Machines/Rep.hpp" +#include "Machines/RepRing.hpp" int main(int argc, const char** argv) { diff --git a/Machines/mascot-party.cpp b/Machines/mascot-party.cpp index f63235840..ad60f2faf 100644 --- a/Machines/mascot-party.cpp +++ b/Machines/mascot-party.cpp @@ -1,6 +1,7 @@ #include "Player-Online.hpp" #include "Math/gfp.h" +#include "GC/TinierSecret.h" int main(int argc, const char** argv) { diff --git a/Machines/ps-rep-ring-party.cpp b/Machines/ps-rep-ring-party.cpp index d74f36444..df54f9cb5 100644 --- a/Machines/ps-rep-ring-party.cpp +++ b/Machines/ps-rep-ring-party.cpp @@ -7,7 +7,7 @@ #include "Protocols/PostSacriRepFieldShare.h" #include "Protocols/ReplicatedMachine.hpp" #include "Processor/RingOptions.h" -#include "Machines/Rep.hpp" +#include "Machines/RepRing.hpp" #include "Protocols/PostSacrifice.hpp" int main(int argc, const char** argv) diff --git a/Machines/replicated-ring-party.cpp b/Machines/replicated-ring-party.cpp index 045e4d2c5..e55d6657a 100644 --- a/Machines/replicated-ring-party.cpp +++ b/Machines/replicated-ring-party.cpp @@ -4,9 +4,11 @@ */ #include "Protocols/ReplicatedMachine.hpp" +#include "Protocols/Rep3Share2k.h" +#include "Protocols/ReplicatedPrep2k.h" #include "Processor/RingOptions.h" #include "Math/Integer.h" -#include "Machines/Rep.hpp" +#include "Machines/RepRing.hpp" int main(int argc, const char** argv) { @@ -15,11 +17,11 @@ int main(int argc, const char** argv) switch (opts.R) { case 64: - ReplicatedMachine>, Rep3Share>(argc, argv, + ReplicatedMachine, Rep3Share>(argc, argv, "replicated-ring", opt); break; case 72: - ReplicatedMachine>, Rep3Share>(argc, argv, + ReplicatedMachine, Rep3Share>(argc, argv, "replicated-ring", opt); break; default: diff --git a/Machines/semi-party.cpp b/Machines/semi-party.cpp index a03459365..128a6b18e 100644 --- a/Machines/semi-party.cpp +++ b/Machines/semi-party.cpp @@ -5,9 +5,12 @@ #include "Math/gfp.h" #include "Protocols/SemiShare.h" +#include "Tools/SwitchableOutput.h" +#include "GC/SemiPrep.h" #include "Player-Online.hpp" #include "Semi.hpp" +#include "GC/ShareSecret.hpp" int main(int argc, const char** argv) { diff --git a/Machines/semi2k-party.cpp b/Machines/semi2k-party.cpp index bb47d44f1..6cf3a572f 100644 --- a/Machines/semi2k-party.cpp +++ b/Machines/semi2k-party.cpp @@ -4,11 +4,14 @@ */ #include "Protocols/Semi2kShare.h" +#include "Protocols/SemiPrep2k.h" #include "Math/gf2n.h" #include "Processor/RingOptions.h" +#include "GC/SemiPrep.h" #include "Player-Online.hpp" #include "Semi.hpp" +#include "GC/ShareSecret.hpp" int main(int argc, const char** argv) { diff --git a/Machines/spdz2k-party.cpp b/Machines/spdz2k-party.cpp index ad15c085a..6674a5dc6 100644 --- a/Machines/spdz2k-party.cpp +++ b/Machines/spdz2k-party.cpp @@ -3,6 +3,10 @@ * */ +#include "GC/TinySecret.h" +#include "GC/TinyMC.h" +#include "GC/TinyPrep.h" +#include "GC/TinierSecret.h" #include "Processor/Machine.h" #include "Processor/RingOptions.h" #include "Protocols/Spdz2kShare.h" @@ -11,6 +15,11 @@ #include "Player-Online.hpp" #include "SPDZ2k.hpp" +#include "GC/ShareParty.hpp" +#include "GC/ShareSecret.hpp" +#include "GC/Secret.hpp" +#include "GC/TinyPrep.hpp" +#include "GC/TinierSharePrep.hpp" int main(int argc, const char** argv) { diff --git a/Machines/tinier-party.cpp b/Machines/tinier-party.cpp new file mode 100644 index 000000000..695d9178a --- /dev/null +++ b/Machines/tinier-party.cpp @@ -0,0 +1,33 @@ +/* + * tinier-party.cpp + * + */ + +#include "GC/TinierSecret.h" +#include "GC/TinierPrep.h" +#include "GC/ShareParty.h" +#include "GC/TinyMC.h" + +#include "GC/ShareParty.hpp" +#include "GC/ShareSecret.hpp" +#include "GC/Instruction.hpp" +#include "GC/Machine.hpp" +#include "GC/Processor.hpp" +#include "GC/Program.hpp" +#include "GC/Thread.hpp" +#include "GC/ThreadMaster.hpp" +#include "GC/Secret.hpp" +#include "GC/TinyPrep.hpp" + +#include "Processor/Machine.hpp" +#include "Processor/Instruction.hpp" +#include "Protocols/MAC_Check.hpp" +#include "Protocols/MAC_Check_Base.hpp" +#include "Protocols/Beaver.hpp" +#include "Protocols/MascotPrep.hpp" + +int main(int argc, const char** argv) +{ + gf2n_short::init_field(40); + GC::ShareParty>(argc, argv, 1000); +} diff --git a/Makefile b/Makefile index 443280f16..56c2dd2aa 100644 --- a/Makefile +++ b/Makefile @@ -12,6 +12,7 @@ PROCESSOR = $(patsubst %.cpp,%.o,$(wildcard Processor/*.cpp)) FHEOFFLINE = $(patsubst %.cpp,%.o,$(wildcard FHEOffline/*.cpp FHE/*.cpp)) GC = $(patsubst %.cpp,%.o,$(wildcard GC/*.cpp)) $(PROCESSOR) +GC_SEMI = GC/SemiSecret.o GC/SemiPrep.o GC/square64.o OT = $(patsubst %.cpp,%.o,$(filter-out OT/OText_main.cpp,$(wildcard OT/*.cpp))) OT_EXE = ot.x ot-offline.x @@ -20,7 +21,7 @@ COMMON = $(MATH) $(TOOLS) $(NETWORK) COMPLETE = $(COMMON) $(PROCESSOR) $(FHEOFFLINE) $(TINYOTOFFLINE) $(GC) $(OT) YAO = $(patsubst %.cpp,%.o,$(wildcard Yao/*.cpp)) $(OT) BMR/Key.o BMR = $(patsubst %.cpp,%.o,$(wildcard BMR/*.cpp BMR/network/*.cpp)) $(COMMON) $(PROCESSOR) $(OT) -VM = $(PROCESSOR) $(COMMON) +VM = $(PROCESSOR) $(COMMON) GC/square64.o LIB = libSPDZ.a @@ -35,10 +36,14 @@ DEPS := $(wildcard */*.d) .SECONDARY: $(OBJS) -all: gen_input online offline externalIO bmr yao replicated shamir real-bmr spdz2k-party.x brain-party.x semi-party.x semi2k-party.x semi-bin-party.x mascot-party.x tiny-party.x +all: arithmetic binary gen_input online offline externalIO bmr + +arithmetic: rep-ring rep-field shamir semi2k-party.x semi-party.x spdz2k-party.x mascot-party.x +binary: rep-bin yao semi-bin-party.x tinier-party.x tiny-party.x real-bmr ifeq ($(USE_NTL),1) -all: overdrive she-offline cowgear-party.x hemi-party.x +all: overdrive she-offline +arithmetic: hemi-party.x cowgear-party.x endif -include $(DEPS) @@ -157,23 +162,27 @@ default-prime-length.x: Utils/default-prime-length.cpp $(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS) $(ECLIB) replicated-bin-party.x: GC/square64.o +replicated-ring-party.x: GC/square64.o +replicated-field-party.x: GC/square64.o +brain-party.x: GC/square64.o malicious-rep-bin-party.x: GC/square64.o semi-bin-party.x: $(VM) $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o tiny-party.x: $(OT) +tinier-party.x: $(OT) shamir-party.x: Machines/ShamirMachine.o malicious-shamir-party.x: Machines/ShamirMachine.o spdz2k-party.x: $(OT) -semi-party.x: $(OT) -semi2k-party.x: $(OT) -hemi-party.x: $(FHEOFFLINE) -cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o +semi-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o +semi2k-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o +hemi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) +cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(OT) mascot-party.x: Machines/SPDZ.o $(OT) Player-Online.x: Machines/SPDZ.o $(OT) ps-rep-ring-party.x: Protocols/MalRepRingOptions.o malicious-rep-ring-party.x: Protocols/MalRepRingOptions.o mal-shamir-ecdsa-party.x: Machines/ShamirMachine.o shamir-ecdsa-party.x: Machines/ShamirMachine.o -semi-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) +semi-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) GC/SemiPrep.o mascot-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) $(LIBSIMPLEOT): SimpleOT/Makefile diff --git a/Math/Bit.h b/Math/Bit.h new file mode 100644 index 000000000..1d9447cdf --- /dev/null +++ b/Math/Bit.h @@ -0,0 +1,54 @@ +/* + * Bit.h + * + */ + +#ifndef MATH_BIT_H_ +#define MATH_BIT_H_ + +#include "BitVec.h" + +class Bit : public BitVec_ +{ + typedef BitVec_ super; + +public: + static int size_in_bits() + { + return 1; + } + + Bit() + { + } + Bit(bool other) : + super(other) + { + } + Bit(const super::super& other) : + super(other) + { + } + + Bit operator*(const Bit& other) const + { + return super::operator*(other); + } + + template + T operator*(const T& other) const + { + return other * *this; + } + + void pack(octetStream& os, int = -1) const + { + super::pack(os, 1); + } + void unpack(octetStream& os, int = -1) + { + super::unpack(os, 1); + } +}; + +#endif /* MATH_BIT_H_ */ diff --git a/Math/BitVec.h b/Math/BitVec.h index a0f5d3e9c..705ee3d6e 100644 --- a/Math/BitVec.h +++ b/Math/BitVec.h @@ -8,57 +8,63 @@ #include "Integer.h" #include "field_types.h" -#include "Square.h" class BitDiagonal; -class BitVec : public IntBase +template +class BitVec_ : public IntBase { public: - typedef BitVec Scalar; + typedef IntBase super; - typedef BitVec next; + typedef BitVec_ Scalar; + + typedef BitVec_ next; typedef BitDiagonal Square; - static const int n_bits = sizeof(a) * 8; + static const int n_bits = sizeof(T) * 8; static char type_char() { return 'B'; } static DataFieldType field_type() { return DATA_GF2; } static bool allows(Dtype dtype) { return dtype == DATA_TRIPLE or dtype == DATA_BIT; } - BitVec() {} - BitVec(long a) : IntBase(a) {} - BitVec(const IntBase& a) : IntBase(a) {} + BitVec_() {} + BitVec_(long a) : super(a) {} + BitVec_(const super& a) : super(a) {} + template + BitVec_(const Z2& a) : super(a.get_limb(0)) {} - BitVec operator+(const BitVec& other) const { return a ^ other.a; } - BitVec operator-(const BitVec& other) const { return a ^ other.a; } - BitVec operator*(const BitVec& other) const { return a & other.a; } + BitVec_ operator+(const BitVec_& other) const { return *this ^ other; } + BitVec_ operator-(const BitVec_& other) const { return *this ^ other; } + BitVec_ operator*(const BitVec_& other) const { return *this & other; } - BitVec operator/(const BitVec& other) const { (void) other; throw not_implemented(); } + BitVec_ operator/(const BitVec_& other) const { (void) other; throw not_implemented(); } - BitVec& operator+=(const BitVec& other) { *this ^= other; return *this; } - BitVec& operator-=(const BitVec& other) { *this ^= other; return *this; } + BitVec_& operator+=(const BitVec_& other) { *this ^= other; return *this; } + BitVec_& operator-=(const BitVec_& other) { *this ^= other; return *this; } - BitVec extend_bit() const { return -(a & 1); } - BitVec mask(int n) const { return n < n_bits ? *this & ((1L << n) - 1) : *this; } + BitVec_ extend_bit() const { return -(this->a & 1); } + BitVec_ mask(int n) const { return n < n_bits ? *this & ((1L << n) - 1) : *this; } template - void add(octetStream& os) { *this += os.get(); } + void add(octetStream& os) { *this += os.get(); } - void mul(const BitVec& a, const BitVec& b) { *this = a * b; } + void mul(const BitVec_& a, const BitVec_& b) { *this = a * b; } - void randomize(PRNG& G, int n = n_bits) { IntBase::randomize(G); *this = mask(n); } + void randomize(PRNG& G, int n = n_bits) { super::randomize(G); *this = mask(n); } - void pack(octetStream& os, int n = n_bits) const { os.store_int(a, DIV_CEIL(n, 8)); } - void unpack(octetStream& os, int n = n_bits) { a = os.get_int(DIV_CEIL(n, 8)); } + void pack(octetStream& os, int n = n_bits) const { os.store_int(this->a, DIV_CEIL(n, 8)); } + void unpack(octetStream& os, int n = n_bits) { this->a = os.get_int(DIV_CEIL(n, 8)); } - static BitVec unpack_new(octetStream& os, int n = n_bits) + static BitVec_ unpack_new(octetStream& os, int n = n_bits) { - BitVec res; + BitVec_ res; res.unpack(os, n); return res; } }; +typedef BitVec_ BitVec; + #endif /* MATH_BITVEC_H_ */ diff --git a/Math/Integer.cpp b/Math/Integer.cpp index 21ba9a302..aa297ef40 100644 --- a/Math/Integer.cpp +++ b/Math/Integer.cpp @@ -5,7 +5,8 @@ #include "Integer.h" -void IntBase::output(ostream& s,bool human) const +template +void IntBase::output(ostream& s,bool human) const { if (human) s << a; @@ -13,7 +14,8 @@ void IntBase::output(ostream& s,bool human) const s.write((char*)&a, sizeof(a)); } -void IntBase::input(istream& s,bool human) +template +void IntBase::input(istream& s,bool human) { if (human) s >> a; @@ -35,3 +37,14 @@ void Integer::reqbl(int n) throw Processor_Error("Program compiled for fields not rings"); } } + +Integer::Integer(const Integer& x, int n_bits) +{ + a = abs(x.get()); + a &= ~(uint64_t(-1) << (n_bits - 1) << 1); + if (x < 0) + a = -a; +} + +template class IntBase; +template class IntBase; diff --git a/Math/Integer.h b/Math/Integer.h index 8f8e175ff..c9bd0da78 100644 --- a/Math/Integer.h +++ b/Math/Integer.h @@ -18,10 +18,11 @@ using namespace std; // Functionality shared between integers and bit vectors +template class IntBase : public ValueInterface { protected: - long a; + T a; public: static const int N_BYTES = sizeof(a); @@ -37,9 +38,9 @@ class IntBase : public ValueInterface static bool allows(Dtype type) { return type <= DATA_BIT; } IntBase() { a = 0; } - IntBase(long a) : a(a) {} + IntBase(T a) : a(a) {} - long get() const { return a; } + T get() const { return a; } bool get_bit(int i) const { return (a >> i) & 1; } char* get_ptr() const { return (char*)&a; } @@ -55,17 +56,17 @@ class IntBase : public ValueInterface bool is_one() const { return a == 1; } bool is_bit() const { return is_zero() or is_one(); } - long operator>>(const IntBase& other) const + long operator>>(const IntBase& other) const { - if (other.a < N_BITS) - return (unsigned long) a >> other.a; + if (other.get() < N_BITS) + return (unsigned long) a >> other.get(); else return 0; } - long operator<<(const IntBase& other) const + long operator<<(const IntBase& other) const { - if (other.a < N_BITS) - return a << other.a; + if (other.get() < N_BITS) + return a << other.get(); else return 0; } @@ -79,8 +80,8 @@ class IntBase : public ValueInterface bool equal(const IntBase& other) const { return *this == other; } - long operator^=(const IntBase& other) { return a ^= other.a; } - long operator&=(const IntBase& other) { return a &= other.a; } + T& operator^=(const IntBase& other) { return a ^= other.a; } + T& operator&=(const IntBase& other) { return a &= other.a; } friend ostream& operator<<(ostream& s, const IntBase& x) { x.output(s, true); return s; } @@ -95,7 +96,7 @@ class IntBase : public ValueInterface }; // Wrapper class for integer -class Integer : public IntBase +class Integer : public IntBase { public: @@ -114,6 +115,10 @@ class Integer : public IntBase Integer(const bigint& x) { *this = (x > 0) ? x.get_ui() : -x.get_ui(); } template Integer(const Z2& x) : Integer(x.get_limb(0)) {} + template + Integer(const gfp_& x); + + Integer(const Integer& x, int n_bits); void convert_destroy(bigint& other) { *this = other.get_si(); } @@ -150,19 +155,23 @@ class Integer : public IntBase void SHR(const Integer& x, const Integer& y) { *this = (unsigned long)x.a >> y.a; } }; -inline void IntBase::randomize(PRNG& G) +template<> +inline void IntBase::randomize(PRNG& G) { a = G.get_word(); } -inline void to_bigint(bigint& res, const Integer& x) +template<> +inline void IntBase::randomize(PRNG& G) { - res = (unsigned long)x.get(); + a = G.get_bit(); } -inline void to_signed_bigint(bigint& res, const Integer& x) +template +Integer::Integer(const gfp_& x) { - res = x.get(); + to_signed_bigint(bigint::tmp, x); + *this = bigint::tmp; } // slight misnomer diff --git a/Math/Z2k.h b/Math/Z2k.h index 2bd9fbed7..eb94c44d8 100644 --- a/Math/Z2k.h +++ b/Math/Z2k.h @@ -53,6 +53,7 @@ class Z2 : public ValueInterface static int size() { return N_BYTES; } static int size_in_limbs() { return N_WORDS; } + static int size_in_bits() { return size() * 8; } static int t() { return 0; } static char type_char() { return 'R'; } @@ -113,6 +114,8 @@ class Z2 : public ValueInterface Z2 operator/(const Z2& other) const { (void) other; throw not_implemented(); } + Z2 operator&(const Z2& other) const; + Z2& operator+=(const Z2& other); Z2& operator-=(const Z2& other); @@ -148,6 +151,9 @@ class Z2 : public ValueInterface void SHL(const Z2& a, const bigint& i) { *this = a << i.get_ui(); } void SHR(const Z2& a, const bigint& i) { *this = a >> i.get_ui(); } + void SHL(const Z2& a, int i) { *this = a << i; } + void SHR(const Z2& a, int i) { *this = a >> i; } + void AND(const Z2& a, const Z2& b); void OR(const Z2& a, const Z2& b); void XOR(const Z2& a, const Z2& b); @@ -348,10 +354,9 @@ ostream& operator<<(ostream& o, const SignedZ2& x) } template -inline void to_signed_bigint(bigint& res, const SignedZ2& x, int n) +void to_bigint(bigint& res, const SignedZ2& a) { - bigint tmp = x; - to_signed_bigint(res, tmp, n); + res = a; } #endif /* MATH_Z2K_H_ */ diff --git a/Math/Z2k.hpp b/Math/Z2k.hpp index 2bc5f652a..b600497e1 100644 --- a/Math/Z2k.hpp +++ b/Math/Z2k.hpp @@ -58,6 +58,14 @@ bool Z2::get_bit(int i) const return 1 & (a[i / N_LIMB_BITS] >> (i % N_LIMB_BITS)); } +template +Z2 Z2::operator&(const Z2& other) const +{ + Z2 res; + res.AND(*this, other); + return res; +} + template bool Z2::operator==(const Z2& other) const { diff --git a/Math/Zp_Data.cpp b/Math/Zp_Data.cpp index 457441abb..dcfee12b1 100644 --- a/Math/Zp_Data.cpp +++ b/Math/Zp_Data.cpp @@ -17,6 +17,7 @@ void Zp_Data::init(const bigint& p,bool mont) #endif pr=p; + pr_half = p / 2; mask=static_cast(1ULL<<((mpz_sizeinbase(pr.get_mpz_t(),2)-1)%(8*sizeof(mp_limb_t))))-1; pr_byte_length = numBytes(pr); pr_bit_length = numBits(pr); diff --git a/Math/Zp_Data.h b/Math/Zp_Data.h index 989f9a873..cf353fafa 100644 --- a/Math/Zp_Data.h +++ b/Math/Zp_Data.h @@ -47,6 +47,7 @@ class Zp_Data public: bigint pr; + bigint pr_half; mp_limb_t mask; size_t pr_byte_length; size_t pr_bit_length; diff --git a/Math/bigint.cpp b/Math/bigint.cpp index f7afa931f..0ef5751e6 100644 --- a/Math/bigint.cpp +++ b/Math/bigint.cpp @@ -146,17 +146,6 @@ bigint::bigint(const GC::Clear& x) : bigint(SignedZ2<64>(x)) { } -void to_signed_bigint(bigint& res, const bigint& x, int n) -{ - res = abs(x); - bigint& tmp = bigint::tmp = 1; - tmp <<= n; - tmp -= 1; - res &= tmp; - if (x < 0) - res.negate(); -} - #ifdef REALLOC_POLICE void bigint::lottery() { diff --git a/Math/bigint.h b/Math/bigint.h index 8cd57665a..05c10784b 100644 --- a/Math/bigint.h +++ b/Math/bigint.h @@ -105,8 +105,6 @@ class bigint : public mpz_class }; -void to_signed_bigint(bigint& res, const bigint& x, int n); - void inline_mpn_zero(mp_limb_t* x, mp_size_t size); void inline_mpn_copyi(mp_limb_t* dest, const mp_limb_t* src, mp_size_t size); diff --git a/Math/field_types.h b/Math/field_types.h index a361e48ac..e02cd1b8f 100644 --- a/Math/field_types.h +++ b/Math/field_types.h @@ -9,6 +9,16 @@ enum DataFieldType { DATA_INT, DATA_GF2N, DATA_GF2, N_DATA_FIELD_TYPE }; -enum Dtype { DATA_TRIPLE, DATA_SQUARE, DATA_BIT, DATA_INVERSE, DATA_BITTRIPLE, DATA_BITGF2NTRIPLE, N_DTYPE }; +enum Dtype +{ + DATA_TRIPLE, + DATA_SQUARE, + DATA_BIT, + DATA_INVERSE, + DATA_BITTRIPLE, + DATA_BITGF2NTRIPLE, + DATA_DABIT, + N_DTYPE +}; #endif /* MATH_FIELD_TYPES_H_ */ diff --git a/Math/gf2n.cpp b/Math/gf2n.cpp index 3610ada91..9a7a87550 100644 --- a/Math/gf2n.cpp +++ b/Math/gf2n.cpp @@ -1,5 +1,6 @@ #include "Math/gf2n.h" +#include "Math/Bit.h" #include "Exceptions/Exceptions.h" @@ -249,6 +250,11 @@ void gf2n_short::mul(const gf2n_short& x,const gf2n_short& y) reduce(hi,lo); } +gf2n_short gf2n_short::operator*(const Bit& x) const +{ + return x.get() * a; +} + diff --git a/Math/gf2n.h b/Math/gf2n.h index 05e414dd1..5d6a29f70 100644 --- a/Math/gf2n.h +++ b/Math/gf2n.h @@ -14,6 +14,7 @@ using namespace std; class gf2n_short; class P2Data; +class Bit; template class Square; typedef Square gf2n_short_square; @@ -82,6 +83,7 @@ class gf2n_short static string type_string() { return "gf2n"; } static int size() { return sizeof(a); } + static int size_in_bits() { return sizeof(a) * 8; } static int t() { return 0; } static int default_length() { return 40; } @@ -166,6 +168,8 @@ class gf2n_short gf2n_short& operator-=(const gf2n_short& x) { sub(x); return *this; } gf2n_short operator/(const gf2n_short& x) const { gf2n_short tmp; tmp.invert(x); return *this * tmp; } + gf2n_short operator*(const Bit& x) const; + void square(); void square(const gf2n_short& aa); void invert(); diff --git a/Math/gf2nlong.h b/Math/gf2nlong.h index 22bb2f241..1c51b2ae7 100644 --- a/Math/gf2nlong.h +++ b/Math/gf2nlong.h @@ -129,6 +129,7 @@ class gf2n_long static string type_string() { return "gf2n_long"; } static int size() { return sizeof(a); } + static int size_in_bits() { return sizeof(a) * 8; } static int t() { return 0; } static int default_length() { return 128; } diff --git a/Math/gfp.cpp b/Math/gfp.cpp index e2f53f5b0..aafca38e1 100644 --- a/Math/gfp.cpp +++ b/Math/gfp.cpp @@ -209,8 +209,7 @@ void to_signed_bigint(bigint& ans, const gfp& x) { to_bigint(ans, x); // get sign and abs(x) - bigint& p_half = bigint::tmp = (gfp::pr()-1)/2; - if (mpz_cmp(ans.get_mpz_t(), p_half.get_mpz_t()) > 0) + if (mpz_cmp(ans.get_mpz_t(), gfp::get_ZpD().pr_half.get_mpz_t()) > 0) ans -= gfp::pr(); } diff --git a/Math/gfp.h b/Math/gfp.h index 8360bee8c..310db3159 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -78,6 +78,7 @@ class gfp_ static string type_string() { return "gfp"; } static int size() { return t() * sizeof(mp_limb_t); } + static int size_in_bits() { return 8 * size(); } static int length() { return ZpD.pr_bit_length; } static void reqbl(int n); diff --git a/Math/operators.h b/Math/operators.h index b3714cd3f..15307f9be 100644 --- a/Math/operators.h +++ b/Math/operators.h @@ -22,8 +22,8 @@ T& operator*=(const T& y, const bool& x) { y = x ? y : T(); return y; } //template //T& operator+=(T& x, const U& y) { x.add(y); return x; } -template -T& operator*=(T& x, const T& y) { x.mul(y); return x; } +//template +//T& operator*=(T& x, const T& y) { x.mul(y); return x; } //template //T& operator-=(T& x, const U& y) { x.sub(y); return x; } diff --git a/Networking/Player.cpp b/Networking/Player.cpp index 5edd49509..193bf3c3d 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -6,6 +6,7 @@ #include "Tools/int.h" #include "Tools/NetworkOptions.h" #include "Networking/Server.h" +#include "Networking/ServerSocket.h" #include #include diff --git a/Networking/Player.h b/Networking/Player.h index 88e3293f9..e6655a465 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -17,7 +17,6 @@ using namespace std; #include "Tools/octetStream.h" #include "Tools/FlexBuffer.h" #include "Networking/sockets.h" -#include "Networking/ServerSocket.h" #include "Tools/sha1.h" #include "Tools/int.h" #include "Networking/Receiver.h" @@ -26,6 +25,7 @@ using namespace std; template class MultiPlayer; class Server; +class ServerSocket; /* Class to get the names off the server */ class Names diff --git a/OT/BitMatrix.cpp b/OT/BitMatrix.cpp index 95687bb00..824cf27e0 100644 --- a/OT/BitMatrix.cpp +++ b/OT/BitMatrix.cpp @@ -458,9 +458,10 @@ void BitMatrix::resize(int length) squares.resize(length / 128); } -int BitMatrix::size() +template +size_t Matrix::vertical_size() { - return squares.size() * 128; + return squares.size() * U::N_ROWS; } template @@ -681,6 +682,8 @@ Y(72, 48) Y(74, 48) Y(72, 64) Y(74, 64) +Y(1, 48) +Y(1, 64) template class Matrix; @@ -697,6 +700,7 @@ BMS #define XXXX(BM, GF) \ template class Slice; \ + template size_t BM::vertical_size(); \ XX(BM, GF) XXXX(Matrix, gf2n_short) diff --git a/OT/BitMatrix.h b/OT/BitMatrix.h index 9ddfde0c0..1b06f3c03 100644 --- a/OT/BitMatrix.h +++ b/OT/BitMatrix.h @@ -120,6 +120,8 @@ class Matrix vector< U, aligned_allocator > squares; + size_t vertical_size(); + void resize_vertical(int length) { squares.resize(DIV_CEIL(length, U::N_ROWS)); } bool operator==(Matrix& other); @@ -145,7 +147,6 @@ class BitMatrix : public Matrix __m128i& operator[](int i) { return squares[i / 128].rows[i % 128]; } void resize(int length); - int size(); void transpose(); void check_transpose(BitMatrix& dual); diff --git a/OT/NPartyTripleGenerator.h b/OT/NPartyTripleGenerator.h index 484112647..bb9528eeb 100644 --- a/OT/NPartyTripleGenerator.h +++ b/OT/NPartyTripleGenerator.h @@ -131,12 +131,12 @@ class NPartyTripleGenerator : public OTTripleGenerator typedef typename T::mac_key_type mac_key_type; typedef typename T::sacri_type sacri_type; - virtual void generateTriples() = 0; - virtual void generateBits() = 0; + virtual void generateTriples() { throw not_implemented(); } + virtual void generateBits() { throw not_implemented(); } public: vector< ShareTriple_ > uncheckedTriples; - vector>> inputs; + vector> inputs; NPartyTripleGenerator(const OTTripleSetup& setup, const Names& names, int thread_num, int nTriples, int nloops, MascotParams& machine, diff --git a/OT/NPartyTripleGenerator.hpp b/OT/NPartyTripleGenerator.hpp index 9460106df..568a5aa64 100644 --- a/OT/NPartyTripleGenerator.hpp +++ b/OT/NPartyTripleGenerator.hpp @@ -195,25 +195,25 @@ void NPartyTripleGenerator::generate() template void NPartyTripleGenerator::generateInputs(int player) { - typedef open_type T; + typedef typename W::input_type::share_type::open_type T; auto& nTriplesPerLoop = this->nTriplesPerLoop; auto& valueBits = this->valueBits; auto& share_prg = this->share_prg; - auto& field_size = this->field_size; auto& ot_multipliers = this->ot_multipliers; auto& nparties = this->nparties; auto& globalPlayer = this->globalPlayer; // extra value for sacrifice - int toCheck = nTriplesPerLoop + 1; + int toCheck = nTriplesPerLoop + + DIV_CEIL(W::mac_key_type::size_in_bits(), T::size_in_bits()); this->signal_multipliers({player, toCheck}); bool mine = player == globalPlayer.my_num(); valueBits.resize(1); if (mine) { - valueBits[0].resize(toCheck * field_size); + valueBits[0].resize(toCheck * T::size_in_bits()); valueBits[0].template randomize_blocks(share_prg); this->signal_multipliers({}); } @@ -221,7 +221,7 @@ void NPartyTripleGenerator::generateInputs(int player) this->wait_for_multipliers(); GlobalPRNG G(globalPlayer); - Share check_sum; + typename W::input_check_type check_sum; inputs.resize(toCheck); auto mac_key = this->get_mac_key(); SemiInput> input(0, globalPlayer); @@ -236,7 +236,8 @@ void NPartyTripleGenerator::generateInputs(int player) input.exchange(); for (int j = 0; j < toCheck; j++) { - T share, mac_sum; + T share; + typename W::mac_type mac_sum; share = input.finalize(player); if (mine) { @@ -250,11 +251,12 @@ void NPartyTripleGenerator::generateInputs(int player) mac_sum = (ot_multipliers[i_thread])->input_macs[j]; } inputs[j] = {{share, mac_sum}, secrets[j]}; - check_sum += inputs[j].share * G.get(); + auto r = G.get(); + check_sum += typename W::input_check_type(r * share, r * mac_sum); } inputs.resize(nTriplesPerLoop); - typename W::MAC_Check MC(mac_key); + typename W::input_check_type::MAC_Check MC(mac_key); MC.POpen(check_sum, globalPlayer); // use zero element because all is perfectly randomized MC.set_random_element({}); @@ -323,18 +325,21 @@ void MascotTripleGenerator::generateBitsGf2n() } template<> +inline void MascotTripleGenerator>::generateBits() { generateBitsGf2n(); } template<> +inline void MascotTripleGenerator>::generateBits() { generateBitsGf2n(); } template<> +inline void MascotTripleGenerator>::generateBits() { generateTriples(); @@ -740,7 +745,7 @@ void MascotTripleGenerator>::generateBitsFromTriples( a_squared[i] = triples[i].a[0] * opened[i] - triples[i].c[0]; MC.POpen_Begin(opened, a_squared, globalPlayer); MC.POpen_End(opened, a_squared, globalPlayer); - Share one(gfp1(1), globalPlayer.my_num(), MC.get_alphai()); + auto one = Share::constant(1, globalPlayer.my_num(), MC.get_alphai()); bits.clear(); for (int i = 0; i < nTriplesPerLoop; i++) { @@ -797,21 +802,25 @@ void OTTripleGenerator::print_progress(int k) } } +inline void GeneratorThread::lock() { pthread_mutex_lock(&mutex); } +inline void GeneratorThread::unlock() { pthread_mutex_unlock(&mutex); } +inline void GeneratorThread::signal() { pthread_cond_signal(&ready); } +inline void GeneratorThread::wait() { if (multi_threaded) diff --git a/OT/OTExtensionWithMatrix.cpp b/OT/OTExtensionWithMatrix.cpp index 39be81ce9..cee723655 100644 --- a/OT/OTExtensionWithMatrix.cpp +++ b/OT/OTExtensionWithMatrix.cpp @@ -270,8 +270,24 @@ void OTCorrelator::correlate(int start, int slice, #endif } +template +void OTCorrelator::expand_correlate_unchecked(const BitVector& delta, int n_bits) +{ + if (n_bits < 0) + n_bits = delta.size(); + resize(n_bits); + int slice = receiverOutputMatrix.squares.size(); + expand(0, slice); + BitVector tmp = delta; + tmp.resize_zero(receiverOutputMatrix.vertical_size()); + correlate(0, slice, tmp, true); +} + void OTExtensionWithMatrix::transpose(int start, int slice) { + if (slice < 0) + slice = receiverOutputMatrix.squares.size(); + BitMatrixSlice receiverOutputSlice(receiverOutputMatrix, start, slice); BitMatrixSlice senderOutputSlices[2] = { BitMatrixSlice(senderOutputMatrices[0], start, slice), @@ -587,3 +603,5 @@ Y(72, 48) Y(74, 48) Y(72, 64) Y(74, 64) +Y(1, 48) +Y(1, 64) diff --git a/OT/OTExtensionWithMatrix.h b/OT/OTExtensionWithMatrix.h index 72007fa17..08b85ef98 100644 --- a/OT/OTExtensionWithMatrix.h +++ b/OT/OTExtensionWithMatrix.h @@ -44,6 +44,7 @@ class OTCorrelator : public OTExtension vector& baseSenderOutputs, U& baseReceiverOutput); void correlate(int start, int slice, BitVector& newReceiverInput, bool useConstantBase, int repeat = 1); + void expand_correlate_unchecked(const BitVector& delta, int n_bits = -1); template void reduce_squares(unsigned int nTriples, vector& output, int start = 0); @@ -79,7 +80,7 @@ class OTExtensionWithMatrix : public OTCorrelator void extend(int nOTs, BitVector& newReceiverInput); void extend_correlated(const BitVector& newReceiverInput); void extend_correlated(int nOTs, const BitVector& newReceiverInput); - void transpose(int start, int slice); + void transpose(int start = 0, int slice = -1); void expand_transposed(); template void hash_outputs(int nOTs, vector& senderOutput, V& receiverOutput, diff --git a/OT/OTMultiplier.h b/OT/OTMultiplier.h index 88b1135e5..bae9078e5 100644 --- a/OT/OTMultiplier.h +++ b/OT/OTMultiplier.h @@ -99,8 +99,7 @@ class MascotMultiplier : public OTMultiplier template class TinyMultiplier : public OTMultiplier { - OTVole mac_vole; + OTVole mac_vole; void after_correlation(); void init_authenticator(const BitVector& baseReceiverInput, @@ -115,6 +114,24 @@ class TinyMultiplier : public OTMultiplier void multiplyForInputs(MultJob job) { (void) job; throw not_implemented(); } }; +template +class TinierMultiplier : public OTMultiplier +{ + OTExtensionWithMatrix auth_ot_ext; + + void after_correlation(); + void init_authenticator(const BitVector& baseReceiverInput, + const vector< vector >& baseSenderInput, + const vector& baseReceiverOutput); + +public: + vector c_output; + + TinierMultiplier(OTTripleGenerator& generator, int thread_num); + + void multiplyForInputs(MultJob job); +}; + template class Spdz2kShare; template @@ -135,8 +152,8 @@ class Spdz2kMultiplier: public OTMultiplier> static const int MAC_BITS = K + 2 * S; vector > c_output; - OTVoleBase, Z2>* mac_vole; - OTVoleBase, Z2>* input_mac_vole; + OTVoleBase>* mac_vole; + OTVoleBase>* input_mac_vole; Spdz2kMultiplier(OTTripleGenerator& generator, int thread_num); ~Spdz2kMultiplier(); diff --git a/OT/OTMultiplier.hpp b/OT/OTMultiplier.hpp index f5ba18d93..e237e3597 100644 --- a/OT/OTMultiplier.hpp +++ b/OT/OTMultiplier.hpp @@ -44,7 +44,7 @@ template TinyMultiplier::TinyMultiplier(OTTripleGenerator& generator, int thread_num) : OTMultiplier(generator, thread_num), - mac_vole( + mac_vole(T::part_type::mac_key_type::N_BITS, 128, 128, 0, 1, generator.players[thread_num], { }, @@ -54,17 +54,26 @@ TinyMultiplier::TinyMultiplier(OTTripleGenerator& generator, c_output.resize(generator.nTriplesPerLoop); } +template +TinierMultiplier::TinierMultiplier(OTTripleGenerator& generator, + int thread_num) : + OTMultiplier(generator, thread_num), + auth_ot_ext(128, 128, 0, 1, generator.players[thread_num], {}, {}, {}, BOTH, true) +{ + c_output.resize(generator.nTriplesPerLoop); +} + template Spdz2kMultiplier::Spdz2kMultiplier(OTTripleGenerator>& generator, int thread_num) : OTMultiplier> (generator, thread_num) { #ifdef USE_OPT_VOLE - mac_vole = new OTVole, Z2>(128, 128, 0, 1, generator.players[thread_num], {}, {}, {}, BOTH, false); - input_mac_vole = new OTVole, Z2>(128, 128, 0, 1, generator.players[thread_num], {}, {}, {}, BOTH, false); + mac_vole = new OTVole>(S, 128, 128, 0, 1, generator.players[thread_num], {}, {}, {}, BOTH, false); + input_mac_vole = new OTVole>(S, 128, 128, 0, 1, generator.players[thread_num], {}, {}, {}, BOTH, false); #else - mac_vole = new OTVoleBase, Z2>(128, 128, 0, 1, generator.players[thread_num], {}, {}, {}, BOTH, false); - input_mac_vole = new OTVoleBase, Z2>(128, 128, 0, 1, generator.players[thread_num], {}, {}, {}, BOTH, false); + mac_vole = new OTVoleBase>(S, 128, 128, 0, 1, generator.players[thread_num], {}, {}, {}, BOTH, false); + input_mac_vole = new OTVoleBase>(S, 128, 128, 0, 1, generator.players[thread_num], {}, {}, {}, BOTH, false); #endif } @@ -200,6 +209,33 @@ void TinyMultiplier::init_authenticator(const BitVector& keyBits, mac_vole.init(keyBits, senderOutput, receiverOutput); } +template +void TinierMultiplier::init_authenticator(const BitVector& keyBits, + const vector< vector >& senderOutput, + const vector& receiverOutput) +{ + auto tmpBits = keyBits; + tmpBits.resize_zero(128); + auto tmpSenderOutput = senderOutput; + tmpSenderOutput.resize(128); + SeededPRNG G; + for (auto& x : tmpSenderOutput) + { + x.resize(2); + for (auto& y : x) + if (y.size() == 0) + { + y.resize(128); + y.randomize(G); + } + } + auto tmpReceiverOutput = receiverOutput; + tmpReceiverOutput.resize(128); + for (auto& y : tmpReceiverOutput) + y.resize_zero(128); + auth_ot_ext.init(tmpBits, tmpSenderOutput, tmpReceiverOutput); +} + template void Spdz2kMultiplier::init_authenticator(const BitVector& keyBits, const vector< vector >& senderOutput, @@ -301,6 +337,33 @@ void TinyMultiplier::after_correlation() this->outbox.push(job); } +template +void TinierMultiplier::after_correlation() +{ + this->auth_ot_ext.set_role(BOTH); + + this->otCorrelator.reduce_squares(this->generator.nTriplesPerLoop, + this->c_output); + + this->outbox.push({}); + + this->macs.resize(3); + MultJob job; + this->inbox.pop(job); + for (int j = 0; j < 3; j++) + { + auth_ot_ext.expand_correlate_unchecked(this->generator.valueBits[j]); + auth_ot_ext.transpose(); + this->macs[j].clear(); + for (size_t i = 0; i < this->generator.valueBits[j].size(); i++) + this->macs[j].push_back( + int128( + auth_ot_ext.receiverOutputMatrix[i] + ^ auth_ot_ext.senderOutputMatrices[0][i]).get_lower()); + } + this->outbox.push(job); +} + template void Spdz2kMultiplier::after_correlation() { @@ -340,18 +403,21 @@ void Spdz2kMultiplier::after_correlation() } template<> +inline void OTMultiplier>::multiplyForBits() { multiplyForTriples(); } template<> +inline void OTMultiplier>::multiplyForBits() { multiplyForGf2nBits(); } template<> +inline void OTMultiplier>::multiplyForBits() { multiplyForGf2nBits(); @@ -447,6 +513,29 @@ void Spdz2kMultiplier::multiplyForInputs(MultJob job) this->outbox.push(job); } +template +void TinierMultiplier::multiplyForInputs(MultJob job) +{ + assert(job.input); + auto& generator = this->generator; + bool mine = job.player == generator.my_num; + auth_ot_ext.set_role(mine ? RECEIVER : SENDER); + if (mine) + this->inbox.pop(); + assert(not mine or job.n_inputs <= (int)generator.valueBits[0].size()); + auth_ot_ext.expand_correlate_unchecked(generator.valueBits[0], job.n_inputs); + auth_ot_ext.transpose(); + auto& input_macs = this->input_macs; + input_macs.resize(job.n_inputs); + if (mine) + for (int j = 0; j < job.n_inputs; j++) + input_macs[j] = int128(auth_ot_ext.receiverOutputMatrix[j]).get_lower(); + else + for (int j = 0; j < job.n_inputs; j++) + input_macs[j] = int128(auth_ot_ext.senderOutputMatrices[0][j]).get_lower(); + this->outbox.push(job); +} + template void OTMultiplier::multiplyForBits() { diff --git a/OT/OTVole.h b/OT/OTVole.h index 2bb1a9580..5e9c6e3c9 100644 --- a/OT/OTVole.h +++ b/OT/OTVole.h @@ -5,16 +5,17 @@ #include "Math/Z2k.h" #include "OTExtension.h" #include "Row.h" +#include "config.h" using namespace std; -template +template class OTVoleBase : public OTExtension { public: - static const int S = U::N_BITS; + const int S; - OTVoleBase(int nbaseOTs, int baseLength, + OTVoleBase(int S, int nbaseOTs, int baseLength, int nloops, int nsubloops, TwoPartyPlayer* player, const BitVector& baseReceiverInput, @@ -24,12 +25,13 @@ class OTVoleBase : public OTExtension bool passive=false) : OTExtension(nbaseOTs, baseLength, nloops, nsubloops, player, baseReceiverInput, baseSenderInput, baseReceiverOutput, INV_ROLE(role), passive), + S(S), corr_prime(), - t0(U::N_BITS), - t1(U::N_BITS), - u(U::N_BITS), - t(U::N_BITS), - a(U::N_BITS) { + t0(S), + t1(S), + u(S), + t(S), + a(S) { // need to flip roles for OT extension init, reset to original role here this->ot_role = role; local_prng.ReSeed(); @@ -39,6 +41,9 @@ class OTVoleBase : public OTExtension void evaluate(vector& output, int nValues, const BitVector& newReceiverInput); + virtual int n_challenges() { return S; } + virtual int get_challenge(PRNG&, int i) { return i; } + protected: // Sender fields @@ -61,12 +66,12 @@ class OTVoleBase : public OTExtension }; -template -class OTVole : public OTVoleBase +template +class OTVole : public OTVoleBase { public: - OTVole(int nbaseOTs, int baseLength, + OTVole(int S, int nbaseOTs, int baseLength, int nloops, int nsubloops, TwoPartyPlayer* player, const BitVector& baseReceiverInput, @@ -74,16 +79,12 @@ class OTVole : public OTVoleBase const vector& baseReceiverOutput, OT_ROLE role=BOTH, bool passive=false) - : OTVoleBase(nbaseOTs, baseLength, nloops, nsubloops, player, baseReceiverInput, + : OTVoleBase(S, nbaseOTs, baseLength, nloops, nsubloops, player, baseReceiverInput, baseSenderInput, baseReceiverOutput, INV_ROLE(role), passive) { } -protected: - - Row tmp; - - void consistency_check(vector& os); - + int n_challenges() { return NUM_VOLE_CHALLENGES; } + int get_challenge(PRNG& G, int) { return G.get_uint(this->S); } }; #endif diff --git a/OT/OTVole.hpp b/OT/OTVole.hpp index ffb168178..247dc509c 100644 --- a/OT/OTVole.hpp +++ b/OT/OTVole.hpp @@ -6,8 +6,8 @@ //#define OTVOLE_TIMER -template -void OTVoleBase::evaluate(vector& output, const vector& newReceiverInput) { +template +void OTVoleBase::evaluate(vector& output, const vector& newReceiverInput) { const int N1 = newReceiverInput.size() + 1; output.resize(newReceiverInput.size()); vector os(2); @@ -65,8 +65,8 @@ void OTVoleBase::evaluate(vector& output, const vector& newReceiverI output[i] = res.rows[i]; } -template -void OTVoleBase::evaluate(vector& output, int nValues, const BitVector& newReceiverInput) { +template +void OTVoleBase::evaluate(vector& output, int nValues, const BitVector& newReceiverInput) { vector values(nValues); if (ot_role & SENDER) { @@ -78,22 +78,20 @@ void OTVoleBase::evaluate(vector& output, int nValues, const BitVector& evaluate(output, values); } -template -void OTVoleBase::set_coeffs(__m128i* coefficients, PRNG& G, int num_blocks) const { - avx_memzero(coefficients, num_blocks); - for (int i = 0; i < num_blocks; ++i) - coefficients[i] = G.get_doubleword(); +template +void OTVoleBase::set_coeffs(__m128i* coefficients, PRNG& G, int num_blocks) const { + G.get_octets((octet*) coefficients, num_blocks * sizeof(__m128i)); } -template -void OTVoleBase::hash_row(octetStream& os, const Row& row, const __m128i* coefficients) { +template +void OTVoleBase::hash_row(octetStream& os, const Row& row, const __m128i* coefficients) { octet hash[VOLE_HASH_SIZE] = {0}; this->hash_row(hash, row, coefficients); os.append(hash, VOLE_HASH_SIZE); } -template -void OTVoleBase::hash_row(octet* hash, const Row& row, const __m128i* coefficients) { +template +void OTVoleBase::hash_row(octet* hash, const Row& row, const __m128i* coefficients) { int num_blocks = DIV_CEIL(row.size() * T::size(), 16); octetStream os; @@ -118,8 +116,8 @@ void OTVoleBase::hash_row(octet* hash, const Row& row, const __m128i* c (octet*) res, crypto_generichash_BYTES, NULL, 0); } -template -void OTVoleBase::consistency_check(vector& os) { +template +void OTVoleBase::consistency_check(vector& os) { PRNG coef_prng_sender; PRNG coef_prng_receiver; @@ -149,8 +147,10 @@ void OTVoleBase::consistency_check(vector& os) { Row t00(t0.size()), t01(t0.size()), t10(t0.size()), t11(t0.size()); for (int alpha = 0; alpha < S; ++alpha) { - for (int beta = 0; beta < S; ++beta) + for (int i = 0; i < n_challenges(); i++) { + int beta = get_challenge(coef_prng_sender, i); + t00 = t0[alpha] - t0[beta]; t01 = t0[alpha] - t1[beta]; t10 = t1[alpha] - t0[beta]; @@ -190,8 +190,10 @@ void OTVoleBase::consistency_check(vector& os) { for (int alpha = 0; alpha < S; ++alpha) { - for (int beta = 0; beta < S; ++beta) + for (int i = 0; i < n_challenges(); i++) { + int beta = get_challenge(coef_prng_receiver, i); + os[1].consume(hashes[0][0], VOLE_HASH_SIZE); os[1].consume(hashes[0][1], VOLE_HASH_SIZE); os[1].consume(hashes[1][0], VOLE_HASH_SIZE); @@ -230,120 +232,3 @@ void OTVoleBase::consistency_check(vector& os) { #endif } } - -template -void OTVole::consistency_check(vector& os) { - PRNG sender_prg; - PRNG receiver_prg; - - if (this->ot_role & RECEIVER) { - receiver_prg.ReSeed(); - os[0].append(receiver_prg.get_seed(), SEED_SIZE); - } - send_if_ot_receiver(this->player, os, this->ot_role); - if (this->ot_role & SENDER) { - octet seed[SEED_SIZE]; - os[1].consume(seed, SEED_SIZE); - sender_prg.SetSeed(seed); - } - os[0].reset_write_head(); - os[1].reset_write_head(); - - if (this->ot_role & SENDER) { -#ifdef OTVOLE_TIMER - timeval totalstartv, totalendv; - gettimeofday(&totalstartv, NULL); -#endif - int total_bytes = this->t0[0].size() * T::size(); - int num_blocks = (total_bytes) / 16 + ((total_bytes % 16) != 0); - __m128i coefficients[num_blocks]; - this->set_coeffs(coefficients, sender_prg, num_blocks); - - Row t00(this->t0.size()), t01(this->t0.size()), t10(this->t0.size()), t11(this->t0.size()); - for (int alpha = 0; alpha < U::N_BITS; ++alpha) - { - for (int i = 0; i < NUM_VOLE_CHALLENGES; ++i) - { - int beta = sender_prg.get_uint(U::N_BITS); - - t00 = this->t0[alpha] - this->t0[beta]; - t01 = this->t0[alpha] - this->t1[beta]; - t10 = this->t1[alpha] - this->t0[beta]; - t11 = this->t1[alpha] - this->t1[beta]; - - this->hash_row(os[0], t00, coefficients); - this->hash_row(os[0], t01, coefficients); - this->hash_row(os[0], t10, coefficients); - this->hash_row(os[0], t11, coefficients); - } - } -#ifdef OTVOLE_TIMER - gettimeofday(&totalendv, NULL); - double elapsed = timeval_diff(&totalstartv, &totalendv); - cout << "\t\tCheck time sender: " << elapsed/1000000 << endl << flush; -#endif - } - - send_if_ot_sender(this->player, os, this->ot_role); - if (this->ot_role & RECEIVER) { -#ifdef OTVOLE_TIMER - timeval totalstartv, totalendv; - gettimeofday(&totalstartv, NULL); -#endif - int total_bytes = this->t[0].size() * T::size(); - int num_blocks = (total_bytes) / 16 + ((total_bytes % 16) != 0); - __m128i coefficients[num_blocks]; - this->set_coeffs(coefficients, receiver_prg, num_blocks); - - octet h00[VOLE_HASH_SIZE] = {0}; - octet h01[VOLE_HASH_SIZE] = {0}; - octet h10[VOLE_HASH_SIZE] = {0}; - octet h11[VOLE_HASH_SIZE] = {0}; - vector> hashes(2); - hashes[0] = {h00, h01}; - hashes[1] = {h10, h11}; - - for (int alpha = 0; alpha < U::N_BITS; ++alpha) - { - for (int i = 0; i < NUM_VOLE_CHALLENGES; ++i) - { - int beta = receiver_prg.get_uint(U::N_BITS); - - os[1].consume(hashes[0][0], VOLE_HASH_SIZE); - os[1].consume(hashes[0][1], VOLE_HASH_SIZE); - os[1].consume(hashes[1][0], VOLE_HASH_SIZE); - os[1].consume(hashes[1][1], VOLE_HASH_SIZE); - - int choice_alpha = this->baseReceiverInput.get_bit(alpha); - int choice_beta = this->baseReceiverInput.get_bit(beta); - - tmp = this->t[alpha] - this->t[beta]; - octet* choice_hash = hashes[choice_alpha][choice_beta]; - octet diff_t[VOLE_HASH_SIZE] = {0}; - this->hash_row(diff_t, tmp, coefficients); - - octet* not_choice_hash = hashes[1 - choice_alpha][1 - choice_beta]; - octet other_diff[VOLE_HASH_SIZE] = {0}; - tmp = this->u[alpha] - this->u[beta]; - tmp -= this->t[alpha]; - tmp += this->t[beta]; - this->hash_row(other_diff, tmp, coefficients); - - if (!OCTETS_EQUAL(choice_hash, diff_t, VOLE_HASH_SIZE)) { - throw consistency_check_fail(); - } - if (!OCTETS_EQUAL(not_choice_hash, other_diff, VOLE_HASH_SIZE)) { - throw consistency_check_fail(); - } - if (alpha != beta && this->u[alpha] == this->u[beta]) { - throw consistency_check_fail(); - } - } - } -#ifdef OTVOLE_TIMER - gettimeofday(&totalendv, NULL); - double elapsed = timeval_diff(&totalstartv, &totalendv); - cout << "\t\tCheck receiver: " << elapsed/1000000 << endl << flush; -#endif - } -} diff --git a/Processor/BaseMachine.cpp b/Processor/BaseMachine.cpp index a42f784eb..db8627dfd 100644 --- a/Processor/BaseMachine.cpp +++ b/Processor/BaseMachine.cpp @@ -10,6 +10,7 @@ using namespace std; BaseMachine* BaseMachine::singleton = 0; +thread_local int BaseMachine::thread_num; BaseMachine& BaseMachine::s() { diff --git a/Processor/BaseMachine.h b/Processor/BaseMachine.h index 5adebec57..cb0b337b7 100644 --- a/Processor/BaseMachine.h +++ b/Processor/BaseMachine.h @@ -27,6 +27,8 @@ class BaseMachine virtual void load_program(string threadname, string filename); public: + static thread_local int thread_num; + string progname; int nthreads; @@ -47,6 +49,13 @@ class BaseMachine void stop(int n); virtual void reqbl(int n) { (void)n; throw runtime_error("not defined"); } + + OTTripleSetup fresh_ot_setup(); }; +inline OTTripleSetup BaseMachine::fresh_ot_setup() +{ + return ot_setups.at(thread_num).get_fresh(); +} + #endif /* PROCESSOR_BASEMACHINE_H_ */ diff --git a/Processor/Data_Files.h b/Processor/Data_Files.h index e5671d26d..d6396aaa7 100644 --- a/Processor/Data_Files.h +++ b/Processor/Data_Files.h @@ -72,6 +72,7 @@ class Preprocessing { DataPositions& usage; +protected: void count(Dtype dtype) { usage.files[T::field_type()][dtype]++; } void count(DataTag tag, int n = 1) { usage.extended[T::field_type()][tag] += n; } void count_input(int player) { usage.inputs[player][T::field_type()]++; } @@ -111,6 +112,7 @@ class Preprocessing void get(vector& S, DataTag tag, const vector& regs, int vector_size); virtual array get_triple(int n_bits); + virtual void get_dabit(T&, typename T::bit_type&) { throw runtime_error("no daBit"); } virtual void buffer_triples() {} virtual void buffer_inverses() {} diff --git a/Processor/DummyProtocol.h b/Processor/DummyProtocol.h index e1956a5bf..a83669da2 100644 --- a/Processor/DummyProtocol.h +++ b/Processor/DummyProtocol.h @@ -10,26 +10,103 @@ using namespace std; #include "Math/BitVec.h" +#include "Data_Files.h" class Player; +class DataPositions; template class SubProcessor; +namespace GC +{ +class NoShare; + +template class ShareThread; +} + +template class DummyMC { public: + void POpen(vector&, vector&, Player&) + { + throw not_implemented(); + } + void Check(Player& P) { (void) P; } + + DummyMC& get_part_MC() + { + return *new DummyMC; + } + + typename T::mac_key_type get_alphai() + { + throw not_implemented(); + return {}; + } }; class DummyProtocol { public: - DummyProtocol(Player& P) + Player& P; + + static int get_n_relevant_players() + { + throw not_implemented(); + } + + DummyProtocol(Player& P) : + P(P) { - (void) P; + } +}; + +template +class DummyLivePrep : public Preprocessing +{ +public: + static void fail() + { + throw runtime_error( + "live preprocessing not implemented for " + T::type_string()); + } + + DummyLivePrep(DataPositions& usage, GC::ShareThread&) : + Preprocessing(usage) + { + } + DummyLivePrep(DataPositions& usage) : + Preprocessing(usage) + { + } + + void set_protocol(typename T::Protocol&) + { + } + void get_three_no_count(Dtype, T&, T&, T&) + { + fail(); + } + void get_two_no_count(Dtype, T&, T&) + { + fail(); + } + void get_one_no_count(Dtype, T&) + { + fail(); + } + void get_input_no_count(T&, typename T::open_type&, int) + { + fail(); + } + void get_no_count(vector&, DataTag, const vector&, int) + { + fail(); } }; @@ -38,7 +115,7 @@ class NotImplementedInput { public: template - NotImplementedInput(T& proc, U& MC) + NotImplementedInput(const T& proc, const U& MC) { (void) proc, (void) MC; } @@ -77,16 +154,20 @@ class NotImplementedInput (void) P; throw not_implemented(); } - void add_mine(int a, int b) + void add_mine(int a, int b = 0) { (void) a, (void) b; throw not_implemented(); } + void add_other(int) + { + throw not_implemented(); + } void exchange() { throw not_implemented(); } - V finalize(int a, int b) + V finalize(int a, int b = 0) { (void) a, (void) b; throw not_implemented(); diff --git a/Processor/ExternalClients.cpp b/Processor/ExternalClients.cpp index 3a5e8329d..ec83a2011 100644 --- a/Processor/ExternalClients.cpp +++ b/Processor/ExternalClients.cpp @@ -1,4 +1,5 @@ #include "Processor/ExternalClients.h" +#include "Networking/ServerSocket.h" #include #include #include diff --git a/Processor/ExternalClients.h b/Processor/ExternalClients.h index be4004cf0..687cac942 100644 --- a/Processor/ExternalClients.h +++ b/Processor/ExternalClients.h @@ -1,7 +1,6 @@ #ifndef _ExternalClients #define _ExternalClients -#include "Networking/ServerSocket.h" #include "Networking/sockets.h" #include "Exceptions/Exceptions.h" #include @@ -11,6 +10,8 @@ #include #include +class AnonymousServerSocket; + /* * Manage the reading and writing of data from/to external clients via Sockets. * Generate the session keys for encryption/decryption of secret communication with external clients. diff --git a/Processor/Input.hpp b/Processor/Input.hpp index 8bb376d39..51f959303 100644 --- a/Processor/Input.hpp +++ b/Processor/Input.hpp @@ -25,14 +25,14 @@ InputBase::InputBase(ArithmeticProcessor* proc) : template Input::Input(SubProcessor& proc, MAC_Check& mc) : - InputBase(&proc.Proc), proc(&proc), MC(mc), prep(proc.DataF), P(proc.P), + InputBase(proc.Proc), proc(&proc), MC(mc), prep(proc.DataF), P(proc.P), shares(proc.P.num_players()) { } template Input::Input(SubProcessor* proc, Player& P) : - InputBase(&proc->Proc), proc(proc), MC(proc->MC), prep(proc->DataF), P( + InputBase(proc->Proc), proc(proc), MC(proc->MC), prep(proc->DataF), P( proc->P), shares(P.num_players()) { assert (proc != 0); @@ -207,11 +207,12 @@ void InputBase::prepare(SubProcessor& Proc, int player, const int* params, int size) { auto& input = Proc.input; + assert(Proc.Proc != 0); if (player == Proc.P.my_num()) { for (int j = 0; j < size; j++) { - U tuple = Proc.Proc.template get_input(Proc.Proc.use_stdin(), + U tuple = Proc.Proc->template get_input(Proc.Proc->use_stdin(), params); for (auto x : tuple.items) input.add_mine(x); @@ -247,7 +248,7 @@ void InputBase::input(SubProcessor& Proc, int n_from_me = 0; - if (Proc.Proc.use_stdin()) + if (Proc.Proc and Proc.Proc->use_stdin()) { for (size_t i = n_arg_tuple - 1; i < args.size(); i += n_arg_tuple) n_from_me += (args[i] == Proc.P.my_num()) * size; @@ -293,7 +294,7 @@ void InputBase::input_mixed(SubProcessor& Proc, const vector& args, case U::TYPE: \ n_arg_tuple = U::N_DEST + U::N_PARAM + 2; \ player = args[i + n_arg_tuple - 1]; \ - if (type != last_type and Proc.Proc.use_stdin()) \ + if (type != last_type and Proc.Proc and Proc.Proc->use_stdin()) \ cout << "Please input " << U::NAME << "s:" << endl; \ prepare(Proc, player, &args[i + U::N_DEST + 1], size); \ break; diff --git a/Processor/Instruction.h b/Processor/Instruction.h index 6d22f2c7c..aff1b82af 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -103,6 +103,7 @@ enum INV = 0x53, INPUTMASK = 0x56, PREP = 0x57, + DABIT = 0x58, // Input INPUT = 0x60, INPUTFIX = 0xF0, @@ -273,6 +274,9 @@ enum RegType { MODP, GF2N, INT, + SBIT, + CBIT, + DYN_SBIT, MAX_REG_TYPE, NONE }; @@ -310,13 +314,22 @@ class BaseInstruction public: virtual ~BaseInstruction() {}; - void parse_operands(istream& s, int pos); + int get_r(int i) const { return r[i]; } + unsigned int get_n() const { return n; } + const vector& get_start() const { return start; } + int get_opcode() const { return opcode; } + int get_size() const { return size; } + + void parse_operands(istream& s, int pos, int file_pos); bool is_gf2n_instruction() const { return ((opcode&0x100)!=0); } virtual int get_reg_type() const; bool is_direct_memory_access(SecrecyType sec_type) const; + // Returns the memory size used if applicable and known + unsigned get_mem(RegType reg_type, SecrecyType sec_type) const; + // Returns the maximal register used unsigned get_max_reg(int reg_type) const; }; @@ -327,14 +340,11 @@ class Instruction : public BaseInstruction { public: // Reads a single instruction from the istream - void parse(istream& s); + void parse(istream& s, int inst_pos); // Return whether usage is known bool get_offline_data_usage(DataPositions& usage); - // Returns the memory size used if applicable and known - unsigned get_mem(RegType reg_type, SecrecyType sec_type) const; - friend ostream& operator<<(ostream& s,const Instruction& instr); // Execute this instruction, updateing the processor and memory diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index a7d9a36f7..b339681f0 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -8,6 +8,8 @@ #include "Exceptions/Exceptions.h" #include "Tools/time-func.h" #include "Tools/parse.h" +#include "GC/Instruction.h" +#include "GC/instructions.h" //#include "Processor/Processor.hpp" #include "Processor/Binary_File_IO.hpp" @@ -28,53 +30,32 @@ #include "Tools/callgrind.h" -// Convert modp to signed bigint of a given bit length inline -void to_signed_bigint(bigint& bi, const gfp& x, int len) -{ - to_bigint(bi, x); - int neg; - // get sign and abs(x) - bigint& p_half = bigint::tmp = (gfp::pr()-1)/2; - if (mpz_cmp(bi.get_mpz_t(), p_half.get_mpz_t()) < 0) - neg = 0; - else - { - bi = gfp::pr() - bi; - neg = 1; - } - // reduce to range -2^(len-1), ..., 2^(len-1) - bigint& one = bigint::tmp = 1; - bi &= (one << len) - 1; - if (neg) - bi = -bi; -} - -inline -void Instruction::parse(istream& s) +void Instruction::parse(istream& s, int inst_pos) { n=0; start.resize(0); r[0]=0; r[1]=0; r[2]=0; r[3]=0; int pos=s.tellg(); opcode=get_int(s); - size=opcode>>9; - opcode&=0x1FF; + size=opcode>>10; + opcode&=0x3FF; if (size==0) size=1; - parse_operands(s, pos); + parse_operands(s, inst_pos, pos); } inline -void BaseInstruction::parse_operands(istream& s, int pos) +void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) { int num_var_args = 0; switch (opcode) { // instructions with 3 register operands case ADDC: + case ADDCB: case ADDS: case ADDM: case SUBC: @@ -88,6 +69,7 @@ void BaseInstruction::parse_operands(istream& s, int pos) case TRIPLE: case ANDC: case XORC: + case XORCB: case ORC: case SHLC: case SHRC: @@ -125,8 +107,11 @@ void BaseInstruction::parse_operands(istream& s, int pos) case LDMSI: case STMCI: case STMSI: + case LDMSBI: + case STMSBI: case MOVC: case MOVS: + case MOVSB: case MOVINT: case LDMINTI: case STMINTI: @@ -154,13 +139,16 @@ void BaseInstruction::parse_operands(istream& s, int pos) case GPROTECTMEMC: case PROTECTMEMINT: case CONDPRINTPLAIN: + case DABIT: r[0]=get_int(s); r[1]=get_int(s); break; // instructions with 1 register operand case BIT: + case BITB: case PRINTMEM: case PRINTREGPLAIN: + case PRINTREGPLAINB: case LDTN: case LDARG: case STARG: @@ -190,20 +178,25 @@ void BaseInstruction::parse_operands(istream& s, int pos) break; // instructions with 2 registers + 1 integer operand case ADDCI: + case ADDCBI: case ADDSI: case SUBCI: case SUBSI: case SUBCFI: case SUBSFI: case MULCI: + case MULCBI: case MULSI: case DIVCI: case MODCI: case ANDCI: case XORCI: + case XORCBI: case ORCI: case SHLCI: case SHRCI: + case SHLCBI: + case SHRCBI: case NOTC: case CONVMODP: case GADDCI: @@ -238,6 +231,10 @@ void BaseInstruction::parse_operands(istream& s, int pos) case LDMS: case STMC: case STMS: + case LDMSB: + case STMSB: + case LDMCB: + case STMCB: case LDMINT: case STMINT: case JMPNZ: @@ -249,6 +246,7 @@ void BaseInstruction::parse_operands(istream& s, int pos) case GSTMC: case GSTMS: case PRINTREG: + case PRINTREGB: case GPRINTREG: case LDINT: case STARTINPUT: @@ -260,6 +258,7 @@ void BaseInstruction::parse_operands(istream& s, int pos) case ACCEPTCLIENTCONNECTION: case INV2M: case CONDPRINTSTR: + case CONDPRINTSTRB: r[0]=get_int(s); n = get_int(s); break; @@ -281,6 +280,7 @@ void BaseInstruction::parse_operands(istream& s, int pos) break; // instructions with 4 register operands case PRINTFLOATPLAIN: + case PRINTFLOATPLAINB: get_vector(4, start, s); break; // open instructions + read/write instructions with variable length args @@ -385,11 +385,55 @@ void BaseInstruction::parse_operands(istream& s, int pos) throw Processor_Error(ss.str()); } break; + case XORM: + case ANDM: + n = get_int(s); + get_ints(r, s, 3); + break; + case LDBITS: + get_ints(r, s, 2); + n = get_int(s); + break; + case BITDECS: + case BITCOMS: + case BITDECC: + num_var_args = get_int(s) - 1; + get_ints(r, s, 1); + get_vector(num_var_args, start, s); + break; + case CONVCINT: + case CONVCBIT: + get_ints(r, s, 2); + break; + case REVEAL: + case CONVSINT: + n = get_int(s); + get_ints(r, s, 2); + break; + case LDMSDI: + case STMSDI: + case LDMSD: + case STMSD: + case STMSDCI: + case XORS: + case ANDRS: + case ANDS: + case INPUTB: + get_vector(get_int(s), start, s); + break; + case PRINTREGSIGNED: + n = get_int(s); + get_ints(r, s, 1); + break; + case TRANS: + num_var_args = get_int(s) - 1; + n = get_int(s); + get_vector(num_var_args, start, s); + break; default: ostringstream os; - os << "Invalid instruction " << hex << showbase << opcode << " at " << dec << pos << endl; - os << "This virtual machine executes arithmetic circuits only." << endl; - os << "Try compiling without '-B' and don't use sbit* types." << endl; + os << "Invalid instruction " << showbase << hex << opcode << " at " << dec + << pos << "/" << hex << file_pos << dec << endl; throw Invalid_Instruction(os.str()); } } @@ -427,6 +471,14 @@ bool Instruction::get_offline_data_usage(DataPositions& usage) inline int BaseInstruction::get_reg_type() const { + switch (opcode & 0x2B0) + { + case SECRET_WRITE: + return SBIT; + case CLEAR_WRITE: + return CBIT; + } + switch (opcode) { case LDMINT: case STMINT: @@ -467,8 +519,19 @@ int BaseInstruction::get_reg_type() const inline unsigned BaseInstruction::get_max_reg(int reg_type) const { + if (opcode == DABIT) + { + if (reg_type == SBIT) + return r[1] + size; + else if (reg_type == MODP) + return r[0] + size; + } + if (get_reg_type() != reg_type) { return 0; } + int skip = 0; + int offset = 0; + switch (opcode) { case DOTPRODS: @@ -483,28 +546,44 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const } return res; } + case LDMSD: + case LDMSDI: + skip = 3; + break; + case STMSD: + case STMSDI: + skip = 2; + break; + case ANDRS: + case XORS: + case ANDS: + skip = 4; + offset = 1; + break; + case INPUTB: + skip = 4; + offset = 3; + break; } - const int *begin, *end; - if (start.size()) - { - begin = start.data(); - end = start.data() + start.size(); - } - else - { - begin = r; - end = r + 3; - } + if (skip > 0) + { + unsigned m = 0; + for (size_t i = offset; i < start.size(); i += skip) + m = max(m, (unsigned)start[i] + 1); + return m; + } unsigned res = 0; - for (auto it = begin; it != end; it++) - res = max(res, (unsigned)*it); + for (auto x : start) + res = max(res, (unsigned)x); + for (auto x : r) + res = max(res, (unsigned)x); return res + size; } inline -unsigned Instruction::get_mem(RegType reg_type, SecrecyType sec_type) const +unsigned BaseInstruction::get_mem(RegType reg_type, SecrecyType sec_type) const { if (get_reg_type() == reg_type and is_direct_memory_access(sec_type)) return n + size; @@ -515,33 +594,27 @@ unsigned Instruction::get_mem(RegType reg_type, SecrecyType sec_type) const inline bool BaseInstruction::is_direct_memory_access(SecrecyType sec_type) const { - if (sec_type == SECRET) - { - switch (opcode) - { - case LDMS: - case STMS: - case GLDMS: - case GSTMS: - return true; - default: - return false; - } - } - else + switch (opcode) { - switch (opcode) - { - case LDMC: - case STMC: - case GLDMC: - case GSTMC: - case LDMINT: - case STMINT: - return true; - default: - return false; - } + case LDMS: + case STMS: + case GLDMS: + case GSTMS: + return sec_type == SECRET; + case LDMC: + case STMC: + case GLDMC: + case GSTMC: + return sec_type == CLEAR; + case LDMINT: + case STMINT: + case LDMSB: + case STMSB: + case LDMCB: + case STMCB: + return true; + default: + return false; } } @@ -572,6 +645,11 @@ inline void Instruction::execute(Processor& Proc) const auto& Procp = Proc.Procp; auto& Proc2 = Proc.Proc2; + // binary instructions + typedef typename sint::bit_type T; + auto& processor = Proc.Procb; + auto& instruction = *this; + // optimize some instructions switch (opcode) { @@ -683,6 +761,9 @@ inline void Instruction::execute(Processor& Proc) const Proc.get_Cp_ref(r[0] + i) = Proc.temp.ansp; } return; +#define X(NAME, CODE) case NAME: CODE; return; + COMBI_INSTRUCTIONS +#undef X } int r[3] = {this->r[0], this->r[1], this->r[2]}; @@ -1293,8 +1374,7 @@ inline void Instruction::execute(Processor& Proc) const if (n > 64) throw Processor_Error(to_string(n) + "-bit conversion impossible; " "integer registers only have 64 bits"); - to_signed_bigint(Proc.temp.aa,Proc.read_Cp(r[1]),n); - Proc.write_Ci(r[0], Proc.temp.aa.get_si()); + Proc.write_Ci(r[0], Integer(Proc.read_Cp(r[1]), n).get()); } break; case GCONVGF2N: @@ -1344,8 +1424,7 @@ inline void Instruction::execute(Processor& Proc) const typename sint::clear z = Proc.read_Cp(start[2]); typename sint::clear s = Proc.read_Cp(start[3]); // MPIR can't handle more precision in exponent - to_signed_bigint(Proc.temp.aa2, p, 31); - long exp = Proc.temp.aa2.get_si(); + long exp = Integer(p, 31).get(); Proc.out << bigint::get_float(v, exp, z, s) << flush; } break; @@ -1525,9 +1604,12 @@ inline void Instruction::execute(Processor& Proc) const Proc2.DataF.get(Proc.Proc2.get_S(), r, start, size); return; default: - printf("Case of opcode=%d not implemented yet\n",opcode); + printf("Case of opcode=0x%x not implemented yet\n",opcode); throw not_implemented(); break; +#define X(NAME, CODE) case NAME: throw runtime_error("wrong case statement"); return; + COMBI_INSTRUCTIONS +#undef X } if (size > 1) { diff --git a/Processor/Machine.h b/Processor/Machine.h index f055b5516..bbf3572c4 100644 --- a/Processor/Machine.h +++ b/Processor/Machine.h @@ -13,6 +13,8 @@ #include "Processor/Online-Thread.h" +#include "GC/Machine.h" + #include "Tools/time-func.h" #include @@ -36,6 +38,7 @@ class Machine : public BaseMachine Names& N; typename sint::mac_key_type alphapi; typename sgf2n::mac_key_type alpha2i; + typename sint::bit_type::mac_key_type alphabi; // Keep record of used offline data DataPositions pos; @@ -55,6 +58,7 @@ class Machine : public BaseMachine Memory M2; Memory Mp; Memory Mi; + GC::Memories bit_memories; vector join_timer; Timer finish_timer; @@ -90,6 +94,8 @@ class Machine : public BaseMachine Machine(Names& N = *(new Names())): N(N) {} void reqbl(int n); + + typename sint::bit_type::mac_key_type get_bit_mac_key() { return alphabi; } }; #endif /* MACHINE_H_ */ diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index bc4eef067..7754a2d53 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -86,6 +86,12 @@ Machine::Machine(int my_number, Names& playerNames, cerr << "MAC Key 2 = " << alpha2i << endl; #endif + // MAC key for bits might depend on sint MAC key + sint::bit_type::generate_mac_key(alphabi, alphapi); + + // deactivate output if necessary + sint::bit_type::out.activate(my_number == 0 or opts.interactive); + // for OT-based preprocessing sint::clear::next::template init(false); @@ -120,7 +126,8 @@ Machine::Machine(int my_number, Names& playerNames, progs[0].print_offline_cost(); #endif - if (live_prep and (sint::needs_ot or sgf2n::needs_ot)) + if (live_prep + and (sint::needs_ot or sgf2n::needs_ot or sint::bit_type::needs_ot)) { Player* P; if (use_encryption) @@ -349,6 +356,8 @@ void Machine::run() outf << M2 << Mp << Mi; outf.close(); + bit_memories.write_memory(N.my_num()); + #ifdef OLD_USAGE for (int dtype = 0; dtype < N_DTYPE; dtype++) { diff --git a/Processor/NoLivePrep.h b/Processor/NoLivePrep.h index 600f74ed8..2d2703c84 100644 --- a/Processor/NoLivePrep.h +++ b/Processor/NoLivePrep.h @@ -26,6 +26,9 @@ class NoLivePrep : public Sub_Data_Files { (void) _; } + NoLivePrep(DataPositions& usage) : NoLivePrep(0, usage) + { + } }; #endif /* PROCESSOR_NOLIVEPREP_H_ */ diff --git a/Processor/Online-Thread.hpp b/Processor/Online-Thread.hpp index fd419b984..bde1925d8 100644 --- a/Processor/Online-Thread.hpp +++ b/Processor/Online-Thread.hpp @@ -29,6 +29,8 @@ void* Sub_Main_Func(void* ptr) vector& progs = machine.progs; int num=tinfo->thread_num; + BaseMachine::s().thread_num = num; + #ifdef DEBUG_THREADS fprintf(stderr, "\tI am in thread %d\n",num); #endif @@ -168,6 +170,7 @@ void* Sub_Main_Func(void* ptr) // destruct protocol before last MAC check and data statistics size_t prep_sent = Proc.DataF.data_sent(); + prep_sent += Proc.share_thread.DataF.data_sent(); delete processor; // MACCheck diff --git a/Processor/OnlineOptions.cpp b/Processor/OnlineOptions.cpp index 8f3b22986..b87fd84a0 100644 --- a/Processor/OnlineOptions.cpp +++ b/Processor/OnlineOptions.cpp @@ -127,7 +127,7 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc, if (allArgs.size() != 3u - opt.isSet("-p")) { - cerr << "ERROR: incorrect number of arguments to Player-Online.x\n"; + cerr << "ERROR: incorrect number of arguments to " << argv[0] << endl; cerr << "Arguments given were:\n"; for (unsigned int j = 1; j < allArgs.size(); j++) cout << "'" << *allArgs[j] << "'" << endl; diff --git a/Processor/PrivateOutput.hpp b/Processor/PrivateOutput.hpp index da8ce1e10..c78358bf9 100644 --- a/Processor/PrivateOutput.hpp +++ b/Processor/PrivateOutput.hpp @@ -20,11 +20,11 @@ void PrivateOutput::start(int player, int target, int source) template void PrivateOutput::stop(int player, int source) { - if (player == proc.P.my_num()) + if (player == proc.P.my_num() and proc.Proc) { open_type value; value.sub(proc.get_C_ref(source), masks.front()); - value.output(proc.Proc.private_output, false); + value.output(proc.Proc->private_output, false); masks.pop_front(); } } diff --git a/Processor/Processor.h b/Processor/Processor.h index c3217c095..3915d1acf 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -18,6 +18,8 @@ #include "ProcessorBase.h" #include "OnlineOptions.h" #include "Tools/SwitchableOutput.h" +#include "GC/Processor.h" +#include "GC/ShareThread.h" class Program; @@ -40,7 +42,7 @@ class SubProcessor template friend class Beaver; public: - ArithmeticProcessor& Proc; + ArithmeticProcessor* Proc; typename T::MAC_Check& MC; Player& P; Preprocessing& DataF; @@ -50,6 +52,8 @@ class SubProcessor SubProcessor(ArithmeticProcessor& Proc, typename T::MAC_Check& MC, Preprocessing& DataF, Player& P); + SubProcessor(typename T::MAC_Check& MC, Preprocessing& DataF, Player& P, + ArithmeticProcessor* Proc = 0); // Access to PO (via calls to POpen start/stop) void POpen(const vector& reg,const Player& P,int size); @@ -120,6 +124,8 @@ class Processor : public ArithmeticProcessor SubProcessor Proc2; SubProcessor Procp; + GC::Processor Procb; + GC::ShareThread share_thread; typename sgf2n::PrivateOutput privateOutput2; typename sint::PrivateOutput privateOutputp; @@ -185,6 +191,8 @@ class Processor : public ArithmeticProcessor void write_Ci(int i,const long& x) { Ci[i]=x; } + void dabit(const Instruction& instruction); + // Access to external client sockets for reading clear/shared data void read_socket_ints(int client_id, const vector& registers); // Setup client public key diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index b28164e5c..e3b617aca 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -7,6 +7,8 @@ #include "Protocols/ReplicatedInput.hpp" #include "Protocols/ReplicatedPrivateOutput.hpp" #include "Processor/ProcessorBase.hpp" +#include "GC/Processor.hpp" +#include "GC/ShareThread.hpp" #include #include @@ -14,6 +16,13 @@ template SubProcessor::SubProcessor(ArithmeticProcessor& Proc, typename T::MAC_Check& MC, Preprocessing& DataF, Player& P) : + SubProcessor(MC, DataF, P, &Proc) +{ +} + +template +SubProcessor::SubProcessor(typename T::MAC_Check& MC, + Preprocessing& DataF, Player& P, ArithmeticProcessor* Proc) : Proc(Proc), MC(MC), P(P), DataF(DataF), protocol(P), input(*this, MC) { DataF.set_proc(this); @@ -28,6 +37,7 @@ Processor::Processor(int thread_num,Player& P, : ArithmeticProcessor(machine.opts, thread_num),DataF(machine, &Procp, &Proc2),P(P), MC2(MC2),MCp(MCp),machine(machine), Proc2(*this,MC2,DataF.DataF2,P),Procp(*this,MCp,DataF.DataFp,P), + Procb(machine.bit_memories), share_thread(machine.get_N(), machine.opts), privateOutput2(Proc2),privateOutputp(Procp), external_clients(ExternalClients(P.my_num(), machine.prep_dir_prefix)), binary_file_io(Binary_File_IO()) @@ -46,6 +56,8 @@ Processor::Processor(int thread_num,Player& P, shared_prng.SeedGlobally(P); out.activate(P.my_num() == 0 or machine.opts.interactive); + + share_thread.pre_run(P, machine.get_bit_mac_key()); } @@ -86,6 +98,22 @@ void Processor::reset(const Program& program,int arg) Procp.resize(reg_maxp); Ci.resize(reg_maxi); this->arg = arg; + Procb.reset(program); +} + +template +void Processor::dabit(const Instruction& instruction) +{ + int size = instruction.get_size(); + assert(size <= sint::bit_type::clear::n_bits); + auto& bin = Procb.S[instruction.get_r(1)]; + bin = {}; + for (int i = 0; i < size; i++) + { + typename sint::bit_type tmp; + Procp.DataF.get_dabit(Procp.get_S_ref(instruction.get_r(0) + i), tmp); + bin ^= tmp << i; + } } #include "Networking/sockets.h" @@ -419,8 +447,11 @@ void SubProcessor::POpen(const vector& reg,const Player& P,int size) } } - Proc.sent += reg.size() * size; - Proc.rounds++; + if (Proc != 0) + { + Proc->sent += reg.size() * size; + Proc->rounds++; + } } template diff --git a/Processor/Program.cpp b/Processor/Program.cpp index 52f0474fa..34075eb06 100644 --- a/Processor/Program.cpp +++ b/Processor/Program.cpp @@ -26,6 +26,8 @@ void Program::compute_constants() p[i].get_mem(RegType(reg_type), SecrecyType(sec_type))); } } + + max_mem[INT][SECRET] = 0; } void Program::parse(istream& s) @@ -34,7 +36,7 @@ void Program::parse(istream& s) Instruction instr; s.peek(); while (!s.eof()) - { instr.parse(s); + { instr.parse(s, p.size()); p.push_back(instr); //cerr << "\t" << instr << endl; s.peek(); diff --git a/Programs/Source/bankers_bonus.mpc b/Programs/Source/bankers_bonus.mpc index 004a1eaaf..e523fd9aa 100644 --- a/Programs/Source/bankers_bonus.mpc +++ b/Programs/Source/bankers_bonus.mpc @@ -22,6 +22,10 @@ from Compiler.util import if_else PORTNUM = 14000 MAX_NUM_CLIENTS = 8 +n_rounds = 0 + +if len(program.args) > 1: + n_rounds = int(program.args[1]) def accept_client(): client_socket_id = regint() @@ -78,8 +82,7 @@ def main(): listen(PORTNUM) print_ln('Listening for client connections on base port %s', PORTNUM) - @do_while - def game_loop(): + def game_loop(_=None): print_ln('Starting a new round of the game.') # Clients socket id (integer). @@ -123,4 +126,11 @@ def main(): return True + if n_rounds > 0: + print('run %d rounds' % n_rounds) + for_range(n_rounds)(game_loop) + else: + print('run forever') + do_while(game_loop) + main() diff --git a/Protocols/Beaver.h b/Protocols/Beaver.h index 9e7c3f81d..4b64b0ceb 100644 --- a/Protocols/Beaver.h +++ b/Protocols/Beaver.h @@ -44,7 +44,7 @@ class Beaver : public ProtocolBase void start_exchange(); void stop_exchange(); - int get_n_relevant_players() { return P.num_players(); } + int get_n_relevant_players() { return 1 + T::threshold(P.num_players()); } }; #endif /* PROTOCOLS_BEAVER_H_ */ diff --git a/Protocols/CowGearPrep.h b/Protocols/CowGearPrep.h index 2f7c129ec..4d80b56f4 100644 --- a/Protocols/CowGearPrep.h +++ b/Protocols/CowGearPrep.h @@ -12,7 +12,7 @@ class PairwiseMachine; template class PairwiseGenerator; template -class CowGearPrep : public RingPrep +class CowGearPrep : public MaliciousRingPrep { typedef typename T::mac_key_type mac_key_type; typedef typename T::clear::FD FD; @@ -31,7 +31,7 @@ class CowGearPrep : public RingPrep static void teardown(); CowGearPrep(SubProcessor* proc, DataPositions& usage) : - RingPrep(proc, usage), pairwise_generator(0) + MaliciousRingPrep(proc, usage), pairwise_generator(0) { } ~CowGearPrep(); diff --git a/Protocols/CowGearPrep.hpp b/Protocols/CowGearPrep.hpp index bac48021f..a0d2787a2 100644 --- a/Protocols/CowGearPrep.hpp +++ b/Protocols/CowGearPrep.hpp @@ -7,6 +7,8 @@ #include "FHEOffline/PairwiseMachine.h" #include "Tools/Bundle.h" +#include "Protocols/ReplicatedPrep.hpp" + template PairwiseMachine* CowGearPrep::pairwise_machine = 0; template diff --git a/Protocols/MAC_Check.h b/Protocols/MAC_Check.h index 29f509f64..2b51d61e5 100644 --- a/Protocols/MAC_Check.h +++ b/Protocols/MAC_Check.h @@ -9,7 +9,6 @@ using namespace std; #include "Protocols/Share.h" #include "Networking/Player.h" -#include "Networking/ServerSocket.h" #include "Protocols/Summer.h" #include "Protocols/MAC_Check_Base.h" #include "Protocols/RandomPrep.h" @@ -87,7 +86,7 @@ class MAC_Check_ : public TreeSum, public MAC_Check_Base< public: - MAC_Check_(const typename T::Scalar& ai, int opening_sum = 10, + MAC_Check_(const typename U::mac_key_type::Scalar& ai, int opening_sum = 10, int max_broadcast = 10, int send_player = 0); virtual ~MAC_Check_(); diff --git a/Protocols/MAC_Check.hpp b/Protocols/MAC_Check.hpp index 06b360d05..093cd5859 100644 --- a/Protocols/MAC_Check.hpp +++ b/Protocols/MAC_Check.hpp @@ -30,7 +30,7 @@ const char* TreeSum::mc_timer_names[] = { }; template -MAC_Check_::MAC_Check_(const typename T::Scalar& ai, int opening_sum, +MAC_Check_::MAC_Check_(const typename U::mac_key_type::Scalar& ai, int opening_sum, int max_broadcast, int send_player) : TreeSum(opening_sum, max_broadcast, send_player) { @@ -147,7 +147,7 @@ void MAC_Check_::Check(const Player& P) if (popen_cnt < 10) { - vector deltas; + vector deltas; Bundle bundle(P); for (int i = 0; i < popen_cnt; i++) { @@ -161,7 +161,7 @@ void MAC_Check_::Check(const Player& P) { for (auto& os : bundle) if (&os != &bundle.mine) - delta += os.get(); + delta += os.get(); if (not delta.is_zero()) throw mac_fail(); } @@ -176,15 +176,15 @@ void MAC_Check_::Check(const Player& P) G.SetSeed(seed); U sj; - T a,gami,temp; - typename T::Scalar h; - vector tau(P.num_players()); + typename U::mac_type a,gami,temp; + typename U::mac_type::Scalar h; + vector tau(P.num_players()); a.assign_zero(); gami.assign_zero(); for (int i=0; i::Check(const Player& P) //cerr << "\tFinal Check" << endl; - T t; + typename U::mac_type t; t.assign_zero(); for (int i=0; i MalRepRingPrep::MalRepRingPrep(SubProcessor* proc, DataPositions& usage) : MaliciousRingPrep(proc, usage) @@ -81,13 +86,11 @@ void shuffle_triple_generation(vector>& triples, Player& P, } template -void ShuffleSacrifice::triple_sacrifice(vector>& triples, - vector>& check_triples, Player& P, - typename T::MAC_Check& MC) +template +void ShuffleSacrifice::shuffle(vector& check_triples, Player& P) { int buffer_size = check_triples.size(); assert(buffer_size >= minimum_n_inputs()); - int N = (buffer_size - C) / B; // shuffle GlobalPRNG G(P); @@ -97,6 +100,17 @@ void ShuffleSacrifice::triple_sacrifice(vector>& triples, int pos = G.get_uint(remaining); swap(check_triples[i], check_triples[i + pos]); } +} + +template +void ShuffleSacrifice::triple_sacrifice(vector>& triples, + vector>& check_triples, Player& P, + typename T::MAC_Check& MC) +{ + int buffer_size = check_triples.size(); + int N = (buffer_size - C) / B; + + shuffle(check_triples, P); // opening C triples vector shares; @@ -119,7 +133,7 @@ void ShuffleSacrifice::triple_sacrifice(vector>& triples, { T& a = check_triples[i][0]; T& b = check_triples[i][1]; - for (int j = 1; j < C; j++) + for (int j = 1; j < B; j++) { T& f = check_triples[i + N * j][0]; T& g = check_triples[i + N * j][1]; @@ -135,7 +149,7 @@ void ShuffleSacrifice::triple_sacrifice(vector>& triples, { T& b = check_triples[i][1]; T& c = check_triples[i][2]; - for (int j = 1; j < C; j++) + for (int j = 1; j < B; j++) { T& f = check_triples[i + N * j][0]; T& h = check_triples[i + N * j][2]; @@ -158,7 +172,7 @@ void MalRepRingPrepWithBits::buffer_bits() typename BitShare::MAC_Check MC; DataPositions usage; MalRepRingPrep prep(0, usage); - SubProcessor bit_proc(proc->Proc, MC, prep, proc->P); + SubProcessor bit_proc(MC, prep, proc->P); prep.set_proc(&bit_proc); bits_from_square_in_ring(this->bits, OnlineOptions::singleton.batch_size, &prep); } @@ -168,3 +182,5 @@ void MalRepRingPrep::buffer_inputs(int player) { this->buffer_inputs_as_usual(player, this->proc); } + +#endif diff --git a/Protocols/MaliciousRep3Share.h b/Protocols/MaliciousRep3Share.h index ce6da3270..9b36f684f 100644 --- a/Protocols/MaliciousRep3Share.h +++ b/Protocols/MaliciousRep3Share.h @@ -12,6 +12,11 @@ template class HashMaliciousRepMC; template class Beaver; template class MaliciousRepPrep; +namespace GC +{ +class MaliciousRepSecret; +} + template class MaliciousRep3Share : public Rep3Share { @@ -28,6 +33,8 @@ class MaliciousRep3Share : public Rep3Share typedef MaliciousRep3Share prep_type; typedef T random_type; + typedef GC::MaliciousRepSecret bit_type; + static string type_short() { return "M" + string(1, T::type_char()); diff --git a/Protocols/MaliciousRepMC.h b/Protocols/MaliciousRepMC.h index 3457426d0..d35b4d5a5 100644 --- a/Protocols/MaliciousRepMC.h +++ b/Protocols/MaliciousRepMC.h @@ -23,6 +23,11 @@ class MaliciousRepMC : public ReplicatedMC const vector& S, const Player& P); virtual void Check(const Player& P); + + MaliciousRepMC& get_part_MC() + { + return *this; + } }; template diff --git a/Protocols/MaliciousRepPrep.h b/Protocols/MaliciousRepPrep.h index 12ec9d034..c499ccba6 100644 --- a/Protocols/MaliciousRepPrep.h +++ b/Protocols/MaliciousRepPrep.h @@ -20,7 +20,7 @@ template void sacrifice(const vector& check_triples, Player& P); template -class MaliciousRepPrep : public BufferPrep +class MaliciousRepPrep : public MaliciousRingPrep { template friend class MalRepRingPrep; @@ -28,7 +28,8 @@ class MaliciousRepPrep : public BufferPrep DataPositions honest_usage; ReplicatedPrep honest_prep; - typename T::Honest::Protocol* replicated; + typename T::Honest::MAC_Check honest_mc; + SubProcessor* honest_proc; typename T::MAC_Check MC; SubProcessor* proc; diff --git a/Protocols/MaliciousRepPrep.hpp b/Protocols/MaliciousRepPrep.hpp index 4dbe345bc..e617c4cc0 100644 --- a/Protocols/MaliciousRepPrep.hpp +++ b/Protocols/MaliciousRepPrep.hpp @@ -12,33 +12,32 @@ MaliciousRepPrep::MaliciousRepPrep(SubProcessor* proc, DataPositions& usag MaliciousRepPrep(usage) { this->proc = proc; + MaliciousRingPrep::proc = proc; } template MaliciousRepPrep::MaliciousRepPrep(DataPositions& usage) : - BufferPrep(usage), honest_usage(usage.num_players()), - honest_prep(0, honest_usage), replicated(0), proc(0) + MaliciousRingPrep(0, usage), honest_usage(usage.num_players()), + honest_prep(0, honest_usage), honest_proc(0), proc(0) { } template MaliciousRepPrep::~MaliciousRepPrep() { - if (replicated) - delete replicated; } template void MaliciousRepPrep::set_protocol(typename T::Protocol& protocol) { + RingPrep::set_protocol(protocol); init_honest(protocol.P); } template void MaliciousRepPrep::init_honest(Player& P) { - replicated = new typename T::Honest::Protocol(P); - honest_prep.set_protocol(*replicated); + honest_proc = new SubProcessor(honest_mc, honest_prep, P); } template diff --git a/Protocols/MaliciousRingPrep.hpp b/Protocols/MaliciousRingPrep.hpp new file mode 100644 index 000000000..ff2e79429 --- /dev/null +++ b/Protocols/MaliciousRingPrep.hpp @@ -0,0 +1,24 @@ +/* + * MaliciousRingPrep.hpp + * + */ + +#ifndef PROTOCOLS_MALICIOUSRINGPREP_HPP_ +#define PROTOCOLS_MALICIOUSRINGPREP_HPP_ + +#include "ReplicatedPrep.h" + +#include "ShuffleSacrifice.hpp" + +template +void MaliciousRingPrep::buffer_dabits() +{ + assert(this->proc != 0); + vector> check_dabits; + ShuffleSacrifice shuffle_sacrifice; + this->buffer_dabits_without_check(check_dabits, + shuffle_sacrifice.minimum_n_inputs()); + shuffle_sacrifice.dabit_sacrifice(this->dabits, check_dabits, *this->proc); +} + +#endif /* PROTOCOLS_MALICIOUSRINGPREP_HPP_ */ diff --git a/Protocols/MascotPrep.h b/Protocols/MascotPrep.h index 19a4f7cbf..658858b10 100644 --- a/Protocols/MascotPrep.h +++ b/Protocols/MascotPrep.h @@ -48,6 +48,7 @@ class MascotFieldPrep : public MascotPrep { void buffer_inverses(); void buffer_bits(); + void buffer_dabits(); public: MascotFieldPrep(SubProcessor* proc, DataPositions& usage) : diff --git a/Protocols/MascotPrep.hpp b/Protocols/MascotPrep.hpp index 9a94f484f..a9da5a38a 100644 --- a/Protocols/MascotPrep.hpp +++ b/Protocols/MascotPrep.hpp @@ -12,6 +12,7 @@ #include "OT/OTTripleSetup.h" #include "OT/Triple.hpp" #include "OT/NPartyTripleGenerator.hpp" +#include "Protocols/ShuffleSacrifice.hpp" template OTPrep::OTPrep(SubProcessor* proc, DataPositions& usage) : @@ -32,10 +33,9 @@ void OTPrep::set_protocol(typename T::Protocol& protocol) RingPrep::set_protocol(protocol); SubProcessor* proc = this->proc; assert(proc != 0); - auto& ot_setups = BaseMachine::s().ot_setups.at(proc->Proc.thread_num); - OTTripleSetup setup = ot_setups.get_fresh(); - triple_generator = new typename T::TripleGenerator(setup, - proc->P.N, proc->Proc.thread_num, + triple_generator = new typename T::TripleGenerator( + BaseMachine::s().fresh_ot_setup(), + proc->P.N, -1, OnlineOptions::singleton.batch_size, 1, params, proc->MC.get_alphai(), &proc->P); triple_generator->multi_threaded = false; @@ -74,6 +74,17 @@ void MascotFieldPrep::buffer_bits() this->bits.push_back(bit); } +template +void MascotFieldPrep::buffer_dabits() +{ + assert(this->proc != 0); + vector> check_dabits; + ShuffleSacrifice shuffle_sacrifice; + this->buffer_dabits_without_check(check_dabits, + shuffle_sacrifice.minimum_n_inputs()); + shuffle_sacrifice.dabit_sacrifice(this->dabits, check_dabits, *this->proc); +} + template void MascotPrep::buffer_inputs(int player) { @@ -110,10 +121,10 @@ T BufferPrep::get_random_from_inputs(int nplayers) template size_t OTPrep::data_sent() { + size_t res = RingPrep::data_sent(); if (triple_generator) - return triple_generator->data_sent(); - else - return 0; + res += triple_generator->data_sent(); + return res; } template diff --git a/Protocols/PostSacriRepRingShare.h b/Protocols/PostSacriRepRingShare.h index 0d32cb0a2..02cc184ef 100644 --- a/Protocols/PostSacriRepRingShare.h +++ b/Protocols/PostSacriRepRingShare.h @@ -7,6 +7,7 @@ #define PROTOCOLS_POSTSACRIREPRINGSHARE_H_ #include "Protocols/MaliciousRep3Share.h" +#include "Protocols/MalRepRingShare.h" template class MalRepRingPrepWithBits; template class PostSacrifice; diff --git a/Protocols/Rep3Share.h b/Protocols/Rep3Share.h index ba0a3a730..a72e1cf16 100644 --- a/Protocols/Rep3Share.h +++ b/Protocols/Rep3Share.h @@ -9,6 +9,7 @@ #include "Math/FixedVec.h" #include "Math/Integer.h" #include "Protocols/Replicated.h" +#include "GC/ShareSecret.h" template class ReplicatedRingPrep; template class PrivateOutput; @@ -30,6 +31,8 @@ class Rep3Share : public FixedVec typedef ReplicatedRingPrep LivePrep; typedef Rep3Share Honest; + typedef GC::SemiHonestRepSecret bit_type; + const static bool needs_ot = false; const static bool dishonest_majority = false; diff --git a/Protocols/Rep3Share2k.h b/Protocols/Rep3Share2k.h new file mode 100644 index 000000000..6ef9305ab --- /dev/null +++ b/Protocols/Rep3Share2k.h @@ -0,0 +1,40 @@ +/* + * Rep3Share2k.h + * + */ + +#ifndef PROTOCOLS_REP3SHARE2K_H_ +#define PROTOCOLS_REP3SHARE2K_H_ + +#include "Rep3Share.h" +#include "Math/Z2k.h" + +template class ReplicatedPrep2k; + +template +class Rep3Share2 : public Rep3Share> +{ + typedef SignedZ2 T; + +public: + typedef Replicated Protocol; + typedef ReplicatedMC MAC_Check; + typedef MAC_Check Direct_MC; + typedef ReplicatedInput Input; + typedef ::PrivateOutput PrivateOutput; + typedef ReplicatedPrep2k LivePrep; + typedef Rep3Share2 Honest; + + typedef GC::SemiHonestRepSecret bit_type; + + Rep3Share2() + { + } + template + Rep3Share2(const FixedVec& other) + { + FixedVec::operator=(other); + } +}; + +#endif /* PROTOCOLS_REP3SHARE2K_H_ */ diff --git a/Protocols/ReplicatedInput.h b/Protocols/ReplicatedInput.h index 5d1d1e781..7834c6d98 100644 --- a/Protocols/ReplicatedInput.h +++ b/Protocols/ReplicatedInput.h @@ -20,7 +20,7 @@ class PrepLessInput : public InputBase public: PrepLessInput(SubProcessor* proc) : - InputBase(proc ? &proc->Proc : 0), processor(proc), i_share(0) {} + InputBase(proc ? proc->Proc : 0), processor(proc), i_share(0) {} virtual ~PrepLessInput() {} void start(int player, int n_inputs); diff --git a/Protocols/ReplicatedMC.h b/Protocols/ReplicatedMC.h index 6bc657bc5..143e0b83e 100644 --- a/Protocols/ReplicatedMC.h +++ b/Protocols/ReplicatedMC.h @@ -31,6 +31,11 @@ class ReplicatedMC : public MAC_Check_Base void POpen_End(vector& values,const vector& S,const Player& P); void Check(const Player& P) { (void)P; } + + ReplicatedMC& get_part_MC() + { + return *this; + } }; #endif /* PROTOCOLS_REPLICATEDMC_H_ */ diff --git a/Protocols/ReplicatedPrep.h b/Protocols/ReplicatedPrep.h index 1c4ab044a..57ea09918 100644 --- a/Protocols/ReplicatedPrep.h +++ b/Protocols/ReplicatedPrep.h @@ -9,6 +9,7 @@ #include "Networking/Player.h" #include "Processor/Data_Files.h" #include "Protocols/Rep3Share.h" +#include "Protocols/ShuffleSacrifice.h" #include "Math/gfp.h" #include @@ -27,6 +28,8 @@ class BufferPrep : public Preprocessing vector bits; vector>> inputs; + vector> dabits; + int n_bit_rounds; virtual void buffer_triples() = 0; @@ -38,6 +41,8 @@ class BufferPrep : public Preprocessing // don't call this if T::Input requires input tuples void buffer_inputs_as_usual(int player, SubProcessor* proc); + virtual void buffer_dabits() { throw runtime_error("no daBits"); } + public: typedef T share_type; @@ -60,11 +65,16 @@ class BufferPrep : public Preprocessing int vector_size); T get_random_from_inputs(int nplayers); + + virtual void get_dabit(T& a, typename T::bit_type& b); }; template class RingPrep : public BufferPrep { + void buffer_ring_bits_without_check(vector& bits, PRNG& G, + int buffer_size); + protected: template friend class MaliciousRepPrep; @@ -73,10 +83,14 @@ class RingPrep : public BufferPrep int base_player; + size_t sent; + void buffer_squares(); void buffer_inverses() { throw runtime_error("not inverses in rings"); } void buffer_bits_without_check(); + void buffer_dabits_without_check(vector>& dabits, + int buffer_size = -1); public: RingPrep(SubProcessor* proc, DataPositions& usage); @@ -87,6 +101,8 @@ class RingPrep : public BufferPrep vector& get_bits() { return this->bits; } void set_protocol(typename T::Protocol& protocol); + + virtual size_t data_sent() { return sent; } }; template @@ -100,6 +116,9 @@ class SemiHonestRingPrep : public virtual RingPrep virtual void buffer_bits() { this->buffer_bits_without_check(); } virtual void buffer_inputs(int player) { this->buffer_inputs_as_usual(player, this->proc); } + + virtual void buffer_dabits() + { this->buffer_dabits_without_check(this->dabits); } }; template @@ -111,6 +130,7 @@ class MaliciousRingPrep : public RingPrep virtual ~MaliciousRingPrep() {} virtual void buffer_bits(); + virtual void buffer_dabits(); }; template diff --git a/Protocols/ReplicatedPrep.hpp b/Protocols/ReplicatedPrep.hpp index f7b601a68..01357f77e 100644 --- a/Protocols/ReplicatedPrep.hpp +++ b/Protocols/ReplicatedPrep.hpp @@ -10,6 +10,10 @@ #include "Math/gfp.h" #include "Processor/OnlineOptions.h" +#include "MaliciousRingPrep.hpp" +#include "ShuffleSacrifice.hpp" +#include "GC/ShareThread.hpp" + template BufferPrep::BufferPrep(DataPositions& usage) : Preprocessing(usage), n_bit_rounds(0), @@ -28,7 +32,7 @@ BufferPrep::~BufferPrep() template RingPrep::RingPrep(SubProcessor* proc, DataPositions& usage) : - BufferPrep(usage), protocol(0), proc(proc), base_player(0) + BufferPrep(usage), protocol(0), proc(proc), base_player(0), sent(0) { } @@ -36,8 +40,8 @@ template void RingPrep::set_protocol(typename T::Protocol& protocol) { this->protocol = &protocol; - if (proc) - base_player = proc->Proc.thread_num; + if (proc and proc->Proc) + base_player = proc->Proc->thread_num; } template @@ -199,9 +203,11 @@ void BufferPrep::get_two_no_count(Dtype dtype, T& a, T& b) } template -void XOR(vector& res, vector& x, vector& y, int buffer_size, +void XOR(vector& res, vector& x, vector& y, typename T::Protocol& prot, SubProcessor* proc) { + assert(x.size() == y.size()); + int buffer_size = x.size(); res.resize(buffer_size); if (T::clear::field_type() == DATA_GF2N) @@ -258,22 +264,29 @@ void buffer_bits_spec(ReplicatedPrep>& prep, vector>& bits, template void RingPrep::buffer_bits_without_check() { - assert(protocol != 0); - auto buffer_size = OnlineOptions::singleton.batch_size; - auto& bits = this->bits; - auto& P = protocol->P; - int n_relevant_players = protocol->get_n_relevant_players(); - vector> player_bits(n_relevant_players, vector(buffer_size)); - typename T::Input input(proc, P); + SeededPRNG G; + buffer_ring_bits_without_check(this->bits, G, + OnlineOptions::singleton.batch_size); +} + +template +void buffer_bits_from_players(vector>& player_bits, PRNG& G, + SubProcessor& proc, int base_player, int buffer_size, + int n_bits = -1) +{ + auto& protocol = proc.protocol; + auto& P = protocol.P; + int n_relevant_players = protocol.get_n_relevant_players(); + player_bits.resize(n_relevant_players, vector(buffer_size)); + typename T::Input input(proc, proc.MC); input.reset_all(P); for (int i = 0; i < n_relevant_players; i++) { int input_player = (base_player + i) % P.num_players(); if (input_player == P.my_num()) { - SeededPRNG G; for (int i = 0; i < buffer_size; i++) - input.add_mine(G.get_bit()); + input.add_mine(G.get_bit(), n_bits); } else for (int i = 0; i < buffer_size; i++) @@ -282,14 +295,60 @@ void RingPrep::buffer_bits_without_check() input.exchange(); for (int i = 0; i < n_relevant_players; i++) for (auto& x : player_bits[i]) - x = input.finalize((base_player + i) % P.num_players()); + x = input.finalize((base_player + i) % P.num_players(), n_bits); +} + +template +void RingPrep::buffer_ring_bits_without_check(vector& bits, PRNG& G, + int buffer_size) +{ + assert(protocol != 0); + assert(proc != 0); + int n_relevant_players = protocol->get_n_relevant_players(); + vector> player_bits; + buffer_bits_from_players(player_bits, G, *proc, base_player, + buffer_size); auto& prot = *protocol; - XOR(bits, player_bits[0], player_bits[1], buffer_size, prot, proc); + XOR(bits, player_bits[0], player_bits[1], prot, proc); for (int i = 2; i < n_relevant_players; i++) - XOR(bits, bits, player_bits[i], buffer_size, prot, proc); + XOR(bits, bits, player_bits[i], prot, proc); base_player++; } +template +void RingPrep::buffer_dabits_without_check(vector>& dabits, + int buffer_size) +{ + if (buffer_size < 0) + buffer_size = OnlineOptions::singleton.batch_size; + assert(protocol != 0); + assert(proc != 0); + SeededPRNG G; + PRNG G2 = G; + typedef typename T::bit_type::part_type bit_type; + vector> player_bits; + auto& party = GC::ShareThread::s(); + DataPositions usage(proc->P.num_players()); + typename bit_type::LivePrep bit_prep(usage); + SubProcessor bit_proc(party.MC->get_part_MC(), + bit_prep, proc->P); + typename T::bit_type::Protocol bit_protocol(protocol->P); + buffer_bits_from_players(player_bits, G, bit_proc, base_player, + buffer_size, 1); + vector int_bits; + buffer_ring_bits_without_check(int_bits, G2, buffer_size); + for (auto& pb : player_bits) + assert(pb.size() == int_bits.size()); + for (size_t i = 0; i < int_bits.size(); i++) + { + bit_type bit = player_bits[0][i]; + for (int j = 1; j < protocol->get_n_relevant_players(); j++) + bit ^= player_bits[j][i]; + dabits.push_back({int_bits[i], bit}); + } + sent += bit_prep.data_sent(); +} + template<> inline void SemiHonestRingPrep>::buffer_bits() @@ -364,6 +423,17 @@ void BufferPrep::get_input_no_count(T& a, typename T::open_type& x, int i) inputs[i].pop_back(); } +template +void BufferPrep::get_dabit(T& a, typename T::bit_type& b) +{ + if (dabits.empty()) + buffer_dabits(); + a = dabits.back().first; + b = dabits.back().second; + dabits.pop_back(); + this->count(DATA_DABIT); +} + template inline void BufferPrep::buffer_inputs(int player) { @@ -382,10 +452,11 @@ void BufferPrep::buffer_inputs_as_usual(int player, SubProcessor* proc) auto buffer_size = OnlineOptions::singleton.batch_size; if (P.my_num() == player) { + SeededPRNG G; for (int i = 0; i < buffer_size; i++) { typename T::clear r; - r.randomize(proc->Proc.secure_prng); + r.randomize(G); input.add_mine(r); this->inputs[player].push_back({input.finalize_mine(), r}); } diff --git a/Protocols/ReplicatedPrep2k.h b/Protocols/ReplicatedPrep2k.h new file mode 100644 index 000000000..bfaf9b0e0 --- /dev/null +++ b/Protocols/ReplicatedPrep2k.h @@ -0,0 +1,27 @@ +/* + * Rep2kPrep.h + * + */ + +#ifndef PROTOCOLS_REPLICATEDPREP2K_H_ +#define PROTOCOLS_REPLICATEDPREP2K_H_ + +#include "ReplicatedPrep.h" + +template +class ReplicatedPrep2k : public ReplicatedRingPrep +{ +public: + ReplicatedPrep2k(SubProcessor* proc, DataPositions& usage) : + RingPrep(proc, usage), ReplicatedRingPrep(proc, usage) + { + } + + void get_dabit(T& a, typename T::bit_type& b) + { + this->get_one(DATA_BIT, a); + b = a & 1; + } +}; + +#endif /* PROTOCOLS_REPLICATEDPREP2K_H_ */ diff --git a/Protocols/Semi2kShare.h b/Protocols/Semi2kShare.h index 8889d3433..b28f6c826 100644 --- a/Protocols/Semi2kShare.h +++ b/Protocols/Semi2kShare.h @@ -8,6 +8,9 @@ #include "SemiShare.h" #include "OT/Rectangle.h" +#include "GC/SemiSecret.h" + +template class SemiPrep2k; template class Semi2kShare : public SemiShare> @@ -22,13 +25,15 @@ class Semi2kShare : public SemiShare> typedef SemiInput Input; typedef ::PrivateOutput PrivateOutput; typedef SPDZ Protocol; - typedef SemiPrep LivePrep; + typedef SemiPrep2k LivePrep; typedef Semi2kShare prep_type; typedef SemiMultiplier Multiplier; typedef OTTripleGenerator TripleGenerator; typedef Z2kSquare Rectangle; + typedef GC::SemiSecret bit_type; + Semi2kShare() { } diff --git a/Protocols/SemiMC.h b/Protocols/SemiMC.h index 8eefc5a26..1eab6c018 100644 --- a/Protocols/SemiMC.h +++ b/Protocols/SemiMC.h @@ -21,6 +21,8 @@ class SemiMC : public TreeSum, public MAC_Check_Base void POpen_End(vector& values,const vector& S,const Player& P); void Check(const Player& P) { (void)P; } + + SemiMC& get_part_MC() { return *this; } }; template diff --git a/Protocols/SemiPrep2k.h b/Protocols/SemiPrep2k.h new file mode 100644 index 000000000..a1e0bc05c --- /dev/null +++ b/Protocols/SemiPrep2k.h @@ -0,0 +1,28 @@ +/* + * SemiPrep2k.h + * + */ + +#ifndef PROTOCOLS_SEMIPREP2K_H_ +#define PROTOCOLS_SEMIPREP2K_H_ + +#include "SemiPrep.h" + +template +class SemiPrep2k : public SemiPrep +{ +public: + SemiPrep2k(SubProcessor* proc, DataPositions& usage) : + RingPrep(proc, usage), OTPrep(proc, usage),SemiHonestRingPrep(proc, usage), SemiPrep(proc, usage) + + { + } + + void get_dabit(T& a, typename T::bit_type& b) + { + this->get_one(DATA_BIT, a); + b = a & 1; + } +}; + +#endif /* PROTOCOLS_SEMIPREP2K_H_ */ diff --git a/Protocols/SemiShare.h b/Protocols/SemiShare.h index 20d5e4ad4..6c75f0da4 100644 --- a/Protocols/SemiShare.h +++ b/Protocols/SemiShare.h @@ -23,6 +23,11 @@ template class PrivateOutput; template class SemiMultiplier; template class OTTripleGenerator; +namespace GC +{ +class SemiSecret; +} + template class SemiShare : public T { @@ -47,6 +52,8 @@ class SemiShare : public T typedef T sacri_type; typedef typename T::Square Rectangle; + typedef GC::SemiSecret bit_type; + const static bool needs_ot = true; const static bool dishonest_majority = true; diff --git a/Protocols/ShamirShare.h b/Protocols/ShamirShare.h index eef06c155..85c7622db 100644 --- a/Protocols/ShamirShare.h +++ b/Protocols/ShamirShare.h @@ -9,6 +9,7 @@ #include "Protocols/Shamir.h" #include "Protocols/ShamirInput.h" #include "Machines/ShamirMachine.h" +#include "GC/NoShare.h" template class ReplicatedPrep; @@ -28,6 +29,8 @@ class ShamirShare : public T typedef ReplicatedPrep LivePrep; typedef ShamirShare Honest; + typedef GC::NoShare bit_type; + const static bool needs_ot = false; const static bool dishonest_majority = false; diff --git a/Protocols/Share.h b/Protocols/Share.h index b011ecf43..6479ca315 100644 --- a/Protocols/Share.h +++ b/Protocols/Share.h @@ -10,6 +10,7 @@ using namespace std; #include "Math/gf2n.h" #include "Protocols/SPDZ.h" +#include "Protocols/SemiShare.h" // Forward declaration as apparently this is needed for friends in templates template class Share; @@ -25,42 +26,33 @@ template class MascotPrep; union square128; -template -class Share +namespace GC +{ +template class TinierSecret; +} + +// abstracting SPDZ and SPDZ-wise +template +class Share_ { T a; // The share - T mac; // Shares of the mac + V mac; // Shares of the mac public: - typedef T mac_key_type; - typedef T mac_type; - typedef T open_type; - typedef T clear; - - typedef Share prep_type; - typedef MascotMultiplier Multiplier; - typedef MascotTripleGenerator TripleGenerator; - typedef T sacri_type; - typedef typename T::Square Rectangle; - typedef Rectangle Square; - - typedef MAC_Check_ MAC_Check; - typedef Direct_MAC_Check Direct_MC; - typedef ::Input Input; - typedef ::PrivateOutput PrivateOutput; - typedef SPDZ Protocol; - typedef MascotFieldPrep LivePrep; - typedef MascotPrep RandomPrep; - - const static bool needs_ot = true; - const static bool dishonest_majority = true; + typedef V mac_key_type; + typedef V mac_type; + typedef T share_type; + typedef typename T::open_type open_type; + typedef typename T::clear clear; - static int size() - { return 2 * T::size(); } + typedef GC::TinierSecret bit_type; - static string type_string() - { return "SPDZ " + T::type_string(); } + const static bool needs_ot = T::needs_ot; + const static bool dishonest_majority = T::dishonest_majority; + + static int size() + { return T::size() + V::size(); } static string type_short() { return string(1, T::type_char()); } @@ -72,13 +64,13 @@ class Share { return T::field_type(); } static int threshold(int nplayers) - { return nplayers - 1; } + { return T::threshold(nplayers); } - static Share constant(const clear& aa, int my_num, const typename T::Scalar& alphai) - { return Share(aa, my_num, alphai); } + static Share_ constant(const clear& aa, int my_num, const typename V::Scalar& alphai) + { return Share_(aa, my_num, alphai); } - template - void assign(const Share& S) + template + void assign(const Share_& S) { a=S.get_share(); mac=S.get_mac(); } void assign(const char* buffer) { a.assign(buffer); mac.assign(buffer + T::size()); } @@ -86,55 +78,54 @@ class Share { a.assign_zero(); mac.assign_zero(); } - void assign(const clear& aa, int my_num, const typename T::Scalar& alphai); + void assign(const clear& aa, int my_num, const typename V::Scalar& alphai); - Share() { assign_zero(); } - template - Share(const Share& S) { assign(S); } - Share(const clear& aa, int my_num, const typename T::Scalar& alphai) + Share_() { assign_zero(); } + template + Share_(const Share_& S) { assign(S); } + Share_(const clear& aa, int my_num, const typename V::Scalar& alphai) { assign(aa, my_num, alphai); } - Share(const T& share, const T& mac) : a(share), mac(mac) {} - ~Share() { ; } + Share_(const T& share, const V& mac) : a(share), mac(mac) {} const T& get_share() const { return a; } - const T& get_mac() const { return mac; } + const V& get_mac() const { return mac; } void set_share(const T& aa) { a=aa; } - void set_mac(const T& aa) { mac=aa; } + void set_mac(const V& aa) { mac=aa; } /* Arithmetic Routines */ - void mul(const Share& S,const T& aa); - void mul_by_bit(const Share& S,const T& aa); - void add(const Share& S,const clear& aa,int my_num,const T& alphai); + void mul(const Share_& S,const clear& aa); + void mul_by_bit(const Share_& S,const clear& aa); + void add(const Share_& S,const clear& aa,int my_num,const T& alphai); void negate() { a.negate(); mac.negate(); } - void sub(const Share& S,const clear& aa,int my_num,const T& alphai); - void sub(const clear& aa,const Share& S,int my_num,const T& alphai); - void add(const Share& S1,const Share& S2); - void sub(const Share& S1,const Share& S2); - void add(const Share& S1) { add(*this,S1); } + void sub(const Share_& S,const clear& aa,int my_num,const T& alphai); + void sub(const clear& aa,const Share_& S,int my_num,const T& alphai); + void add(const Share_& S1,const Share_& S2); + void sub(const Share_& S1,const Share_& S2); + void add(const Share_& S1) { add(*this,S1); } // obsolete interface - void add(const Share& S,const clear& aa,bool playerone,const T& alphai); - void sub(const Share& S,const clear& aa,bool playerone,const T& alphai); - void sub(const clear& aa,const Share& S,bool playerone,const T& alphai); - - Share operator+(const Share& x) const - { Share res; res.add(*this, x); return res; } - Share operator-(const Share& x) const - { Share res; res.sub(*this, x); return res; } + void add(const Share_& S,const clear& aa,bool playerone,const T& alphai); + void sub(const Share_& S,const clear& aa,bool playerone,const T& alphai); + void sub(const clear& aa,const Share_& S,bool playerone,const T& alphai); + + Share_ operator+(const Share_& x) const + { Share_ res; res.add(*this, x); return res; } + Share_ operator-(const Share_& x) const + { Share_ res; res.sub(*this, x); return res; } template - Share operator*(const U& x) const - { Share res; res.mul(*this, x); return res; } - Share operator/(const T& x) const - { Share res; res.set_share(a / x); res.set_mac(mac / x); return res; } + Share_ operator*(const U& x) const + { Share_ res; res.mul(*this, x); return res; } + Share_ operator/(const T& x) const + { Share_ res; res.set_share(a / x); res.set_mac(mac / x); return res; } - Share& operator+=(const Share& x) { add(x); return *this; } + Share_& operator+=(const Share_& x) { add(x); return *this; } template - Share& operator*=(const U& x) { mul(*this, x); return *this; } + Share_& operator*=(const U& x) { mul(*this, x); return *this; } - Share operator<<(int i) { return this->operator*(T(1) << i); } - Share& operator<<=(int i) { return *this = *this << i; } + Share_ operator<<(int i) { return this->operator*(T(1) << i); } + Share_& operator<<=(int i) { return *this = *this << i; } - Share operator>>(int i) { return {a >> i, mac >> i}; } + Share_ operator>>(int i) { return {a >> i, mac >> i}; } void force_to_bit() { a.force_to_bit(); } @@ -149,7 +140,7 @@ class Share mac.input(s,human); } - friend ostream& operator<<(ostream& s, const Share& x) { x.output(s, true); return s; } + friend ostream& operator<<(ostream& s, const Share_& x) { x.output(s, true); return s; } void pack(octetStream& os, bool full = true) const; void unpack(octetStream& os, bool full = true); @@ -167,61 +158,94 @@ class Share friend bool check_macs(const vector< Share >& S,const T& key); }; -template -using Share_ = Share; +// SPDZ(2k) only +template +class Share : public Share_, SemiShare> +{ +public: + typedef Share_, SemiShare> super; + + typedef T mac_key_type; + + typedef Share prep_type; + typedef Share input_check_type; + typedef Share input_type; + typedef MascotMultiplier Multiplier; + typedef MascotTripleGenerator TripleGenerator; + typedef T sacri_type; + typedef typename T::Square Rectangle; + typedef Rectangle Square; + + typedef MAC_Check_ MAC_Check; + typedef Direct_MAC_Check Direct_MC; + typedef ::Input Input; + typedef ::PrivateOutput PrivateOutput; + typedef SPDZ Protocol; + typedef MascotFieldPrep LivePrep; + typedef MascotPrep RandomPrep; + + static string type_string() + { return "SPDZ " + T::type_string(); } + + Share() {} + template + Share(const U& other) : super(other) {} + Share(const SemiShare& share, const SemiShare& mac) : + super(share, mac) {} +}; // specialized mul by bit for gf2n template <> -void Share::mul_by_bit(const Share& S,const gf2n& aa); +void Share_, SemiShare>::mul_by_bit(const Share_, SemiShare>& S,const gf2n& aa); -template -Share operator*(const T& y, const Share& x) { Share res; res.mul(x, y); return res; } +template +Share_ operator*(const typename T::clear& y, const Share_& x) { Share_ res; res.mul(x, y); return res; } -template -inline void Share::add(const Share& S1,const Share& S2) +template +inline void Share_::add(const Share_& S1,const Share_& S2) { a.add(S1.a,S2.a); mac.add(S1.mac,S2.mac); } -template -void Share::sub(const Share& S1,const Share& S2) +template +void Share_::sub(const Share_& S1,const Share_& S2) { a.sub(S1.a,S2.a); mac.sub(S1.mac,S2.mac); } -template -inline void Share::mul(const Share& S,const T& aa) +template +inline void Share_::mul(const Share_& S,const clear& aa) { a.mul(S.a,aa); - mac.mul(S.mac,aa); + mac = aa * S.mac; } -template -inline void Share::add(const Share& S,const clear& aa,int my_num,const T& alphai) +template +inline void Share_::add(const Share_& S,const clear& aa,int my_num,const T& alphai) { - *this = S + Share(aa, my_num, alphai); + *this = S + Share_(aa, my_num, alphai); } -template -inline void Share::sub(const Share& S,const clear& aa,int my_num,const T& alphai) +template +inline void Share_::sub(const Share_& S,const clear& aa,int my_num,const T& alphai) { - *this = S - Share(aa, my_num, alphai); + *this = S - Share_(aa, my_num, alphai); } -template -inline void Share::sub(const clear& aa,const Share& S,int my_num,const T& alphai) +template +inline void Share_::sub(const clear& aa,const Share_& S,int my_num,const T& alphai) { - *this = Share(aa, my_num, alphai) - S; + *this = Share_(aa, my_num, alphai) - S; } -template -inline void Share::assign(const clear& aa, int my_num, - const typename T::Scalar& alphai) +template +inline void Share_::assign(const clear& aa, int my_num, + const typename V::Scalar& alphai) { - Protocol::assign(a, aa, my_num); - mac.mul(aa, alphai); + a = T::constant(aa, my_num); + mac = aa * alphai; #ifdef DEBUG_MAC cout << "load " << hex << mac << " = " << aa << " * " << alphai << endl; #endif diff --git a/Protocols/Share.hpp b/Protocols/Share.hpp index a2aa823ab..9da2da422 100644 --- a/Protocols/Share.hpp +++ b/Protocols/Share.hpp @@ -7,9 +7,9 @@ #include "Math/Integer.h" -template +template inline -void Share::mul_by_bit(const Share& S,const T& aa) +void Share_::mul_by_bit(const Share_& S,const clear& aa) { a.mul(S.a,aa); mac.mul(S.mac,aa); @@ -17,7 +17,8 @@ void Share::mul_by_bit(const Share& S,const T& aa) template<> inline -void Share::mul_by_bit(const Share& S, const gf2n& aa) +void Share_, SemiShare>::mul_by_bit( + const Share_, SemiShare>& S, const gf2n& aa) { a.mul_by_bit(S.a,aa); mac.mul_by_bit(S.mac,aa); @@ -26,7 +27,7 @@ void Share::mul_by_bit(const Share& S, const gf2n& aa) -template +template T combine(const vector< Share >& S) { T ans=S[0].a; @@ -38,16 +39,16 @@ T combine(const vector< Share >& S) -template -inline void Share::pack(octetStream& os, bool full) const +template +inline void Share_::pack(octetStream& os, bool full) const { a.pack(os); if (full) mac.pack(os); } -template -inline void Share::unpack(octetStream& os, bool full) +template +inline void Share_::unpack(octetStream& os, bool full) { a.unpack(os); if (full) @@ -55,7 +56,7 @@ inline void Share::unpack(octetStream& os, bool full) } -template +template bool check_macs(const vector< Share >& S,const T& key) { T val=combine(S); diff --git a/Protocols/ShuffleSacrifice.h b/Protocols/ShuffleSacrifice.h index 9a792e097..f325bb778 100644 --- a/Protocols/ShuffleSacrifice.h +++ b/Protocols/ShuffleSacrifice.h @@ -12,6 +12,9 @@ using namespace std; class Player; +template +using dabit = pair; + template class ShuffleSacrifice { @@ -23,14 +26,33 @@ class ShuffleSacrifice { return max(n_outputs, minimum_n_outputs()) * B + C; } + static int minimum_n_inputs_with_combining() + { + return minimum_n_inputs(B * minimum_n_outputs()); + } static int minimum_n_outputs() { +#ifdef INSECURE +#ifdef FAKE_BATCH + cout << "FAKE FAKE FAKE" << endl; + return 1 << 10; +#endif +#endif return 1 << 20; } + template + void shuffle(vector& items, Player& P); + void triple_sacrifice(vector>& triples, vector>& check_triples, Player& P, typename T::MAC_Check& MC); + void triple_combine(vector>& triples, + vector>& to_combine, Player& P, + typename T::MAC_Check& MC); + + void dabit_sacrifice(vector>& dabits, + vector>& check_dabits, SubProcessor& proc); }; #endif /* PROTOCOLS_SHUFFLESACRIFICE_H_ */ diff --git a/Protocols/ShuffleSacrifice.hpp b/Protocols/ShuffleSacrifice.hpp new file mode 100644 index 000000000..e62b9cb62 --- /dev/null +++ b/Protocols/ShuffleSacrifice.hpp @@ -0,0 +1,126 @@ +/* + * ShuffleSacrifice.hpp + * + */ + +#ifndef PROTOCOLS_SHUFFLESACRIFICE_HPP_ +#define PROTOCOLS_SHUFFLESACRIFICE_HPP_ + +#include "ShuffleSacrifice.h" + +#include "MalRepRingPrep.hpp" + +template +inline void ShuffleSacrifice::triple_combine(vector >& triples, + vector >& to_combine, Player& P, + typename T::MAC_Check& MC) +{ + int buffer_size = to_combine.size(); + int N = buffer_size / B; + assert(minimum_n_outputs() <= N); + + shuffle(to_combine, P); + + vector opened; + vector masked; + masked.reserve(N); + for (int i = 0; i < N; i++) + { + T& b = to_combine[i][1]; + for (int j = 1; j < B; j++) + { + T& g = to_combine[i + N * j][1]; + masked.push_back(b - g); + } + } + MC.POpen(opened, masked, P); + auto it = opened.begin(); + for (int i = 0; i < N; i++) + { + T& a = to_combine[i][0]; + T& c = to_combine[i][2]; + for (int j = 1; j < B; j++) + { + T& f = to_combine[i + N * j][0]; + T& h = to_combine[i + N * j][2]; + auto& rho = *(it++); + a += f; + c += h + f * rho; + } + } + to_combine.resize(N); + triples = to_combine; +} + +template +void ShuffleSacrifice::dabit_sacrifice(vector >& output, + vector >& to_check, SubProcessor& proc) +{ + auto& P = proc.P; + auto& MC = proc.MC; + + int buffer_size = to_check.size(); + int N = (buffer_size - C) / B; + + shuffle(to_check, P); + + // opening C + vector shares; + vector bit_shares; + for (int i = 0; i < C; i++) + { + shares.push_back(to_check.back().first); + bit_shares.push_back(to_check.back().second); + to_check.pop_back(); + } + vector opened; + MC.POpen(opened, shares, P); + vector bits; + auto& MCB = *T::bit_type::part_type::new_mc( + GC::ShareThread::s().MC->get_alphai()); + MCB.POpen(bits, bit_shares, P); + for (int i = 0; i < C; i++) + if (opened[i] != bits[i].get()) + throw Offline_Check_Error("dabit shuffle opening"); + + // sacrifice buckets + typename T::Protocol protocol(P); + protocol.init_mul(&proc); + for (int i = 0; i < N; i++) + { + auto& a = to_check[i].first; + for (int j = 1; j < B; j++) + { + auto& f = to_check[i + N * j].first; + protocol.prepare_mul(a, f); + } + } + protocol.exchange(); + shares.clear(); + bit_shares.clear(); + shares.reserve((B - 1) * N); + bit_shares.reserve((B - 1) * N); + for (int i = 0; i < N; i++) + { + auto& a = to_check[i].first; + auto& b = to_check[i].second; + for (int j = 1; j < B; j++) + { + auto& f = to_check[i + N * j].first; + auto& g = to_check[i + N * j].second; + shares.push_back(a + f - protocol.finalize_mul() * 2); + bit_shares.push_back(b + g); + } + } + MC.POpen(opened, shares, P); + MCB.POpen(bits, bit_shares, P); + for (int i = 0; i < (B - 1) * N; i++) + if (opened[i] != bits[i].get()) + throw Offline_Check_Error("dabit shuffle opening"); + + to_check.resize(N); + output = to_check; + delete &MCB; +} + +#endif /* PROTOCOLS_SHUFFLESACRIFICE_HPP_ */ diff --git a/Protocols/Spdz2kPrep.h b/Protocols/Spdz2kPrep.h index df410fcfa..52b35f521 100644 --- a/Protocols/Spdz2kPrep.h +++ b/Protocols/Spdz2kPrep.h @@ -8,6 +8,7 @@ #include "MascotPrep.h" #include "Spdz2kShare.h" +#include "GC/TinySecret.h" template void bits_from_square_in_ring(vector& bits, int buffer_size, U* bit_prep); @@ -32,6 +33,8 @@ class Spdz2kPrep : public MascotPrep void buffer_inverses() { throw division_by_zero(); } void buffer_bits(); + void get_dabit(T& a, GC::TinySecret& b); + size_t data_sent(); }; diff --git a/Protocols/Spdz2kPrep.hpp b/Protocols/Spdz2kPrep.hpp index 66d41a991..fc7f8ff8b 100644 --- a/Protocols/Spdz2kPrep.hpp +++ b/Protocols/Spdz2kPrep.hpp @@ -40,7 +40,7 @@ void Spdz2kPrep::set_protocol(typename T::Protocol& protocol) // just dummies bit_pos = DataPositions(proc->P.num_players()); bit_DataF = new Sub_Data_Files(0, 0, "", bit_pos, 0); - bit_proc = new SubProcessor(proc->Proc, *bit_MC, *bit_DataF, proc->P); + bit_proc = new SubProcessor(*bit_MC, *bit_DataF, proc->P); bit_prep = new MascotPrep(bit_proc, bit_pos); bit_prep->params.amplify = false; bit_protocol = new typename BitShare::Protocol(proc->P); @@ -88,7 +88,7 @@ void bits_from_square_in_ring(vector& bits, int buffer_size, U* bit_prep) assert(bit_proc != 0); auto bit_MC = &bit_proc->MC; vector squares, random_shares; - BitShare one(1, bit_proc->P.my_num(), bit_MC->get_alphai()); + auto one = BitShare::constant(1, bit_proc->P.my_num(), bit_MC->get_alphai()); for (int i = 0; i < buffer_size; i++) { BitShare a, a2; @@ -109,6 +109,14 @@ void bits_from_square_in_ring(vector& bits, int buffer_size, U* bit_prep) bit_MC->Check(bit_proc->P); } +template +void Spdz2kPrep::get_dabit(T& a, GC::TinySecret& b) +{ + this->get_one(DATA_BIT, a); + b.resize_regs(1); + b.get_reg(0) = Spdz2kShare<1, T::s>(a); +} + template size_t Spdz2kPrep::data_sent() { diff --git a/Protocols/Spdz2kShare.h b/Protocols/Spdz2kShare.h index c94807f60..c04d8349b 100644 --- a/Protocols/Spdz2kShare.h +++ b/Protocols/Spdz2kShare.h @@ -15,6 +15,11 @@ template class Spdz2kMultiplier; template class Spdz2kTripleGenerator; +namespace GC +{ +template class TinySecret; +} + template class Spdz2kShare : public Share> { @@ -42,6 +47,8 @@ class Spdz2kShare : public Share> typedef SPDZ Protocol; typedef Spdz2kPrep LivePrep; + typedef GC::TinySecret bit_type; + const static int k = K; const static int s = S; @@ -49,12 +56,10 @@ class Spdz2kShare : public Share> static string type_short() { return "Z" + to_string(K) + "," + to_string(S); } Spdz2kShare() {} - template - Spdz2kShare(const Share& x) : super(x) {} - Spdz2kShare(const clear& x, int my_num, const mac_key_type& alphai) : - super(x, my_num, alphai) - { - } + template + Spdz2kShare(const Share_& x) : super(x) {} + template + Spdz2kShare(const T& share, const V& mac) : super(share, mac) {} }; diff --git a/Protocols/fake-stuff.hpp b/Protocols/fake-stuff.hpp index 23d4c51af..028546ab6 100644 --- a/Protocols/fake-stuff.hpp +++ b/Protocols/fake-stuff.hpp @@ -1,3 +1,5 @@ +#ifndef PROTOCOLS_FAKE_STUFF_HPP_ +#define PROTOCOLS_FAKE_STUFF_HPP_ #include "Protocols/fake-stuff.h" #include "Processor/Data_Files.h" @@ -13,19 +15,22 @@ template class Share; template class SemiShare; template class ShamirShare; template class FixedVec; +template class Share_; namespace GC { template class TinySecret; +template class TinierSecret; } -template -void make_share(Share* Sa,const U& a,int N,const V& key,PRNG& G) +template +void make_share(Share_* Sa,const U& a,int N,const V& key,PRNG& G) { insecure("share generation", false); - T mac,x,y; - mac.mul(a,key); - Share S; + T x; + W mac, y; + mac = a * key; + Share_ S; S.set_share(a); S.set_mac(mac); @@ -39,21 +44,33 @@ void make_share(Share* Sa,const U& a,int N,const V& key,PRNG& G) Sa[N-1]=S; } -template -void make_share(GC::TinySecret* Sa,const U& a,int N,const V& key,PRNG& G) +template +void make_vector_share(T* Sa,const U& a,int N,const V& key,PRNG& G) { int length = Sa[0].default_length; for (int i = 0; i < N; i++) Sa[i].resize_regs(length); for (int j = 0; j < length; j++) { - typename GC::TinySecret::part_type shares[N]; - make_share(shares, a.get_bit(j), N, key, G); + typename T::part_type shares[N]; + make_share(shares, typename T::part_type::clear(a.get_bit(j)), N, key, G); for (int i = 0; i < N; i++) Sa[i].get_reg(j) = shares[i]; } } +template +void make_share(GC::TinySecret* Sa, const U& a, int N, const V& key, PRNG& G) +{ + make_vector_share(Sa, a, N, key, G); +} + +template +void make_share(GC::TinierSecret* Sa, const U& a, int N, const V& key, PRNG& G) +{ + make_vector_share(Sa, a, N, key, G); +} + template void make_share(SemiShare* Sa,const T& a,int N,const T& key,PRNG& G) { @@ -383,3 +400,5 @@ void make_inverse(const typename T::mac_key_type& key, int N, int ntrip, bool ze { outf[i].close(); } delete[] outf; } + +#endif diff --git a/README.md b/README.md index 865321df3..bce22c47d 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ The following table lists all protocols that are fully supported. | Security model | Mod prime / GF(2^n) | Mod 2^k | Bin. SS | Garbling | | --- | --- | --- | --- | --- | -| Malicious, dishonest majority | [MASCOT](#secret-sharing) | [SPDZ2k](#secret-sharing) | [Tiny](#secret-sharing) | [BMR](#bmr) | +| Malicious, dishonest majority | [MASCOT](#secret-sharing) | [SPDZ2k](#secret-sharing) | [Tiny / Tinier](#secret-sharing) | [BMR](#bmr) | | Covert, dishonest majority | [CowGear](#secret-sharing) | N/A | N/A | N/A | | Semi-honest, dishonest majority | [Semi / Hemi](#secret-sharing) | [Semi2k](#secret-sharing) | [SemiBin](#secret-sharing) | [Yao's GC](#yaos-garbled-circuits) / [BMR](#bmr) | | Malicious, honest majority | [Shamir / Rep3 / PS](#honest-majority) | [Brain / Rep3 / PS](#honest-majority) | [Rep3](#honest-majority) | [BMR](#bmr) | @@ -128,7 +128,7 @@ phase outputs the amount of offline material required, which allows to compute the preprocessing time for a particular computation. #### Requirements - - GCC 5 or later (tested with 8.2) or LLVM/clang 5 or later (tested with 7). We recommend clang because it performs better. + - GCC 5 or later (tested with up to 9) or LLVM/clang 5 or later (tested with up to 9). We recommend clang because it performs better. - MPIR library, compiled with C++ support (use flag --enable-cxx when running configure) - libsodium library, tested against 1.0.16 - OpenSSL, tested against and 1.0.2 and 1.1.0 @@ -186,8 +186,7 @@ fail if the minimum is not met. ```./compile.py -R ``` -Currently, 64 is the only supported bit length, but it still has to be -specified for future compatibility. +Currently, most machines support bit lengths 64 and 72. #### Binary circuits @@ -203,6 +202,28 @@ binary circuits. This can be changed with `sfix.set_precision`. See If you would like to use integers of various precisions, you can use `sbitint.get_type(n)` to get a type for `n`-bit arithmetic. +#### Mixed circuits + +MP-SPDZ allows to mix computation between arithmetic and binary +secret sharing in the same security model. In the compiler, this is +used to switch from arithmetic to binary computation for certain +non-linear functions. At the time of writing, these include +comparison, bit decomposition, truncation, and modulo power of two. +You activate all this by adding `-X` when compiling arithmetic +circuits, that is +```./compile.py [-F ] ``` +for computation modulo a prime and +```./compile.py -R ``` +for computation modulo 2^k. + +Internally, this uses daBits described by [Rotaru and +Wood](https://eprint.iacr.org/2019/207), that is secret random bits +shared in different domains. Some security models allow direct +conversion of random bits from arithmetic to binary while others +require inputs from several parties followed by computing XOR and +checking for malicious security as described by Rotaru and Wood in +Section 4.1. + #### Compiling and running programs from external directories Programs can also be edited, compiled and run from any directory with the above basic structure. So for a source file in `./Programs/Source/`, all SPDZ scripts must be run from `./`. The `setup-online.sh` script must also be run from `./` to create the relevant data. For example: @@ -242,6 +263,7 @@ The following table shows all programs for dishonest-majority computation using | `hemi-party.x` | Semi-homomorphic encryption | Mod prime | Semi-honest | `hemi.sh` | | `semi-bin-party.x` | OT-based | Binary | Semi-honest | `semi-bin.sh` | | `tiny-party.x` | Adapted SPDZ2k | Binary | Malicious | `tiny.sh` | +| `tinier-party.x` | [FKOS15](https://eprint.iacr.org/2015/901) | Binary | Malicious | `tinier.sh` | Semi and Semi2k denote the result of stripping MASCOT/SPDZ2k of all steps required for malicious security, namely amplifying, sacrificing, @@ -409,7 +431,7 @@ or `Scripts/ring.sh tutorial` -The `-I` enable interactive inputs, and in the tutorial party 0 and 1 +The `-I` argument enables interactive inputs, and in the tutorial party 0 and 1 will be asked to provide three numbers. Otherwise, and when using the script, the inputs are read from `Player-Data/Input-P-0`. diff --git a/Scripts/test_tutorial.sh b/Scripts/test_tutorial.sh index 0a519ea27..4a4ff14d2 100755 --- a/Scripts/test_tutorial.sh +++ b/Scripts/test_tutorial.sh @@ -1,31 +1,70 @@ #!/bin/bash +while getopts XC opt; do + case $opt in + X) compile_opts=-X + dabit=1 + ;; + C) cheap=1 + ;; + esac +done + +shift $[OPTIND-1] + for i in 0 1; do seq 0 3 > Player-Data/Input-P$i-0 done -function test +function test_vm { + ulimit -c unlimited if ! Scripts/$1.sh tutorial | grep 'weighted average: 2.333'; then Scripts/$1.sh tutorial exit 1 fi } -./compile.py -R 64 tutorial +for dabit in ${dabit:-0 1}; do + if [[ $dabit = 1 ]]; then + compile_opts="$compile_opts -X" + fi -for i in ring brain mal-rep-ring ps-rep-ring semi2k spdz2k; do - test $i -done + ./compile.py -R 64 $compile_opts tutorial + + for i in ring brain mal-rep-ring ps-rep-ring semi2k; do + test_vm $i + done -./compile.py tutorial + if ! test "$dabit" = 1 -a "$cheap" = 1; then + test_vm spdz2k + fi -for i in rep-field mal-rep-field ps-rep-field shamir mal-shamir hemi cowgear semi mascot; do - test $i + ./compile.py $compile_opts tutorial + + for i in rep-field mal-rep-field ps-rep-field; do + test_vm $i + done + + if [[ ! "$dabit" = 1 ]]; then + for i in shamir mal-shamir; do + test_vm $i + done + fi + + for i in hemi semi; do + test_vm $i + done + + if ! test "$dabit" = 1 -a "$cheap" = 1; then + for i in cowgear mascot; do + test_vm $i + done + fi done -./compile.py -B 16 tutorial +./compile.py -B 16 $compile_opts tutorial -for i in replicated mal-rep-bin semi-bin yao tiny rep-bmr mal-rep-bmr shamir-bmr mal-shamir-bmr; do - test $i +for i in replicated mal-rep-bin semi-bin yao tinier tiny rep-bmr mal-rep-bmr shamir-bmr mal-shamir-bmr; do + test_vm $i done diff --git a/Scripts/tinier.sh b/Scripts/tinier.sh new file mode 100755 index 000000000..0244607d7 --- /dev/null +++ b/Scripts/tinier.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +. $HERE/run-common.sh + +run_player tinier-party.x $* || exit 1 diff --git a/Tools/BitVector.cpp b/Tools/BitVector.cpp index e4668ce14..1e8a987eb 100644 --- a/Tools/BitVector.cpp +++ b/Tools/BitVector.cpp @@ -13,7 +13,8 @@ void BitVector::resize_zero(size_t new_nbits) { size_t old_nbytes = nbytes; resize(new_nbits); - avx_memzero(bytes + old_nbytes, nbytes - old_nbytes); + if (old_nbytes < nbytes) + avx_memzero(bytes + old_nbytes, nbytes - old_nbytes); } const void* BitVector::get_ptr_to_byte(size_t i, size_t block_size) const diff --git a/Tools/BitVector.h b/Tools/BitVector.h index 3c174a188..982ba56de 100644 --- a/Tools/BitVector.h +++ b/Tools/BitVector.h @@ -189,7 +189,8 @@ class BitVector bool get_bit(int i) const { - assert(i < (int)nbits); + if (i >= (int)nbits) + throw out_of_range("BitVector access: " + to_string(i) + "/" + to_string(nbits)); return (bytes[i/8] >> (i % 8)) & 1; } void set_bit(int i,unsigned int a) @@ -268,7 +269,10 @@ class BitVector template T inline BitVector::get_portion(int i) const { - return (char*)&bytes[T::size() * i]; + if (T::size_in_bits() == 1) + return get_bit(i); + else + return (char*)&bytes[T::size() * i]; } template diff --git a/Tools/MMO.cpp b/Tools/MMO.cpp index 8c78ce3dd..211729271 100644 --- a/Tools/MMO.cpp +++ b/Tools/MMO.cpp @@ -140,3 +140,4 @@ Z(gf2n_short) Z(BitVec) Z(Z2<41>) Z(Z2<120>) Z(Z2<122>) Z(Z2<136>) Z(Z2<138>) +Z(Z2<65>) Z(Z2<49>) diff --git a/Tools/names.cpp b/Tools/names.cpp index 75446ec67..a428f9bb5 100644 --- a/Tools/names.cpp +++ b/Tools/names.cpp @@ -1,3 +1,5 @@ #include "Processor/Data_Files.h" -const char* DataPositions::dtype_names[N_DTYPE + 1] = { "Triples", "Squares", "Bits", "Inverses", "BitTriples", "BitGF2NTriples", "None" }; +const char* DataPositions::dtype_names[N_DTYPE + 1] = +{ "Triples", "Squares", "Bits", "Inverses", "BitTriples", "BitGF2NTriples", + "daBits", "None" }; diff --git a/Utils/Fake-Offline.cpp b/Utils/Fake-Offline.cpp index 88e96ddb0..3e194797c 100644 --- a/Utils/Fake-Offline.cpp +++ b/Utils/Fake-Offline.cpp @@ -13,6 +13,7 @@ #include "GC/MaliciousRepSecret.h" #include "GC/SemiSecret.h" #include "GC/TinySecret.h" +#include "GC/TinierSecret.h" #include "Math/Setup.h" #include "Processor/Data_Files.h" @@ -543,6 +544,11 @@ int generate(ez::ezOptionParser& opt) make_mult_triples>(keyt, nplayers, default_num, zero, prep_data_prefix); make_bits>(keyt, nplayers, default_num, zero); + gf2n_short keytt; + generate_mac_keys>(keytt, _, nplayers, prep_data_prefix); + make_mult_triples>(keytt, nplayers, default_num, zero, prep_data_prefix); + make_bits>(keytt, nplayers, default_num, zero); + make_basic>({}, nplayers, default_num, zero); make_basic>({}, nplayers, default_num, zero); diff --git a/Utils/spdz2-offline.cpp b/Utils/spdz2-offline.cpp index f47229eab..ed39ce0b4 100644 --- a/Utils/spdz2-offline.cpp +++ b/Utils/spdz2-offline.cpp @@ -15,6 +15,7 @@ using namespace std; #include "FHE/NTL-Subs.h" #include "Tools/ezOptionParser.h" #include "Tools/mkpath.h" +#include "Tools/Signal.h" #include "Math/Setup.h" #include "Protocols/MAC_Check.hpp" diff --git a/Yao/YaoEvalWire.h b/Yao/YaoEvalWire.h index 3f2c4dd61..6c4444548 100644 --- a/Yao/YaoEvalWire.h +++ b/Yao/YaoEvalWire.h @@ -14,8 +14,6 @@ class YaoEvalWire : public Phase { public: - typedef DummyMC MC; - static string name() { return "YaoEvalWire"; } typedef ostream& out_type; diff --git a/compile.py b/compile.py index 77c2c5775..e3a636bfb 100755 --- a/compile.py +++ b/compile.py @@ -68,6 +68,8 @@ def main(): parser.add_option("-b", "--budget", dest="budget", default=100000, help="set budget for optimized loop unrolling " "(default: 100000)") + parser.add_option("-X", "--mixed", action="store_true", dest="mixed", + help="mixing arithmetic and binary computation") options,args = parser.parse_args() if len(args) < 1: parser.print_help()