From 0f7020d791a667ede375aa365f109ac286e89d43 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 17 Feb 2022 13:21:19 +1100 Subject: [PATCH] Semi-honest computation based on threshold semi-homomorphic encryption. --- CHANGELOG.md | 13 +- CONFIG | 1 + Compiler/GC/instructions.py | 37 ++- Compiler/GC/types.py | 28 +- Compiler/allocator.py | 20 +- Compiler/instructions.py | 177 +++++++------ Compiler/instructions_base.py | 57 ++++- Compiler/library.py | 57 ++++- Compiler/ml.py | 351 +++++++++++++++++++++++--- Compiler/oram.py | 1 + Compiler/program.py | 11 +- Compiler/types.py | 97 +++++-- FHE/FHE_Keys.cpp | 16 +- FHE/FHE_Keys.h | 2 + FHE/FHE_Params.cpp | 15 ++ FHE/FHE_Params.h | 6 +- FHE/NTL-Subs.cpp | 11 +- FHE/NTL-Subs.h | 2 +- FHE/NoiseBounds.cpp | 5 +- FHE/Ring_Element.cpp | 1 + FHE/Rq_Element.cpp | 9 +- FHE/Rq_Element.h | 8 +- FHEOffline/DataSetup.cpp | 2 +- FHEOffline/Multiplier.cpp | 7 + FHEOffline/Multiplier.h | 3 + FHEOffline/PairwiseSetup.cpp | 14 +- FHEOffline/PairwiseSetup.h | 2 +- FHEOffline/SimpleDistDecrypt.cpp | 8 + FHEOffline/SimpleDistDecrypt.h | 1 + FHEOffline/TemiSetup.cpp | 59 +++++ FHEOffline/TemiSetup.h | 34 +++ GC/Memory.h | 2 +- GC/ShareSecret.h | 1 + GC/TinySecret.h | 1 + GC/instructions.h | 2 +- Machines/ShamirMachine.hpp | 1 + Machines/temi-party.cpp | 37 +++ Makefile | 5 +- Math/FixedVec.h | 5 - Math/Zp_Data.h | 2 +- Math/gf2n.cpp | 16 +- Math/mpn_fixed.h | 6 + Networking/Player.h | 1 + OT/BaseOT.cpp | 18 +- Processor/Binary_File_IO.hpp | 13 +- Processor/Input.h | 19 +- Processor/Input.hpp | 2 +- Processor/Instruction.h | 2 + Processor/Instruction.hpp | 46 ++-- Processor/Machine.hpp | 15 +- Processor/Memory.h | 4 +- Processor/Memory.hpp | 7 +- Processor/PrivateOutput.h | 12 +- Processor/PrivateOutput.hpp | 33 ++- Processor/Processor.h | 6 +- Processor/Processor.hpp | 36 ++- Processor/Program.cpp | 2 +- Processor/Program.h | 4 +- Processor/SpecificPrivateOutput.h | 65 +++++ Programs/Source/falcon_alex.mpc | 100 ++++++++ Programs/Source/keras_cifar_lenet.mpc | 45 ++++ Programs/Source/keras_mnist_dense.mpc | 3 +- Programs/Source/keras_mnist_lenet.mpc | 13 + Programs/Source/mnist_full_A.mpc | 6 + Programs/Source/mnist_full_C.mpc | 8 +- Protocols/Atlas.hpp | 6 + Protocols/Hemi.hpp | 2 +- Protocols/HemiMatrixPrep.h | 5 +- Protocols/HemiMatrixPrep.hpp | 68 +++-- Protocols/HemiPrep.h | 3 + Protocols/HemiPrep.hpp | 14 + Protocols/HemiShare.h | 1 + Protocols/LowGearKeyGen.hpp | 8 +- Protocols/MAC_Check.h | 17 +- Protocols/MAC_Check.hpp | 1 + Protocols/MAC_Check_Base.h | 4 + Protocols/MalRepRingShare.h | 4 +- Protocols/MaliciousRep3Share.h | 3 +- Protocols/MaliciousShamirPO.h | 3 +- Protocols/MaliciousShamirShare.h | 4 +- Protocols/MamaShare.h | 6 - Protocols/PostSacriRepFieldShare.h | 4 +- Protocols/PostSacriRepRingShare.h | 4 +- Protocols/ProtocolSet.h | 25 +- Protocols/Rep3Share.h | 7 +- Protocols/Rep3Share2k.h | 3 +- Protocols/Rep4Input.h | 1 - Protocols/Rep4Input.hpp | 6 - Protocols/Replicated.h | 7 - Protocols/Replicated.hpp | 6 +- Protocols/ReplicatedPrep.hpp | 33 ++- Protocols/ReplicatedPrivateOutput.h | 26 -- Protocols/ReplicatedPrivateOutput.hpp | 30 --- Protocols/Semi.h | 6 + Protocols/SemiInput.h | 29 +-- Protocols/SemiInput.hpp | 62 ++++- Protocols/Shamir.h | 1 - Protocols/Shamir.hpp | 36 +-- Protocols/ShamirInput.h | 7 +- Protocols/ShamirInput.hpp | 33 ++- Protocols/ShamirMC.h | 4 + Protocols/ShamirMC.hpp | 13 + Protocols/ShamirShare.h | 7 +- Protocols/Share.h | 1 + Protocols/ShareInterface.h | 1 + Protocols/SpdzWiseInput.h | 3 - Protocols/SpdzWiseInput.hpp | 18 -- Protocols/SpdzWiseMC.h | 2 +- Protocols/SpdzWisePrep.hpp | 1 - Protocols/TemiPrep.h | 72 ++++++ Protocols/TemiPrep.hpp | 129 ++++++++++ Protocols/TemiShare.h | 42 +++ Protocols/fake-stuff.hpp | 9 +- README.md | 33 ++- Scripts/prep-usage.py | 23 ++ Scripts/temi.sh | 8 + Scripts/test_tutorial.sh | 2 +- Tools/Buffer.h | 4 + Tools/Exceptions.cpp | 4 +- Tools/Exceptions.h | 2 +- Tools/octetStream.h | 2 + Utils/binary-example.cpp | 4 +- Utils/mixed-example.cpp | 4 +- Utils/paper-example.cpp | 4 +- doc/instructions.rst | 10 +- doc/low-level.rst | 5 + doc/non-linear.rst | 2 +- doc/preprocessing.rst | 34 ++- doc/requirements.txt | 1 + 129 files changed, 1973 insertions(+), 539 deletions(-) create mode 100644 FHEOffline/TemiSetup.cpp create mode 100644 FHEOffline/TemiSetup.h create mode 100644 Machines/temi-party.cpp create mode 100644 Processor/SpecificPrivateOutput.h create mode 100644 Programs/Source/falcon_alex.mpc create mode 100644 Programs/Source/keras_cifar_lenet.mpc delete mode 100644 Protocols/ReplicatedPrivateOutput.h delete mode 100644 Protocols/ReplicatedPrivateOutput.hpp create mode 100644 Protocols/TemiPrep.h create mode 100644 Protocols/TemiPrep.hpp create mode 100644 Protocols/TemiShare.h create mode 100755 Scripts/prep-usage.py create mode 100755 Scripts/temi.sh diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b75d24f8..6a0406a8e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,17 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. -## 0.2.9 (Jan 11, 2021) +## 0.3.0 (Feb 17, 2022) + +- Semi-honest computation based on threshold semi-homomorphic encryption +- Batch normalization backward propagation +- AlexNet for CIFAR-10 +- Specific private output protocols +- Semi-honest additive secret sharing without communication +- Sending of personal values +- Allow overwriting of persistence files +- Protocol signature in persistence files + +## 0.2.9 (Jan 11, 2022) - Disassembler - Run-time parameter for probabilistic truncation error diff --git a/CONFIG b/CONFIG index ba6855ea9..05b3683d5 100644 --- a/CONFIG +++ b/CONFIG @@ -42,6 +42,7 @@ else AVX_OT = 1 endif else +ARCH = AVX_OT = 0 endif diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index fc64ae2d2..ef9c14a3f 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -497,7 +497,7 @@ class movsb(NonVectorInstruction): code = opcodes['MOVSB'] arg_format = ['sbw','sb'] -class trans(base.VarArgsInstruction): +class trans(base.VarArgsInstruction, base.DynFormatInstruction): """ Secret bit register vector transpose. The first destination vector will contain the least significant bits of all source vectors etc. @@ -511,10 +511,22 @@ class trans(base.VarArgsInstruction): code = opcodes['TRANS'] is_vec = lambda self: True def __init__(self, *args): - self.arg_format = ['int'] + ['sbw'] * args[0] + \ - ['sb'] * (len(args) - 1 - args[0]) super(trans, self).__init__(*args) + @classmethod + def dynamic_arg_format(cls, args): + yield 'int' + n = next(args) + for i in range(n): + yield 'sbw' + next(args) + while True: + try: + yield 'sb' + next(args) + except StopIteration: + break + class bitb(NonVectorInstruction): """ Copy fresh secret random bit to secret bit register. @@ -560,7 +572,7 @@ def add_usage(self, req_node): req_node.increment(('bit', 'input', self.args[i]), self.args[i + 1]) class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction, - base.Mergeable): + base.Mergeable, base.DynFormatInstruction): """ Copy private input to secret bit registers bit by bit. The input is read as floating-point number, multiplied by a power of two, rounded to an integer, and then decomposed into bits. @@ -577,11 +589,18 @@ class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction, code = opcodes['INPUTBVEC'] def __init__(self, *args, **kwargs): - self.arg_format = [] - for x in self.get_arg_tuples(args): - self.arg_format += ['int', 'int', 'p'] + ['sbw'] * (x[0] - 3) super(inputbvec, self).__init__(*args, **kwargs) + @classmethod + def dynamic_arg_format(cls, args): + yield 'int' + for i, n in cls.bases(args): + yield 'int' + yield 'p' + for j in range(n - 3): + yield 'sbw' + yield 'int' + @staticmethod def get_arg_tuples(args): i = 0 @@ -590,10 +609,6 @@ def get_arg_tuples(args): i += args[i] assert i == len(args) - def merge(self, other): - self.args += other.args - self.arg_format += other.arg_format - def add_usage(self, req_node): for x in self.get_arg_tuples(self.args): req_node.increment(('bit', 'input', x[2]), x[0] - 3) diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 13619c7f9..38c37a261 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -41,7 +41,7 @@ class bitsn(cls): return cls.types[length] @classmethod def conv(cls, other): - if isinstance(other, cls): + if isinstance(other, cls) and cls.n == other.n: return other elif isinstance(other, MemValue): return cls.conv(other.read()) @@ -246,14 +246,20 @@ def conv_regint_by_bit(cls, n, res, other): assert n == res.n assert n == other.size cls.conv_cint_vec(cint(other, size=other.size), res) + @classmethod + def conv(cls, other): + if isinstance(other, cbits) and cls.n != None and \ + cls.n // cls.unit == other.n // cls.unit: + return other + else: + return super(cbits, cls).conv(other) types = {} def load_int(self, value): - if self.n <= 64: - tmp = regint(value) - elif value == self.long_one(): - tmp = cint(1, size=self.n) - else: - raise CompilerError('loading long integers to cbits not supported') + n_limbs = math.ceil(self.n / self.unit) + tmp = regint(size=n_limbs) + for i in range(n_limbs): + tmp[i].load_int(value % 2 ** self.unit) + value >>= self.unit self.load_other(tmp) def store_in_dynamic_mem(self, address): inst.stmsdci(self, cbits.conv(address)) @@ -1163,14 +1169,14 @@ class cbitfix(object): @classmethod def _new(cls, value): res = cls() + if cls.k < value.unit: + bits = value.bit_decompose(cls.k) + sign = bits[-1] + value += (sign << (cls.k)) * -1 res.v = value return res def output(self): v = self.v - if self.k < v.unit: - bits = self.v.bit_decompose(self.k) - sign = bits[-1] - v += (sign << (self.k)) * -1 inst.print_float_plainb(v, cbits.get_type(32)(-self.f), cbits(0), cbits(0), cbits(0)) diff --git a/Compiler/allocator.py b/Compiler/allocator.py index 9871d97fa..cf2f13ef4 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -403,6 +403,20 @@ def keep_merged_order(instr, n, t): add_edge(last_input[t][1], n) last_input[t][0] = n + def keep_text_order(inst, n): + if inst.get_players() is None: + # switch + for x in list(last_input.keys()): + if isinstance(x, int): + add_edge(last_input[x][0], n) + del last_input[x] + keep_merged_order(instr, n, None) + elif last_input[None][0] is not None: + keep_merged_order(instr, n, None) + else: + for player in inst.get_players(): + keep_merged_order(instr, n, player) + for n,instr in enumerate(block.instructions): outputs,inputs = instr.get_def(), instr.get_used() @@ -427,7 +441,7 @@ def keep_merged_order(instr, n, t): # will be merged if isinstance(instr, TextInputInstruction): - keep_merged_order(instr, n, TextInputInstruction) + keep_text_order(instr, n) elif isinstance(instr, RawInputInstruction): keep_merged_order(instr, n, RawInputInstruction) @@ -479,10 +493,6 @@ def keep_merged_order(instr, n, t): last_print_str = n elif isinstance(instr, PublicFileIOInstruction): keep_order(instr, n, instr.__class__) - elif isinstance(instr, startprivateoutput_class): - keep_order(instr, n, startprivateoutput_class, 2) - elif isinstance(instr, stopprivateoutput_class): - keep_order(instr, n, stopprivateoutput_class, 2) elif isinstance(instr, prep_class): keep_order(instr, n, instr.args[0]) elif isinstance(instr, StackInstruction): diff --git a/Compiler/instructions.py b/Compiler/instructions.py index a85fb25ad..e06797684 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -421,6 +421,10 @@ class use_matmul(base.Instruction): code = base.opcodes['USE_MATMUL'] arg_format = ['int','int','int','int'] + @classmethod + def get_usage(cls, args): + return {('matmul', tuple(arg.i for arg in args[:3])): args[3].i} + class run_tape(base.Instruction): """ Start tape/bytecode file in another thread. @@ -1229,15 +1233,20 @@ def __init__(self, *args, **kwargs): @base.gf2n @base.vectorize class inputmask(base.Instruction): - r""" Load secret $s_i$ with the next input mask for player $p$ and - write the mask on player $p$'s private output. """ + """ Store fresh random input mask(s) in secret register (vector) and clear + register (vector) of the relevant player. + + :param: mask (sint) + :param: mask (cint, player only) + :param: player (int) + """ __slots__ = [] code = base.opcodes['INPUTMASK'] - arg_format = ['sw', 'p'] + arg_format = ['sw', 'cw', 'p'] field_type = 'modp' def add_usage(self, req_node): - req_node.increment((self.field_type, 'input', self.args[1]), \ + req_node.increment((self.field_type, 'input', self.args[2]), \ self.get_size()) @base.vectorize @@ -1293,10 +1302,8 @@ class asm_input(base.TextInputInstruction): arg_format = tools.cycle(['sw', 'p']) field_type = 'modp' - def add_usage(self, req_node): - for player in self.args[1::2]: - req_node.increment((self.field_type, 'input', player), \ - self.get_size()) + def get_players(self): + return self.args[1::2] @base.vectorize class inputfix(base.TextInputInstruction): @@ -1305,10 +1312,8 @@ class inputfix(base.TextInputInstruction): arg_format = tools.cycle(['sw', 'int', 'p']) field_type = 'modp' - def add_usage(self, req_node): - for player in self.args[2::3]: - req_node.increment((self.field_type, 'input', player), \ - self.get_size()) + def get_players(self): + return self.args[2::3] @base.vectorize class inputfloat(base.TextInputInstruction): @@ -1322,7 +1327,7 @@ def add_usage(self, req_node): req_node.increment((self.field_type, 'input', player), \ 4 * self.get_size()) -class inputmixed_base(base.TextInputInstruction): +class inputmixed_base(base.TextInputInstruction, base.DynFormatInstruction): __slots__ = [] field_type = 'modp' # the following has to match TYPE: (N_DEST, N_PARAM) @@ -1341,22 +1346,30 @@ def __init__(self, name, *args): type_id = self.type_ids[name] super(inputmixed_base, self).__init__(type_id, *args) - @property - def arg_format(self): - for i in self.bases(): - t = self.args[i] - yield 'int' + @classmethod + def dynamic_arg_format(self, args): + yield 'int' + for i, t in self.bases(iter(args)): for j in range(self.types[t][0]): yield 'sw' for j in range(self.types[t][1]): yield 'int' yield self.player_arg_type + yield 'int' - def bases(self): + @classmethod + def bases(self, args): i = 0 - while i < len(self.args): - yield i - i += sum(self.types[self.args[i]]) + 2 + while True: + try: + t = next(args) + except StopIteration: + return + yield i, t + n = sum(self.types[t]) + i += n + 2 + for j in range(n + 1): + next(args) @base.vectorize class inputmixed(inputmixed_base): @@ -1380,13 +1393,16 @@ class inputmixed(inputmixed_base): player_arg_type = 'p' def add_usage(self, req_node): - for i in self.bases(): - t = self.args[i] + for i, t in self.bases(iter(self.args)): player = self.args[i + sum(self.types[t]) + 1] n_dest = self.types[t][0] req_node.increment((self.field_type, 'input', player), \ n_dest * self.get_size()) + def get_players(self): + for i, t in self.bases(iter(self.args)): + yield self.args[i + sum(self.types[t]) + 1] + @base.vectorize class inputmixedreg(inputmixed_base): """ Store private input in secret registers (vectors). The input is @@ -1412,6 +1428,9 @@ def add_usage(self, req_node): # player 0 as proxy req_node.increment((self.field_type, 'input', 0), float('inf')) + def get_players(self): + pass + @base.gf2n @base.vectorize class rawinput(base.RawInputInstruction, base.Mergeable): @@ -1433,7 +1452,23 @@ def add_usage(self, req_node): req_node.increment((self.field_type, 'input', player), \ self.get_size()) -class inputpersonal(base.Instruction, base.Mergeable): +class personal_base(base.Instruction, base.Mergeable): + __slots__ = [] + field_type = 'modp' + + def __init__(self, *args): + super(personal_base, self).__init__(*args) + for i in range(0, len(args), 4): + assert args[i + 2].size == args[i] + assert args[i + 3].size == args[i] + + def add_usage(self, req_node): + for i in range(0, len(self.args), 4): + player = self.args[i + 1] + req_node.increment((self.field_type, 'input', player), \ + self.args[i]) + +class inputpersonal(personal_base): """ Private input from cint. :param: vector size (int) @@ -1445,19 +1480,39 @@ class inputpersonal(base.Instruction, base.Mergeable): __slots__ = [] code = base.opcodes['INPUTPERSONAL'] arg_format = tools.cycle(['int','p','sw','c']) - field_type = 'modp' + +class privateoutput(personal_base): + """ Private input from cint. + + :param: vector size (int) + :param: player (int) + :param: destination (cint) + :param: source (sint) + :param: (repeat from vector size)... + """ + __slots__ = [] + code = base.opcodes['PRIVATEOUTPUT'] + arg_format = tools.cycle(['int','p','cw','s']) + +class sendpersonal(base.Instruction, base.Mergeable): + """ Private input from cint. + + :param: vector size (int) + :param: destination player (int) + :param: destination (cint) + :param: source player (int) + :param: source (cint) + :param: (repeat from vector size)... + """ + __slots__ = [] + code = base.opcodes['SENDPERSONAL'] + arg_format = tools.cycle(['int','p','cw','p','c']) def __init__(self, *args): - super(inputpersonal, self).__init__(*args) - for i in range(0, len(args), 4): + super(sendpersonal, self).__init__(*args) + for i in range(0, len(args), 5): assert args[i + 2].size == args[i] - assert args[i + 3].size == args[i] - - def add_usage(self, req_node): - for i in range(0, len(self.args), 4): - player = self.args[i + 1] - req_node.increment((self.field_type, 'input', player), \ - self.args[i]) + assert args[i + 4].size == args[i] @base.gf2n @base.vectorize @@ -1789,27 +1844,6 @@ class floatoutput(base.PublicFileIOInstruction): code = base.opcodes['FLOATOUTPUT'] arg_format = ['p','c','c','c','c'] -@base.gf2n -@base.vectorize -class startprivateoutput(base.Instruction): - r""" Initiate private output to $n$ of $s_j$ via $s_i$. """ - __slots__ = [] - code = base.opcodes['STARTPRIVATEOUTPUT'] - arg_format = ['sw','s','p'] - field_type = 'modp' - - def add_usage(self, req_node): - req_node.increment((self.field_type, 'input', self.args[2]), \ - self.get_size()) - -@base.gf2n -@base.vectorize -class stopprivateoutput(base.Instruction): - r""" Previously iniated private output to $n$ via $c_i$. """ - __slots__ = [] - code = base.opcodes['STOPPRIVATEOUTPUT'] - arg_format = ['cw','c','p'] - @base.vectorize class rand(base.Instruction): """ Store insecure random value of specified length in clear integer @@ -2210,7 +2244,8 @@ def get_used(self): @base.gf2n @base.vectorize -class dotprods(base.VarArgsInstruction, base.DataInstruction): +class dotprods(base.VarArgsInstruction, base.DataInstruction, + base.DynFormatInstruction): """ Dot product of secret registers (vectors). Note that the vectorized version works element-wise. @@ -2238,31 +2273,29 @@ def __init__(self, *args): flat_args += [x, y] base.Instruction.__init__(self, *flat_args) - @property - def arg_format(self): + @classmethod + def dynamic_arg_format(self, args): field = 'g' if self.is_gf2n() else '' - for i in self.bases(): - yield 'int' + yield 'int' + for i, n in self.bases(args): yield 's' + field + 'w' - for j in range(self.args[i] - 2): + for j in range(n - 2): yield 's' + field + yield 'int' - gf2n_arg_format = arg_format - - def bases(self): - i = 0 - while i < len(self.args): - yield i - i += self.args[i] + @property + def gf2n_arg_format(self): + return self.arg_format() def get_repeat(self): - return sum(self.args[i] // 2 for i in self.bases()) * self.get_size() + return sum(self.args[i] // 2 + for i, n in self.bases(iter(self.args))) * self.get_size() def get_def(self): - return [self.args[i + 1] for i in self.bases()] + return [self.args[i + 1] for i, n in self.bases(iter(self.args))] def get_used(self): - for i in self.bases(): + for i, n in self.bases(iter(self.args)): for reg in self.args[i + 2:i + self.args[i]]: yield reg diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index fb2a67b89..d6c647add 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -105,6 +105,7 @@ MATMULSM = 0xAB, CONV2DS = 0xAC, CHECK = 0xAF, + PRIVATEOUTPUT = 0xAD, # Data access TRIPLE = 0x50, BIT = 0x51, @@ -128,6 +129,7 @@ INPUTMIXEDREG = 0xF3, RAWINPUT = 0xF4, INPUTPERSONAL = 0xF5, + SENDPERSONAL = 0xF6, STARTINPUT = 0x61, STOPINPUT = 0x62, READSOCKETC = 0x63, @@ -364,6 +366,7 @@ class GF2N_Instruction(instruction_cls): arg_format = copy.deepcopy(instruction_cls.arg_format) reformat(arg_format) + @classmethod def is_gf2n(self): return True @@ -505,8 +508,12 @@ def expand_merged(self, skip): for arg in self.args: try: new_regs.append(type(arg)(size=size)) - except: + except TypeError: break + except: + print([call[0][0].size for call in self.calls]) + raise + assert len(new_regs) > 1 base = 0 for call in self.calls: for new_reg, reg in zip(new_regs[1:], call[0][1:]): @@ -854,6 +861,7 @@ def has_var_args(self): def is_vec(self): return False + @classmethod def is_gf2n(self): return False @@ -902,6 +910,10 @@ def get_new_args(self, size, subs): new_args.append(arg) return new_args + @staticmethod + def get_usage(args): + return {} + # String version of instruction attempting to replicate encoded version def __str__(self): @@ -949,9 +961,18 @@ def __init__(self, f): if name == 'cisc': arg_format = itertools.chain(['str'], itertools.repeat('int')) else: - arg_format = itertools.repeat('int') - self.args = [ArgFormats[next(arg_format)](f) - for i in range(n_args)] + def arg_iter(): + i = 0 + while True: + try: + yield self.args[i].i + except AttributeError: + yield None + i += 1 + arg_format = t.dynamic_arg_format(arg_iter()) + self.args = [] + for i in range(n_args): + self.args.append(ArgFormats[next(arg_format)](f)) def __str__(self): name = self.type.__name__ @@ -963,6 +984,9 @@ def __str__(self): res += ', '.join(str(arg) for arg in self.args) return res + def get_usage(self): + return self.type.get_usage(self.args) + class VarArgsInstruction(Instruction): def has_var_args(self): return True @@ -974,6 +998,26 @@ class VectorInstruction(Instruction): def get_code(self): return super(VectorInstruction, self).get_code(len(self.args[0])) +class DynFormatInstruction(Instruction): + __slots__ = [] + + @property + def arg_format(self): + return self.dynamic_arg_format(iter(self.args)) + + @classmethod + def bases(self, args): + i = 0 + while True: + try: + n = next(args) + except StopIteration: + return + yield i, n + i += n + for j in range(n - 1): + next(args) + ### ### Basic arithmetic ### @@ -1072,6 +1116,11 @@ class TextInputInstruction(VarArgsInstruction, DoNotEliminateInstruction): """ Input from text file or stdin """ __slots__ = [] + def add_usage(self, req_node): + for player in self.get_players(): + req_node.increment((self.field_type, 'input', player), \ + self.get_size()) + ### ### Data access instructions ### diff --git a/Compiler/library.py b/Compiler/library.py index 4f6c2de16..3f31499b0 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -223,7 +223,7 @@ def crash(condition=None): if isinstance(condition, localint): # allow crash on local values condition = condition._v - if condition == None: + if condition is None: condition = regint(1) instructions.crash(regint.conv(condition)) @@ -284,8 +284,8 @@ def get_arg(): def make_array(l): if isinstance(l, program.Tape.Register): - res = Array(1, type(l)) - res[0] = l + res = Array(len(l), type(l)) + res[:] = l else: l = list(l) res = Array(len(l), type(l[0]) if l else cint) @@ -1032,6 +1032,7 @@ def _(i): state = tuplify(initializer()) k = 0 block = get_block() + assert not isinstance(n_loops, int) or n_loops > 0 pre = copy.copy(loop_body.__globals__) while (not util.is_constant(n_loops) or k < n_loops) \ and (len(get_block()) < budget or k == 0) \ @@ -1211,7 +1212,13 @@ def decorator(loop_body): if t != regint: raise CompilerError('Not implemented for other than regint') args = Matrix(n_threads, 2 + thread_mem_req.get(regint, 0), 'ci') - state = tuple(initializer()) + state = initializer() + if len(state) == 0: + state_type = cint + elif isinstance(state, (tuple, list)): + state_type = type(state[0]) + else: + state_type = type(state) def f(inc): base = args[get_arg()][0] if not util.is_constant(thread_rounds): @@ -1224,8 +1231,7 @@ def f(inc): if thread_mem_req: thread_mem = Array(thread_mem_req[regint], regint, \ args[get_arg()].address + 2) - mem_state = Array(len(state), type(state[0]) \ - if state else cint, args[get_arg()][1]) + mem_state = Array(len(state), state_type, args[get_arg()][1]) @map_reduce_single(n_parallel, thread_rounds + inc, \ initializer, reducer, mem_state) def f(i): @@ -1257,14 +1263,14 @@ def f(i): threads = prog.run_tapes(thread_args) for thread in threads: prog.join_tape(thread) - if state: + if len(state): if thread_rounds: for i in range(n_threads - remainder): - state = reducer(Array(len(state), type(state[0]), \ + state = reducer(Array(len(state), state_type, \ args[remainder + i][1]), state) if remainder: for i in range(remainder): - state = reducer(Array(len(state), type(state[0]).reg_type, \ + state = reducer(Array(len(state), state_type, \ args[i][1]), state) def returner(): return untuplify(state) @@ -1300,6 +1306,39 @@ def summer(i): """ return map_sum(n_threads, None, n_loops, len(types), types) +def map_sum_simple(n_threads, n_loops, type, size): + """ Vectorized multi-threaded sum reduction. The following computes a + 100 sums of ten squares in three threads:: + + @map_sum_simple(3, 10, sint, 100) + def summer(i): + return sint(regint.inc(100, i, 0)) ** 2 + + result = summer() + + :param n_threads: number of threads (int) + :param n_loops: number of loop runs (regint/cint/int) + :param type: return type, must match the return statement + in the loop + :param size: vector size, must match the return statement + in the loop + + """ + initializer = lambda: type(0, size=size) + def summer(*args): + assert len(args) == 2 + args = list(args) + for i in (0, 1): + if isinstance(args[i], tuple): + assert len(args[i]) == 1 + args[i] = args[i][0] + for i in (0, 1): + assert len(args[i]) == size + if isinstance(args[i], Array): + args[i] = args[i][:] + return args[0] + args[1] + return map_reduce(n_threads, 1, n_loops, initializer, summer) + def tree_reduce_multithread(n_threads, function, vector): inputs = vector.Array(len(vector)) inputs.assign_vector(vector) diff --git a/Compiler/ml.py b/Compiler/ml.py index 5c4664be8..c521934fe 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -223,6 +223,7 @@ class Layer: thetas = lambda self: () debug_output = False back_batch_size = 128 + print_random_update = False @property def shape(self): @@ -254,6 +255,9 @@ def forward(self, batch=None, training=None): def __str__(self): return type(self).__name__ + str(self._Y.sizes) + def __repr__(self): + return '%s(%s)' % (type(self).__name__, self.Y.sizes) + class NoVariableLayer(Layer): input_from = lambda *args, **kwargs: None output_weights = lambda *args: None @@ -459,6 +463,10 @@ def __init__(self, N, d_out, approx=False, debug=False): self.debug = debug self.true_X = sfix.Array(N) + def __repr__(self): + return '%s(%s, %s, approx=%s)' % \ + (type(self).__name__, self.N, self.d_out, self.approx) + def _forward(self, batch): N = len(batch) d_out = self.X.sizes[1] @@ -609,10 +617,11 @@ def backward_params(self, f_schur_Y, batch): N = len(batch) tmp = Matrix(self.d_in, self.d_out, unreduced_sfix) + A = sfix.Matrix(N, self.d_out, address=f_schur_Y.address) + B = sfix.Matrix(self.N, self.d_in, address=self.X.address) + @multithread(self.n_threads, self.d_in) def _(base, size): - A = sfix.Matrix(self.N, self.d_out, address=f_schur_Y.address) - B = sfix.Matrix(self.N, self.d_in, address=self.X.address) mp = B.direct_trans_mul(A, reduce=False, indices=(regint.inc(size, base), batch.get_vector(), @@ -622,16 +631,24 @@ def _(base, size): progress('nabla W (matmul)') - if self.d_in * self.d_out < 200000: - print('reduce at once') - @multithread(self.n_threads, self.d_in * self.d_out) - def _(base, size): - self.nabla_W.assign_vector( - tmp.get_vector(base, size).reduce_after_mul(), base=base) - else: - @for_range_opt(self.d_in) - def _(i): - self.nabla_W[i] = tmp[i].get_vector().reduce_after_mul() + @multithread(self.n_threads, self.d_in * self.d_out, + max_size=get_program().budget) + def _(base, size): + self.nabla_W.assign_vector( + tmp.get_vector(base, size).reduce_after_mul(), base=base) + + if self.print_random_update: + print_ln('backward %s', self) + i = regint.get_random(64) % self.d_in + j = regint.get_random(64) % self.d_out + print_ln('%s at (%s, %s): before=%s after=%s A=%s B=%s', + str(self.nabla_W), i, j, tmp[i][j].v.reveal(), + self.nabla_W[i][j].reveal(), + A.get_column(j).reveal(), + B.get_column_by_row_indices( + batch.get_vector(), i).reveal()) + print_ln('batch=%s B=%s', batch, + [self.X[bi][0][i].reveal() for bi in batch]) progress('nabla W') @@ -699,6 +716,7 @@ def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False): self.d_in = d_in self.d_out = d_out self.d = d + self.activation = activation self.X = MultiArray([N, d, d_in], sfix) self.Y = MultiArray([N, d, d_out], sfix) @@ -721,12 +739,17 @@ def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False): else: self.f_input = self.Y + def __repr__(self): + return '%s(%s, %s, %s, activation=%s)' % \ + (type(self).__name__, self.N, self.d_in, + self.d_out, repr(self.activation)) + def reset(self): d_in = self.d_in d_out = self.d_out r = math.sqrt(6.0 / (d_in + d_out)) print('Initializing dense weights in [%f,%f]' % (-r, r)) - self.W.assign_vector(sfix.get_random(-r, r, size=self.W.total_size())) + self.W.randomize(-r, r) self.b.assign_all(0) def input_from(self, player, raw=False): @@ -820,6 +843,12 @@ def _(base, size): regint.inc(self.d_in))), base) + if self.print_random_update: + print_ln('backward %s', self) + index = regint.get_random(64) % self.nabla_X.total_size() + print_ln('%s nabla_X at %s: %s', str(self.nabla_X), + index, self.nabla_X.to_array()[index].reveal()) + progress('nabla X') self.backward_params(f_schur_Y, batch=batch) @@ -890,6 +919,10 @@ def __init__(self, N, d1, d2=1, alpha=0.5): self.alpha = alpha self.B = MultiArray([N, d1, d2], sint) + def __repr__(self): + return '%s(%s, %s, alpha=%s)' % \ + (type(self).__name__, self.N, self.d1, self.alpha) + def forward(self, batch, training=False): if training: n_bits = -math.log(self.alpha, 2) @@ -1022,6 +1055,7 @@ class MaxPool(NoVariableLayer): def __init__(self, shape, strides=(1, 2, 2, 1), ksize=(1, 2, 2, 1), padding='VALID'): assert len(shape) == 4 + assert min(shape) > 0, shape for x in strides, ksize: for i in 0, 3: assert x[i] == 1 @@ -1033,12 +1067,18 @@ def __init__(self, shape, strides=(1, 2, 2, 1), ksize=(1, 2, 2, 1), self.Y = Tensor(output_shape, sfix) self.strides = strides self.ksize = ksize + self.padding = padding self.nabla_X = Tensor(shape, sfix) self.nabla_Y = Tensor(output_shape, sfix) self.N = shape[0] self.comparisons = MultiArray([self.N, self.X.sizes[3], ksize[1] * ksize[2]], sint) + def __repr__(self): + return '%s(%s, strides=%s, ksize=%s, padding=%s)' % \ + (type(self).__name__, self.X.sizes, self.strides, + self.ksize, self.padding) + def _forward(self, batch): def process(pool, bi, k, i, j): def m(a, b): @@ -1165,7 +1205,7 @@ def _(base, size): self.Y[batch[0]].assign_vector(tmp, base) class FusedBatchNorm(Layer): - """ Fixed-point fused batch normalization layer. + """ Fixed-point fused batch normalization layer (inference only). :param shape: input/output shape (tuple/list of four int) """ @@ -1192,6 +1232,153 @@ def _(i, j): self.X[batch[0]][i][j].get_vector() * self.weights.get_vector() + self.bias.get_vector()) +class BatchNorm(Layer): + """ Fixed-point batch normalization layer. + + :param shape: input/output shape (tuple/list of four int) + :param approx: use approximate square root + + """ + thetas = lambda self: (self.weights, self.bias) + nablas = lambda self: (self.nabla_weights, self.nabla_bias) + + def __init__(self, shape, approx=True, args=None): + assert len(shape) in (2, 3, 4) + if len(shape) == 4: + shape = [shape[0], shape[1] * shape[2], shape[3]] + elif len(shape) == 2: + shape = [shape[0], 1, shape[1]] + tensors = (Tensor(shape, sfix) for i in range(4)) + self.X, self.Y, self.nabla_X, self.nabla_Y = tensors + arrays = (sfix.Array(shape[2]) for i in range(4)) + self.var, self.mu, self.weights, self.bias = arrays + arrays = (sfix.Array(shape[2]) for i in range(4)) + self.mu_hat, self.var_hat, self.nabla_weights, self.nabla_bias = arrays + self.epsilon = 2 ** (-sfix.f + 1) + self.momentum = 0.1 + if args != None: + approx = 'precisebn' not in args + self.approx = approx + if approx: + print('Approximate square root inverse in batch normalization') + self.InvertSqrt = mpc_math.InvertSqrt + else: + print('Precise square root inverse in batch normalization') + self.InvertSqrt = lambda x: 1 / mpc_math.sqrt(x) + + def __repr__(self): + return '%s(%s, approx=%s)' % \ + (type(self).__name__, self.X.sizes, self.approx) + + def reset(self): + self.bias.assign_all(0) + self.weights.assign_all(1) + self.mu_hat.assign_all(0) + self.var_hat.assign_all(0) + + def _output(self, batch, mu, var): + factor = sfix.Array(len(mu)) + factor[:] = self.InvertSqrt(var[:] + self.epsilon) + @for_range_opt_multithread(self.n_threads, + [len(batch), self.X.sizes[1]]) + def _(i, j): + tmp = self.weights[:] * (self.X[i][j][:] - self.mu[:]) * factor[:] + self.Y[i][j][:] = self.bias[:] + tmp + + def forward(self, batch, training=False): + if training: + d = self.X.sizes[1] + d_in = self.X.sizes[2] + s = sfix.Array(d_in) + @map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in) + def _(i, j): + return (self.X[batch[i]][j].get_vector()) + s.assign(_()) + @multithread(self.n_threads, d_in) + def _(base, size): + self.mu.assign_vector( + s.get_vector(base, size) / (len(batch) * d), base) + @map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in) + def _(i, j): + item = self.X[batch[i]][j].get_vector() + return ((item - self.mu[:]) ** 2) + self.var.assign(_()) + @multithread(self.n_threads, d_in) + def _(base, size): + self.var.assign_vector( + self.var.get_vector(base, size) / (len(batch) * d - 1), + base) + for x, y, in (self.mu_hat, self.mu), (self.var_hat, self.var): + x[:] = self.momentum * y[:] + (1 - self.momentum) * x[:] + self._output(batch, self.mu, self.var) + if self.print_random_update: + i = regint.get_random(64) % len(batch) + j = regint.get_random(64) % d + k = regint.get_random(64) % d_in + for x in self.mu, self.var: + print_ln('%s at %s: %s', str(x), k, x[k].reveal()) + print_ln('%s at (%s, %s, %s): in=%s out=%s', + str(self.Y), i, j, k, self.X[i][j][k].reveal(), + self.Y[i][j][k].reveal()) + else: + self._output(batch, self.mu_hat, self.var_hat) + + def backward(self, batch, compute_nabla_X=True): + factor = Array.create_from( + self.InvertSqrt(self.var[:] + self.epsilon)) + mynYf = self.X.same_shape() + gamnY = self.X.same_shape() + gamnYd = self.X.same_shape() + nYdf = self.X.same_shape() + d = self.X.sizes[1] + d_in = self.X.sizes[2] + @for_range_opt_multithread(self.n_threads, [len(batch), d]) + def _(i, j): + tmp = self.weights[:] * self.nabla_Y[i][j][:] + gamnY[i][j] = tmp + gamnYd[i][j] = tmp * (self.X[i][j][:] - self.mu[:]) + mynYf[i][j] = tmp * factor[:] + nYdf[i][j] = self.nabla_Y[i][j][:] * \ + (self.X[i][j][:] - self.mu[:]) * factor[:] + @map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in) + def _(i, j): + return (self.nabla_Y[i][j][:]) + self.nabla_bias.assign(_()) + @map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in) + def _(i, j): + return (nYdf[i][j]) + self.nabla_weights.assign(_()) + factor3 = Array.create_from(factor[:] ** 3) + @map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in) + def _(i, j): + return (mynYf[i][j]) + s1 = Array.create_from(_()) + @multithread(self.n_threads, len(s1)) + def _(base, size): + s1.assign_vector(s1.get_vector(base, size) / (len(batch) * d), base) + @map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in) + def _(i, j): + return (gamnYd[i][j][:] * factor3[:]) + s2 = Array.create_from(_()) + @multithread(self.n_threads, len(s2)) + def _(base, size): + s2.assign_vector( + s2.get_vector(base, size) / (len(batch) * d - 1), base) + @for_range_opt_multithread(self.n_threads, [len(batch), d]) + def _(i, j): + self.nabla_X[i][j][:] = mynYf[i][j][:] \ + - s1[:] - (self.X[i][j][:] - self.mu[:]) * s2[:] + if self.print_random_update: + print_ln('backward %s', self) + i = regint.get_random(64) % len(batch) + j = regint.get_random(64) % d + k = regint.get_random(64) % d_in + for x in self.nabla_bias, self.nabla_weights: + print_ln('%s at %s: %s', str(x), k, x[k].reveal()) + print_ln('%s at (%s, %s, %s): in=%s out=%s', str(self.Y), i, j, k, + self.nabla_Y[i][j][k].reveal(), + self.nabla_X[i][j][k].reveal()) + class QuantBase(object): bias_before_reduction = True @@ -1298,6 +1485,8 @@ def __init__(self, input_shape, weight_shape, bias_shape, output_shape, stride, self.padding.append(pad_total // 2) elif padding == 'VALID': self.padding = [0, 0] + elif isinstance(padding, int): + self.padding = [padding, padding] else: self.padding = padding @@ -1323,6 +1512,12 @@ def __init__(self, input_shape, weight_shape, bias_shape, output_shape, stride, assert(len(output_shape) == 4) assert(len(weight_shape) == 4) + def __repr__(self): + return '%s(%s, %s, %s, %s, %s, padding=%s, tf_weight_format=%s)' % \ + (type(self).__name__, self.X.sizes, self.weight_shape, + self.bias_shape, self.Y.sizes, self.stride, repr(self.padding), + self.tf_weight_format) + def input_from(self, player, raw=False): self.input_params_from(player) self.weights.input_from(player, budget=100000, raw=raw) @@ -1545,20 +1740,20 @@ def _(i, j): self.nabla_weights.assign_vector_by_indices(reduced, j, None, None, i) if compute_nabla_X: - assert tuple(self.padding) == (0, 0) assert tuple(self.stride) == (1, 1) reverse_weights = MultiArray( [n_channels_in, weights_h, weights_w, n_channels_out], sfix) - @for_range(n_channels_out) - def _(i): + @for_range_opt_multithread(self.n_threads, n_channels_in) + def _(l): @for_range(weights_h) def _(j): @for_range(weights_w) def _(k): - @for_range(n_channels_in) - def _(l): - reverse_weights[l][weights_h-j-1][k][i] = \ - self.weights[i][j][weights_w-k-1][l] + addresses = regint.inc(n_channels_out, + self.weights[0][j][weights_w-k-1].get_address(l), + reduce(operator.mul, self.weights.sizes[1:])) + reverse_weights[l][weights_h-j-1][k].assign_vector( + self.weights.value_type.load_mem(addresses)) padded_w = inputs_w + 2 * padding_w padded_h = inputs_h + 2 * padding_h if padding_h or padding_w: @@ -1579,14 +1774,16 @@ def _(i, j): unreduced_sfix._new(res).reduce_after_mul(), i, None, None, j) if padding_h or padding_w: - @for_range(N) + @for_range_opt_multithread(self.n_threads, N) def _(i): @for_range(inputs_h) def _(j): @for_range(inputs_w) def _(k): + jj = j + padding_w + kk = k + padding_w self.nabla_X[i][j][k].assign_vector( - output[i][j][k].get_vector()) + output[i][jj][kk].get_vector()) if self.debug_output: @for_range(len(batch)) @@ -1806,6 +2003,7 @@ def __init__(self, report_loss=None): self.report_loss = report_loss self.X_by_label = None self.print_update_average = False + self.print_random_update = False self.print_losses = False self.print_loss_reduction = False self.i_epoch = MemValue(0) @@ -1846,6 +2044,7 @@ def reset(self): def batch_for(self, layer, batch): if layer in (self.layers[0], self.layers[-1]): + assert not isinstance(layer, BatchNorm) return batch else: batch = regint.Array(len(batch)) @@ -1876,6 +2075,21 @@ def forward(self, N=None, batch=None, keep_intermediate=True, if i != len(self.layers) - 1 or run_last: layer.forward(batch=self.batch_for(layer, batch), training=training) + if self.print_random_update: + print_ln('forward layer %s', layer) + l = min(100, layer.Y[i].total_size()) + i = regint.get_random(64) % len(batch) + if l < 100: + j = 0 + else: + j = regint.get_random(64) % \ + (layer.Y[i].total_size() - l) + print_ln('forward layer %s at (%s, %s): %s', layer, i, j, + layer.Y[i].to_array().get_vector(j, l).reveal()) + i = regint.get_random(64) % layer.Y[0].total_size() + print_ln('forward layer %s vertical at %s: %s', layer, i, + [layer.Y[j].to_array()[i].reveal() + for j in range(len(batch))]) if self.time_layers: stop_timer(100 + i) break_point() @@ -1979,7 +2193,11 @@ def _(j): label * n) self.forward(batch=batch, training=True) self.backward(batch=batch) + if self.time_layers: + start_timer(1000) self.update(i, batch=batch) + if self.time_layers: + stop_timer(1000) loss_sum.iadd(self.layers[-1].l) if self.print_loss_reduction: before = self.layers[-1].average_loss(N) @@ -2070,6 +2288,8 @@ def run_by_args(self, program, n_runs, batch_size, test_X, test_Y, if 'nomom' in program.args: self.momentum = 0 self.print_losses = 'print_losses' in program.args + self.print_random_update = 'print_random_update' in program.args + Layer.print_random_update = self.print_random_update self.time_layers = 'time_layers' in program.args self.revealing_correctness = not 'no_acc' in program.args self.layers[-1].compute_loss = not 'no_loss' in program.args @@ -2099,6 +2319,16 @@ def run_by_args(self, program, n_runs, batch_size, test_X, test_Y, print_ln('loss %s', self.layers[-1].l.reveal()) self.output_weights() return + if 'bench10' in program.args or 'bench1' in program.args: + n = 1 if 'bench1' in program.args else 10 + print('benchmarking %s iterations' % n) + @for_range(n) + def _(i): + batch = Array.create_from(regint.inc(batch_size)) + self.forward(batch=batch, training=True) + self.backward(batch=batch) + self.update(0, batch=batch) + return @for_range(n_runs) def _(i): if not acc_first: @@ -2115,6 +2345,7 @@ def _(i): cfix(self.n_correct, k=63, f=31) / n_trained, self.n_correct, n_trained) if test_X and test_Y: + print('use test set') n_test = len(test_Y) n_correct, loss = self.reveal_correctness(test_X, test_Y, acc_batch_size) @@ -2211,7 +2442,8 @@ def _(base, size): util.max, abs_g.get_vector()) scale = MemValue(sfix._new(library.AppRcr( max_g.v, max_g.k, max_g.f, simplex_flag=True))) - @multithread(self.n_threads, m.total_size()) + @multithread(self.n_threads, m.total_size(), + max_size=get_program().budget) def _(base, size): m_part = m.get_vector(base, size) v_part = v.get_vector(base, size) @@ -2333,20 +2565,33 @@ def _(i): print_ln_if((x > limit) + (x < -limit), 'theta epoch=%s %s index=%s %s', i_epoch.read(), str(theta), i, x) - index = regint.get_random(64) % len(a) - print_ln('%s at %s: nabla=%s update=%s theta=%s', str(theta), index, - aa[1][index], aa[0][index], aa[2][index]) + if self.print_random_update: + print_ln('update') + l = min(100, nabla.total_size()) + if l < 100: + index = 0 + else: + index = regint.get_random(64) % (nabla.total_size() - l) + print_ln('%s at %s: nabla=%s update=%s theta=%s', str(theta), + index, nabla.to_array().get_vector(index, l).reveal(), + delta_theta.to_array().get_vector(index, l).reveal(), + theta.to_array().get_vector(index, l).reveal()) self.gamma.imul(1 - 10 ** - 6) def apply_padding(input_shape, kernel_size, strides, padding): + if isinstance(padding, int): + input_shape = [x + 2 * padding for x in input_shape] + padding = 'valid' if padding == 'valid': - return (input_shape[0] - kernel_size[0] + 1) // strides[0], \ + res = (input_shape[0] - kernel_size[0] + 1) // strides[0], \ (input_shape[1] - kernel_size[1] + 1) // strides[1], + assert min(res) > 0, (input_shape, kernel_size, strides, padding) + return res elif padding == 'same': - return (input_shape[1]) // strides[0], \ - (input_shape[2]) // strides[1], + return (input_shape[0]) // strides[0], \ + (input_shape[1]) // strides[1], else: - raise Exception('invalid padding: ' + padding) + raise Exception('invalid padding: %s' % padding) class keras: class layers: @@ -2354,7 +2599,7 @@ class layers: Dense = lambda *args, **kwargs: ('dense', args, kwargs) def Conv2D(filters, kernel_size, strides=(1, 1), padding='valid', - activation=None): + activation=None, input_shape=None): return 'conv2d', {'filters': filters, 'kernel_size': kernel_size, 'strides': strides, 'padding': padding, 'activation': activation} @@ -2369,6 +2614,13 @@ def Dropout(rate): raise Exception('rate needs to be a power of two') return 'dropout', rate + def Activation(activation): + assert(activation == 'relu') + return activation, + + def BatchNormalization(): + return 'batchnorm', + class optimizers: SGD = lambda *args, **kwargs: ('sgd', args, kwargs) Adam = lambda *args, **kwargs: ('adam', args, kwargs) @@ -2383,12 +2635,25 @@ def __init__(self, layers): def compile(self, optimizer): self.optimizer = optimizer + def compile_by_args(self, program): + if 'adam' in program.args: + self.optimizer = 'adam', [], {} + elif 'amsgrad' in program.args: + self.optimizer = 'adam', [], {'amsgrad': True} + else: + self.optimizer = 'sgd', [], {} + @property def trainable_variables(self): if self.opt == None: raise Exception('need to run build() or fit() first') return list(self.opt.thetas) + def summary(self): + sizes = [var.total_size() for var in self.trainable_variables] + print(sizes) + print('Trainable params:', sum(sizes)) + def build(self, input_shape, batch_size=128): data_input_shape = input_shape if self.opt != None and \ @@ -2415,12 +2680,11 @@ def build(self, input_shape, batch_size=128): if i == len(self.layers) - 1: if layer[2].get('activation', 'softmax') in \ ('softmax', 'sigmoid'): - del layer[2]['activation'] + layer[2].pop('activation', None) layers.append(Dense(N, n_units, layer[1][0], **layer[2])) + input_shape = layers[-1].Y.sizes elif name == 'conv2d': - if len(layers) != 0: - input_shape = layers[-1].Y.sizes input_shape = list(input_shape) + \ [1] * (4 - len(input_shape)) print (layer[1]) @@ -2437,9 +2701,13 @@ def build(self, input_shape, batch_size=128): output_shape = [batch_size] + list( apply_padding(input_shape[1:3], kernel_size, strides, padding)) + [filters] + padding = padding.upper() if isinstance(padding, str) \ + else padding layers.append(FixConv2d(input_shape, weight_shape, (filters,), output_shape, - strides, padding.upper())) + strides, padding)) + input_shape = output_shape + print('conv output shape', output_shape) elif name == 'maxpool': pool_size = layer[1]['pool_size'] strides = layer[1]['strides'] @@ -2450,16 +2718,23 @@ def build(self, input_shape, batch_size=128): strides = (strides, strides) if strides == None: strides = pool_size - layers.append(MaxPool(layers[-1].Y.sizes, + layers.append(MaxPool(input_shape, [1] + list(strides) + [1], [1] + list(pool_size) + [1], - padding.upper())) + padding)) + input_shape = layers[-1].Y.sizes elif name == 'dropout': layers.append(Dropout(batch_size, reduce( operator.mul, layers[-1].Y.sizes[1:]), alpha=layer[1])) + input_shape = layers[-1].Y.sizes elif name == 'flatten': pass + elif name == 'relu': + layers.append(Relu(layers[-1].Y.sizes)) + elif name == 'batchnorm': + input_shape = layers[-1].Y.sizes + layers.append(BatchNorm(layers[-1].Y.sizes)) else: raise Exception(layer[0] + ' not supported') if layers[-1].d_out == 1: diff --git a/Compiler/oram.py b/Compiler/oram.py index 443d826cd..543fc4aab 100644 --- a/Compiler/oram.py +++ b/Compiler/oram.py @@ -1493,6 +1493,7 @@ def f(i): self.l[i] = [0] * self.elements_per_block time() print_ln('packed ORAM init %s/%s', i, real_init_rounds) + print_ln('packed ORAM init done') print('index initialized, size', size) def translate_index(self, index): """ Bit slicing *index* according parameters. Output is tuple diff --git a/Compiler/program.py b/Compiler/program.py index 5dad8e516..366723304 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -580,10 +580,19 @@ def disable_memory_warnings(self): @staticmethod def read_tapes(schedule): + m = re.search(r'([^/]*)\.mpc', schedule) + if m: + schedule = m.group(1) if not os.path.exists(schedule): schedule = 'Programs/Schedules/%s.sch' % schedule - lines = open(schedule).readlines() + try: + lines = open(schedule).readlines() + except FileNotFoundError: + print('%s not found, have you compiled the program?' % schedule, + file=sys.stderr) + sys.exit(1) + for tapename in lines[2].split(' '): yield tapename.strip() diff --git a/Compiler/types.py b/Compiler/types.py index 0063fdc16..1dbe1f909 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -1675,6 +1675,13 @@ def output(self): __ne__ = lambda self, other: localint(self._v != other) class personal(Tape._no_truth): + """ Value known to one player. Supports operations with public + values and personal values known to the same player. Can be used + with :py:func:`~Compiler.library.print_ln_to`. + + :param player: player (int) + :param value: cleartext value (cint, cfix, cfloat) or array thereof + """ def __init__(self, player, value): assert value is not NotImplemented assert not isinstance(value, _secret) @@ -1685,8 +1692,24 @@ def __init__(self, player, value): self._v = value def binary_output(self): + """ Write binary output to + ``Player-Data/Binary-Output-P-`` if + supported by underlying type. Player must be known at compile time.""" self._v.binary_output(self.player) + def reveal_to(self, player): + """ Pass personal value to another player. """ + if isinstance(self._v, Array): + source = self._v[:] + else: + source = self._v + source = cint.conv(source) + res = cint(size=source.size) + sendpersonal(source.size, player, res, self.player, source) + if isinstance(self._v, Array): + res = Array.create_from(res) + return personal(player, res) + def bit_decompose(self, length): return [personal(self.player, x) for x in self._v.bit_decompose(length)] @@ -1858,8 +1881,13 @@ def get_random_inverse(cls): @vectorized_classmethod @set_instruction_type def get_random_input_mask_for(cls, player): - res = cls() - inputmask(res, player) + """ Secret random input mask according to security model. + + :return: mask (sint), mask (personal cint) + :param size: vector size (int, default 1) + """ + res = cls(), personal(player, cls.clear_type()) + inputmask(res[0], res[1]._v, player) return res @classmethod @@ -2071,15 +2099,13 @@ def reveal(self): @set_instruction_type def reveal_to(self, player): """ Reveal secret value to :py:obj:`player`. - Result written to ``Player-Data/Private-Output-P`` :param player: int - :returns: value to be used with :py:func:`~Compiler.library.print_ln_to` + :returns: :py:class:`personal` """ - masked = self.__class__() - res = personal(player, self.clear_type()) - startprivateoutput(masked, self, player) - stopprivateoutput(res._v, masked.reveal(), player) + mask = self.get_random_input_mask_for(player) + masked = self + mask[0] + res = personal(player, masked.reveal() - mask[1]) return res @@ -2633,21 +2659,20 @@ def raw_mod2m(self, m): @vectorize def reveal_to(self, player): """ Reveal secret value to :py:obj:`player`. - Result potentially written to - ``Player-Data/Private-Output-P``, but not if - :py:obj:`player` is a :py:class:`regint`. - :param player: public integer (int/regint/cint): - :returns: value to be used with :py:func:`~Compiler.library.print_ln_to` + :param player: public integer (int/regint/cint) + :returns: :py:class:`personal` """ - if not util.is_constant(player) or self.size > 1: + if not util.is_constant(player): secret_mask = sint() player_mask = cint() inputmaskreg(secret_mask, player_mask, regint.conv(player)) return personal(player, (self + secret_mask).reveal() - player_mask) else: - return super(sint, self).reveal_to(player) + res = personal(player, self.clear_type()) + privateoutput(self.size, player, res._v, self) + return res def private_division(self, divisor, active=True, dividend_length=None, divisor_length=None): @@ -4366,12 +4391,9 @@ def multipliable(v, k, f, size): def reveal_to(self, player): """ Reveal secret value to :py:obj:`player`. - Raw representation possibly written to - ``Player-Data/Private-Output-P``, but not if - :py:obj:`player` is a :py:class:`regint`. :param player: public integer (int/regint/cint) - :returns: value to be used with :py:func:`~Compiler.library.print_ln_to` + :returns: :py:class:`personal` """ return personal(player, cfix._new(self.v.reveal_to(player)._v, self.k, self.f)) @@ -5221,6 +5243,9 @@ def __setitem__(self, index, value): return self.assign(value, addresses) self._store(value, self.get_address(index)) + def to_array(self): + return self + def get_sub(self, start, stop=None): if stop is None: stop = start @@ -5471,6 +5496,10 @@ def shuffle(self): """ Insecure shuffle in place. """ self.assign_vector(self.get(regint.inc(len(self)).shuffle())) + def randomize(self, *args): + """ Randomize according to data type. """ + self.assign_vector(self.value_type.get_random(*args, size=len(self))) + def reveal(self): """ Reveal the whole array. @@ -5596,6 +5625,9 @@ def __len__(self): def __iter__(self): return (self[i] for i in range(len(self))) + def to_array(self): + return Array(self.total_size(), self.value_type, address=self.address) + def assign_all(self, value): """ Assign the same value to all entries. @@ -5958,6 +5990,7 @@ def direct_mul_trans(self, other, reduce=True, indices=None): """ assert len(self.sizes) == 2 assert len(other.sizes) == 2 + assert other.address != None if indices is None: assert self.sizes[1] == other.sizes[1] indices = [regint.inc(i) for i in self.sizes + other.sizes[::-1]] @@ -6145,6 +6178,16 @@ def diag(self): n = self.sizes[0] return self.array.get(regint.inc(n, 0, n + 1)) + def randomize(self, *args): + """ Randomize according to data type. """ + if self.total_size() < program.options.budget: + self.assign_vector( + self.value_type.get_random(*args, size=self.total_size())) + else: + @library.for_range(self.sizes[0]) + def _(i): + self[i].randomize(*args) + def reveal_list(self): """ Reveal as list. """ return list(self.get_vector().reveal()) @@ -6251,6 +6294,22 @@ def __init__(self, rows, columns, value_type, debug=None, address=None): MultiArray.__init__(self, [rows, columns], value_type, debug=debug, \ address=address) + def get_column(self, index): + """ Get column as vector. + + :param index: regint/cint/int + """ + assert self.value_type.n_elements() == 1 + addresses = regint.inc(self.sizes[0], self.address + index, + self.sizes[1]) + return self.value_type.load_mem(addresses) + + def get_column_by_row_indices(self, rows, column): + assert self.value_type.n_elements() == 1 + addresses = rows * self.sizes[1] + \ + regint.inc(len(rows), self.address + column, 0) + return self.value_type.load_mem(addresses) + def set_column(self, index, vector): """ Change column. diff --git a/FHE/FHE_Keys.cpp b/FHE/FHE_Keys.cpp index 2a4d6b123..20dfb1bb5 100644 --- a/FHE/FHE_Keys.cpp +++ b/FHE/FHE_Keys.cpp @@ -47,11 +47,18 @@ Rq_Element FHE_PK::sample_secret_key(PRNG& G) } void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost) +{ + Rq_Element a(*this); + a.randomize(G); + partial_key_gen(sk, a, G, noise_boost); +} + +void FHE_PK::partial_key_gen(const Rq_Element& sk, const Rq_Element& a, PRNG& G, + int noise_boost) { FHE_PK& PK = *this; - // Generate the main public key - PK.a0.randomize(G); + a0 = a; // b0=a0*s+p*e0 Rq_Element e0((*PK.params).FFTD(),evaluation,evaluation); @@ -77,9 +84,6 @@ void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost) mul(es,es,PK.pr); add(PK.Sw_b,PK.Sw_b,es); - // Lowering level as we only decrypt at level 0 - sk.lower_level(); - // bs=bs-p1*s^2 Rq_Element s2; mul(s2,sk,sk); // Mult at level 0 @@ -334,7 +338,7 @@ void FHE_SK::check(const FHE_Params& params, const FHE_PK& pk, template void FHE_SK::check(const FHE_PK& pk, const FD& FieldD) { - check(*params, pk, pr); + check(*params, pk, FieldD.get_prime()); pk.check_noise(*this); if (decrypt(pk.encrypt(Plaintext_(FieldD)), FieldD) != Plaintext_(FieldD)) diff --git a/FHE/FHE_Keys.h b/FHE/FHE_Keys.h index 72a7ddfa8..30ecc2925 100644 --- a/FHE/FHE_Keys.h +++ b/FHE/FHE_Keys.h @@ -150,6 +150,8 @@ class FHE_PK Rq_Element sample_secret_key(PRNG& G); void KeyGen(Rq_Element& sk, PRNG& G, int noise_boost = 1); + void partial_key_gen(const Rq_Element& sk, const Rq_Element& a, PRNG& G, + int noise_boost = 1); void check_noise(const FHE_SK& sk) const; void check_noise(const Rq_Element& x, bool check_modulo = false) const; diff --git a/FHE/FHE_Params.cpp b/FHE/FHE_Params.cpp index 8ae6c2885..0de8bb1e9 100644 --- a/FHE/FHE_Params.cpp +++ b/FHE/FHE_Params.cpp @@ -3,6 +3,11 @@ #include "FHE/Ring_Element.h" #include "Tools/Exceptions.h" +FHE_Params::FHE_Params(int n_mults) : + FFTData(n_mults + 1), Chi(0.7), sec_p(-1), matrix_dim(1) +{ +} + void FHE_Params::set(const Ring& R, const vector& primes) { @@ -24,6 +29,14 @@ void FHE_Params::set_sec(int sec) throw runtime_error("distributed decryption bound is zero"); } +void FHE_Params::set_matrix_dim(int matrix_dim) +{ + assert(matrix_dim > 0); + if (FFTData[0].get_prime() != 0) + throw runtime_error("cannot change matrix dimension after parameter generation"); + this->matrix_dim = matrix_dim; +} + bigint FHE_Params::Q() const { bigint res = FFTData[0].get_prime(); @@ -40,6 +53,7 @@ void FHE_Params::pack(octetStream& o) const Chi.pack(o); Bval.pack(o); o.store(sec_p); + o.store(matrix_dim); } void FHE_Params::unpack(octetStream& o) @@ -52,6 +66,7 @@ void FHE_Params::unpack(octetStream& o) Chi.unpack(o); Bval.unpack(o); o.get(sec_p); + o.get(matrix_dim); } bool FHE_Params::operator!=(const FHE_Params& other) const diff --git a/FHE/FHE_Params.h b/FHE/FHE_Params.h index 8ac400839..9407b0ba4 100644 --- a/FHE/FHE_Params.h +++ b/FHE/FHE_Params.h @@ -26,10 +26,11 @@ class FHE_Params // Data for distributed decryption int sec_p; bigint Bval; + int matrix_dim; public: - FHE_Params(int n_mults = 1) : FFTData(n_mults + 1), Chi(0.7), sec_p(-1) {} + FHE_Params(int n_mults = 1); int n_mults() const { return FFTData.size() - 1; } @@ -37,6 +38,9 @@ class FHE_Params void set(const vector& primes); void set_sec(int sec); + void set_matrix_dim(int matrix_dim); + int get_matrix_dim() const { return matrix_dim; } + const vector& FFTD() const { return FFTData; } const bigint& p0() const { return FFTData[0].get_prime(); } diff --git a/FHE/NTL-Subs.cpp b/FHE/NTL-Subs.cpp index c6e294a63..7c46a74fe 100644 --- a/FHE/NTL-Subs.cpp +++ b/FHE/NTL-Subs.cpp @@ -47,7 +47,7 @@ bool same_word_length(int l1, int l2) template <> int generate_semi_setup(int plaintext_length, int sec, - FHE_Params& params, FFT_Data& FTD, bool round_up) + FHE_Params& params, FFT_Data& FTD, bool round_up, int n) { int m = 1024; int lgp = plaintext_length; @@ -58,7 +58,7 @@ int generate_semi_setup(int plaintext_length, int sec, while (true) { tmp_params = params; - SemiHomomorphicNoiseBounds nb(p, phi_N(m), 1, sec, + SemiHomomorphicNoiseBounds nb(p, phi_N(m), n, sec, numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, tmp_params); bigint p1 = 2 * p * m, p0 = p; while (nb.min_p0(params.n_mults() > 0, p1) > p0) @@ -89,14 +89,14 @@ int generate_semi_setup(int plaintext_length, int sec, template <> int generate_semi_setup(int plaintext_length, int sec, - FHE_Params& params, P2Data& P2D, bool round_up) + FHE_Params& params, P2Data& P2D, bool round_up, int n) { if (params.n_mults() > 0) throw runtime_error("only implemented for 0-level BGV"); gf2n_short::init_field(plaintext_length); int m; char_2_dimension(m, plaintext_length); - SemiHomomorphicNoiseBounds nb(2, phi_N(m), 1, sec, + SemiHomomorphicNoiseBounds nb(2, phi_N(m), n, sec, numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, params); int lgp0 = numBits(nb.min_p0(false, 0)); int extra_slack = common_semi_setup(params, m, 2, lgp0, -1, round_up); @@ -590,6 +590,9 @@ void char_2_dimension(int& m, int& lg2) m=5797; lg2=40; break; + case 16: + m = 13107; + break; default: throw runtime_error("field size not supported"); break; diff --git a/FHE/NTL-Subs.h b/FHE/NTL-Subs.h index c0a2ecfea..acaba70b6 100644 --- a/FHE/NTL-Subs.h +++ b/FHE/NTL-Subs.h @@ -52,7 +52,7 @@ void generate_setup(int nparties, int lgp, int lg2, // semi-homomorphic, includes slack template int generate_semi_setup(int plaintext_length, int sec, - FHE_Params& params, FD& FieldD, bool round_up); + FHE_Params& params, FD& FieldD, bool round_up, int n = 1); // field-independent semi-homomorphic setup int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1, diff --git a/FHE/NoiseBounds.cpp b/FHE/NoiseBounds.cpp index 7ab8e5172..f2e151c42 100644 --- a/FHE/NoiseBounds.cpp +++ b/FHE/NoiseBounds.cpp @@ -39,6 +39,7 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p, bigint B_clean_not_top_gear = B_clean << int(ceil(sec / 2.)); B_clean = max(B_clean_not_top_gear, B_clean_top_gear); B_scale = (c1 + c2 * V_s) * p * sqrt(phi_m / 12.0); + int matrix_dim = params.get_matrix_dim(); #ifdef NOISY cout << "p * sqrt(phi(m) / 12): " << p * sqrt(phi_m / 12.0) << endl; cout << "V_s: " << V_s << endl; @@ -48,9 +49,11 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p, cout << "log(slack): " << slack << endl; cout << "B_clean: " << B_clean << endl; cout << "B_scale: " << B_scale << endl; + cout << "matrix dimension: " << matrix_dim << endl; #endif - drown = 1 + n * (bigint(1) << sec); + assert(matrix_dim > 0); + drown = 1 + matrix_dim * n * (bigint(1) << sec); } bigint SemiHomomorphicNoiseBounds::min_p0(const bigint& p1) diff --git a/FHE/Ring_Element.cpp b/FHE/Ring_Element.cpp index 812560a3a..554d4dc10 100644 --- a/FHE/Ring_Element.cpp +++ b/FHE/Ring_Element.cpp @@ -50,6 +50,7 @@ void Ring_Element::prepare_push() void Ring_Element::allocate() { + assert(FFTD); element.resize(FFTD->phi_m()); } diff --git a/FHE/Rq_Element.cpp b/FHE/Rq_Element.cpp index af7a664b5..531df90f7 100644 --- a/FHE/Rq_Element.cpp +++ b/FHE/Rq_Element.cpp @@ -109,6 +109,13 @@ void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b) } } +void Rq_Element::add(octetStream& os) +{ + Rq_Element tmp(*this); + tmp.unpack(os); + *this += tmp; +} + void Rq_Element::randomize(PRNG& G,int l) { set_level(l); @@ -246,7 +253,7 @@ void Rq_Element::Scale(const bigint& p) // Now add delta back onto a0 Rq_Element bb(b0,b1); - add(*this,*this,bb); + ::add(*this,*this,bb); // Now divide by p1 mod p0 modp p1_inv,pp; diff --git a/FHE/Rq_Element.h b/FHE/Rq_Element.h index d5e718419..a58cb7de0 100644 --- a/FHE/Rq_Element.h +++ b/FHE/Rq_Element.h @@ -93,12 +93,14 @@ class Rq_Element friend void mul(Rq_Element& ans,const Rq_Element& a,const Rq_Element& b); friend void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b); + void add(octetStream& os); + template Rq_Element& operator+=(const vector& other); - Rq_Element& operator+=(const Rq_Element& other) { add(*this, *this, other); return *this; } + Rq_Element& operator+=(const Rq_Element& other) { ::add(*this, *this, other); return *this; } - Rq_Element operator+(const Rq_Element& b) const { Rq_Element res(*this); add(res, *this, b); return res; } + Rq_Element operator+(const Rq_Element& b) const { Rq_Element res(*this); ::add(res, *this, b); return res; } Rq_Element operator-(const Rq_Element& b) const { Rq_Element res(*this); sub(res, *this, b); return res; } template Rq_Element operator*(const T& b) const { Rq_Element res(*this); mul(res, *this, b); return res; } @@ -176,7 +178,7 @@ Rq_Element& Rq_Element::operator+=(const vector& other) { Rq_Element tmp = *this; tmp.from(Iterator(other), lev); - add(*this, *this, tmp); + ::add(*this, *this, tmp); return *this; } diff --git a/FHEOffline/DataSetup.cpp b/FHEOffline/DataSetup.cpp index 0f5d1fe86..48a8a6ef8 100644 --- a/FHEOffline/DataSetup.cpp +++ b/FHEOffline/DataSetup.cpp @@ -203,7 +203,7 @@ template void PartSetup::secure_init(Player& P, MachineBase& machine, int plaintext_length, int sec) { - ::secure_init(*this, P, machine, plaintext_length, sec); + ::secure_init(*this, P, machine, plaintext_length, sec, params); } template diff --git a/FHEOffline/Multiplier.cpp b/FHEOffline/Multiplier.cpp index 732904b39..92632002c 100644 --- a/FHEOffline/Multiplier.cpp +++ b/FHEOffline/Multiplier.cpp @@ -130,6 +130,13 @@ void Multiplier::report_size(ReportType type, MemoryUsage& res) res += memory_usage; } +template +const vector& Multiplier::get_multiplicands( + const vector >& others_ct, const FHE_PK&) +{ + return others_ct[P.get_full_player().get_player(-P.get_offset())]; +} + template class Multiplier; template class Multiplier; diff --git a/FHEOffline/Multiplier.h b/FHEOffline/Multiplier.h index e2e1ce660..9ab517a66 100644 --- a/FHEOffline/Multiplier.h +++ b/FHEOffline/Multiplier.h @@ -55,6 +55,9 @@ class Multiplier size_t report_size(ReportType type); void report_size(ReportType type, MemoryUsage& res); size_t report_volatile() { return volatile_capacity; } + + const vector& get_multiplicands( + const vector>& others_ct, const FHE_PK&); }; #endif /* FHEOFFLINE_MULTIPLIER_H_ */ diff --git a/FHEOffline/PairwiseSetup.cpp b/FHEOffline/PairwiseSetup.cpp index bba83b5fd..047c84f2c 100644 --- a/FHEOffline/PairwiseSetup.cpp +++ b/FHEOffline/PairwiseSetup.cpp @@ -9,6 +9,7 @@ #include "Math/Setup.h" #include "FHEOffline/Proof.h" #include "FHEOffline/PairwiseMachine.h" +#include "FHEOffline/TemiSetup.h" #include "Tools/Commit.h" #include "Tools/Bundle.h" #include "Processor/OnlineOptions.h" @@ -53,7 +54,7 @@ void PairwiseSetup::init(const Player& P, int sec, int plaintext_length, template void PairwiseSetup::secure_init(Player& P, PairwiseMachine& machine, int plaintext_length, int sec) { - ::secure_init(*this, P, machine, plaintext_length, sec); + ::secure_init(*this, P, machine, plaintext_length, sec, params); alpha = FieldD; machine.sk = FHE_SK(params, FieldD.get_prime()); for (auto& pk : machine.other_pks) @@ -62,13 +63,14 @@ void PairwiseSetup::secure_init(Player& P, PairwiseMachine& machine, int pla template void secure_init(T& setup, Player& P, U& machine, - int plaintext_length, int sec) + int plaintext_length, int sec, FHE_Params& params) { machine.sec = sec; sec = max(sec, 40); machine.drown_sec = sec; string filename = PREP_DIR + T::name() + "-" + to_string(plaintext_length) + "-" + to_string(sec) + "-" + + to_string(params.get_matrix_dim()) + "-" + OnlineOptions::singleton.prime.get_str() + "-" + to_string(CowGearOptions::singleton.top_gear()) + "-P" + to_string(P.my_num()) + "-" + to_string(P.num_players()); @@ -85,7 +87,6 @@ void secure_init(T& setup, Player& P, U& machine, { cout << "Finding parameters for security " << sec << " and field size ~2^" << plaintext_length << endl; - setup.params = setup.params.n_mults(); setup.generate(P, machine, plaintext_length, sec); setup.check(P, machine); octetStream os; @@ -208,5 +209,8 @@ void PairwiseSetup::set_alphai(T alphai) template class PairwiseSetup; template class PairwiseSetup; -template void secure_init(PartSetup&, Player&, MachineBase&, int, int); -template void secure_init(PartSetup&, Player&, MachineBase&, int, int); +template void secure_init(PartSetup&, Player&, MachineBase&, int, int, FHE_Params& params); +template void secure_init(PartSetup&, Player&, MachineBase&, int, int, FHE_Params& params); + +template void secure_init(TemiSetup&, Player&, MachineBase&, int, int, FHE_Params& params); +template void secure_init(TemiSetup&, Player&, MachineBase&, int, int, FHE_Params& params); diff --git a/FHEOffline/PairwiseSetup.h b/FHEOffline/PairwiseSetup.h index 8e16eaf34..f6482edec 100644 --- a/FHEOffline/PairwiseSetup.h +++ b/FHEOffline/PairwiseSetup.h @@ -15,7 +15,7 @@ class MachineBase; template void secure_init(T& setup, Player& P, U& machine, - int plaintext_length, int sec); + int plaintext_length, int sec, FHE_Params& params); template class PairwiseSetup diff --git a/FHEOffline/SimpleDistDecrypt.cpp b/FHEOffline/SimpleDistDecrypt.cpp index 3774cd3c1..c8b923123 100644 --- a/FHEOffline/SimpleDistDecrypt.cpp +++ b/FHEOffline/SimpleDistDecrypt.cpp @@ -18,7 +18,12 @@ void SimpleDistDecrypt::reshare(Plaintext& EC) { (void)EC; + m = reshare(cm); +} +template +Plaintext_ SimpleDistDecrypt::reshare(const Ciphertext& cm) +{ PRNG G; G.ReSeed(); this->f.randomize(G, Full); @@ -27,10 +32,13 @@ void SimpleDistDecrypt::reshare(Plaintextrun(cm); // Step 4 + Plaintext_ m(this->f.get_field()); if (this->P.my_num()==0) { sub(m,this->mf,this->f); } else { m=this->f; m.negate(); } + + return m; } diff --git a/FHEOffline/SimpleDistDecrypt.h b/FHEOffline/SimpleDistDecrypt.h index 9589f15a1..c929a7990 100644 --- a/FHEOffline/SimpleDistDecrypt.h +++ b/FHEOffline/SimpleDistDecrypt.h @@ -20,6 +20,7 @@ class SimpleDistDecrypt : public DistDecrypt void reshare(Plaintext& m, const Ciphertext& cm, EncCommitBase& EC); + Plaintext_ reshare(const Ciphertext& cm); }; #endif /* FHEOFFLINE_SIMPLEDISTDECRYPT_H_ */ diff --git a/FHEOffline/TemiSetup.cpp b/FHEOffline/TemiSetup.cpp new file mode 100644 index 000000000..fc222ed51 --- /dev/null +++ b/FHEOffline/TemiSetup.cpp @@ -0,0 +1,59 @@ +/* + * TemiSetup.cpp + * + */ + +#include "TemiSetup.h" +#include "PairwiseSetup.h" +#include "FHE/NTL-Subs.h" +#include "Protocols/HemiOptions.h" + +template +TemiSetup::TemiSetup() +{ + this->params = FHE_Params(0); + this->pk = {this->params, 0}; + this->sk = {this->params, 0}; + this->calpha = this->params; + this->params.set_matrix_dim( + HemiOptions::singleton.plain_matmul ? + 1 : OnlineOptions::singleton.batch_size); +} + +template +void TemiSetup::secure_init(Player& P, int plaintext_length) +{ + MachineBase machine; + ::secure_init(*this, P, machine, plaintext_length, 0, this->params); +} + +template +void TemiSetup::generate(Player& P, MachineBase&, + int plaintext_length, int sec) +{ + generate_semi_setup(plaintext_length, sec, this->params, this->FieldD, + false, P.num_players()); + this->sk = {this->params, this->FieldD.get_prime()}; + this->pk = {this->params, this->FieldD.get_prime()}; +} + +template +void TemiSetup::key_and_mac_generation(Player& P, MachineBase&, int, + true_type) +{ + Rq_Element a(this->params); + GlobalPRNG GG(P); + a.randomize(GG); + SeededPRNG G; + auto sk = this->pk.sample_secret_key(G); + this->sk.assign(sk); + this->pk.partial_key_gen(sk, a, G); + TreeSum ts; + vector pks; + pks.push_back(this->pk.b()); + ts.run(pks, P); + this->pk.assign(this->pk.a(), pks[0]); +} + +template class TemiSetup; +template class TemiSetup; diff --git a/FHEOffline/TemiSetup.h b/FHEOffline/TemiSetup.h new file mode 100644 index 000000000..483cb0ee6 --- /dev/null +++ b/FHEOffline/TemiSetup.h @@ -0,0 +1,34 @@ +/* + * TemiSetup.h + * + */ + +#ifndef FHEOFFLINE_TEMISETUP_H_ +#define FHEOFFLINE_TEMISETUP_H_ + +#include "FHE/FHE_Keys.h" +#include "FHEOffline/SimpleMachine.h" + +template +class TemiSetup : public PartSetup +{ +public: + static string name() + { + return "TemiParams"; + } + + static string protocol_name(int) + { + return "Temi"; + } + + TemiSetup(); + + void secure_init(Player& P, int plaintext_length); + void generate(Player& P, MachineBase&, int plaintext_length, int sec); + + void key_and_mac_generation(Player& P, MachineBase&, int, true_type); +}; + +#endif /* FHEOFFLINE_TEMISETUP_H_ */ diff --git a/GC/Memory.h b/GC/Memory.h index 359677a20..006a91d94 100644 --- a/GC/Memory.h +++ b/GC/Memory.h @@ -47,11 +47,11 @@ inline void Memory::check_index(Integer index) const ss << T::type_string() << " memory overflow: " << i << "/" << vector::size(); throw Processor_Error(ss.str()); } -#endif #ifdef DEBUG_MEMORY cout << typeid(T).name() << " at " << this << " index " << i << ": " << vector::operator[](i) << endl; #endif +#endif } template diff --git a/GC/ShareSecret.h b/GC/ShareSecret.h index 48f75b8f2..6d9f26525 100644 --- a/GC/ShareSecret.h +++ b/GC/ShareSecret.h @@ -122,6 +122,7 @@ class RepSecretBase : public FixedVec, public ShareSecret static const bool dishonest_majority = false; static const bool variable_players = false; static const bool needs_ot = false; + static const bool has_mac = false; static string type_string() { return "replicated secret"; } static string phase_name() { return "Replicated computation"; } diff --git a/GC/TinySecret.h b/GC/TinySecret.h index 9b6c84782..9cdde3dc7 100644 --- a/GC/TinySecret.h +++ b/GC/TinySecret.h @@ -49,6 +49,7 @@ class VectorSecret : public Secret static const bool dishonest_majority = T::dishonest_majority; static const bool variable_players = T::variable_players; static const bool needs_ot = T::needs_ot; + static const bool has_mac = T::has_mac; static const bool expensive_triples = false; static const int default_length = 64; diff --git a/GC/instructions.h b/GC/instructions.h index 66ae46d22..49443cc23 100644 --- a/GC/instructions.h +++ b/GC/instructions.h @@ -55,7 +55,7 @@ X(BITDECC, PROC.bitdecc(EXTRA, C0)) \ X(SHRCBI, C0 = PC1 >> IMM) \ X(SHLCBI, C0 = PC1 << IMM) \ - X(LDBITS, S0.load_clear(REG1, IMM)) \ + X(LDBITS, S0.load_clear(REG1, int(IMM))) \ X(LDMSB, PROC.mem_op(SIZE, PROC.S, MMS, R0, IMM)) \ X(STMSB, PROC.mem_op(SIZE, MMS, PROC.S, IMM, R0)) \ X(LDMCB, PROC.mem_op(SIZE, PROC.C, MMC, R0, IMM)) \ diff --git a/Machines/ShamirMachine.hpp b/Machines/ShamirMachine.hpp index 7697c5124..9f18d3a6f 100644 --- a/Machines/ShamirMachine.hpp +++ b/Machines/ShamirMachine.hpp @@ -23,6 +23,7 @@ #include "Protocols/Shamir.hpp" #include "Protocols/ShamirMC.hpp" #include "Protocols/MaliciousShamirMC.hpp" +#include "Protocols/MaliciousShamirPO.hpp" #include "Protocols/MAC_Check_Base.hpp" #include "Protocols/Beaver.hpp" #include "Protocols/Spdz2kPrep.hpp" diff --git a/Machines/temi-party.cpp b/Machines/temi-party.cpp new file mode 100644 index 000000000..12e99dc27 --- /dev/null +++ b/Machines/temi-party.cpp @@ -0,0 +1,37 @@ +/* + * temi-party.cpp + * + */ + +#include "Protocols/TemiShare.h" +#include "Math/gfp.h" +#include "Math/gf2n.h" +#include "FHE/P2Data.h" +#include "Tools/ezOptionParser.h" +#include "GC/SemiSecret.h" +#include "GC/SemiPrep.h" + +#include "Processor/FieldMachine.hpp" +#include "Protocols/TemiPrep.hpp" +#include "Processor/Data_Files.hpp" +#include "Processor/Instruction.hpp" +#include "Processor/Machine.hpp" +#include "Protocols/SemiPrep.hpp" +#include "Protocols/SemiInput.hpp" +#include "Protocols/MAC_Check_Base.hpp" +#include "Protocols/MAC_Check.hpp" +#include "Protocols/SemiMC.hpp" +#include "Protocols/Beaver.hpp" +#include "Protocols/MalRepRingPrep.hpp" +#include "Protocols/Hemi.hpp" +#include "GC/ShareSecret.hpp" +#include "GC/SemiHonestRepPrep.h" +#include "Math/gfp.hpp" + +int main(int argc, const char** argv) +{ + ez::ezOptionParser opt; + HemiOptions::singleton = {opt, argc, argv}; + DishonestMajorityFieldMachine(argc, argv, + opt); +} diff --git a/Makefile b/Makefile index e40528b8c..4f558e1d6 100644 --- a/Makefile +++ b/Makefile @@ -61,7 +61,7 @@ arithmetic: rep-ring rep-field shamir semi2k-party.x semi-party.x mascot sy binary: rep-bin yao semi-bin-party.x tinier-party.x tiny-party.x ccd-party.x malicious-ccd-party.x real-bmr all: overdrive she-offline -arithmetic: hemi-party.x soho-party.x gear +arithmetic: semi-he gear -include $(DEPS) include $(wildcard *.d static/*.d) @@ -87,6 +87,7 @@ she-offline: Check-Offline.x spdz2-offline.x overdrive: simple-offline.x pairwise-offline.x cnc-offline.x gear gear: cowgear-party.x chaigear-party.x lowgear-party.x highgear-party.x +semi-he: hemi-party.x soho-party.x temi-party.x rep-field: malicious-rep-field-party.x replicated-field-party.x ps-rep-field-party.x @@ -210,6 +211,7 @@ static/spdz2k-party.x: $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp)) semi-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o semi2k-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o hemi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) +temi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) soho-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(TINIER) chaigear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(TINIER) @@ -217,6 +219,7 @@ lowgear-party.x: $(FHEOFFLINE) $(TINIER) Protocols/CowGearOptions.o Protocols/Lo highgear-party.x: $(FHEOFFLINE) $(TINIER) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o atlas-party.x: GC/AtlasSecret.o static/hemi-party.x: $(FHEOBJS) +static/temi-party.x: $(FHEOBJS) static/soho-party.x: $(FHEOBJS) static/cowgear-party.x: $(FHEOBJS) static/chaigear-party.x: $(FHEOBJS) diff --git a/Math/FixedVec.h b/Math/FixedVec.h index 55983e0b0..c0b2373ed 100644 --- a/Math/FixedVec.h +++ b/Math/FixedVec.h @@ -14,11 +14,6 @@ using namespace std; #include "Tools/random.h" #include "field_types.h" -template class ReplicatedMC; -template class ReplicatedInput; -template class ReplicatedPrivateOutput; -template class Replicated; - template class FixedVec { diff --git a/Math/Zp_Data.h b/Math/Zp_Data.h index f30e71037..13d700fc1 100644 --- a/Math/Zp_Data.h +++ b/Math/Zp_Data.h @@ -233,7 +233,7 @@ inline void Zp_Data::Mont_Mult_(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* if (mpn_cmp(ans+T,prA,T+1)>=0) { mpn_sub_fixed_n(z,ans+T,prA); } else - { inline_mpn_copyi(z,ans+T,T); } + { inline_mpn_copyi(z,ans+T); } #else Mont_Mult(z, x, y, t); #endif diff --git a/Math/gf2n.cpp b/Math/gf2n.cpp index 1a6fe41d1..f9491fb7b 100644 --- a/Math/gf2n.cpp +++ b/Math/gf2n.cpp @@ -18,15 +18,21 @@ bool gf2n_::useC; word gf2n_short_table[256][256]; -#define num_2_fields 6 +#define num_2_fields 7 /* Require * 2*(n-1)-64+t1<64 */ -int fields_2[num_2_fields][4] = { - {4,1,0,0},{8,4,3,1},{28,1,0,0},{40,20,15,10},{63,1,0,0},{128,7,2,1}, - }; - +int fields_2[num_2_fields][4] = +{ + { 4, 1, 0, 0 }, + { 8, 4, 3, 1 }, + { 16, 5, 3, 1 }, + { 28, 1, 0, 0 }, + { 40, 20, 15, 10 }, + { 63, 1, 0, 0 }, + { 128, 7, 2, 1 }, +}; template void gf2n_::init_tables() diff --git a/Math/mpn_fixed.h b/Math/mpn_fixed.h index b55a6b7ec..b1c5642be 100644 --- a/Math/mpn_fixed.h +++ b/Math/mpn_fixed.h @@ -24,6 +24,12 @@ inline void inline_mpn_copyi(mp_limb_t* dest, const mp_limb_t* src, mp_size_t si avx_memcpy(dest, src, size * sizeof(mp_limb_t)); } +template +inline void inline_mpn_copyi(mp_limb_t* dest, const mp_limb_t* src) +{ + avx_memcpy(dest, src); +} + inline void debug_print(const char* name, const mp_limb_t* x, int n) { (void)name, (void)x, (void)n; diff --git a/Networking/Player.h b/Networking/Player.h index 9c90dbd1f..ff4bdcd1d 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -542,6 +542,7 @@ class OffsetPlayer : public TwoPartyPlayer int other_player_num() const { return P.get_player(offset); } int num_players() const { return 2; } int get_offset() const { return offset; } + Player& get_full_player() const { return P; } void send(octetStream& o) const { P.send_to(P.get_player(offset), o); } void reverse_send(octetStream& o) const { P.send_to(P.get_player(-offset), o); } diff --git a/OT/BaseOT.cpp b/OT/BaseOT.cpp index 988565854..730ffa6f6 100644 --- a/OT/BaseOT.cpp +++ b/OT/BaseOT.cpp @@ -206,6 +206,18 @@ void BaseOT::exec_base(bool new_receiver_inputs) receiver_outputs[i + j].set_byte(k, receiver_keys[j][k]); } } + +#ifdef BASE_OT_DEBUG + for (j = 0; j < 4; j++) + for (k = 0; k < AES_BLK_SIZE; k++) + { + printf("%4d-th receiver key:", i+j); + for (k = 0; k < HASHBYTES; k++) printf("%.2X", receiver_keys[j][k]); + printf("\n"); + } + + printf("\n"); +#endif } } @@ -244,12 +256,6 @@ void BaseOT::exec_base(bool new_receiver_inputs) for (k = 0; k < HASHBYTES; k++) printf("%.2X", sender_keys[1][j][k]); printf("\n"); } - if (ot_role & RECEIVER) - { - printf("%4d-th receiver key:", i+j); - for (k = 0; k < HASHBYTES; k++) printf("%.2X", receiver_keys[j][k]); - printf("\n"); - } } printf("\n"); diff --git a/Processor/Binary_File_IO.hpp b/Processor/Binary_File_IO.hpp index ef735279a..ea8239a5b 100644 --- a/Processor/Binary_File_IO.hpp +++ b/Processor/Binary_File_IO.hpp @@ -25,7 +25,7 @@ void Binary_File_IO::write_to_file(const string filename, if (start_pos != -1) { - long write_pos = start_pos * T::size(); + long write_pos = file_signature().get_total_length() + start_pos * T::size(); // fill with zeros if needed for (long i = outf.tellp(); i < write_pos; i++) outf.put(0); @@ -50,10 +50,13 @@ void Binary_File_IO::read_from_file(const string filename, vector< T >& buffer, inf.open(filename, ios::in | ios::binary); if (inf.fail()) { throw file_missing(filename, "Binary_File_IO.read_from_file expects this file to exist."); } + check_file_signature(inf, filename).get_length(); + auto data_start = inf.tellg(); + int size_in_bytes = T::size() * buffer.size(); int n_read = 0; char read_buffer[size_in_bytes]; - inf.seekg(start_posn * T::size()); + inf.seekg(start_posn * T::size(), iostream::cur); do { inf.read(read_buffer + n_read, size_in_bytes - n_read); @@ -62,7 +65,9 @@ void Binary_File_IO::read_from_file(const string filename, vector< T >& buffer, if (inf.eof()) { stringstream ss; - ss << "Got to EOF when reading from disk (expecting " << size_in_bytes << " bytes)."; + ss << "Got to EOF when reading from disk (expecting " << size_in_bytes + << " bytes from " << (long(data_start) + start_posn * T::size()) + << ")."; throw file_error(ss.str()); } if (inf.fail()) @@ -74,7 +79,7 @@ void Binary_File_IO::read_from_file(const string filename, vector< T >& buffer, } while (n_read < size_in_bytes); - end_posn = inf.tellg() / T::size(); + end_posn = (inf.tellg() - data_start) / T::size(); assert (end_posn == start_posn + int(buffer.size())); //Check if at end of file by getting 1 more char. diff --git a/Processor/Input.h b/Processor/Input.h index 98c6c83b0..728c81f6a 100644 --- a/Processor/Input.h +++ b/Processor/Input.h @@ -32,6 +32,15 @@ class InputBase Buffer buffer; Timer timer; + // Send my inputs (not generally available) + virtual void send_mine() { throw not_implemented(); } + // Get share for next input of mine (not generally available) + virtual T finalize_mine() { throw not_implemented(); } + // Store share for next input from ``player`` from buffer ``o`` + // in ``target`` (not generally available) + virtual void finalize_other(int, T&, octetStream&, int = -1) + { throw not_implemented(); } + public: vector os; int values_input; @@ -61,18 +70,12 @@ class InputBase /// Schedule input from other player virtual void add_other(int player, int n_bits = -1) = 0; /// Schedule input from all players - void add_from_all(const clear& input, int n_bits = -1); + void add_from_all(const typename T::open_type& input, int n_bits = -1); - /// Send my inputs - virtual void send_mine() = 0; /// Run input protocol for all players virtual void exchange(); - /// Get share for next input of mine - virtual T finalize_mine() = 0; - /// Store share for next input from ``player`` from buffer ``o`` in ``target`` - virtual void finalize_other(int player, T& target, octetStream& o, int n_bits = -1) = 0; - /// Get share for next input from ``player` + /// Get share for next input from ``player`` virtual T finalize(int player, int n_bits = -1); void raw_input(SubProcessor& proc, const vector& args, int size); diff --git a/Processor/Input.hpp b/Processor/Input.hpp index b9f7a77ab..246c9eb1d 100644 --- a/Processor/Input.hpp +++ b/Processor/Input.hpp @@ -113,7 +113,7 @@ void Input::add_other(int player, int) } template -void InputBase::add_from_all(const clear& input, int n_bits) +void InputBase::add_from_all(const typename T::open_type& input, int n_bits) { for (int i = 0; i < P->num_players(); i++) if (i == P->my_num()) diff --git a/Processor/Instruction.h b/Processor/Instruction.h index ca062cbcb..a7e1e3185 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -106,6 +106,7 @@ enum MATMULSM = 0xAB, CONV2DS = 0xAC, CHECK = 0xAF, + PRIVATEOUTPUT = 0xAD, // Data access TRIPLE = 0x50, BIT = 0x51, @@ -127,6 +128,7 @@ enum INPUTMIXEDREG = 0xF3, RAWINPUT = 0xF4, INPUTPERSONAL = 0xF5, + SENDPERSONAL = 0xF6, STARTINPUT = 0x61, STOPINPUT = 0x62, READSOCKETC = 0x63, diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 25fa666f4..1bc46f94f 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -200,14 +200,17 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case USE: case USE_INP: case USE_EDABIT: - case STARTPRIVATEOUTPUT: - case GSTARTPRIVATEOUTPUT: - case STOPPRIVATEOUTPUT: - case GSTOPPRIVATEOUTPUT: case DIGESTC: + case INPUTMASK: + case GINPUTMASK: get_ints(r, s, 2); n = get_int(s); break; + case STARTPRIVATEOUTPUT: + case GSTARTPRIVATEOUTPUT: + case STOPPRIVATEOUTPUT: + case GSTOPPRIVATEOUTPUT: + throw runtime_error("two-stage private output not supported any more"); case USE_MATMUL: get_ints(r, s, 3); n = get_int(s); @@ -237,8 +240,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case PRINTREGB: case GPRINTREG: case LDINT: - case INPUTMASK: - case GINPUTMASK: case INV2M: case CONDPRINTSTR: case CONDPRINTSTRB: @@ -290,6 +291,8 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case RAWINPUT: case GRAWINPUT: case INPUTPERSONAL: + case SENDPERSONAL: + case PRIVATEOUTPUT: case TRUNC_PR: case RUN_TAPE: num_var_args = get_int(s); @@ -599,6 +602,7 @@ int BaseInstruction::get_reg_type() const case PUBINPUT: case FLOATOUTPUT: case READSOCKETC: + case PRIVATEOUTPUT: return CINT; default: if (is_gf2n_instruction()) @@ -738,10 +742,16 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const skip = 1; break; case INPUTPERSONAL: + case PRIVATEOUTPUT: size_offset = -2; offset = 2; skip = 4; break; + case SENDPERSONAL: + size_offset = -2; + offset = 2; + skip = 5; + break; case READSOCKETS: case READSOCKETC: case READSOCKETINT: @@ -939,13 +949,11 @@ inline void Instruction::execute(Processor& Proc) const break; case INPUTMASK: Procp.DataF.get_input(Proc.get_Sp_ref(r[0]), Proc.temp.rrp, n); - if (n == Proc.P.my_num()) - Proc.temp.rrp.output(Proc.private_output, false); + Proc.write_Cp(r[1], Proc.temp.rrp); break; case GINPUTMASK: Proc2.DataF.get_input(Proc.get_S2_ref(r[0]), Proc.temp.ans2, n); - if (n == Proc.P.my_num()) - Proc.temp.ans2.output(Proc.private_output, false); + Proc.write_C2(r[1], Proc.temp.ans2); break; case INPUT: sint::Input::template input>(Proc.Procp, start, size); @@ -974,6 +982,12 @@ inline void Instruction::execute(Processor& Proc) const case INPUTPERSONAL: Proc.Procp.input_personal(start); return; + case SENDPERSONAL: + Proc.Procp.send_personal(start); + return; + case PRIVATEOUTPUT: + Proc.Procp.private_output(start); + return; // Note: Fp version has different semantics for NOTC than GNOTC case NOTC: to_bigint(Proc.temp.aa, Proc.read_Cp(r[1])); @@ -1202,18 +1216,6 @@ inline void Instruction::execute(Processor& Proc) const Proc.binary_output.write((char*) &tmp, sizeof(double)); } break; - case STARTPRIVATEOUTPUT: - Proc.privateOutputp.start(n,r[0],r[1]); - break; - case GSTARTPRIVATEOUTPUT: - Proc.privateOutput2.start(n,r[0],r[1]); - break; - case STOPPRIVATEOUTPUT: - Proc.privateOutputp.stop(n,r[0],r[1]); - break; - case GSTOPPRIVATEOUTPUT: - Proc.privateOutput2.stop(n,r[0],r[1]); - break; case PREP: Procp.DataF.get(Proc.Procp.get_S(), r, start, size); return; diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index a43c9d475..cd318f1aa 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -97,12 +97,19 @@ Machine::Machine(int my_number, Names& playerNames, // initialize persistence if necessary for (auto& prog : progs) { - if (prog.writes_persistance) + if (prog.writes_persistence) { string filename = Binary_File_IO::filename(my_number); ifstream pers(filename); - if (pers.fail()) - ofstream pers(filename, ios::binary); + try + { + check_file_signature(pers, filename); + } + catch (signature_mismatch&) + { + ofstream pers(filename, ios::binary); + file_signature().output(pers); + } break; } } @@ -418,12 +425,14 @@ void Machine::run() cerr << "Full broadcast" << endl; #endif +#ifdef CHOP_MEMORY // Reduce memory size to speed up unsigned max_size = 1 << 20; if (M2.size_s() > max_size) M2.resize_s(max_size); if (Mp.size_s() > max_size) Mp.resize_s(max_size); +#endif // Write out the memory to use next time ofstream outf(memory_filename(), ios::out | ios::binary); diff --git a/Processor/Memory.h b/Processor/Memory.h index 9ec02d2b8..1fbeda7ec 100644 --- a/Processor/Memory.h +++ b/Processor/Memory.h @@ -44,9 +44,9 @@ class Memory static void check_index(const vector& M, size_t i) { (void) M, (void) i; -#ifdef NO_CHECK_INDEX +#ifndef NO_CHECK_INDEX if (i >= M.size()) - throw overflow("memory", i, M.size()); + throw overflow(U::type_string() + " memory", i, M.size()); #endif } diff --git a/Processor/Memory.hpp b/Processor/Memory.hpp index c3c3e01bf..ef767441b 100644 --- a/Processor/Memory.hpp +++ b/Processor/Memory.hpp @@ -19,6 +19,9 @@ void MemoryPart::minimum_size(size_t size) { if (size > this->size()) this->resize(size); +#ifdef DEBUG_MEMORY_SIZE + cerr << T::type_string() << " memory has now size " << this->size() << endl; +#endif } catch (bad_alloc&) { @@ -58,9 +61,9 @@ istream& operator>>(istream& s,Memory& M) int len; s >> len; - M.resize_s(len); + M.MS.minimum_size(len); s >> len; - M.resize_c(len); + M.MC.minimum_size(len); s.seekg(1, istream::cur); for (unsigned int i=0; i& proc; + typename T::MAC_Check MC; deque masks; public: - PrivateOutput(SubProcessor& proc) : proc(proc) { }; + PrivateOutput(SubProcessor& proc); + ~PrivateOutput(); - void start(int player, int target, int source); - void stop(int player, int dest, int source); - - T start(int player, const T& source); - typename T::clear stop(int player, const typename T::clear& masked); + void prepare_sending(const T& source, int player); + void exchange(); + typename T::clear finalize(int player); }; #endif /* PROCESSOR_PRIVATEOUTPUT_H_ */ diff --git a/Processor/PrivateOutput.hpp b/Processor/PrivateOutput.hpp index 977e7e15d..d2cee8a14 100644 --- a/Processor/PrivateOutput.hpp +++ b/Processor/PrivateOutput.hpp @@ -7,13 +7,21 @@ #include "Processor.h" template -void PrivateOutput::start(int player, int target, int source) +PrivateOutput::PrivateOutput(SubProcessor& proc) : + proc(proc), MC(proc.MC.get_alphai()) { - proc.get_S_ref(target) = start(player, proc.get_S_ref(source)); + MC.init_open(proc.P); + MC.set_prep(proc.DataF); } template -T PrivateOutput::start(int player, const T& source) +PrivateOutput::~PrivateOutput() +{ + MC.Check(proc.P); +} + +template +void PrivateOutput::prepare_sending(const T& source, int player) { assert (player < proc.P.num_players()); open_type mask; @@ -24,26 +32,25 @@ T PrivateOutput::start(int player, const T& source) if (player == proc.P.my_num()) masks.push_back(mask); - return res; + MC.prepare_open(res); } template -void PrivateOutput::stop(int player, int dest, int source) +void PrivateOutput::exchange() { - auto& value = proc.get_C_ref(dest); - value = stop(player, proc.get_C_ref(source)); - if (proc.Proc) - value.output(proc.Proc->private_output, false); + MC.exchange(proc.P); } template -typename T::clear PrivateOutput::stop(int player, const typename T::clear& source) +typename T::clear PrivateOutput::finalize(int player) { - typename T::clear value; + auto res = MC.finalize_open(); + if (player == proc.P.my_num()) { - value = source - masks.front(); + res -= masks.front(); masks.pop_front(); } - return value; + + return res; } diff --git a/Processor/Processor.h b/Processor/Processor.h index c91b677bf..38ea7f258 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -71,6 +71,8 @@ class SubProcessor void conv2ds(const Instruction& instruction); void input_personal(const vector& args); + void send_personal(const vector& args); + void private_output(const vector& args); CheckVector& get_S() { @@ -110,7 +112,6 @@ class ArithmeticProcessor : public ProcessorBase ifstream private_input; ifstream public_input; ofstream public_output; - ofstream private_output; ofstream binary_output; int sent, rounds; @@ -172,9 +173,6 @@ class Processor : public ArithmeticProcessor SubProcessor Proc2; SubProcessor Procp; - typename sgf2n::PrivateOutput privateOutput2; - typename sint::PrivateOutput privateOutputp; - unsigned int PC; TempVars temp; diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index c55a6dfc1..d74594b3d 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -4,9 +4,8 @@ #include "Processor/Processor.h" #include "Processor/Program.h" #include "GC/square64.h" +#include "SpecificPrivateOutput.h" -#include "Protocols/ReplicatedInput.hpp" -#include "Protocols/ReplicatedPrivateOutput.hpp" #include "Processor/ProcessorBase.hpp" #include "GC/Processor.hpp" #include "GC/ShareThread.hpp" @@ -63,7 +62,6 @@ Processor::Processor(int thread_num,Player& P, share_thread(DataF.DataFb, P, machine.get_bit_mac_key()), Procb(machine.bit_memories), Proc2(*this,MC2,DataF.DataF2,P),Procp(*this,MCp,DataF.DataFp,P), - privateOutput2(Proc2),privateOutputp(Procp), external_clients(P.my_num()), binary_file_io(Binary_File_IO()) { @@ -74,7 +72,6 @@ Processor::Processor(int thread_num,Player& P, private_input_filename = (get_filename(PREP_DIR "Private-Input-",true)); private_input.open(private_input_filename.c_str()); public_output.open(get_filename(PREP_DIR "Public-Output-",true).c_str(), ios_base::out); - private_output.open(get_filename(PREP_DIR "Private-Output-",true).c_str(), ios_base::out); binary_output.open( get_parameterized_filename(P.my_num(), thread_num, PREP_DIR "Binary-Output"), ios_base::out); @@ -654,6 +651,37 @@ void SubProcessor::input_personal(const vector& args) S[args[i + 2] + j] = input.finalize(args[i + 1]); } +template +void SubProcessor::private_output(const vector& args) +{ + typename T::PrivateOutput output(*this); + for (size_t i = 0; i < args.size(); i += 4) + for (int j = 0; j < args[i]; j++) + { + int player = args[i + 1]; + output.prepare_sending(S.at(args[i + 3] + j), player); + } + output.exchange(); + for (size_t i = 0; i < args.size(); i += 4) + for (int j = 0; j < args[i]; j++) + C.at(args[i + 2] + j) = output.finalize(args[i + 1]); +} + +template +void SubProcessor::send_personal(const vector& args) +{ + octetStreams to_send(P), to_receive(P); + for (size_t i = 0; i < args.size(); i += 5) + if (args[i + 3] == P.my_num()) + for (int j = 0; j < args[i]; j++) + C[args[i + 4] + j].pack(to_send[args[i + 1]]); + P.send_receive_all(to_send, to_receive); + for (size_t i = 0; i < args.size(); i += 5) + if (args[i + 1] == P.my_num()) + for (int j = 0; j < args[i]; j++) + C[args[i + 2] + j].unpack(to_receive[args[i + 3]]); +} + template typename sint::clear Processor::get_inverse2(unsigned m) { diff --git a/Processor/Program.cpp b/Processor/Program.cpp index c33039428..dac73400b 100644 --- a/Processor/Program.cpp +++ b/Processor/Program.cpp @@ -23,7 +23,7 @@ void Program::compute_constants() max_mem[reg_type] = max(max_mem[reg_type], p[i].get_mem(RegType(reg_type))); } - writes_persistance |= p[i].opcode == WRITEFILESHARE; + writes_persistence |= p[i].opcode == WRITEFILESHARE; } } diff --git a/Processor/Program.h b/Processor/Program.h index a41c9e2a6..87a263f08 100644 --- a/Processor/Program.h +++ b/Processor/Program.h @@ -30,10 +30,10 @@ class Program public: - bool writes_persistance; + bool writes_persistence; Program(int nplayers) : offline_data_used(nplayers), - unknown_usage(false), writes_persistance(false) + unknown_usage(false), writes_persistence(false) { compute_constants(); } // Read in a program diff --git a/Processor/SpecificPrivateOutput.h b/Processor/SpecificPrivateOutput.h new file mode 100644 index 000000000..7878db1cd --- /dev/null +++ b/Processor/SpecificPrivateOutput.h @@ -0,0 +1,65 @@ +/* + * SpecificPrivateOutput.h + * + */ + +#ifndef PROCESSOR_SPECIFICPRIVATEOUTPUT_H_ +#define PROCESSOR_SPECIFICPRIVATEOUTPUT_H_ + +template +class SpecificPrivateOutput +{ + deque secrets; + vector pos; + Player& P; + vector active; + +public: + SpecificPrivateOutput(SubProcessor& proc) : + P(proc.P) + { + for (int i = 0; i < P.num_players(); i++) + pos.push_back(new typename T::PO(proc.P)); + active.resize(P.num_players()); + } + + ~SpecificPrivateOutput() + { + for (auto& x : pos) + delete x; + } + + void prepare_sending(const T& secret, int player) + { + pos[player]->prepare_sending(secret, player); + if (P.my_num() == player) + secrets.push_back(secret); + active[player] = true; + } + + void exchange() + { + for (int i = 0; i < this->P.num_players(); i++) + if (active[i]) + { + if (i == this->P.my_num()) + pos[i]->receive(); + else + pos[i]->send(i); + } + } + + typename T::clear finalize(int player) + { + if (player == this->P.my_num()) + { + T secret = secrets.front(); + secrets.pop_front(); + return pos[player]->finalize(secret); + } + else + return {}; + } +}; + +#endif /* PROCESSOR_SPECIFICPRIVATEOUTPUT_H_ */ diff --git a/Programs/Source/falcon_alex.mpc b/Programs/Source/falcon_alex.mpc new file mode 100644 index 000000000..3c535248f --- /dev/null +++ b/Programs/Source/falcon_alex.mpc @@ -0,0 +1,100 @@ +from Compiler.ml import keras +import Compiler.ml as tf + +try: + n_epochs = int(program.args[1]) +except (ValueError, IndexError): + n_epochs = 10 + +try: + batch_size = int(program.args[2]) +except (ValueError, IndexError): + batch_size = 128 + +try: + n_threads = int(program.args[3]) +except (ValueError, IndexError): + n_threads = 36 + +#Instantiation +AlexNet = [] + +padding = 'same' +batchnorm = 'batchnorm' in program.args + +#1st Convolutional Layer +AlexNet.append(keras.layers.Conv2D(filters=96, input_shape=(32,32,3), kernel_size=(11,11), strides=(4,4), padding=9)) +AlexNet.append(keras.layers.Activation('relu')) +AlexNet.append(keras.layers.MaxPooling2D(pool_size=3, strides=(2,2))) +if batchnorm: + AlexNet.append(keras.layers.BatchNormalization()) + +#2nd Convolutional Layer +AlexNet.append(keras.layers.Conv2D(filters=256, kernel_size=(5, 5), strides=(1,1), padding=1)) +AlexNet.append(keras.layers.Activation('relu')) +if batchnorm: + AlexNet.append(keras.layers.BatchNormalization()) +AlexNet.append(keras.layers.MaxPooling2D(pool_size=(2,2), strides=1)) + +#3rd Convolutional Layer +AlexNet.append(keras.layers.Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), padding=1)) +AlexNet.append(keras.layers.Activation('relu')) + +#4th Convolutional Layer +AlexNet.append(keras.layers.Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), padding=1)) +AlexNet.append(keras.layers.Activation('relu')) + +#5th Convolutional Layer +AlexNet.append(keras.layers.Conv2D(filters=256, kernel_size=(3,3), strides=(1,1), padding=1)) +AlexNet.append(keras.layers.Activation('relu')) + +#Passing it to a Fully Connected layer +# 1st Fully Connected Layer +AlexNet.append(keras.layers.Dense(256)) +AlexNet.append(keras.layers.Activation('relu')) + +if 'dropout' in program.args: + AlexNet.append(keras.layers.Dropout(0.5)) + +#2nd Fully Connected Layer +AlexNet.append(keras.layers.Dense(256)) +AlexNet.append(keras.layers.Activation('relu')) + +if 'dropout' in program.args: + AlexNet.append(keras.layers.Dropout(0.5)) + +#Output Layer +AlexNet.append(keras.layers.Dense(10)) + + +tf.set_n_threads(n_threads) +program.options_from_args() +sfix.set_precision_from_args(program, adapt_ring=True) + +training_samples = MultiArray([50000, 32, 32, 3], sfix) +training_labels = MultiArray([50000, 10], sint) + +test_samples = MultiArray([10000, 32, 32, 3], sfix) +test_labels = MultiArray([10000, 10], sint) + +if 'no_acc' not in program.args: + training_labels.input_from(0) + training_samples.input_from(0) + + test_labels.input_from(0) + test_samples.input_from(0) + +model = tf.keras.models.Sequential(AlexNet) + +model.compile_by_args(program) + +model.build(training_samples.sizes) +model.summary() + +opt = model.fit( + training_samples, + training_labels, + epochs=n_epochs, + batch_size=batch_size, + validation_data=(test_samples, test_labels) +) diff --git a/Programs/Source/keras_cifar_lenet.mpc b/Programs/Source/keras_cifar_lenet.mpc new file mode 100644 index 000000000..882d2e187 --- /dev/null +++ b/Programs/Source/keras_cifar_lenet.mpc @@ -0,0 +1,45 @@ +# this trains LeNet on MNIST with a dropout layer +# see https://github.com/csiro-mlai/mnist-mpc for data preparation + +program.options_from_args() + +training_samples = MultiArray([50000, 32, 32, 3], sfix) +training_labels = MultiArray([50000, 10], sint) + +test_samples = MultiArray([10000, 32, 32, 3], sfix) +test_labels = MultiArray([10000, 10], sint) + +training_labels.input_from(0) +training_samples.input_from(0) + +test_labels.input_from(0) +test_samples.input_from(0) + +from Compiler import ml +tf = ml +ml.set_n_threads(36) + +layers = [ + tf.keras.layers.Conv2D(20, 5, 1, 'valid', activation='relu'), + tf.keras.layers.MaxPooling2D(2), + tf.keras.layers.Conv2D(50, 5, 1, 'valid', activation='relu'), + tf.keras.layers.MaxPooling2D(2), + tf.keras.layers.Flatten(), + tf.keras.layers.Dropout(0.5), + tf.keras.layers.Dense(500, activation='relu'), + tf.keras.layers.Dense(10, activation='softmax') +] + +model = tf.keras.models.Sequential(layers) + +optim = tf.keras.optimizers.Adam(amsgrad=True) + +model.compile(optimizer=optim) + +opt = model.fit( + training_samples, + training_labels, + epochs=10, + batch_size=128, + validation_data=(test_samples, test_labels) +) diff --git a/Programs/Source/keras_mnist_dense.mpc b/Programs/Source/keras_mnist_dense.mpc index a525c0650..76b1e23f5 100644 --- a/Programs/Source/keras_mnist_dense.mpc +++ b/Programs/Source/keras_mnist_dense.mpc @@ -21,7 +21,8 @@ tf = ml layers = [ tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation='relu'), - tf.keras.layers.Dense(128, activation='relu'), + tf.keras.layers.Dense(128), + tf.keras.layers.Activation('relu'), tf.keras.layers.Dense(10, activation='softmax') ] diff --git a/Programs/Source/keras_mnist_lenet.mpc b/Programs/Source/keras_mnist_lenet.mpc index 9fdac27fd..674cf4036 100644 --- a/Programs/Source/keras_mnist_lenet.mpc +++ b/Programs/Source/keras_mnist_lenet.mpc @@ -20,8 +20,21 @@ tf = ml layers = [ tf.keras.layers.Conv2D(20, 5, 1, 'valid', activation='relu'), +] + +if 'batchnorm' in program.args: + layers += [tf.keras.layers.BatchNormalization()] + +layers += [ tf.keras.layers.MaxPooling2D(2), tf.keras.layers.Conv2D(50, 5, 1, 'valid', activation='relu'), +] + + +if 'batchnorm' in program.args: + layers += [tf.keras.layers.BatchNormalization()] + +layers += [ tf.keras.layers.MaxPooling2D(2), tf.keras.layers.Flatten(), tf.keras.layers.Dropout(0.5), diff --git a/Programs/Source/mnist_full_A.mpc b/Programs/Source/mnist_full_A.mpc index 9dc8a6851..37cd73d2d 100644 --- a/Programs/Source/mnist_full_A.mpc +++ b/Programs/Source/mnist_full_A.mpc @@ -21,6 +21,8 @@ elif 'debug' in program.args: n_test = 100 elif 'debug5000' in program.args: N = n_test = 5000 +elif 'mini' in program.args: + N = n_test = 10 else: N = 60000 n_test = 10000 @@ -39,6 +41,7 @@ except: batch_size = N N = min(N, 10000) +batch_size = min(batch_size, N) ml.Layer.back_batch_size = batch_size try: @@ -71,6 +74,9 @@ else: ml.Dense(N, n_inner, n_inner, activation=activation, debug=debug_ml), ml.Dense(N, n_inner, 10, debug=debug_ml)] +if 'batchnorm' in program.args: + layers.insert(1, ml.BatchNorm([N, n_inner])) + if 'dropout' in program.args: for i in range(len(layers) - 1, 0, -1): layers.insert(i, ml.Dropout(N, n_inner)) diff --git a/Programs/Source/mnist_full_C.mpc b/Programs/Source/mnist_full_C.mpc index 6ea76b260..04ca11ad6 100644 --- a/Programs/Source/mnist_full_C.mpc +++ b/Programs/Source/mnist_full_C.mpc @@ -53,7 +53,7 @@ except: ml.Layer.back_batch_size = batch_size layers = [ - ml.FixConv2d([n_examples, 28, 28, 1], (20, 5, 5, 1), (20,), [n_examples, 24, 24, 20], (1, 1), 'VALID'), + ml.FixConv2d([n_examples, 28, 28, 1], (20, 5, 5, 1), (20,), [N, 24, 24, 20], (1, 1), 'VALID'), ml.MaxPool([N, 24, 24, 20]), ml.Relu([N, 12, 12, 20]), ml.FixConv2d([N, 12, 12, 20], (50, 5, 5, 20), (50,), [N, 8, 8, 50], (1, 1), 'VALID'), @@ -66,6 +66,12 @@ layers = [ layers += [ml.MultiOutput.from_args(program, n_examples, 10)] +if 'batchnorm' in program.args: + for arg in program.args: + assert not arg.startswith('dropout') + layers.insert(4, ml.BatchNorm([N, 8, 8, 50], args=program.args)) + layers.insert(1, ml.BatchNorm([N, 24, 24, 20], args=program.args)) + if 'dropout' in program.args or 'dropout2' in program.args: layers.insert(8, ml.Dropout(N, 500)) elif 'dropout.25' in program.args: diff --git a/Protocols/Atlas.hpp b/Protocols/Atlas.hpp index c3a919b3d..9c6f0b9c8 100644 --- a/Protocols/Atlas.hpp +++ b/Protocols/Atlas.hpp @@ -85,6 +85,12 @@ void Atlas::exchange() resharing.add_mine(e); } + for (size_t i = 0; i < min(masks.size(), size_t(P.num_players())); i++) + { + int j = (base_king + i) % P.num_players(); + resharing.add_sender(j); + } + resharing.exchange(); } diff --git a/Protocols/Hemi.hpp b/Protocols/Hemi.hpp index e67b28a97..1eebd3b73 100644 --- a/Protocols/Hemi.hpp +++ b/Protocols/Hemi.hpp @@ -27,7 +27,7 @@ HemiMatrixPrep& Hemi::get_matrix_prep(const array& dims, if (matrix_preps.find(dims) == matrix_preps.end()) matrix_preps.insert({dims, new HemiMatrixPrep(dims[0], dims[1], dims[2], - dynamic_cast&>(processor.DataF))}); + dynamic_cast(processor.DataF))}); return *matrix_preps.at(dims); } diff --git a/Protocols/HemiMatrixPrep.h b/Protocols/HemiMatrixPrep.h index e48d92571..ea5a7211c 100644 --- a/Protocols/HemiMatrixPrep.h +++ b/Protocols/HemiMatrixPrep.h @@ -18,17 +18,18 @@ template class HemiMatrixPrep : public BufferPrep> { typedef BufferPrep> super; + typedef typename T::LivePrep LivePrep; int n_rows, n_inner, n_cols; bool swapped; DataPositions* usage; - HemiPrep* prep; + LivePrep* prep; HemiMatrixPrep(const HemiMatrixPrep&) = delete; public: - HemiMatrixPrep(int n_rows, int n_inner, int n_cols, HemiPrep& prep) : + HemiMatrixPrep(int n_rows, int n_inner, int n_cols, LivePrep& prep) : super(*(usage = new DataPositions)), n_rows(n_rows), n_inner(n_inner), n_cols(n_cols), prep(&prep) { diff --git a/Protocols/HemiMatrixPrep.hpp b/Protocols/HemiMatrixPrep.hpp index 82b28431c..f42212995 100644 --- a/Protocols/HemiMatrixPrep.hpp +++ b/Protocols/HemiMatrixPrep.hpp @@ -87,11 +87,10 @@ void HemiMatrixPrep::buffer_triples() assert(prep); auto& multipliers = prep->get_multipliers(); - assert(prep->pairwise_machine); - auto& FTD = prep->pairwise_machine->setup_p.FieldD; - auto& pk = prep->pairwise_machine->pk; + auto& FTD = prep->get_FTD(); + auto& pk = prep->get_pk(); int n_matrices = FTD.num_slots() / n_rows; -#ifdef VERBOSE +#ifdef VERBOSE_HE fprintf(stderr, "creating %d %dx%d * %dx%d triples\n", n_matrices, n_rows, n_inner, n_inner, n_cols); fflush(stderr); @@ -103,20 +102,23 @@ void HemiMatrixPrep::buffer_triples() AddableVector> C(n_matrices); MatrixRandMultJob job(C, A, B); - if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton()) + if (T::local_mul) { - auto& queues = BaseMachine::s().queues; - int start = queues.distribute(job, n_matrices); - job.begin = start; - job.end = n_matrices; - matrix_rand_mult(job); - queues.wrap_up(job); - } - else - { - job.begin = 0; - job.end = n_matrices; - matrix_rand_mult(job); + if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton()) + { + auto& queues = BaseMachine::s().queues; + int start = queues.distribute(job, n_matrices); + job.begin = start; + job.end = n_matrices; + matrix_rand_mult(job); + queues.wrap_up(job); + } + else + { + job.begin = 0; + job.end = n_matrices; + matrix_rand_mult(job); + } } #ifdef VERBOSE_HE @@ -130,26 +132,35 @@ void HemiMatrixPrep::buffer_triples() assert(prep->proc); auto& P = prep->proc->P; - Bundle bundle(P); - bundle.mine.store(diag.ciphertexts); - P.unchecked_broadcast(bundle); vector> others_ct; - for (auto& os : bundle) + + if (T::local_mul or OnlineOptions::singleton.direct) + { + Bundle bundle(P); + bundle.mine.store(diag.ciphertexts); + P.unchecked_broadcast(bundle); + for (auto& os : bundle) + { + others_ct.push_back({}); + os.get(others_ct.back(), Ciphertext(pk)); + } + } + else { - others_ct.push_back({}); - os.get(others_ct.back(), Ciphertext(pk)); + others_ct.push_back(diag.ciphertexts); + TreeSum().run(others_ct[0], P); } for (int j = 0; j < n_cols; j++) for (auto m : multipliers) { -#ifdef VERBOSE +#ifdef VERBOSE_HE fprintf(stderr, "column %d with party offset %d at %f\n", j, m->get_offset(), timer.elapsed()); fflush(stderr); #endif Ciphertext C(pk); - auto& multiplicands = others_ct[P.get_player(-m->get_offset())]; + auto& multiplicands = m->get_multiplicands(others_ct, pk); if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton()) { auto& queues = BaseMachine::s().queues; @@ -160,7 +171,7 @@ void HemiMatrixPrep::buffer_triples() CipherPlainMultJob job(products, multiplicands, multiplicands2, true); int start = queues.distribute(job, n_inner); #ifdef VERBOSE_HE - fprintf(stderr, "from %d in central thread\n", start); + fprintf(stderr, "from %d in central thread at %f\n", start, timer.elapsed()); fflush(stderr); #endif for (int i = start; i < n_inner; i++) @@ -185,7 +196,10 @@ void HemiMatrixPrep::buffer_triples() m->add(products[j], C, BOTH, n_inner); } - C += diag.dediag(products, n_matrices); + if (T::local_mul) + C += diag.dediag(products, n_matrices); + else + C = diag.dediag(products, n_matrices); for (int i = 0; i < n_matrices; i++) if (swapped) diff --git a/Protocols/HemiPrep.h b/Protocols/HemiPrep.h index c43b43e95..b2b510aa0 100644 --- a/Protocols/HemiPrep.h +++ b/Protocols/HemiPrep.h @@ -34,6 +34,9 @@ class HemiPrep : public SemiHonestRingPrep static void basic_setup(Player& P); static void teardown(); + static const FHE_PK& get_pk(); + static const FD& get_FTD(); + HemiPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), RingPrep(proc, usage), diff --git a/Protocols/HemiPrep.hpp b/Protocols/HemiPrep.hpp index 6cdd75476..c456424e5 100644 --- a/Protocols/HemiPrep.hpp +++ b/Protocols/HemiPrep.hpp @@ -34,6 +34,20 @@ void HemiPrep::basic_setup(Player& P) T::clear::template init(); } +template +const FHE_PK& HemiPrep::get_pk() +{ + assert(pairwise_machine); + return pairwise_machine->pk; +} + +template +const typename T::clear::FD& HemiPrep::get_FTD() +{ + assert(pairwise_machine); + return pairwise_machine->setup().FieldD; +} + template HemiPrep::~HemiPrep() diff --git a/Protocols/HemiShare.h b/Protocols/HemiShare.h index d299fb18f..4a85cbe34 100644 --- a/Protocols/HemiShare.h +++ b/Protocols/HemiShare.h @@ -27,6 +27,7 @@ class HemiShare : public SemiShare typedef HemiPrep LivePrep; static const bool needs_ot = false; + static const bool local_mul = true; static true_type triple_matmul; HemiShare() diff --git a/Protocols/LowGearKeyGen.hpp b/Protocols/LowGearKeyGen.hpp index 9ff92fb0e..be0fac61d 100644 --- a/Protocols/LowGearKeyGen.hpp +++ b/Protocols/LowGearKeyGen.hpp @@ -140,12 +140,12 @@ void KeyGenProtocol::output_to(int player, vector& opened, vector& shares) { PrivateOutput po(*proc); - vector masked; for (auto& share : shares) - masked.push_back(po.start(player, share)); - MC->POpen(opened, masked, P); + po.prepare_sending(share, player); + po.exchange(); + opened.resize(shares.size()); for (auto& x : opened) - x = po.stop(player, x); + x = po.finalize(player); } template diff --git a/Protocols/MAC_Check.h b/Protocols/MAC_Check.h index 571f391ef..2250417d0 100644 --- a/Protocols/MAC_Check.h +++ b/Protocols/MAC_Check.h @@ -52,6 +52,7 @@ class TreeSum virtual ~TreeSum(); void run(vector& values, const Player& P); + T run(const T& value, const Player& P); octetStream& get_buffer() { return os; } @@ -210,6 +211,14 @@ void TreeSum::run(vector& values, const Player& P) finish(values, P); } +template +T TreeSum::run(const T& value, const Player& P) +{ + vector values = {value}; + run(values, P); + return values[0]; +} + template size_t TreeSum::report_size(ReportType type) { @@ -244,14 +253,6 @@ void add_openings(vector& values, const Player& P, int sum_players, int last_ MC.player_timers[sender].start(); P.wait_receive(sender, oss[j]); MC.player_timers[sender].stop(); - if ((unsigned)oss[j].get_length() < values.size() * T::size()) - { - stringstream ss; - ss << "Not enough information received, expected " - << values.size() * T::size() << " bytes, got " - << oss[j].get_length(); - throw Processor_Error(ss.str()); - } MC.timers[SUM].start(); for (unsigned int i=0; i::Check(const Player& P) auto& vals = this->vals; auto& macs = this->macs; auto& popen_cnt = this->popen_cnt; + assert(int(macs.size()) <= popen_cnt); if (popen_cnt < 10) { diff --git a/Protocols/MAC_Check_Base.h b/Protocols/MAC_Check_Base.h index c7d477ad4..5a60281c6 100644 --- a/Protocols/MAC_Check_Base.h +++ b/Protocols/MAC_Check_Base.h @@ -12,6 +12,8 @@ using namespace std; #include "Networking/Player.h" #include "Tools/PointerVector.h" +template class Preprocessing; + /** * Abstract base class for opening protocols */ @@ -61,6 +63,8 @@ class MAC_Check_Base virtual void CheckFor(const typename T::open_type& value, const vector& shares, const Player& P); virtual const Player& get_check_player(const Player& P) const { return P; } + + virtual void set_prep(Preprocessing&) {} }; #endif /* PROTOCOLS_MAC_CHECK_BASE_H_ */ diff --git a/Protocols/MalRepRingShare.h b/Protocols/MalRepRingShare.h index 63bfe63ac..ff33a6eea 100644 --- a/Protocols/MalRepRingShare.h +++ b/Protocols/MalRepRingShare.h @@ -17,6 +17,7 @@ class MalRepRingShare : public MaliciousRep3Share> { typedef SignedZ2 T; typedef MaliciousRep3Share super; + typedef MalRepRingShare This; public: const static int BIT_LENGTH = K; @@ -26,7 +27,8 @@ class MalRepRingShare : public MaliciousRep3Share> typedef HashMaliciousRepMC MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput Input; - typedef ::PrivateOutput PrivateOutput; + typedef ReplicatedPO PO; + typedef SpecificPrivateOutput PrivateOutput; typedef MalRepRingPrepWithBits LivePrep; typedef MaliciousRep3Share> prep_type; typedef Z2 random_type; diff --git a/Protocols/MaliciousRep3Share.h b/Protocols/MaliciousRep3Share.h index f98e9797f..e6f3a8a6a 100644 --- a/Protocols/MaliciousRep3Share.h +++ b/Protocols/MaliciousRep3Share.h @@ -13,6 +13,7 @@ template class Beaver; template class MaliciousRepPrepWithBits; template class MaliciousRepPO; template class MaliciousRepPrep; +template class SpecificPrivateOutput; namespace GC { @@ -30,8 +31,8 @@ class MaliciousRep3Share : public Rep3Share typedef HashMaliciousRepMC> MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput> Input; - typedef ::PrivateOutput> PrivateOutput; typedef MaliciousRepPO PO; + typedef SpecificPrivateOutput PrivateOutput; typedef Rep3Share Honest; typedef MaliciousRepPrepWithBits LivePrep; typedef MaliciousRepPrep TriplePrep; diff --git a/Protocols/MaliciousShamirPO.h b/Protocols/MaliciousShamirPO.h index 65003d108..5bffe4f8e 100644 --- a/Protocols/MaliciousShamirPO.h +++ b/Protocols/MaliciousShamirPO.h @@ -9,13 +9,14 @@ template class MaliciousShamirPO { +protected: Player& P; octetStream to_send; vector to_receive; vector shares; - MaliciousShamirMC MC; + typename T::Direct_MC MC; public: MaliciousShamirPO(Player& P); diff --git a/Protocols/MaliciousShamirShare.h b/Protocols/MaliciousShamirShare.h index 47592981f..fee8e8292 100644 --- a/Protocols/MaliciousShamirShare.h +++ b/Protocols/MaliciousShamirShare.h @@ -13,6 +13,7 @@ template class MaliciousRepPrepWithBits; template class MaliciousRepPrep; template class MaliciousShamirPO; +template class SpecificPrivateOutput; namespace GC { @@ -23,14 +24,15 @@ template class MaliciousShamirShare : public ShamirShare { typedef ShamirShare super; + typedef MaliciousShamirShare This; public: typedef Beaver> Protocol; typedef MaliciousShamirMC MAC_Check; typedef MAC_Check Direct_MC; typedef ShamirInput Input; - typedef ::PrivateOutput PrivateOutput; typedef MaliciousShamirPO PO; + typedef SpecificPrivateOutput PrivateOutput; typedef ShamirShare Honest; typedef MaliciousRepPrepWithBits LivePrep; typedef MaliciousRepPrep TriplePrep; diff --git a/Protocols/MamaShare.h b/Protocols/MamaShare.h index fa3bc9f03..c90a5e277 100644 --- a/Protocols/MamaShare.h +++ b/Protocols/MamaShare.h @@ -76,12 +76,6 @@ class MamaShare : public Share_, MamaMac> return string(1, T::type_char()); } - static void read_or_generate_mac_key(string, Player&, mac_key_type& key) - { - SeededPRNG G; - key.randomize(G); - } - MamaShare() { } diff --git a/Protocols/PostSacriRepFieldShare.h b/Protocols/PostSacriRepFieldShare.h index a7fed8afb..06196762b 100644 --- a/Protocols/PostSacriRepFieldShare.h +++ b/Protocols/PostSacriRepFieldShare.h @@ -15,6 +15,7 @@ template class PostSacriRepFieldShare : public MaliciousRep3Share { typedef MaliciousRep3Share super; + typedef PostSacriRepFieldShare This; public: typedef typename super::clear clear; @@ -23,7 +24,8 @@ class PostSacriRepFieldShare : public MaliciousRep3Share typedef HashMaliciousRepMC MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput Input; - typedef ::PrivateOutput PrivateOutput; + typedef ReplicatedPO PO; + typedef SpecificPrivateOutput PrivateOutput; typedef MaliciousRepPrepWithBits LivePrep; PostSacriRepFieldShare() diff --git a/Protocols/PostSacriRepRingShare.h b/Protocols/PostSacriRepRingShare.h index d4f2ab0fd..7cbd483c4 100644 --- a/Protocols/PostSacriRepRingShare.h +++ b/Protocols/PostSacriRepRingShare.h @@ -17,6 +17,7 @@ template class PostSacriRepRingShare : public Rep3Share2 { typedef Rep3Share2 super; + typedef PostSacriRepRingShare This; public: static const int BIT_LENGTH = K; @@ -33,7 +34,8 @@ class PostSacriRepRingShare : public Rep3Share2 typedef HashMaliciousRepMC MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput Input; - typedef ::PrivateOutput PrivateOutput; + typedef ReplicatedPO PO; + typedef SpecificPrivateOutput PrivateOutput; typedef MalRepRingPrepWithBits LivePrep; typedef GC::MaliciousRepSecret bit_type; diff --git a/Protocols/ProtocolSet.h b/Protocols/ProtocolSet.h index e6a8eb525..09be88cb1 100644 --- a/Protocols/ProtocolSet.h +++ b/Protocols/ProtocolSet.h @@ -42,8 +42,13 @@ class ProtocolSet { } - ~ProtocolSet() + /** + * Run all protocol checks + */ + void check() { + protocol.check(); + output.Check(processor.P); } }; @@ -73,6 +78,15 @@ class BinaryProtocolSet *thread.protocol), input(output, prep, P) { } + + /** + * Run all protocol checks + */ + void check() + { + protocol.check(); + output.Check(protocol.P); + } }; /** @@ -102,6 +116,15 @@ class MixedProtocolSet arithmetic.protocol), input(arithmetic.input) { } + + /** + * Run all protocol checks + */ + void check() + { + arithmetic.check(); + binary.check(); + } }; #endif /* PROTOCOLS_PROTOCOLSET_H_ */ diff --git a/Protocols/Rep3Share.h b/Protocols/Rep3Share.h index e85065ac0..44853b79a 100644 --- a/Protocols/Rep3Share.h +++ b/Protocols/Rep3Share.h @@ -15,7 +15,8 @@ template class ReplicatedPrep; template class ReplicatedRingPrep; -template class PrivateOutput; +template class ReplicatedPO; +template class SpecificPrivateOutput; template class RepShare : public FixedVec, public ShareInterface @@ -99,6 +100,7 @@ template class Rep3Share : public RepShare { typedef RepShare super; + typedef Rep3Share This; public: typedef T clear; @@ -107,7 +109,8 @@ class Rep3Share : public RepShare typedef ReplicatedMC MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput Input; - typedef ::PrivateOutput PrivateOutput; + typedef ReplicatedPO PO; + typedef SpecificPrivateOutput PrivateOutput; typedef ReplicatedPrep LivePrep; typedef ReplicatedRingPrep TriplePrep; typedef Rep3Share Honest; diff --git a/Protocols/Rep3Share2k.h b/Protocols/Rep3Share2k.h index 23f28cf9b..e52d160bb 100644 --- a/Protocols/Rep3Share2k.h +++ b/Protocols/Rep3Share2k.h @@ -24,7 +24,8 @@ class Rep3Share2 : public Rep3Share> typedef ReplicatedMC MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput Input; - typedef ::PrivateOutput PrivateOutput; + typedef ReplicatedPO PO; + typedef SpecificPrivateOutput PrivateOutput; typedef ReplicatedPrep2k LivePrep; typedef Rep3Share2 Honest; typedef SignedZ2 clear; diff --git a/Protocols/Rep4Input.h b/Protocols/Rep4Input.h index f1bc29af9..04acd0043 100644 --- a/Protocols/Rep4Input.h +++ b/Protocols/Rep4Input.h @@ -31,7 +31,6 @@ class Rep4Input : public InputBase void add_mine(const typename T::open_type& input, int n_bits = -1); void add_other(int player, int n_bits = -1); - void send_mine(); void exchange(); T finalize_mine(); diff --git a/Protocols/Rep4Input.hpp b/Protocols/Rep4Input.hpp index 5600b45c7..48844396b 100644 --- a/Protocols/Rep4Input.hpp +++ b/Protocols/Rep4Input.hpp @@ -64,12 +64,6 @@ void Rep4Input::add_other(int player, int) results[player].push_back(res); } -template -void Rep4Input::send_mine() -{ - throw not_implemented(); -} - template void Rep4Input::exchange() { diff --git a/Protocols/Replicated.h b/Protocols/Replicated.h index 67527a208..2357d0f5e 100644 --- a/Protocols/Replicated.h +++ b/Protocols/Replicated.h @@ -19,10 +19,6 @@ using namespace std; template class SubProcessor; template class ReplicatedMC; template class ReplicatedInput; -template class ReplicatedPrivateOutput; -template class Share; -template class Rep3Share; -template class MAC_Check_Base; template class Preprocessing; class Instruction; @@ -141,9 +137,6 @@ class Replicated : public ReplicatedBase, public ProtocolBase void trunc_pr(const vector& regs, int size, U& proc, false_type); public: - typedef ReplicatedMC MAC_Check; - typedef ReplicatedInput Input; - static const bool uses_triples = false; Replicated(Player& P); diff --git a/Protocols/Replicated.hpp b/Protocols/Replicated.hpp index 374ed89b1..1a8a66b99 100644 --- a/Protocols/Replicated.hpp +++ b/Protocols/Replicated.hpp @@ -10,6 +10,7 @@ #include "Processor/Processor.h" #include "Processor/TruncPrTuple.h" #include "Tools/benchmarking.h" +#include "Tools/Bundle.h" #include "ReplicatedInput.h" #include "Rep3Share2k.h" @@ -162,14 +163,13 @@ void Replicated::prepare_mul(const T& x, } template -inline void Replicated::prepare_reshare(const typename T::clear& share, +void Replicated::prepare_reshare(const typename T::clear& share, int n) { - auto add_share = share; typename T::value_type tmp[2]; for (int i = 0; i < 2; i++) tmp[i].randomize(shared_prngs[i], n); - add_share += tmp[0] - tmp[1]; + auto add_share = share + tmp[0] - tmp[1]; add_share.pack(os[0], n); add_shares.push_back(add_share); } diff --git a/Protocols/ReplicatedPrep.hpp b/Protocols/ReplicatedPrep.hpp index 916ee6b8f..b12f7f91f 100644 --- a/Protocols/ReplicatedPrep.hpp +++ b/Protocols/ReplicatedPrep.hpp @@ -56,16 +56,24 @@ BufferPrep::~BufferPrep() << " bit generation" << endl; #endif + auto field_type = T::clear::field_type(); + auto& my_usage = this->usage.files.at(field_type); + this->print_left("triples", triples.size() * T::default_length, type_string, this->usage.files.at(T::clear::field_type()).at(DATA_TRIPLE) * T::default_length); + size_t used_bits = my_usage.at(DATA_BIT); + if (not T::clear::invertible and field_type == DATA_INT and not T::has_mac) + // add dabits with computation modulo power of two but without MAC + used_bits += my_usage.at(DATA_DABIT); + this->print_left("bits", bits.size(), type_string, used_bits); + #define X(KIND, TYPE) \ this->print_left(#KIND, KIND.size(), type_string, \ this->usage.files.at(T::clear::field_type()).at(TYPE)); X(squares, DATA_SQUARE) X(inverses, DATA_INVERSE) - X(bits, DATA_BIT) X(dabits, DATA_DABIT) #undef X @@ -601,17 +609,6 @@ void buffer_bits_from_players(vector>& player_bits, for (int i = 0; i < n_relevant_players; i++) for (auto& x : player_bits[i]) x = input.finalize((base_player + i) % P.num_players(), n_bits); -#if !defined(__clang__) && (__GNUC__ == 6) - // mitigate compiler bug - Bundle bundle(P); - P.unchecked_broadcast(bundle); -#endif -#ifdef DEBUG_BIT_SACRIFICE - typename T::MAC_Check MC; - for (int i = 0; i < n_relevant_players; i++) - for (auto& x : player_bits[i]) - assert((MC.open(x, P) == 0) or (MC.open(x, P) == 1)); -#endif } template @@ -1164,18 +1161,18 @@ void BufferPrep::buffer_inputs_as_usual(int player, SubProcessor* proc) typename T::clear r; r.randomize(G); input.add_mine(r); - this->inputs[player].push_back({input.finalize_mine(), r}); + this->inputs[player].push_back({input.finalize(player), r}); } - input.send_mine(); + input.exchange(); } else { - octetStream os; - P.receive_player(player, os); - T share; + for (int i = 0; i < buffer_size; i++) + input.add_other(player); + input.exchange(); for (int i = 0; i < buffer_size; i++) { - input.finalize_other(player, share, os); + auto share = input.finalize(player); this->inputs[player].push_back({share, 0}); } } diff --git a/Protocols/ReplicatedPrivateOutput.h b/Protocols/ReplicatedPrivateOutput.h deleted file mode 100644 index b9e546ca2..000000000 --- a/Protocols/ReplicatedPrivateOutput.h +++ /dev/null @@ -1,26 +0,0 @@ -/* - * ReplicatedPrivateOutput.h - * - */ - -#ifndef PROTOCOLS_REPLICATEDPRIVATEOUTPUT_H_ -#define PROTOCOLS_REPLICATEDPRIVATEOUTPUT_H_ - -template -class SubProcessor; -template -class Share; - -template -class ReplicatedPrivateOutput -{ - SubProcessor& proc; - -public: - ReplicatedPrivateOutput(SubProcessor& proc); - - void start(int player, int target, int source); - void stop(int player, int source); -}; - -#endif /* PROTOCOLS_REPLICATEDPRIVATEOUTPUT_H_ */ diff --git a/Protocols/ReplicatedPrivateOutput.hpp b/Protocols/ReplicatedPrivateOutput.hpp deleted file mode 100644 index d34872235..000000000 --- a/Protocols/ReplicatedPrivateOutput.hpp +++ /dev/null @@ -1,30 +0,0 @@ -/* - * ReplicatedPrivateOutput.cpp - * - */ - -#include "ReplicatedPrivateOutput.h" -#include "Processor/Processor.h" -#include "Math/FixedVec.h" -#include "Math/Integer.h" - -template -inline ReplicatedPrivateOutput::ReplicatedPrivateOutput( - SubProcessor& proc) : - proc(proc) -{ -} - -template -void ReplicatedPrivateOutput::start(int player, int target, - int source) -{ - (void)player, (void)target, (void)source; - throw runtime_error("not implemented, use PrivateOutput"); -} - -template -void ReplicatedPrivateOutput::stop(int player, int source) -{ - (void)player, (void)source; -} diff --git a/Protocols/Semi.h b/Protocols/Semi.h index e290ca0eb..5f63a9d62 100644 --- a/Protocols/Semi.h +++ b/Protocols/Semi.h @@ -71,6 +71,12 @@ class Semi : public SPDZ proc.get_S()[info.source_base + i] >> info.m; } } + + void buffer_random() + { + for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + this->random.push_back(G.get()); + } }; #endif /* PROTOCOLS_SEMI_H_ */ diff --git a/Protocols/SemiInput.h b/Protocols/SemiInput.h index 87a1e08e5..4fc265b7c 100644 --- a/Protocols/SemiInput.h +++ b/Protocols/SemiInput.h @@ -14,34 +14,33 @@ template class SemiMC; * Additive secret sharing input protocol */ template -class SemiInput : public IndividualInput +class SemiInput : public InputBase { - SeededPRNG secure_prng; + vector send_prngs; + vector recv_prngs; + Player& P; + vector> shares; public: - SemiInput(SubProcessor& proc, SemiMC& MC) : - IndividualInput(proc) + SemiInput(SubProcessor& proc, SemiMC&) : + SemiInput(&proc, proc.P) { - (void) MC; } - SemiInput(SubProcessor* proc, Player& P) : - IndividualInput(proc, P) - { - } + SemiInput(SubProcessor* proc, Player& P); SemiInput(typename T::MAC_Check& MC, Preprocessing& prep, Player& P) : - SemiInput(P) + SemiInput(0, P) { (void) MC, (void) prep; } - SemiInput(Player& P) : - IndividualInput(0, P) - { - } - + void reset(int player); void add_mine(const typename T::clear& input, int n_bits = -1); + void add_other(int player, int n_bits = -1); + void exchange(); + void finalize_other(int player, T& target, octetStream& o, int n_bits = -1); + T finalize_mine(); }; #endif /* PROTOCOLS_SEMIINPUT_H_ */ diff --git a/Protocols/SemiInput.hpp b/Protocols/SemiInput.hpp index 28673250f..3ed1feefe 100644 --- a/Protocols/SemiInput.hpp +++ b/Protocols/SemiInput.hpp @@ -11,22 +11,64 @@ #include "ShamirInput.hpp" template -void SemiInput::add_mine(const typename T::clear& input, int n_bits) +SemiInput::SemiInput(SubProcessor* proc, Player& P) : + InputBase(proc), P(P) +{ + shares.resize(P.num_players()); + vector to_send(P.num_players()), to_receive; + for (int i = 0; i < P.num_players(); i++) + { + send_prngs.push_back({}); + to_send[i].append(send_prngs.back().get_seed(), SEED_SIZE); + } + P.send_receive_all(to_send, to_receive); + recv_prngs.resize(P.num_players()); + for (int i = 0; i < P.num_players(); i++) + if (i != P.my_num()) + recv_prngs[i].SetSeed(to_receive[i].consume(SEED_SIZE)); + this->reset_all(P); +} + +template +void SemiInput::reset(int player) +{ + shares[player].clear(); +} + +template +void SemiInput::add_mine(const typename T::clear& input, int) { auto& P = this->P; typename T::open_type sum, share; for (int i = 0; i < P.num_players(); i++) { - if (i < P.num_players() - 1) - share.randomize(secure_prng, n_bits); - else - share = input - sum; - sum += share; - if (i == P.my_num()) - this->shares.push_back(share); - else - share.pack(this->os[i], n_bits); + if (i != P.my_num()) + sum += send_prngs[i].template get(); } + shares[P.my_num()].push_back(input - sum); +} + +template +void SemiInput::add_other(int, int) +{ +} + +template +void SemiInput::exchange() +{ +} + +template +void SemiInput::finalize_other(int player, T& target, octetStream&, + int) +{ + target = recv_prngs[player].template get(); +} + +template +T SemiInput::finalize_mine() +{ + return shares[P.my_num()].next(); } #endif diff --git a/Protocols/Shamir.h b/Protocols/Shamir.h index f722886eb..402173e98 100644 --- a/Protocols/Shamir.h +++ b/Protocols/Shamir.h @@ -27,7 +27,6 @@ class Shamir : public ProtocolBase { typedef typename T::open_type::Scalar U; - octetStreams os; vector reconstruction; U rec_factor; ShamirInput* resharing; diff --git a/Protocols/Shamir.hpp b/Protocols/Shamir.hpp index 9fe10bdea..8bfdf70ea 100644 --- a/Protocols/Shamir.hpp +++ b/Protocols/Shamir.hpp @@ -69,8 +69,6 @@ int Shamir::get_n_relevant_players() template void Shamir::reset() { - os.reset(P); - if (resharing == 0) { resharing = new ShamirInput(0, P); @@ -78,6 +76,9 @@ void Shamir::reset() for (int i = 0; i < P.num_players(); i++) resharing->reset(i); + + for (int i = 0; i < n_mul_players; i++) + resharing->add_sender(i); } template @@ -92,37 +93,27 @@ template void Shamir::prepare_mul(const T& x, const T& y, int n) { (void) n; - auto add_share = x * y * rec_factor; if (P.my_num() < n_mul_players) - resharing->add_mine(add_share); + resharing->add_mine(x * y * rec_factor); } template void Shamir::exchange() { - vector senders(P.num_players(), false); - for (int i = 0; i < n_mul_players; i++) - senders[i] = true; - P.send_receive_all(senders, resharing->os, os); + assert(resharing); + resharing->exchange(); } template void Shamir::start_exchange() { - if (P.my_num() < n_mul_players) - for (int offset = 1; offset < P.num_players(); offset++) - P.send_relative(offset, resharing->os[P.get_player(offset)]); + resharing->start_exchange(); } template void Shamir::stop_exchange() { - for (int offset = 1; offset < P.num_players(); offset++) - { - int receive_from = P.get_player(-offset); - if (receive_from < n_mul_players) - P.receive_player(receive_from, os[receive_from]); - } + resharing->stop_exchange(); } template @@ -136,15 +127,8 @@ template T Shamir::finalize(int n_relevant_players) { ShamirShare res = U(0); - if (P.my_num() < n_relevant_players) - res = resharing->finalize_mine(); for (int i = 0; i < n_relevant_players; i++) - if (i != P.my_num()) - { - T tmp; - resharing->finalize_other(i, tmp, os[i]); - res += tmp; - } + res += resharing->finalize(i); return res; } @@ -259,7 +243,7 @@ vector Shamir::get_randoms(PRNG& G, int t) input.reset_all(P); int buffer_size = OnlineOptions::singleton.batch_size; for (int i = 0; i < buffer_size; i += hyper.size()) - input.add_mine(G.get()); + input.add_from_all(G.get()); input.exchange(); vector inputs; vector random; diff --git a/Protocols/ShamirInput.h b/Protocols/ShamirInput.h index 023467077..91e093091 100644 --- a/Protocols/ShamirInput.h +++ b/Protocols/ShamirInput.h @@ -21,10 +21,11 @@ class IndividualInput : public PrepLessInput protected: Player& P; octetStreams os; + vector senders; public: IndividualInput(SubProcessor* proc, Player& P) : - PrepLessInput(proc), P(P) + PrepLessInput(proc), P(P), senders(P.num_players()) { this->reset_all(P); } @@ -34,10 +35,14 @@ class IndividualInput : public PrepLessInput } void reset(int player); + void add_sender(int player); void add_other(int player, int n_bits = -1); void send_mine(); void exchange(); void finalize_other(int player, T& target, octetStream& o, int n_bits = -1); + + void start_exchange(); + void stop_exchange(); }; /** diff --git a/Protocols/ShamirInput.hpp b/Protocols/ShamirInput.hpp index d84b09a6b..6d9992ad7 100644 --- a/Protocols/ShamirInput.hpp +++ b/Protocols/ShamirInput.hpp @@ -20,6 +20,8 @@ void IndividualInput::reset(int player) this->i_share = 0; os.reset(P); } + + senders[player] = false; } template @@ -68,12 +70,20 @@ void ShamirInput::add_mine(const typename T::open_type& input, int n_bits) else x.pack(this->os[i]); } + + this->senders[P.my_num()] = true; +} + +template +void IndividualInput::add_sender(int player) +{ + senders[player] = true; } template void IndividualInput::add_other(int player, int) { - (void) player; + add_sender(player); } template @@ -87,7 +97,26 @@ void IndividualInput::send_mine() template void IndividualInput::exchange() { - P.send_receive_all(os, InputBase::os); + P.send_receive_all(senders, os, InputBase::os); +} + +template +void IndividualInput::start_exchange() +{ + if (senders[P.my_num()]) + for (int offset = 1; offset < P.num_players(); offset++) + P.send_relative(offset, os[P.get_player(offset)]); +} + +template +void IndividualInput::stop_exchange() +{ + for (int offset = 1; offset < P.num_players(); offset++) + { + int receive_from = P.get_player(-offset); + if (senders[receive_from]) + P.receive_player(receive_from, InputBase::os[receive_from]); + } } template diff --git a/Protocols/ShamirMC.h b/Protocols/ShamirMC.h index 8f76d6a79..6bda92dfc 100644 --- a/Protocols/ShamirMC.h +++ b/Protocols/ShamirMC.h @@ -33,9 +33,12 @@ class IndirectShamirMC : public MAC_Check_Base template class ShamirMC : public IndirectShamirMC { + typedef typename T::open_type open_type; typedef typename T::open_type::Scalar rec_type; vector reconstruction; + ShamirMC(const ShamirMC&); + void finalize(vector& values, const vector& S); protected: @@ -71,6 +74,7 @@ class ShamirMC : public IndirectShamirMC void Check(const Player& P) { (void)P; } vector get_reconstruction(const Player& P); + open_type reconstruct(const vector& shares); }; #endif /* PROTOCOLS_SHAMIRMC_H_ */ diff --git a/Protocols/ShamirMC.hpp b/Protocols/ShamirMC.hpp index 6d6af9136..e3e7cd3ac 100644 --- a/Protocols/ShamirMC.hpp +++ b/Protocols/ShamirMC.hpp @@ -130,6 +130,19 @@ typename T::open_type ShamirMC::finalize_open() return res; } +template +typename T::open_type ShamirMC::reconstruct(const vector& shares) +{ + assert(reconstruction.size()); + typename T::open_type res; + for (size_t j = 0; j < reconstruction.size(); j++) + { + res += shares[j] * reconstruction[j]; + } + + return res; +} + template void IndirectShamirMC::exchange(const Player& P) { diff --git a/Protocols/ShamirShare.h b/Protocols/ShamirShare.h index 6e818c39f..e7daabfcf 100644 --- a/Protocols/ShamirShare.h +++ b/Protocols/ShamirShare.h @@ -13,6 +13,8 @@ template class ReplicatedPrep; template class ReplicatedRingPrep; +template class MaliciousShamirPO; +template class SpecificPrivateOutput; namespace GC { @@ -22,6 +24,8 @@ template class CcdSecret; template class ShamirShare : public T, public ShareInterface { + typedef ShamirShare This; + public: typedef T clear; typedef T open_type; @@ -34,7 +38,8 @@ class ShamirShare : public T, public ShareInterface typedef IndirectShamirMC MAC_Check; typedef ShamirMC Direct_MC; typedef ShamirInput Input; - typedef ::PrivateOutput PrivateOutput; + typedef MaliciousShamirPO PO; + typedef SpecificPrivateOutput PrivateOutput; typedef ReplicatedPrep LivePrep; typedef ReplicatedRingPrep TriplePrep; typedef ShamirShare Honest; diff --git a/Protocols/Share.h b/Protocols/Share.h index 743a2c614..92be4f144 100644 --- a/Protocols/Share.h +++ b/Protocols/Share.h @@ -55,6 +55,7 @@ class Share_ : public ShareInterface const static bool needs_ot = T::needs_ot; const static bool dishonest_majority = T::dishonest_majority; const static bool variable_players = T::variable_players; + const static bool has_mac = true; static int size() { return T::size() + V::size(); } diff --git a/Protocols/ShareInterface.h b/Protocols/ShareInterface.h index ae6e7b7dd..444214e47 100644 --- a/Protocols/ShareInterface.h +++ b/Protocols/ShareInterface.h @@ -34,6 +34,7 @@ class ShareInterface static const bool has_trunc_pr = false; static const bool has_split = false; + static const bool has_mac = false; static const false_type triple_matmul; diff --git a/Protocols/SpdzWiseInput.h b/Protocols/SpdzWiseInput.h index e9597527d..4c5675e91 100644 --- a/Protocols/SpdzWiseInput.h +++ b/Protocols/SpdzWiseInput.h @@ -36,11 +36,8 @@ class SpdzWiseInput : public InputBase void reset(int player); void add_mine(const typename T::open_type& input, int n_bits = -1); void add_other(int player, int n_bits = -1); - void send_mine(); void exchange(); T finalize(int player, int n_bits = -1); - T finalize_mine(); - void finalize_other(int player, T& target, octetStream& o, int n_bits = -1); }; #endif /* PROTOCOLS_SPDZWISEINPUT_H_ */ diff --git a/Protocols/SpdzWiseInput.hpp b/Protocols/SpdzWiseInput.hpp index e0d508e51..7aaa14c92 100644 --- a/Protocols/SpdzWiseInput.hpp +++ b/Protocols/SpdzWiseInput.hpp @@ -85,21 +85,3 @@ T SpdzWiseInput::finalize(int player, int) { return shares[player].next(); } - -template -void SpdzWiseInput::send_mine() -{ - throw runtime_error("use exchange()"); -} - -template -T SpdzWiseInput::finalize_mine() -{ - throw runtime_error("use finalize()"); -} - -template -void SpdzWiseInput::finalize_other(int, T&, octetStream&, int) -{ - throw runtime_error("use finalize()"); -} diff --git a/Protocols/SpdzWiseMC.h b/Protocols/SpdzWiseMC.h index 9e953e730..9991dafb2 100644 --- a/Protocols/SpdzWiseMC.h +++ b/Protocols/SpdzWiseMC.h @@ -32,7 +32,7 @@ class SpdzWiseMC : public MAC_Check_Base { } - void init_open(const Player& P, int n) + void init_open(const Player& P, int n = 0) { inner_MC.init_open(P, n); } diff --git a/Protocols/SpdzWisePrep.hpp b/Protocols/SpdzWisePrep.hpp index 9cb86017a..1090fc08e 100644 --- a/Protocols/SpdzWisePrep.hpp +++ b/Protocols/SpdzWisePrep.hpp @@ -15,7 +15,6 @@ #include "Spdz2kPrep.hpp" #include "ShamirMC.hpp" #include "MaliciousRepPO.hpp" -#include "MaliciousShamirPO.hpp" #include "GC/RepPrep.hpp" template diff --git a/Protocols/TemiPrep.h b/Protocols/TemiPrep.h new file mode 100644 index 000000000..de7406bba --- /dev/null +++ b/Protocols/TemiPrep.h @@ -0,0 +1,72 @@ +/* + * TemiPrep.h + * + */ + +#ifndef PROTOCOLS_TEMIPREP_H_ +#define PROTOCOLS_TEMIPREP_H_ + +#include "ReplicatedPrep.h" +#include "FHEOffline/TemiSetup.h" + +template class HemiMatrixPrep; + +template +class TemiMultiplier +{ + typedef typename T::clear::FD FD; + + vector multiplicands; + + Player& P; + +public: + TemiMultiplier(Player& P); + + vector& get_multiplicands( + vector>& ciphertexts, const FHE_PK& pk); + void add(Plaintext_& res, const Ciphertext& C, OT_ROLE role = BOTH, + int n_summands = 1); + + int get_offset() + { + return 0; + } +}; + +/** + * Semi-honest triple generation with semi-homomorphic encryption + */ +template +class TemiPrep : public SemiHonestRingPrep +{ + friend class HemiMatrixPrep; + + typedef typename T::clear::FD FD; + + static Lock lock; + static TemiSetup* setup; + + vector*> multipliers; + +public: + static void basic_setup(Player& P); + static void teardown(); + + static const FD& get_FTD(); + static const FHE_PK& get_pk(); + static const TemiSetup& get_setup(); + + TemiPrep(SubProcessor* proc, DataPositions& usage) : + BufferPrep(usage), + BitPrep(proc, usage), RingPrep(proc, usage), + SemiHonestRingPrep(proc, usage) + { + } + + void buffer_triples(); + + vector*>& get_multipliers(); +}; + +#endif /* PROTOCOLS_TEMIPREP_H_ */ diff --git a/Protocols/TemiPrep.hpp b/Protocols/TemiPrep.hpp new file mode 100644 index 000000000..1088a99cc --- /dev/null +++ b/Protocols/TemiPrep.hpp @@ -0,0 +1,129 @@ +/* + * TemiPrep.hppg + * + * + */ + +#ifndef PROTOCOLS_TEMIPREP_HPP_ +#define PROTOCOLS_TEMIPREP_HPP_ + +#include "TemiPrep.h" +#include "FHEOffline/SimpleMachine.h" + +#include "FHEOffline/DataSetup.hpp" + +template +TemiSetup* TemiPrep::setup; + +template +Lock TemiPrep::lock; + +template +void TemiPrep::basic_setup(Player& P) +{ + assert(not setup); + setup = new TemiSetup; + MachineBase machine; + setup->secure_init(P, T::clear::length()); + read_or_generate_secrets(*setup, P, machine, 1, true_type()); + T::clear::template init(); +} + +template +void TemiPrep::teardown() +{ + if (setup) + delete setup; +} + +template +const typename T::clear::FD& TemiPrep::get_FTD() +{ + assert(setup); + return setup->FieldD; +} + +template +inline const FHE_PK& TemiPrep::get_pk() +{ + assert(setup); + return setup->pk; +} + +template +const TemiSetup& TemiPrep::get_setup() +{ + assert(setup); + return *setup; +} + +template +void TemiPrep::buffer_triples() +{ + lock.lock(); + if (setup == 0) + { + PlainPlayer P(this->proc->P.N, "Temi" + T::type_string()); + basic_setup(P); + } + lock.unlock(); + + auto& P = this->proc->P; + auto& FieldD = setup->FieldD; + + Plaintext_ a(FieldD), b(FieldD), c(FieldD); + + SeededPRNG G; + a.randomize(G); + b.randomize(G); + + TreeSum ts; + auto C = ts.run(setup->pk.encrypt(a), P); + C = ts.run(C * b + setup->pk.template encrypt(FieldD), P); + c = SimpleDistDecrypt(P, *setup).reshare(C); + + for (unsigned i = 0; i < a.num_slots(); i++) + this->triples.push_back({{a.element(i), b.element(i), c.element(i)}}); +} + +template +vector*>& TemiPrep::get_multipliers() +{ + assert(setup); + assert( + OnlineOptions::singleton.batch_size + <= setup->params.get_matrix_dim()); + assert(this->proc); + if (multipliers.empty()) + multipliers.push_back(new TemiMultiplier(this->proc->P)); + return multipliers; +} + +template +TemiMultiplier::TemiMultiplier(Player& P) : P(P) +{ +} + +template +vector& TemiMultiplier::get_multiplicands( + vector >& ciphertexts, const FHE_PK& pk) +{ + multiplicands.clear(); + multiplicands.resize(ciphertexts[0].size(), pk); + for (size_t j = 0; j < multiplicands.size(); j++) + for (size_t i = 0; i < ciphertexts.size(); i++) + multiplicands[j] += ciphertexts[i].at(j); + return multiplicands; +} + +template +void TemiMultiplier::add(Plaintext_& res, const Ciphertext& C, + OT_ROLE, int) +{ + TreeSum ts; + SimpleDistDecrypt dd(P, TemiPrep::get_setup()); + auto zero = TemiPrep::get_pk().template encrypt(TemiPrep::get_FTD()); + res += dd.reshare(ts.run(C + zero, P)); +} + +#endif /* PROTOCOLS_TEMIPREP_HPP_ */ diff --git a/Protocols/TemiShare.h b/Protocols/TemiShare.h new file mode 100644 index 000000000..f4f37dcd6 --- /dev/null +++ b/Protocols/TemiShare.h @@ -0,0 +1,42 @@ +/* + * TemiShare.h + * + */ + +#ifndef PROTOCOLS_TEMISHARE_H_ +#define PROTOCOLS_TEMISHARE_H_ + +#include "HemiShare.h" + +template class TemiPrep; +template class Hemi; + +template +class TemiShare : public HemiShare +{ + typedef TemiShare This; + typedef HemiShare super; + +public: + typedef SemiMC MAC_Check; + typedef DirectSemiMC Direct_MC; + typedef SemiInput Input; + typedef ::PrivateOutput PrivateOutput; + typedef typename conditional, Beaver>::type Protocol; + typedef TemiPrep LivePrep; + + static const bool needs_ot = false; + static const bool local_mul = false; + + TemiShare() + { + } + template + TemiShare(const U& other) : + super(other) + { + } + +}; + +#endif /* PROTOCOLS_TEMISHARE_H_ */ diff --git a/Protocols/fake-stuff.hpp b/Protocols/fake-stuff.hpp index 951cbfe74..45d92613f 100644 --- a/Protocols/fake-stuff.hpp +++ b/Protocols/fake-stuff.hpp @@ -317,7 +317,14 @@ void read_mac_key(const string& directory, int player_num, int nplayers, U& key) throw mac_key_error(filename); } - key.input(inpf,true); + try + { + key.input(inpf,true); + } + catch(exception&) + { + throw mac_key_error(filename); + } if (inpf.fail()) throw mac_key_error(filename); diff --git a/README.md b/README.md index bd1075121..99d0f0763 100644 --- a/README.md +++ b/README.md @@ -85,10 +85,31 @@ The following table lists all protocols that are fully supported. | --- | --- | --- | --- | --- | | Malicious, dishonest majority | [MASCOT / LowGear / HighGear](#secret-sharing) | [SPDZ2k](#secret-sharing) | [Tiny / Tinier](#secret-sharing) | [BMR](#bmr) | | Covert, dishonest majority | [CowGear / ChaiGear](#secret-sharing) | N/A | N/A | N/A | -| Semi-honest, dishonest majority | [Semi / Hemi / Soho](#secret-sharing) | [Semi2k](#secret-sharing) | [SemiBin](#secret-sharing) | [Yao's GC](#yaos-garbled-circuits) / [BMR](#bmr) | +| Semi-honest, dishonest majority | [Semi / Hemi / Temi / Soho](#secret-sharing) | [Semi2k](#secret-sharing) | [SemiBin](#secret-sharing) | [Yao's GC](#yaos-garbled-circuits) / [BMR](#bmr) | | Malicious, honest majority | [Shamir / Rep3 / PS / SY](#honest-majority) | [Brain / Rep[34] / PS / SY](#honest-majority) | [Rep3 / CCD / PS](#honest-majority) | [BMR](#bmr) | | Semi-honest, honest majority | [Shamir / ATLAS / Rep3](#honest-majority) | [Rep3](#honest-majority) | [Rep3 / CCD](#honest-majority) | [BMR](#bmr) | +Modulo prime and modulo 2^k are the two settings that allow +integer-like computation. For k = 64, the latter corresponds to the +computation available on the widely used 64-bit processors. GF(2^n) +denotes Galois extension fields of order 2^n, which are different to +computation modulo 2^n. In particular, every element has an inverse, +which is not the case modulo 2^n. See [this +article](https://en.wikipedia.org/wiki/Finite_field) for an +introduction. Modulo prime and GF(2^n) are lumped together because the +protocols are very similar due to the mathematical properties. + +Bin. SS stands for binary secret sharing, that is secret sharing +modulo two. In some settings, this requires specific protocols as some +protocols require the domain size to be larger than two. In other +settings, the protocol is the same mathematically speaking, but a +specific implementation allows for optimizations such as using the +inherent parallelism of bit-wise operations on machine words. + +A security model specifies how many parties are "allowed" to misbehave +in what sense. Malicious means that not following the protocol will at +least be detected while semi-honest means that even corrupted parties +are assumed to follow the protocol. See [this paper](https://eprint.iacr.org/2020/300) for an explanation of the various security models and a high-level introduction to multi-party computation. @@ -257,7 +278,9 @@ compute the preprocessing time for a particular computation. add `AVX_OT = 0` in addition. - For optimal results on Linux on ARM, add `ARCH = -march=-march=armv8.2-a+crypto` to `CONFIG.mine`. This enables the - hardware support for AES. + hardware support for AES. See the [GCC + documentation](https://gcc.gnu.org/onlinedocs/gcc/AArch64-Options.html#AArch64-Options) + on available options. - To benchmark online-only protocols or Overdrive offline phases, add the following line at the top: `MY_CFLAGS = -DINSECURE` - `PREP_DIR` should point to a local, unversioned directory to store preprocessing data (the default is `Player-Data` in the current directory). - For homomorphic encryption with GF(2^40), set `USE_NTL = 1`. @@ -501,6 +524,7 @@ The following table shows all programs for dishonest-majority computation using | `cowgear-party.x` | Adapted [LowGear](https://eprint.iacr.org/2017/1230) | Mod prime | Covert | `cowgear.sh` | | `chaigear-party.x` | Adapted [HighGear](https://eprint.iacr.org/2017/1230) | Mod prime | Covert | `chaigear.sh` | | `hemi-party.x` | Semi-homomorphic encryption | Mod prime | Semi-honest | `hemi.sh` | +| `temi-party.x` | Adapted [CDN01](https://eprint.iacr.org/2000/055) | Mod prime | Semi-honest | `temi.sh` | | `soho-party.x` | Somewhat homomorphic encryption | Mod prime | Semi-honest | `soho.sh` | | `semi-bin-party.x` | OT-based | Binary | Semi-honest | `semi-bin.sh` | | `tiny-party.x` | Adapted SPDZ2k | Binary | Malicious | `tiny.sh` | @@ -538,6 +562,11 @@ Hemi and Soho denote the stripped version version of LowGear and HighGear, respectively, for semi-honest security similar to Semi, that is, generating additively shared Beaver triples using semi-homomorphic encryption. +Temi in turn denotes the adaption of +[Cramer et al.](https://eprint.iacr.org/2000/055) to LWE-based +semi-homomorphic encryption. +Both Hemi and Temi use the diagonal packing by [Halevi and +Shoup](https://eprint.iacr.org/2014/106) for matrix multiplication. We will use MASCOT to demonstrate the use, but the other protocols work similarly. diff --git a/Scripts/prep-usage.py b/Scripts/prep-usage.py new file mode 100755 index 000000000..cb8ca6198 --- /dev/null +++ b/Scripts/prep-usage.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 + +import sys, os +import collections + +sys.path.append('.') + +from Compiler.program import * +from Compiler.instructions_base import * + +if len(sys.argv) <= 1: + print('Usage: %s ' % sys.argv[0]) + +res = collections.defaultdict(lambda: 0) +m = 0 + +tapename = next(Program.read_tapes(sys.argv[1])) +res = Tape.ReqNum() +for inst in Tape.read_instructions(tapename): + res.update(inst.get_usage()) + +for x in res.pretty(): + print(x) diff --git a/Scripts/temi.sh b/Scripts/temi.sh new file mode 100755 index 000000000..86f46c548 --- /dev/null +++ b/Scripts/temi.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +. $HERE/run-common.sh + +run_player temi-party.x $* || exit 1 diff --git a/Scripts/test_tutorial.sh b/Scripts/test_tutorial.sh index 10fe575f2..e8c02f6cb 100755 --- a/Scripts/test_tutorial.sh +++ b/Scripts/test_tutorial.sh @@ -59,7 +59,7 @@ for dabit in ${dabit:-0 1 2}; do ./compile.py $compile_opts tutorial for i in rep-field shamir mal-rep-field ps-rep-field sy-rep-field \ - atlas mal-shamir sy-shamir hemi semi \ + atlas mal-shamir sy-shamir hemi semi temi \ soho mascot; do test_vm $i $run_opts done diff --git a/Tools/Buffer.h b/Tools/Buffer.h index 941ec4256..ffd411233 100644 --- a/Tools/Buffer.h +++ b/Tools/Buffer.h @@ -86,6 +86,10 @@ octetStream check_file_signature(ifstream& file, const string& filename) { throw signature_mismatch(filename); } + catch (IO_Error&) + { + throw signature_mismatch(filename); + } if (file_signature() != file_spec) throw signature_mismatch(filename); return file_spec; diff --git a/Tools/Exceptions.cpp b/Tools/Exceptions.cpp index 96f69b0c5..f6f4ba2ec 100644 --- a/Tools/Exceptions.cpp +++ b/Tools/Exceptions.cpp @@ -35,8 +35,8 @@ wrong_gfp_size::wrong_gfp_size(const char* name, const bigint& p, { } -overflow::overflow(const char* name, size_t i, size_t n) : - runtime_error(string(name) + " overflow: " + to_string(i) + "/" + to_string(n)) +overflow::overflow(const string& name, size_t i, size_t n) : + runtime_error(name + " overflow: " + to_string(i) + "/" + to_string(n)) { } diff --git a/Tools/Exceptions.h b/Tools/Exceptions.h index 18406cf6c..fff8b2de4 100644 --- a/Tools/Exceptions.h +++ b/Tools/Exceptions.h @@ -237,7 +237,7 @@ class mac_key_error: public runtime_error class overflow : public runtime_error { public: - overflow(const char* name, size_t i, size_t n); + overflow(const string& name, size_t i, size_t n); }; class unknown_input_type : public runtime_error diff --git a/Tools/octetStream.h b/Tools/octetStream.h index cd90b0e94..676382eaf 100644 --- a/Tools/octetStream.h +++ b/Tools/octetStream.h @@ -80,6 +80,8 @@ class octetStream size_t get_ptr() const { return ptr; } /// Length size_t get_length() const { return len; } + /// Length including size tag + size_t get_total_length() const { return len + sizeof(len); } /// Allocation size_t get_max_length() const { return mxlen; } /// Data pointer diff --git a/Utils/binary-example.cpp b/Utils/binary-example.cpp index 45e5f3371..962b27753 100644 --- a/Utils/binary-example.cpp +++ b/Utils/binary-example.cpp @@ -129,12 +129,12 @@ void run(int argc, char** argv) output.prepare_open(c); } output.exchange(P); + set.check(); cout << "result: "; for (int i = 0; i < n; i++) cout << output.finalize_open() << " "; cout << endl; - protocol.check(); - output.Check(P); + set.check(); } diff --git a/Utils/mixed-example.cpp b/Utils/mixed-example.cpp index 532d705e4..a36949d6f 100644 --- a/Utils/mixed-example.cpp +++ b/Utils/mixed-example.cpp @@ -126,12 +126,12 @@ void run(char** argv) output.prepare_open(res); } output.exchange(P); - bit_output.Check(P); + set.check(); cout << "result: "; for (int i = 0; i < n; i++) cout << output.finalize_open() << " "; cout << endl; - output.Check(P); + set.check(); } diff --git a/Utils/paper-example.cpp b/Utils/paper-example.cpp index 9cae6953f..83571c218 100644 --- a/Utils/paper-example.cpp +++ b/Utils/paper-example.cpp @@ -110,7 +110,7 @@ void run(char** argv, int prime_length) c = protocol.finalize_dotprod(n); // protocol check before revealing results - protocol.check(); + set.check(); output.init_open(P); output.prepare_open(c); @@ -120,5 +120,5 @@ void run(char** argv, int prime_length) cout << "result: " << result << endl; // result check after opening - output.Check(P); + set.check(); } diff --git a/doc/instructions.rst b/doc/instructions.rst index 1a833994e..fb62066ed 100644 --- a/doc/instructions.rst +++ b/doc/instructions.rst @@ -85,12 +85,10 @@ Compiler.instructions module .. automodule:: Compiler.instructions :members: :no-undoc-members: - :exclude-members: asm_input, inputmask, lts, print_char4_regint, - print_char_regint, protectmemc, sqrs, - start_grind, startprivateoutput, stop_grind, - stopprivateoutput, writesocketc, writesocketint, - protectmemint, protectmems, print_mem, - matmul_base, g2muls, inputmixed_base, raw_output + :exclude-members: asm_input, sqrs, + start_grind, stop_grind, + writesocketc, writesocketint, + matmul_base, inputmixed_base, raw_output Compiler.GC.instructions module ------------------------------- diff --git a/doc/low-level.rst b/doc/low-level.rst index c70bf5b65..7f5474fd4 100644 --- a/doc/low-level.rst +++ b/doc/low-level.rst @@ -309,6 +309,11 @@ Share Types - ``SpdzWiseShare`` - `SPDZ-wise `_. ``T`` must be ``MaliciousShamirShare`` or ``MaliciousRep3Share``. + * + - ``TemiShare`` + - Semi-honest protocol with Beaver multiplication based on + threshold semi-homomorphic encryption. ``T`` must be + ``gfp_`` or ``gf2n_short``. Protocol Setup diff --git a/doc/non-linear.rst b/doc/non-linear.rst index bcdbbd3ae..e5df4c204 100644 --- a/doc/non-linear.rst +++ b/doc/non-linear.rst @@ -88,7 +88,7 @@ The following table lists the matching arithmetic and binary protocols. cut-and-choose analysis by `Furukawa et al. `_ * - - Semi, Hemi, Soho, Semi2k + - Semi, Hemi, Temi, Soho, Semi2k - SemiBin (Beaver triples modulo 2 using OT) * - `Malicious Shamir `_ diff --git a/doc/preprocessing.rst b/doc/preprocessing.rst index 1441e3524..21500c455 100644 --- a/doc/preprocessing.rst +++ b/doc/preprocessing.rst @@ -85,7 +85,37 @@ modulo the default 128-bit prime 00000025 -``Fake-Offline.x`` generates preprocessing data insecurely for a range -of protocols, and ``{mascot,cowgear,mal-shamir}-offline.x`` generate +The actual data is stored is by simple concatenation. For example, +triples are stored as repetitions of ``a, b, ab``, and daBits are +stored as repetitions of ``a, b`` where ``a`` is the arithmetic +share and ``b`` is the binary share. + +For protocols with MAC, the value share is stored before the MAC +share. + +Values are generally stored in little-endian order. Note the following +domain specifics: + +Modulo a prime + Values are stored in `Montgomery representation + `_ + with :math:`R` being the smallest power of :math:`2^{64}` larger than + the prime. For example, :math:`R = 2^{128}` for a 128-bit prime. + Furthermore, the values are stored in the smallest number of 8-byte + blocks necessary, all in little-endian order. + +Modulo a power of two: + Values are stored in the smallest number of 8-byte blocks necessary, + all in little-endian order. + +:math:`GF(2^n)` + Values are stored in blocks according to the storage size above, + all in little-endian order. + +For further details, have a look at ``Utils/Fake-Offline.cpp``, which +contains code that generates preprocessing data insecurely for a range +of protocols (underlying the binary ``Fake-Offline.x``). + +``{mascot,cowgear,mal-shamir}-offline.x`` generate sufficient preprocessing data for a specific high-level program with MASCOT, CowGear, and malicious Shamir secret sharing, respectively. diff --git a/doc/requirements.txt b/doc/requirements.txt index cd6467ed8..32add0c79 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -1 +1,2 @@ breathe +sphinx-rtd-theme==0.5.2