From 78fe3d8bad4e654276eb3a11a88173195e672f44 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 9 Jul 2024 12:17:25 +1000 Subject: [PATCH] Maintenance. --- BMR/AndJob.cpp | 2 +- BMR/AndJob.h | 4 +- CHANGELOG.md | 17 + CONFIG | 2 +- Compiler/GC/instructions.py | 6 +- Compiler/GC/types.py | 105 ++-- Compiler/allocator.py | 8 +- Compiler/circuit.py | 18 +- Compiler/comparison.py | 56 +- Compiler/compilerLib.py | 48 +- Compiler/config.py | 1 + Compiler/dijkstra.py | 2 +- Compiler/floatingpoint.py | 189 +++--- Compiler/instructions.py | 85 +++ Compiler/instructions_base.py | 119 +++- Compiler/library.py | 367 +++++++++--- Compiler/ml.py | 228 ++++++-- Compiler/mpc_math.py | 61 +- Compiler/non_linear.py | 47 +- Compiler/oram.py | 18 +- Compiler/program.py | 120 +++- Compiler/sorting.py | 2 + Compiler/sqrt_oram.py | 37 +- Compiler/types.py | 550 +++++++++++------- Compiler/util.py | 4 +- ECDSA/Fake-ECDSA.cpp | 2 +- ECDSA/fake-spdz-ecdsa-party.cpp | 1 + ECDSA/preprocessing.hpp | 4 +- ECDSA/sy-rep-ecdsa-party.cpp | 1 + ExternalIO/Client.hpp | 7 + ExternalIO/client.py | 63 ++ ExternalIO/domains.py | 2 +- ExternalIO/personal-client-example.py | 19 + FHEOffline/Producer.cpp | 2 + GC/BitAdder.hpp | 6 + GC/FakeSecret.cpp | 4 +- GC/FakeSecret.h | 4 +- GC/Instruction.cpp | 4 +- GC/NoShare.h | 15 +- GC/PersonalPrep.hpp | 13 +- GC/Processor.h | 21 +- GC/Processor.hpp | 92 ++- GC/Rep4Prep.cpp | 1 + GC/Secret.h | 8 +- GC/Secret.hpp | 4 +- GC/Semi.cpp | 4 +- GC/Semi.h | 2 +- GC/SemiPrep.cpp | 3 + GC/SemiSecret.h | 11 +- GC/SemiSecret.hpp | 53 +- GC/ShareParty.hpp | 15 +- GC/ShareSecret.h | 13 +- GC/ShareSecret.hpp | 4 +- GC/ShareThread.hpp | 26 +- GC/ShiftableTripleBuffer.h | 2 + GC/TinySecret.h | 11 + GC/instructions.h | 11 +- License.txt | 2 +- Machines/OTMachine.cpp | 4 +- Machines/Tinier.cpp | 1 + Machines/TripleMachine.cpp | 1 + Machines/mama-party.cpp | 1 + Machines/sy-rep-field-party.cpp | 1 + Machines/sy-rep-ring-party.cpp | 1 + Machines/tinier-party.cpp | 1 + Machines/tiny-party.cpp | 1 + Makefile | 10 +- Math/BitVec.h | 8 +- Math/FixedVec.h | 5 + Math/Z2k.h | 15 +- Math/gf2n.cpp | 1 + Math/gfp.h | 17 +- Math/gfpvar.h | 2 +- Math/modp.h | 2 +- Math/modp.hpp | 4 +- Networking/CryptoPlayer.cpp | 6 +- Networking/Player.cpp | 51 +- Networking/Player.h | 31 +- Networking/Server.cpp | 9 +- Networking/ServerSocket.cpp | 6 +- Networking/sockets.cpp | 9 +- Networking/ssl_sockets.h | 10 +- OT/BaseOT.h | 14 +- OT/BitDiagonal.cpp | 5 +- OT/NPartyTripleGenerator.hpp | 19 +- OT/OTCorrelator.hpp | 5 + OT/OTExtensionWithMatrix.cpp | 63 +- OT/OTExtensionWithMatrix.h | 13 +- OT/OTMultiplier.h | 2 + OT/OTMultiplier.hpp | 36 +- OT/OTTripleSetup.cpp | 28 +- OT/OTTripleSetup.h | 29 +- Processor/BaseMachine.cpp | 25 +- Processor/BaseMachine.h | 39 +- Processor/Conv2dTuple.h | 6 +- Processor/Data_Files.h | 33 +- Processor/Data_Files.hpp | 12 +- Processor/DummyProtocol.h | 11 +- Processor/ExternalClients.cpp | 1 - Processor/Instruction.cpp | 10 +- Processor/Instruction.h | 11 +- Processor/Instruction.hpp | 59 +- Processor/Machine.hpp | 37 +- Processor/Memory.h | 4 +- Processor/Memory.hpp | 16 +- Processor/Online-Thread.h | 1 + Processor/Online-Thread.hpp | 66 ++- Processor/OnlineMachine.h | 2 + Processor/OnlineMachine.hpp | 19 + Processor/OnlineOptions.cpp | 15 + Processor/OnlineOptions.h | 6 + Processor/PrepBase.cpp | 8 +- Processor/PrepBase.h | 16 +- Processor/PrepBuffer.h | 44 ++ Processor/Processor.h | 34 +- Processor/Processor.hpp | 113 +++- Processor/ProcessorBase.cpp | 3 +- Processor/ProcessorBase.h | 4 +- Processor/Program.cpp | 18 + Processor/Program.h | 1 + Processor/RingOptions.cpp | 7 +- Processor/ThreadQueue.cpp | 4 +- Processor/ThreadQueue.h | 5 +- Programs/Circuits | 2 +- Programs/Source/bench-dt.mpc | 6 +- Programs/Source/falcon_alex.mpc | 2 +- Programs/Source/htmac.mpc | 6 +- Programs/Source/keras_mnist_lenet_predict.mpc | 1 + Programs/Source/mnist_A.mpc | 18 +- Programs/Source/personal_client_example.py | 11 + Programs/Source/prf_mimc.mpc | 2 +- Programs/Source/torch_densenet.py | 53 ++ Programs/Source/torch_mnist_lenet.mpc | 2 +- Programs/Source/torch_mnist_lenet_predict.mpc | 9 +- Programs/Source/torch_resnet.py | 47 ++ Programs/Source/torch_squeeze.py | 53 ++ Protocols/AtlasShare.h | 4 +- Protocols/Beaver.hpp | 2 +- Protocols/BufferScope.h | 8 +- Protocols/DabitSacrifice.hpp | 5 +- Protocols/DealerMatrixPrep.h | 5 + Protocols/DealerMatrixPrep.hpp | 3 +- Protocols/FakeProtocol.h | 16 +- Protocols/FakeShare.h | 3 +- Protocols/FakeShare.hpp | 2 +- Protocols/Hemi.h | 3 + Protocols/Hemi.hpp | 67 ++- Protocols/HemiMatrixPrep.h | 2 + Protocols/HemiMatrixPrep.hpp | 23 +- Protocols/MAC_Check.hpp | 3 +- Protocols/MAC_Check_Base.h | 2 +- Protocols/MaliciousRepPrep.hpp | 9 +- Protocols/MamaPrep.hpp | 2 + Protocols/NoShare.h | 2 + Protocols/PostSacriRepRingShare.h | 2 +- Protocols/Rep3Share.h | 1 + Protocols/Rep3Share2k.h | 2 +- Protocols/Rep3Shuffler.h | 8 +- Protocols/Rep3Shuffler.hpp | 10 +- Protocols/Rep4.h | 4 +- Protocols/Rep4.hpp | 24 +- Protocols/Rep4Input.hpp | 8 + Protocols/Rep4MC.hpp | 3 + Protocols/Rep4Share2k.h | 4 +- Protocols/RepRingOnlyEdabitPrep.hpp | 9 +- Protocols/Replicated.h | 17 +- Protocols/Replicated.hpp | 13 +- Protocols/ReplicatedMC.hpp | 1 + Protocols/ReplicatedPrep.h | 12 +- Protocols/ReplicatedPrep.hpp | 47 +- Protocols/RingOnlyPrep.hpp | 2 +- Protocols/SecureShuffle.h | 14 +- Protocols/SecureShuffle.hpp | 20 +- Protocols/Semi.h | 2 +- Protocols/Semi2kShare.h | 2 +- Protocols/SemiInput.h | 2 + Protocols/SemiInput.hpp | 12 +- Protocols/SemiShare.h | 1 + Protocols/Shamir.hpp | 4 +- Protocols/ShamirInput.hpp | 3 + Protocols/ShamirShare.h | 2 + Protocols/Share.h | 8 +- Protocols/Share.hpp | 28 + Protocols/ShareInterface.cpp | 10 + Protocols/ShareInterface.h | 9 +- Protocols/ShuffleSacrifice.hpp | 55 +- Protocols/Spdz2kPrep.hpp | 34 +- Protocols/SpdzWise.h | 6 +- Protocols/SpdzWise.hpp | 8 +- Protocols/SpdzWisePrep.hpp | 3 +- Protocols/SpdzWiseRep3Shuffler.h | 40 ++ Protocols/SpdzWiseRep3Shuffler.hpp | 68 +++ Protocols/SpdzWiseRingShare.h | 2 +- Protocols/SpdzWiseShare.hpp | 2 + Protocols/fake-stuff.h | 30 +- Protocols/fake-stuff.hpp | 62 +- README.md | 51 +- Scripts/compile-emulate.py | 3 +- Scripts/memory-usage.py | 7 +- Scripts/test_ecdsa.sh | 7 +- Scripts/torch_mnist_lenet_import.py | 20 +- Tools/Buffer.cpp | 22 +- Tools/Buffer.h | 25 +- Tools/CheckVector.h | 106 +++- Tools/Exceptions.cpp | 10 + Tools/Exceptions.h | 2 + Tools/NamedStats.cpp | 34 ++ Tools/NamedStats.h | 22 + Tools/SwitchableOutput.h | 7 + Tools/octetStream.cpp | 10 +- Tools/octetStream.h | 71 ++- Tools/pprint.h | 4 + Utils/Check-Offline-Z2k.cpp | 1 + Utils/Check-Offline.cpp | 96 +-- Utils/Fake-Offline.cpp | 78 +-- Utils/check-passive.cpp | 2 +- Utils/stream-fake-mascot-triples.cpp | 5 +- Yao/YaoEvalWire.cpp | 3 +- Yao/YaoEvalWire.h | 2 +- Yao/YaoGarbleWire.cpp | 2 +- Yao/YaoGarbleWire.h | 2 +- deps/libOTe | 2 +- doc/add-instruction.rst | 171 ++++++ doc/conf.py | 1 + doc/index.rst | 2 + doc/instructions.rst | 5 +- doc/io.rst | 88 ++- doc/journey.rst | 2 + doc/machine-learning.rst | 8 + doc/ml-quickstart.rst | 50 +- doc/multinode.rst | 3 + doc/optimization.rst | 148 +++++ doc/preprocessing.rst | 10 +- doc/troubleshooting.rst | 8 + 234 files changed, 4265 insertions(+), 1359 deletions(-) create mode 100755 ExternalIO/personal-client-example.py create mode 100644 Processor/PrepBuffer.h create mode 100644 Programs/Source/personal_client_example.py create mode 100644 Programs/Source/torch_densenet.py create mode 100644 Programs/Source/torch_resnet.py create mode 100644 Programs/Source/torch_squeeze.py create mode 100644 Protocols/SpdzWiseRep3Shuffler.h create mode 100644 Protocols/SpdzWiseRep3Shuffler.hpp create mode 100644 Tools/NamedStats.cpp create mode 100644 Tools/NamedStats.h create mode 100644 doc/add-instruction.rst create mode 100644 doc/optimization.rst diff --git a/BMR/AndJob.cpp b/BMR/AndJob.cpp index a72542df3..0189b3abe 100644 --- a/BMR/AndJob.cpp +++ b/BMR/AndJob.cpp @@ -15,7 +15,7 @@ int AndJob::run() #endif __m128i* prf_output = new __m128i[PAD_TO_8(ProgramParty::s().get_n_parties())]; auto gate = gates.begin(); - vector< GC::Secret >& S = *this->S; + auto& S = *this->S; const vector& args = *this->args; int i_gate = 0; for (size_t i = start; i < end; i += 4) diff --git a/BMR/AndJob.h b/BMR/AndJob.h index 0ae7994ba..13fbaf94b 100644 --- a/BMR/AndJob.h +++ b/BMR/AndJob.h @@ -15,7 +15,7 @@ using namespace std; class AndJob { - vector< GC::Secret >* S; + StackedVector< GC::Secret >* S; const vector* args; public: @@ -25,7 +25,7 @@ class AndJob AndJob() : S(0), args(0), start(0), end(0), gate_id(0) {} - void reset(vector >& S, const vector& args, + void reset(StackedVector >& S, const vector& args, size_t start, gate_id_t gate_id, size_t n_gates, int n_parties) { this->S = &S; diff --git a/CHANGELOG.md b/CHANGELOG.md index 6abb6df67..adfd96802 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,22 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.3.9 (July 9, 2024) + +- Inference with non-sequential PyTorch networks +- SHA-3 for any input length (@hiddely) +- Improved client facilities +- Shuffling with malicious security for SPDZ-wise protocols by [Asharov et al.](https://ia.cr/2022/1595) +- More reusable bytecode via in-thread calling facility +- Recursive functions without return values +- Fewer rounds for parallel matrix multiplications (@vincent-ehrmanntraut) +- Optimized usage of SoftSpokenOT in semi-honest protocols +- More integrity checks on storage in MAC-based protocols +- Use C++17 +- Use glibc 2.18 for the binaries +- Fixed security bugs: remotely caused buffer overflows (#1382) +- Fixed security bug: Missing randomization before revealing to client +- Fixed security bug: Bias in Rep3 secure shuffling + ## 0.3.8 (December 14, 2023) - Functionality for multiple nodes per party diff --git a/CONFIG b/CONFIG index ba108166c..7194fd6c2 100644 --- a/CONFIG +++ b/CONFIG @@ -106,7 +106,7 @@ else BOOST = -lboost_thread $(MY_BOOST) endif -CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -I$(ROOT)/deps -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++11 -Werror +CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -I$(ROOT)/deps -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++17 -Werror CFLAGS += $(BREW_CFLAGS) CPPFLAGS = $(CFLAGS) LD = $(CXX) diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index 1b53f9300..52b4db4cb 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -203,8 +203,10 @@ def dynamic_arg_format(cls, args): def add_usage(self, req_node): for i, n in self.bases(iter(self.args)): size = self.args[i + 1] - req_node.increment(('bit', 'triple'), size * (n - 3) // 2) - req_node.increment(('bit', 'mixed'), size) + n = (n - 3) // 2 + req_node.increment(('bit', 'triple'), size * n) + if n > 1: + req_node.increment(('bit', 'mixed'), size * ((n + 63) // 64)) def copy(self, size, subs): return type(self)(*self.get_new_args(size, subs)) diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index ea1856792..0d82bbefa 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -13,7 +13,7 @@ from Compiler.types import vectorized_classmethod from Compiler.program import Tape, Program from Compiler.exceptions import * -from Compiler import util, oram, floatingpoint, library +from Compiler import util, oram, floatingpoint, library, comparison from Compiler import instructions_base import Compiler.GC.instructions as inst import operator @@ -21,6 +21,11 @@ import itertools from functools import reduce +class _binary: + def reveal_to(self, *args, **kwargs): + raise CompilerError( + '%s does not support revealing to indivual players' % type(self)) + class bits(Tape.Register, _structure, _bit): n = 40 unit = 64 @@ -149,6 +154,12 @@ def set_length(self, n): self.n = n def set_size(self, size): pass + def load_int(self, value): + n_limbs = math.ceil(self.n / self.unit) + for i in range(n_limbs): + self.conv_regint(min(self.unit, self.n - i * self.unit), + self[i], regint(value % 2 ** self.unit)) + value >>= self.unit def load_other(self, other): if isinstance(other, cint): assert(self.n == other.size) @@ -236,12 +247,14 @@ def _new_by_number(self, i, size=1): return res def if_else(self, x, y): """ - Vectorized oblivious selection:: + Bit-wise oblivious selection:: sb32 = sbits.get_type(32) print_ln('%s', sb32(3).if_else(sb32(5), sb32(2)).reveal()) - This will output 1. + This will output 1 because it selects the two least + significant bits from 5 and the rest of the bits from 2. + """ return result_conv(x, y)(self & (x ^ y) ^ y) def zero_if_not(self, condition): @@ -268,6 +281,9 @@ def copy_from_part(self, source, base, size): self.bit_compose(source.bit_decompose()[base:base + size])) def vector_size(self): return self.n + @staticmethod + def size_for_mem(): + return 1 class cbits(bits): """ Clear bits register. Helper type with limited functionality. """ @@ -302,13 +318,6 @@ def conv(cls, other): else: return super(cbits, cls).conv(other) types = {} - def load_int(self, value): - n_limbs = math.ceil(self.n / self.unit) - tmp = regint(size=n_limbs) - for i in range(n_limbs): - tmp[i].load_int(value % 2 ** self.unit) - value >>= self.unit - self.load_other(tmp) def store_in_dynamic_mem(self, address): inst.stmsdci(self, cbits.conv(address)) def clear_op(self, other, c_inst, ci_inst, op): @@ -502,11 +511,7 @@ def load_int(self, value): if self.n <= 32: inst.ldbits(self, self.n, value) else: - size = math.ceil(self.n / self.unit) - tmp = regint(size=size) - for i in range(size): - tmp[i].load_int((value >> (i * 64)) % 2**64) - self.load_other(tmp) + bits.load_int(self, value) def load_other(self, other): if isinstance(other, cbits) and self.n == other.n: inst.convcbit2s(self.n, self, other) @@ -675,7 +680,7 @@ def bit_adder(*args, **kwargs): def ripple_carry_adder(*args, **kwargs): return sbitint.ripple_carry_adder(*args, **kwargs) -class sbitvec(_vec, _bit): +class sbitvec(_vec, _bit, _binary): """ Vector of registers of secret bits, effectively a matrix of secret bits. This facilitates parallel arithmetic operations in binary circuits. Container types are not supported, use :py:obj:`sbitvec.get_type` for that. @@ -732,15 +737,16 @@ def get_type(cls, n): :py:obj:`v` and the columns by calling :py:obj:`elements`. """ class sbitvecn(cls, _structure): - @staticmethod - def malloc(size, creator_tape=None): - return sbit.malloc(size * n, creator_tape=creator_tape) + @classmethod + def malloc(cls, size, creator_tape=None): + return sbit.malloc( + size * cls.mem_size(), creator_tape=creator_tape) @staticmethod def n_elements(): return 1 @staticmethod def mem_size(): - return n + return sbits.get_type(n).mem_size() @classmethod def get_input_from(cls, player, size=1, f=0): """ Secret input from :py:obj:`player`. The input is decomposed @@ -780,38 +786,28 @@ def __init__(self, other=None, size=None): self.v = sbits.get_type(n)(other).bit_decompose() assert len(self.v) == n assert size is None or size == self.v[0].n - @vectorized_classmethod - def load_mem(cls, address): - size = instructions_base.get_global_vector_size() - if size not in (None, 1): - assert isinstance(address, int) or len(address) == 1 - sb = sbits.get_type(size) - return cls.from_vec(sb.bit_compose( - sbit.load_mem(address + i + j * n) for j in range(size)) - for i in range(n)) - if not isinstance(address, int): - v = [sbit.load_mem(x, size=n).v[0] for x in address] - return cls(v) + @classmethod + def load_mem(cls, address, size=None): + if isinstance(address, int) or len(address) == 1: + address = [address + i for i in range(size or 1)] else: - return cls.from_vec(sbit.load_mem(address + i) - for i in range(n)) + assert size == None + return cls( + [sbits.get_type(n).load_mem(x) for x in address]) def store_in_mem(self, address): size = 1 for x in self.v: if not util.is_constant(x): size = max(size, x.n) - v = [sbits.get_type(size).conv(x) for x in self.v] - if not isinstance(address, int) and len(address) != 1: - v = self.elements() - assert len(v) == len(address) - for x, y in zip(v, address): - for i, xx in enumerate(x.bit_decompose(n)): - xx.store_in_mem(y + i) + if isinstance(address, int): + address = range(address, address + size) + elif len(address) == 1: + address = [address + i * self.mem_size() + for i in range(size)] else: - assert isinstance(address, int) or len(address) == 1 - for i in range(n): - for j, x in enumerate(v[i].bit_decompose()): - x.store_in_mem(address + i + j * n) + assert size == len(address) + for x, dest in zip(self.elements(), address): + x.store_in_mem(dest) @classmethod def two_power(cls, nn, size=1): return cls.from_vec( @@ -864,7 +860,7 @@ def __init__(self, elements=None, length=None, input_length=None): assert isinstance(elements, sint) if Program.prog.use_split(): x = elements.split_to_two_summands(length) - v = sbitint.carry_lookahead_adder(x[0], x[1], fewer_inv=True) + v = sbitint.bit_adder(x[0], x[1]) else: prog = Program.prog if not prog.options.ring: @@ -877,6 +873,7 @@ def __init__(self, elements=None, length=None, input_length=None): length, prog.security) prog.use_edabit(backup) return + comparison.require_ring_size(length, 'A2B conversion') l = int(Program.prog.options.ring) r, r_bits = sint.get_edabit(length, size=elements.size) c = ((elements - r) << (l - length)).reveal() @@ -885,6 +882,8 @@ def __init__(self, elements=None, length=None, input_length=None): x = sbitintvec.from_vec(r_bits) + sbitintvec.from_vec(cb) v = x.v self.v = v[:length] + elif isinstance(elements, sbitvec): + self.v = elements.v elif elements is not None and not (util.is_constant(elements) and \ elements == 0): self.v = sbits.trans(elements) @@ -1347,13 +1346,19 @@ def elements(self): def __add__(self, other): if util.is_zero(other): return self - a, b = self.expand(other) + try: + a, b = self.expand(other) + except: + return NotImplemented v = sbitint.bit_adder(a, b) return self.get_type(len(v)).from_vec(v) __radd__ = __add__ __sub__ = _bitint.__sub__ def __rsub__(self, other): - a, b = self.expand(other) + try: + a, b = self.expand(other) + except: + return NotImplemented return self.from_vec(b) - self.from_vec(a) def __mul__(self, other): if isinstance(other, sbits): @@ -1447,7 +1452,7 @@ def output(self): inst.print_float_plainb(v, cbits.get_type(32)(-self.f), cbits(0), cbits(0), cbits(0)) -class sbitfix(_fix): +class sbitfix(_fix, _binary): """ Secret signed fixed-point number in one binary register. Use :py:obj:`set_precision()` to change the precision. @@ -1515,7 +1520,7 @@ class cls(_fix): cls.set_precision(f, k) return cls._new(cls.int_type(other), k, f) -class sbitfixvec(_fix, _vec): +class sbitfixvec(_fix, _vec, _binary): """ Vector of fixed-point numbers for parallel binary computation. Use :py:obj:`set_precision()` to change the precision. diff --git a/Compiler/allocator.py b/Compiler/allocator.py index 2aaac2015..557b5a945 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -76,7 +76,7 @@ def alloc(self, size): self.top += size self.limit = max(self.limit, self.top) if res >= REG_MAX: - raise RegisterOverflowError() + raise RegisterOverflowError(size) return res def free(self, base, size): @@ -209,7 +209,8 @@ def dealloc_reg(self, reg, inst, free): for x in itertools.chain(dup.duplicates, base.duplicates): to_check.add(x) - if reg not in self.program.base_addresses: + if reg not in self.program.base_addresses \ + and not isinstance(inst, call_arg): free.free(base) if inst.is_vec() and base.vector: self.defined[base] = inst @@ -608,7 +609,8 @@ def keep_text_order(inst, n): # so this threshold should lead to acceptable compile times even on slower processors. first_factor_total_number_of_values = instr.args[12 * matmul_idx + 3] * instr.args[12 * matmul_idx + 4] second_factor_total_number_of_values = instr.args[12 * matmul_idx + 4] * instr.args[12 * matmul_idx + 5] - max_dependencies_per_matrix = 1500**2 + max_dependencies_per_matrix = \ + self.block.parent.program.budget if first_factor_total_number_of_values > max_dependencies_per_matrix or second_factor_total_number_of_values > max_dependencies_per_matrix: if block.warn_about_mem and not block.parent.warned_about_mem: print('WARNING: Order of memory instructions not preserved due to long vector, errors possible') diff --git a/Compiler/circuit.py b/Compiler/circuit.py index e33f396fd..702f8a039 100644 --- a/Compiler/circuit.py +++ b/Compiler/circuit.py @@ -5,7 +5,7 @@ make Programs/Circuits -.. _`Bristol Fashion`: https://homes.esat.kuleuven.be/~nsmart/MPC +.. _`Bristol Fashion`: https://nigelsmart.github.io/MPC-Circuits """ import math @@ -15,6 +15,7 @@ from Compiler import util import itertools import struct +import os class Circuit: """ @@ -47,7 +48,12 @@ class Circuit: """ def __init__(self, name): + self.name = name self.filename = 'Programs/Circuits/%s.txt' % name + if not os.path.exists(self.filename): + if os.system('make Programs/Circuits'): + raise CompilerError('Cannot download circuit descriptions. ' + 'Make sure make and git are installed.') f = open(self.filename) self.functions = {} @@ -57,8 +63,9 @@ def __call__(self, *inputs): def run(self, *inputs): n = inputs[0][0].n, get_tape() if n not in self.functions: - self.functions[n] = function_block(lambda *args: - self.compile(*args)) + self.functions[n] = function_block( + lambda *args: self.compile(*args)) + self.functions[n].name = '%s(%d)' % (self.name, inputs[0][0].n) flat_res = self.functions[n](*itertools.chain(*inputs)) res = [] i = 0 @@ -124,7 +131,7 @@ def compile(self, *all_inputs): def sha3_256(x): """ - This function implements SHA3-256 for inputs of up to 1080 bits:: + This function implements SHA3-256 for inputs of any length:: from circuit import sha3_256 a = sbitvec.from_vec([]) @@ -138,7 +145,8 @@ def sha3_256(x): for x in a, b, c, d, e, f, g, h: sha3_256(x).reveal_print_hex() - This should output the `test vectors + This should output the hashes of the above inputs, beginning with + the `test vectors `_ of SHA3-256 for 0, 8, 16, and 24 bits as well as the hash of the 0 byte:: diff --git a/Compiler/comparison.py b/Compiler/comparison.py index 41b36a476..a5d0ab542 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -76,13 +76,13 @@ def require_ring_size(k, op): program.curr_tape.require_bit_length(k) @instructions_base.cisc -def LTZ(s, a, k, kappa): +def LTZ(s, a, k): """ s = (a ?< 0) k: bit length of a """ - movs(s, program.non_linear.ltz(a, k, kappa)) + movs(s, program.non_linear.ltz(a, k)) def LtzRing(a, k): from .types import sint, _bitint @@ -105,14 +105,14 @@ def LtzRing(a, k): u = CarryOutRaw(a[::-1], b[::-1]) return sint.conv(r_bin[m].bit_xor(c_prime >> m).bit_xor(u)) -def LessThanZero(a, k, kappa): +def LessThanZero(a, k): from . import types res = types.sint() - LTZ(res, a, k, kappa) + LTZ(res, a, k) return res @instructions_base.cisc -def Trunc(d, a, k, m, kappa, signed): +def Trunc(d, a, k, m, signed): """ d = a >> m @@ -124,7 +124,7 @@ def Trunc(d, a, k, m, kappa, signed): movs(d, a) return else: - movs(d, program.non_linear.trunc(a, k, m, kappa, signed)) + movs(d, program.non_linear.trunc(a, k, m, signed=signed)) def TruncRing(d, a, k, m, signed): program.curr_tape.require_bit_length(1) @@ -197,13 +197,13 @@ def TruncLeakyInRing(a, k, m, signed): shifted = ((a << (n_shift - m)) + (r << n_shift)).reveal(False) masked = shifted >> n_shift u = sint() - BitLTL(u, masked, r_bits[:n_bits], 0) + BitLTL(u, masked, r_bits[:n_bits]) res = (u << n_bits) + masked - r if signed: res -= (1 << (n_bits - 1)) return res -def TruncRoundNearest(a, k, m, kappa, signed=False): +def TruncRoundNearest(a, k, m, signed=False): """ Returns a / 2^m, rounded to the nearest integer. @@ -212,12 +212,10 @@ def TruncRoundNearest(a, k, m, kappa, signed=False): """ if m == 0: return a - nl = program.non_linear - nl.check_security(kappa) return program.non_linear.trunc_round_nearest(a, k, m, signed) @instructions_base.cisc -def Mod2m(a_prime, a, k, m, kappa, signed): +def Mod2m(a_prime, a, k, m, signed): """ a_prime = a % 2^m @@ -225,8 +223,6 @@ def Mod2m(a_prime, a, k, m, kappa, signed): m: compile-time integer signed: True/False, describes a """ - nl = program.non_linear - nl.check_security(kappa) movs(a_prime, program.non_linear.mod2m(a, k, m, signed)) def Mod2mRing(a_prime, a, k, m, signed): @@ -237,13 +233,13 @@ def Mod2mRing(a_prime, a, k, m, signed): tmp = a + r_prime c_prime = (tmp << shift).reveal(False) >> shift u = sint() - BitLTL(u, c_prime, r_bin[:m], 0) + BitLTL(u, c_prime, r_bin[:m]) res = (u << m) + c_prime - r_prime if a_prime is not None: movs(a_prime, res) return res -def Mod2mField(a_prime, a, k, m, kappa, signed): +def Mod2mField(a_prime, a, k, m, signed): from .types import sint r_dprime = program.curr_block.new_reg('s') r_prime = program.curr_block.new_reg('s') @@ -255,7 +251,7 @@ def Mod2mField(a_prime, a, k, m, kappa, signed): t = [program.curr_block.new_reg('s') for i in range(6)] c2m = program.curr_block.new_reg('c') c2k1 = program.curr_block.new_reg('c') - PRandM(r_dprime, r_prime, r, k, m, kappa) + PRandM(r_dprime, r_prime, r, k, m) ld2i(c2m, m) mulm(t[0], r_dprime, c2m) if signed: @@ -268,9 +264,9 @@ def Mod2mField(a_prime, a, k, m, kappa, signed): asm_open(True, c, t[3]) modc(c_prime, c, c2m) if const_rounds: - BitLTC1(u, c_prime, r, kappa) + BitLTC1(u, c_prime, r) else: - BitLTL(u, c_prime, r, kappa) + BitLTL(u, c_prime, r) mulm(t[4], u, c2m) submr(t[5], c_prime, r_prime) adds(a_prime, t[5], t[4]) @@ -288,13 +284,15 @@ def MaskingBitsInRing(m, strict=False): r_bin = r return sint.bit_compose(r), r_bin -def PRandM(r_dprime, r_prime, b, k, m, kappa, use_dabit=True): +def PRandM(r_dprime, r_prime, b, k, m, use_dabit=True): """ r_dprime = random secret integer in range [0, 2^(k + kappa - m) - 1] r_prime = random secret integer in range [0, 2^m - 1] b = array containing bits of r_prime """ - program.curr_tape.require_bit_length(k + kappa) + assert k >= m + kappa = program.security + program.curr_tape.require_bit_length(k + kappa, reason='statistical masking as in https://www.researchgate.net/publication/225092133_Improved_Primitives_for_Secure_Multiparty_Integer_Computation') from .types import sint if program.use_edabit() and not const_rounds: movs(r_dprime, sint.get_edabit(k + kappa - m, True)[0]) @@ -329,7 +327,7 @@ def PRandInt(r, k): bit(t[1][i]) adds(t[2][i], t[0][i], t[1][i]) -def BitLTC1(u, a, b, kappa): +def BitLTC1(u, a, b): """ u = a 1 and k % 2 == 1: a.append(None) @@ -424,12 +422,12 @@ def CarryOutAux(a, kappa): if k > 1: for i in range(k//2): u[i] = carry(a[2*i+1], a[2*i], i != k//2-1) - return CarryOutAux(u[:k//2][::-1], kappa) + return CarryOutAux(u[:k//2][::-1]) else: return a[0][1] # carry out with carry-in bit c -def CarryOut(res, a, b, c=0, kappa=None): +def CarryOut(res, a, b, c=0): """ res = last carry bit in addition of a and b @@ -456,7 +454,7 @@ def CarryOutRaw(a, b, c=0): s[0] = d[-1][0].bit_and(c) s[1] = d[-1][1] + s[0] d[-1][1] = s[1] - return CarryOutAux(d[::-1], None) + return CarryOutAux(d[::-1]) def CarryOutRawLE(a, b, c=0): """ Little-endian version """ @@ -469,7 +467,7 @@ def CarryOutLE(a, b, c=0): CarryOut(res, a[::-1], b[::-1], c) return res -def BitLTL(res, a, b, kappa): +def BitLTL(res, a, b): """ res = a 1: @@ -313,6 +322,8 @@ def build_program(self, name=None): self.prog.use_split(int(os.getenv("PLAYERS", 2))) if self.options.execute in ("rep4-ring",): self.prog.use_split(4) + if self.options.execute.find("dealer") >= 0: + self.prog.use_edabit(True) def build_vars(self): from . import comparison, floatingpoint, instructions, library, types @@ -368,7 +379,14 @@ def build_vars(self): "cfloat", "squant", ]: - del self.VARS[i] + class dummy: + def __init__(self, *args): + raise CompilerError(self.error) + dummy.error = i + " not availabe with binary circuits" + if i in ("cint", "cfix"): + dummy.error += ". See https://mp-spdz.readthedocs.io/en/" \ + "latest/Compiler.html#Compiler.types." + i + self.VARS[i] = dummy else: self.sint = types.sint self.sfix = types.sfix @@ -503,13 +521,15 @@ def finalize_compile(self): return self.prog - @staticmethod - def executable_from_protocol(protocol): - match = { - "ring": "replicated-ring", - "rep-field": "replicated-field", - "replicated": "replicated-bin" - } + match = { + "ring": "replicated-ring", + "rep-field": "replicated-field", + "replicated": "replicated-bin" + } + + @classmethod + def executable_from_protocol(cls, protocol): + match = cls.match if protocol in match: protocol = match[protocol] if protocol.find("bmr") == -1: @@ -588,11 +608,19 @@ def run(i): for filename in glob.glob("Player-Data/*.0"): connection.put(filename, dest + "Player-Data") + def run_with_error(i): + try: + run(i) + except IOError: + print('IO error when copying files, does %s have enough space?' % + hostnames[i]) + raise + import threading import random threads = [] for i in range(len(hosts)): - threads.append(threading.Thread(target=run, args=(i,))) + threads.append(threading.Thread(target=run_with_error, args=(i,))) for thread in threads: thread.start() for thread in threads: diff --git a/Compiler/config.py b/Compiler/config.py index 9297a7e7e..ceabd9735 100644 --- a/Compiler/config.py +++ b/Compiler/config.py @@ -2,6 +2,7 @@ REG_MAX = 2 ** 32 USER_MEM = 8192 +MEM_MAX = 2 ** 64 P_VALUES = { 32: 2147565569, \ 64: 9223372036855103489, \ diff --git a/Compiler/dijkstra.py b/Compiler/dijkstra.py index f9f43f940..8257f28cb 100644 --- a/Compiler/dijkstra.py +++ b/Compiler/dijkstra.py @@ -558,7 +558,7 @@ def test_stupid_dijkstra_on_cycle(n, n_loops=None): @for_range(n) def f(i): M[i][(i+1)%n] = ExtInt(1) - M[i][(i-1)%n] = ExtInt(1) + M[i][(i-1+n)%n] = ExtInt(1) if n_loops is not None: stop_timer(1) start_timer() diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index 0dcf4f818..e99febaa3 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -39,26 +39,25 @@ def maskRing(a, k): c = ((a + r_prime) << shift).reveal(False) >> shift return c, r -def maskField(a, k, kappa): +def maskField(a, k): r_dprime = types.sint() r_prime = types.sint() c = types.cint() r = [types.sint() for i in range(k)] - comparison.PRandM(r_dprime, r_prime, r, k, k, kappa) + comparison.PRandM(r_dprime, r_prime, r, k, k) # always signed due to usage in equality testing a += two_power(k) asm_open(True, c, a + two_power(k) * r_dprime + r_prime) return c, r @instructions_base.ret_cisc -def EQZ(a, k, kappa): +def EQZ(a, k): prog = program.Program.prog if prog.use_split(): from GC.types import sbitvec v = sbitvec(a, k).v bit = util.tree_reduce(operator.and_, (~b for b in v)) return types.sintbit.conv(bit) - prog.non_linear.check_security(kappa) return prog.non_linear.eqz(a, k) def bits(a,m): @@ -99,12 +98,12 @@ def or_op(a, b, void=None): def mul_op(a, b, void=None): return a * b -def PreORC(a, kappa=None, m=None, raw=False): +def PreORC(a, m=None, raw=False): k = len(a) if k == 1: return [a[0]] prog = program.Program.prog - kappa = kappa or prog.security + kappa = prog.security m = m or k if isinstance(a[0], types.sgf2n): max_k = program.Program.prog.galois_length - 1 @@ -128,13 +127,13 @@ def PreORC(a, kappa=None, m=None, raw=False): t = [types.sint() for i in range(m)] b = comparison.PreMulC([a[i] + 1 for i in range(k)]) for i in range(m): - comparison.Mod2(t[i], b[k-1-i], k, kappa, False) + comparison.Mod2(t[i], b[k-1-i], k, False) p[m-1-i] = 1 - t[i] return p else: # not constant-round anymore - s = [PreORC(a[i:i+max_k], kappa, raw=raw) for i in range(0,k,max_k)] - t = PreORC([si[-1] for si in s[:-1]], kappa, raw=raw) + s = [PreORC(a[i:i+max_k], raw=raw) for i in range(0,k,max_k)] + t = PreORC([si[-1] for si in s[:-1]], raw=raw) return sum(([or_op(x, y) for x in si] for si,y in zip(s[1:],t)), s[0])[-m:] @@ -175,6 +174,41 @@ def PreOpL2(op, items): output[2 * i] = op(v[i - 1], items[2 * i]) return output +def PreOpL2_vec(op, *items): + """ Vectorized version of :py:func:`PreOpL2` """ + k = len(items[0]) + for x in items: + assert len(x) == k + if k == 1: + return items + half = k // 2 + other_half = (k + 1) // 2 - 1 + u = op([x.get_vector(base=0, size=half, skip=2) for x in items], + [x.get_vector(base=1, size=half, skip=2) for x in items]) + assert len(u) == len(items) + assert len(u[0]) == half + v = PreOpL2_vec(op, *u) + if other_half: + w = op([x.get_vector(base=0, size=other_half) for x in v], + [x.get_vector(base=2, size=other_half, skip=2) for x in items]) + if half == other_half: + res = [type(x).zip(x, y) for x, y in zip(v, w)] + for i in range(len(res)): + res[i] = type(res[i]).concat((items[i].get_vector(base=0, size=1), + res[i])) + else: + if other_half: + for i in range(len(w)): + w[i] = type(w[i]).concat((items[i].get_vector(base=0, size=1), + w[i])) + else: + w = [x.get_vector(base=0, size=1) for x in items] + res = [type(x).zip(x, y) for x, y in zip(w, v)] + assert len(res) == len(items) + for x in res: + assert len(x) == k + return res + def PreOpN(op, items): """ Naive PreOp algorithm """ k = len(items) @@ -184,9 +218,9 @@ def PreOpN(op, items): output[i] = op(output[i-1], items[i]) return output -def PreOR(a, kappa=None, raw=False): +def PreOR(a=None, raw=False): if comparison.const_rounds: - return PreORC(a, kappa, raw=raw) + return PreORC(a, raw=raw) else: return PreOpL(or_op, a) @@ -199,24 +233,24 @@ def KOpL(op, a): t2 = KOpL(op, a[k//2:]) return op(t1, t2) -def KORL(a, kappa=None): +def KORL(a): """ log rounds k-ary OR """ k = len(a) if k == 1: return a[0] else: - t1 = KORL(a[:k//2], kappa) - t2 = KORL(a[k//2:], kappa) + t1 = KORL(a[:k//2]) + t2 = KORL(a[k//2:]) return t1 + t2 - t1.bit_and(t2) -def KORC(a, kappa): - return PreORC(a, kappa, 1)[0] +def KORC(a): + return PreORC(a, 1)[0] -def KOR(a, kappa): +def KOR(a): if comparison.const_rounds: - return KORC(a, kappa) + return KORC(a) else: - return KORL(a, None) + return KORL(a) def KMul(a): if comparison.const_rounds: @@ -262,7 +296,7 @@ def BitAdd(a, b, bits_to_compute=None): s[k] = c[k-1] return s -def BitDec(a, k, m, kappa, bits_to_compute=None): +def BitDec(a, k, m, bits_to_compute=None): return program.Program.prog.non_linear.bit_dec(a, k, m) def BitDecRingRaw(a, k, m): @@ -270,7 +304,7 @@ def BitDecRingRaw(a, k, m): n_shift = int(program.Program.prog.options.ring) - m if program.Program.prog.use_split(): x = a.split_to_two_summands(m) - bits = types._bitint.carry_lookahead_adder(x[0], x[1], fewer_inv=False) + bits = types._bitint.bit_adder(x[0], x[1]) return bits[:m] else: if program.Program.prog.use_edabit(): @@ -292,13 +326,14 @@ def BitDecRing(a, k, m): # reversing to reduce number of rounds return [types.sintbit.conv(bit) for bit in reversed(bits)][::-1] -def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None): +def BitDecFieldRaw(a, k, m, bits_to_compute=None): instructions_base.set_global_vector_size(a.size) r_dprime = types.sint() r_prime = types.sint() c = types.cint() r = [types.sint() for i in range(m)] - comparison.PRandM(r_dprime, r_prime, r, k, m, kappa) + comparison.PRandM(r_dprime, r_prime, r, k, m) + kappa = program.Program.prog.security pow2 = two_power(k + kappa) asm_open(True, c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime) res = r[0].bit_adder(r, list(r[0].bit_decompose_clear(c,m))) @@ -306,16 +341,16 @@ def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None): return res @instructions_base.bit_cisc -def BitDecField(a, k, m, kappa, bits_to_compute=None): - res = BitDecFieldRaw(a, k, m, kappa, bits_to_compute) +def BitDecField(a, k, m, bits_to_compute=None): + res = BitDecFieldRaw(a, k, m, bits_to_compute) return [types.sintbit.conv(bit) for bit in res] @instructions_base.ret_cisc -def Pow2(a, l, kappa): +def Pow2(a, l): comparison.program.curr_tape.require_bit_length(l - 1) m = int(ceil(log(l, 2))) - t = BitDec(a, m, m, kappa) + t = BitDec(a, m, m) return Pow2_from_bits(t) def Pow2_from_bits(bits): @@ -327,11 +362,12 @@ def Pow2_from_bits(bits): t[i] = t[i]*pow2k[i] + 1 - t[i] return KMul(t) -def B2U(a, l, kappa): - pow2a = Pow2(a, l, kappa) - return B2U_from_Pow2(pow2a, l, kappa), pow2a +def B2U(a, l): + pow2a = Pow2(a, l) + return B2U_from_Pow2(pow2a, l), pow2a -def B2U_from_Pow2(pow2a, l, kappa): +def B2U_from_Pow2(pow2a, l): + kappa = program.Program.prog.security r = [types.sint() for i in range(l)] t = types.sint() c = types.cint() @@ -353,17 +389,17 @@ def B2U_from_Pow2(pow2a, l, kappa): c = list(r_bits[0].bit_decompose_clear(c, l)) x = [r_bits[i].bit_xor(c[i]) for i in range(l)] #print ' '.join(str(b.value) for b in x) - y = PreOR(x, kappa) + y = PreOR(x) #print ' '.join(str(b.value) for b in y) return [types.sint.conv(1 - y[i]) for i in range(l)] -def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False): +def Trunc(a, l, m, compute_modulo=False, signed=False): """ Oblivious truncation by secret m """ prog = program.Program.prog if util.is_constant(m) and not compute_modulo: # cheaper res = type(a)(size=a.size) - comparison.Trunc(res, a, l, m, kappa, signed=signed) + comparison.Trunc(res, a, l, m, signed=signed) return res if l == 1: if compute_modulo: @@ -371,9 +407,9 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False): else: return a * (1 - m) if program.Program.prog.options.ring and not compute_modulo: - return TruncInRing(a, l, Pow2(m, l, kappa)) + return TruncInRing(a, l, Pow2(m, l)) else: - kappa = kappa or program.Program.prog.security + kappa = program.Program.prog.security r = [types.sint() for i in range(l)] r_dprime = types.sint(0) r_prime = types.sint(0) @@ -381,7 +417,7 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False): c = types.cint() ci = [types.cint() for i in range(l)] d = types.sint() - x, pow2m = B2U(m, l, kappa) + x, pow2m = B2U(m, l) for i in range(l): bit(r[i]) t1 = two_power(i) * r[i] @@ -398,7 +434,7 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False): for i in range(1,l): ci[i] = c % two_power(i) c_dprime = sum(ci[i]*(x[i-1] - x[i]) for i in range(1,l)) - d = program.Program.prog.non_linear.ltz(c_dprime - r_prime, l, kappa) + d = program.Program.prog.non_linear.ltz(c_dprime - r_prime, l) if compute_modulo: b = c_dprime - r_prime + pow2m * d return b, pow2m @@ -429,33 +465,33 @@ def TruncInRing(to_shift, l, pow2m): def SplitInRing(a, l, m): if l == 1: return m.if_else(a, 0), m.if_else(0, a), 1 - pow2m = Pow2(m, l, None) + pow2m = Pow2(m, l) upper = TruncInRing(a, l, pow2m) lower = a - upper * pow2m return lower, upper, pow2m -def TruncRoundNearestAdjustOverflow(a, length, target_length, kappa): - t = comparison.TruncRoundNearest(a, length, length - target_length, kappa) - overflow = t.greater_equal(two_power(target_length), target_length + 1, kappa) +def TruncRoundNearestAdjustOverflow(a, length, target_length): + t = comparison.TruncRoundNearest(a, length, length - target_length) + overflow = t.greater_equal(two_power(target_length), target_length + 1) s = (1 - overflow) * t + overflow * t.trunc_zeros(1, length, False) return s, overflow -def Int2FL(a, gamma, l, kappa=None): +def Int2FL(a, gamma, l): lam = gamma - 1 - s = a.less_than(0, gamma, security=kappa) - z = a.equal(0, gamma, security=kappa) + s = a.less_than(0, gamma) + z = a.equal(0, gamma) a = s.if_else(-a, a) - a_bits = a.bit_decompose(lam, security=kappa) + a_bits = a.bit_decompose(lam) a_bits.reverse() - b = PreOR(a_bits, kappa) + b = PreOR(a_bits) t = a * (1 + a.bit_compose(1 - b_i for b_i in b)) p = a.popcnt_bits(b) - lam if gamma - 1 > l: if types.sfloat.round_nearest: - v, overflow = TruncRoundNearestAdjustOverflow(t, gamma - 1, l, kappa) + v, overflow = TruncRoundNearestAdjustOverflow(t, gamma - 1, l) p = p + overflow else: - v = t.right_shift(gamma - l - 1, gamma - 1, kappa, signed=False) + v = t.right_shift(gamma - l - 1, gamma - 1, signed=False) else: v = 2**(l-gamma+1) * t p = (p + gamma - 1 - l) * z.bit_not() @@ -466,32 +502,31 @@ def FLRound(x, mode): *mode*: 0 -> floor, 1 -> ceil, -1 > trunc """ v1, p1, z1, s1, l, k = x.v, x.p, x.z, x.s, x.vlen, x.plen a = types.sint() - comparison.LTZ(a, p1, k, x.kappa) - b = p1.less_than(-l + 1, k, x.kappa) - v2, inv_2pow_p1 = Trunc(v1, l, -a * (1 - b) * x.p, x.kappa, True) - c = EQZ(v2, l, x.kappa) + comparison.LTZ(a, p1, k) + b = p1.less_than(-l + 1, k) + v2, inv_2pow_p1 = Trunc(v1, l, -a * (1 - b) * x.p, compute_modulo=True) + c = EQZ(v2, l) if mode == -1: away_from_zero = 0 mode = x.s else: away_from_zero = mode + s1 - 2 * mode * s1 v = v1 - v2 + (1 - c) * inv_2pow_p1 * away_from_zero - d = v.equal(two_power(l), l + 1, x.kappa) + d = v.equal(two_power(l), l + 1) v = d * two_power(l-1) + (1 - d) * v v = a * ((1 - b) * v + b * away_from_zero * two_power(l-1)) + (1 - a) * v1 s = (1 - b * mode) * s1 - z = or_op(EQZ(v, l, x.kappa), z1) + z = or_op(EQZ(v, l), z1) v = v * (1 - z) p = ((p1 + d * a) * (1 - b) + b * away_from_zero * (1 - l)) * (1 - z) return v, p, z, s @instructions_base.ret_cisc -def TruncPr(a, k, m, kappa=None, signed=True): +def TruncPr(a, k, m, signed=True): """ Probabilistic truncation [a/2^m + u] where Pr[u = 1] = (a % 2^m) / 2^m """ nl = program.Program.prog.non_linear - nl.check_security(kappa) return nl.trunc_pr(a, k, m, signed) def TruncPrRing(a, k, m, signed=True): @@ -540,16 +575,14 @@ def TruncPrRing(a, k, m, signed=True): res -= (1 << (k - m - 1)) return res -def TruncPrField(a, k, m, kappa=None): +def TruncPrField(a, k, m): if m == 0: return a - if kappa is None: - kappa = 40 b = two_power(k-1) + a r_prime, r_dprime = types.sint(), types.sint() comparison.PRandM(r_dprime, r_prime, [types.sint() for i in range(m)], - k, m, kappa, use_dabit=False) + k, m, use_dabit=False) two_to_m = two_power(m) r = two_to_m * r_dprime + r_prime c = (b + r).reveal(False) @@ -559,49 +592,49 @@ def TruncPrField(a, k, m, kappa=None): return d @instructions_base.ret_cisc -def SDiv(a, b, l, kappa, round_nearest=False): +def SDiv(a, b, l, round_nearest=False): theta = int(ceil(log(l / 3.5) / log(2))) alpha = two_power(2*l) w = types.cint(int(2.9142 * 2 ** l)) - 2 * b x = alpha - b * w y = a * w - y = y.round(2 * l + 1, l, kappa, round_nearest, signed=False) + y = y.round(2 * l + 1, l, nearest=round_nearest, signed=False) x2 = types.sint() - comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, True) + comparison.Mod2m(x2, x, 2 * l + 1, l, signed=True) x1 = comparison.TruncZeros(x - x2, 2 * l + 1, l, True) for i in range(theta-1): - y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa, - round_nearest, + y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, + nearest=round_nearest, signed=False) - y = y.round(2 * l + 1, l, kappa, round_nearest, signed=False) - x = x1 * x2 + (x2**2).round(2 * l + 1, l + 1, kappa, round_nearest, + y = y.round(2 * l + 1, l, nearest=round_nearest, signed=False) + x = x1 * x2 + (x2**2).round(2 * l + 1, l + 1, nearest=round_nearest, signed=False) - x = x1 * x1 + x.round(2 * l + 1, l - 1, kappa, round_nearest, + x = x1 * x1 + x.round(2 * l + 1, l - 1, nearest=round_nearest, signed=False) x2 = types.sint() - comparison.Mod2m(x2, x, 2 * l, l, kappa, False) + comparison.Mod2m(x2, x, 2 * l, l, signed=False) x1 = comparison.TruncZeros(x - x2, 2 * l + 1, l, True) - y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa, - round_nearest, signed=False) - y = y.round(2 * l + 1, l + 1, kappa, round_nearest) + y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, nearest=round_nearest, + signed=False) + y = y.round(2 * l + 1, l + 1, nearest=round_nearest) return y -def SDiv_mono(a, b, l, kappa): +def SDiv_mono(a, b, l): theta = int(ceil(log(l / 3.5) / log(2))) alpha = two_power(2*l) w = types.cint(int(2.9142 * two_power(l))) - 2 * b x = alpha - b * w y = a * w - y = TruncPr(y, 2 * l + 1, l + 1, kappa) + y = TruncPr(y, 2 * l + 1, l + 1) for i in range(theta-1): y = y * (alpha + x) # keep y with l bits - y = TruncPr(y, 3 * l, 2 * l, kappa) + y = TruncPr(y, 3 * l, 2 * l) x = x**2 # keep x with 2l bits - x = TruncPr(x, 4 * l, 2 * l, kappa) + x = TruncPr(x, 4 * l, 2 * l) y = y * (alpha + x) - y = TruncPr(y, 3 * l, 2 * l, kappa) + y = TruncPr(y, 3 * l, 2 * l) return y # LT bit comparison on shared bit values diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 230f62539..30c5aea99 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -488,6 +488,70 @@ class join_tape(base.Instruction): code = base.opcodes['JOIN_TAPE'] arg_format = ['int'] +class call_tape(base.DoNotEliminateInstruction): + """ Start tape/bytecode file in same thread. Arguments/return values + starting from :py:obj:`direction` are optional. + + :param: tape number (int) + :param: arg (regint) + :param: direction (0 for argument, 1 for return value) + :param: register type (see :py:obj:`vm_types`) + :param: register size (int) + :param: destination register + :param: source register + :param: (repeat from direction) + + """ + code = base.opcodes['CALL_TAPE'] + arg_format = tools.chain(['int', 'ci'], + tools.cycle(['int','int','int','*w','*'])) + + @staticmethod + def type_check(reg, type_id): + assert base.vm_types[reg.reg_type] == type_id + + def __init__(self, *args, **kwargs): + super(call_tape, self).__init__(*args, **kwargs) + for i in range(2, len(args), 5): + for reg in args[i + 3:i + 5]: + self.type_check(reg, args[i + 1]) + assert reg.size == args[i + 2] + assert args[i] in (0, 1) + assert args[i + 4 - args[i]].program == program.curr_tape + assert args[i + 3 + args[i]].program == program.tapes[args[0]] + + def get_def(self): + # hide registers from called tape + for i in range(2, len(self.args), 5): + if self.args[i]: + yield self.args[i + 3] + + def get_used(self): + # hide registers from called tape + yield self.args[1] + for i in range(2, len(self.args), 5): + if not self.args[i]: + yield self.args[i + 4] + + def add_usage(self, req_node): + req_node.num += program.tapes[self.args[0]].req_tree.aggregate() + +class call_arg(base.DoNotEliminateInstruction, base.VectorInstruction): + """ Pseudo instruction for arguments in connection with + :py:class:`call_tape`. + + :param: destination (register) + :param: register type (see :py:obj:`vm_types`) + + """ + code = base.opcodes['CALL_ARG'] + arg_format = ['*w','int'] + + def __init__(self, *args, **kwargs): + super(call_arg, self).__init__(*args, **kwargs) + for i in range(0, len(args), 2): + call_tape.type_check(args[i], args[i + 1]) + class crash(base.IOInstruction): """ Crash runtime if the value in the register is not zero. @@ -687,6 +751,27 @@ def __init__(self, *args): for i in range(1, len(args), 2): assert args[i] == len(args[i + 1]) +class zips(base.Instruction): + """ Zip vectors. + + :param: result (sint) + :param: operand (sint) + :param: operand (sint) + + """ + __slots__ = [] + code = base.opcodes['ZIPS'] + arg_format = ['sw','s','s'] + is_vec = lambda self: True + + def __init__(self, *args): + super(zips, self).__init__(*args) + assert len(args[0]) == len(args[1]) + len(args[2]) + assert len(args[1]) == len(args[2]) + + def get_code(self): + return super(zips, self).get_code(len(self.args[1])) + @base.gf2n @base.vectorize class mulc(base.MulBase): diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index 9e88ec58b..6d36a480a 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -68,6 +68,8 @@ USE_MATMUL = 0x1F, ACTIVE = 0xE9, CMDLINEARG = 0xEB, + CALL_TAPE = 0xEC, + CALL_ARG = 0xED, # Addition ADDC = 0x20, ADDS = 0x21, @@ -85,6 +87,7 @@ PREFIXSUMS = 0x2D, PICKS = 0x2E, CONCATS = 0x2F, + ZIPS = 0x3F, # Multiplication/division MULC = 0x30, MULM = 0x31, @@ -223,6 +226,17 @@ ) +vm_types = dict( + ci = 0, + sb = 1, + cb = 2, + s = 4, + c = 5, + sg = 6, + cg = 7, +) + + def int_to_bytes(x): """ 32 bit int to big-endian 4 byte conversion. """ assert(x < 2**32 and x >= -2**32) @@ -491,7 +505,8 @@ def new_instructions(self, size, regs): reset_global_vector_size() program.curr_tape = old_tape for x, bl in tape.req_bit_length.items(): - old_tape.require_bit_length(bl - 1, x) + old_tape.require_bit_length( + bl - 1, x, tape.bit_length_reason if x == 'p' else '') from Compiler.allocator import Merger merger = Merger(block, program.options, tuple(program.to_merge)) @@ -516,40 +531,48 @@ def new_instructions(self, size, regs): inst.copy(size, subs) reset_global_vector_size() - def expand_to_function(self, size, new_regs): - key = size, program.curr_tape, \ - tuple(arg for arg, reg in zip(self.args, new_regs) if reg is None), \ + class Arg: + def __init__(self, reg): + from Compiler.GC.types import bits + self.type = type(reg) + self.binary = isinstance(reg, bits) + self.reg = reg + def new(self, size): + if self.binary: + return self.type() + else: + return self.type(size=size) + def load(self): + return self.reg + def store(self, reg): + if self.type != type(None): + self.reg.update(reg) + def is_real(self): + return self.reg is not None + + def base_key(self, size, new_regs): + return size, tuple( + arg for arg, reg in zip(self.args, new_regs) if reg is None), \ tuple(type(reg) for reg in new_regs) + + @staticmethod + def get_name(key): + return '_'.join(['%s(%d)' % (function.__name__, key[0])] + + [str(x) for x in key[1]]) + + def expand_to_function(self, size, new_regs): + key = self.base_key(size, new_regs) + (program.curr_tape,) if key not in self.functions: + args = [self.Arg(x) for x in new_regs] from Compiler import library, types - from Compiler.GC.types import bits - class Arg: - def __init__(self, reg): - self.type = type(reg) - self.binary = isinstance(reg, bits) - self.reg = reg - # if reg is not None: - # program.base_addresses[reg] = None - def new(self): - if self.binary: - return self.type() - else: - return self.type(size=size) - def load(self): - return self.reg - def store(self, reg): - if self.type != type(None): - self.reg.update(reg) - args = [Arg(x) for x in new_regs] @library.function_block def f(): - res = [arg.new() for arg in args[:n_outputs]] - self.new_instructions(size, - res + [arg.load() for arg in args[n_outputs:]]) + res = [arg.new(size) for arg in args[:n_outputs]] + self.new_instructions( + size, res + [arg.load() for arg in args[n_outputs:]]) for reg, arg in zip(res, args): arg.store(reg) - f.name = '_'.join(['%s(%d)' % (function.__name__, size)] + - [str(x) for x in key[2]]) + f.name = self.get_name(key) self.functions[key] = f, args f, args = self.functions[key] for i in range(len(new_regs) - n_outputs): @@ -558,6 +581,31 @@ def f(): for i in range(n_outputs): new_regs[i].link(args[i].load()) + def expand_to_tape(self, size, new_regs): + key = self.base_key(size, new_regs) + args = [self.Arg(x) for x in new_regs] + if key not in self.functions: + from Compiler import library, types + @library.function_call_tape + def f(*in_args): + res = [arg.new(size) for arg in args[:n_outputs]] + in_args = list(in_args) + my_args = list(res) + for arg in args[n_outputs:]: + if arg.is_real(): + my_args.append(in_args.pop(0)) + else: + my_args.append(arg.reg) + self.new_instructions(size, my_args) + return res + f.name = self.get_name(key) + self.functions[key] = f + f = self.functions[key] + in_args = filter(lambda arg: arg.is_real(), args[n_outputs:]) + res = util.tuplify(f(*(arg.load() for arg in in_args))) + for i in range(n_outputs): + new_regs[i].link(res[i]) + def expand_merged(self, skip): if function.__name__ in skip: good = True @@ -595,7 +643,11 @@ def expand_merged(self, skip): raise if program.cisc_to_function and \ (program.curr_tape.singular or program.n_running_threads): - self.expand_to_function(size, new_regs) + if (program.options.garbled or program.options.binary or \ + not program.use_tape_calls) and not program.force_cisc_tape: + self.expand_to_function(size, new_regs) + else: + self.expand_to_tape(size, new_regs) else: self.new_instructions(size, new_regs) program.curr_block.n_rounds += self.n_rounds - 1 @@ -795,6 +847,12 @@ class ClearIntAF(RegisterArgFormat): reg_type = RegType.ClearInt name = 'regint' +class AnyRegAF(RegisterArgFormat): + reg_type = '*' + @staticmethod + def check(arg): + assert isinstance(arg, program.curr_tape.Register) + class IntArgFormat(ArgFormat): n_bits = 32 @@ -898,6 +956,8 @@ def __str__(self): 'sgw': SecretGF2NAF, 'ci': ClearIntAF, 'ciw': ClearIntAF, + '*': AnyRegAF, + '*w': AnyRegAF, 'i': ImmediateModpAF, 'ig': ImmediateGF2NAF, 'int': IntArgFormat, @@ -938,6 +998,7 @@ def __init__(self, *args, **kwargs): Instruction.count += 1 if Instruction.count % 100000 == 0: print("Compiled %d lines at" % self.__class__.count, time.asctime()) + sys.stdout.flush() if Instruction.count > 10 ** 7: print("Compilation produced more that 10 million instructions. " "Consider using './compile.py -l' or replacing for loops " diff --git a/Compiler/library.py b/Compiler/library.py index 967a74122..ca1f890e7 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -107,7 +107,7 @@ def print_plain_str(ss): elif isinstance(val, cfloat): val.print_float_plain() elif isinstance(val, (list, tuple, Array, SubMultiArray)): - print_str(*_expand_to_print(val)) + print_str(*_expand_to_print(val), print_secrets=print_secrets) else: try: val.output() @@ -314,7 +314,7 @@ def get_cmdline_arg(idx): return localint(res) def make_array(l, t=None): - if isinstance(l, Tape.Register): + if isinstance(l, types._structure): res = Array(len(l), t or type(l)) res[:] = l else: @@ -334,13 +334,12 @@ def start(self): return self def join(self): self.thread.join() - instructions.program.free(self.base, 'ci') - for reg_type,addr in self.bases.items(): - get_program().free(addr, reg_type.reg_type) + if self.base is not None: + instructions.program.free(self.base, 'ci') class Function: def __init__(self, function, name=None, compile_args=[]): - self.type_args = {} + self.last_key = None self.function = function self.name = name if name is None: @@ -348,46 +347,40 @@ def __init__(self, function, name=None, compile_args=[]): self.compile_args = compile_args def __call__(self, *args): args = tuple(arg.read() if isinstance(arg, MemValue) else arg for arg in args) - from .types import _types - get_reg_type = lambda x: \ - regint if isinstance(x, int) else _types.get(x.reg_type, type(x)) - key = len(args), get_tape() - if key not in self.type_args: + runtime_args = [] + reg_args = [] + key = self.base_key(), + for i,arg in enumerate(args): + if isinstance(arg, types._vectorizable): + key += (arg.shape, arg.value_type) + else: + arg = MemValue(arg) + reg_args.append(arg) + t = arg.value_type + key += (arg.size, t) + runtime_args.append(arg) + if key != self.last_key: # first call - type_args = collections.defaultdict(list) - for i,arg in enumerate(args): - if not isinstance(arg, types._vectorizable): - type_args[get_reg_type(arg)].append(i) + outer_runtime_args = runtime_args def wrapped_function(*compile_args): - base = get_arg() - bases = dict((t, regint.load_mem(base + i)) \ - for i,t in enumerate(sorted(type_args, - key=lambda x: - x.reg_type))) - runtime_args = list(args) - for t in sorted(type_args, key=lambda x: x.reg_type): - i = 0 - for i_arg in type_args[t]: - runtime_args[i_arg] = t.load_mem(bases[t] + i) - i += util.mem_size(t) - return self.function(*(list(compile_args) + runtime_args)) + addresses = regint.Array(len(outer_runtime_args), + address=get_arg()) + runtime_args = [] + for i, arg in enumerate(outer_runtime_args): + if isinstance(arg, MemValue): + arg = arg.value_type.load_mem( + address=addresses[i], size=arg.size) + runtime_args.append(arg) + self.result = self.function( + *(list(compile_args) + runtime_args)) + return self.result self.on_first_call(wrapped_function) - self.type_args[key] = type_args - type_args = self.type_args[key] - base = instructions.program.malloc(len(type_args), 'ci') - bases = dict((t, get_program().malloc(len(type_args[t]), t)) \ - for t in type_args) - for i,reg_type in enumerate(sorted(type_args, - key=lambda x: x.reg_type)): - store_in_mem(bases[reg_type], base + i) - j = 0 - for i_arg in type_args[reg_type]: - if get_reg_type(args[i_arg]) != reg_type: - raise CompilerError('type mismatch: "%s" not of type "%s"' % - (args[i_arg], reg_type)) - store_in_mem(args[i_arg], bases[reg_type] + j) - j += util.mem_size(reg_type) - return self.on_call(base, bases) + self.last_key = key + addresses = regint.Array(len(runtime_args)) + for i, arg in enumerate(reg_args): + addresses[i] = arg.address + return self.on_call(addresses._address, + [(arg.value_type, arg.address) for arg in reg_args]) class FunctionTape(Function): # not thread-safe @@ -401,6 +394,113 @@ def on_first_call(self, wrapped_function): single_thread=self.single_thread) def on_call(self, base, bases): return FunctionTapeCall(self.thread, base, bases) + @staticmethod + def base_key(): + pass + +class FunctionCallTape(FunctionTape): + def __init__(self, *args, **kwargs): + super(FunctionTape, self).__init__(*args, **kwargs) + self.instances = {} + def __call__(self, *args, **kwargs): + key = () + def process_for_key(arg): + nonlocal key + if isinstance(arg, types._vectorizable): + key += (arg.value_type, tuple(arg.shape)) + elif isinstance(arg, Tape.Register): + key += (type(arg), arg.size) + elif isinstance(arg, list): + key += (tuple(arg), 'l') + else: + key += (arg,) + for arg in args: + process_for_key(arg) + for name, arg in sorted(kwargs.items()): + key += (name, 'kw') + process_for_key(arg) + if key not in self.instances: + my_args = [] + def wrapped_function(): + actual_call_args = [] + def process_for_call(arg): + if isinstance(arg, Tape.Register): + my_arg = arg.same_type() + call_arg(my_arg, base.vm_types[my_arg.reg_type]) + my_args.append(my_arg) + return my_arg + elif isinstance(arg, types._vectorizable): + my_arg = arg.same_shape(address=regint()) + call_arg(my_arg.address, base.vm_types['ci']) + my_args.append(my_arg) + my_arg = arg.same_shape( + address=MemValue(my_arg.address)) + return my_arg + actual_call_args.append(my_arg) + else: + my_args.append(arg) + return arg + for arg in args: + actual_call_args.append(process_for_call(arg)) + actual_call_kwargs = {} + for name, arg in sorted(kwargs.items()): + actual_call_kwargs[name] = process_for_call(arg) + self.result = self.function(*actual_call_args, + **actual_call_kwargs) + if self.result is not None: + self.result = list(tuplify(self.result)) + for i, res in enumerate(self.result): + if util.is_constant(res): + self.result[i] = regint(res) + self.on_first_call(wrapped_function, key, my_args) + for name, arg in sorted(kwargs.items()): + args += arg, + return self.on_call(*self.instances[key], args) + def on_first_call(self, wrapped_function, key, inside_args): + program = get_program() + program.curr_tape + tape_handle = len(program.tapes) + # entry for recursion + self.instances[key] = tape_handle, None, inside_args + assert tape_handle == program.new_tape( + wrapped_function, name=self.name, args=self.compile_args, + single_thread=get_tape().singular, finalize=False, + thread_pool=get_tape().free_threads) + tape = program.tapes[tape_handle] + if self.result is not None: + self.result = list(tuplify(self.result)) + for reg in self.result: + reg.can_eliminate = False + tape.return_values.append(reg) + assert not tape.purged + get_program().finalize_tape(tape) + self.instances[key] = tape_handle, self.result, inside_args + def on_call(self, tape_handle, result, inside_args, args): + tape = get_program().tapes[tape_handle] + if tape.ran_threads and tape.free_threads != get_tape().free_threads: + raise CompilerError( + 'cannot call thread-running tape from another thread') + assert len(inside_args) == len(args) + out_result = [] + call_args = [] + if result is not None: + out_result = [reg.same_type() for reg in result] + for x, y in zip(out_result, result): + call_args += [ + 1, instructions_base.vm_types[x.reg_type], + x.size_for_mem(), x, y] + for x, y in zip(inside_args, args): + if isinstance(x, Tape.Register): + call_args += [ + 0, instructions_base.vm_types[x.reg_type], + x.size_for_mem(), x, y] + elif isinstance(x, types._vectorizable): + call_args += [0, base.vm_types['ci'], 1, + x.address, regint.conv(y.address)] + call_tape(tape_handle, regint(0), + *call_args) + break_point('call-%s' % self.name) + return untuplify(tuple(out_result)) def function_tape(function): return FunctionTape(function) @@ -413,18 +513,108 @@ def wrapper(function): def single_thread_function_tape(function): return FunctionTape(function, single_thread=True) -def memorize(x): +def function_call_tape(function): + if get_program().use_tape_calls: + return FunctionCallTape(function) + else: + return function + +def method_call_tape(function): + tapes = {} + def wrapper(self, *args, **kwargs): + def use(name): + x = self.__dict__[name] + return not isinstance(x, types.MultiArray) or \ + x.array._address is not None + key = (type(self),) + tuple(filter(use, sorted(self.__dict__))) + member_key = key[1:] + if key not in tapes: + def f(*args, **kwargs): + class Dummy(type(self)): + __init__ = lambda self: None + dummy = Dummy() + members = args[:len(member_key)] + real_args = args[len(member_key):] + addresses = {} + for name, member in zip(member_key, members): + dummy.__dict__[name] = member + if isinstance(member, types._vectorizable): + addresses[name] = member.address + res = function(dummy, *real_args, **kwargs) + for name, member in zip(member_key, members): + new_member = dummy.__dict__[name] + desc = '%s in %s.%s' % (name, type(self).__name__, + function.__name__) + if id(new_member) != id(member): + raise CompilerError('cannot change members ' + 'in method tape (%s)' % desc) + if isinstance(member, types._vectorizable) and \ + id(new_member.address) != id(addresses[name]): + raise CompilerError('cannot change memory address ' + 'in method tape (%s)' % desc) + if set(member_key) != set(dummy.__dict__): + raise CompilerError('cannot add members ' + 'in method tape (%s)' % desc) + return res + f.__name__ = '%s-%s' % (type(self).__name__, function.__name__) + tapes[key] = function_call_tape(f) + members = tuple(self.__dict__[x] for x in member_key) + res = tapes[key](*(members + args), **kwargs) + return res + return wrapper + +def function(function): + """ Create a run-time function. The arguments can be memory or basic + types, and return values can be basic types:: + + @function + def f(x, y, z): + y.write(1) + z[0] = 2 + return x + 3 + + a = MemValue(sint(0)) + b = sint.Array(10) + c = f(sint(4), a, b) + + print_ln('%s %s %s', a.reveal(), b[0].reveal(), c.reveal()) + + This should output:: + + 1 2 7 + + You can use run-time functions recursively but without return + values in this case. + + """ + return FunctionCallTape(function) + +def memorize(x, write=True): if isinstance(x, (tuple, list)): - return tuple(memorize(i) for i in x) + return tuple(memorize(i, write=write) for i in x) + elif x is None: + return else: - return MemValue(x) + return MemValue(x, write=write) def unmemorize(x): if isinstance(x, (tuple, list)): return tuple(unmemorize(i) for i in x) + elif x is None: + return else: return x.read() +def write_mem(dest, source): + if isinstance(dest, (tuple, list)): + assert len(dest) == len(source) + for x, y in zip(dest, source): + write_mem(x, y) + elif dest is None: + return + else: + dest.write(source) + class FunctionBlock(Function): def on_first_call(self, wrapped_function): p_return_address = get_tape().program.malloc(1, 'ci') @@ -470,6 +660,10 @@ def on_call(self, base, bases): if self.result is not None: return unmemorize(self.result) + @staticmethod + def base_key(): + return get_tape() + def function_block(function): return FunctionBlock(function) @@ -907,11 +1101,15 @@ def f(j): r = reducer(tuplify(loop_body(j)), mem_state) write_state_to_memory(r) state = mem_state - for i,x in enumerate(state): - if use_array: - mem_state[i] = x - else: - mem_state[i].write(x) + if use_array and len(state) and \ + isinstance(types._register, types._vectorizable): + mem_state[:] = state.get_vector() + else: + for i,x in enumerate(state): + if use_array: + mem_state[i] = x + else: + mem_state[i].write(x) def returner(): return untuplify(tuple(state)) return returner @@ -987,7 +1185,7 @@ def multithread(n_threads, n_items=None, max_size=None): .. code:: - @multithread(8, 25) + @multithread(3, 25) def f(base, size): ... """ @@ -1077,7 +1275,7 @@ def f(i): return loop_body(base + i) prog = get_program() thread_args = [] - if prog.curr_tape == prog.tapes[0]: + if prog.curr_tape.singular: prog.n_running_threads = n_threads if not util.is_zero(thread_rounds): prog.prevent_breaks = False @@ -1288,9 +1486,21 @@ def decorator(loop_body): return loop_body return decorator -def _run_and_link(function, g=None): +def _run_and_link(function, g=None, lock_lists=True): if g is None: g = function.__globals__ + if lock_lists: + class A(list): + def __init_(self, l): + self[:] = l + def __setitem__(*args): + raise Exception('you cannot change lists in branches, ' + 'use Array or MultiArray instead') + __delitem__ = append = clear = extend = insert = __setitem__ + pop = remove = reverse = sort = __setitem__ + for x in g: + if isinstance(g[x], list): + g[x] = A(g[x]) pre = copy.copy(g) res = function() _link(pre, g) @@ -1570,14 +1780,25 @@ def listen_for_clients(port): """ instructions.listen(regint.conv(port)) -def accept_client_connection(port): +def accept_client_connection(port, players=None): """ Accept client connection on specific port base. :param port: port base (int/regint/cint) + :param players: subset of players (default: all) :returns: client id + """ res = regint() - instructions.acceptclientconnection(res, regint.conv(port)) + if players is None: + instructions.acceptclientconnection(res, regint.conv(port)) + else: + @if_e(sum(regint(players) == + get_player_id()._v.expand_to_vector(len(players)))) + def _(): + res.update(accept_client_connection(port)) + @else_ + def _(): + res.update(-1) return res def init_client_connection(host, port, my_id, relative_port=True): @@ -1697,14 +1918,14 @@ def cint_cint_division(a, b, k, f): from Compiler.program import Program @instructions_base.ret_cisc -def sint_cint_division(a, b, k, f, kappa, nearest=False): +def sint_cint_division(a, b, k, f, nearest=False): """ type(a) = sint, type(b) = cint """ theta = int(ceil(log(k/3.5) / log(2))) two = cint(2) * two_power(f) sign_b = cint(1) - 2 * cint(b.less_than(0, k)) - sign_a = sint(1) - 2 * comparison.LessThanZero(a, k, kappa) + sign_a = sint(1) - 2 * comparison.LessThanZero(a, k) absolute_b = b * sign_b absolute_a = a * sign_a w0 = approximate_reciprocal(absolute_b, k, f, theta) @@ -1714,20 +1935,20 @@ def sint_cint_division(a, b, k, f, kappa, nearest=False): W = w0 for i in range(1, theta): - A = (A * W).round(2 * k, f, kappa=kappa, nearest=nearest, signed=True) + A = (A * W).round(2 * k, f, nearest=nearest, signed=True) temp = (B * W + 2 * (f - 1)) >> f W = two - temp B = temp return (sign_a * sign_b) * A -def IntDiv(a, b, k, kappa=None): +def IntDiv(a, b, k): l = 2 * k + 1 b = a.conv(b) return FPDiv(a.extend(l) << k, b.extend(l) << k, l, k, - kappa, nearest=True) + nearest=True) @instructions_base.ret_cisc -def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False): +def FPDiv(a, b, k, f, simplex_flag=False, nearest=False): """ Goldschmidt method as presented in Catrina10, """ @@ -1750,40 +1971,40 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False): base.set_global_vector_size(b.size) alpha = b.get_type(2 * k).two_power(2*f, size=b.size) - w = AppRcr(b, k, f, kappa, simplex_flag, nearest).extend(2 * k) + w = AppRcr(b, k, f, simplex_flag, nearest).extend(2 * k) x = alpha - b.extend(2 * k) * w base.reset_global_vector_size() y = a.extend(l_y) * w - y = y.round(l_y, f, kappa, nearest, signed=True) + y = y.round(l_y, f, nearest, signed=True) for i in range(theta - 1): x = x.extend(2 * k) y = y.extend(l_y) * (alpha + x).extend(l_y) x = x * x - y = y.round(l_y, 2*f, kappa, nearest, signed=True) - x = x.round(2*k, 2*f, kappa, nearest, signed=True) + y = y.round(l_y, 2*f, nearest, signed=True) + x = x.round(2*k, 2*f, nearest, signed=True) x = x.extend(2 * k) y = y.extend(l_y) * (alpha + x).extend(l_y) - y = y.round(l_y, 3 * f - res_f, kappa, nearest, signed=True) + y = y.round(l_y, 3 * f - res_f, nearest, signed=True) return y -def AppRcr(b, k, f, kappa=None, simplex_flag=False, nearest=False): +def AppRcr(b, k, f, simplex_flag=False, nearest=False): """ Approximate reciprocal of [b]: Given [b], compute [1/b] """ alpha = b.get_type(2 * k)(int(2.9142 * 2**k)) - c, v = b.Norm(k, f, kappa, simplex_flag) + c, v = b.Norm(k, f, simplex_flag) #v should be 2**{k - m} where m is the length of the bitwise repr of [b] d = alpha - 2 * c w = d * v - w = w.round(2 * k + 1, 2 * (k - f), kappa, nearest, signed=True) + w = w.round(2 * k + 1, 2 * (k - f), nearest, signed=True) # now w * 2 ^ {-f} should be an initial approximation of 1/b return w -def Norm(b, k, f, kappa, simplex_flag=False): +def Norm(b, k, f, simplex_flag=False): """ Computes secret integer values [c] and [v_prime] st. 2^{k-1} <= c < 2^k and c = b*v_prime @@ -1799,8 +2020,8 @@ def Norm(b, k, f, kappa, simplex_flag=False): absolute_val = sign * b #next 2 lines actually compute the SufOR for little indian encoding - bits = absolute_val.bit_decompose(k, kappa, maybe_mixed=True)[::-1] - suffixes = PreOR(bits, kappa)[::-1] + bits = absolute_val.bit_decompose(k, maybe_mixed=True)[::-1] + suffixes = PreOR(bits)[::-1] z = [0] * k for i in range(k - 1): diff --git a/Compiler/ml.py b/Compiler/ml.py index 63f815380..eb541c177 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -203,6 +203,20 @@ def wrapper(*args, **kwargs): copy_doc(wrapper, function) return wrapper +def _layer_method_call_tape(function): + function = method_call_tape(function) + def wrapper(self, *args, **kwargs): + self._Y.alloc() + if self.inputs and len(self.inputs) == 1: + backup = self.inputs + del self.inputs + res = function(self, *args, **kwargs) + self.inputs = backup + return res + else: + return function(self, *args, **kwargs) + return wrapper + class Tensor(MultiArray): def __init__(self, *args, **kwargs): kwargs['alloc'] = False @@ -259,6 +273,7 @@ def Y(self): def Y(self, value): self._Y = value + @_layer_method_call_tape def forward(self, batch=None, training=None): if batch is None: batch = Array.create_from(regint(0)) @@ -1045,7 +1060,8 @@ def f_prime_part(self, base, size): def _forward(self, batch=[0]): n_per_item = reduce(operator.mul, self.X.sizes[1:]) - @multithread(self.n_threads, len(batch) * n_per_item) + @multithread(self.n_threads, len(batch) * n_per_item, + max_size=program.budget) def _(base, size): self.Y.assign_vector(self.f_part(base, size), base) @@ -1195,6 +1211,7 @@ class MaxPool(PoolBase): list/tuple of integers """ + @_layer_method_call_tape def forward(self, batch=None, training=False): if batch is None: batch = Array.create_from(regint(0)) @@ -1252,26 +1269,38 @@ def __init__(self, inputs, dimension): self.dimension = dimension shapes = [inp.shape for inp in inputs] assert dimension == 3 - assert len(shapes) == 2 - assert len(shapes[0]) == len(shapes[1]) + assert len(shapes[0]) == 4 + for shape in shapes: + assert len(shape) == len(shapes[0]) shape = [] for i in range(len(shapes[0])): if i == dimension: - shape.append(shapes[0][i] + shapes[1][i]) + shape.append(sum(x[i] for x in shapes)) else: - assert shapes[0][i] == shapes[1][i] shape.append(shapes[0][i]) self.Y = Tensor(shape, sfix) + self.bases = [sum(x[dimension] for x in shapes[:k]) + for k in range(len(shapes))] + self.addresses = Array.create_from(regint(list( + x.Y.address for x in inputs))) def _forward(self, batch=[0]): assert len(batch) == 1 @for_range_multithread(self.n_threads, 1, self.Y.sizes[1:3]) def _(i, j): - X = [x.Y[batch[0]] for x in self.inputs] - self.Y[batch[0]][i][j].assign_vector(X[0][i][j].get_vector()) - self.Y[batch[0]][i][j].assign_part_vector( - X[1][i][j].get_vector(), - len(X[0][i][j])) + if len(set(self.bases)) == 1: + @for_range(len(self.inputs)) + def _(k): + self.Y[batch[0]][i][j].assign_part_vector( + MultiArray( + self.inputs[0].shape, + address=self.addresses[k])[i][j].get_vector(), + k * self.bases[1]) + else: + X = [x.Y[batch[0]] for x in self.inputs] + for k in range(len(self.inputs)): + self.Y[batch[0]][i][j].assign_part_vector( + X[k][i][j].get_vector(), self.bases[k]) class Add(NoVariableLayer): """ Fixed-point addition layer. @@ -1334,12 +1363,14 @@ class BatchNorm(Layer): def __init__(self, shape, approx=True, args=None): assert len(shape) in (2, 3, 4) + self.Y = sfix.Tensor(shape) if len(shape) == 4: shape = [shape[0], shape[1] * shape[2], shape[3]] elif len(shape) == 2: shape = [shape[0], 1, shape[1]] - tensors = (Tensor(shape, sfix) for i in range(4)) - self.X, self.Y, self.nabla_X, self.nabla_Y = tensors + self.my_Y = sfix.Tensor(shape, address=self.Y.address) + tensors = (Tensor(shape, sfix) for i in range(3)) + self.X, self.nabla_X, self.nabla_Y = tensors arrays = (sfix.Array(shape[2]) for i in range(4)) self.var, self.mu, self.weights, self.bias = arrays arrays = (sfix.Array(shape[2]) for i in range(4)) @@ -1374,11 +1405,11 @@ def _output(self, batch, mu, var): [len(batch), self.X.sizes[1]]) def _(i, j): tmp = self.weights[:] * (self.X[i][j][:] - mu[:]) * factor[:] - self.Y[i][j][:] = self.bias[:] + tmp + self.my_Y[i][j][:] = self.bias[:] + tmp + @_layer_method_call_tape def forward(self, batch, training=False): if training or not self.is_trained: - self.is_trained = True d = self.X.sizes[1] d_in = self.X.sizes[2] s = sfix.Array(d_in) @@ -1411,7 +1442,7 @@ def _(base, size): print_ln('%s at %s: %s', str(x), k, x[k].reveal()) print_ln('%s at (%s, %s, %s): in=%s out=%s', str(self.Y), i, j, k, self.X[i][j][k].reveal(), - self.Y[i][j][k].reveal()) + self.my_Y[i][j][k].reveal()) else: self._output(batch, self.mu_hat, self.var_hat) @@ -1471,6 +1502,10 @@ def _(i, j): self.nabla_Y[i][j][k].reveal(), self.nabla_X[i][j][k].reveal()) + def reveal_parameters_to_binary(self): + for param in self.thetas() + (self.mu_hat, self.var_hat): + param.reveal().binary_output() + class QuantBase(object): bias_before_reduction = True @@ -1498,11 +1533,12 @@ def const_div(self, acc, n): class FixBase: bias_before_reduction = False - @staticmethod - def new_squant(): - class _(sfix): - params = None - return _ + class my_squant(sfix): + params = None + + @classmethod + def new_squant(cls): + return cls.my_squant def input_params_from(self, player): pass @@ -1553,7 +1589,8 @@ def init_temp(cls, layers): cls.temp_inputs = sfix.Array(size) def __init__(self, input_shape, weight_shape, bias_shape, output_shape, stride, - padding='SAME', tf_weight_format=False, inputs=None): + padding='SAME', tf_weight_format=False, inputs=None, + weight_type=None): super(ConvBase, self).__init__(input_shape, output_shape, inputs=inputs) self.weight_shape = weight_shape @@ -1582,7 +1619,11 @@ def __init__(self, input_shape, weight_shape, bias_shape, output_shape, stride, else: self.padding = padding - self.weight_squant = self.new_squant() + if weight_type: + self.weight_squant = weight_type + else: + self.weight_squant = self.new_squant() + self.bias_squant = self.new_squant() self.weights = Tensor(weight_shape, self.weight_squant) @@ -1645,8 +1686,7 @@ def reduction(self, batch_length=1): n_summands = self.n_summands() #start_timer(2) n_outputs = batch_length * reduce(operator.mul, self.output_shape[1:]) - @multithread(self.n_threads, n_outputs, - 1000 if sfix.round_nearest else 10 ** 6) + @multithread(self.n_threads, n_outputs, max_size=program.budget) def _(base, n_per_thread): res = self.input_squant().unreduced( sint.load_mem(unreduced.address + base, @@ -1709,7 +1749,7 @@ def _(i, j): res += self.bias.expand_to_vector(j, res.size).v else: res += self.bias.expand_to_vector(j, res.size).v << \ - self.input_squant.f + self.weight_squant.f addresses = regint.inc(res.size, self.unreduced[i * part_size].address + j, n_channels_out) @@ -2029,7 +2069,7 @@ def _(out_y, out_x, c): self.Y[0][out_y][out_x][c] = self.output_squant._new(acc) def easyConv2d(input_shape, batch_size, out_channels, kernel_size, stride=1, - padding=0): + padding=0, **kwargs): """ More convenient interface to :py:class:`FixConv2d`. :param input_shape: input shape (tuple/list of four int) @@ -2050,7 +2090,7 @@ def easyConv2d(input_shape, batch_size, out_channels, kernel_size, stride=1, padding = padding.upper() if isinstance(padding, str) \ else padding return FixConv2d(input_shape, weight_shape, (out_channels,), output_shape, - stride, padding) + stride, padding, **kwargs) def easyMaxPool(input_shape, kernel_size, stride=None, padding=0): """ More convenient interface to :py:class:`MaxPool`. @@ -2102,6 +2142,7 @@ def __init__(self, input_shape, output_shape, filter_size, strides=(1, 1), PoolBase.__init__(self, input_shape, [1] + list(strides) + [1], [1] + list(filter_size) + [1], padding) self.pool_size = reduce(operator.mul, filter_size) + self.d_out = self.Y.shape[-1] if output_shape: assert self.Y.shape == list(output_shape) @@ -2322,6 +2363,10 @@ def f(start, batch_size, batch): self.forward(batch=batch, run_last=False) part = self.layers[-1].eval(batch_size, top=top) res.assign_part_vector(part.get_vector(), start) + if self.output_stats: + for layer in self.layers[:-1]: + print_ln(layer) + self.stat(' Y', layer.Y) self.run_in_batches(f, data, batch_size or len(self.layers[1].X)) return res @@ -2567,9 +2612,9 @@ def f(start, batch_size, batch): def run_in_batches(self, f, data, batch_size, truth=None): batch_size = min(batch_size, data.sizes[0]) - training_data = self.layers[0].X.address + training_data = self.layers[0]._X.array._address training_truth = self.layers[-1].Y.address - self.layers[0].X.address = data.address + self.layers[0]._X.address = data.address if truth: self.layers[-1].Y.address = truth.address N = data.sizes[0] @@ -2748,13 +2793,15 @@ def trainable_variables(self): return list(self.thetas) def reveal_model_to_binary(self): - input_shape = self.layers[0].X.shape + """ Reveal model and store it in the binary output file, see + :ref:`reveal-model` for details. """ + input_shape = self.layers[0]._X.shape for layer in self.layers: if len(input_shape) == 4 and isinstance(layer, DenseBase): layer.reveal_parameters_to_binary(reshape=input_shape[1:]) else: layer.reveal_parameters_to_binary() - input_shape = layer.Y.shape + input_shape = layer._Y.shape class Adam(Optimizer): """ Adam/AMSgrad optimizer. @@ -3046,7 +3093,7 @@ def trainable_variables(self): def summary(self): self.opt.summary() - def build(self, input_shape, batch_size=128): + def build(self, input_shape, batch_size=128, program=None): data_input_shape = input_shape if self.opt != None and \ input_shape == self.opt.layers[0]._X.sizes and \ @@ -3120,8 +3167,11 @@ def build(self, input_shape, batch_size=128): if layers[-1].d_out == 1: layers.append(Output(data_input_shape[0])) else: - layers.append( - MultiOutput(data_input_shape[0], layers[-1].d_out)) + shape = data_input_shape[0], layers[-1].d_out + if program: + layers.append(MultiOutput.from_args(program, *shape)) + else: + layers.append(MultiOutput(*shape)) if self.optimizer[1]: raise Exception('use keyword arguments for optimizer') opt = self.optimizer[0] @@ -3186,11 +3236,11 @@ def predict(self, x, batch_size=None): batch_size = min(batch_size, self.batch_size) return self.opt.eval(x, batch_size=batch_size) -def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None, - regression=False): - """ Convert a PyTorch Sequential object to MP-SPDZ layers. +def layers_from_torch(model, data_input_shape, batch_size, input_via=None, + regression=False, layer_args={}, program=None): + """ Convert a PyTorch Module object to MP-SPDZ layers. - :param sequence: PyTorch Sequential object + :param model: PyTorch Module object :param data_input_shape: input shape (list of four int) :param batch_size: batch size (int) :param input_via: player to input model data via (default: don't) @@ -3198,17 +3248,29 @@ def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None, """ layers = [] + named_layers = {} def mul(x): return reduce(operator.mul, x) - def process(item): - nonlocal input_shape + import torch + + def process(item, inputs, input_shape, args): + if item == torch.cat: + if len(inputs) > 1: + layers.append( + Concat(inputs, dimension=len(inputs[0].shape) - 1)) + return + elif item == operator.add: + layers.append(Add(inputs)) + return + elif item == torch.flatten: + return + # single-input layers from here + if inputs and len(inputs) > 1: + raise CompilerError('multi-input layer %s not supported' % item) name = type(item).__name__ - if name == 'Sequential': - for x in item: - process(x) - elif name == 'Linear': + if name == 'Linear': assert mul(input_shape[1:]) == item.in_features assert item.bias is not None layers.append(Dense(input_shape[0], item.in_features, @@ -3239,7 +3301,7 @@ def process(item): elif name == 'Conv2d': layers.append(easyConv2d(input_shape, batch_size, item.out_channels, item.kernel_size, item.stride, - item.padding)) + item.padding, **layer_args.get(item, {}))) input_shape = layers[-1].Y.shape if input_via is not None: shapes = [x.shape for x in @@ -3247,11 +3309,14 @@ def process(item): import numpy swapped = numpy.moveaxis( numpy.array(item.weight.detach()), 1, -1) - layers[-1].weights = sfix.input_tensor_via(input_via, swapped) - layers[-1].bias = sfix.input_tensor_via( - input_via, item.bias.detach()) + layers[-1].weights = \ + layers[-1].weights.value_type.input_tensor_via( + input_via, swapped) assert layers[-1].weights.shape == shapes[0] - assert layers[-1].bias.shape == shapes[1] + if isinstance(item.bias, torch.Tensor): + layers[-1].bias = sfix.input_tensor_via( + input_via, item.bias.detach()) + assert layers[-1].bias.shape == shapes[1] elif name == 'MaxPool2d': layers.append(easyMaxPool(input_shape, item.kernel_size, item.stride, item.padding)) @@ -3260,10 +3325,24 @@ def process(item): layers.append(FixAveragePool2d(input_shape, None, item.kernel_size, item.stride, item.padding)) input_shape = layers[-1].Y.shape - elif name == 'ReLU': + elif name == 'AdaptiveAvgPool2d' or \ + item == torch.nn.functional.adaptive_avg_pool2d: + if name == 'AdaptiveAvgPool2d': + output = item.output_size + else: + output = args[1] + for i in (0, 1): + assert input_shape[1 + i] % output[i] == 0 + stride = [input_shape[1 + i] // output[i] for i in (0, 1)] + kernel_size = [input_shape[1 + i] - (output[i] - 1) * stride[i] + for i in (0, 1)] + layers.append(FixAveragePool2d(input_shape, None, kernel_size, + stride, padding=0)) + input_shape = layers[-1].Y.shape + elif name == 'ReLU' or item == torch.nn.functional.relu: layers.append(Relu(input_shape)) elif name == 'Flatten': - pass + return elif name == 'BatchNorm2d': layers.append(BatchNorm(layers[-1].Y.sizes)) if input_via is not None: @@ -3277,16 +3356,52 @@ def process(item): alpha=item.p)) input_shape = layers[-1].Y.sizes else: - raise CompilerError('unknown PyTorch module: ' + name) + raise CompilerError('unknown PyTorch module: %s' % item) + layers[-1].inputs = inputs input_shape = data_input_shape + [1] * (4 - len(data_input_shape)) - process(sequence) + + torch_layers = list(torch.fx.symbolic_trace(model).graph.nodes) + for i, layer in enumerate(torch_layers[1:-1]): + if layer.op == 'call_module': + target = model + for attr in layer.target.split('.'): + target = getattr(target, attr) + else: + target = layer.target + if not layers: + assert layer.args == (torch_layers[i],) + inputs = None + else: + if len(layer.args) < 2 or (layer.args[1] != 1 and + layer.args[1] != (1, 1)): + args = layer.args + elif isinstance(layer.args[0], list): + args = layer.args[0] + else: + args = layer.args[0], + inputs = [named_layers[x] for x in args] + if len(inputs) == 1: + if isinstance(inputs[0], (Dropout, BatchNorm)): + input_shape = inputs[0].inputs[0].Y.shape + else: + input_shape = inputs[0]._Y.shape + else: + input_shape = None + process(target, inputs, input_shape, layer.args) + if layers: + named_layers[layer] = layers[-1] + if regression: layers.append(LinearOutput(data_input_shape[0], layers[-1].d_out)) elif layers[-1].d_out == 1: layers.append(Output(data_input_shape[0])) else: - layers.append(MultiOutput(data_input_shape[0], layers[-1].d_out)) + shape = data_input_shape[0], layers[-1].d_out + if program: + layers.append(MultiOutput.from_args(program, *shape)) + else: + layers.append(MultiOutput(*shape)) return layers class OneLayerSGD: @@ -3335,6 +3450,9 @@ def predict(self, X): class SGDLogistic(OneLayerSGD): """ Logistic regression using SGD. + The member :py:obj:`opt` refers to the internal instance of + :py:class:`Optimizer`, which allows to use the funcionality + therein. :param n_epochs: number of epochs :param batch_size: batch size @@ -3460,6 +3578,8 @@ def _(i): def mr(A, n_iterations, stop=False): """ Iterative matrix inverse approximation. + This is based on the conjugate gradients algorithm in Section + 10.2.4 of `these lecture notes `_. :param A: matrix to invert :param n_iterations: maximum number of iterations diff --git a/Compiler/mpc_math.py b/Compiler/mpc_math.py index 9e3d40972..2208fc212 100644 --- a/Compiler/mpc_math.py +++ b/Compiler/mpc_math.py @@ -1,7 +1,9 @@ """ Module for math operations. -Implements trigonometric and logarithmic functions. +Most of the functionality is due to `Aly and Smart +`_ with some optimizations by +`Keller and Sun `_. This has to imported explicitly. """ @@ -98,7 +100,7 @@ # @return truncated sint value of x def trunc(x): if isinstance(x, types._fix): - return x.v.right_shift(x.f, x.k, security=x.kappa, signed=True) + return x.v.right_shift(x.f, x.k, signed=True) elif type(x) is types.sfloat: v, p, z, s = floatingpoint.FLRound(x, 0) #return types.sfloat(v, p, z, s, x.err) @@ -106,19 +108,6 @@ def trunc(x): return x -## -# loads integer to fractional type (sint) -# @param x: coefficient to be truncated. -# -# @return returns sfix, sfloat loaded value -def load_sint(x, l_type): - if l_type is types.sfix: - return types.sfix.from_sint(x) - elif l_type is types.sfloat: - return x - return x - - ## # evaluates a Polynomial to a given x in a privacy preserving manner. # Inputs can be of any kind of register, secret or otherwise. @@ -448,7 +437,7 @@ def log2_fx(x, use_division=True): if isinstance(x, types._fix): # transforms sfix to f*2^n, where f is [o.5,1] bounded # obtain number bounded by [0,5 and 1] by transforming input to sfloat - v, p, z, s = floatingpoint.Int2FL(x.v, x.k, x.f, x.kappa) + v, p, z, s = floatingpoint.Int2FL(x.v, x.k, x.f) p -= x.f vlen = x.f v = x._new(v, k=x.k, f=x.f) @@ -473,7 +462,7 @@ def log2_fx(x, use_division=True): return a # *(1-(f.z))*(1-f.s)*(1-f.error) -def pow_fx(x, y): +def pow_fx(x, y, zero_output=False): """ Returns the value of the expression :math:`x^y` where both inputs are secret shared. It uses :py:func:`log2_fx` together with @@ -494,7 +483,7 @@ def pow_fx(x, y): # obtains y * log2(x) exp = y * log2_x # returns 2^(y*log2(x)) - return exp2_fx(exp) + return exp2_fx(exp, zero_output) def log_fx(x, b): @@ -535,8 +524,8 @@ def abs_fx(x): # # @return floored sint value of x def floor_fx(x): - return load_sint(x.v.right_shift(x.f, bit_length=x.k, security=x.kappa, - signed=True), type(x)) + return type(x)(x.v.right_shift(x.f, bit_length=x.k, signed=True), + k=x.k, f=x.f) ### sqrt methods @@ -743,13 +732,13 @@ def lin_app_SQ(b, k, f): c, v, m, W = norm_SQ(types.sint(b), k) # c is now escalated - w = alpha * load_sint(c,types.sfix) + beta # equation before b and reduction by order of k + w = alpha * c + beta # equation before b and reduction by order of k # m even or odd determination m_bit = types.sint() - comparison.Mod2(m_bit, m, int(math.ceil(math.log(k, 2))), w.kappa, False) - m = load_sint(m_bit, types.sfix) + comparison.Mod2(m_bit, m, int(math.ceil(math.log(k, 2))), signed=False) + m = m_bit # w times v this way both terms have 2^3k and can be symplified w = w * v @@ -774,7 +763,7 @@ def lin_app_SQ(b, k, f): def sqrt_fx(x_l, k, f): factor = 1.0 / (2.0 ** f) - x = load_sint(x_l, types.sfix) * factor + x = x_l * factor theta = int(math.ceil(math.log(k/5.4))) @@ -912,29 +901,29 @@ def tanh(x): # next functions due to https://dl.acm.org/doi/10.1145/3411501.3419427 -def Sep(x): +def Sep(x, sfix=types.sfix): b = floatingpoint.PreOR(list(reversed(x.v.bit_decompose(x.k, maybe_mixed=True)))) bb = b[:] while len(bb) < 2 * x.f - 1: bb.insert(0, type(b[0])(0)) t = x.v * (1 + x.v.bit_compose(b_i.bit_not() for b_i in bb[-2 * x.f + 1:])) - u = types.sfix._new(t.right_shift(x.f, 2 * x.k, signed=False)) + u = sfix._new(t.right_shift(x.f, 2 * x.k, signed=False)) b += [b[0].long_one()] return u, [b[i + 1] - b[i] for i in reversed(range(x.k))] -def SqrtComp(z, old=False): - f = types.sfix.f +def SqrtComp(z, old=False, sfix=types.sfix): + f = sfix.f k = len(z) if isinstance(z[0], types.sint): - return types.sfix._new(sum(z[i] * types.cfix( + return sfix._new(sum(z[i] * types.cfix( 2 ** (-(i - f + 1) / 2), k=k, f=f).v for i in range(k))) k_prime = k // 2 f_prime = f // 2 - c1 = types.sfix(2 ** ((f + 1) / 2 + 1)) - c0 = types.sfix(2 ** (f / 2 + 1)) + c1 = sfix(2 ** ((f + 1) / 2 + 1)) + c0 = sfix(2 ** (f / 2 + 1)) a = [z[2 * i].bit_or(z[2 * i + 1]) for i in range(k_prime)] - tmp = types.sfix._new(types.sint.bit_compose(reversed(a[:2 * f_prime]))) + tmp = sfix._new(types.sint.bit_compose(reversed(a[:2 * f_prime]))) if old: b = sum(types.sint.conv(zi).if_else(i, 0) for i, zi in enumerate(z)) % 2 else: @@ -942,11 +931,15 @@ def SqrtComp(z, old=False): return types.sint.conv(b).if_else(c1, c0) * tmp @types.vectorize +@instructions_base.sfix_cisc def InvertSqrt(x, old=False): """ Reciprocal square root approximation by `Lu et al. `_ """ - u, z = Sep(x) + class my_sfix(types.sfix): + f = x.f + k = x.k + u, z = Sep(x, sfix=my_sfix) c = 3.14736 + u * (4.63887 * u - 5.77789) - return c * SqrtComp(z, old=old) + return c * SqrtComp(z, old=old, sfix=my_sfix) diff --git a/Compiler/non_linear.py b/Compiler/non_linear.py index 66e82908d..147af9f20 100644 --- a/Compiler/non_linear.py +++ b/Compiler/non_linear.py @@ -4,14 +4,6 @@ from . import comparison, program class NonLinear: - kappa = None - - def set_security(self, kappa): - pass - - def check_security(self, kappa): - pass - def mod2m(self, a, k, m, signed): """ a_prime = a % 2^m @@ -45,18 +37,16 @@ def trunc_pr(self, a, k, m, signed=True): def trunc_round_nearest(self, a, k, m, signed): res = sint() - comparison.Trunc(res, a + (1 << (m - 1)), k + 1, m, self.kappa, - signed) + comparison.Trunc(res, a + (1 << (m - 1)), k + 1, m, signed) return res - def trunc(self, a, k, m, kappa, signed): - self.check_security(kappa) + def trunc(self, a, k, m, signed): if m == 0: return a return self._trunc(a, k, m, signed) - def ltz(self, a, k, kappa=None): - return -self.trunc(a, k, k - 1, kappa, True) + def ltz(self, a, k): + return -self.trunc(a, k, k - 1, True) class Masking(NonLinear): def eqz(self, a, k): @@ -68,28 +58,19 @@ def eqz(self, a, k): class Prime(Masking): """ Non-linear functionality modulo a prime with statistical masking. """ - def __init__(self, kappa): - self.set_security(kappa) - - def set_security(self, kappa): - self.kappa = kappa - - def check_security(self, kappa): - assert self.kappa == kappa or kappa is None - def _mod2m(self, a, k, m, signed): res = sint() if m == 1: - Mod2(res, a, k, self.kappa, signed) + Mod2(res, a, k, signed) else: - Mod2mField(res, a, k, m, self.kappa, signed) + Mod2mField(res, a, k, m, signed) return res def _mask(self, a, k): - return maskField(a, k, self.kappa) + return maskField(a, k) def _trunc_pr(self, a, k, m, signed=None): - return TruncPrField(a, k, m, self.kappa) + return TruncPrField(a, k, m) def _trunc(self, a, k, m, signed=None): a_prime = self.mod2m(a, k, m, signed) @@ -99,12 +80,12 @@ def _trunc(self, a, k, m, signed=None): def bit_dec(self, a, k, m, maybe_mixed=False): if maybe_mixed: - return BitDecFieldRaw(a, k, m, self.kappa) + return BitDecFieldRaw(a, k, m) else: - return BitDecField(a, k, m, self.kappa) + return BitDecField(a, k, m) def kor(self, d): - return KOR(d, self.kappa) + return KOR(d) class KnownPrime(NonLinear): """ Non-linear functionality modulo a prime known at compile time. """ @@ -144,13 +125,13 @@ def eqz(self, a, k): a += two_power(k) return 1 - types.sintbit.conv(KORL(self.bit_dec(a, k, k, True))) - def ltz(self, a, k, kappa=None): + def ltz(self, a, k): if k + 1 < self.prime.bit_length(): # https://dl.acm.org/doi/10.1145/3474123.3486757 # "negative" values wrap around when doubling, thus becoming odd return self.mod2m(2 * a, k + 1, 1, False) else: - return super(KnownPrime, self).ltz(a, k, kappa) + return super(KnownPrime, self).ltz(a, k) class Ring(Masking): """ Non-linear functionality modulo a power of two known at compile time. @@ -189,5 +170,5 @@ def trunc_round_nearest(self, a, k, m, signed): else: return super(Ring, self).trunc_round_nearest(a, k, m, signed) - def ltz(self, a, k, kappa=None): + def ltz(self, a, k): return LtzRing(a, k) diff --git a/Compiler/oram.py b/Compiler/oram.py index 04d084bc6..d4862e669 100644 --- a/Compiler/oram.py +++ b/Compiler/oram.py @@ -877,7 +877,7 @@ def _read(self, index): demux_array(bit_decompose(index, self.index_size), \ self.index_vector) t = self.value_type.get_type(None if None in self.entry_size else max(self.entry_size)) - @map_sum(get_n_threads(self.size), n_parallel, self.size, \ + @map_sum(get_n_threads(self.size), None, self.size, \ self.value_length + 1, t) def f(i): entry = self.ram[i] @@ -897,7 +897,7 @@ def _write(self, index, *new_value): new_value = make_array( new_value, self.value_type.get_type( max(x or 0 for x in self.entry_size))) - @for_range_multithread(get_n_threads(self.size), n_parallel, self.size) + @for_range_multithread(get_n_threads(self.size), None, self.size) def f(i): entry = self.ram[i] access_here = self.index_vector[i] @@ -917,7 +917,7 @@ def _access(self, index, write, new_empty, *new_value): max(x or 0 for x in self.entry_size))) new_empty = MemValue(new_empty) write = MemValue(write) - @map_sum(get_n_threads(self.size), n_parallel, self.size, \ + @map_sum(get_n_threads(self.size), None, self.size, \ self.value_length + 1, [self.value_type.bit_type] + \ [self.value_type] * self.value_length) def f(i): @@ -1340,8 +1340,8 @@ def _(i): half = (empty_positions[i]+1 - parity) // 2 half_max = self.bucket_size // 2 - bits = floatingpoint.B2U(half, half_max, Program.prog.security)[0] - bits2 = floatingpoint.B2U(half+parity, half_max, Program.prog.security)[0] + bits = floatingpoint.B2U(half, half_max)[0] + bits2 = floatingpoint.B2U(half+parity, half_max)[0] # (doesn't work) #bits2 = [0] * half_max ## second half with parity bit @@ -1350,7 +1350,8 @@ def _(i): #bits2[0] = (1 - bits[0]) * parity bucket_bits = [b for sl in zip(bits2,bits) for b in sl] else: - bucket_bits = floatingpoint.B2U(empty_positions[i]+1, self.bucket_size, Program.prog.security)[0] + bucket_bits = floatingpoint.B2U(empty_positions[i]+1, + self.bucket_size)[0] assert len(bucket_bits) == self.bucket_size for j, b in enumerate(bucket_bits): pos_bits[i * self.bucket_size + j] = [b, leaf] @@ -1376,8 +1377,7 @@ def _(i): Program.prog.curr_tape.start_new_basicblock() bucket_sizes = Array(2**self.D, regint) - for i in range(2**self.D): - bucket_sizes[i] = 0 + bucket_sizes.assign_all(0) @for_range_opt(len(entries)) def _(k): @@ -1697,7 +1697,7 @@ class OneLevelORAM(TreeORAM): class BinaryORAM: def __init__(self, size, value_type=None, **kwargs): - import circuit_oram + from Compiler import circuit_oram from Compiler.GC import types n_bits = int(get_program().options.binary) self.value_type = value_type or types.sbitintvec.get_type(n_bits) diff --git a/Compiler/program.py b/Compiler/program.py index 72bde7ec5..a29b6ec93 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -18,7 +18,7 @@ import Compiler.instructions import Compiler.instructions_base import Compiler.instructions_base as inst_base -from Compiler.config import REG_MAX, USER_MEM, COST +from Compiler.config import REG_MAX, USER_MEM, COST, MEM_MAX from Compiler.exceptions import CompilerError from Compiler.instructions_base import RegType @@ -103,6 +103,38 @@ def __init__(self, args, options=defaults, name=None): self.bit_length = int(options.binary) or int(options.field) if options.prime: self.prime = int(options.prime) + print("WARNING: --prime/-P activates code that usually isn't " + "the most efficient variant. Consider using --field/-F " + "and set the prime only during the actual computation.") + if not self.rabbit_gap() and self.prime > 2 ** 50: + print("The chosen prime is particularly inefficient. " + "Consider using a prime that is closer to a power " + "of two", end='') + try: + import gmpy2 + bad_prime = self.prime + self.prime = 2 ** int( + round(math.log(self.prime, 2))) + 1 + while True: + if self.prime > 2 ** 59: + # LWE compatibility + step = 2 ** 15 + else: + step = 1 + if self.prime < bad_prime: + self.prime += step + else: + self.prime -= step + if gmpy2.is_prime(self.prime): + break + assert self.rabbit_gap() + print(", for example, %d." % self.prime) + self.prime = bad_prime + except ImportError: + print(".") + if options.execute: + print("Use '-- --prime ' to specify the prime for " + "execution only.") max_bit_length = int(options.prime).bit_length() - 2 if self.bit_length > max_bit_length: raise CompilerError( @@ -111,7 +143,7 @@ def __init__(self, args, options=defaults, name=None): self.bit_length = self.bit_length or max_bit_length self.non_linear = KnownPrime(self.prime) else: - self.non_linear = Prime(self.security) + self.non_linear = Prime() if not self.bit_length: self.bit_length = 64 print("Default bit length for compilation:", self.bit_length) @@ -197,6 +229,8 @@ def __init__(self, args, options=defaults, name=None): self.cisc_to_function = True if not self.options.cisc: self.options.cisc = not self.options.optimize_hard + self.use_tape_calls = True + self.force_cisc_tape = False Program.prog = self from . import comparison, instructions, instructions_base, types @@ -278,7 +312,8 @@ def set_ring_size(self, ring_size): self.non_linear = Ring(ring_size) self.options.ring = str(ring_size) - def new_tape(self, function, args=[], name=None, single_thread=False): + def new_tape(self, function, args=[], name=None, single_thread=False, + finalize=True, **kwargs): """ Create a new tape from a function. See :py:func:`~Compiler.library.multithread` and @@ -309,11 +344,12 @@ def g(): self.curr_tape tape_index = len(self.tapes) self.tape_stack.append(self.curr_tape) - self.curr_tape = Tape(name, self) + self.curr_tape = Tape(name, self, **kwargs) self.curr_tape.singular = single_thread self.tapes.append(self.curr_tape) function(*args) - self.finalize_tape(self.curr_tape) + if finalize: + self.finalize_tape(self.curr_tape) if self.tape_stack: self.curr_tape = self.tape_stack.pop() return tape_index @@ -346,6 +382,7 @@ def run_tapes(self, args): thread_numbers = [] while len(thread_numbers) < len(args): free_threads = self.curr_tape.free_threads + self.curr_tape.ran_threads = True if free_threads: thread_numbers.append(min(free_threads)) free_threads.remove(thread_numbers[-1]) @@ -417,7 +454,10 @@ def write_bytes(self): def finalize_tape(self, tape): if not tape.purged: + curr_tape = self.curr_tape + self.curr_tape = tape tape.optimize(self.options) + self.curr_tape = curr_tape tape.write_bytes() if self.options.asmoutfile: tape.write_str(self.options.asmoutfile + "-" + tape.name) @@ -472,16 +512,18 @@ def malloc(self, size, mem_type, reg_type=None, creator_tape=None, use_freed=Tru self.allocated_mem[mem_type] += size if len(str(addr)) != len(str(addr + size)) and self.verbose: print("Memory of type '%s' now of size %d" % (mem_type, addr + size)) - if addr + size >= 2**64: - raise CompilerError("allocation exceeded for type '%s'" % mem_type) + if addr + size >= MEM_MAX: + raise CompilerError( + "allocation exceeded for type '%s' after adding %d" % \ + (mem_type, size)) self.allocated_mem_blocks[addr, mem_type] = size, self.curr_block.alloc_pool if single_size: - from .library import get_thread_number, runtime_error_if + from .library import get_arg, runtime_error_if bak = self.curr_tape.active_basicblock self.curr_tape.active_basicblock = self.curr_tape.basicblocks[0] - tn = get_thread_number() - runtime_error_if(tn > self.n_running_threads, "malloc") - res = addr + single_size * (tn - 1) + arg = get_arg() + runtime_error_if(arg >= self.n_running_threads, "malloc") + res = addr + single_size * arg self.curr_tape.active_basicblock = bak self.base_addresses[res] = addr return res @@ -577,7 +619,6 @@ def set_bit_length(self, bit_length): def set_security(self, security): changed = self._security != security self._security = security - self.non_linear.set_security(security) if changed: print("Changed statistical security for comparison etc. to", security) @@ -783,7 +824,7 @@ def read_domain_size(cls, schedule): class Tape: """A tape contains a list of basic blocks, onto which instructions are added.""" - def __init__(self, name, program): + def __init__(self, name, program, thread_pool=None): """Set prime p and the initial instructions and registers.""" self.program = program name += "-%d" % program.get_tape_counter() @@ -800,12 +841,15 @@ def __init__(self, name, program): self.merge_opens = True self.if_states = [] self.req_bit_length = defaultdict(lambda: 0) + self.bit_length_reason = None self.function_basicblocks = {} self.functions = [] self.singular = True - self.free_threads = set() + self.free_threads = set() if thread_pool is None else thread_pool self.loop_breaks = [] self.warned_about_mem = False + self.return_values = [] + self.ran_threads = False class BasicBlock(object): def __init__(self, parent, name, scope, exit_condition=None, @@ -880,7 +924,8 @@ def relevant(inst): self.usage_instructions = list(filter(relevant, self.instructions)) else: self.usage_instructions = [] - if len(self.usage_instructions) > 1000: + if len(self.usage_instructions) > 1000 and \ + self.parent.program.verbose: print("Retaining %d instructions" % len(self.usage_instructions)) del self.instructions self.purged = True @@ -1107,6 +1152,9 @@ def optimize(self, options): if addr.program == self and self.basicblocks: allocator.alloc_reg(addr, self.basicblocks[-1].alloc_pool) + for reg in self.return_values: + allocator.alloc_reg(reg, self.basicblocks[-1].alloc_pool) + seen = set() def alloc(block): @@ -1214,7 +1262,10 @@ def alloc_loop(block): Compiler.instructions.reqbl(bl, add_to_prog=False) ) if self.program.verbose: - print("Tape requires prime bit length", self.req_bit_length["p"]) + print("Tape requires prime bit length", + self.req_bit_length["p"], + ('for %s' % self.bit_length_reason + if self.bit_length_reason else '')) print("Tape requires galois bit length", self.req_bit_length["2"]) @unpurged @@ -1287,6 +1338,7 @@ def write_bytes(self, filename=None): if "Bytecode" not in filename: filename = self.program.programs_dir + "/Bytecode/" + filename print("Writing to", filename) + sys.stdout.flush() f = open(filename, "wb") h = hashlib.sha256() for i in self._get_instructions(): @@ -1395,13 +1447,12 @@ def __repr__(self): return repr(dict(self)) class ReqNode(object): - __slots__ = ["num", "_children", "name", "blocks", "aggregated"] - def __init__(self, name): self._children = [] self.name = name self.blocks = [] self.aggregated = None + self.num = None @property def children(self): @@ -1411,12 +1462,17 @@ def children(self): def aggregate(self, *args): if self.aggregated is not None: return self.aggregated + self.recursion = self.num is not None + if self.recursion: + return Tape.ReqNum() self.num = Tape.ReqNum() for block in self.blocks: block.add_usage(self) res = reduce( lambda x, y: x + y.aggregate(self.name), self.children, self.num ) + if self.recursion: + res *= float('inf') self.aggregated = res return res @@ -1442,7 +1498,7 @@ def aggregate(self, name): n_reps = self.aggregator([1]) n_rounds = res["all", "round"] n_invs = res["all", "inv"] - if (n_invs / n_rounds) * 1000 < n_reps: + if (n_invs / n_rounds) * 1000 < n_reps and Program.prog.verbose: print( self.nodes[0].blocks[0].name, "blowing up rounds: ", @@ -1468,15 +1524,19 @@ def open_scope(self, aggregator, scope=False, name=""): def close_scope(self, outer_scope, parent_req_node, name): self.start_new_basicblock(outer_scope, name, req_node=parent_req_node) - def require_bit_length(self, bit_length, t="p"): + def require_bit_length(self, bit_length, t="p", reason=None): if t == "p": if self.program.prime: if bit_length >= self.program.prime.bit_length() - 1: raise CompilerError( "required bit length %d too much for %d" % (bit_length, self.program.prime) + + ('(for %s)' % reason if reason else '') ) - self.req_bit_length[t] = max(bit_length + 1, self.req_bit_length[t]) + bit_length += 1 + if bit_length > self.req_bit_length[t]: + self.req_bit_length[t] = bit_length + self.bit_length_reason = reason else: self.req_bit_length[t] = max(bit_length, self.req_bit_length) @@ -1498,6 +1558,21 @@ def __bool__(self): "In some cases, you can fix this by using 'compile.py -l'." ) + def __int__(self): + raise CompilerError( + "It is impossible to convert run-time types to compile-time " + "Python types like int or float. The reason for this is that " + "%s objects are only a placeholder during the execution in " + "Python, the actual value of which is only defined in the " + "virtual machine at a later time. See " + "https://mp-spdz.readthedocs.io/en/latest/journey.html " + "to get an understanding of the overall design. " + "In rare cases, you can fix this by using 'compile.py -l'." % \ + type(self).__name__ + ) + + __float__ = __int__ + class Register(_no_truth): """ Class for creating new registers. The register's index is automatically assigned @@ -1619,6 +1694,9 @@ def __len__(self): def copy(self): return Tape.Register(self.reg_type, Program.prog.curr_tape) + def same_type(self): + return type(self)(size=self.size) + def link(self, other): if Program.prog.options.noreallocate: raise CompilerError("reallocation necessary for linking, " diff --git a/Compiler/sorting.py b/Compiler/sorting.py index c8cb87e89..f4f38caba 100644 --- a/Compiler/sorting.py +++ b/Compiler/sorting.py @@ -1,5 +1,6 @@ import itertools from Compiler import types, library, instructions +from Compiler import comparison, util def dest_comp(B): Bt = B.transpose() @@ -20,6 +21,7 @@ def reveal_sort(k, D, reverse=False): backward order """ + comparison.require_ring_size(util.log2(len(k)) + 1, 'sorting') assert len(k) == len(D) library.break_point() shuffle = types.sint.get_secure_shuffle(len(k)) diff --git a/Compiler/sqrt_oram.py b/Compiler/sqrt_oram.py index 0fecc67d6..28b063e9d 100644 --- a/Compiler/sqrt_oram.py +++ b/Compiler/sqrt_oram.py @@ -109,7 +109,7 @@ def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type self.shuffle_used = cint.Array(self.n) # Random permutation on the data self.shufflei = Array.create_from( - [self.index_type(i) for i in range(self.n)]) + self.index_type(regint.inc(self.n))) # Calculate the period if not given # upon recursion, the period should stay the same ("in sync"), # therefore it can be passed as a constructor parameter @@ -122,7 +122,7 @@ def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type # Note that self.shuffle_the_shuffle mutates this field # Why don't we pass it as an argument then? Well, this way we don't have to allocate memory while shuffling, which keeps open the possibility for multithreading self.permutation = Array.create_from( - [self.index_type(i) for i in range(self.n)]) + self.index_type(regint.inc(self.n))) # We allow the caller to postpone the initialization of the shuffle # This is the most expensive operation, and can be done in a thread (only if you know what you're doing) # Note that if you do not initialize, the ORAM is insecure @@ -256,7 +256,6 @@ def _(): return result - @lib.method_block def write(self, index: T, *value: T): global trace, n_parallel if trace: @@ -271,7 +270,12 @@ def write(self, index: T, *value: T): else: raise Exception("Cannot handle type of value passed") print(self.entry_length, value, type(value),len(value)) - value = MemValue(value) + + self._write(index, *value) + + @lib.method_block + def _write(self, index: T, *value: T): + value = MemValue(self.value_type(value)) index = MemValue(index) # Refresh if we have performed T (period) accesses @@ -513,14 +517,14 @@ def _(): # Since the underlying memory of the position map is already aligned in # this packed structure, we can simply overwrite the memory while # maintaining the structure. - self.position_map.reinitialize(*self.permutation) + self.position_map.reinitialize(self.permutation) - def reinitialize(self, *data: T): + def reinitialize(self, data: T): # Note that this method is only used during refresh, and as such is # only called with a permutation as data. # The logical addresses of some previous permutation are irrelevant and must be reset - self.shufflei.assign([self.index_type(i) for i in range(self.n)]) + self.shufflei.assign_vector(self.index_type(regint.inc(self.n))) # Reset the clock self.t.write(0) # Reset shuffle_used @@ -530,10 +534,10 @@ def reinitialize(self, *data: T): # This structure is preserved while overwriting the values using # assign_vector self.shuffle.assign_vector(self.value_type( - data, size=self.n * self.entry_length)) + data[:], size=self.n * self.entry_length)) # Note that this updates self.permutation (see constructor for explanation) self.shuffle_the_shuffle() - self.position_map.reinitialize(*self.permutation) + self.position_map.reinitialize(self.permutation) def _reset_shuffle_used(self): global allow_memory_allocation @@ -568,7 +572,7 @@ def get_position(self, logical_address: _secret, fake: B) -> Any: print_at_depth(self.depth, 'Scanning %s for logical address %s (fake=%s)', self.__class__.__name__, logical_address.reveal(), sintbit(fake).reveal()) - def reinitialize(self, *permutation: T): + def reinitialize(self, permutation: T): """Reinitialize this PositionMap. Since the reinitialization occurs at runtime (`on SqrtORAM.refresh()`), @@ -613,9 +617,10 @@ def __init__(self, permutation: Array, period: int, packed_size = int(math.ceil(self.n / pack)) packed_structure = MultiArray( (packed_size, pack), value_type=value_type) - for i in range(packed_size): + @lib.for_range(packed_size) + def _(i): packed_structure[i] = Array.create_from( - permutation[i*pack:(i+1)*pack]) + permutation.get_vector(base=i * pack, size=pack)) SqrtOram.__init__(self, packed_structure, value_type=value_type, period=period, entry_length=pack, k=self.depth, @@ -720,8 +725,8 @@ def _(i): return p.reveal() - def reinitialize(self, *permutation: T): - SqrtOram.reinitialize(self, *permutation) + def reinitialize(self, permutation: T): + SqrtOram.reinitialize(self, permutation) class LinearPositionMap(PositionMap): @@ -790,8 +795,8 @@ def _(): return p.reveal() - def reinitialize(self, *data: T): - self.physical.assign_vector(data) + def reinitialize(self, data : T): + self.physical.assign(data) global allow_memory_allocation if allow_memory_allocation: diff --git a/Compiler/types.py b/Compiler/types.py index 29254268a..9ca5f1a00 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -27,13 +27,22 @@ sint.get_input_from(0, size=10) * sint.get_input_from(1, size=10) +The following types are available in arithmetic circuits and with +reduced functionality in binary circuits. + .. autosummary:: :nosignatures: sint - cint - regint sfix + regint + +The following types are only available in arithmetic circuits. + +.. autosummary:: + :nosignatures: + + cint cfix sfloat sgf2n @@ -84,18 +93,19 @@ class ClientMessageType: class MPCThread(object): def __init__(self, target, name, args = [], runtime_arg = 0, - single_thread = False): + single_thread = False, finalize = True): """ Create a thread from a callable object. """ if not callable(target): raise CompilerError('Target %s for thread %s is not callable' % (target,name)) self.name = name - self.tape = Tape(program.name + '-' + name, program) self.target = target self.args = args self.runtime_arg = runtime_arg self.running = 0 self.tape_handle = program.new_tape(target, args, name, - single_thread=single_thread) + single_thread=single_thread, + finalize=finalize) + self.tape = program.tapes[self.tape_handle] self.run_handles = [] def start(self, runtime_arg = None): @@ -182,8 +192,7 @@ def vectorized_function(cls, *args, **kwargs): def vectorize_init(function): def vectorized_init(*args, **kwargs): size = None - if len(args) > 1 and (isinstance(args[1], _register) or \ - isinstance(args[1], sfloat)): + if len(args) > 1 and isinstance(args[1], (_register, sfloat, cfix)): size = args[1].size if 'size' in kwargs and kwargs['size'] is not None \ and kwargs['size'] != size: @@ -306,7 +315,7 @@ def mul_no_reduce(self, other, res_params=None): def reduce_after_mul(self): return self - def pow2(self, bit_length=None, security=None): + def pow2(self, bit_length=None): return 2**self def min(self, other): @@ -372,7 +381,15 @@ def ripple_carry_adder(*args, **kwargs): return intbitint.ripple_carry_adder(*args, **kwargs) def if_else(self, a, b): - """ MUX on bit in arithmetic circuits. + """ MUX on bit in arithmetic circuits:: + + print_ln('%s', sint(0).if_else(sint(2), sint(3)).reveal()) + print_ln('%s', sint(1).if_else(sint(4), sint(5)).reveal()) + + This will output:: + + 3 + 4 :param a/b: any type supporting the necessary operations :return: a if :py:obj:`self` is 1, b if :py:obj:`self` is 0, undefined otherwise @@ -537,29 +554,35 @@ def Matrix(cls, rows, columns, *args, **kwargs): return Matrix(rows, columns, cls, *args, **kwargs) @classmethod - def Tensor(cls, shape): + def Tensor(cls, shape, **kwargs): """ Type-dependent tensor of any dimension:: a = sfix.Tensor([10, 10]) """ if len(shape) == 1: - return Array(shape[0], cls) + return Array(shape[0], cls, **kwargs) elif len(shape) == 2: - return Matrix(*shape, cls) + return Matrix(*shape, cls, **kwargs) else: - return MultiArray(shape, cls) + return MultiArray(shape, cls, **kwargs) @classmethod def row_matrix_mul(cls, row, matrix, res_params=None): - return sum(row[k].mul_no_reduce(matrix[k].get_vector(), - res_params) \ - for k in range(len(row))).reduce_after_mul() + res = type(row[0].mul_no_reduce( + matrix[0][0], res_params=res_params))(0, size=matrix.sizes[1]) + @library.for_range_opt(len(row)) + def _(k): + res.iadd(row[k].mul_no_reduce(matrix[k].get_vector(), res_params)) + return res.reduce_after_mul() @staticmethod def mem_size(): return 1 + def size_for_mem(self): + return self.size + class _secret_structure(_structure): @classmethod def input_tensor_from(cls, player, shape): @@ -724,6 +747,7 @@ def hard_conv(cls, val): @vectorized_classmethod @set_instruction_type + @read_mem_value def _load_mem(cls, address, direct_inst, indirect_inst): if isinstance(address, _register): if address.size > 1: @@ -748,6 +772,7 @@ def _expand_address(address, size): else: return address + @read_mem_value @set_instruction_type def _store_in_mem(self, address, direct_inst, indirect_inst): if isinstance(address, _register): @@ -858,8 +883,8 @@ def get_type(cls, length): def two_power(n, size=None): return floatingpoint.two_power(n) - def Norm(self, k, f, kappa=None, simplex_flag=False): - return library.Norm(self, k, f, kappa=kappa, simplex_flag=simplex_flag) + def Norm(self, k, f, simplex_flag=False): + return library.Norm(self, k, f, simplex_flag=simplex_flag) class _clear(_arithmetic_register): """ Clear domain-dependent type. """ @@ -1021,6 +1046,10 @@ class cint(_clear, _int): division. ``**`` requires the exponent to be compile-time integer or the base to be two. + This type is restricted to arithmetic circuits due to the fact + that only arithmetic protocols offer communication-less + public-private integer operations. + :param val: initialization (cint/regint/int/cgf2n or list thereof) :param size: vector size (int), defaults to 1 or size of list @@ -1066,8 +1095,8 @@ def store_in_mem(self, address): self._store_in_mem(address, stmc, stmci) @staticmethod - def in_immediate_range(value): - if program.options.ring: + def in_immediate_range(value, regint=False): + if program.options.ring and not regint: if abs(value) > 2 ** int(program.options.ring): raise CompilerError('value outside range for domain') return value < 2**31 and value >= -2**31 @@ -1104,6 +1133,11 @@ def load_int(self, val): elif chunk: sum += sign * chunk + def load_other(self, val): + if isinstance(val, cfix): + val = val.v.round(val.k, val.f) + super(cint, self).load_other(val) + @vectorize def to_regint(self, n_bits=64, dest=None): """ Convert to regint. @@ -1279,9 +1313,10 @@ def right_shift(self, other, bit_length=None): :param other: cint/regint/int """ return self >> other - def round(self, k, m, kappa=None, nearest=None, signed=False): + def round(self, k, m, nearest=None, signed=False): if signed: self += 2 ** (k - 1) + self += 2 ** (m - 1) res = self >> m if signed: res -= 2 ** (k - m - 1) @@ -1292,7 +1327,7 @@ def greater_than(self, other, bit_length=None): return self > other @vectorize - def bit_decompose(self, bit_length=None, kappa=None, maybe_mixed=None): + def bit_decompose(self, bit_length=None, maybe_mixed=None): """ Clear bit decomposition. :param bit_length: number of bits (default is global bit length) @@ -1563,7 +1598,7 @@ def __init__(self, val=None, size=None): super(regint, self).__init__(self.reg_type, val=val, size=size) def load_int(self, val): - if cint.in_immediate_range(val): + if cint.in_immediate_range(val, regint=True): ldint(self, val) else: lower = val % 2**32 @@ -1904,6 +1939,22 @@ def read_fix(cls, player, f, k, precision): fixinput(player, tmp, f, precision) return cls(player, cfix._new(tmp, f=f, k=k)) + @classmethod + def read_int_from_socket(cls, player, socket, size=1): + return cls.read_from_socket(player, socket, cint, size) + + @classmethod + def read_fix_from_socket(cls, player, socket, size=1): + return cls.read_from_socket(player, socket, cfix, size) + + @classmethod + def read_from_socket(cls, player, socket, type, size): + tmp = type(size=size) + @library.if_(player == library.get_player_id()._v) + def _(): + tmp.link(type.read_from_socket(socket, size=size)) + return cls(player, tmp) + def binary_output(self): """ Write binary output to ``Player-Data/Binary-Output-P-`` if @@ -2384,18 +2435,19 @@ class sint(_secret, _int): operator is a compile-time power of two. Most non-linear operations require compile-time parameters for bit - length and statistical security. They default to the global - parameters set by :py:meth:`program.set_bit_length` and - :py:meth:`program.set_security`. The acceptable minimum for statistical - security is considered to be 40. The defaults for the parameters + length. It defaults to the global parameters set by + :py:meth:`program.set_bit_length`, and its default is output at the beginning of the compilation. If the computation domain is modulo a power of two, the - operands will be truncated to the bit length, and the security - parameter does not matter. Modulo prime, the behaviour is + operands will be truncated to the bit length. + Modulo prime, the behaviour is undefined and potentially insecure if the operands are longer than the bit length. + See :ref:`nonlinear` for an overview of how non-linear + computation is implemented. + :param val: initialization (sint/cint/regint/int/cgf2n or list thereof, sbits/sbitvec/sfix, or :py:class:`personal`) :param size: vector size (int), defaults to 1 or size of list @@ -2416,7 +2468,11 @@ class sint(_secret, _int): PreOp = staticmethod(floatingpoint.PreOpL) PreOR = staticmethod(floatingpoint.PreOR) - get_type = staticmethod(lambda n: sint) + + @classmethod + def get_type(cls, n): + cls.require_bit_length(n or 0) + return cls @staticmethod def require_bit_length(n_bits): @@ -2494,7 +2550,8 @@ def get_edabit(cls, n_bits, strict=False): else: a = [sint.get_random_bit() for i in range(n_bits)] return sint.bit_compose(a), a - program.curr_tape.require_bit_length(n_bits - 1) + assert n_bits > 0 + program.curr_tape.require_bit_length(n_bits - 2) whole = cls() size = get_global_vector_size() from Compiler.GC.types import sbits, sbitvec @@ -2558,6 +2615,8 @@ def reveal_to_clients(cls, clients, values): for value in values: assert(value.size == values[0].size) + r = sint.get_random(size=value.size) + value += r - r.reveal() if program.active: r = sint.get_random() to_send += [value, r, value * r] @@ -2717,7 +2776,7 @@ def __abs__(self): @read_mem_value @type_comp @vectorize - def __lt__(self, other, bit_length=None, security=None): + def __lt__(self, other, bit_length=None): """ Secret comparison (signed). :param other: sint/cint/regint/int @@ -2725,42 +2784,39 @@ def __lt__(self, other, bit_length=None, security=None): :return: 0/1 (sintbit) """ res = sintbit() comparison.LTZ(res, self - other, - (bit_length or program.bit_length) + 1, - security) + (bit_length or program.bit_length) + 1) return res @read_mem_value @type_comp @vectorize - def __gt__(self, other, bit_length=None, security=None): + def __gt__(self, other, bit_length=None): res = sintbit() comparison.LTZ(res, other - self, - (bit_length or program.bit_length) + 1, - security) + (bit_length or program.bit_length) + 1) return res @read_mem_value @type_comp - def __le__(self, other, bit_length=None, security=None): - return 1 - self.greater_than(other, bit_length, security) + def __le__(self, other, bit_length=None): + return 1 - self.greater_than(other, bit_length) @read_mem_value @type_comp - def __ge__(self, other, bit_length=None, security=None): - return 1 - self.less_than(other, bit_length, security) + def __ge__(self, other, bit_length=None): + return 1 - self.less_than(other, bit_length) @read_mem_value @type_comp @vectorize - def __eq__(self, other, bit_length=None, security=None): + def __eq__(self, other, bit_length=None): return sintbit.conv( - floatingpoint.EQZ(self - other, bit_length or program.bit_length, - security)) + floatingpoint.EQZ(self - other, bit_length or program.bit_length)) @read_mem_value @type_comp - def __ne__(self, other, bit_length=None, security=None): - return 1 - self.equal(other, bit_length, security) + def __ne__(self, other, bit_length=None): + return 1 - self.equal(other, bit_length) less_than = __lt__ greater_than = __gt__ @@ -2787,7 +2843,7 @@ def __mod__(self, modulus): @vectorize @read_mem_value - def mod2m(self, m, bit_length=None, security=None, signed=True): + def mod2m(self, m, bit_length=None, signed=True): """ Secret modulo power of two. :param m: secret or public integer (sint/cint/regint/int) @@ -2800,9 +2856,10 @@ def mod2m(self, m, bit_length=None, security=None, signed=True): if m >= bit_length: return self res = sint() - comparison.Mod2m(res, self, bit_length, m, security, signed) + comparison.Mod2m(res, self, bit_length, m, signed=signed) else: - res, pow2 = floatingpoint.Trunc(self, bit_length, m, security, True) + res, pow2 = floatingpoint.Trunc(self, bit_length, m, + compute_modulo=True) return res @vectorize @@ -2815,25 +2872,24 @@ def __rpow__(self, base): return NotImplemented @vectorize - def pow2(self, bit_length=None, security=None): + def pow2(self, bit_length=None): """ Secret power of two. :param bit_length: bit length of input (default: global bit length) """ - return floatingpoint.Pow2(self, bit_length or program.bit_length, \ - security) + return floatingpoint.Pow2(self, bit_length or program.bit_length) - def __lshift__(self, other, bit_length=None, security=None): + def __lshift__(self, other, bit_length=None): """ Secret left shift. :param other: secret or public integer (sint/cint/regint/int) :param bit_length: bit length of input (default: global bit length) """ - return self * util.pow2_value(other, bit_length, security) + return self * util.pow2_value(other, bit_length) @vectorize @read_mem_value - def __rshift__(self, other, bit_length=None, security=None, signed=True): + def __rshift__(self, other, bit_length=None, signed=True): """ Secret right shift. :param other: secret or public integer (sint/cint/regint/int) @@ -2844,12 +2900,12 @@ def __rshift__(self, other, bit_length=None, security=None, signed=True): if other == 0: return self res = sint() - comparison.Trunc(res, self, bit_length, other, security, signed) + comparison.Trunc(res, self, bit_length, other, signed) return res elif isinstance(other, sint): - return floatingpoint.Trunc(self, bit_length, other, security) + return floatingpoint.Trunc(self, bit_length, other) else: - return floatingpoint.Trunc(self, bit_length, sint(other), security) + return floatingpoint.Trunc(self, bit_length, sint(other)) left_shift = __lshift__ right_shift = __rshift__ @@ -2869,23 +2925,22 @@ def __rrshift__(self, other): return floatingpoint.Trunc(other, program.bit_length, self) @vectorize - def bit_decompose(self, bit_length=None, security=None, maybe_mixed=False): + def bit_decompose(self, bit_length=None, maybe_mixed=False): """ Secret bit decomposition. """ if bit_length == 0: return [] bit_length = bit_length or program.bit_length - program.non_linear.check_security(security) return program.non_linear.bit_dec(self, bit_length, bit_length, maybe_mixed) - def TruncMul(self, other, k, m, kappa=None, nearest=False): - return (self * other).round(k, m, kappa, nearest, signed=True) + def TruncMul(self, other, k, m, nearest=False): + return (self * other).round(k, m, nearest, signed=True) - def TruncPr(self, k, m, kappa=None, signed=True): - return floatingpoint.TruncPr(self, k, m, kappa, signed=signed) + def TruncPr(self, k, m, signed=True): + return floatingpoint.TruncPr(self, k, m, signed=signed) @vectorize - def round(self, k, m, kappa=None, nearest=False, signed=False): + def round(self, k, m, nearest=False, signed=False): """ Truncate and maybe round secret :py:obj:`k`-bit integer by :py:obj:`m` bits. :py:obj:`m` can be secret if :py:obj:`nearest` is false, in which case the truncation will be @@ -2895,19 +2950,18 @@ def round(self, k, m, kappa=None, nearest=False, signed=False): :param k: int :param m: secret or compile-time integer (sint/int) - :param kappa: statistical security parameter (int) :param nearest: bool :param signed: bool """ secret = isinstance(m, sint) if nearest: if secret: raise NotImplementedError() - return comparison.TruncRoundNearest(self, k, m, kappa, + return comparison.TruncRoundNearest(self, k, m, signed=signed) else: if secret: - return floatingpoint.Trunc(self, k, m, kappa) - return self.TruncPr(k, m, kappa, signed=signed) + return floatingpoint.Trunc(self, k, m) + return self.TruncPr(k, m, signed=signed) def __truediv__(self, other): """ Secret fixed-point division. @@ -2924,7 +2978,7 @@ def __rtruediv__(self, other): return sfix._new(other) / sfix._new(self) @vectorize - def int_div(self, other, bit_length=None, security=None): + def int_div(self, other, bit_length=None): """ Secret integer division. Note that the domain bit length needs to be about four times the bit length. @@ -2932,10 +2986,9 @@ def int_div(self, other, bit_length=None, security=None): :param bit_length: bit length of input (default: global bit length) """ k = bit_length or program.bit_length - kappa = security - tmp = library.IntDiv(self, other, k, kappa) + tmp = library.IntDiv(self, other, k) res = type(self)() - comparison.Trunc(res, tmp, 2 * k, k, kappa, True) + comparison.Trunc(res, tmp, 2 * k, k, signed=True) return res @vectorize @@ -3100,14 +3153,14 @@ def get_reverse_vector(self): picks(res, self, self.size - 1, -1) return res - def get_vector(self, base=0, size=None): + def get_vector(self, base=0, size=None, skip=1): if size is None: size = len(self) - base if base == 0 and size == len(self): return self assert base + size <= len(self) res = type(self)(size=size) - picks(res, self, base, 1) + picks(res, self, base, skip) return res @classmethod @@ -3118,6 +3171,12 @@ def concat(cls, parts): concats(res, *args) return res + @classmethod + def zip(cls, *parts): + res = cls(size=sum(len(part) for part in parts)) + zips(res, *parts) + return res + class sintbit(sint): """ :py:class:`sint` holding a bit, supporting binary operations (``&, |, ^``). """ @@ -3586,7 +3645,10 @@ def __sub__(self, other): if type(other) == sgf2n: raise CompilerError('Unclear subtraction') from util import bit_not, bit_and, bit_xor - a, b = self.expand(other) + try: + a, b = self.expand(other) + except: + return NotImplemented n = 1 for x in (a + b): try: @@ -3617,7 +3679,7 @@ def __lshift__(self, other): def __rshift__(self, other): return self.compose(self.bit_decompose()[other:]) - def bit_decompose(self, n_bits=None, security=None): + def bit_decompose(self, n_bits=None): if self.bits is None: self.bits = self.force_bit_decompose(self.n_bits) if n_bits is None: @@ -3662,7 +3724,7 @@ def __ge__(self, other): def __gt__(self, other): return (self <= other).bit_not() - def __eq__(self, other, bit_length=None, security=None): + def __eq__(self, other, bit_length=None): diff = self ^ other diff_bits = [x.bit_not() for x in diff.bit_decompose()[:bit_length]] return self.comp_result(util.tree_reduce(lambda x, y: x.bit_and(y), @@ -3892,6 +3954,11 @@ class cfix(_number, _structure): an sfix. It also support comparisons (``==, !=, <, <=, >, >=``), returning either :py:class:`regint` or :py:class:`sbitint`. + Similarly to :py:class:`Compiler.types.cint`, this type is + restricted to arithmetic circuits due to the fact that only + arithmetic protocols offer communication-less public-private + integer operations. + :param v: cfix/float/int """ @@ -4038,13 +4105,7 @@ def conv(cls, other): if isinstance(other, cls): return other else: - try: - res = cfix() - res.load_int(other) - return res - except (TypeError, CompilerError): - pass - return cls(other) + return cls(other) def store_in_mem(self, address): """ Store in memory by public address. """ @@ -4139,7 +4200,7 @@ def __eq__(self, other): if isinstance(other, cfix): return self.v == other.v elif isinstance(other, sfix): - return other.v.equal(self.v, self.k, other.kappa) + return other.v.equal(self.v, self.k) else: raise NotImplementedError @@ -4153,7 +4214,7 @@ def __lt__(self, other): elif isinstance(other, sfix): if(self.k != other.k or self.f != other.f): raise TypeError('Incompatible fixed point types in comparison') - return other.v.greater_than(self.v, self.k, other.kappa) + return other.v.greater_than(self.v, self.k) else: raise NotImplementedError @@ -4164,7 +4225,7 @@ def __le__(self, other): if isinstance(other, cfix): return 1 - (self > other) elif isinstance(other, sfix): - return other.v.greater_equal(self.v, self.k, other.kappa) + return other.v.greater_equal(self.v, self.k) else: raise NotImplementedError @@ -4175,7 +4236,7 @@ def __gt__(self, other): if isinstance(other, cfix): return other.__lt__(self) elif isinstance(other, sfix): - return other.v.less_than(self.v, self.k, other.kappa) + return other.v.less_than(self.v, self.k) else: raise NotImplementedError @@ -4186,7 +4247,7 @@ def __ge__(self, other): if isinstance(other, cfix): return 1 - (self < other) elif isinstance(other, sfix): - return other.v.less_equal(self.v, self.k, other.kappa) + return other.v.less_equal(self.v, self.k) else: raise NotImplementedError @@ -4197,7 +4258,7 @@ def __ne__(self, other): if isinstance(other, cfix): return self.v != other.v elif isinstance(other, sfix): - return other.v.not_equal(self.v, self.k, other.kappa) + return other.v.not_equal(self.v, self.k) else: raise NotImplementedError @@ -4218,7 +4279,6 @@ def __truediv__(self, other): assert self.k == other.k assert self.f == other.f return sfix._new(library.FPDiv(self.v, other.v, self.k, self.f, - other.kappa, nearest=sfix.round_nearest), k=self.k, f=self.f) else: @@ -4261,11 +4321,13 @@ def binary_output(self, player=None): def link(self, other): self.v.link(other.v) + def update(self, other): + self.v.update(other.v) + class _single(_number, _secret_structure): """ Representation as single integer preserving the order """ """ E.g. fixed-point numbers """ __slots__ = ['v'] - kappa = None round_nearest = False """ Whether to round deterministically to nearest instead of probabilistically, e.g. after fixed-point multiplication. """ @@ -4318,6 +4380,9 @@ def read_from_socket(cls, client_id, n=1): def write_to_socket(cls, client_id, values): cls.int_type.write_to_socket(client_id, [x.v for x in values]) + def write_fully_to_socket(self, client_id): + self.v.write_fully_to_socket(client_id) + @vectorized_classmethod def load_mem(cls, address, mem_type=None): """ Load from memory by public address. """ @@ -4331,11 +4396,7 @@ def conv(cls, other): elif isinstance(other, (list, tuple)): return type(other)(cls.conv(x) for x in other) else: - try: - return cls.from_sint(other) - except (TypeError, CompilerError): - pass - return cls(other) + return cls(other) @classmethod def coerce(cls, other): @@ -4452,7 +4513,7 @@ def __eq__(self, other): :rtype: same as internal representation""" other = self.coerce(other) if isinstance(other, (cfix, _single)): - return self.v.equal(other.v, self.k, self.kappa) + return self.v.equal(other.v, self.k) else: raise NotImplementedError @@ -4460,7 +4521,7 @@ def __eq__(self, other): def __le__(self, other): other = self.coerce(other) if isinstance(other, (cfix, _single)): - return self.v.less_equal(other.v, self.k, self.kappa) + return self.v.less_equal(other.v, self.k) else: raise NotImplementedError @@ -4468,7 +4529,7 @@ def __le__(self, other): def __lt__(self, other): other = self.coerce(other) if isinstance(other, (cfix, _single)): - return self.v.less_than(other.v, self.k, self.kappa) + return self.v.less_than(other.v, self.k) else: raise NotImplementedError @@ -4476,7 +4537,7 @@ def __lt__(self, other): def __ge__(self, other): other = self.coerce(other) if isinstance(other, (cfix, _single)): - return self.v.greater_equal(other.v, self.k, self.kappa) + return self.v.greater_equal(other.v, self.k) else: raise NotImplementedError @@ -4484,7 +4545,7 @@ def __ge__(self, other): def __gt__(self, other): other = self.coerce(other) if isinstance(other, (cfix, _single)): - return self.v.greater_than(other.v, self.k, self.kappa) + return self.v.greater_than(other.v, self.k) else: raise NotImplementedError @@ -4492,7 +4553,7 @@ def __gt__(self, other): def __ne__(self, other): other = self.coerce(other) if isinstance(other, (cfix, _single)): - return self.v.not_equal(other.v, self.k, self.kappa) + return self.v.not_equal(other.v, self.k) else: raise NotImplementedError @@ -4575,12 +4636,12 @@ def conv(cls, other): @classmethod def _new(cls, other, k=None, f=None): - res = cls(k=k, f=f) + res = cls(k=k, f=f, initialize=False) res.v = cls.int_type.conv(other) return res @vectorize_init - def __init__(self, _v=None, k=None, f=None, size=None): + def __init__(self, _v=None, k=None, f=None, size=None, initialize=True): if k is None: k = self.k else: @@ -4600,7 +4661,10 @@ def adjust(v): v >>= f_diff return v if _v is None: - self.v = self.int_type(0) + if initialize: + self.v = self.int_type(0) + else: + return elif isinstance(_v, self.int_type): self.load_int(_v) elif isinstance(_v, cfix.scalars): @@ -4673,8 +4737,7 @@ def mul(self, other): max_f = max(self.f, other.f) min_f = min(self.f, other.f) val = self.v.TruncMul(other.v, k + min_f, min_f, - self.kappa, - self.round_nearest) + nearest=self.round_nearest) if 'vec' not in self.__dict__: return self._new(val, k=k, f=max_f) else: @@ -4698,18 +4761,20 @@ def __truediv__(self, other): if util.is_constant_float(other): assert other != 0 log = math.ceil(math.log(abs(other), 2)) + if 2 ** log == other and log < self.f: + return self * 2 ** -log other_length = self.f + log if other_length >= self.k - 1: factor = 2 ** (self.k - other_length - 2) self *= factor other *= factor - if 2 ** log == other: - return self * 2 ** -log + if util.is_zero(self): + return 0 other = self.coerce(other) assert self.k == other.k assert self.f == other.f if isinstance(other, (_fix, cfix)): - v = library.FPDiv(self.v, other.v, self.k, self.f, self.kappa, + v = library.FPDiv(self.v, other.v, self.k, self.f, nearest=self.round_nearest) else: raise TypeError('Incompatible fixed point types in division') @@ -4725,7 +4790,8 @@ def __rtruediv__(self, other): @vectorize def compute_reciprocal(self): """ Secret fixed-point reciprocal. """ - return type(self)(library.FPDiv(cint(2) ** self.f, self.v, self.k, self.f, self.kappa, True)) + return type(self)(library.FPDiv(cint(2) ** self.f, self.v, self.k, + self.f, nearest=True)) def reveal(self): """ Reveal secret fixed-point number. @@ -4767,7 +4833,8 @@ class sfix(_fix): Note that the default precision (16 bits after the dot, 31 bits in total) only allows numbers up to :math:`2^{31-16-1} \\approx - 16000`. You can increase this using :py:func:`set_precision`. + 16000` with the smallest non-zero number being :math:`2^{-16}`. + You can change this using :py:func:`set_precision`. :params _v: int/float/regint/cint/sint/sfloat """ @@ -4777,6 +4844,13 @@ class sfix(_fix): get_type = staticmethod(lambda n: sint) default_type = sint + @classmethod + def get_prec_type(cls, f, k=None): + class sfix_prec(cls): + pass + sfix_prec.set_precision(f, k) + return sfix_prec + @vectorized_classmethod def get_input_from(cls, player, binary=False, n_bytes=None): """ Secret fixed-point input. @@ -4894,14 +4968,19 @@ def pre_mul(self): return self.v def unreduced(self, v, other=None, res_params=None, n_summands=1): - return unreduced_sfix(v, self.k + self.f, self.f, self.kappa) + assert res_params is None or \ + (res_params.k == self.k and res_params.f == self.f) + if other is None: + return unreduced_sfix(v, self.k + self.f, self.f) + else: + return unreduced_sfix(v, self.k + other.f, other.f) @staticmethod def multipliable(v, k, f, size): return cfix._new(cint.conv(v, size=size), k, f) def dot(self, other): - """ Dot product with :py:class:`sint:`. """ + """ Dot product with :py:class:`sint`. """ if isinstance(other, sint): return self._new(sint.dot_product(self.v, other), k=self.k, f=self.f) else: @@ -4948,6 +5027,17 @@ def concat(cls, parts): int_parts.append(part.v) return cls._new(cls.int_type.concat(int_parts), k=k, f=f) + @classmethod + def zip(cls, *parts): + int_parts = [] + f = parts[0].f + k = parts[0].k + for part in parts: + assert part.f == f + assert part.k == k + int_parts.append(part.v) + return cls._new(cls.int_type.zip(*int_parts), k=k, f=f) + def __repr__(self): return '' % (self.f, self.k, self.v) @@ -4956,13 +5046,12 @@ class unreduced_sfix(_single): @classmethod def _new(cls, v): - return cls(v, sfix.k + sfix.f, sfix.f, sfix.kappa) + return cls(v, sfix.k + sfix.f, sfix.f) - def __init__(self, v, k, m, kappa): + def __init__(self, v, k, m): self.v = v self.k = k self.m = m - self.kappa = kappa assert self.k is not None assert self.m is not None @@ -4971,14 +5060,13 @@ def __add__(self, other): return self assert self.k == other.k assert self.m == other.m - assert self.kappa == other.kappa - return unreduced_sfix(self.v + other.v, self.k, self.m, self.kappa) + return unreduced_sfix(self.v + other.v, self.k, self.m) __radd__ = __add__ @vectorize def reduce_after_mul(self): - v = sfix.int_type.round(self.v, self.k, self.m, self.kappa, + v = sfix.int_type.round(self.v, self.k, self.m, nearest=sfix.round_nearest, signed=True) return sfix._new(v, k=self.k - self.m, f=self.m) @@ -5180,13 +5268,13 @@ def reduce(self, unreduced): int_mult = util.expand(int_mult, size) tmp = unreduced.v * int_mult + shifted_Z shifted = tmp.round(self.max_length, n_shift, - kappa=squant.kappa, nearest=squant.round_nearest, + nearest=squant.round_nearest, signed=True) if squant.clamp: length = max(self.k, self.max_length - n_shift) + 1 top = (1 << self.k) - 1 - over = shifted.greater_than(top, length, squant.kappa) - under = shifted.less_than(0, length, squant.kappa) + over = shifted.greater_than(top, length) + under = shifted.less_than(0, length) shifted = over.if_else(top, shifted) shifted = under.if_else(0, shifted) return squant._new(shifted, params=self) @@ -5223,7 +5311,6 @@ class sfloat(_number, _secret_structure): # single precision vlen = 24 plen = 8 - kappa = None round_nearest = False @staticmethod @@ -5327,14 +5414,14 @@ def __init__(self, v, p=None, z=None, s=None, size=None): elif isinstance(v, sfix): f = v.f v, p, z, s = floatingpoint.Int2FL(v.v, v.k, - self.vlen, self.kappa) + self.vlen) p = p - f elif util.is_constant_float(v): v, p, z, s = self.convert_float(v, self.vlen, self.plen) else: v, p, z, s = floatingpoint.Int2FL(sint.conv(v), program.bit_length, - self.vlen, self.kappa) + self.vlen) if isinstance(v, int): if not ((v >= 2**(self.vlen-1) and v < 2**(self.vlen)) or v == 0): raise CompilerError('Floating point number malformed: significand') @@ -5392,9 +5479,9 @@ def add(self, other): s2 = other.s z1 = self.z z2 = other.z - a = p1.less_than(p2, self.plen, self.kappa) - b = floatingpoint.EQZ(p1 - p2, self.plen, self.kappa) - c = v1.less_than(v2, self.vlen, self.kappa) + a = p1.less_than(p2, self.plen) + b = floatingpoint.EQZ(p1 - p2, self.plen) + c = v1.less_than(v2, self.vlen) ap1 = a*p1 ap2 = a*p2 aneg = 1 - a @@ -5410,10 +5497,9 @@ def add(self, other): vmin = bneg*(av1 + v2 - av2) + b*(cv1 + v2 - cv2) s3 = s1 + s2 - 2 * s1 * s2 comparison.LTZ(d, self.vlen + pmin - pmax + sfloat.round_nearest, - self.plen, self.kappa) + self.plen) pow_delta = floatingpoint.Pow2((1 - d) * (pmax - pmin), - self.vlen + 1 + sfloat.round_nearest, - self.kappa) + self.vlen + 1 + sfloat.round_nearest) # deviate from paper for more precision #v3 = 2 * (vmax - s3) + 1 v3 = vmax @@ -5429,25 +5515,24 @@ def add(self, other): to_trunc *= two_power(self.vlen + sfloat.round_nearest) v = to_trunc * floatingpoint.Inv(pow_delta) comparison.Trunc(t, v, 2 * self.vlen + 1 + sfloat.round_nearest, - self.vlen - 1, self.kappa, False) + self.vlen - 1, signed=False) v = t u = floatingpoint.BitDec(v, self.vlen + 2 + sfloat.round_nearest, - self.vlen + 2 + sfloat.round_nearest, self.kappa, + self.vlen + 2 + sfloat.round_nearest, list(range(1 + sfloat.round_nearest, self.vlen + 2 + sfloat.round_nearest))) # using u[0] doesn't seem necessary - h = floatingpoint.PreOR(u[:sfloat.round_nearest:-1], self.kappa) + h = floatingpoint.PreOR(u[:sfloat.round_nearest:-1]) p0 = self.vlen + 1 - sum(h) pow_p0 = 1 + sum([two_power(i) * (1 - h[i]) for i in range(len(h))]) if self.round_nearest: t2, overflow = \ floatingpoint.TruncRoundNearestAdjustOverflow(pow_p0 * v, self.vlen + 3, - self.vlen, - self.kappa) + self.vlen) p0 = p0 - overflow else: - comparison.Trunc(t2, pow_p0 * v, self.vlen + 2, 2, self.kappa, False) + comparison.Trunc(t2, pow_p0 * v, self.vlen + 2, 2, signed=False) v = t2 # deviate for more precision #p = pmax - p0 + 1 - d @@ -5455,7 +5540,7 @@ def add(self, other): zz = self.z*other.z zprod = 1 - self.z - other.z + zz v = zprod*t2 + self.z*v2 + other.z*v1 - z = floatingpoint.EQZ(v, self.vlen, self.kappa) + z = floatingpoint.EQZ(v, self.vlen) p = (zprod*p + self.z*p2 + other.z*p1)*(1 - z) s = (1 - b)*(a*other.s + aneg*self.s) + b*(c*other.s + cneg*self.s) s = zprod*s + (other.z - zz)*self.s + (self.z - zz)*other.s @@ -5477,12 +5562,13 @@ def mul(self, other): comparison.ld2i(c2expl, self.vlen) if sfloat.round_nearest: v1 = comparison.TruncRoundNearest(self.v*other.v, 2*self.vlen, - self.vlen-1, self.kappa) + self.vlen-1) else: - comparison.Trunc(v1, self.v*other.v, 2*self.vlen, self.vlen-1, self.kappa, False) + comparison.Trunc(v1, self.v*other.v, 2*self.vlen, self.vlen-1, + signed=False) t = v1 - c2expl - comparison.LTZ(b, t, self.vlen+1, self.kappa) - comparison.Trunc(v2, b*v1 + v1, self.vlen+1, 1, self.kappa, False) + comparison.LTZ(b, t, self.vlen+1) + comparison.Trunc(v2, b*v1 + v1, self.vlen+1, 1, signed=False) z1, z2, s1, s2, p1, p2 = (x.expand_to_vector() for x in \ (self.z, other.z, self.s, other.s, self.p, other.p)) @@ -5510,10 +5596,10 @@ def __truediv__(self, other): :param other: sfloat/float/sfix/sint/cint/regint/int """ other = self.conv(other) v = floatingpoint.SDiv(self.v, other.v + other.z * (2**self.vlen - 1), - self.vlen, self.kappa, self.round_nearest) - b = v.less_than(two_power(self.vlen-1), self.vlen + 1, self.kappa) - overflow = v.greater_equal(two_power(self.vlen), self.vlen + 1, self.kappa) - underflow = v.less_than(two_power(self.vlen-2), self.vlen + 1, self.kappa) + self.vlen, round_nearest=self.round_nearest) + b = v.less_than(two_power(self.vlen-1), self.vlen + 1) + overflow = v.greater_equal(two_power(self.vlen), self.vlen + 1) + underflow = v.less_than(two_power(self.vlen-2), self.vlen + 1) v = (v + b * v) * (1 - overflow) * (1 - underflow) + \ overflow * (2**self.vlen - 1) + \ underflow * (2**(self.vlen-1)) * (1 - self.z) @@ -5544,9 +5630,9 @@ def __lt__(self, other): z2 = other.z s1 = self.s s2 = other.s - a = self.p.less_than(other.p, self.plen, self.kappa) - c = floatingpoint.EQZ(self.p - other.p, self.plen, self.kappa) - d = ((1 - 2*self.s)*self.v).less_than((1 - 2*other.s)*other.v, self.vlen + 1, self.kappa) + a = self.p.less_than(other.p, self.plen) + c = floatingpoint.EQZ(self.p - other.p, self.plen) + d = ((1 - 2*self.s)*self.v).less_than((1 - 2*other.s)*other.v, self.vlen + 1) cd = c*d ca = c*a b1 = cd + a - ca @@ -5578,8 +5664,8 @@ def __eq__(self, other): other = self.conv(other) # the sign can be both ways for zeroes both_zero = self.z * other.z - return floatingpoint.EQZ(self.v - other.v, self.vlen, self.kappa) * \ - floatingpoint.EQZ(self.p - other.p, self.plen, self.kappa) * \ + return floatingpoint.EQZ(self.v - other.v, self.vlen) * \ + floatingpoint.EQZ(self.p - other.p, self.plen) * \ (1 - self.s - other.s + 2 * self.s * other.s) * \ (1 - both_zero) + both_zero @@ -5592,17 +5678,17 @@ def __ne__(self, other): del op def log2(self): - up = self.v.greater_than(1 << (self.vlen - 1), self.vlen, self.kappa) + up = self.v.greater_than(1 << (self.vlen - 1), self.vlen) return self.p + self.vlen - 1 + up def round_to_int(self): """ Secret floating-point rounding to integer. :return: sint """ - direction = self.p.greater_equal(-self.vlen, self.plen, self.kappa) - right = self.v.right_shift(-self.p - 1, self.vlen + 1, self.kappa) - up = right.mod2m(1, self.vlen + 1, self.kappa) - right = right.right_shift(1, self.vlen + 1, self.kappa) + up + direction = self.p.greater_equal(-self.vlen, self.plen) + right = self.v.right_shift(-self.p - 1, self.vlen + 1) + up = right.mod2m(1, self.vlen + 1) + right = right.right_shift(1, self.vlen + 1) + up abs_value = direction * right return self.s.if_else(-abs_value, abs_value) @@ -5691,12 +5777,6 @@ def reveal_to_clients(self, clients): """ self.value_type.reveal_to_clients(clients, [self.get_vector()]) - @staticmethod - def _cmp_fail(*args): - raise CompilerError('equality of data structures is not implemented') - - __eq__ = __ne__ = __le__ = __lt__ = __gt__ = __ge__ = _cmp_fail - class Array(_vectorizable): """ Array accessible by public index. That is, ``a[i]`` works for an @@ -5791,7 +5871,7 @@ def get_address(self, index, size=None): if isinstance(index, (_secret, _single)): raise CompilerError('need cleartext index') key = str(index), size or 1 - if not util.is_constant(index): + if isinstance(index, _clear): index = regint.conv(index) if self.length is not None: from .GC.types import cbits @@ -5929,9 +6009,9 @@ def __iter__(self): for i in range(self.length): yield self[i] - def same_shape(self): + def same_shape(self, **kwargs): """ Array of same length and type. """ - return Array(self.length, self.value_type) + return Array(self.length, self.value_type, **kwargs) def assign(self, other, base=0): """ Assignment. @@ -5986,7 +6066,7 @@ def _(base, size): self.assign_vector(self.value_type(value, size=size), base) else: v = mem_value.read() - if isinstance(v, sint): + if isinstance(v, (sint, sfix)): self.assign_vector(v.expand_to_vector(size), base=base) else: @library.for_range_opt(size) @@ -6033,7 +6113,7 @@ def get_slice_addresses(self, slice): assert len(slice) <= self.total_size() base = regint.inc(len(slice), slice.address, 1, 1) inc = regint.inc(len(slice), self.address, 1, 1, 1) - addresses = slice.value_type.load_mem(base) + inc + addresses = regint.conv(slice.value_type.load_mem(base)) + inc return addresses def get_slice_vector(self, slice): @@ -6192,6 +6272,24 @@ def __pow__(self, value): :param other: compile-time integer (int) """ return self.get_vector() ** value + def __eq__(self, other): + return self.get_vector() == other + + def __ne__(self, other): + return self.get_vector() != other + + def __lt__(self, other): + return self.get_vector() < other + + def __le__(self, other): + return self.get_vector() <= other + + def __gt__(self, other): + return self.get_vector() > other + + def __ge__(self, other): + return self.get_vector() >= other + __radd__ = __add__ __rmul__ = __mul__ @@ -6214,6 +6312,11 @@ def __itruediv__(self, other): def __neg__(self): return -self.get_vector() + def dot(self, other): + """ Dot product with another array. """ + M = Matrix(1, len(self), self.value_type, address=self.address) + return M.dot(other) + def shuffle(self): """ Insecure shuffle in place. """ self.assign_vector(self.get(regint.inc(len(self)).shuffle())) @@ -6368,7 +6471,10 @@ def __init__(self, sizes, value_type, address, index, debug=None): self.sizes = tuple(sizes) self.value_type = _get_type(value_type) if address is not None: - self.address = address + index * self.total_size() + if not util.is_zero(index): + self.address = address + index * self.total_size() + else: + self.address = address else: self.address = None self.sub_cache = {} @@ -6486,7 +6592,7 @@ def assign(self, other, base=0): try: if self.value_type.n_elements() > 1: assert self.sizes == other.sizes - self.assign_vector(other.get_vector()) + self.assign_vector(other.get_vector(), base=base) except: for i, x in enumerate(other): self[base + i].assign(x) @@ -6606,12 +6712,12 @@ def assign_vector_by_indices(self, vector, *indices): addresses = self.get_addresses(*indices) vector.store_in_mem(addresses) - def same_shape(self): + def same_shape(self, **kwargs): """ :return: new multidimensional array with same shape and basic type """ if len(self.sizes) == 2: - return Matrix(*self.sizes, self.value_type) + return Matrix(*self.sizes, self.value_type, **kwargs) else: - return MultiArray(self.sizes, self.value_type) + return MultiArray(self.sizes, self.value_type, **kwargs) def get_part(self, start, size): """ Part multi-array. @@ -6733,6 +6839,30 @@ def __sub__(self, other): return self.from_vector( self.sizes, self.get_vector() - other.get_vector()) + def __eq__(self, other): + return self.from_vector( + self.sizes, self.get_vector() == other) + + def __ne__(self, other): + return self.from_vector( + self.sizes, self.get_vector() != other) + + def __lt__(self, other): + return self.from_vector( + self.sizes, self.get_vector() < other) + + def __le__(self, other): + return self.from_vector( + self.sizes, self.get_vector() <= other) + + def __gt__(self, other): + return self.from_vector( + self.sizes, self.get_vector() > other) + + def __ge__(self, other): + return self.from_vector( + self.sizes, self.get_vector() >= other) + def iadd(self, other): """ Element-wise addition in place. @@ -6803,6 +6933,11 @@ class t(self.value_type): res_matrix = Matrix(self.sizes[0], other.sizes[1], t) try: try: + # force matmuls for smaller sizes + a, c = res_matrix.sizes + if a * c / (a + c) < 2 and \ + self.value_type == other.value_type: + raise AttributeError() self.value_type.direct_matrix_mul skip_reduce = set((sint, sfix)) == \ set((self.value_type, other.value_type)) @@ -7104,7 +7239,12 @@ def secure_shuffle(self): `_ or Section 3.2 of `Asharov et al. `_ if applicable. """ - self.assign_vector(self.get_vector().secure_shuffle(self.part_size())) + if self.total_size() < 2 ** 28: + self.assign_vector(self.get_vector().secure_shuffle(self.part_size())) + else: + perm = sint.get_secure_shuffle(len(self)) + self.secure_permute(perm) + delshuffle(perm) def secure_permute(self, permutation, reverse=False, n_threads=None): """ Securely permute rows (first index). See @@ -7121,7 +7261,7 @@ def _(i): self.set_column(i, self.get_column(i).secure_permute( permutation, reverse=reverse)) - def sort(self, key_indices=None, n_bits=None): + def sort(self, key_indices=None, n_bits=None, batcher=False): """ Sort sub-arrays (different first index) in place. This uses `radix sort `_. @@ -7129,6 +7269,7 @@ def sort(self, key_indices=None, n_bits=None): ``(1, 2)`` to sort three-dimensional array ``a`` by keys ``a[*][1][2]``. Default is ``(0, ..., 0)`` of correct length. :param n_bits: number of bits in keys (default: global bit length) + :param batcher: whether to use Batcher's odd-even merge sorting """ if key_indices is None: @@ -7136,7 +7277,7 @@ def sort(self, key_indices=None, n_bits=None): if len(key_indices) != len(self.sizes) - 1: raise CompilerError('length of key_indices has to be one less ' 'than the dimension') - if program.options.binary: + if program.options.binary or batcher: assert len(self.sizes) == 2 library.loopy_odd_even_merge_sort(self, key_indices=key_indices) return @@ -7209,8 +7350,10 @@ def reveal_to_binary_output(self, player=None): self.get_vector().reveal_to(player).binary_output() def __str__(self): - return '%s multi-array of lengths %s at %s' % (self.value_type, - self.sizes, self.address) + return '%s multi-array of lengths %s at %s' % ( + self.value_type, self.sizes, + '' if self.array._address is None else self.address) + __repr__ = __str__ class MultiArray(SubMultiArray): """ @@ -7405,7 +7548,7 @@ def iadd(self, other): store_in_mem = lambda self,address: self.read().store_in_mem(address) -class MemValue(_mem): +class MemValue(_mem, _vectorizable): """ Single value in memory. This is useful to transfer information between threads. Operations are automatically read from memory if required, this means you can use any operation with @@ -7423,23 +7566,23 @@ def if_necessary(cls, value): else: return cls(value) - def __init__(self, value, address=None): + def __init__(self, value, address=None, write=True): self.last_write_block = None + if isinstance(value, MemValue): + value = value.read() if isinstance(value, int): self.value_type = regint value = regint(value) - elif isinstance(value, MemValue): - self.value_type = value.value_type else: self.value_type = type(value) self.deleted = False + self.size = value.size_for_mem() if address is None: - self.address = self.value_type.malloc(value.size) - self.size = value.size - self.write(value) + self.address = self.value_type.malloc(self.size) + if write: + self.write(value) else: self.address = address - self.size = 1 def delete(self): self.value_type.free(self.address) @@ -7474,7 +7617,7 @@ def write(self, value): except: raise CompilerError('Cannot store %s as MemValue of %s' % \ (type(value), self.value_type)) - if value.size != self.size: + if value.size_for_mem() != self.size: raise CompilerError('size mismatch') self.register = value if not isinstance(self.register, self.value_type): @@ -7490,18 +7633,18 @@ def reveal(self): :return: relevant clear type """ return self.read().reveal() - less_than = lambda self,other,bit_length=None,security=None: \ - self.read().less_than(other,bit_length,security) - greater_than = lambda self,other,bit_length=None,security=None: \ - self.read().greater_than(other,bit_length,security) - less_equal = lambda self,other,bit_length=None,security=None: \ - self.read().less_equal(other,bit_length,security) - greater_equal = lambda self,other,bit_length=None,security=None: \ - self.read().greater_equal(other,bit_length,security) - equal = lambda self,other,bit_length=None,security=None: \ - self.read().equal(other,bit_length,security) - not_equal = lambda self,other,bit_length=None,security=None: \ - self.read().not_equal(other,bit_length,security) + less_than = lambda self,other,bit_length=None: \ + self.read().less_than(other,bit_length) + greater_than = lambda self,other,bit_length=None: \ + self.read().greater_than(other,bit_length) + less_equal = lambda self,other,bit_length=None: \ + self.read().less_equal(other,bit_length) + greater_equal = lambda self,other,bit_length=None: \ + self.read().greater_equal(other,bit_length) + equal = lambda self,other,bit_length=None: \ + self.read().equal(other,bit_length) + not_equal = lambda self,other,bit_length=None: \ + self.read().not_equal(other,bit_length) pow2 = lambda self,*args,**kwargs: self.read().pow2(*args, **kwargs) mod2m = lambda self,*args,**kwargs: self.read().mod2m(*args, **kwargs) @@ -7524,6 +7667,11 @@ def expand_to_vector(self, size=None): addresses = regint.inc(size, self.address, 0) return self.value_type.load_mem(addresses) + shape = property(lambda self: ('mv', self.size)) + + def same_shape(self, address=None): + return type(self)(self.value_type(size=self.size), address=address) + def __repr__(self): return 'MemValue(%s,%s)' % (self.value_type, self.address) diff --git a/Compiler/util.py b/Compiler/util.py index 6c3c3ce59..6e9f43554 100644 --- a/Compiler/util.py +++ b/Compiler/util.py @@ -28,11 +28,11 @@ def greater_than(a, b, bits): else: return a.greater_than(b, bits) -def pow2_value(a, bit_length=None, security=None): +def pow2_value(a, bit_length=None): if is_constant_float(a): return 2**a else: - return a.pow2(bit_length, security) + return a.pow2(bit_length) def mod2m(a, b, bits, signed): if isinstance(a, int): diff --git a/ECDSA/Fake-ECDSA.cpp b/ECDSA/Fake-ECDSA.cpp index 510db51ac..226044f37 100644 --- a/ECDSA/Fake-ECDSA.cpp +++ b/ECDSA/Fake-ECDSA.cpp @@ -15,7 +15,7 @@ int main() { P256Element::init(); - P256Element::Scalar key; + KeySetup> key; string prefix = PREP_DIR "ECDSA/"; mkdir_p(prefix.c_str()); write_online_setup(prefix, P256Element::Scalar::pr()); diff --git a/ECDSA/fake-spdz-ecdsa-party.cpp b/ECDSA/fake-spdz-ecdsa-party.cpp index ea19c8ee3..1ce82ff35 100644 --- a/ECDSA/fake-spdz-ecdsa-party.cpp +++ b/ECDSA/fake-spdz-ecdsa-party.cpp @@ -44,6 +44,7 @@ int main(int argc, const char** argv) typedef Share pShare; string prefix = get_prep_sub_dir(PREP_DIR "ECDSA/", 2); read_mac_key(prefix, N, keyp); + pShare::set_mac_key(keyp); pShare::MAC_Check::setup(P); Share::MAC_Check::setup(P); diff --git a/ECDSA/preprocessing.hpp b/ECDSA/preprocessing.hpp index 0a5e0ab9c..f13b87dae 100644 --- a/ECDSA/preprocessing.hpp +++ b/ECDSA/preprocessing.hpp @@ -37,6 +37,9 @@ void preprocessing(vector>& tuples, int buffer_size, EcdsaOptions opts) { bool prep_mul = opts.prep_mul; + if (prep_mul) + proc.protocol.init_mul(); + Timer timer; timer.start(); Player& P = proc.P; @@ -77,7 +80,6 @@ void preprocessing(vector>& tuples, int buffer_size, MCc.POpen_Begin(opened_Rs, secret_Rs, extra_player); if (prep_mul) { - protocol.init_mul(); for (int i = 0; i < buffer_size; i++) protocol.prepare_mul(inv_ks[i], sk); protocol.start_exchange(); diff --git a/ECDSA/sy-rep-ecdsa-party.cpp b/ECDSA/sy-rep-ecdsa-party.cpp index 88447d7f6..fda4e35e8 100644 --- a/ECDSA/sy-rep-ecdsa-party.cpp +++ b/ECDSA/sy-rep-ecdsa-party.cpp @@ -23,6 +23,7 @@ #include "Protocols/SpdzWisePrep.hpp" #include "Protocols/SpdzWiseInput.hpp" #include "Protocols/SpdzWiseShare.hpp" +#include "Protocols/SpdzWiseRep3Shuffler.hpp" #include "Processor/Data_Files.hpp" #include "Processor/Instruction.hpp" #include "Processor/Machine.hpp" diff --git a/ExternalIO/Client.hpp b/ExternalIO/Client.hpp index c401d86b1..4edba785d 100644 --- a/ExternalIO/Client.hpp +++ b/ExternalIO/Client.hpp @@ -24,6 +24,13 @@ Client::Client(const vector& hostnames, int port_base, "P" + to_string(i), "C" + to_string(my_client_id), true); if (i == 0) specification.Receive(sockets[0]); + else + { + octetStream spec; + spec.Receive(sockets[i]); + if (spec != specification) + throw runtime_error("inconsistent specification"); + } } } diff --git a/ExternalIO/client.py b/ExternalIO/client.py index 84a392590..1980576ae 100644 --- a/ExternalIO/client.py +++ b/ExternalIO/client.py @@ -31,6 +31,15 @@ def set_keepalive_osx(sock, after_idle_sec=1, interval_sec=3, max_fails=5): sock.setsockopt(socket.IPPROTO_TCP, TCP_KEEPALIVE, interval_sec) class Client: + """Client to servers running secure computation. Works both as a client + to all parties or a trusted client to a single party. + + :param hostnames: hostnames or IP addresses to connect to + :param port_base: port number for first hostname, + increases by one for every additional hostname + :param my_client_id: number to identify client + + """ def __init__(self, hostnames, port_base, my_client_id): ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) name = 'C%d' % my_client_id @@ -62,6 +71,11 @@ def __init__(self, hostnames, port_base, my_client_id): self.specification = octetStream() self.specification.Receive(self.sockets[0]) + for sock in self.sockets[1:]: + specification = octetStream() + specification.Receive(sock) + if specification.buf != self.specification.buf: + raise Exception('inconsistent specification') type = self.specification.get_int(4) if type == ord('R'): self.domain = Z2(self.specification.get_int(4)) @@ -99,6 +113,12 @@ def receive_triples(self, T, n): return triples def send_private_inputs(self, values): + """ Send inputs privately to the computation servers. + This assumes that the client is connected to all servers. + + :param values: list of input values + + """ T = self.domain triples = self.receive_triples(T, len(values)) os = octetStream() @@ -109,10 +129,46 @@ def send_private_inputs(self, values): os.Send(socket) def receive_outputs(self, n): + """ Receive outputs privately from the computation servers. + This assumes that the client is connected to all servers. + + :param n: number of outputs + + """ T = self.domain triples = self.receive_triples(T, n) return [int(self.clear_domain(triple[0].v)) for triple in triples] + def send_public_inputs(self, values): + """ Send values in the clear. This works for public inputs + to all servers or to send shares to a single server. + + :param values: list of values + + """ + os = octetStream() + for value in values: + self.domain(value).pack(os) + for socket in self.sockets: + os.Send(socket) + + def receive_plain_values(self, socket=None): + """ Receive values in the clear. This works for public inputs + to all servers or to send shares to a single server. + + :param socket: socket to use (need to specify it there is more than one) + + """ + if socket is None: + if len(self.sockets) != 1: + raise Exception('need to specify socket') + socket = self.sockets[0] + os = octetStream() + os.Receive(socket) + assert len(os) % self.domain.size() == 0 + return [int(os.get(self.domain)) + for i in range(len(os) // self.domain.size())] + class octetStream: def __init__(self, value=None): self.buf = b'' @@ -123,6 +179,8 @@ def __init__(self, value=None): def get_length(self): return len(self.buf) + __len__ = get_length + def reset_write_head(self): self.buf = b'' self.ptr = 0 @@ -164,6 +222,11 @@ def get_bigint(self): else: return 0 + def get(self, type): + res = type() + res.unpack(self) + return res + def consume(self, length): self.ptr += length assert self.ptr <= len(self.buf) diff --git a/ExternalIO/domains.py b/ExternalIO/domains.py index 1c4a2a9df..d154a9dd7 100644 --- a/ExternalIO/domains.py +++ b/ExternalIO/domains.py @@ -7,7 +7,7 @@ def __init__(self, value=0): def __int__(self): res = self.v % self.modulus - return res if res < self.modulus / 2 else res - self.modulus + return int(res if res < self.modulus / 2 else res - self.modulus) def __add__(self, other): try: diff --git a/ExternalIO/personal-client-example.py b/ExternalIO/personal-client-example.py new file mode 100755 index 000000000..20fb4ea05 --- /dev/null +++ b/ExternalIO/personal-client-example.py @@ -0,0 +1,19 @@ +#!/usr/bin/python3 + +import sys, random + +sys.path.insert(0, 'ExternalIO') + +from client import * + +party = int(sys.argv[1]) + +client = Client(['localhost'], 15000 + party, 0) + +n = 1000 + +if party < 2: + client.send_public_inputs(random.gauss(0, 1) * 2 ** 16 for i in range(n)) + +x = [client.receive_plain_values() for i in range(2)] +client.send_public_inputs(a + b for a, b in zip(*x)) diff --git a/FHEOffline/Producer.cpp b/FHEOffline/Producer.cpp index c3ab59ebf..753a77506 100644 --- a/FHEOffline/Producer.cpp +++ b/FHEOffline/Producer.cpp @@ -13,6 +13,8 @@ #include "SimpleMachine.h" #include "Tools/mkpath.h" +#include "Protocols/Share.hpp" + template Producer::Producer(int output_thread, bool write_output) : n_slots(0), output_thread(output_thread), write_output(write_output), diff --git a/GC/BitAdder.hpp b/GC/BitAdder.hpp index a3f821a0d..37f0eb523 100644 --- a/GC/BitAdder.hpp +++ b/GC/BitAdder.hpp @@ -8,6 +8,8 @@ #include "BitAdder.h" +#include "Protocols/BufferScope.h" + #include template @@ -69,6 +71,9 @@ void BitAdder::add(vector >& res, size_t n_items = end - begin; + if (OnlineOptions::singleton.has_option("verbose_and")) + fprintf(stderr, "%lu ANDs in bit adder\n", length * n_items * n_bits); + if (supply) { #ifdef VERBOSE_EDA @@ -85,6 +90,7 @@ void BitAdder::add(vector >& res, vector carries(n_items); vector a(n_items), b(n_items); auto& protocol = proc.protocol; + BufferScope scope(proc.DataF, n_items * length * n_bits); for (int i = 0; i < n_bits; i++) { assert(summands[i].size() == 2); diff --git a/GC/FakeSecret.cpp b/GC/FakeSecret.cpp index f69130009..e3ecdd113 100644 --- a/GC/FakeSecret.cpp +++ b/GC/FakeSecret.cpp @@ -24,14 +24,14 @@ void FakeSecret::load_clear(int n, const Integer& x) *this = x; } -void FakeSecret::bitcom(Memory& S, const vector& regs) +void FakeSecret::bitcom(StackedVector& S, const vector& regs) { *this = 0; for (unsigned int i = 0; i < regs.size(); i++) *this ^= (S[regs[i]] << i); } -void FakeSecret::bitdec(Memory& S, const vector& regs) const +void FakeSecret::bitdec(StackedVector& S, const vector& regs) const { for (unsigned int i = 0; i < regs.size(); i++) S[regs[i]] = (*this >> i) & 1; diff --git a/GC/FakeSecret.h b/GC/FakeSecret.h index 668b5a967..b7ded479a 100644 --- a/GC/FakeSecret.h +++ b/GC/FakeSecret.h @@ -142,8 +142,8 @@ class FakeSecret : public ShareInterface, public BitVec template void store(Memory& mem, size_t address) { mem[address] = *this; } - void bitcom(Memory& S, const vector& regs); - void bitdec(Memory& S, const vector& regs) const; + void bitcom(StackedVector& S, const vector& regs); + void bitdec(StackedVector& S, const vector& regs) const; template void xor_(int n, const FakeSecret& x, const T& y) diff --git a/GC/Instruction.cpp b/GC/Instruction.cpp index 6be1eb1ab..bdd8014a2 100644 --- a/GC/Instruction.cpp +++ b/GC/Instruction.cpp @@ -84,8 +84,8 @@ void Instruction::parse(istream& s, int pos) ostringstream os; os << "Code not defined for instruction " << showbase << hex << opcode << dec << endl; os << "This virtual machine executes binary circuits only." << endl; - os << "Use 'compile.py -B'." << endl; - throw Invalid_Instruction(os.str()); + os << "Use 'compile.py -B'."; + exit_error(os.str()); break; } } diff --git a/GC/NoShare.h b/GC/NoShare.h index c3d795e7b..8ad05cf3e 100644 --- a/GC/NoShare.h +++ b/GC/NoShare.h @@ -52,7 +52,7 @@ class NoValue : public ValueInterface static DataFieldType field_type() { - throw not_implemented(); + return DATA_GF2; } static void init_minimum(int) @@ -80,7 +80,8 @@ class NoValue : public ValueInterface bool operator!=(NoValue) const { return false; } - bool operator==(int) { fail(); return false; } + bool operator==(int) const { fail(); return false; } + bool operator==(NoValue) const { fail(); return false; } bool get_bit(int) { fail(); return 0; } @@ -92,6 +93,8 @@ class NoValue : public ValueInterface void input(istream&, bool) { fail(); } void output(ostream&, bool) {} + + void pack(octetStream&) const { fail(); } }; inline ostream& operator<<(ostream& o, NoValue) @@ -169,8 +172,8 @@ class NoShare : public ShareInterface void load_clear(Integer, Integer) { fail(); } void random_bit() { fail(); } - void bitdec(vector&, const vector&) const { fail(); } - void bitcom(vector&, const vector&) const { fail(); } + void bitdec(StackedVector&, const vector&) const { fail(); } + void bitcom(StackedVector&, const vector&) const { fail(); } void assign(const char*) { fail(); } @@ -190,6 +193,8 @@ class NoShare : public ShareInterface NoShare& operator+=(const NoShare&) { fail(); return *this; } + bool operator==(NoShare) const { fail(); return false; } + NoShare get_bit(int) const { fail(); return {}; } void xor_bit(int, NoShare) const { fail(); } @@ -201,6 +206,8 @@ class NoShare : public ShareInterface void input(istream&, bool) { fail(); } void output(ostream&, bool) { fail(); } + + void pack(octetStream&) const { fail(); } }; } /* namespace GC */ diff --git a/GC/PersonalPrep.hpp b/GC/PersonalPrep.hpp index 44c4080e0..6bed352a2 100644 --- a/GC/PersonalPrep.hpp +++ b/GC/PersonalPrep.hpp @@ -79,10 +79,11 @@ template void PersonalPrep::buffer_personal_triples(vector>& triples, size_t begin, size_t end) { -#ifdef VERBOSE_EDA - fprintf(stderr, "personal triples %zu to %zu\n", begin, end); RunningTimer timer; -#endif + bool verbose = OnlineOptions::singleton.has_option("verbose_eda"); + if (verbose) + fprintf(stderr, "personal triples %zu to %zu\n", begin, end); + auto& party = ShareThread::s(); auto& MC = party.MC->get_part_MC(); auto& P = *party.P; @@ -102,9 +103,9 @@ void PersonalPrep::buffer_personal_triples(vector>& triples, input.exchange(); for (size_t i = begin; i < end; i++) triples[i][2] = input.finalize(input_player, T::default_length); -#ifdef VERBOSE_EDA - fprintf(stderr, "personal triples took %f seconds\n", timer.elapsed()); -#endif + + if (verbose) + fprintf(stderr, "personal triples took %f seconds\n", timer.elapsed()); } } diff --git a/GC/Processor.h b/GC/Processor.h index da9655b28..8671a6843 100644 --- a/GC/Processor.h +++ b/GC/Processor.h @@ -16,6 +16,7 @@ using namespace std; #include "Math/Integer.h" #include "Processor/ProcessorBase.h" #include "Processor/Instruction.h" +#include "Tools/CheckVector.h" namespace GC { @@ -38,9 +39,9 @@ class Processor : public ::ProcessorBase, public GC::RuntimeBranching // rough measure for the memory usage size_t complexity; - Memory S; - Memory C; - Memory I; + StackedVector S; + StackedVector C; + StackedVector I; Timer xor_timer; @@ -78,8 +79,8 @@ class Processor : public ::ProcessorBase, public GC::RuntimeBranching template void store_clear_in_dynamic(const vector& args, U& dynamic_memory); - template - void mem_op(int n, Memory& dest, const Memory& source, + template + void mem_op(int n, U& dest, const V& source, Integer dest_address, Integer source_address); void xors(const vector& args); @@ -105,6 +106,9 @@ class Processor : public ::ProcessorBase, public GC::RuntimeBranching template void convcbit2s(const BaseInstruction& instruction); + void convcbitvec(const BaseInstruction& instruction, StackedVector& Ci, + Player* P); + void print_reg(int reg, int n, int size); void print_reg_plain(Clear& value); void print_reg_signed(unsigned n_bits, Integer value); @@ -114,6 +118,13 @@ class Processor : public ::ProcessorBase, public GC::RuntimeBranching void print_float_prec(int n); void incint(const BaseInstruction& instruction); + + void push_stack(); + void push_args(const vector& args); + void pop_stack(const vector& results); + + template + void call_tape(const BaseInstruction& instruction, U& dynamic_memory); }; template diff --git a/GC/Processor.hpp b/GC/Processor.hpp index c3b208500..2bcb0e182 100644 --- a/GC/Processor.hpp +++ b/GC/Processor.hpp @@ -18,7 +18,7 @@ using namespace std; #include "Math/BitVec.h" #include "GC/Machine.hpp" -#include "Processor/ProcessorBase.hpp" +#include "Processor/Processor.hpp" #include "Processor/IntInput.hpp" #include "Math/bigint.hpp" @@ -53,9 +53,9 @@ template template void Processor::reset(const U& program, int arg) { - S.resize(program.num_reg(SBIT), "registers"); - C.resize(program.num_reg(CBIT), "registers"); - I.resize(program.num_reg(INT), "registers"); + S.resize(program.num_reg(SBIT)); + C.resize(program.num_reg(CBIT)); + I.resize(program.num_reg(INT)); set_arg(arg); PC = 0; } @@ -202,14 +202,14 @@ void GC::Processor::store_clear_in_dynamic(const vector& args, } template -template -void Processor::mem_op(int n, Memory& dest, const Memory& source, +template +void Processor::mem_op(int n, U& dest, const V& source, Integer dest_address, Integer source_address) { dest.check_index(dest_address + n - 1); source.check_index(source_address + n - 1); - auto d = &dest[dest_address]; - auto s = &source[source_address]; + auto d = &dest[dest_address.get()]; + auto s = &source[source_address.get()]; for (int i = 0; i < n; i++) { *d++ = *s++; @@ -388,6 +388,30 @@ void Processor::convcbit2s(const BaseInstruction& instruction) min(size_t(unit), instruction.get_n() - i * unit)); } +template +void Processor::convcbitvec(const BaseInstruction& instruction, + StackedVector& Ci, Player* P) +{ + vector bits; + auto n = instruction.get_n(); + bits.reserve(n); + for (size_t i = 0; i < instruction.get_n(); i++) + { + int i1 = i / GC::Clear::N_BITS; + int i2 = i % GC::Clear::N_BITS; + auto bit = C[instruction.get_r(1) + i1].get_bit(i2); + bits.push_back(bit); + } + + if (P) + sync(bits, *P); + else if (not T::symmetric) + sync(bits, *Thread::s().P); + + for (size_t i = 0; i < n; i++) + Ci[instruction.get_r(0) + i] = bits[i]; +} + template void Processor::print_reg(int reg, int n, int size) { @@ -417,7 +441,7 @@ void Processor::print_reg_signed(unsigned n_bits, Integer reg) { if (n_bits <= Clear::N_BITS) { - auto value = C[reg]; + auto value = C[reg.get()]; unsigned n_shift = 0; if (n_bits > 1) n_shift = sizeof(value.get()) * 8 - n_bits; @@ -477,6 +501,56 @@ void Processor::incint(const BaseInstruction& instruction) } } +template +void GC::Processor::push_stack() +{ + S.push_stack(); + C.push_stack(); +} + +template +void GC::Processor::push_args(const vector& args) +{ + S.push_args(args, SBIT); + C.push_args(args, CBIT); +} + +template +void GC::Processor::pop_stack(const vector& results) +{ + S.pop_stack(results, SBIT); + C.pop_stack(results, CBIT); +} + +template +template +void Processor::call_tape(const BaseInstruction& instruction, U& dynamic_memory) +{ + auto new_arg = I.at(instruction.get_r(1)).get(); + + PC_stack.push_back(PC); + arg_stack.push_back(this->arg); + push_stack(); + I.push_stack(); + + auto& tape = machine->progs.at(instruction.get_r(0)); + reset(tape, new_arg); + + auto& args = instruction.get_start(); + push_args(args); + I.push_args(args, INT); + + tape.execute(*this, dynamic_memory, PC); + + pop_stack(args); + I.pop_stack(args, INT); + + PC = PC_stack.back(); + PC_stack.pop_back(); + this->arg = arg_stack.back(); + arg_stack.pop_back(); +} + } /* namespace GC */ #endif diff --git a/GC/Rep4Prep.cpp b/GC/Rep4Prep.cpp index 242be3422..ab9f0e824 100644 --- a/GC/Rep4Prep.cpp +++ b/GC/Rep4Prep.cpp @@ -7,6 +7,7 @@ #include "Protocols/Rep4.hpp" #include "Protocols/Rep4Input.hpp" +#include "Protocols/Replicated.hpp" #include "Protocols/ReplicatedPrep.hpp" namespace GC diff --git a/GC/Secret.h b/GC/Secret.h index b6c3b3f65..54addd235 100644 --- a/GC/Secret.h +++ b/GC/Secret.h @@ -76,6 +76,10 @@ class Secret static const bool actual_inputs = T::actual_inputs; + static const bool symmetric = true; + + static bool real_shares(const Player&) { return true; } + static int threshold(int nplayers) { return T::threshold(nplayers); } static Secret input(party_id_t from, const int128& input, int n_bits = -1); @@ -148,9 +152,9 @@ class Secret Secret operator>>(int i) const; template - void bitcom(Memory& S, const vector& regs); + void bitcom(StackedVector& S, const vector& regs); template - void bitdec(Memory& S, const vector& regs) const; + void bitdec(StackedVector& S, const vector& regs) const; Secret operator+(const Secret& x) const; Secret& operator+=(const Secret& x) { *this = *this + x; return *this; } diff --git a/GC/Secret.hpp b/GC/Secret.hpp index 68d794cf5..9e8c2df67 100644 --- a/GC/Secret.hpp +++ b/GC/Secret.hpp @@ -197,7 +197,7 @@ Secret Secret::operator>>(int i) const template template -void Secret::bitcom(Memory& S, const vector& regs) +void Secret::bitcom(StackedVector& S, const vector& regs) { registers.clear(); for (unsigned int i = 0; i < regs.size(); i++) @@ -210,7 +210,7 @@ void Secret::bitcom(Memory& S, const vector& regs) template template -void Secret::bitdec(Memory& S, const vector& regs) const +void Secret::bitdec(StackedVector& S, const vector& regs) const { if (regs.size() > registers.size()) throw overflow("not enough bits for bit decomposition", regs.size(), diff --git a/GC/Semi.cpp b/GC/Semi.cpp index 0a0e3f917..1efdaf3f1 100644 --- a/GC/Semi.cpp +++ b/GC/Semi.cpp @@ -17,7 +17,7 @@ namespace GC void Semi::prepare_mult(const SemiSecret& x, const SemiSecret& y, int n, bool repeat) { - if (repeat and OnlineOptions::singleton.live_prep) + if (repeat and OnlineOptions::singleton.live_prep and (n < 0 or n > 1)) { this->triples.push_back({{}}); auto& triple = this->triples.back(); @@ -35,6 +35,8 @@ void Semi::prepare_mult(const SemiSecret& x, const SemiSecret& y, int n, void Semi::prepare_mul(const SemiSecret& x, const SemiSecret& y, int n) { + if (n == -1) + n = SemiSecret::default_length; super::prepare_mul(x.mask(n), y.mask(n), n); } diff --git a/GC/Semi.h b/GC/Semi.h index 654112531..6148d7b58 100644 --- a/GC/Semi.h +++ b/GC/Semi.h @@ -24,7 +24,7 @@ class Semi : public Beaver void prepare_mult(const SemiSecret& x, const SemiSecret& y, int n, bool repeat); - void prepare_mul(const SemiSecret& x, const SemiSecret& y, int n); + void prepare_mul(const SemiSecret& x, const SemiSecret& y, int n = -1); }; } /* namespace GC */ diff --git a/GC/SemiPrep.cpp b/GC/SemiPrep.cpp index 02cf31a5f..79db31bfe 100644 --- a/GC/SemiPrep.cpp +++ b/GC/SemiPrep.cpp @@ -79,6 +79,9 @@ array SemiPrep::get_mixed_triple(int n) if (mixed_triples.empty()) { assert(this->triple_generator); + this->triple_generator->set_batch_size( + BaseMachine::batch_size(DATA_MIXED, + this->buffer_size)); this->triple_generator->generateMixedTriples(); for (auto& x : this->triple_generator->mixedTriples) { diff --git a/GC/SemiSecret.h b/GC/SemiSecret.h index 4110b9a49..30d7dfdf2 100644 --- a/GC/SemiSecret.h +++ b/GC/SemiSecret.h @@ -38,6 +38,13 @@ class SemiSecretBase : public V, public ShareSecret static const int default_length = sizeof(BitVec) * 8; + static const bool symmetric = V::symmetric; + + static bool real_shares(const Player& P) + { + return V::real_shares(P); + } + static string type_string() { return "binary secret"; } static string phase_name() { return "Binary computation"; } @@ -64,8 +71,8 @@ class SemiSecretBase : public V, public ShareSecret void load_clear(int n, const Integer& x); - void bitcom(Memory& S, const vector& regs); - void bitdec(Memory& S, const vector& regs) const; + void bitcom(StackedVector& S, const vector& regs); + void bitdec(StackedVector& S, const vector& regs) const; void xor_(int n, const T& x, const T& y) { *this = BitVec(x ^ y).mask(n); } diff --git a/GC/SemiSecret.hpp b/GC/SemiSecret.hpp index b147cce36..b869c87a1 100644 --- a/GC/SemiSecret.hpp +++ b/GC/SemiSecret.hpp @@ -70,30 +70,40 @@ void SemiSecret::andrsvec(Processor& processor, assert(protocol); protocol->init_mul(); auto it = args.begin(); + int total_bits = 0, total_ops = 0; while (it < args.end()) { int n_args = (*it++ - 3) / 2; int size = *it++; + total_bits += n_args * size; it += n_args; int base = *it++; - assert(n_args <= N_BITS); for (int i = 0; i < size; i += N_BITS) { - square64 square; - for (int j = 0; j < n_args; j++) - square.rows[j] = processor.S.at(*(it + j) + i / N_BITS).get(); - int n_ops = min(N_BITS, size - i); - square.transpose(n_args, n_ops); - for (int j = 0; j < n_ops; j++) + for (int k = 0; k < n_args; k += N_BITS) { - long bit = processor.S.at(base + i / N_BITS).get_bit(j); - auto y_ext = SemiSecret(bit).extend_bit(); - protocol->prepare_mult(square.rows[j], y_ext, n_args, true); + int left = min(N_BITS, n_args - k); + square64 square; + for (int j = 0; j < left; j++) + square.rows[j] = processor.S.at( + *(it + k + j) + i / N_BITS).get(); + int n_ops = min(N_BITS, size - i); + total_ops += n_ops; + square.transpose(left, n_ops); + for (int j = 0; j < n_ops; j++) + { + long bit = processor.S.at(base + i / N_BITS).get_bit(j); + auto y_ext = SemiSecret(bit).extend_bit(); + protocol->prepare_mult(square.rows[j], y_ext, left, true); + } } } it += n_args; } + if (OnlineOptions::singleton.has_option("verbose_and")) + fprintf(stderr, "%d/%d repeat ANDs\n", total_bits, total_ops); + protocol->exchange(); it = args.begin(); @@ -103,13 +113,18 @@ void SemiSecret::andrsvec(Processor& processor, int size = *it++; for (int i = 0; i < size; i += N_BITS) { - int n_ops = min(N_BITS, size - i); - square64 square; - for (int j = 0; j < n_ops; j++) - square.rows[j] = protocol->finalize_mul(n_args).get(); - square.transpose(n_ops, n_args); - for (int j = 0; j < n_args; j++) - processor.S.at(*(it + j) + i / N_BITS) = square.rows[j]; + for (int base = 0; base < n_args; base += N_BITS) + { + int left = min(N_BITS, n_args - base); + int n_ops = min(N_BITS, size - i); + square64 square; + for (int j = 0; j < n_ops; j++) + square.rows[j] = protocol->finalize_mul(left).get(); + square.transpose(n_ops, left); + for (int j = 0; j < left; j++) + processor.S.at(*(it + base + j) + i / N_BITS) = + square.rows[j]; + } } it += 2 * n_args + 1; } @@ -123,7 +138,7 @@ void SemiSecretBase::load_clear(int n, const Integer& x) } template -void SemiSecretBase::bitcom(Memory& S, const vector& regs) +void SemiSecretBase::bitcom(StackedVector& S, const vector& regs) { *this = 0; for (unsigned int i = 0; i < regs.size(); i++) @@ -131,7 +146,7 @@ void SemiSecretBase::bitcom(Memory& S, const vector& regs) } template -void SemiSecretBase::bitdec(Memory& S, +void SemiSecretBase::bitdec(StackedVector& S, const vector& regs) const { for (unsigned int i = 0; i < regs.size(); i++) diff --git a/GC/ShareParty.hpp b/GC/ShareParty.hpp index 57beaec0d..5a151f1bd 100644 --- a/GC/ShareParty.hpp +++ b/GC/ShareParty.hpp @@ -108,18 +108,9 @@ ShareParty::ShareParty(int argc, const char** argv, ez::ezOptionParser& opt, else P = new PlainPlayer(this->N, "shareparty"); - try - { - read_mac_key( - get_prep_sub_dir(PREP_DIR, network_opts.nplayers), - this->N, - this->mac_key); - } - catch (exception& e) - { - SeededPRNG G; - this->mac_key.randomize(G); - } + T::read_or_generate_mac_key( + get_prep_sub_dir(PREP_DIR, network_opts.nplayers), + *P, this->mac_key); T::MC::setup(*P); diff --git a/GC/ShareSecret.h b/GC/ShareSecret.h index 64189822f..6deea9c80 100644 --- a/GC/ShareSecret.h +++ b/GC/ShareSecret.h @@ -46,6 +46,9 @@ class ShareSecret static const bool is_real = true; static const bool actual_inputs = true; + static const bool symmetric = true; + + static bool real_shares(const Player&) { return true; } static ShareThread& get_party() { @@ -118,6 +121,7 @@ class RepSecretBase : public FixedVec, public ShareSecret typedef BitVec open_type; typedef NoShare mac_type; typedef NoValue mac_key_type; + typedef NoShare mac_share_type; typedef NoShare bit_type; @@ -151,6 +155,11 @@ class RepSecretBase : public FixedVec, public ShareSecret { } + static GC::NoValue get_mac_key() + { + throw runtime_error("no MAC"); + } + template static string proto_fake_opts() { @@ -166,8 +175,8 @@ class RepSecretBase : public FixedVec, public ShareSecret { } - void bitcom(Memory& S, const vector& regs); - void bitdec(Memory& S, const vector& regs) const; + void bitcom(StackedVector& S, const vector& regs); + void bitdec(StackedVector& S, const vector& regs) const; void xor_(int n, const This& x, const This& y) { *this = (x ^ y).mask(n); } diff --git a/GC/ShareSecret.hpp b/GC/ShareSecret.hpp index 510e1a825..2d3f134d4 100644 --- a/GC/ShareSecret.hpp +++ b/GC/ShareSecret.hpp @@ -54,7 +54,7 @@ void ReplicatedSecret::load_clear(int n, const Integer& x) } template -void RepSecretBase::bitcom(Memory& S, const vector& regs) +void RepSecretBase::bitcom(StackedVector& S, const vector& regs) { *this = 0; for (unsigned int i = 0; i < regs.size(); i++) @@ -62,7 +62,7 @@ void RepSecretBase::bitcom(Memory& S, const vector& regs) } template -void RepSecretBase::bitdec(Memory& S, const vector& regs) const +void RepSecretBase::bitdec(StackedVector& S, const vector& regs) const { for (unsigned int i = 0; i < regs.size(); i++) S[regs[i]] = (*this >> i) & 1; diff --git a/GC/ShareThread.hpp b/GC/ShareThread.hpp index 61fd30583..88ca4fa40 100644 --- a/GC/ShareThread.hpp +++ b/GC/ShareThread.hpp @@ -94,23 +94,35 @@ void ShareThread::and_(Processor& processor, processor.check_args(args, 4); protocol->init_mul(); T x_ext, y_ext; + int total_bits = 0; for (size_t i = 0; i < args.size(); i += 4) { int n_bits = args[i]; + total_bits += n_bits; int left = args[i + 2]; int right = args[i + 3]; for (int j = 0; j < DIV_CEIL(n_bits, T::default_length); j++) { int n = min(T::default_length, n_bits - j * T::default_length); + + if (not repeat and n == T::default_length) + { + protocol->prepare_mul(processor.S[left + j], processor.S[right + j]); + continue; + } + + processor.S[left + j].mask(x_ext, n); if (repeat) processor.S[right].extend_bit(y_ext, n); else processor.S[right + j].mask(y_ext, n); - processor.S[left + j].mask(x_ext, n); protocol->prepare_mult(x_ext, y_ext, n, repeat); } } + if (OnlineOptions::singleton.has_option("verbose_and")) + fprintf(stderr, "%d%s ANDs\n", total_bits, repeat ? " repeat" : ""); + protocol->exchange(); for (size_t i = 0; i < args.size(); i += 4) @@ -121,6 +133,13 @@ void ShareThread::and_(Processor& processor, { int n = min(T::default_length, n_bits - j * T::default_length); auto& res = processor.S[out + j]; + + if (not repeat and n == T::default_length) + { + res = protocol->finalize_mul(); + continue; + } + protocol->finalize_mult(res, n); res.mask(res, n); } @@ -136,10 +155,12 @@ void ShareThread::andrsvec(Processor& processor, const vector& args) protocol->init_mul(); auto it = args.begin(); T x_ext, y_ext; + int total_bits = 0; while (it < args.end()) { int n_args = (*it++ - 3) / 2; int size = *it++; + total_bits += size * n_args; it += n_args; int base = *it++; for (int i = 0; i < size; i += N_BITS) @@ -155,6 +176,9 @@ void ShareThread::andrsvec(Processor& processor, const vector& args) it += n_args; } + if (OnlineOptions::singleton.has_option("verbose_and")) + fprintf(stderr, "%d repeat ANDs\n", total_bits); + protocol->exchange(); it = args.begin(); diff --git a/GC/ShiftableTripleBuffer.h b/GC/ShiftableTripleBuffer.h index af66cfb22..a3299fb0e 100644 --- a/GC/ShiftableTripleBuffer.h +++ b/GC/ShiftableTripleBuffer.h @@ -32,6 +32,8 @@ class ShiftableTripleBuffer array get_triple_no_count(int n_bits) { int max_n_bits = T::default_length; + if (n_bits == -1) + n_bits = max_n_bits; assert(n_bits <= max_n_bits); assert(n_bits > 0); array res; diff --git a/GC/TinySecret.h b/GC/TinySecret.h index ada5ca885..30d6c44be 100644 --- a/GC/TinySecret.h +++ b/GC/TinySecret.h @@ -53,9 +53,15 @@ class VectorSecret : public Secret static const bool malicious = T::malicious; static const bool expensive_triples = T::expensive_triples; static const bool randoms_for_opens = false; + static const bool symmetric = true; static const int default_length = 64; + static bool real_shares(const Player&) + { + return true; + } + static int size() { return part_type::size() * default_length; @@ -72,6 +78,11 @@ class VectorSecret : public Secret T::read_or_generate_mac_key(directory, P, key); } + static typename T::mac_type get_mac_key() + { + return T::get_mac_key(); + } + template static void reveal_inst(U& processor, const vector& args) { diff --git a/GC/instructions.h b/GC/instructions.h index 5f6149b1f..dfd883c31 100644 --- a/GC/instructions.h +++ b/GC/instructions.h @@ -80,14 +80,15 @@ X(INPUTBVEC, T::inputbvec(PROC, Proc, EXTRA)) \ X(CONVSINT, S0.load_clear(IMM, Proc.read_Ci(REG1))) \ X(CONVCINT, C0 = Proc.read_Ci(REG1)) \ - X(CONVCBIT, Proc.write_Ci(R0, PC1.get())) \ + X(CONVCBIT, Proc.write_Ci(R0, Proc.sync(PC1.get()))) \ X(CONVCINTVEC, Proc.convcintvec(instruction)) \ - X(CONVCBITVEC, Proc.convcbitvec(instruction)) \ + X(CONVCBITVEC, Proc.Procb.convcbitvec(instruction, Proc.get_Ci(), &Proc.P)) \ X(CONVCBIT2S, PROC.convcbit2s(instruction)) \ X(DABIT, Proc.dabit(INST)) \ X(EDABIT, Proc.edabit(INST)) \ X(SEDABIT, Proc.edabit(INST, true)) \ X(SPLIT, Proc.split(INST)) \ + X(CALL_ARG, ) \ #define GC_INSTRUCTIONS \ X(INPUTB, T::inputb(PROC, EXTRA)) \ @@ -101,6 +102,7 @@ X(CONVCINT, C0 = PI1) \ X(CONVCBIT, T::convcbit(I0, PC1, PROC)) \ X(CONVCBIT2S, T::convcbit2s(PROC, instruction)) \ + X(CONVCBITVEC, PROC.convcbitvec(instruction, Ci, 0)) \ X(PRINTCHR, PROC.print_chr(IMM)) \ X(PRINTSTR, PROC.print_str(IMM)) \ X(PRINTFLOATPREC, PROC.print_float_prec(IMM)) \ @@ -146,8 +148,11 @@ X(NPLAYERS, I0 = Thread::s().P->num_players()) \ X(THRESHOLD, I0 = T::threshold(Thread::s().P->num_players())) \ X(PLAYERID, I0 = Thread::s().P->my_num()) \ - X(CRASH, if (I0.get()) throw crash_requested()) \ + X(CRASH, if (I0.get() and T::actual_inputs) throw crash_requested()) \ X(ACTIVE, ) \ + X(LDTN, I0 = BaseMachine::thread_num) \ + X(CALL_TAPE, PROC.call_tape(INST, MD)) \ + X(CALL_ARG, ) \ #define INSTRUCTIONS BIT_INSTRUCTIONS GC_INSTRUCTIONS diff --git a/License.txt b/License.txt index 9c8f81b1c..4ec881f69 100644 --- a/License.txt +++ b/License.txt @@ -1,4 +1,4 @@ -The Software is copyright (c) 2023, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230. +The Software is copyright (c) 2024, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230. CSIRO grants you a licence to the Software on the terms of the BSD 3-Clause Licence. diff --git a/Machines/OTMachine.cpp b/Machines/OTMachine.cpp index 961dfbc5f..4f7218993 100644 --- a/Machines/OTMachine.cpp +++ b/Machines/OTMachine.cpp @@ -232,9 +232,9 @@ OTMachine::OTMachine(int argc, const char** argv) gettimeofday(&baseOTstart, NULL); // swap role for base OTs if (opt.isSet("-r")) - bot_ = new BaseOT(nbase, 128, P, INV_ROLE(ot_role)); + bot_ = new BaseOT(nbase, P, INV_ROLE(ot_role)); else - bot_ = new FakeOT(nbase, 128, P, INV_ROLE(ot_role)); + bot_ = new FakeOT(nbase, P, INV_ROLE(ot_role)); cout << "real mode " << opt.isSet("-r") << endl; BaseOT& bot = *bot_; bot.exec_base(); diff --git a/Machines/Tinier.cpp b/Machines/Tinier.cpp index 6a2ff874c..f91031b67 100644 --- a/Machines/Tinier.cpp +++ b/Machines/Tinier.cpp @@ -14,6 +14,7 @@ #include "GC/TinierSharePrep.hpp" #include "GC/CcdPrep.hpp" #include "GC/PersonalPrep.hpp" +#include "Protocols/Share.hpp" //template class GC::ShareParty>; template class GC::CcdPrep>; diff --git a/Machines/TripleMachine.cpp b/Machines/TripleMachine.cpp index 45c62e5fa..f2d03de5d 100644 --- a/Machines/TripleMachine.cpp +++ b/Machines/TripleMachine.cpp @@ -15,6 +15,7 @@ #include "Math/BitVec.h" #include "GC/TinierSecret.h" +#include "Protocols/Share.hpp" #include "Protocols/fake-stuff.hpp" #include "Protocols/MascotPrep.hpp" #include "Math/Z2k.hpp" diff --git a/Machines/mama-party.cpp b/Machines/mama-party.cpp index 942f22f4d..1b206891a 100644 --- a/Machines/mama-party.cpp +++ b/Machines/mama-party.cpp @@ -39,5 +39,6 @@ int main(int argc, const char** argv) return run<2, 1>(machine); cerr << "Not compiled for choice of parameters" << endl; + cerr << "Try using '-lgp 128'" << endl; exit(1); } diff --git a/Machines/sy-rep-field-party.cpp b/Machines/sy-rep-field-party.cpp index a457e3b09..17d3f9222 100644 --- a/Machines/sy-rep-field-party.cpp +++ b/Machines/sy-rep-field-party.cpp @@ -23,6 +23,7 @@ #include "Protocols/SpdzWisePrep.hpp" #include "Protocols/SpdzWiseInput.hpp" #include "Protocols/SpdzWiseShare.hpp" +#include "Protocols/SpdzWiseRep3Shuffler.hpp" #include "Processor/Data_Files.hpp" #include "Processor/Instruction.hpp" #include "Processor/Machine.hpp" diff --git a/Machines/sy-rep-ring-party.cpp b/Machines/sy-rep-ring-party.cpp index 45faca6f1..9ba4312bb 100644 --- a/Machines/sy-rep-ring-party.cpp +++ b/Machines/sy-rep-ring-party.cpp @@ -22,6 +22,7 @@ #include "Protocols/SpdzWisePrep.hpp" #include "Protocols/SpdzWiseInput.hpp" #include "Protocols/SpdzWiseShare.hpp" +#include "Protocols/SpdzWiseRep3Shuffler.hpp" #include "Protocols/PostSacrifice.hpp" #include "Protocols/MalRepRingPrep.hpp" #include "Protocols/MaliciousRepPrep.hpp" diff --git a/Machines/tinier-party.cpp b/Machines/tinier-party.cpp index 1ea00ffe3..d37f5e173 100644 --- a/Machines/tinier-party.cpp +++ b/Machines/tinier-party.cpp @@ -26,6 +26,7 @@ #include "Protocols/Beaver.hpp" #include "Protocols/MascotPrep.hpp" #include "Protocols/MalRepRingPrep.hpp" +#include "Protocols/Share.hpp" int main(int argc, const char** argv) { diff --git a/Machines/tiny-party.cpp b/Machines/tiny-party.cpp index f83f839f5..f19759836 100644 --- a/Machines/tiny-party.cpp +++ b/Machines/tiny-party.cpp @@ -26,6 +26,7 @@ #include "Protocols/MAC_Check_Base.hpp" #include "Protocols/Beaver.hpp" #include "Protocols/MascotPrep.hpp" +#include "Protocols/Share.hpp" int main(int argc, const char** argv) { diff --git a/Makefile b/Makefile index 8ec6c993e..9f5cbf3f8 100644 --- a/Makefile +++ b/Makefile @@ -52,7 +52,7 @@ endif endif # used for dependency generation -OBJS = $(patsubst %.cpp,%.o,$(wildcard */*.cpp)) $(STATIC_OTE) +OBJS = $(patsubst %.cpp,%.o,$(wildcard */*.cpp */*/*.cpp)) $(STATIC_OTE) DEPS := $(wildcard */*.d */*/*.d) # never delete @@ -150,13 +150,17 @@ static/%.x: Machines/%.o $(LIBRELEASE) $(LIBSIMPLEOT) local/lib/libcryptoTools.a $(MAKE) static-dir $(CXX) -o $@ $(CFLAGS) $^ -Wl,-Map=$<.map -Wl,-Bstatic -static-libgcc -static-libstdc++ $(LIBRELEASE) -llibOTe -lcryptoTools $(LIBSIMPLEOT) $(BOOST) $(LDLIBS) -Wl,-Bdynamic -ldl +static/%.x: Machines/BMR/%.o $(LIBRELEASE) $(LIBSIMPLEOT) local/lib/libcryptoTools.a local/lib/liblibOTe.a + $(MAKE) static-dir + $(CXX) -o $@ $(CFLAGS) $^ -Wl,-Map=$<.map -Wl,-Bstatic -static-libgcc -static-libstdc++ $(LIBRELEASE) -llibOTe -lcryptoTools $(LIBSIMPLEOT) $(BOOST) $(LDLIBS) -Wl,-Bdynamic -ldl + static/%.x: ECDSA/%.o ECDSA/P256Element.o $(VMOBJS) $(OT) $(LIBSIMPLEOT) $(CXX) $(CFLAGS) -o $@ $^ -Wl,-Map=$<.map -Wl,-Bstatic -static-libgcc -static-libstdc++ $(BOOST) $(LDLIBS) -Wl,-Bdynamic -ldl static-dir: @ mkdir static 2> /dev/null; true -static-release: static-dir $(patsubst Machines/%.cpp, static/%.x, $(wildcard Machines/*-party.cpp)) static/emulate.x +static-release: static-dir $(patsubst Machines/%.cpp, static/%.x, $(wildcard Machines/*-party.cpp)) $(patsubst Machines/BMR/%.cpp, static/%.x, $(wildcard Machines/BMR/*-party.cpp)) static/emulate.x Fake-ECDSA.x: ECDSA/Fake-ECDSA.cpp ECDSA/P256Element.o $(COMMON) Processor/PrepBase.o $(CXX) -o $@ $^ $(CFLAGS) $(LDLIBS) @@ -352,7 +356,7 @@ cmake: wget https://github.com/Kitware/CMake/releases/download/v3.24.1/cmake-3.24.1.tar.gz tar xzvf cmake-3.24.1.tar.gz cd cmake-3.24.1; \ - ./bootstrap --parallel=8 --prefix=../local && make && make install + ./bootstrap --parallel=8 --prefix=../local && make -j8 && make install mac-setup: mac-machine-setup brew install openssl boost libsodium gmp yasm ntl cmake diff --git a/Math/BitVec.h b/Math/BitVec.h index ca63c24cb..909de8f65 100644 --- a/Math/BitVec.h +++ b/Math/BitVec.h @@ -67,8 +67,8 @@ class BitVec_ : public IntBase { if (n == -1) pack(os); - else if (n == 1) - os.store_bit(this->a); + else if (n < 8) + os.store_bits(this->a, n); else os.store_int(super::mask(n).get(), DIV_CEIL(n, 8)); } @@ -77,8 +77,8 @@ class BitVec_ : public IntBase { if (n == -1) unpack(os); - else if (n == 1) - this->a = os.get_bit(); + else if (n < 8) + this->a = os.get_bits(n); else this->a = os.get_int(DIV_CEIL(n, 8)); } diff --git a/Math/FixedVec.h b/Math/FixedVec.h index a412c7e04..d38dc0e4f 100644 --- a/Math/FixedVec.h +++ b/Math/FixedVec.h @@ -161,6 +161,11 @@ class FixedVec return equal(1); } + bool operator==(const FixedVec& other) const + { + return equal(other); + } + bool operator!=(const FixedVec& other) const { return not equal(other); diff --git a/Math/Z2k.h b/Math/Z2k.h index 2c6704d34..c9c16bab6 100644 --- a/Math/Z2k.h +++ b/Math/Z2k.h @@ -306,7 +306,7 @@ class SignedZ2 : public Z2 return operator*(SignedZ2<64>(other)); } - void output(ostream& s, bool human = true) const; + void output(ostream& s, bool human = true, bool signed_ = true) const; }; template @@ -479,12 +479,17 @@ SignedZ2 abs(const SignedZ2& x) } template -void SignedZ2::output(ostream& s, bool human) const +void SignedZ2::output(ostream& s, bool human, bool signed_) const { if (human) { - bigint::tmp = *this; - s << bigint::tmp; + if (signed_) + { + bigint::tmp = *this; + s << bigint::tmp; + } + else + Z2::output(s, human); } else Z2::output(s, false); @@ -493,7 +498,7 @@ void SignedZ2::output(ostream& s, bool human) const template ostream& operator<<(ostream& o, const SignedZ2& x) { - x.output(o, true); + x.output(o, true, false); return o; } diff --git a/Math/gf2n.cpp b/Math/gf2n.cpp index ba638d974..56b76b3ab 100644 --- a/Math/gf2n.cpp +++ b/Math/gf2n.cpp @@ -510,6 +510,7 @@ gf2n_short::gf2n_short(const int128& a) // Expansion is by x=y^5+1 (as we embed GF(256) into GF(2^40) void expand_byte(gf2n_short& a,int b) { + gf2n_short::init_field(40); gf2n_short x,xp; x = (32+1); xp.assign_one(); diff --git a/Math/gfp.h b/Math/gfp.h index 313c74fb6..7cd7351bd 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -107,7 +107,7 @@ class gfp_ : public ValueInterface static void write_setup(string dir) { write_online_setup(dir, pr()); } static void check_setup(string dir); - static string fake_opts() { return " -lgp " + to_string(length()); } + static string fake_opts() { return " -P " + to_string(pr()); } /** * Get the prime modulus @@ -229,18 +229,25 @@ class gfp_ : public ValueInterface // faster randomization, see implementation for explanation void almost_randomize(PRNG& G); - void output(ostream& s,bool human) const - { a.output(s,ZpD,human); } + /** + * Output. + * @param s output stream + * @param human human-readable or binary + * @param signed_ signed representation (range `[-p/2,p/2]` instead of `[0,p]`) + */ + void output(ostream& s, bool human, bool signed_ = false) const + { a.output(s,ZpD, human, signed_); } void input(istream& s,bool human) { a.input(s,ZpD,human); } /** - * Human-readable output in the range `[-p/2, p/2]`. + * Human-readable output in the range `[0, p]`. * @param s output stream * @param x value */ friend ostream& operator<<(ostream& s,const gfp_& x) - { x.output(s,true); + { + x.output(s, true, false); return s; } /** diff --git a/Math/gfpvar.h b/Math/gfpvar.h index e2cafb365..5e823d0ef 100644 --- a/Math/gfpvar.h +++ b/Math/gfpvar.h @@ -82,7 +82,7 @@ class gfpvar_ { write_setup(get_prep_sub_dir(nplayers)); } - static string fake_opts() { return " -lgp " + to_string(length()); } + static string fake_opts() { return " -P " + to_string(pr()); } gfpvar_(); gfpvar_(int other); diff --git a/Math/modp.h b/Math/modp.h index 2ca1e1047..6e13e0e3c 100644 --- a/Math/modp.h +++ b/Math/modp.h @@ -132,7 +132,7 @@ class modp_ // - Can do in human or machine only format (later should be faster) // - If human output appends a space to help with reading // and also convert back/forth from Montgomery if needed - void output(ostream& s,const Zp_Data& ZpD,bool human) const; + void output(ostream& s, const Zp_Data& ZpD, bool human, bool signed_ = false) const; void input(istream& s,const Zp_Data& ZpD,bool human); template diff --git a/Math/modp.hpp b/Math/modp.hpp index 32faf9766..0a12360d4 100644 --- a/Math/modp.hpp +++ b/Math/modp.hpp @@ -327,12 +327,12 @@ void Power(modp_& ans,const modp_& x,const bigint& exp,const Zp_Data& ZpD) template -void modp_::output(ostream& s,const Zp_Data& ZpD,bool human) const +void modp_::output(ostream& s, const Zp_Data& ZpD, bool human, bool signed_) const { if (human) { bigint te; to_bigint(te, ZpD); - if (te < ZpD.pr / 2) + if (te < ZpD.pr / 2 or not signed_) s << te; else s << (te - ZpD.pr); diff --git a/Networking/CryptoPlayer.cpp b/Networking/CryptoPlayer.cpp index 6192fcdf5..53bc35879 100644 --- a/Networking/CryptoPlayer.cpp +++ b/Networking/CryptoPlayer.cpp @@ -10,12 +10,12 @@ void check_ssl_file(string filename) { if (not ifstream(filename)) - throw runtime_error("Cannot access " + filename + exit_error("Cannot access " + filename + ". Have you set up SSL?\n" "You can use `Scripts/setup-ssl.sh `."); } -void ssl_error(string side, string other, string me) +void ssl_error(string side, string other, string me, exception& e) { cerr << side << "-side handshake with " << other << " failed. Make sure both sides " @@ -48,6 +48,8 @@ void ssl_error(string side, string other, string me) cerr << "/"; } cerr << endl; + cerr << "SSL error: " << e.what() << endl; + exit(1); } CryptoPlayer::CryptoPlayer(const Names& Nms, const string& id_base) : diff --git a/Networking/Player.cpp b/Networking/Player.cpp index d70e5d639..620ea0f3c 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -7,6 +7,7 @@ #include "Networking/Server.h" #include "Networking/ServerSocket.h" #include "Networking/Exchanger.h" +#include "Processor/OnlineOptions.h" #include #include @@ -78,7 +79,7 @@ void Names::init(int player, int pnb, const string& filename, int nplayers_wante } } if (nplayers_wanted > 0 and nplayers_wanted != nplayers) - throw runtime_error("not enough hosts in " + filename); + exit_error("not enough hosts in " + filename); #ifdef DEBUG_NETWORKING cerr << "Got list of " << nplayers << " players from file: " << endl; for (unsigned int i = 0; i < names.size(); i++) @@ -127,7 +128,17 @@ void Names::setup_names(const char *servername, int my_port) int socket_num; int pn = portnum_base; - set_up_client_socket(socket_num, servername, pn); + + try + { + set_up_client_socket(socket_num, servername, pn); + } + catch (exception& e) + { + exit_error( + string("cannot reach coordination server: ") + e.what()); + } + octetStream("P" + to_string(player_no)).Send(socket_num); #ifdef DEBUG_NETWORKING cerr << "Sent " << player_no << " to " << servername << ":" << pn << endl; @@ -155,11 +166,11 @@ void Names::setup_names(const char *servername, int my_port) } catch (exception& e) { - throw runtime_error(string("error in network setup: ") + e.what()); + exit_error(string("error in network setup: ") + e.what()); } if (names.size() != ports.size()) - throw runtime_error("invalid network setup"); + exit_error("invalid network setup"); nplayers = names.size(); #ifdef VERBOSE for (int i = 0; i < nplayers; i++) @@ -288,7 +299,15 @@ void PlainPlayer::setup_sockets(const vector& names, "Setting up send to self socket to %s:%d with id %s\n", localhost, ports[i], pn.c_str()); #endif - set_up_client_socket(sockets[i],localhost,ports[i]); + try + { + set_up_client_socket(sockets[i],localhost,ports[i]); + } + catch (exception& e) + { + exit_error("cannot connect to myself, " + "maybe check your firewall configuration"); + } } else { #ifdef DEBUG_NETWORKING fprintf(stderr, "Setting up client to %s:%d with id %s\n", @@ -762,7 +781,7 @@ NamedCommStats& NamedCommStats::operator +=(const NamedCommStats& other) { sent += other.sent; for (auto it = other.begin(); it != other.end(); it++) - (*this)[it->first] += it->second; + map::operator[](it->first) += it->second; return *this; } @@ -786,7 +805,7 @@ NamedCommStats NamedCommStats::operator -(const NamedCommStats& other) const NamedCommStats res = *this; res.sent = sent - other.sent; for (auto it = other.begin(); it != other.end(); it++) - res[it->first] -= it->second; + res.map::operator[](it->first) -= it->second; return res; } @@ -818,9 +837,25 @@ Timer& NamedCommStats::add_to_last_round(const string& name, size_t length) } } -void PlayerBase::reset_stats() +Timer& CommStatsWithName::add_length_only(size_t length) +{ + if (OnlineOptions::singleton.has_option("verbose_comm")) + fprintf(stderr, "%s %zu bytes in same round\n", name.c_str(), length); + return stats.add_length_only(length); +} + +Timer& CommStatsWithName::add(const octetStream& os) +{ + if (OnlineOptions::singleton.has_option("verbose_comm")) + fprintf(stderr, "%s %zu bytes\n", name.c_str(), os.get_length()); + return stats.add(os); +} + +void Player::reset_stats() { comm_stats.reset(); + for (auto& x : thread_stats) + x.reset(); } NamedCommStats Player::total_comm() const diff --git a/Networking/Player.h b/Networking/Player.h index d31e1ebac..40e113bc1 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -141,18 +141,28 @@ struct CommStats } Timer& add_length_only(size_t length) { -#ifdef VERBOSE_COMM - cout << "add " << length << endl; -#endif data += length; return timer; } Timer& add(const octetStream& os) { return add(os.get_length()); } - void add(const octetStream& os, const TimeScope& scope) { add(os) += scope; } CommStats& operator+=(const CommStats& other); CommStats& operator-=(const CommStats& other); }; +class CommStatsWithName +{ + const string& name; + CommStats& stats; + +public: + CommStatsWithName(const string& name, CommStats& stats) : + name(name), stats(stats) {} + + Timer& add_length_only(size_t length); + Timer& add(const octetStream& os); + void add(const octetStream& os, const TimeScope& scope) { add(os) += scope; } +}; + class NamedCommStats : public map { public: @@ -167,14 +177,8 @@ class NamedCommStats : public map void print(bool newline = false); void reset(); Timer& add_to_last_round(const string& name, size_t length); -#ifdef VERBOSE_COMM - CommStats& operator[](const string& name) - { - auto& res = map::operator[](name); - cout << name << " after " << res.data << endl; - return res; - } -#endif + CommStatsWithName operator[](const string& name) + { return {name, map::operator[](name)}; } }; /** @@ -209,8 +213,6 @@ class PlayerBase virtual void send_receive_all(const vector&, vector&) const { throw not_implemented(); } - - void reset_stats(); }; /** @@ -394,6 +396,7 @@ class Player : public PlayerBase { receive_player(i, o); } NamedCommStats total_comm() const; + void reset_stats(); }; /** diff --git a/Networking/Server.cpp b/Networking/Server.cpp index f8b545b9f..8f6d8d01a 100644 --- a/Networking/Server.cpp +++ b/Networking/Server.cpp @@ -168,8 +168,13 @@ Server* Server::start_networking(Names& N, int my_num, int nplayers, cerr << "Starting networking for " << my_num << "/" << nplayers << " with server on " << hostname << ":" << (portnum) << endl; #endif - assert(my_num >= 0); - assert(my_num < nplayers); + if (my_num < 0 or my_num >= nplayers) + { + cerr << "Player number " << my_num << " outside range: 0-" + << nplayers - 1 << endl; + exit(1); + } + Server* server = 0; pthread_t thread; if (my_num == 0) diff --git a/Networking/ServerSocket.cpp b/Networking/ServerSocket.cpp index d863efdf3..c57657a46 100644 --- a/Networking/ServerSocket.cpp +++ b/Networking/ServerSocket.cpp @@ -205,7 +205,7 @@ int ServerSocket::get_connection_socket(const string& id) while (clients.find(id) == clients.end()) { if (data_signal.wait(CONNECTION_TIMEOUT) == ETIMEDOUT) - throw runtime_error("Timed out waiting for peer. See " + exit_error("Timed out waiting for peer. See " "https://mp-spdz.readthedocs.io/en/latest/networking.html " "for details on networking."); } @@ -230,7 +230,7 @@ void AnonymousServerSocket::init() void AnonymousServerSocket::process_client(const string& client_id) { if (clients.find(client_id) != clients.end()) - throw runtime_error("client " + client_id + " already connected"); + exit_error("client " + client_id + " already connected"); client_connection_queue.push(client_id); } @@ -242,7 +242,7 @@ int AnonymousServerSocket::get_connection_socket(string& client_id) { int res = data_signal.wait(CONNECTION_TIMEOUT); if (res == ETIMEDOUT) - throw runtime_error("timed out while waiting for client"); + exit_error("timed out while waiting for client"); else if (res) throw runtime_error("waiting error"); } diff --git a/Networking/sockets.cpp b/Networking/sockets.cpp index 0c1dcfa47..6572b12a5 100644 --- a/Networking/sockets.cpp +++ b/Networking/sockets.cpp @@ -14,7 +14,7 @@ void error(const char *str) gethostname(err,1000); strcat(err," : "); strcat(err,str); - throw runtime_error(string() + err + " : " + strerror(old_errno)); + exit_error(string() + err + " : " + strerror(old_errno)); } void set_up_client_socket(int& mysocket,const char* hostname,int Portnum) @@ -62,7 +62,7 @@ void set_up_client_socket(int& mysocket,const char* hostname,int Portnum) { for (rp = ai; rp != NULL; rp = rp->ai_next) cerr << "Family on offer: " << ai->ai_family << endl; - runtime_error(string("No AF_INET for ") + (char*)hostname + " on " + (char*)my_name); + exit_error(string("No AF_INET for ") + (char*)hostname + " on " + (char*)my_name); } @@ -106,10 +106,11 @@ void set_up_client_socket(int& mysocket,const char* hostname,int Portnum) if (fl < 0) { - throw runtime_error( + exit_error( string() + "cannot connect from " + my_name + " to " + hostname + ":" + to_string(Portnum) + " after " + to_string(attempts) - + " attempts in one minute because " + strerror(connect_errno) + ". " + + " attempts in " + to_string(CONNECTION_TIMEOUT) + + " seconds because " + strerror(connect_errno) + ". " "https://mp-spdz.readthedocs.io/en/latest/troubleshooting.html#" "connection-failures has more information on port requirements."); } diff --git a/Networking/ssl_sockets.h b/Networking/ssl_sockets.h index a9ce63130..3e176d22a 100644 --- a/Networking/ssl_sockets.h +++ b/Networking/ssl_sockets.h @@ -21,7 +21,7 @@ typedef boost::asio::io_service ssl_service; void check_ssl_file(string filename); -void ssl_error(string side, string other, string server); +void ssl_error(string side, string other, string server, exception& e); class ssl_ctx : public boost::asio::ssl::context { @@ -62,9 +62,9 @@ class ssl_socket : public boost::asio::ssl::stream try { handshake(ssl_socket::client); - } catch (...) + } catch (exception& e) { - ssl_error("Client", other, me); + ssl_error("Client", other, me, e); throw; } else @@ -72,9 +72,9 @@ class ssl_socket : public boost::asio::ssl::stream try { handshake(ssl_socket::server); - } catch (...) + } catch (exception& e) { - ssl_error("Server", other, me); + ssl_error("Server", other, me, e); throw; } diff --git a/OT/BaseOT.h b/OT/BaseOT.h index b8e4b876d..e136f7d3c 100644 --- a/OT/BaseOT.h +++ b/OT/BaseOT.h @@ -47,12 +47,12 @@ class BaseOT vector receiver_outputs; TwoPartyPlayer* P; /// Number of OTs - int nOT, ot_length; + int nOT; /// Which role(s) on this side OT_ROLE ot_role; - BaseOT(int nOT, int ot_length, TwoPartyPlayer* player, OT_ROLE role=BOTH) - : P(player), nOT(nOT), ot_length(ot_length), ot_role(role) + BaseOT(int nOT, TwoPartyPlayer* player, OT_ROLE role=BOTH) + : P(player), nOT(nOT), ot_role(role) { receiver_inputs.resize(nOT); sender_inputs.resize(nOT); @@ -69,14 +69,12 @@ class BaseOT } BaseOT(TwoPartyPlayer* player, OT_ROLE role) : - BaseOT(128, 128, player, role) + BaseOT(128, player, role) { } virtual ~BaseOT() {} - int length() { return ot_length; } - /// Set choice bits void set_receiver_inputs(const BitVector& new_inputs) { @@ -126,8 +124,8 @@ class BaseOT class FakeOT : public BaseOT { public: - FakeOT(int nOT, int ot_length, TwoPartyPlayer* player, OT_ROLE role=BOTH) : - BaseOT(nOT, ot_length, player, role) {} + FakeOT(int nOT, TwoPartyPlayer* player, OT_ROLE role=BOTH) : + BaseOT(nOT, player, role) {} void exec_base(bool new_receiver_inputs=true); }; diff --git a/OT/BitDiagonal.cpp b/OT/BitDiagonal.cpp index e6c4293a1..7f21b8a85 100644 --- a/OT/BitDiagonal.cpp +++ b/OT/BitDiagonal.cpp @@ -8,12 +8,13 @@ void BitDiagonal::pack(octetStream& os) const { for (int i = 0; i < N_ROWS; i++) - os.store_int(rows[i].get_bit(i), 1); + os.store_bit(rows[i].get_bit(i)); + os.append(0); } void BitDiagonal::unpack(octetStream& os) { *this = {}; for (int i = 0; i < N_ROWS; i++) - rows[i] = os.get_int(1) << i; + rows[i] = RowType(os.get_bit()) << i; } diff --git a/OT/NPartyTripleGenerator.hpp b/OT/NPartyTripleGenerator.hpp index 1dce9cc78..4d723d1f7 100644 --- a/OT/NPartyTripleGenerator.hpp +++ b/OT/NPartyTripleGenerator.hpp @@ -72,6 +72,11 @@ Spdz2kTripleGenerator::Spdz2kTripleGenerator(const OTTripleSetup& setup, template void OTTripleGenerator::set_batch_size(int batch_size) { + // limit to ~1 GB + batch_size = min(batch_size, int(1e7 / sizeof(T) / sizeof(T))); + if (OnlineOptions::singleton.has_option("verbose_ot")) + fprintf(stderr, "OT batch size %d (share size %d)\n", batch_size, + int(sizeof(T))); nTriplesPerLoop = DIV_CEIL(batch_size, nloops); nTriples = nTriplesPerLoop * nloops; nPreampTriplesPerLoop = nTriplesPerLoop * nAmplify; @@ -198,9 +203,9 @@ void NPartyTripleGenerator::generate() { outputFile.open(ss.str().c_str()); if (machine.generateMACs or not T::clear::invertible) - file_signature().output(outputFile); + file_signature(this->mac_key).output(outputFile); else - file_signature().output(outputFile); + file_signature>().output(outputFile); } if (machine.generateBits) @@ -250,6 +255,7 @@ void NPartyTripleGenerator::generateInputs(int player) inputs.resize(toCheck); auto mac_key = this->get_mac_key(); SemiInput> input(0, globalPlayer); + input.maybe_init(globalPlayer); input.reset_all(globalPlayer); vector secrets(toCheck); if (mine) @@ -528,8 +534,10 @@ void OTTripleGenerator::generateMixedTriples() machine.set_passive(); machine.output = false; - int n = multiple_minimum(100 * nPreampTriplesPerLoop, - T::open_type::size_in_bits()); + int n = multiple_minimum(nPreampTriplesPerLoop, 8); + + if (OnlineOptions::singleton.has_option("verbose_mixed")) + fprintf(stderr, "generating %d mixed triples\n", n); valueBits.resize(2); valueBits[0].resize(n); @@ -556,6 +564,9 @@ void OTTripleGenerator::generateMixedTriples() template void OTTripleGenerator::plainTripleRound(int k) { + if (OnlineOptions::singleton.has_option("verbose_triples")) + fprintf(stderr, "generating %d triples\n", nPreampTriplesPerLoop); + typedef typename U::open_type T; if (not (machine.amplify or machine.output)) diff --git a/OT/OTCorrelator.hpp b/OT/OTCorrelator.hpp index d6c19761b..38a635ccd 100644 --- a/OT/OTCorrelator.hpp +++ b/OT/OTCorrelator.hpp @@ -78,6 +78,11 @@ void OTCorrelator::correlate(int start, int slice, Slice t1Slice(t1, start, slice); Slice uSlice(u, start, slice); + if (OnlineOptions::singleton.has_option("verbose_correlate")) + fprintf(stderr, "correlate %d matrices of size %d*%d, %u bits\n", slice, + int(U::PartType::n_rows()), int(U::PartType::n_columns()), + newReceiverInput.size()); + // create correlation if (ot_role & RECEIVER) { diff --git a/OT/OTExtensionWithMatrix.cpp b/OT/OTExtensionWithMatrix.cpp index b778b7195..fdd1dfa0a 100644 --- a/OT/OTExtensionWithMatrix.cpp +++ b/OT/OTExtensionWithMatrix.cpp @@ -20,7 +20,7 @@ osuCrypto::IOService ot_extension_ios; OTExtensionWithMatrix OTExtensionWithMatrix::setup(TwoPartyPlayer& player, int128 delta, OT_ROLE role, bool passive) { - BaseOT baseOT(128, 128, &player, INV_ROLE(role)); + BaseOT baseOT(128, &player, INV_ROLE(role)); PRNG G; G.ReSeed(); baseOT.set_receiver_inputs(delta); @@ -30,6 +30,11 @@ OTExtensionWithMatrix OTExtensionWithMatrix::setup(TwoPartyPlayer& player, OTExtensionWithMatrix::OTExtensionWithMatrix(BaseOT& baseOT, TwoPartyPlayer* player, bool passive) : OTCorrelator(baseOT, player, passive) +{ + init_me(); +} + +void OTExtensionWithMatrix::init_me() { G.ReSeed(); nsubloops = 1; @@ -37,6 +42,7 @@ OTExtensionWithMatrix::OTExtensionWithMatrix(BaseOT& baseOT, TwoPartyPlayer* pla #ifndef USE_KOS channel = 0; #endif + softspoken_k = 2; } OTExtensionWithMatrix::~OTExtensionWithMatrix() @@ -47,27 +53,43 @@ OTExtensionWithMatrix::~OTExtensionWithMatrix() #endif } +bool OTExtensionWithMatrix::use_kos() +{ +#ifdef USE_KOS + return true; +#else + return OnlineOptions::singleton.has_option("use_kos"); +#endif +} + void OTExtensionWithMatrix::protocol_agreement() { if (agreed) return; Bundle bundle(*player); -#ifdef USE_KOS - bundle.mine = string("KOS15"); -#else - bundle.mine = string("SoftSpokenOT"); -#endif + if (use_kos()) + bundle.mine = string("KOS15"); + else + bundle.mine = string("SoftSpokenOT"); + + if (OnlineOptions::singleton.has_option("high_softspoken")) + softspoken_k = 8; + + bundle.mine.store(softspoken_k); + player->unchecked_broadcast(bundle); try { bundle.compare(*player); + agreed = true; } catch (mismatch_among_parties&) { cerr << "Parties compiled with different OT extensions" << endl; cerr << "Set \"USE_KOS\" to the same value on all parties" << endl; + cerr << "and make sure that the SoftSpokenOT parameter is the same" << endl; exit(1); } } @@ -104,16 +126,27 @@ void OTExtensionWithMatrix::transfer(int nOTs, #endif } -void OTExtensionWithMatrix::extend(int nOTs_requested, const BitVector& newReceiverInput) +void OTExtensionWithMatrix::extend(int nOTs_requested, + const BitVector& newReceiverInput, bool hash) { protocol_agreement(); + if (use_kos()) + { + extend_correlated(nOTs_requested, newReceiverInput); + if (hash) + hash_outputs(nOTs_requested); + return; + } + #ifdef USE_KOS - extend_correlated(nOTs_requested, newReceiverInput); - hash_outputs(nOTs_requested); + assert(use_kos()); #else resize(nOTs_requested); + if (nOTs_requested == 0) + return; + if (not channel) channel = new osuCrypto::Channel(ot_extension_ios, new PlayerCtSocket(*player)); @@ -141,14 +174,18 @@ void OTExtensionWithMatrix::soft_sender(size_t n) if (not (ot_role & SENDER)) return; + if (OnlineOptions::singleton.has_option("verbose_ot")) + fprintf(stderr, "%zu OTs as sender\n", n); + osuCrypto::PRNG prng(osuCrypto::sysRandomSeed()); - osuCrypto::SoftSpokenOT::TwoOneMaliciousSender sender(2); + osuCrypto::SoftSpokenOT::TwoOneMaliciousSender sender(softspoken_k); vector outputs; for (auto& x : G_receiver) { outputs.push_back(x.get_doubleword()); } + sender.malicious = not passive_only; sender.setBaseOts(outputs, {baseReceiverInput.get_ptr(), sender.baseOtCount()}, prng, *channel); @@ -171,8 +208,11 @@ void OTExtensionWithMatrix::soft_receiver(size_t n, if (not (ot_role & RECEIVER)) return; + if (OnlineOptions::singleton.has_option("verbose_ot")) + fprintf(stderr, "%zu OTs as receiver\n", n); + osuCrypto::PRNG prng(osuCrypto::sysRandomSeed()); - osuCrypto::SoftSpokenOT::TwoOneMaliciousReceiver recver(2); + osuCrypto::SoftSpokenOT::TwoOneMaliciousReceiver recver(softspoken_k); vector> inputs; for (auto& x : G_sender) @@ -181,6 +221,7 @@ void OTExtensionWithMatrix::soft_receiver(size_t n, for (int i = 0; i < 2; i++) inputs.back()[i] = x[i].get_doubleword(); } + recver.malicious = not passive_only; recver.setBaseOts(inputs, prng, *channel); // Choose which messages should be received. diff --git a/OT/OTExtensionWithMatrix.h b/OT/OTExtensionWithMatrix.h index e6eab6da0..7bbf41af2 100644 --- a/OT/OTExtensionWithMatrix.h +++ b/OT/OTExtensionWithMatrix.h @@ -63,6 +63,10 @@ class OTExtensionWithMatrix : public OTCorrelator bool agreed; + int softspoken_k; + + void init_me(); + public: PRNG G; @@ -76,21 +80,18 @@ class OTExtensionWithMatrix : public OTCorrelator : OTCorrelator(player, role, passive), nsubloops(nsubloops) { - G.ReSeed(); - agreed = false; -#ifndef USE_KOS - channel = 0; -#endif + init_me(); } OTExtensionWithMatrix(BaseOT& baseOT, TwoPartyPlayer* player, bool passive); ~OTExtensionWithMatrix(); + bool use_kos(); void protocol_agreement(); void transfer(int nOTs, const BitVector& receiverInput, int nloops); - void extend(int nOTs, const BitVector& newReceiverInput); + void extend(int nOTs, const BitVector& newReceiverInput, bool hash = true); void extend_correlated(const BitVector& newReceiverInput); void extend_correlated(int nOTs, const BitVector& newReceiverInput); void transpose(int start = 0, int slice = -1); diff --git a/OT/OTMultiplier.h b/OT/OTMultiplier.h index 64b78412c..c83e9af29 100644 --- a/OT/OTMultiplier.h +++ b/OT/OTMultiplier.h @@ -77,6 +77,8 @@ class OTMultiplier : public OTMultiplierMac& generator, int thread_num); virtual ~OTMultiplier(); + + void init(); void multiply(); }; diff --git a/OT/OTMultiplier.hpp b/OT/OTMultiplier.hpp index 69636cfe7..fbafbefbb 100644 --- a/OT/OTMultiplier.hpp +++ b/OT/OTMultiplier.hpp @@ -88,11 +88,10 @@ OTMultiplier::~OTMultiplier() } template -void OTMultiplier::multiply() +void OTMultiplier::init() { keyBits.set(generator.get_mac_key()); rot_ext.extend(keyBits.size(), keyBits); - this->outbox.push({}); senderOutput.resize(keyBits.size()); for (size_t j = 0; j < keyBits.size(); j++) { @@ -106,10 +105,18 @@ void OTMultiplier::multiply() assert(receiverOutput.size() >= keyBits.size()); receiverOutput.resize(keyBits.size()); init_authenticator(keyBits, senderOutput, receiverOutput); +} +template +void OTMultiplier::multiply() +{ + this->outbox.push({}); MultJob job; while (this->inbox.pop(job)) { + if (receiverOutput.empty()) + init(); + if (job.input) { if (job.player == generator.my_num @@ -155,13 +162,14 @@ void SemiMultiplier::multiplyForBits() otCorrelator.set_role(role); BitVector aBits = this->generator.valueBits[0]; - rot_ext.extend_correlated(aBits); + rot_ext.extend(aBits.size(), aBits, not rot_ext.use_kos()); typedef typename T::Rectangle X; vector >& baseSenderOutputs = otCorrelator.matrices; Matrix& baseReceiverOutput = otCorrelator.senderOutputMatrices[0]; - rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput); + rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput, + rot_ext.use_kos()); int n_squares = otCorrelator.receiverOutputMatrix.squares.size(); otCorrelator.setup_for_correlation(aBits, baseSenderOutputs, @@ -201,12 +209,13 @@ void SemiMultiplier::multiplyForMixed() this->generator.players[this->thread_num], BOTH, true); BitVector aBits = this->generator.valueBits[0]; - rot_ext.extend_correlated(aBits); + rot_ext.extend(aBits.size(), aBits, not rot_ext.use_kos()); auto& baseSenderOutputs = otCorrelator.matrices; auto& baseReceiverOutput = otCorrelator.senderOutputMatrices[0]; - rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput); + rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput, + rot_ext.use_kos()); if (this->generator.get_player().num_players() == 2) { @@ -265,16 +274,17 @@ void OTMultiplier::multiplyForTriples() //timers["Extension"].start(); if (generator.machine.use_extension) { -#ifdef USE_KOS - rot_ext.extend_correlated(aBits); -#else - rot_ext.extend(aBits.size(), aBits); - corr_hash = false; -#endif + if (rot_ext.use_kos()) + rot_ext.extend_correlated(aBits); + else + { + rot_ext.extend(aBits.size(), aBits); + corr_hash = false; + } } else { - BaseOT bot(aBits.size(), -1, generator.players[thread_num]); + BaseOT bot(aBits.size(), generator.players[thread_num]); bot.set_receiver_inputs(aBits); bot.exec_base(false); for (size_t i = 0; i < aBits.size(); i++) diff --git a/OT/OTTripleSetup.cpp b/OT/OTTripleSetup.cpp index f2bdabe23..24f1bf2dd 100644 --- a/OT/OTTripleSetup.cpp +++ b/OT/OTTripleSetup.cpp @@ -1,5 +1,11 @@ #include "OTTripleSetup.h" +void* run_ot(void* job) +{ + ((OTTripleSetup::SetupJob*)job)->run(); + return 0; +} + void OTTripleSetup::setup() { timeval baseOTstart, baseOTend; @@ -12,13 +18,13 @@ void OTTripleSetup::setup() } //baseReceiverInput.randomize(G); + vector threads; for (int i = 0; i < nparties - 1; i++) - { - baseOTs[i]->set_receiver_inputs(base_receiver_inputs); - baseOTs[i]->exec_base(false); - baseSenderInputs[i] = baseOTs[i]->sender_inputs; - baseReceiverOutputs[i] = baseOTs[i]->receiver_outputs; - } + threads.push_back({*this, i}); + for (int i = 0; i < nparties - 1; i++) + pthread_create(&threads[i].thread, 0, run_ot, &threads[i]); + for (int i = 0; i < nparties - 1; i++) + pthread_join(threads[i].thread, 0); gettimeofday(&baseOTend, NULL); #ifdef VERBOSE_BASEOT double basetime = timeval_diff(&baseOTstart, &baseOTend); @@ -34,6 +40,14 @@ void OTTripleSetup::setup() // (since Sender finishes baseOTs before Receiver) } +void OTTripleSetup::run(int i) +{ + baseOTs[i]->set_receiver_inputs(base_receiver_inputs); + baseOTs[i]->exec_base(false); + baseSenderInputs[i] = baseOTs[i]->sender_inputs; + baseReceiverOutputs[i] = baseOTs[i]->receiver_outputs; +} + void OTTripleSetup::close_connections() { for (size_t i = 0; i < players.size(); i++) @@ -47,7 +61,7 @@ OTTripleSetup OTTripleSetup::get_fresh() OTTripleSetup res = *this; for (int i = 0; i < nparties - 1; i++) { - BaseOT bot(nbase, 128, 0); + BaseOT bot(nbase, 0); bot.sender_inputs = baseSenderInputs[i]; bot.receiver_outputs = baseReceiverOutputs[i]; bot.set_seeds(); diff --git a/OT/OTTripleSetup.h b/OT/OTTripleSetup.h index 4f28bcd94..11c9d9b6f 100644 --- a/OT/OTTripleSetup.h +++ b/OT/OTTripleSetup.h @@ -13,6 +13,8 @@ */ class OTTripleSetup { + void run(int i); + BitVector base_receiver_inputs; vector baseOTs; @@ -22,8 +24,27 @@ class OTTripleSetup int nbase; public: + class SetupJob + { + OTTripleSetup& setup; + int i; + + public: + pthread_t thread; + + SetupJob(OTTripleSetup& setup, int i) : + setup(setup), i(i), thread(0) + { + } + + void run() + { + setup.run(i); + } + }; + map timers; - vector players; + vector players; vector< vector< array > > baseSenderInputs; vector< vector > baseReceiverOutputs; @@ -56,16 +77,16 @@ class OTTripleSetup else other_player = i; - players.push_back(new OffsetPlayer(N, N.get_offset(other_player))); + players.push_back(new VirtualTwoPartyPlayer(N, other_player)); // sets up a pair of base OTs, playing both roles if (real_OTs) { - baseOTs[i] = new BaseOT(nbase, 128, players[i]); + baseOTs[i] = new BaseOT(nbase, players[i]); } else { - baseOTs[i] = new FakeOT(nbase, 128, players[i]); + baseOTs[i] = new FakeOT(nbase, players[i]); } } diff --git a/Processor/BaseMachine.cpp b/Processor/BaseMachine.cpp index 2752a5da0..c2005afbc 100644 --- a/Processor/BaseMachine.cpp +++ b/Processor/BaseMachine.cpp @@ -70,6 +70,29 @@ int BaseMachine::bucket_size(size_t usage) return res; } +int BaseMachine::matrix_batch_size(int n_rows, int n_inner, int n_cols) +{ + unsigned res = min(100, OnlineOptions::singleton.batch_size); + if (has_program()) + res = min(res, (unsigned) matrix_requirement(n_rows, n_inner, n_cols)); + return res; +} + +int BaseMachine::matrix_requirement(int n_rows, int n_inner, int n_cols) +{ + if (has_program()) + { + auto res = s().progs[0].get_offline_data_used().matmuls[ + {n_rows, n_inner, n_cols}]; + if (res) + return res; + else + return -1; + } + else + return -1; +} + BaseMachine::BaseMachine() : nthreads(0) { if (sodium_init() == -1) @@ -287,7 +310,7 @@ void BaseMachine::print_comm(Player& P, const NamedCommStats& comm_stats) rounds += x.second.rounds; cerr << "Data sent = " << comm_stats.sent / 1e6 << " MB in ~" << rounds << " rounds (party " << P.my_num() << " only"; - if (nthreads > 1) + if (multithread) cerr << "; rounds counted double due to multi-threading"; if (not OnlineOptions::singleton.verbose) cerr << "; use '-v' for more details"; diff --git a/Processor/BaseMachine.h b/Processor/BaseMachine.h index e522dcfe1..bd7d7a0a1 100644 --- a/Processor/BaseMachine.h +++ b/Processor/BaseMachine.h @@ -44,6 +44,7 @@ class BaseMachine string progname; int nthreads; + bool multithread; ThreadQueues queues; @@ -66,10 +67,14 @@ class BaseMachine template static int batch_size(Dtype type, int buffer_size = 0, int fallback = 0); template + static int input_batch_size(int player, int buffer_size = 0); + template static int edabit_batch_size(int n_bits, int buffer_size = 0); static int edabit_bucket_size(int n_bits); static int triple_bucket_size(DataFieldType type); static int bucket_size(size_t usage); + static int matrix_batch_size(int n_rows, int n_inner, int n_cols); + static int matrix_requirement(int n_rows, int n_inner, int n_cols); BaseMachine(); virtual ~BaseMachine() {} @@ -105,6 +110,10 @@ inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P) template int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback) { + if (OnlineOptions::singleton.has_option("debug_batch_size")) + fprintf(stderr, "batch_size buffer_size=%d fallback=%d\n", buffer_size, + fallback); + int n_opts; int n = 0; int res = 0; @@ -114,7 +123,7 @@ int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback) else if (fallback > 0) n_opts = fallback; else - n_opts = OnlineOptions::singleton.batch_size; + n_opts = OnlineOptions::singleton.batch_size * T::default_length; if (buffer_size <= 0 and has_program()) { @@ -132,7 +141,6 @@ int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback) { n = buffer_size; buffer_size = 0; - n_opts = OnlineOptions::singleton.batch_size; } if (n > 0 and not (buffer_size > 0)) @@ -161,16 +169,33 @@ int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback) else res = n_opts; -#ifdef DEBUG_BATCH_SIZE - cerr << DataPositions::dtype_names[type] << " " << T::type_string() - << " res=" << res << " n=" - << n << " n_opts=" << n_opts << " buffer_size=" << buffer_size << endl; -#endif + if (OnlineOptions::singleton.has_option("debug_batch_size")) + cerr << DataPositions::dtype_names[type] << " " << T::type_string() + << " res=" << res << " n=" << n << " n_opts=" << n_opts + << " buffer_size=" << buffer_size << endl; assert(res > 0); return res; } +template +int BaseMachine::input_batch_size(int player, int buffer_size) +{ + if (buffer_size) + return buffer_size; + + if (has_program()) + { + auto res = + s().progs[0].get_offline_data_used( + ).inputs[player][T::clear::field_type()]; + if (res > 0) + return res; + } + + return OnlineOptions::singleton.batch_size; +} + template int BaseMachine::edabit_batch_size(int n_bits, int buffer_size) { diff --git a/Processor/Conv2dTuple.h b/Processor/Conv2dTuple.h index 8e265ab36..113530e94 100644 --- a/Processor/Conv2dTuple.h +++ b/Processor/Conv2dTuple.h @@ -29,10 +29,12 @@ class Conv2dTuple Conv2dTuple(const vector& args, int start); + array matrix_dimensions(); + template - void pre(vector& S, typename T::Protocol& protocol); + void pre(StackedVector& S, typename T::Protocol& protocol); template - void post(vector& S, typename T::Protocol& protocol); + void post(StackedVector& S, typename T::Protocol& protocol); template void run_matrix(SubProcessor& processor); diff --git a/Processor/Data_Files.h b/Processor/Data_Files.h index ae948294e..4c20e09ab 100644 --- a/Processor/Data_Files.h +++ b/Processor/Data_Files.h @@ -12,8 +12,10 @@ #include "Networking/Player.h" #include "Protocols/edabit.h" #include "PrepBase.h" +#include "PrepBuffer.h" #include "EdabitBuffer.h" #include "Tools/TimerWithComm.h" +#include "Tools/CheckVector.h" #include #include @@ -104,8 +106,6 @@ class Preprocessing : public PrepBase protected: static const bool use_part = false; - DataPositions& usage; - bool do_count; void count(Dtype dtype, int n = 1) @@ -115,9 +115,9 @@ class Preprocessing : public PrepBase template void get_edabits(bool strict, size_t size, T* a, - vector& Sb, const vector& regs, false_type); + StackedVector& Sb, const vector& regs, false_type); template - void get_edabits(bool, size_t, T*, vector&, + void get_edabits(bool, size_t, T*, StackedVector&, const vector&, true_type) { throw not_implemented(); } @@ -126,6 +126,8 @@ class Preprocessing : public PrepBase T get_random_from_inputs(int nplayers); public: + int buffer_size; + template static Preprocessing* get_new(Machine& machine, DataPositions& usage, SubProcessor* proc); @@ -135,7 +137,8 @@ class Preprocessing : public PrepBase static Preprocessing* get_live_prep(SubProcessor* proc, DataPositions& usage); - Preprocessing(DataPositions& usage) : usage(usage), do_count(true) {} + Preprocessing(DataPositions& usage) : + PrepBase(usage), do_count(true), buffer_size(0) {} virtual ~Preprocessing() {} virtual void set_protocol(typename T::Protocol&) {}; @@ -151,7 +154,7 @@ class Preprocessing : public PrepBase virtual void get_one_no_count(Dtype, T&) { throw not_implemented(); } virtual void get_input_no_count(T&, typename T::open_type&, int) { throw not_implemented() ; } - virtual void get_no_count(vector&, DataTag, const vector&, int) + virtual void get_no_count(StackedVector&, DataTag, const vector&, int) { throw not_implemented(); } void get(Dtype dtype, T* a); @@ -159,7 +162,7 @@ class Preprocessing : public PrepBase void get_two(Dtype dtype, T& a, T& b); void get_one(Dtype dtype, T& a); void get_input(T& a, typename T::open_type& x, int i); - void get(vector& S, DataTag tag, const vector& regs, int vector_size); + void get(StackedVector& S, DataTag tag, const vector& regs, int vector_size); /// Get fresh random multiplication triple virtual array get_triple(int n_bits); @@ -174,7 +177,7 @@ class Preprocessing : public PrepBase virtual void get_dabit(T& a, typename T::bit_type& b); virtual void get_dabit_no_count(T&, typename T::bit_type&) { throw runtime_error("no daBit"); } virtual void get_edabits(bool strict, size_t size, T* a, - vector& Sb, const vector& regs) + StackedVector& Sb, const vector& regs) { get_edabits<0>(strict, size, a, Sb, regs, T::clear::characteristic_two); } virtual void get_edabit_no_count(bool, int, edabit&) { throw runtime_error("no edaBits"); } @@ -201,11 +204,11 @@ class Sub_Data_Files : public Preprocessing static int tuple_length(int dtype); - BufferOwner buffers[N_DTYPE]; - vector> input_buffers; - BufferOwner, RefInputTuple> my_input_buffers; - map > extended; - BufferOwner, dabit> dabit_buffer; + array, N_DTYPE> buffers; + vector> input_buffers; + PrepBuffer, RefInputTuple, T> my_input_buffers; + map > extended; + PrepBuffer, dabit, T> dabit_buffer; map> edabit_buffers; map> my_edabits; @@ -284,7 +287,7 @@ class Sub_Data_Files : public Preprocessing } void setup_extended(const DataTag& tag, int tuple_size = 0); - void get_no_count(vector& S, DataTag tag, const vector& regs, int vector_size); + void get_no_count(StackedVector& S, DataTag tag, const vector& regs, int vector_size); void get_dabit_no_count(T& a, typename T::bit_type& b); part_type& get_part(); @@ -397,7 +400,7 @@ inline void Preprocessing::get_input(T& a, typename T::open_type& x, int i) } template -inline void Preprocessing::get(vector& S, DataTag tag, +inline void Preprocessing::get(StackedVector& S, DataTag tag, const vector& regs, int vector_size) { usage.count(T::clear::field_type(), tag, vector_size); diff --git a/Processor/Data_Files.hpp b/Processor/Data_Files.hpp index 9cce036e2..3cc2b7e2b 100644 --- a/Processor/Data_Files.hpp +++ b/Processor/Data_Files.hpp @@ -143,14 +143,14 @@ Sub_Data_Files::Sub_Data_Files(int my_num, int num_players, { if (T::clear::allows(Dtype(dtype))) { - buffers[dtype].setup( + buffers[dtype].setup(num_players, PrepBase::get_filename(prep_data_dir, Dtype(dtype), type_short, my_num, thread_num), tuple_length(dtype), type_string, DataPositions::dtype_names[dtype]); } } - dabit_buffer.setup( + dabit_buffer.setup(num_players, PrepBase::get_filename(prep_data_dir, DATA_DABIT, type_short, my_num, thread_num), dabit::size(), type_string, DataPositions::dtype_names[DATA_DABIT]); @@ -161,10 +161,10 @@ Sub_Data_Files::Sub_Data_Files(int my_num, int num_players, string filename = PrepBase::get_input_filename(prep_data_dir, type_short, i, my_num, thread_num); if (i == my_num) - my_input_buffers.setup(filename, + my_input_buffers.setup(num_players, filename, InputTuple::size(), type_string); else - input_buffers[i].setup(filename, + input_buffers[i].setup(num_players, filename, T::size(), type_string); } @@ -344,14 +344,14 @@ void Sub_Data_Files::setup_extended(const DataTag& tag, int tuple_size) { stringstream ss; ss << prep_data_dir << tag.get_string() << "-" << T::type_short() << "-P" << my_num; - buffer.setup(ss.str(), tuple_length); + buffer.setup(num_players, ss.str(), tuple_length); } buffer.check_tuple_length(tuple_length); } template -void Sub_Data_Files::get_no_count(vector& S, DataTag tag, const vector& regs, int vector_size) +void Sub_Data_Files::get_no_count(StackedVector& S, DataTag tag, const vector& regs, int vector_size) { setup_extended(tag, regs.size()); for (int j = 0; j < vector_size; j++) diff --git a/Processor/DummyProtocol.h b/Processor/DummyProtocol.h index 028b65ac9..c3252c3a5 100644 --- a/Processor/DummyProtocol.h +++ b/Processor/DummyProtocol.h @@ -12,6 +12,7 @@ using namespace std; #include "Math/BitVec.h" #include "Data_Files.h" #include "Protocols/Replicated.h" +#include "Protocols/ReplicatedPrep.h" #include "Protocols/MAC_Check_Base.h" #include "Processor/Input.h" @@ -109,7 +110,7 @@ class DummyProtocol : public ProtocolBase }; template -class DummyLivePrep : public Preprocessing +class DummyLivePrep : public BufferPrep { public: static const bool homomorphic = true; @@ -133,16 +134,16 @@ class DummyLivePrep : public Preprocessing } DummyLivePrep(DataPositions& usage, GC::ShareThread&) : - Preprocessing(usage) + BufferPrep(usage) { } DummyLivePrep(DataPositions& usage, bool = true) : - Preprocessing(usage) + BufferPrep(usage) { } DummyLivePrep(SubProcessor*, DataPositions& usage) : - Preprocessing(usage) + BufferPrep(usage) { } @@ -165,7 +166,7 @@ class DummyLivePrep : public Preprocessing { fail(); } - void get_no_count(vector&, DataTag, const vector&, int) + void get_no_count(StackedVector&, DataTag, const vector&, int) { fail(); } diff --git a/Processor/ExternalClients.cpp b/Processor/ExternalClients.cpp index 1bf8136e8..80b5e2fbd 100644 --- a/Processor/ExternalClients.cpp +++ b/Processor/ExternalClients.cpp @@ -81,7 +81,6 @@ int ExternalClients::init_client_connection(const string& host, int portnum, auto socket = new client_socket(io_service, *peer_ctxs[my_client_id], plain_socket, "P" + to_string(party_num), "C" + to_string(my_client_id), true); - if (party_num == 0) { octetStream specification; specification.Receive(socket); diff --git a/Processor/Instruction.cpp b/Processor/Instruction.cpp index 999fb9eaa..512507d7d 100644 --- a/Processor/Instruction.cpp +++ b/Processor/Instruction.cpp @@ -15,7 +15,7 @@ #include template -void Instruction::execute_clear_gf2n(vector& registers, +void Instruction::execute_clear_gf2n(StackedVector& registers, MemoryPart& memory, ArithmeticProcessor& Proc) const { auto& C2 = registers; @@ -30,7 +30,7 @@ void Instruction::execute_clear_gf2n(vector& registers, } template -void Instruction::gbitdec(vector& registers) const +void Instruction::gbitdec(StackedVector& registers) const { for (int j = 0; j < size; j++) { @@ -44,7 +44,7 @@ void Instruction::gbitdec(vector& registers) const } template -void Instruction::gbitcom(vector& registers) const +void Instruction::gbitcom(StackedVector& registers) const { for (int j = 0; j < size; j++) { @@ -124,7 +124,7 @@ ostream& operator<<(ostream& s, const Instruction& instr) return s; } -template void Instruction::execute_clear_gf2n(vector& registers, +template void Instruction::execute_clear_gf2n(StackedVector& registers, MemoryPart& memory, ArithmeticProcessor& Proc) const; -template void Instruction::execute_clear_gf2n(vector& registers, +template void Instruction::execute_clear_gf2n(StackedVector& registers, MemoryPart& memory, ArithmeticProcessor& Proc) const; diff --git a/Processor/Instruction.h b/Processor/Instruction.h index 265f18a29..c89880624 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -15,6 +15,7 @@ template class Machine; template class Processor; template class SubProcessor; template class MemoryPart; +template class StackedVector; class ArithmeticProcessor; class SwitchableOutput; @@ -73,6 +74,8 @@ enum USE_MATMUL = 0x1F, ACTIVE = 0xE9, CMDLINEARG = 0xEB, + CALL_TAPE = 0xEC, + CALL_ARG = 0xED, // Addition ADDC = 0x20, ADDS = 0x21, @@ -90,6 +93,7 @@ enum PREFIXSUMS = 0x2D, PICKS = 0x2E, CONCATS = 0x2F, + ZIPS = 0x3F, // Multiplication/division/other arithmetic MULC = 0x30, MULM = 0x31, @@ -351,6 +355,7 @@ class BaseInstruction string str; public: + BaseInstruction() : opcode(0), size(0), n(0) {} virtual ~BaseInstruction() {}; int get_r(int i) const { return r[i]; } @@ -391,13 +396,13 @@ class Instruction : public BaseInstruction void execute(Processor& Proc) const; template - void execute_clear_gf2n(vector& registers, MemoryPart& memory, + void execute_clear_gf2n(StackedVector& registers, MemoryPart& memory, ArithmeticProcessor& Proc) const; template - void gbitdec(vector& registers) const; + void gbitdec(StackedVector& registers) const; template - void gbitcom(vector& registers) const; + void gbitcom(StackedVector& registers) const; void execute_regint(ArithmeticProcessor& Proc, MemoryPart& Mi) const; diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 90f4db526..dc71ac5cd 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -92,6 +92,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case DIVINT: case CONDPRINTPLAIN: case INPUTMASKREG: + case ZIPS: get_ints(r, s, 3); break; // instructions with 2 register operands @@ -245,6 +246,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case CONDPRINTSTRB: case RANDOMS: case GENSECSHUFFLE: + case CALL_ARG: r[0]=get_int(s); n = get_int(s); break; @@ -330,6 +332,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) // read from file, input is opcode num_args, // start_file_posn (read), end_file_posn(write) var1, var2, ... case READFILESHARE: + case CALL_TAPE: num_var_args = get_int(s) - 2; r[0] = get_int(s); r[1] = get_int(s); @@ -588,6 +591,7 @@ int BaseInstruction::get_reg_type() const case ACCEPTCLIENTCONNECTION: case GENSECSHUFFLE: case CMDLINEARG: + case CALL_TAPE: return INT; case PREP: case GPREP: @@ -596,7 +600,6 @@ int BaseInstruction::get_reg_type() const case USE_EDABIT: case USE_MATMUL: case RUN_TAPE: - case CISC: // those use r[] not for registers return NONE; case LDI: @@ -639,6 +642,8 @@ int BaseInstruction::get_reg_type() const case PRIVATEOUTPUT: case FIXINPUT: return CINT; + case CALL_ARG: + return n; default: if (is_gf2n_instruction()) { @@ -706,6 +711,14 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const } else return 0; + case CALL_TAPE: + { + int res = 0; + for (auto it = start.begin(); it < start.end(); it += 5) + if (it[1] == reg_type) + res = max(res, (*it ? it[3] : it[4]) + it[2]); + return res; + } default: if (get_reg_type() != reg_type) return 0; @@ -713,6 +726,21 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const switch (opcode) { + case CISC: + { + int res = 0; + for (auto it = start.begin(); it < start.end(); it += *it) + { + assert(it + *it <= start.end()); + res = max(res, it[1] + it[2]); + } + return res; + } + case MULS: + skip = 4; + offset = 1; + size_offset = -1; + break; case DOTPRODS: { int res = 0; @@ -737,7 +765,14 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const return res; } case MATMULSM: - return r[0] + start[0] * start[2]; + { + int res = 0; + for (auto it = start.begin(); it < start.end(); it += 12) + { + res = max(res, *it + *(it + 3) * *(it + 5)); + } + return res; + } case CONV2DS: { unsigned res = 0; @@ -956,6 +991,17 @@ inline void Instruction::execute(Processor& Proc) const } return; } + case ZIPS: + { + auto& S = Proc.Procp.get_S(); + auto dest = S.begin() + r[0]; + for (int i = 0; i < get_size(); i++) + { + *dest++ = S[r[1] + i]; + *dest++ = S[r[2] + i]; + } + return; + } case DIVC: Proc.write_Cp(r[0], Proc.read_Cp(r[1]) / sanitize(Proc.Procp, r[2])); break; @@ -1232,6 +1278,9 @@ inline void Instruction::execute(Processor& Proc) const case JOIN_TAPE: Proc.machine.join_tape(r[0]); break; + case CALL_TAPE: + Proc.call_tape(r[0], Proc.read_Ci(r[1]), start); + break; case CRASH: if (Proc.read_Ci(r[0])) throw crash_requested(); @@ -1277,7 +1326,6 @@ inline void Instruction::execute(Processor& Proc) const // get client connection at port number n + my_num()) int client_handle = Proc.external_clients.get_client_connection( Proc.read_Ci(r[1])); - if (Proc.P.my_num() == 0) { octetStream os; os.store(int(sint::open_type::type_char())); @@ -1421,7 +1469,8 @@ void Program::execute(Processor& Proc) const #endif #ifdef OUTPUT_INSTRUCTIONS - cerr << instruction << endl; + if (OnlineOptions::singleton.has_option("output_instructions")) + cerr << instruction << endl; #endif Proc.PC++; @@ -1461,7 +1510,7 @@ void Instruction::print(SwitchableOutput& out, T* v, T* p, T* s, T* z, T* nan) c for (int i = 0; i < size; i++) { if (p == 0 or (*p == 0 and s == 0)) - out << v[i]; + out.signed_output(v[i]); else if (s == 0) out << bigint::get_float(v[i], p[i], {}, {}); else diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index 1bb285eb3..97151cc20 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -235,6 +235,7 @@ size_t Machine::load_program(const string& threadname, M2.minimum_size(SGF2N, CGF2N, progs[i], threadname); Mp.minimum_size(SINT, CINT, progs[i], threadname); Mi.minimum_size(NONE, INT, progs[i], threadname); + bit_memories.reset(progs[i]); return progs.back().size(); } @@ -340,14 +341,15 @@ void Machine::fill_matmul(int thread_number, int tape_number, auto subdim = it->first; subdim[1] = min(subdim[1] - j, max_inner); subdim[2] = min(subdim[2] - k, max_cols); - auto& source = - dynamic_cast&>(source_proc.protocol).get_matrix_prep( + auto& source_proto = dynamic_cast&>(source_proc.protocol); + auto& source = source_proto.get_matrix_prep( subdim, source_proc); auto& dest = dynamic_cast&>(tinfo[thread_number].processor->Procp.protocol).get_matrix_prep( subdim, tinfo[thread_number].processor->Procp); - for (int i = 0; i < it->second; i++) - dest.push_triple(source.get_triple_no_count(-1)); + if (not source_proto.use_plain_matmul(subdim, source_proc)) + for (int i = 0; i < it->second; i++) + dest.push_triple(source.get_triple_no_count(-1)); } } } @@ -434,7 +436,13 @@ pair Machine::stop_threads() auto comm_stats = total_comm(); if (OnlineOptions::singleton.verbose) - queues.print_breakdown(); + { + NamedStats total; + for (auto queue : queues) + total += queue->stats; + total.print(); + queues.print_breakdown(); + } for (auto& queue : queues) delete queue; @@ -464,7 +472,7 @@ void Machine::run(const string& progname) finish_timer.start(); // actual usage - bool multithread = nthreads > 1; + multithread = nthreads > 1; auto res = stop_threads(); DataPositions& pos = res.first; @@ -518,14 +526,15 @@ void Machine::run(const string& progname) cerr << "Full broadcast" << endl; #endif -#ifdef CHOP_MEMORY - // Reduce memory size to speed up - unsigned max_size = 1 << 20; - if (M2.size_s() > max_size) - M2.resize_s(max_size); - if (Mp.size_s() > max_size) - Mp.resize_s(max_size); -#endif + if (not OnlineOptions::singleton.has_option("output_full_memory")) + { + // Reduce memory size to speed up + unsigned max_size = 1 << 20; + if (M2.size_s() > max_size) + M2.resize_s(max_size); + if (Mp.size_s() > max_size) + Mp.resize_s(max_size); + } // Write out the memory to use next time ofstream outf(memory_filename(), ios::out | ios::binary); diff --git a/Processor/Memory.h b/Processor/Memory.h index 649c745b1..0bfefdef8 100644 --- a/Processor/Memory.h +++ b/Processor/Memory.h @@ -44,10 +44,10 @@ class MemoryPart virtual const T& at(size_t i) const = 0; template - void indirect_read(const Instruction& inst, vector& regs, + void indirect_read(const Instruction& inst, StackedVector& regs, const U& indices); template - void indirect_write(const Instruction& inst, vector& regs, + void indirect_write(const Instruction& inst, StackedVector& regs, const U& indices); void minimum_size(size_t size); diff --git a/Processor/Memory.hpp b/Processor/Memory.hpp index 23a41b1fa..5781061c3 100644 --- a/Processor/Memory.hpp +++ b/Processor/Memory.hpp @@ -6,21 +6,21 @@ template template void MemoryPart::indirect_read(const Instruction& inst, - vector& regs, const U& indices) + StackedVector& regs, const U& indices) { size_t n = inst.get_size(); auto dest = regs.begin() + inst.get_r(0); auto start = indices.begin() + inst.get_r(1); -#ifdef CHECK_SIZE +#ifndef NO_CHECK_SIZE assert(start + n <= indices.end()); assert(dest + n <= regs.end()); #endif - long size = this->size(); + size_t size = this->size(); const T* data = this->data(); for (auto it = start; it < start + n; it++) { #ifndef NO_CHECK_SIZE - if (*it >= size) + if (size_t(it->get()) >= size) throw overflow(T::type_string() + " memory read", it->get(), size); #endif *dest++ = data[it->get()]; @@ -30,21 +30,21 @@ void MemoryPart::indirect_read(const Instruction& inst, template template void MemoryPart::indirect_write(const Instruction& inst, - vector& regs, const U& indices) + StackedVector& regs, const U& indices) { size_t n = inst.get_size(); auto source = regs.begin() + inst.get_r(0); auto start = indices.begin() + inst.get_r(1); -#ifdef CHECK_SIZE +#ifndef NO_CHECK_SIZE assert(start + n <= indices.end()); assert(source + n <= regs.end()); #endif - long size = this->size(); + size_t size = this->size(); T* data = this->data(); for (auto it = start; it < start + n; it++) { #ifndef NO_CHECK_SIZE - if (*it >= size) + if (size_t(it->get()) >= size) throw overflow(T::type_string() + " memory write", it->get(), size); #endif data[it->get()] = *source++; diff --git a/Processor/Online-Thread.h b/Processor/Online-Thread.h index 577ab9f44..8368cadc0 100644 --- a/Processor/Online-Thread.h +++ b/Processor/Online-Thread.h @@ -33,6 +33,7 @@ class thread_info const char* name); void Sub_Main_Func(); + void Main_Func_With_Purge(); }; #endif diff --git a/Processor/Online-Thread.hpp b/Processor/Online-Thread.hpp index 5bce952fe..045e9f3e7 100644 --- a/Processor/Online-Thread.hpp +++ b/Processor/Online-Thread.hpp @@ -340,6 +340,13 @@ void thread_info::Sub_Main_Func() cerr << endl; #endif + if (num == 0 and OnlineOptions::singleton.verbose + and machine.queues.size() > 1) + { + cerr << "Main thread communication:" << endl; + P.total_comm().print(); + } + // wind down thread by thread machine.stats += Proc.stats; queues->timers["wait"] = wait_timer + queues->wait_timer; @@ -347,11 +354,43 @@ void thread_info::Sub_Main_Func() queues->timers["online"] = online_timer - online_prep_timer - queues->wait_timer; queues->timers["prep"] = timer - queues->timers["wait"] - queues->timers["online"]; + NamedStats stats; + stats["integer multiplications"] = Proc.Procp.protocol.counter; + stats["integer multiplication rounds"] = Proc.Procp.protocol.rounds; + stats["probabilistic truncations"] = Proc.Procp.protocol.trunc_pr_counter; + stats["probabilistic truncation rounds"] = Proc.Procp.protocol.trunc_rounds; + stats["ANDs"] = Proc.share_thread.protocol->bit_counter; + stats["AND rounds"] = Proc.share_thread.protocol->rounds; + stats["integer openings"] = MCp->values_opened; + stats["integer inputs"] = Proc.Procp.input.values_input; + for (auto x : Proc.Procp.shuffler.stats) + stats["shuffles of length " + to_string(x.first)] = x.second; + + try + { + auto proc = dynamic_cast&>(Proc.DataF.DataFp).bit_part_proc; + if (proc) + stats["ANDs in preprocessing"] = proc->protocol.bit_counter; + } + catch (...) + { + } + + try + { + auto protocol = dynamic_cast&>(Proc.DataF.DataFp).protocol; + if (protocol) + stats["integer multiplications in preprocessing"] = protocol->counter; + } + catch (...) + { + } + // prevent faulty usage message Proc.DataF.set_usage(actual_usage); delete processor; - queues->finished(actual_usage, P.total_comm()); + queues->finished(actual_usage, P.total_comm(), stats); delete MC2; delete MCp; @@ -367,6 +406,27 @@ template void* thread_info::Main_Func(void* ptr) { auto& ti = *(thread_info*)(ptr); + if (OnlineOptions::singleton.has_option("throw_exceptions")) + ti.Main_Func_With_Purge(); + else + { + try + { + ti.Main_Func_With_Purge(); + } + catch (exception& e) + { + cerr << "Fatal error: " << e.what() << endl; + exit(1); + } + } + return 0; +} + +template +void thread_info::Main_Func_With_Purge() +{ + auto& ti = *this; #ifdef INSECURE ti.Sub_Main_Func(); #else @@ -383,12 +443,10 @@ void* thread_info::Main_Func(void* ptr) } catch (...) { - thread_info* ti = (thread_info*)ptr; - ti->purge_preprocessing(ti->machine->get_N(), ti->thread_num); + purge_preprocessing(machine->get_N(), thread_num); throw; } #endif - return 0; } diff --git a/Processor/OnlineMachine.h b/Processor/OnlineMachine.h index a6b493080..eb9e2d1fa 100644 --- a/Processor/OnlineMachine.h +++ b/Processor/OnlineMachine.h @@ -36,6 +36,8 @@ class OnlineMachine template int run(); + template + int run_with_error(); Player* new_player(const string& id_base); diff --git a/Processor/OnlineMachine.hpp b/Processor/OnlineMachine.hpp index 078a39199..b337c3a7a 100644 --- a/Processor/OnlineMachine.hpp +++ b/Processor/OnlineMachine.hpp @@ -173,6 +173,25 @@ Player* OnlineMachine::new_player(const string& id_base) template int OnlineMachine::run() +{ + if (online_opts.has_option("throw_exception")) + return run_with_error(); + else + { + try + { + return run_with_error(); + } + catch (exception& e) + { + cerr << "Fatal error: " << e.what() << endl; + exit(1); + } + } +} + +template +int OnlineMachine::run_with_error() { #ifndef INSECURE try diff --git a/Processor/OnlineOptions.cpp b/Processor/OnlineOptions.cpp index ad0a0403d..b1677e89f 100644 --- a/Processor/OnlineOptions.cpp +++ b/Processor/OnlineOptions.cpp @@ -112,6 +112,15 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, "-B", // Flag token. "--bucket-size" // Flag token. ); + opt.add( + "", // Default. + 0, // Required? + -1, // Number of args expected. + ',', // Delimiter if expecting multiple args. + "Further options", // Help description. + "-o", // Flag token. + "--options" // Flag token. + ); if (security) opt.add( @@ -138,6 +147,12 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, verbose = opt.isSet("--verbose"); #endif + opt.get("--options")->getStrings(options); + +#ifdef THROW_EXCEPTIONS + options.push_back("throw_exceptions"); +#endif + if (security) { opt.get("-S")->getInt(security_parameter); diff --git a/Processor/OnlineOptions.h b/Processor/OnlineOptions.h index 7e32c3176..00408c7af 100644 --- a/Processor/OnlineOptions.h +++ b/Processor/OnlineOptions.h @@ -37,6 +37,7 @@ class OnlineOptions bool receive_threads; std::string disk_memory; vector args; + vector options; OnlineOptions(); OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, @@ -67,6 +68,11 @@ class OnlineOptions lgp = numBits(prime); return get_prep_sub_dir(PREP_DIR, nplayers, lgp); } + + bool has_option(const string& option) + { + return find(options.begin(), options.end(), option) != options.end(); + } }; #endif /* PROCESSOR_ONLINEOPTIONS_H_ */ diff --git a/Processor/PrepBase.cpp b/Processor/PrepBase.cpp index 5403b6f7b..0d467614c 100644 --- a/Processor/PrepBase.cpp +++ b/Processor/PrepBase.cpp @@ -40,6 +40,11 @@ string PrepBase::get_edabit_filename(const string& prep_data_dir, + to_string(my_num) + get_suffix(thread_num); } +PrepBase::PrepBase(DataPositions& usage) : + usage(usage) +{ +} + void PrepBase::print_left(const char* name, size_t n, const string& type_string, size_t used, bool large) { @@ -72,7 +77,8 @@ void PrepBase::print_left_edabits(size_t n, size_t n_batch, bool strict, cerr << " edaBits of size " << n_bits << " left" << endl; } - if (n * n_batch > used / 10) + if (n * n_batch > used / 10 + and n * n_batch > size_t(usage.files[DATA_INT][DATA_DABIT]) / 10) { cerr << "Significant amount of unused edaBits of size " << n_bits << ". "; diff --git a/Processor/PrepBase.h b/Processor/PrepBase.h index 8e68b1ad7..78aa7332c 100644 --- a/Processor/PrepBase.h +++ b/Processor/PrepBase.h @@ -12,8 +12,13 @@ using namespace std; #include "Math/field_types.h" #include "Tools/TimerWithComm.h" +class DataPositions; + class PrepBase { +protected: + DataPositions& usage; + public: static string get_suffix(int thread_num); @@ -25,12 +30,15 @@ class PrepBase static string get_edabit_filename(const string& prep_data_dir, int n_bits, int my_num, int thread_num = 0); - static void print_left(const char* name, size_t n, + TimerWithComm prep_timer; + + PrepBase(DataPositions& usage); + + void print_left(const char* name, size_t n, const string& type_string, size_t used, bool large = false); - static void print_left_edabits(size_t n, size_t n_batch, bool strict, - int n_bits, size_t used, bool malicious); - TimerWithComm prep_timer; + void print_left_edabits(size_t n, size_t n_batch, bool strict, + int n_bits, size_t used, bool malicious); }; #endif /* PROCESSOR_PREPBASE_H_ */ diff --git a/Processor/PrepBuffer.h b/Processor/PrepBuffer.h new file mode 100644 index 000000000..9f9d5339b --- /dev/null +++ b/Processor/PrepBuffer.h @@ -0,0 +1,44 @@ +/* + * PrepBuffer.h + * + */ + +#ifndef PROCESSOR_PREPBUFFER_H_ +#define PROCESSOR_PREPBUFFER_H_ + +#include "Tools/Buffer.h" + +template +class PrepBuffer : public BufferOwner +{ + int num_players; + string fake_opts; + +public: + PrepBuffer() : + num_players(0) + { + } + + void setup(int num_players, const string& filename, int tuple_length, + const string& type_string = "", const char* data_type = "") + { + this->num_players = num_players; + fake_opts = V::template proto_fake_opts(); + BufferOwner::setup(filename, tuple_length, type_string, data_type); + } + + void input(U& a) + { + try + { + BufferOwner::input(a); + } + catch (exception& e) + { + throw prep_setup_error(e.what(), num_players, fake_opts); + } + } +}; + +#endif /* PROCESSOR_PREPBUFFER_H_ */ diff --git a/Processor/Processor.h b/Processor/Processor.h index 08f4cd269..612ff00f1 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -21,18 +21,22 @@ #include "GC/Processor.h" #include "GC/ShareThread.h" #include "Protocols/SecureShuffle.h" +#include "Tools/NamedStats.h" class Program; +// synchronize in asymmetric protocols +template +void sync(vector& x, Player& P); + template class SubProcessor { - CheckVector C; - CheckVector S; + StackedVector C; + StackedVector S; DataPositions bit_usage; - - typename T::Protocol::Shuffler shuffler; + NamedStats stats; void resize(size_t size) { C.resize(size); S.resize(size); } @@ -62,6 +66,8 @@ class SubProcessor typename BT::LivePrep bit_prep; vector personal_bit_preps; + typename T::Protocol::Shuffler shuffler; + SubProcessor(ArithmeticProcessor& Proc, typename T::MAC_Check& MC, Preprocessing& DataF, Player& P); SubProcessor(typename T::MAC_Check& MC, Preprocessing& DataF, Player& P, @@ -76,8 +82,8 @@ class SubProcessor void muls(const vector& reg); void mulrs(const vector& reg); void dotprods(const vector& reg, int size); - void matmuls(const vector& source, const Instruction& instruction); - void matmulsm(const MemoryPart& source, const Instruction& instruction); + void matmuls(const StackedVector& source, const Instruction& instruction); + void matmulsm(const MemoryPart& source, const vector& args); void matmulsm_finalize_batch(vector::const_iterator startMatmul, int startI, int startJ, vector::const_iterator endMatmul, @@ -96,12 +102,12 @@ class SubProcessor void send_personal(const vector& args); void private_output(const vector& args); - CheckVector& get_S() + StackedVector& get_S() { return S; } - CheckVector& get_C() + StackedVector& get_C() { return C; } @@ -116,13 +122,17 @@ class SubProcessor return C[i]; } - void inverse_permutation(const Instruction &instruction, int handle); + void inverse_permutation(const Instruction &instruction, int handle); + + void push_stack(); + void push_args(const vector& args); + void pop_stack(const vector& results); }; class ArithmeticProcessor : public ProcessorBase { protected: - CheckVector Ci; + StackedVector Ci; ofstream public_output; ofstream binary_output; @@ -174,7 +184,7 @@ class ArithmeticProcessor : public ProcessorBase { return Ci[i]; } void write_Ci(size_t i, const long& x) { Ci[i]=x; } - CheckVector& get_Ci() + StackedVector& get_Ci() { return Ci; } virtual ofstream& get_public_output() @@ -292,6 +302,8 @@ class Processor : public ArithmeticProcessor ofstream& get_public_output(); ofstream& get_binary_output(); + void call_tape(int tape_number, int arg, const vector& results); + private: template friend class SPDZ; diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index aab2ba9d0..d6cdd7567 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -25,9 +25,8 @@ SubProcessor::SubProcessor(ArithmeticProcessor& Proc, typename T::MAC_Check& template SubProcessor::SubProcessor(typename T::MAC_Check& MC, Preprocessing& DataF, Player& P, ArithmeticProcessor* Proc) : - shuffler(*this), Proc(Proc), MC(MC), P(P), DataF(DataF), protocol(P), input(*this, MC), - bit_prep(bit_usage) + bit_prep(bit_usage), shuffler(*this) { DataF.set_proc(this); protocol.init(DataF, MC); @@ -113,7 +112,7 @@ Processor::Processor(int thread_num,Player& P, secure_prng.ReSeed(); shared_prng.SeedGlobally(P, false); - setup_redirection(P.my_num(), thread_num, opts, out); + setup_redirection(P.my_num(), thread_num, opts, out, sint::real_shares(P)); Procb.out = out; } @@ -158,6 +157,7 @@ void Processor::reset(const Program& program,int arg) Procp.get_S().resize(program.num_reg(SINT)); Procp.get_C().resize(program.num_reg(CINT)); Ci.resize(program.num_reg(INT)); + this->arg = arg; Procb.reset(program); } @@ -209,17 +209,6 @@ void Processor::edabit(const Instruction& instruction, bool strict) &Procp.get_S_ref(instruction.get_r(0)), Procb.S, regs); } -template -void Processor::convcbitvec(const Instruction& instruction) -{ - for (size_t i = 0; i < instruction.get_n(); i++) - { - int i1 = i / GC::Clear::N_BITS; - int i2 = i % GC::Clear::N_BITS; - Ci[instruction.get_r(0) + i] = Procb.C[instruction.get_r(1) + i1].get_bit(i2); - } -} - template void Processor::convcintvec(const Instruction& instruction) { @@ -314,10 +303,9 @@ void Processor::write_socket(const RegType reg_type, } } -#ifdef VERBOSE_COMM - cerr << "send " << socket_stream.get_length() << " to client " << socket_id - << endl; -#endif + if (OnlineOptions::singleton.has_option("verbose_comm")) + fprintf(stderr, "Send %zu bytes to client %d\n", socket_stream.get_length(), + socket_id); try { TimeScope _(client_stats.add(socket_stream.get_length())); @@ -362,7 +350,10 @@ void Processor::read_socket_vector(int client_id, for (int j = 0; j < size; j++) for (int i = 0; i < m; i++) get_Cp_ref(registers[i] + j) = - socket_stream.get(); + socket_stream.get(); + + if (socket_stream.left()) + throw runtime_error("unexpected data"); } // Receive vector of field element shares over private channel @@ -562,7 +553,7 @@ void SubProcessor::dotprods(const vector& reg, int size) } template -void SubProcessor::matmuls(const vector& source, +void SubProcessor::matmuls(const StackedVector& source, const Instruction& instruction) { protocol.init_dotprod(); @@ -604,12 +595,10 @@ void SubProcessor::matmuls(const vector& source, template void SubProcessor::matmulsm(const MemoryPart& source, - const Instruction& instruction) + const vector& start) { assert(Proc); - auto& start = instruction.get_start(); - auto batchStartMatrix = start.begin(); int batchStartI = 0; int batchStartJ = 0; @@ -816,7 +805,7 @@ Conv2dTuple::Conv2dTuple(const vector& arguments, int start) } template -void Conv2dTuple::pre(vector& S, typename T::Protocol& protocol) +void Conv2dTuple::pre(StackedVector& S, typename T::Protocol& protocol) { for (int i_batch = 0; i_batch < batch_size; i_batch ++) { @@ -857,7 +846,7 @@ void Conv2dTuple::pre(vector& S, typename T::Protocol& protocol) } template -void Conv2dTuple::post(vector& S, typename T::Protocol& protocol) +void Conv2dTuple::post(StackedVector& S, typename T::Protocol& protocol) { for (int i_batch = 0; i_batch < batch_size; i_batch ++) { @@ -1049,17 +1038,85 @@ void Processor::fixinput(const Instruction& instruction) template long Processor::sync(long x) const +{ + vector tmp = {x}; + ::sync(tmp, P); + return tmp[0].get(); +} + +template +void sync(vector& x, Player& P) { if (not sint::symmetric) { + octetStream os; // send number to dealer if (P.my_num() == 0) - P.send_long(P.num_players() - 1, x); + { + os.store(x); + P.send_to(P.num_players() - 1, os); + } if (not sint::real_shares(P)) - return P.receive_long(0); + { + P.receive_player(0, os); + os.get(x); + } } +} - return x; +template +void SubProcessor::push_stack() +{ + S.push_stack(); + C.push_stack(); +} + +template +void SubProcessor::push_args(const vector& args) +{ + auto char2 = T::clear::characteristic_two; + S.push_args(args, char2 ? SGF2N : SINT); + C.push_args(args, char2 ? CGF2N : CINT); +} + +template +void SubProcessor::pop_stack(const vector& results) +{ + auto char2 = T::clear::characteristic_two; + S.pop_stack(results, char2 ? SGF2N : SINT); + C.pop_stack(results, char2 ? CGF2N : CINT); +} + +template +void Processor::call_tape(int tape_number, int arg, + const vector& args) +{ + PC_stack.push_back(PC); + arg_stack.push_back(this->arg); + Procp.push_stack(); + Proc2.push_stack(); + Procb.push_stack(); + Ci.push_stack(); + + auto& tape = machine.progs.at(tape_number); + reset(tape, arg); + + Procp.push_args(args); + Proc2.push_args(args); + Procb.push_args(args); + Ci.push_args(args, INT); + + tape.execute(*this); + + Procp.pop_stack(args); + Proc2.pop_stack(args); + Procb.pop_stack(args); + Ci.pop_stack(args, INT); + + PC = PC_stack.back(); + PC_stack.pop_back(); + this->arg = arg_stack.back(); + arg_stack.pop_back(); } #endif diff --git a/Processor/ProcessorBase.cpp b/Processor/ProcessorBase.cpp index 0fa1ab529..9038ef8a6 100644 --- a/Processor/ProcessorBase.cpp +++ b/Processor/ProcessorBase.cpp @@ -27,11 +27,12 @@ void ProcessorBase::open_input_file(int my_num, int thread_num, } void ProcessorBase::setup_redirection(int my_num, int thread_num, - OnlineOptions& opts, SwitchableOutput& out) + OnlineOptions& opts, SwitchableOutput& out, bool real) { // only output on party 0 if not interactive bool always_stdout = opts.cmd_private_output_file == "."; bool output = my_num == 0 or opts.interactive or always_stdout; + output &= real; out.activate(output); if (not (opts.cmd_private_output_file.empty() or always_stdout)) diff --git a/Processor/ProcessorBase.h b/Processor/ProcessorBase.h index d33dea42f..b765b77b0 100644 --- a/Processor/ProcessorBase.h +++ b/Processor/ProcessorBase.h @@ -28,6 +28,8 @@ class ProcessorBase protected: // Optional argument to tape Integer arg; + vector arg_stack; + vector PC_stack; string get_parameterized_filename(int my_num, int thread_num, const string& prefix); @@ -61,7 +63,7 @@ class ProcessorBase T get_input(istream& is, const string& input_filename, const int* params); void setup_redirection(int my_nu, int thread_num, OnlineOptions& opts, - SwitchableOutput& out); + SwitchableOutput& out, bool real = true); }; #endif /* PROCESSOR_PROCESSORBASE_H_ */ diff --git a/Processor/Program.cpp b/Processor/Program.cpp index f9cb5c579..6774b8813 100644 --- a/Processor/Program.cpp +++ b/Processor/Program.cpp @@ -28,6 +28,24 @@ void Program::compute_constants() } void Program::parse(string filename) +{ + if (OnlineOptions::singleton.has_option("throw_exceptions")) + parse_with_error(filename); + else + { + try + { + parse_with_error(filename); + } + catch(exception& e) + { + cerr << "Error in bytecode: " << e.what() << endl; + exit(1); + } + } +} + +void Program::parse_with_error(string filename) { ifstream pinp(filename); if (pinp.fail()) diff --git a/Processor/Program.h b/Processor/Program.h index 2c8470f8c..7783da687 100644 --- a/Processor/Program.h +++ b/Processor/Program.h @@ -42,6 +42,7 @@ class Program // Read in a program void parse(string filename); + void parse_with_error(string filename); void parse(istream& s); DataPositions get_offline_data_used() const { return offline_data_used; } diff --git a/Processor/RingOptions.cpp b/Processor/RingOptions.cpp index ec9e9f066..2ac542e9a 100644 --- a/Processor/RingOptions.cpp +++ b/Processor/RingOptions.cpp @@ -30,9 +30,14 @@ RingOptions::RingOptions(ez::ezOptionParser& opt, int argc, const char** argv) int RingOptions::ring_size_from_opts_or_schedule(string progname) { + int r = BaseMachine::ring_size_from_schedule(progname); if (R_is_set) + { + if (r and r != R) + cerr << "Different -R option in compilation and run-time: " << r + << " vs " << R << endl; return R; - int r = BaseMachine::ring_size_from_schedule(progname); + } if (r == 0) r = R; cerr << "Trying to run " << r << "-bit computation" << endl; diff --git a/Processor/ThreadQueue.cpp b/Processor/ThreadQueue.cpp index ced871b5d..becf3ab14 100644 --- a/Processor/ThreadQueue.cpp +++ b/Processor/ThreadQueue.cpp @@ -33,10 +33,12 @@ void ThreadQueue::finished(const ThreadJob& job) out.push(job); } -void ThreadQueue::finished(const ThreadJob& job, const NamedCommStats& new_comm_stats) +void ThreadQueue::finished(const ThreadJob& job, + const NamedCommStats& new_comm_stats, const NamedStats& stats) { finished(job); set_comm_stats(new_comm_stats); + this->stats = stats; } void ThreadQueue::set_comm_stats(const NamedCommStats& new_comm_stats) diff --git a/Processor/ThreadQueue.h b/Processor/ThreadQueue.h index c9640b7ae..6f100d344 100644 --- a/Processor/ThreadQueue.h +++ b/Processor/ThreadQueue.h @@ -7,6 +7,7 @@ #define PROCESSOR_THREADQUEUE_H_ #include "ThreadJob.h" +#include "Tools/NamedStats.h" class ThreadQueue { @@ -20,6 +21,7 @@ class ThreadQueue map timers; Timer wait_timer; + NamedStats stats; ThreadQueue() : left(0) @@ -34,7 +36,8 @@ class ThreadQueue void schedule(const ThreadJob& job); ThreadJob next(); void finished(const ThreadJob& job); - void finished(const ThreadJob& job, const NamedCommStats& comm_stats); + void finished(const ThreadJob& job, const NamedCommStats& comm_stats, + const NamedStats& stats = {}); ThreadJob result(); void set_comm_stats(const NamedCommStats& new_comm_stats); diff --git a/Programs/Circuits b/Programs/Circuits index 908452826..cdd592769 160000 --- a/Programs/Circuits +++ b/Programs/Circuits @@ -1 +1 @@ -Subproject commit 908452826cbd67f9757850a4943871bf66574e48 +Subproject commit cdd5927692c04670593c9d9c922e65f0fbe4f203 diff --git a/Programs/Source/bench-dt.mpc b/Programs/Source/bench-dt.mpc index 4c8c64c90..f3287dd83 100644 --- a/Programs/Source/bench-dt.mpc +++ b/Programs/Source/bench-dt.mpc @@ -24,8 +24,10 @@ decision_tree.max_leaves = 2000 if 'nearest' in program.args: sfix.round_nearest = True -layers = decision_tree.TreeTrainer( - train[1], train[0], n_levels, binary=binary, n_threads=n_threads).train() +trainer = decision_tree.TreeTrainer( + train[1], train[0], n_levels, binary=binary, n_threads=n_threads) +trainer.time = 'time' in program.args +layers = trainer.train() #decision_tree.output_decision_tree(layers) diff --git a/Programs/Source/falcon_alex.mpc b/Programs/Source/falcon_alex.mpc index 26422b86e..df13e382f 100644 --- a/Programs/Source/falcon_alex.mpc +++ b/Programs/Source/falcon_alex.mpc @@ -88,7 +88,7 @@ model = tf.keras.models.Sequential(AlexNet) model.compile_by_args(program) -model.build(training_samples.sizes) +model.build(training_samples.sizes, program=program) model.summary() opt = model.fit( diff --git a/Programs/Source/htmac.mpc b/Programs/Source/htmac.mpc index 310e20764..b16fdf81c 100644 --- a/Programs/Source/htmac.mpc +++ b/Programs/Source/htmac.mpc @@ -38,9 +38,9 @@ from random import randint import sys program.bit_length = 128 -n_parallel = int(sys.argv[2]) -n_total = int(sys.argv[3]) -nmessages = int(sys.argv[4]) +n_parallel = int(program.args[1]) +n_total = int(program.args[2]) +nmessages = int(program.args[3]) use_mimc_prf = True # Use just one PRF diff --git a/Programs/Source/keras_mnist_lenet_predict.mpc b/Programs/Source/keras_mnist_lenet_predict.mpc index 100dd564a..3bd63bb7c 100644 --- a/Programs/Source/keras_mnist_lenet_predict.mpc +++ b/Programs/Source/keras_mnist_lenet_predict.mpc @@ -36,6 +36,7 @@ model.build(test_samples.sizes) start = 0 for var in model.trainable_variables: var.assign_all(0) +# activate to use the model output by keras_mnist_lenet # start = var.read_from_file(start) guesses = model.predict(test_samples) diff --git a/Programs/Source/mnist_A.mpc b/Programs/Source/mnist_A.mpc index 18d4f369d..2184969fe 100644 --- a/Programs/Source/mnist_A.mpc +++ b/Programs/Source/mnist_A.mpc @@ -82,22 +82,8 @@ def _(i): sgd.run(batch_size) stop_timer(1) - def get_correct(Y, n): - n_correct = regint(0) - for i in range(n): - n_correct += (Y[i].reveal() > 0).bit_xor( - layers[-2].Y[i][0][0][0].reveal() < 0) - return n_correct - - sgd.forward(N) - - n_correct = get_correct(layers[-1].Y, N) + n_correct, loss = sgd.reveal_correctness(layers[0].X, layers[-1].Y) print_ln('train_acc: %s (%s/%s)', cfix(n_correct) / N, n_correct, N) - training_address = layers[0].X.address - layers[0].X.address = X.address - sgd.forward(n_test) - layers[0].X.address = training_address - - n_correct = get_correct(Y, n_test) + n_correct, loss = sgd.reveal_correctness(X, Y) print_ln('acc: %s (%s/%s)', cfix(n_correct) / n_test, n_correct, n_test) diff --git a/Programs/Source/personal_client_example.py b/Programs/Source/personal_client_example.py new file mode 100644 index 000000000..f7b491488 --- /dev/null +++ b/Programs/Source/personal_client_example.py @@ -0,0 +1,11 @@ +listen_for_clients(15000) +socket = accept_client_connection(15000) + +n = 1000 + +for i in range(2): + x = personal.read_fix_from_socket(i, socket, n) + sfix(x).write_fully_to_socket(socket) + +res = sum(sfix.read_from_socket(socket, n)) +print_ln('%s', res.reveal()) diff --git a/Programs/Source/prf_mimc.mpc b/Programs/Source/prf_mimc.mpc index d03fa0056..8bd089a8d 100644 --- a/Programs/Source/prf_mimc.mpc +++ b/Programs/Source/prf_mimc.mpc @@ -2,7 +2,7 @@ from Compiler import instructions_base import sys program.bit_length = 128 -nparallel = int(sys.argv[2]) +nparallel = int(program.args[1]) instructions_base.set_global_vector_size(nparallel) use_cubes = True diff --git a/Programs/Source/torch_densenet.py b/Programs/Source/torch_densenet.py new file mode 100644 index 000000000..0ddadbf12 --- /dev/null +++ b/Programs/Source/torch_densenet.py @@ -0,0 +1,53 @@ +# this tests the pretrained DenseNet in secure computation + +program.options_from_args() +sfix.set_precision_from_args(program, adapt_ring=True) +MultiArray.disable_index_checks() +Array.check_indices = False + +from Compiler import ml + +try: + ml.set_n_threads(int(program.args[2])) +except: + pass + +import torchvision +import torch +import numpy +import requests +import io +import PIL + +from torchvision import transforms + +model = getattr(torchvision.models.densenet, 'densenet' + program.args[1])( + weights='DEFAULT') + +r = requests.get('https://github.com/pytorch/hub/raw/master/images/dog.jpg') +input_image = PIL.Image.open(io.BytesIO(r.content)) +preprocess = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), +]) +input_tensor = preprocess(input_image) +input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model + +with torch.no_grad(): + output = int(model(input_batch).argmax()) + print('Model says %d' % output) + +secret_input = sfix.input_tensor_via( + 0, numpy.moveaxis(input_batch.numpy(), 1, -1)) + +layers = ml.layers_from_torch( + model, secret_input.shape, 1, input_via=0, + layer_args={model.features.conv0: {'weight_type': sfix.get_prec_type(32)}}) + +optimizer = ml.Optimizer(layers) +optimizer.output_stats = 'output_stats' in program.args + +print_ln('Secure computation says %s', + optimizer.eval(secret_input, top=True)[0].reveal()) diff --git a/Programs/Source/torch_mnist_lenet.mpc b/Programs/Source/torch_mnist_lenet.mpc index 75ccf24d6..353cebde8 100644 --- a/Programs/Source/torch_mnist_lenet.mpc +++ b/Programs/Source/torch_mnist_lenet.mpc @@ -40,7 +40,7 @@ from Compiler import ml ml.set_n_threads(int(program.args[2])) -layers = ml.layers_from_torch(net, data[0][1].shape, 128) +layers = ml.layers_from_torch(net, data[0][1].shape, 128, program=program) layers[0].X = data[0][1] layers[-1].Y = data[0][0] diff --git a/Programs/Source/torch_mnist_lenet_predict.mpc b/Programs/Source/torch_mnist_lenet_predict.mpc index 8e8b54cb1..465c69eef 100644 --- a/Programs/Source/torch_mnist_lenet_predict.mpc +++ b/Programs/Source/torch_mnist_lenet_predict.mpc @@ -17,7 +17,7 @@ for train in True, False: import torch import torch.nn as nn -net = nn.Sequential( +layers = [ nn.Conv2d(1, 20, 5), nn.ReLU(), nn.MaxPool2d(2), @@ -29,7 +29,12 @@ net = nn.Sequential( nn.Linear(800, 500), nn.ReLU(), nn.Linear(500, 10) -) +] + +if 'bn' in program.args: + layers.insert(3, nn.BatchNorm2d(20)) + +net = nn.Sequential(*layers) # train for a bit transform = torchvision.transforms.Compose( diff --git a/Programs/Source/torch_resnet.py b/Programs/Source/torch_resnet.py new file mode 100644 index 000000000..5a0e72e6f --- /dev/null +++ b/Programs/Source/torch_resnet.py @@ -0,0 +1,47 @@ +# this tests the pretrained ResNet in secure computation + +program.options_from_args() + +from Compiler import ml + +try: + ml.set_n_threads(int(program.args[2])) +except: + pass + +import torchvision +import torch +import numpy +import requests +import io +import PIL + +from torchvision import transforms + +model = getattr(torchvision.models.resnet, 'resnet' + program.args[1])( + weights='DEFAULT') + +r = requests.get('https://github.com/pytorch/hub/raw/master/images/dog.jpg') +input_image = PIL.Image.open(io.BytesIO(r.content)) +preprocess = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), +]) +input_tensor = preprocess(input_image) +input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model + +with torch.no_grad(): + output = int(model(input_batch).argmax()) + print('Model says %d' % output) + +secret_input = sfix.input_tensor_via( + 0, numpy.moveaxis(input_batch.numpy(), 1, -1)) + +layers = ml.layers_from_torch(model, secret_input.shape, 1, input_via=0) + +optimizer = ml.Optimizer(layers) + +print_ln('Secure computation says %s', + optimizer.eval(secret_input, top=True)[0].reveal()) diff --git a/Programs/Source/torch_squeeze.py b/Programs/Source/torch_squeeze.py new file mode 100644 index 000000000..4aed19406 --- /dev/null +++ b/Programs/Source/torch_squeeze.py @@ -0,0 +1,53 @@ +# this tests the pretrained SqueezeNet in secure computation + +program.options_from_args() + +from Compiler import ml + +try: + ml.set_n_threads(int(program.args[1])) +except: + pass + +import torchvision +import torch +import numpy +import requests +import io +import PIL + +from torchvision import transforms + +model = torchvision.models.get_model('SqueezeNet1_1', weights='DEFAULT') + +r = requests.get('https://github.com/pytorch/hub/raw/master/images/dog.jpg') +input_image = PIL.Image.open(io.BytesIO(r.content)) +preprocess = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), +]) +input_tensor = preprocess(input_image) +input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model + +with torch.no_grad(): + output = int(model(input_batch).argmax()) + print('Model says %d' % output) + +secret_input = sfix.input_tensor_via( + 0, numpy.moveaxis(input_batch.numpy(), 1, -1)) + +layer_args = {} +if 'first_conv_high' in program.args: + layer_args[getattr(model.features, '0')] = \ + {'weight_type': sfix.get_prec_type(32)} + +layers = ml.layers_from_torch(model, secret_input.shape, 1, input_via=0, + layer_args=layer_args) + +optimizer = ml.Optimizer(layers) +optimizer.output_stats = 'output_stats' in program.args + +print_ln('Secure computation says %s', + optimizer.eval(secret_input, top=True)[0].reveal()) diff --git a/Protocols/AtlasShare.h b/Protocols/AtlasShare.h index bea233a59..41c672ea4 100644 --- a/Protocols/AtlasShare.h +++ b/Protocols/AtlasShare.h @@ -24,7 +24,7 @@ class AtlasShare : public ShamirShare public: typedef Atlas Protocol; - typedef ::Input Input; + typedef ShamirInput Input; typedef IndirectShamirMC MAC_Check; typedef ShamirMC Direct_MC; typedef ::PrivateOutput PrivateOutput; @@ -35,6 +35,8 @@ class AtlasShare : public ShamirShare typedef GC::AtlasSecret bit_type; #endif + const static int bit_generation_threshold = 2; + static string alt() { return ""; diff --git a/Protocols/Beaver.hpp b/Protocols/Beaver.hpp index 8c89f420b..69f1f7323 100644 --- a/Protocols/Beaver.hpp +++ b/Protocols/Beaver.hpp @@ -83,7 +83,7 @@ void Beaver::stop_exchange() template T Beaver::finalize_mul(int n) { - (void) n; + this->add_mul(n); typename T::open_type masked[2]; T& tmp = (*triple)[2]; for (int k = 0; k < 2; k++) diff --git a/Protocols/BufferScope.h b/Protocols/BufferScope.h index 6cb2b82c9..e5519e183 100644 --- a/Protocols/BufferScope.h +++ b/Protocols/BufferScope.h @@ -9,15 +9,17 @@ template class BufferPrep; template class Preprocessing; +#include "Processor/OnlineOptions.h" + template class BufferScope { - BufferPrep& prep; + T& prep; int bak; public: - BufferScope(Preprocessing & prep, int buffer_size) : - prep(dynamic_cast&>(prep)) + BufferScope(T& prep, int buffer_size) : + prep(prep) { bak = this->prep.buffer_size; this->prep.buffer_size = buffer_size; diff --git a/Protocols/DabitSacrifice.hpp b/Protocols/DabitSacrifice.hpp index 26d24def7..0ceabceb3 100644 --- a/Protocols/DabitSacrifice.hpp +++ b/Protocols/DabitSacrifice.hpp @@ -117,6 +117,9 @@ void DabitSacrifice::sacrifice_and_check_bits(vector >& dabits, vector >& check_dabits, SubProcessor& proc, ThreadQueues* queues) { + if (OnlineOptions::singleton.has_option("verbose_dabit")) + fprintf(stderr, "checking %zu daBits\n", check_dabits.size()); + vector> to_check; sacrifice_without_bit_check(to_check, check_dabits, proc, queues); typename T::Protocol protocol(proc.P); @@ -134,7 +137,7 @@ void DabitSacrifice::sacrifice_and_check_bits(vector >& dabits, } else { - BufferScope scope(proc.DataF, multiplicands.size()); + BufferScope scope(proc.DataF, multiplicands.size()); protocol.multiply(products, multiplicands, 0, multiplicands.size(), proc); } vector check_for_zero; diff --git a/Protocols/DealerMatrixPrep.h b/Protocols/DealerMatrixPrep.h index 787397255..dc02c14ef 100644 --- a/Protocols/DealerMatrixPrep.h +++ b/Protocols/DealerMatrixPrep.h @@ -26,6 +26,11 @@ class DealerMatrixPrep : public BufferPrep> { } + int minimum_batch() + { + return -1; + } + void buffer_triples(); }; diff --git a/Protocols/DealerMatrixPrep.hpp b/Protocols/DealerMatrixPrep.hpp index 29e4c1efd..5ec2adcbf 100644 --- a/Protocols/DealerMatrixPrep.hpp +++ b/Protocols/DealerMatrixPrep.hpp @@ -51,7 +51,8 @@ void DealerMatrixPrep::buffer_triples() vector senders(P.num_players()); senders.back() = true; octetStreams os(P), to_receive(P); - int batch_size = min(100, OnlineOptions::singleton.batch_size); + int batch_size = BaseMachine::matrix_batch_size(n_rows, n_inner, n_cols); + assert(batch_size > 0); if (not T::real_shares(P)) { SeededPRNG G; diff --git a/Protocols/FakeProtocol.h b/Protocols/FakeProtocol.h index f462862a4..246c9cb2a 100644 --- a/Protocols/FakeProtocol.h +++ b/Protocols/FakeProtocol.h @@ -20,11 +20,13 @@ class FakeShuffle public: typedef ShuffleStore store_type; + map stats; + FakeShuffle(SubProcessor&) { } - FakeShuffle(vector& a, size_t n, int unit_size, size_t output_base, + FakeShuffle(StackedVector& a, size_t n, int unit_size, size_t output_base, size_t input_base, SubProcessor&) { apply(a, n, unit_size, output_base, input_base, 0, 0); @@ -35,7 +37,7 @@ class FakeShuffle return store.add(); } - void apply(vector& a, size_t n, int unit_size, size_t output_base, + void apply(StackedVector& a, size_t n, int unit_size, size_t output_base, size_t input_base, int, bool) { auto source = a.begin() + input_base; @@ -52,7 +54,7 @@ class FakeShuffle } } - void inverse_permutation(vector&, size_t, size_t, size_t) + void inverse_permutation(StackedVector&, size_t, size_t, size_t) { } }; @@ -270,7 +272,7 @@ class FakeProtocol : public ProtocolBase { ltz_stats[args[i + 4]] += args[i + 1]; assert(i + args[i] <= args.size()); - assert(args[i] == 6); + assert(args[i] >= 5); for (int j = 0; j < args[i + 1]; j++) { auto& res = processor.get_S()[args[i + 2] + j]; @@ -284,7 +286,7 @@ class FakeProtocol : public ProtocolBase for (size_t i = 0; i < args.size(); i += args[i]) { assert(i + args[i] <= args.size()); - assert(args[i] == 6); + assert(args[i] >= 5); for (int j = 0; j < args[i + 1]; j++) { auto& res = processor.get_S()[args[i + 2] + j]; @@ -297,10 +299,10 @@ class FakeProtocol : public ProtocolBase for (size_t i = 0; i < args.size(); i += args[i]) { assert(i + args[i] <= args.size()); - assert(args[i] == 8); + assert(args[i] == 7); int k = args[i + 4]; int m = args[i + 5]; - int s = args[i + 7]; + int s = args[i + 6]; assert((s == 0) or (s == 1)); for (int j = 0; j < args[i + 1]; j++) { diff --git a/Protocols/FakeShare.h b/Protocols/FakeShare.h index 7a8d424fd..fd2419016 100644 --- a/Protocols/FakeShare.h +++ b/Protocols/FakeShare.h @@ -21,6 +21,7 @@ class FakeShare : public T, public ShareInterface public: typedef T open_type; typedef T clear; + typedef This share_type; typedef FakePrep LivePrep; typedef FakeMC MAC_Check; @@ -62,7 +63,7 @@ class FakeShare : public T, public ShareInterface { } - static void split(vector& dest, const vector& regs, + static void split(StackedVector& dest, const vector& regs, int n_bits, const This* source, int n_inputs, GC::FakeSecret::Protocol& protocol); }; diff --git a/Protocols/FakeShare.hpp b/Protocols/FakeShare.hpp index 7994f5104..19e08143f 100644 --- a/Protocols/FakeShare.hpp +++ b/Protocols/FakeShare.hpp @@ -8,7 +8,7 @@ #include "GC/square64.h" template -void FakeShare::split(vector& dest, +void FakeShare::split(StackedVector& dest, const vector& regs, int n_bits, const This* source, int n_inputs, GC::FakeSecret::Protocol&) { diff --git a/Protocols/Hemi.h b/Protocols/Hemi.h index 2bceb9f31..623663b85 100644 --- a/Protocols/Hemi.h +++ b/Protocols/Hemi.h @@ -30,6 +30,9 @@ class Hemi : public T::BasicProtocol typename T::MatrixPrep& get_matrix_prep(const array& dimensions, SubProcessor& processor); + bool use_plain_matmul(const array dimensions, + SubProcessor& processor); + ShareMatrix matrix_multiply(const ShareMatrix& A, const ShareMatrix& B, SubProcessor& processor); diff --git a/Protocols/Hemi.hpp b/Protocols/Hemi.hpp index 807f5fc79..bc4b4d3a5 100644 --- a/Protocols/Hemi.hpp +++ b/Protocols/Hemi.hpp @@ -32,17 +32,48 @@ typename T::MatrixPrep& Hemi::get_matrix_prep(const array& dims, return *matrix_preps.at(dims); } +template +bool Hemi::use_plain_matmul(const array dim, SubProcessor& processor) +{ + if (OnlineOptions::singleton.has_option("force_matrix_triples")) + return false; + + auto& prep = get_matrix_prep(dim, processor); + int savings = (dim[0] * dim[2]) / (dim[0] + dim[2]) + 1; + int requirement = BaseMachine::matrix_requirement(dim[0], dim[1], dim[2]); + + if (OnlineOptions::singleton.has_option("verbose_matrix")) + fprintf(stderr, "savings=%d minimum_batch=%d requirement=%d\n", savings, + prep.minimum_batch(), requirement); + + return HemiOptions::singleton.plain_matmul + or not OnlineOptions::singleton.live_prep + or prep.minimum_batch() / savings > requirement; +} + template void Hemi::matmulsm(SubProcessor& processor, MemoryPart& source, const Instruction& instruction) { - if (HemiOptions::singleton.plain_matmul - or not OnlineOptions::singleton.live_prep) + auto& dim = instruction.get_start(); + + vector plain_args, complex_args; + + for (auto it = dim.begin(); it < dim.end(); it += 12) { - processor.matmulsm(source, instruction); - return; + array real_dims({it[3], it[4], it[5]}); + + if (use_plain_matmul(real_dims, processor)) + plain_args.insert(plain_args.end(), it, it + 12); + else + complex_args.insert(complex_args.end(), it, it + 12); } + if (not plain_args.empty()) + processor.matmulsm(source, plain_args); + + auto& S = processor.get_S(); + // Perform the matrix multiplications in sequence. // They are not merged into one communication round since that would require multiple matrix_preps to // merge rounds. @@ -50,10 +81,10 @@ void Hemi::matmulsm(SubProcessor& processor, MemoryPart& source, // which is not implemented yet. auto Proc = processor.Proc; assert(Proc); - auto& S = processor.get_S(); - auto& start = instruction.get_start(); - for (auto matmulArgs = start.begin(); matmulArgs < start.end(); matmulArgs += 12) { + for (auto matmulArgs = complex_args.begin(); + matmulArgs < complex_args.end(); matmulArgs += 12) + { auto C = S.begin() + matmulArgs[0]; size_t firstFactorBase = Proc->get_Ci().at(matmulArgs[1]).get(); size_t secondFactorBase = Proc->get_Ci().at(matmulArgs[2]).get(); @@ -141,26 +172,32 @@ template void Hemi::conv2ds(SubProcessor& processor, const Instruction& instruction) { - if (HemiOptions::singleton.plain_matmul - or not OnlineOptions::singleton.live_prep) - { - processor.conv2ds(instruction); - return; - } - auto& args = instruction.get_start(); vector tuples; for (size_t i = 0; i < args.size(); i += 15) + { tuples.push_back(Conv2dTuple(args, i)); + if (use_plain_matmul(tuples.back().matrix_dimensions(), processor)) + { + processor.conv2ds(instruction); + return; + } + } for (auto& tuple : tuples) tuple.run_matrix(processor); } +inline +array Conv2dTuple::matrix_dimensions() +{ + return {1, weights_h * weights_w * n_channels_in, batch_size * output_h * output_w}; +} + template void Conv2dTuple::run_matrix(SubProcessor& processor) { auto& S = processor.get_S(); - array dim({{1, weights_h * weights_w * n_channels_in, batch_size * output_h * output_w}}); + array dim = matrix_dimensions(); ShareMatrix A(dim[0], dim[1]), B(dim[1], dim[2]); if (not T::real_shares(processor.P)) diff --git a/Protocols/HemiMatrixPrep.h b/Protocols/HemiMatrixPrep.h index d6912bb83..e1f476bcf 100644 --- a/Protocols/HemiMatrixPrep.h +++ b/Protocols/HemiMatrixPrep.h @@ -43,6 +43,8 @@ class HemiMatrixPrep : public BufferPrep> this->P = &prep.proc->P; } + int minimum_batch(); + void set_protocol(typename ShareMatrix::Protocol&) { } diff --git a/Protocols/HemiMatrixPrep.hpp b/Protocols/HemiMatrixPrep.hpp index 062c22391..781941eed 100644 --- a/Protocols/HemiMatrixPrep.hpp +++ b/Protocols/HemiMatrixPrep.hpp @@ -87,20 +87,29 @@ inline void matrix_rand_mult(ThreadJob, false_type) } template -void HemiMatrixPrep::buffer_triples() +int HemiMatrixPrep::minimum_batch() { + assert(prep); + return prep->get_FTD().num_slots() / n_rows; +} +template +void HemiMatrixPrep::buffer_triples() +{ assert(prep); auto& multipliers = prep->get_multipliers(); auto& FTD = prep->get_FTD(); auto& pk = prep->get_pk(); - int n_matrices = FTD.num_slots() / n_rows; -#ifdef VERBOSE_HE - fprintf(stderr, "creating %d %dx%d * %dx%d triples\n", n_matrices, n_rows, n_inner, - n_inner, n_cols); - fflush(stderr); + int n_matrices = minimum_batch(); + + if (OnlineOptions::singleton.has_option("verbose_he")) + { + fprintf(stderr, "creating %d %dx%d * %dx%d triples\n", n_matrices, + n_rows, n_inner, n_inner, n_cols); + fflush(stderr); + } + RunningTimer timer; -#endif AddableVector> A(n_matrices, {n_rows, n_inner}), B(n_matrices, {n_inner, n_cols}); SeededPRNG G; diff --git a/Protocols/MAC_Check.hpp b/Protocols/MAC_Check.hpp index 4c93b8e65..a38a1c921 100644 --- a/Protocols/MAC_Check.hpp +++ b/Protocols/MAC_Check.hpp @@ -69,8 +69,6 @@ Tree_MAC_Check::Tree_MAC_Check(const typename U::mac_key_type::Scalar& ai, in { popen_cnt=0; this->alphai=ai; - vals.reserve(2 * POPEN_MAX); - macs.reserve(2 * POPEN_MAX); } template @@ -89,6 +87,7 @@ template void Tree_MAC_Check::init_open(const Player&, int n) { macs.reserve(macs.size() + n); + vals.reserve(vals.size() + n); this->secrets.clear(); this->values.clear(); this->secrets.reserve(n); diff --git a/Protocols/MAC_Check_Base.h b/Protocols/MAC_Check_Base.h index fed190ef4..5b8553c7e 100644 --- a/Protocols/MAC_Check_Base.h +++ b/Protocols/MAC_Check_Base.h @@ -31,7 +31,7 @@ class MAC_Check_Base PointerVector values; public: - int values_opened; + size_t values_opened; static void setup(Player&) {} static void teardown() {} diff --git a/Protocols/MaliciousRepPrep.hpp b/Protocols/MaliciousRepPrep.hpp index 0a98e8e27..177c7e732 100644 --- a/Protocols/MaliciousRepPrep.hpp +++ b/Protocols/MaliciousRepPrep.hpp @@ -7,6 +7,7 @@ #define PROTOCOLS_MALICIOUSREPPREP_HPP_ #include "MaliciousRepPrep.h" +#include "BufferScope.h" #include "Tools/Subroutines.h" #include "Processor/OnlineOptions.h" @@ -78,6 +79,9 @@ void MaliciousRepPrep::buffer_triples() auto& triples = this->triples; auto buffer_size = BaseMachine::batch_size(DATA_TRIPLE, this->buffer_size); + if (OnlineOptions::singleton.has_option("verbose_triples")) + fprintf(stderr, "creating %d triples (%d)\n", buffer_size, + this->buffer_size); auto& honest_proc = this->honest_proc; assert(honest_proc != 0); Player& P = honest_proc->P; @@ -85,6 +89,7 @@ void MaliciousRepPrep::buffer_triples() check_triples.reserve(buffer_size); auto& honest_prot = honest_proc->protocol; honest_prot.init_mul(); + BufferScope scope(honest_prot, 3 * buffer_size); for (int i = 0; i < buffer_size; i++) { check_triples.push_back({}); @@ -148,7 +153,7 @@ void MaliciousRepPrep::buffer_squares() assert(honest_proc); Player& P = honest_proc->P; squares.clear(); - honest_prep.buffer_size = buffer_size; + BufferScope scope(honest_prep, buffer_size); for (int i = 0; i < buffer_size; i++) { T a, b; @@ -191,7 +196,7 @@ void MaliciousBitOnlyRepPrep::buffer_bits() this->buffer_size); assert(honest_proc); Player& P = honest_proc->P; - honest_prep.buffer_size = buffer_size; + BufferScope scope(honest_prep, buffer_size); bits.clear(); for (int i = 0; i < buffer_size; i++) { diff --git a/Protocols/MamaPrep.hpp b/Protocols/MamaPrep.hpp index 6a3d0fcbd..7dde4cc33 100644 --- a/Protocols/MamaPrep.hpp +++ b/Protocols/MamaPrep.hpp @@ -62,8 +62,10 @@ void MamaPrep::buffer_triples() for (int k = 0; k < 3; k++) triples.back()[k] = x.byIndex(k, 0); } +#ifdef VERBOSE cerr << "Got " << triple_generator->uncheckedTriples.size() << " triples" << endl; +#endif } if (use_shuffling) diff --git a/Protocols/NoShare.h b/Protocols/NoShare.h index a8874b8fd..b8bce46da 100644 --- a/Protocols/NoShare.h +++ b/Protocols/NoShare.h @@ -21,6 +21,8 @@ class NoShare : public ShareInterface typedef NoShare This; public: + typedef This share_type; + // type for clear values in relevant domain typedef T clear; typedef clear open_type; diff --git a/Protocols/PostSacriRepRingShare.h b/Protocols/PostSacriRepRingShare.h index cac57ffc6..30f0574b2 100644 --- a/Protocols/PostSacriRepRingShare.h +++ b/Protocols/PostSacriRepRingShare.h @@ -62,7 +62,7 @@ class PostSacriRepRingShare : public Rep3Share2 } template - static void split(vector& dest, const vector& regs, int n_bits, + static void split(StackedVector& dest, const vector& regs, int n_bits, const super* source, int n_inputs, typename bit_type::Protocol& protocol) { diff --git a/Protocols/Rep3Share.h b/Protocols/Rep3Share.h index cd321b26e..1e2852a11 100644 --- a/Protocols/Rep3Share.h +++ b/Protocols/Rep3Share.h @@ -29,6 +29,7 @@ class RepShare : public FixedVec, public ShareInterface public: typedef T clear; typedef T open_type; + typedef This share_type; const static bool needs_ot = false; const static bool dishonest_majority = false; diff --git a/Protocols/Rep3Share2k.h b/Protocols/Rep3Share2k.h index 0fc2e50e4..ea4a79e06 100644 --- a/Protocols/Rep3Share2k.h +++ b/Protocols/Rep3Share2k.h @@ -44,7 +44,7 @@ class Rep3Share2 : public Rep3Share> } template - static void split(vector& dest, const vector& regs, int n_bits, + static void split(StackedVector& dest, const vector& regs, int n_bits, const Rep3Share2* source, int n_inputs, typename U::Protocol& protocol) { diff --git a/Protocols/Rep3Shuffler.h b/Protocols/Rep3Shuffler.h index 94d86c9c5..fecfe7c31 100644 --- a/Protocols/Rep3Shuffler.h +++ b/Protocols/Rep3Shuffler.h @@ -19,17 +19,19 @@ class Rep3Shuffler SubProcessor& proc; public: - Rep3Shuffler(vector& a, size_t n, int unit_size, size_t output_base, + map stats; + + Rep3Shuffler(StackedVector& a, size_t n, int unit_size, size_t output_base, size_t input_base, SubProcessor& proc); Rep3Shuffler(SubProcessor& proc); int generate(int n_shuffle, store_type& store); - void apply(vector& a, size_t n, int unit_size, size_t output_base, + void apply(StackedVector& a, size_t n, int unit_size, size_t output_base, size_t input_base, shuffle_type& shuffle, bool reverse); - void inverse_permutation(vector& stack, size_t n, size_t output_base, + void inverse_permutation(StackedVector& stack, size_t n, size_t output_base, size_t input_base); }; diff --git a/Protocols/Rep3Shuffler.hpp b/Protocols/Rep3Shuffler.hpp index f3a29c84d..52a1e9dcb 100644 --- a/Protocols/Rep3Shuffler.hpp +++ b/Protocols/Rep3Shuffler.hpp @@ -9,7 +9,7 @@ #include "Rep3Shuffler.h" template -Rep3Shuffler::Rep3Shuffler(vector& a, size_t n, int unit_size, +Rep3Shuffler::Rep3Shuffler(StackedVector& a, size_t n, int unit_size, size_t output_base, size_t input_base, SubProcessor& proc) : proc(proc) { @@ -38,14 +38,14 @@ int Rep3Shuffler::generate(int n_shuffle, store_type& store) for (int j = 0; j < n_shuffle; j++) { int k = proc.protocol.shared_prngs[i].get_uint(n_shuffle - j); - swap(perm[k], perm[k + j]); + swap(perm[j], perm[k + j]); } } return res; } template -void Rep3Shuffler::apply(vector& a, size_t n, int unit_size, +void Rep3Shuffler::apply(StackedVector& a, size_t n, int unit_size, size_t output_base, size_t input_base, shuffle_type& shuffle, bool reverse) { @@ -57,6 +57,8 @@ void Rep3Shuffler::apply(vector& a, size_t n, int unit_size, if (shuffle.empty()) throw runtime_error("shuffle has been deleted"); + stats[n / unit_size] += unit_size; + vector to_shuffle; for (size_t i = 0; i < n; i++) to_shuffle.push_back(a[input_base + i]); @@ -120,7 +122,7 @@ void Rep3Shuffler::apply(vector& a, size_t n, int unit_size, } template -void Rep3Shuffler::inverse_permutation(vector&, size_t, size_t, size_t) +void Rep3Shuffler::inverse_permutation(StackedVector&, size_t, size_t, size_t) { throw runtime_error("inverse permutation not implemented"); } diff --git a/Protocols/Rep4.h b/Protocols/Rep4.h index b661fa3aa..0a7e1fd5a 100644 --- a/Protocols/Rep4.h +++ b/Protocols/Rep4.h @@ -65,6 +65,8 @@ class Rep4 : public ProtocolBase template T finalize_mul(int n_bits, false_type); + void must_check(); + public: static const bool uses_triples = false; @@ -94,7 +96,7 @@ class Rep4 : public ProtocolBase void trunc_pr(const vector& regs, int size, SubProcessor& proc); template - void split(vector& dest, const vector& regs, int n_bits, + void split(StackedVector& dest, const vector& regs, int n_bits, const U* source, int n_inputs); int get_n_relevant_players() { return 2; } diff --git a/Protocols/Rep4.hpp b/Protocols/Rep4.hpp index 45d42ab4a..10dc5e1f4 100644 --- a/Protocols/Rep4.hpp +++ b/Protocols/Rep4.hpp @@ -36,19 +36,25 @@ Rep4::Rep4(Player& P, prngs_type& prngs) : template Rep4::~Rep4() { + check(); + for (auto& x : receive_hashes) for (auto& y : x) - if (y.size > 0) - { - check(); - return; - } + assert(y.size == 0); for (auto& x : send_hashes) for (auto& y : x) - if (y.size > 0) + assert(y.size == 0); +} + +template +void Rep4::check() +{ + for (auto& x : channels) + for (auto y : x) + if (y) { - check(); + must_check(); return; } } @@ -288,7 +294,7 @@ T Rep4::finalize_dotprod(int) } template -void Rep4::check() +void Rep4::must_check() { octetStreams to_send(P); for (int i = 1; i < 4; i++) @@ -472,7 +478,7 @@ void Rep4::trunc_pr(const vector& regs, int size, template template -void Rep4::split(vector& dest, const vector& regs, int n_bits, +void Rep4::split(StackedVector& dest, const vector& regs, int n_bits, const U* source, int n_inputs) { assert(regs.size() / n_bits == 2); diff --git a/Protocols/Rep4Input.hpp b/Protocols/Rep4Input.hpp index 48844396b..a711167dc 100644 --- a/Protocols/Rep4Input.hpp +++ b/Protocols/Rep4Input.hpp @@ -22,6 +22,8 @@ template Rep4Input::~Rep4Input() { check(); + for (auto& hash : hashes) + assert(hash.size == 0); } template @@ -76,6 +78,12 @@ void Rep4Input::exchange() template void Rep4Input::check() { + bool check_needed = false; + for (auto& hash : hashes) + check_needed |= hash.size != 0; + if (not check_needed) + return; + octetStream os[2][2]; for (int i = 0; i < 2; i++) { diff --git a/Protocols/Rep4MC.hpp b/Protocols/Rep4MC.hpp index 22dceff11..7b2fffda2 100644 --- a/Protocols/Rep4MC.hpp +++ b/Protocols/Rep4MC.hpp @@ -32,6 +32,9 @@ void Rep4MC::exchange(const Player& P) template void Rep4MC::Check(const Player& P) { + if (check_hash.size == 0) + return; + octetStream left; check_hash.final(left); P.pass_around(left, -1); diff --git a/Protocols/Rep4Share2k.h b/Protocols/Rep4Share2k.h index 596cb2013..806439f6d 100644 --- a/Protocols/Rep4Share2k.h +++ b/Protocols/Rep4Share2k.h @@ -43,7 +43,7 @@ class Rep4Share2 : public Rep4Share> } template - static void split(vector& dest, const vector& regs, int n_bits, + static void split(StackedVector& dest, const vector& regs, int n_bits, const Rep4Share2* source, int n_inputs, Rep4& protocol) { int n_split = regs.size() / n_bits; @@ -54,7 +54,7 @@ class Rep4Share2 : public Rep4Share> } template - static void split(vector& dest, const vector& regs, int n_bits, + static void split(StackedVector& dest, const vector& regs, int n_bits, const Rep4Share2* source, int n_inputs, Player& P) { int my_num = P.my_num(); diff --git a/Protocols/RepRingOnlyEdabitPrep.hpp b/Protocols/RepRingOnlyEdabitPrep.hpp index 5213cf491..da2fab0b7 100644 --- a/Protocols/RepRingOnlyEdabitPrep.hpp +++ b/Protocols/RepRingOnlyEdabitPrep.hpp @@ -13,20 +13,21 @@ void RepRingOnlyEdabitPrep::buffer_edabits(int n_bits, ThreadQueues*) assert(this->proc); int dl = T::bit_type::default_length; int buffer_size = DIV_CEIL(BaseMachine::edabit_batch_size(n_bits, this->buffer_size), dl) * dl; - vector wholes; - wholes.resize(buffer_size); + StackedVector swholes; + swholes.resize(buffer_size); Instruction inst; inst.r[0] = 0; inst.n = n_bits; inst.size = buffer_size; - this->proc->protocol.randoms_inst(wholes, inst); + this->proc->protocol.randoms_inst(swholes, inst); + vector wholes(swholes.begin(), swholes.end()); auto& P = this->proc->P; vector regs(P.num_players() * n_bits); for (size_t i = 0; i < regs.size(); i++) regs[i] = i * buffer_size / dl; typedef typename T::bit_type bt; - vector bits(n_bits * P.num_players() * buffer_size); + StackedVector bits(n_bits * P.num_players() * buffer_size); T::split(bits, regs, n_bits, wholes.data(), wholes.size(), *GC::ShareThread < bt > ::s().protocol); diff --git a/Protocols/Replicated.h b/Protocols/Replicated.h index 1f1176ff7..f9db5cfe8 100644 --- a/Protocols/Replicated.h +++ b/Protocols/Replicated.h @@ -54,17 +54,20 @@ class ProtocolBase protected: vector random; - int trunc_pr_counter; - int rounds, trunc_rounds; - int dot_counter; - int bit_counter; + void add_mul(int n); public: typedef T share_type; typedef SecureShuffle Shuffler; - int counter; + long trunc_pr_counter; + long rounds, trunc_rounds; + long dot_counter; + long bit_counter; + long counter; + + int buffer_size; ProtocolBase(); virtual ~ProtocolBase(); @@ -107,12 +110,12 @@ class ProtocolBase { (void) regs, (void) size; (void) proc; throw runtime_error("trunc_pr not implemented"); } virtual void randoms(T&, int) { throw runtime_error("randoms not implemented"); } - virtual void randoms_inst(vector&, const Instruction&); + virtual void randoms_inst(StackedVector&, const Instruction&); template void matmulsm(SubProcessor & proc, MemoryPart& source, const Instruction& instruction) - { proc.matmulsm(source, instruction); } + { proc.matmulsm(source, instruction.get_start()); } template void conv2ds(SubProcessor& proc, const Instruction& instruction) diff --git a/Protocols/Replicated.hpp b/Protocols/Replicated.hpp index dc6324451..fd807364e 100644 --- a/Protocols/Replicated.hpp +++ b/Protocols/Replicated.hpp @@ -23,6 +23,7 @@ ProtocolBase::ProtocolBase() : trunc_pr_counter(0), rounds(0), trunc_rounds(0), dot_counter(0), bit_counter(0), counter(0) { + buffer_size = OnlineOptions::singleton.batch_size; } template @@ -213,10 +214,16 @@ void Replicated::stop_exchange() } template -inline T Replicated::finalize_mul(int n) +void ProtocolBase::add_mul(int n) { this->counter++; - this->bit_counter += n; + this->bit_counter += n < 0 ? T::default_length : n; +} + +template +inline T Replicated::finalize_mul(int n) +{ + this->add_mul(n); T result; result[0] = add_shares.next(); result[1].unpack(os[1], n); @@ -262,7 +269,7 @@ T Replicated::get_random() } template -void ProtocolBase::randoms_inst(vector& S, +void ProtocolBase::randoms_inst(StackedVector& S, const Instruction& instruction) { for (int j = 0; j < instruction.get_size(); j++) diff --git a/Protocols/ReplicatedMC.hpp b/Protocols/ReplicatedMC.hpp index 4d875a3b2..195b0dc83 100644 --- a/Protocols/ReplicatedMC.hpp +++ b/Protocols/ReplicatedMC.hpp @@ -34,6 +34,7 @@ void ReplicatedMC::prepare(const vector& S) to_send.reserve(S.size() * T::value_type::size()); for (auto& x : S) x[0].pack(to_send); + this->values_opened += S.size(); } template diff --git a/Protocols/ReplicatedPrep.h b/Protocols/ReplicatedPrep.h index 0633cbdc4..5279c9cf0 100644 --- a/Protocols/ReplicatedPrep.h +++ b/Protocols/ReplicatedPrep.h @@ -90,8 +90,6 @@ class BufferPrep : public Preprocessing public: typedef T share_type; - int buffer_size; - /// Key-independent setup if necessary (cryptosystem parameters) static void basic_setup(Player& P) { (void) P; } /// Generate keys if necessary @@ -118,7 +116,7 @@ class BufferPrep : public Preprocessing void get_two_no_count(Dtype dtype, T& a, T& b); void get_one_no_count(Dtype dtype, T& a); void get_input_no_count(T& a, typename T::open_type& x, int i); - void get_no_count(vector& S, DataTag tag, const vector& regs, + void get_no_count(StackedVector& S, DataTag tag, const vector& regs, int vector_size); virtual void get_dabit_no_count(T& a, typename T::bit_type& b); @@ -154,12 +152,12 @@ class BitPrep : public virtual BufferPrep protected: int base_player; - typename T::Protocol* protocol; - void buffer_ring_bits_without_check(vector& bits, PRNG& G, int buffer_size); public: + typename T::Protocol* protocol; + BitPrep(SubProcessor* proc, DataPositions& usage); ~BitPrep(); @@ -180,8 +178,6 @@ class RingPrep : public virtual BitPrep { typedef typename T::bit_type::part_type BT; - SubProcessor* bit_part_proc; - protected: void buffer_dabits_without_check(vector>& dabits, int buffer_size = -1, ThreadQueues* queues = 0); @@ -214,6 +210,8 @@ class RingPrep : public virtual BitPrep typename BT::Input& bit_input, int input_player, int begin, int end); public: + SubProcessor* bit_part_proc; + RingPrep(SubProcessor* proc, DataPositions& usage); virtual ~RingPrep(); diff --git a/Protocols/ReplicatedPrep.hpp b/Protocols/ReplicatedPrep.hpp index c2801c4cf..8b1a4f116 100644 --- a/Protocols/ReplicatedPrep.hpp +++ b/Protocols/ReplicatedPrep.hpp @@ -62,8 +62,7 @@ class InScope template BufferPrep::BufferPrep(DataPositions& usage) : Preprocessing(usage), n_bit_rounds(0), - proc(0), P(0), - buffer_size(0) + proc(0), P(0) { } @@ -200,6 +199,7 @@ void generate_triples_initialized(vector>& triples, int n_triples, U* protocol, int n_bits = -1) { triples.resize(n_triples); + BufferScope scope(*protocol, 2 * triples.size()); for (size_t i = 0; i < triples.size(); i++) { auto& triple = triples[i]; @@ -220,6 +220,8 @@ void BufferPrep::get_three_no_count(Dtype dtype, T& a, T& b, T& c) if (triples.empty()) { + if (OnlineOptions::singleton.has_option("verbose_triples")) + fprintf(stderr, "out of %s triples\n", T::type_string().c_str()); InScope in_scope(this->do_count, false, *this); buffer_triples(); assert(not triples.empty()); @@ -403,7 +405,7 @@ template template void SemiHonestRingPrep::buffer_bits(true_type, false_type) { - if (this->protocol->get_n_relevant_players() > 10 + if (this->protocol->get_n_relevant_players() > T::bit_generation_threshold or OnlineOptions::singleton.bits_from_squares or T::dishonest_majority) buffer_bits_from_squares(*this); @@ -470,6 +472,9 @@ template void MaliciousRingPrep::buffer_personal_dabits_without_check( int input_player, vector>& to_check, int buffer_size) { + if (OnlineOptions::singleton.has_option("verbose_dabit")) + fprintf(stderr, "generating %d personal dabits\n", buffer_size); + assert(this->proc != 0); auto& P = this->proc->P; auto &party = GC::ShareThread::s(); @@ -508,9 +513,9 @@ void RingPrep::buffer_personal_edabits_without_check(int n_bits, vector& sums, vector >& bits, SubProcessor& proc, int input_player, int begin, int end) { -#ifdef VERBOSE_EDA - fprintf(stderr, "generate personal edaBits %d to %d\n", begin, end); -#endif + if (OnlineOptions::singleton.has_option("verbose_eda")) + fprintf(stderr, "generate personal edaBits %d to %d\n", begin, end); + InScope in_scope(this->do_count, false, *this); assert(this->proc != 0); auto& P = proc.P; @@ -520,6 +525,8 @@ void RingPrep::buffer_personal_edabits_without_check(int n_bits, bit_input.reset_all(P); assert(begin % BT::default_length == 0); int buffer_size = end - begin; + BufferScope _(this->proc->DataF, buffer_size); + BufferScope __(proc.DataF, buffer_size); buffer_personal_edabits_without_check_pre(n_bits, P, input, bit_input, input_player, buffer_size); input.exchange(); @@ -680,7 +687,7 @@ void BitPrep::buffer_ring_bits_without_check(vector& bits, PRNG& G, int n_relevant_players = protocol->get_n_relevant_players(); vector> player_bits; auto stat = proc->P.total_comm(); - BufferScope _(*this, buffer_size); + BufferScope _(*this, buffer_size); buffer_bits_from_players(player_bits, G, *proc, this->base_player, buffer_size, 1); auto& prot = *protocol; @@ -727,6 +734,9 @@ void SemiRep3Prep::buffer_dabits(ThreadQueues*) BT::default_length); int n_bits = n_blocks * BT::default_length; + if (OnlineOptions::singleton.has_option("verbose_dabit")) + fprintf(stderr, "generating %d daBits\n", n_bits); + vector b(n_blocks); vector> a(n_bits); @@ -746,7 +756,8 @@ void SemiRep3Prep::buffer_dabits(ThreadQueues*) // the first multiplication vector first(n_bits), second(n_bits); - typename T::Input input(P); + typename T::Input& input = this->proc->input; + input.reset_all(P); if (P.my_num() == 0) { @@ -957,7 +968,7 @@ void RingPrep::buffer_sedabits_from_edabits(int n_bits, false_type) fprintf(stderr, "sedabit buffer size %zu\n", buffer_size); #endif auto& loose = this->edabits[{false, n_bits}]; - BufferScope scope(*this, buffer_size * edabitvec::MAX_SIZE); + BufferScope scope(*this, buffer_size * edabitvec::MAX_SIZE); while (loose.size() < buffer_size) this->buffer_edabits(false, n_bits); sanitize<0>(loose, n_bits); @@ -990,15 +1001,14 @@ template void RingPrep::sanitize(vector>& edabits, int n_bits, int player, int begin, int end) { -#ifdef VERBOSE_EDA - fprintf(stderr, "sanitize edaBits %d to %d in %d\n", begin, end, - BaseMachine::thread_num); -#endif + if (OnlineOptions::singleton.has_option("verbose_eda")) + fprintf(stderr, "sanitize edaBits %d to %d in %d\n", begin, end, + BaseMachine::thread_num); vector dabits; typedef typename T::bit_type::part_type::small_type BT; vector to_open; - BufferScope scope(*this, (end - begin)); + BufferScope scope(*this, (end - begin)); for (int i = begin; i < end; i++) { auto& x = edabits[i]; @@ -1047,7 +1057,7 @@ void RingPrep::sanitize(vector>& edabits, int n_bits) vector dabits; typedef typename T::bit_type::part_type BT; vector to_open; - BufferScope scope(*this, edabits.size() * edabits[0].size()); + BufferScope scope(*this, edabits.size() * edabits[0].size()); #ifdef DEBUG_BATCH_SIZE cerr << this->dabits.size() << " daBits left before" << endl; @@ -1163,6 +1173,7 @@ void BufferPrep::get_input_no_count(T& a, typename T::open_type& x, int i) { InScope in_scope(this->do_count, false, *this); buffer_inputs(i); + assert(not inputs.empty()); } a = inputs[i].back().share; x = inputs[i].back().value; @@ -1256,7 +1267,7 @@ void BufferPrep::buffer_edabits_with_queues(bool strict, int n_bits) template template void Preprocessing::get_edabits(bool strict, size_t size, T* a, - vector& Sb, const vector& regs, false_type) + StackedVector& Sb, const vector& regs, false_type) { int n_bits = regs.size(); edabit eb; @@ -1345,7 +1356,7 @@ void BufferPrep::buffer_inputs_as_usual(int player, SubProcessor* proc) } template -void BufferPrep::get_no_count(vector& S, DataTag tag, +void BufferPrep::get_no_count(StackedVector& S, DataTag tag, const vector& regs, int vector_size) { (void) S, (void) tag, (void) regs, (void) vector_size; @@ -1377,7 +1388,7 @@ T BufferPrep::get_random() template void BufferPrep::buffer_extra(Dtype type, int n_items) { - BufferScope scope(*this, n_items); + BufferScope scope(*this, n_items); switch (type) { diff --git a/Protocols/RingOnlyPrep.hpp b/Protocols/RingOnlyPrep.hpp index cf1d0675d..6f77cbd4f 100644 --- a/Protocols/RingOnlyPrep.hpp +++ b/Protocols/RingOnlyPrep.hpp @@ -18,7 +18,7 @@ void RingOnlyPrep::buffer_dabits_from_bits_without_check( this->proc->bit_prep, this->proc->P); typename T::bit_type::part_type::Input input(bit_proc); input.reset_all(this->proc->P); - BufferScope scope(*this, buffer_size); + BufferScope scope(*this, buffer_size); for (int i = 0; i < buffer_size; i++) { T bit; diff --git a/Protocols/SecureShuffle.h b/Protocols/SecureShuffle.h index 5601db457..a6b2df725 100644 --- a/Protocols/SecureShuffle.h +++ b/Protocols/SecureShuffle.h @@ -71,17 +71,19 @@ class SecureShuffle void configure(int config_player, vector* perm, int n); void player_round(int config_player); - void waksman(vector& a, int depth, int start); + void waksman(StackedVector& a, int depth, int start); void cond_swap(T& x, T& y, const T& b); void iter_waksman(bool reverse = false); void waksman_round(int size, bool inwards, bool reverse); - void pre(vector& a, size_t n, size_t input_base); - void post(vector& a, size_t n, size_t input_base); + void pre(StackedVector& a, size_t n, size_t input_base); + void post(StackedVector& a, size_t n, size_t input_base); public: - SecureShuffle(vector& a, size_t n, int unit_size, + map stats; + + SecureShuffle(StackedVector& a, size_t n, int unit_size, size_t output_base, size_t input_base, SubProcessor& proc); SecureShuffle(SubProcessor& proc); @@ -101,7 +103,7 @@ class SecureShuffle * @param reverse Boolean indicating whether to apply the inverse of the permutation * @see SecureShuffle::generate for obtaining a shuffle handle */ - void apply(vector& a, size_t n, int unit_size, size_t output_base, + void apply(StackedVector& a, size_t n, int unit_size, size_t output_base, size_t input_base, shuffle_type& shuffle, bool reverse); /** @@ -117,7 +119,7 @@ class SecureShuffle * @param output_base The starting address of the output vector (i.e. the location to write the inverted permutation to) * @param input_base The starting address of the input vector (i.e. the location from which to read the permutation) */ - void inverse_permutation(vector& stack, size_t n, size_t output_base, size_t input_base); + void inverse_permutation(StackedVector& stack, size_t n, size_t output_base, size_t input_base); }; #endif /* PROTOCOLS_SECURESHUFFLE_H_ */ diff --git a/Protocols/SecureShuffle.hpp b/Protocols/SecureShuffle.hpp index f41c3f970..186b6fee1 100644 --- a/Protocols/SecureShuffle.hpp +++ b/Protocols/SecureShuffle.hpp @@ -58,7 +58,7 @@ SecureShuffle::SecureShuffle(SubProcessor& proc) : } template -SecureShuffle::SecureShuffle(vector& a, size_t n, int unit_size, +SecureShuffle::SecureShuffle(StackedVector& a, size_t n, int unit_size, size_t output_base, size_t input_base, SubProcessor& proc) : proc(proc), unit_size(unit_size), n_shuffle(0), exact(false) { @@ -71,11 +71,13 @@ SecureShuffle::SecureShuffle(vector& a, size_t n, int unit_size, } template -void SecureShuffle::apply(vector& a, size_t n, int unit_size, size_t output_base, +void SecureShuffle::apply(StackedVector& a, size_t n, int unit_size, size_t output_base, size_t input_base, shuffle_type& shuffle, bool reverse) { this->unit_size = unit_size; + stats[n / unit_size] += unit_size; + pre(a, n, input_base); assert(shuffle.size() == proc.protocol.get_relevant_players().size()); @@ -98,7 +100,7 @@ void SecureShuffle::apply(vector& a, size_t n, int unit_size, size_t outpu template -void SecureShuffle::inverse_permutation(vector &stack, size_t n, size_t output_base, +void SecureShuffle::inverse_permutation(StackedVector &stack, size_t n, size_t output_base, size_t input_base) { int alice = 0; int bob = 1; @@ -107,9 +109,11 @@ void SecureShuffle::inverse_permutation(vector &stack, size_t n, size_t ou auto &input = proc.input; // This method only supports two players - assert(P.num_players() == 2); + if (P.num_players() != 2) + throw runtime_error("inverse permutation only implemented for two players"); // The current implementation assumes a semi-honest environment - assert(!T::malicious); + if (T::malicious) + throw runtime_error("inverse permutation only implemented for semi-honest protocols"); // We are dealing directly with permutations, so the unit_size will always be 1. this->unit_size = 1; @@ -173,7 +177,7 @@ void SecureShuffle::inverse_permutation(vector &stack, size_t n, size_t ou } template -void SecureShuffle::pre(vector& a, size_t n, size_t input_base) +void SecureShuffle::pre(StackedVector& a, size_t n, size_t input_base) { n_shuffle = n / unit_size; assert(unit_size * n_shuffle == n); @@ -204,7 +208,7 @@ void SecureShuffle::pre(vector& a, size_t n, size_t input_base) } template -void SecureShuffle::post(vector& a, size_t n, size_t output_base) +void SecureShuffle::post(StackedVector& a, size_t n, size_t output_base) { if (exact) for (size_t i = 0; i < n; i++) @@ -344,7 +348,7 @@ void SecureShuffle::configure(int config_player, vector *perm, int n) { } template -void SecureShuffle::waksman(vector& a, int depth, int start) +void SecureShuffle::waksman(StackedVector& a, int depth, int start) { int n = a.size(); diff --git a/Protocols/Semi.h b/Protocols/Semi.h index 903aca6d1..f73dfd9d5 100644 --- a/Protocols/Semi.h +++ b/Protocols/Semi.h @@ -87,7 +87,7 @@ class Semi : public SPDZ void buffer_random() { - for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + for (int i = 0; i < this->buffer_size; i++) this->random.push_back(G.get()); } }; diff --git a/Protocols/Semi2kShare.h b/Protocols/Semi2kShare.h index 679c6bc8d..f88969d5b 100644 --- a/Protocols/Semi2kShare.h +++ b/Protocols/Semi2kShare.h @@ -49,7 +49,7 @@ class Semi2kShare : public SemiShare> } template - static void split(vector& dest, const vector& regs, int n_bits, + static void split(StackedVector& dest, const vector& regs, int n_bits, const Semi2kShare* source, int n_inputs, typename U::Protocol& protocol) { diff --git a/Protocols/SemiInput.h b/Protocols/SemiInput.h index d4c864f06..464d793a3 100644 --- a/Protocols/SemiInput.h +++ b/Protocols/SemiInput.h @@ -19,6 +19,8 @@ class PairwiseKeyInput : public PrepLessInput public: PairwiseKeyInput(SubProcessor* proc, PlayerBase& P); + + void maybe_init(PlayerBase& P); }; /** diff --git a/Protocols/SemiInput.hpp b/Protocols/SemiInput.hpp index 5cdfae792..1b943d997 100644 --- a/Protocols/SemiInput.hpp +++ b/Protocols/SemiInput.hpp @@ -18,9 +18,17 @@ SemiInput::SemiInput(SubProcessor* proc, PlayerBase& P) : } template -PairwiseKeyInput::PairwiseKeyInput(SubProcessor* proc, PlayerBase& P) : +PairwiseKeyInput::PairwiseKeyInput(SubProcessor* proc, PlayerBase&) : PrepLessInput(proc) { +} + +template +void PairwiseKeyInput::maybe_init(PlayerBase& P) +{ + if (send_prngs.size() > 0) + return; + vector to_send(P.num_players()), to_receive; for (int i = 0; i < P.num_players(); i++) { @@ -44,6 +52,7 @@ void SemiInput::reset(int player) template void SemiInput::add_mine(const typename T::clear& input, int) { + this->maybe_init(P); auto& P = this->P; typename T::open_type sum, share; for (int i = 0; i < P.num_players(); i++) @@ -57,6 +66,7 @@ void SemiInput::add_mine(const typename T::clear& input, int) template void SemiInput::add_other(int, int) { + this->maybe_init(P); } template diff --git a/Protocols/SemiShare.h b/Protocols/SemiShare.h index 8d9b11466..39d242708 100644 --- a/Protocols/SemiShare.h +++ b/Protocols/SemiShare.h @@ -54,6 +54,7 @@ class SemiShare : public T, public ShareInterface public: typedef T open_type; typedef T clear; + typedef SemiShare share_type; typedef SemiMC MAC_Check; typedef DirectSemiMC Direct_MC; diff --git a/Protocols/Shamir.hpp b/Protocols/Shamir.hpp index e11b37513..815377a4a 100644 --- a/Protocols/Shamir.hpp +++ b/Protocols/Shamir.hpp @@ -256,7 +256,9 @@ vector Shamir::get_randoms(PRNG& G, int t) random_input = new ShamirInput(0, P, threshold); auto& input = *random_input; input.reset_all(P); - int buffer_size = OnlineOptions::singleton.batch_size; + auto buffer_size = this->buffer_size; + if (OnlineOptions::singleton.has_option("verbose_random")) + fprintf(stderr, "generating %d random elements\n", buffer_size); for (int i = 0; i < buffer_size; i += hyper.size()) input.add_from_all(G.get()); input.exchange(); diff --git a/Protocols/ShamirInput.hpp b/Protocols/ShamirInput.hpp index f82c6f568..1fcaf8476 100644 --- a/Protocols/ShamirInput.hpp +++ b/Protocols/ShamirInput.hpp @@ -64,6 +64,8 @@ void ShamirInput::init() template void ShamirInput::add_mine(const typename T::open_type& input, int n_bits) { + this->maybe_init(this->P); + (void) n_bits; auto& P = this->P; int n = P.num_players(); @@ -106,6 +108,7 @@ void ShamirInput::finalize_other(int player, T& target, template void IndividualInput::add_sender(int player) { + this->maybe_init(this->P); senders[player] = true; } diff --git a/Protocols/ShamirShare.h b/Protocols/ShamirShare.h index 12966cfd8..9f6a0129f 100644 --- a/Protocols/ShamirShare.h +++ b/Protocols/ShamirShare.h @@ -31,6 +31,7 @@ class ShamirShare : public T, public ShareInterface typedef T clear; typedef T open_type; typedef void sacri_type; + typedef This share_type; typedef Shamir Protocol; typedef IndirectShamirMC MAC_Check; @@ -51,6 +52,7 @@ class ShamirShare : public T, public ShareInterface const static bool variable_players = true; const static bool expensive = false; const static bool malicious = false; + const static int bit_generation_threshold = 3; static string type_short() { diff --git a/Protocols/Share.h b/Protocols/Share.h index cfab66d12..1bff3b487 100644 --- a/Protocols/Share.h +++ b/Protocols/Share.h @@ -37,6 +37,8 @@ template class TinierSecret; template class Share_ : public ShareInterface { + static V mac_key; + T a; // The share V mac; // Shares of the mac @@ -73,8 +75,10 @@ class Share_ : public ShareInterface static void read_or_generate_mac_key(string directory, const Player& P, U& key); - static void specification(octetStream& os) - { T::specification(os); } + static void specification(octetStream& os); + + static mac_key_type get_mac_key(); + static void set_mac_key(const mac_key_type& mac_key); static Share_ constant(const open_type& aa, int my_num, const typename V::Scalar& alphai) { return Share_(aa, my_num, alphai); } diff --git a/Protocols/Share.hpp b/Protocols/Share.hpp index c6f675f7f..f62f854ff 100644 --- a/Protocols/Share.hpp +++ b/Protocols/Share.hpp @@ -3,6 +3,9 @@ #include "Share.h" +template +typename Share_::mac_key_type Share_::mac_key; + template template @@ -21,6 +24,31 @@ void Share_::read_or_generate_mac_key(string directory, const Player& P, SeededPRNG G; key.randomize(G); } + + mac_key = key; + + if (OnlineOptions::singleton.has_option("output_mac")) + { + cerr << "MAC key: " << mac_key << endl; + } +} + +template +typename Share_::mac_key_type Share_::get_mac_key() +{ + return mac_key; +} + +template +void Share_::set_mac_key(const mac_key_type& mac_key) +{ + Share_::mac_key = mac_key; +} + +template +void Share_::specification(octetStream& os) +{ + T::specification(os); } template diff --git a/Protocols/ShareInterface.cpp b/Protocols/ShareInterface.cpp index 065ad3dc5..6f1cc5bd3 100644 --- a/Protocols/ShareInterface.cpp +++ b/Protocols/ShareInterface.cpp @@ -4,7 +4,17 @@ */ #include "ShareInterface.h" +#include "GC/NoShare.h" const int ShareInterface::default_length; const false_type ShareInterface::triple_matmul; + +GC::NoValue ShareInterface::get_mac_key() +{ + throw runtime_error("no MAC"); +} + +void ShareInterface::set_mac_key(GC::NoValue) +{ +} diff --git a/Protocols/ShareInterface.h b/Protocols/ShareInterface.h index 187d021d2..79ec8c44b 100644 --- a/Protocols/ShareInterface.h +++ b/Protocols/ShareInterface.h @@ -17,6 +17,8 @@ class Player; class Instruction; class ValueInterface; +template class StackedVector; + namespace GC { class NoShare; @@ -50,6 +52,8 @@ class ShareInterface static const bool randoms_for_opens = false; + static const int bit_generation_threshold = 0; + static const int default_length = 1; static string type_short() { throw runtime_error("shorthand undefined"); } @@ -59,7 +63,7 @@ class ShareInterface static bool real_shares(const Player&) { return true; } template - static void split(vector, vector, int, T*, int, + static void split(StackedVector&, vector, int, T*, int, typename U::Protocol&) { throw runtime_error("split not implemented"); } @@ -75,6 +79,9 @@ class ShareInterface template static void generate_mac_key(T&, U&) {} + static GC::NoValue get_mac_key(); + static void set_mac_key(GC::NoValue); + static int threshold(int) { throw runtime_error("undefined threshold"); } template diff --git a/Protocols/ShuffleSacrifice.hpp b/Protocols/ShuffleSacrifice.hpp index 5ca976e40..192cf36ce 100644 --- a/Protocols/ShuffleSacrifice.hpp +++ b/Protocols/ShuffleSacrifice.hpp @@ -7,6 +7,7 @@ #define PROTOCOLS_SHUFFLESACRIFICE_HPP_ #include "ShuffleSacrifice.h" +#include "BufferScope.h" #include "Tools/PointerVector.h" #include "GC/BitAdder.h" @@ -187,11 +188,15 @@ void EdabitShuffleSacrifice::edabit_sacrifice(vector >& output, SubProcessor& proc, bool strict, int player, ThreadQueues* queues) { -#ifdef VERBOSE_EDA - cerr << "Sacrificing edaBits of length " << n_bits << endl; Timer timer; - timer.start(); -#endif + bool verbose = OnlineOptions::singleton.has_option("verbose_eda"); + + if (verbose) + { + fprintf(stderr, "Sacrificing %zu edaBits of length %zu\n", + wholes.size(), n_bits); + timer.start(); + } auto& P = proc.P; auto& MC = proc.MC; @@ -227,9 +232,8 @@ void EdabitShuffleSacrifice::edabit_sacrifice(vector >& output, parts.clear(); parts.shrink_to_fit(); -#ifdef VERBOSE_EDA - cerr << "Initialization took " << init_timer.elapsed() << " seconds" << endl; -#endif + if (verbose) + cerr << "Initialization took " << init_timer.elapsed() << " seconds" << endl; int buffer_size = to_check.size(); int N = (buffer_size - C) / B; @@ -246,20 +250,21 @@ void EdabitShuffleSacrifice::edabit_sacrifice(vector >& output, int n_triples = DIV_CEIL((B - 1) * N * n_bits, dl); proc.personal_bit_preps.at(player)->buffer_personal_triples(n_triples, queues); + BufferScope scope(*proc.personal_bit_preps.at(player), n_triples); for (int i = 0; i < n_triples; i++) personal_prep.push_triple( proc.personal_bit_preps.at(player)->get_triple(dl)); proc.personal_bit_preps.at(player)->shrink_to_fit(); } -#ifdef VERBOSE_EDA - cerr << "Personal preprocessing took " << personal_timer.elapsed() << " seconds" << endl; -#endif + + if (verbose) + cerr << "Personal preprocessing took " << personal_timer.elapsed() << " seconds" << endl; RunningTimer shuffle_timer; shuffle(to_check, P); -#ifdef VERBOSE_EDA - cerr << "Shuffling took " << shuffle_timer.elapsed() << " seconds" << endl; -#endif + + if (verbose) + cerr << "Shuffling took " << shuffle_timer.elapsed() << " seconds" << endl; // opening C vector shares; @@ -325,10 +330,10 @@ void EdabitShuffleSacrifice::edabit_sacrifice(vector >& output, else edabit_sacrifice_buckets(to_check, strict, player, proc, 0, N, personal_prep); -#ifdef VERBOSE_EDA - cerr << "Bucket sacrifice took " << bucket_timer.elapsed() << " seconds" - << endl; -#endif + + if (verbose) + cerr << "Bucket sacrifice took " << bucket_timer.elapsed() << " seconds" + << endl; RunningTimer output_timer; to_check.resize(N); @@ -340,17 +345,16 @@ void EdabitShuffleSacrifice::edabit_sacrifice(vector >& output, for (auto& y : x.second) output.back().second.push_back(y); } -#ifdef VERBOSE_EDA - cerr << "Output took " << output_timer.elapsed() << " seconds" << endl; -#endif + + if (verbose) + cerr << "Output took " << output_timer.elapsed() << " seconds" << endl; MCB.Check(P); delete &MCB; -#ifdef VERBOSE_EDA - cerr << "Done sacrificing edaBits of length " << n_bits << " after " - << timer.elapsed() << " seconds" << endl; -#endif + if (verbose) + cerr << "Done sacrificing edaBits of length " << n_bits << " after " + << timer.elapsed() << " seconds" << endl; } template @@ -379,6 +383,9 @@ void EdabitShuffleSacrifice::edabit_sacrifice_buckets(vector>& to_c auto& P = proc.P; auto& MC = proc.MC; + if (OnlineOptions::singleton.has_option("verbose_eda")) + fprintf(stderr, "sacrificing %d edaBits\n", N); + // sacrifice buckets RunningTimer add_prep_timer; vector>> summands(n_bits_to_open, diff --git a/Protocols/Spdz2kPrep.hpp b/Protocols/Spdz2kPrep.hpp index a3e3f2fc2..ab657e11b 100644 --- a/Protocols/Spdz2kPrep.hpp +++ b/Protocols/Spdz2kPrep.hpp @@ -94,7 +94,7 @@ void bits_from_square_in_ring(vector& bits, int buffer_size, U* bit_prep) auto bit_MC = &bit_proc->MC; vector squares, random_shares; auto one = BitShare::constant(1, bit_proc->P.my_num(), bit_MC->get_alphai()); - bit_prep->buffer_size = buffer_size; + BufferScope scope(*bit_prep, buffer_size); for (int i = 0; i < buffer_size; i++) { BitShare a, a2; @@ -169,11 +169,15 @@ void MaliciousRingPrep::buffer_edabits_from_personal(bool strict, int n_bits, typedef typename T::bit_type::part_type bit_type; vector> bits; vector sums; -#ifdef VERBOSE_EDA - cerr << "Generate edaBits of length " << n_bits << " to sacrifice" << endl; + + bool verbose = OnlineOptions::singleton.has_option("verbose_eda"); Timer timer; - timer.start(); -#endif + if (verbose) + { + cerr << "Generate edaBits of length " << n_bits << " to sacrifice" << endl; + timer.start(); + } + auto &party = GC::ShareThread::s(); SubProcessor bit_proc(party.MC->get_part_MC(), this->proc->bit_prep, this->proc->P); @@ -196,12 +200,16 @@ void MaliciousRingPrep::buffer_edabits_from_personal(bool strict, int n_bits, BitAdder().add(bits, player_bits, bit_proc, bit_type::default_length, queues); player_bits.clear(); -#ifdef VERBOSE_EDA - cerr << "Adding edaBits took " << add_timer.elapsed() << " seconds" << endl; - cerr << "Done with generating edaBits after " << timer.elapsed() - << " seconds" << endl; + + if (verbose) + { + cerr << "Adding edaBits took " << add_timer.elapsed() << " seconds" + << endl; + cerr << "Done with generating edaBits after " << timer.elapsed() + << " seconds" << endl; + } + RunningTimer finalize_timer; -#endif vector> checked; for (size_t i = 0; i < sums.size(); i++) { @@ -226,9 +234,9 @@ void MaliciousRingPrep::buffer_edabits_from_personal(bool strict, int n_bits, else output.back().push_back(x); } -#ifdef VERBOSE_EDA - cerr << "Finalizing took " << finalize_timer.elapsed() << " seconds" << endl; -#endif + + if (verbose) + cerr << "Finalizing took " << finalize_timer.elapsed() << " seconds" << endl; } template diff --git a/Protocols/SpdzWise.h b/Protocols/SpdzWise.h index 296550290..bca59dab9 100644 --- a/Protocols/SpdzWise.h +++ b/Protocols/SpdzWise.h @@ -7,6 +7,7 @@ #define PROTOCOLS_SPDZWISE_H_ #include "Replicated.h" +#include "SpdzWiseRep3Shuffler.h" template class SpdzWiseInput; @@ -33,6 +34,9 @@ class SpdzWise : public ProtocolBase virtual void zero_check(check_type t); public: + typedef typename conditional, + SpdzWiseRep3Shuffler>::type Shuffler; + static const bool uses_triples = false; Player& P; @@ -60,7 +64,7 @@ class SpdzWise : public ProtocolBase int get_n_relevant_players() { return internal.get_n_relevant_players(); } - void randoms_inst(vector& S, const Instruction& instruction); + void randoms_inst(StackedVector& S, const Instruction& instruction); }; #endif /* PROTOCOLS_SPDZWISE_H_ */ diff --git a/Protocols/SpdzWise.hpp b/Protocols/SpdzWise.hpp index b7a8c741f..90289c89f 100644 --- a/Protocols/SpdzWise.hpp +++ b/Protocols/SpdzWise.hpp @@ -5,6 +5,8 @@ #include "SpdzWise.h" +#include "BufferScope.h" + #include "mac_key.hpp" template @@ -121,6 +123,8 @@ void SpdzWise::check() internal.init_dotprod(); coefficients.clear(); + BufferScope _(internal, results.size()); + for (auto& res : results) { coefficients.push_back(internal.get_random()); @@ -158,7 +162,7 @@ void SpdzWise::buffer_random() { // proxy for initialization assert(mac_key != 0); - int batch_size = OnlineOptions::singleton.batch_size; + auto batch_size = this->buffer_size; vector rs; rs.reserve(batch_size); // cannot use member instance @@ -178,7 +182,7 @@ void SpdzWise::buffer_random() } template -void SpdzWise::randoms_inst(vector& S, +void SpdzWise::randoms_inst(StackedVector& S, const Instruction& instruction) { internal.init_mul(); diff --git a/Protocols/SpdzWisePrep.hpp b/Protocols/SpdzWisePrep.hpp index e94a59ee3..522104a13 100644 --- a/Protocols/SpdzWisePrep.hpp +++ b/Protocols/SpdzWisePrep.hpp @@ -106,7 +106,8 @@ void SpdzWisePrep::buffer_inputs(int player) { assert(this->proc != 0); assert(this->protocol != 0); - vector rs(OnlineOptions::singleton.batch_size); + vector rs(BaseMachine::input_batch_size(player, + this->buffer_size)); auto& P = this->proc->P; this->inputs.resize(P.num_players()); this->protocol->init_mul(); diff --git a/Protocols/SpdzWiseRep3Shuffler.h b/Protocols/SpdzWiseRep3Shuffler.h new file mode 100644 index 000000000..c17736204 --- /dev/null +++ b/Protocols/SpdzWiseRep3Shuffler.h @@ -0,0 +1,40 @@ +/* + * SpdzWiseShuffler.h + * + */ + +#ifndef PROTOCOLS_SPDZWISEREP3SHUFFLER_H_ +#define PROTOCOLS_SPDZWISEREP3SHUFFLER_H_ + +#include "Rep3Shuffler.h" +#include "ProtocolSet.h" + +template +class SpdzWiseRep3Shuffler +{ + SubProcessor& proc; + + ProtocolSet internal_set; + Rep3Shuffler internal; + +public: + typedef typename Rep3Shuffler::store_type store_type; + typedef typename Rep3Shuffler::shuffle_type shuffle_type; + + map stats; + + SpdzWiseRep3Shuffler(StackedVector& a, size_t n, int unit_size, size_t output_base, + size_t input_base, SubProcessor& proc); + + SpdzWiseRep3Shuffler(SubProcessor& proc); + + int generate(int n_shuffle, store_type& store); + + void apply(StackedVector& a, size_t n, int unit_size, size_t output_base, + size_t input_base, shuffle_type& shuffle, bool reverse); + + void inverse_permutation(StackedVector& stack, size_t n, size_t output_base, + size_t input_base); +}; + +#endif /* PROTOCOLS_SPDZWISEREP3SHUFFLER_H_ */ diff --git a/Protocols/SpdzWiseRep3Shuffler.hpp b/Protocols/SpdzWiseRep3Shuffler.hpp new file mode 100644 index 000000000..971e4b1e4 --- /dev/null +++ b/Protocols/SpdzWiseRep3Shuffler.hpp @@ -0,0 +1,68 @@ +/* + * SpdzWiseShuffler.cpp + * + */ + +#include "SpdzWiseRep3Shuffler.h" + +template +SpdzWiseRep3Shuffler::SpdzWiseRep3Shuffler(StackedVector& a, size_t n, + int unit_size, size_t output_base, size_t input_base, + SubProcessor& proc) : + SpdzWiseRep3Shuffler(proc) +{ + store_type store; + int handle = generate(n / unit_size, store); + apply(a, n, unit_size, output_base, input_base, store.get(handle), + false); +} + +template +SpdzWiseRep3Shuffler::SpdzWiseRep3Shuffler(SubProcessor& proc) : + proc(proc), internal_set(proc.P, {}), internal(internal_set.processor) +{ +} + +template +int SpdzWiseRep3Shuffler::generate(int n_shuffle, store_type& store) +{ + return internal.generate(n_shuffle, store); +} + +template +void SpdzWiseRep3Shuffler::apply(StackedVector& a, size_t n, + int unit_size, size_t output_base, size_t input_base, + shuffle_type& shuffle, bool reverse) +{ + stats[n / unit_size] += unit_size; + + StackedVector to_shuffle; + to_shuffle.reserve(2 * n); + + for (size_t i = 0; i < n; i++) + { + auto& x = a[input_base + i]; + to_shuffle.push_back(x.get_share()); + to_shuffle.push_back(x.get_mac()); + } + + internal.apply(to_shuffle, 2 * n, 2 * unit_size, 0, 0, shuffle, reverse); + + + for (size_t i = 0; i < n; i++) + { + auto& x = a[output_base + i]; + x.set_share(to_shuffle[2 * i]); + x.set_mac(to_shuffle[2 * i + 1]); + proc.protocol.add_to_check(x); + } + + proc.protocol.maybe_check(); +} + +template +void SpdzWiseRep3Shuffler::inverse_permutation(StackedVector&, size_t, size_t, + size_t) +{ + throw not_implemented(); +} diff --git a/Protocols/SpdzWiseRingShare.h b/Protocols/SpdzWiseRingShare.h index 476b47fb2..f3b2df168 100644 --- a/Protocols/SpdzWiseRingShare.h +++ b/Protocols/SpdzWiseRingShare.h @@ -57,7 +57,7 @@ class SpdzWiseRingShare : public SpdzWiseShare>> } template - static void split(vector& dest, const vector& regs, int n_bits, + static void split(StackedVector& dest, const vector& regs, int n_bits, const SpdzWiseRingShare* source, int n_inputs, typename U::Protocol& protocol) { diff --git a/Protocols/SpdzWiseShare.hpp b/Protocols/SpdzWiseShare.hpp index acd6fa298..13d9e2720 100644 --- a/Protocols/SpdzWiseShare.hpp +++ b/Protocols/SpdzWiseShare.hpp @@ -40,6 +40,8 @@ void SpdzWiseShare::read_or_generate_mac_key(string directory, Player& P, T& if (fresh) mac_key = typename T::Honest::Protocol(P).get_random(); + + super::set_mac_key(mac_key); } template diff --git a/Protocols/fake-stuff.h b/Protocols/fake-stuff.h index 735a59531..0fc851137 100644 --- a/Protocols/fake-stuff.h +++ b/Protocols/fake-stuff.h @@ -38,15 +38,31 @@ template typename T::mac_key_type read_generate_write_mac_key(Player& P, string directory = ""); +template +class KeySetup +{ +public: + typename T::mac_share_type::open_type key; + vector key_shares; + + typename T::mac_share_type get(size_t i) const + { + if (key_shares.empty()) + return {}; + else + return key_shares.at(i); + } +}; + template class Files { public: ofstream* outf; int N; - typename T::mac_type key; + KeySetup key; PRNG& G; - Files(int N, const typename T::mac_type& key, const string& prep_data_prefix, + Files(int N, const KeySetup& key, const string& prep_data_prefix, Dtype type, PRNG& G, int thread_num = -1) : Files(N, key, get_prep_sub_dir(prep_data_prefix, N, true) @@ -54,7 +70,7 @@ class Files G, thread_num) { } - Files(int N, const typename T::mac_type& key, const string& prefix, + Files(int N, const KeySetup& key, const string& prefix, PRNG& G, int thread_num = -1) : N(N), key(key), G(G) { @@ -67,7 +83,7 @@ class Files filename << PrepBase::get_suffix(thread_num); cout << "Opening " << filename.str() << endl; outf[i].open(filename.str().c_str(),ios::out | ios::binary); - file_signature().output(outf[i]); + file_signature(key.get(i)).output(outf[i]); if (outf[i].fail()) throw file_error(filename.str().c_str()); } @@ -79,11 +95,11 @@ class Files template void output_shares(const typename U::open_type& a) { - output_shares(a, key); + output_shares(a, key.key); } - template + template void output_shares(const typename U::open_type& a, - const typename U::mac_type& key) + const V& key) { vector Sa(N); make_share(Sa,a,N,key,G); diff --git a/Protocols/fake-stuff.hpp b/Protocols/fake-stuff.hpp index 57b5c227a..22da13753 100644 --- a/Protocols/fake-stuff.hpp +++ b/Protocols/fake-stuff.hpp @@ -50,12 +50,11 @@ void make_share(Share_* Sa,const U& a,int N,const V& key,PRNG& G) Sa[N-1]=S; } -template -void make_share(SpdzWiseShare>* Sa,const U& a,int N,const V& key,PRNG& G) +template +void make_share(SpdzWiseShare>* Sa,const U& a,int N,const T& key,PRNG& G) { - assert (key[0] == key[1]); - auto mac = a * key[0]; - FixedVec shares, macs; + auto mac = a * key; + FixedVec shares, macs; shares.randomize_to_sum(a, G); macs.randomize_to_sum(mac, G); @@ -112,8 +111,8 @@ void make_share(GC::TinierSecret* Sa, const U& a, int N, const V& key, PRNG& make_vector_share(Sa, a, N, key, G); } -template -void make_share(SemiShare* Sa,const T& a,int N,const U&,PRNG& G) +template +void make_share(SemiShare* Sa,const V& a,int N,const U&,PRNG& G) { T x, S = a; for (int i=0; i> VanderStore::vandermonde; template void make_share(ShamirShare* Sa, const V& a, int N, - const typename ShamirShare::mac_type&, PRNG& G) + const GC::NoValue&, PRNG& G) { auto& vandermonde = VanderStore::vandermonde; if (vandermonde.empty()) @@ -316,6 +315,7 @@ template void read_mac_key(const Names& N, typename T::mac_key_type& key) { read_mac_key(get_prep_sub_dir(N.num_players()), N, key); + T::set_mac_key(key); } template @@ -402,20 +402,36 @@ inline GC::NoValue read_generate_write_mac_key(Player&, } template -void read_global_mac_key(const string& directory, int nparties, U& key) +KeySetup read_global_mac_key(const string& directory, int nparties) { - U pp; + if (is_same()) + return {}; + + KeySetup res; + res.key_shares.resize(nparties); + + auto& key = res.key; key.assign_zero(); for (int i= 0; i < nparties; i++) { + typename U::mac_key_type pp; read_mac_key(directory, i, nparties, pp); cout << " Key " << i << ": " << pp << endl; key += pp; + res.key_shares.at(i) = pp; } cout << "--------------\n"; cout << "Final Keys : " << key << endl; + + return res; +} + +template +void read_global_mac_key(const string& directory, int nparties, U& key) +{ + key = read_global_mac_key>(directory, nparties).key; } template <> @@ -448,18 +464,18 @@ T reconstruct(vector>& shares) } template -void make_mac_key_share(typename T::mac_share_type::open_type& key, - vector& key_shares, int nplayers, T, PRNG& G) +void make_mac_key_share(KeySetup& setup, int nplayers, T, PRNG& G) { - key.randomize(G); - make_share(key_shares.data(), key, nplayers, GC::NoShare(), G); - assert(not key_shares[0].is_zero()); + setup.key.randomize(G); + make_share(setup.key_shares.data(), setup.key, nplayers, GC::NoValue(), G); + assert(not setup.key_shares[0].is_zero()); } template -void make_mac_key_share(Z2& key, - vector>>& key_shares, int nplayers, Spdz2kShare, PRNG& G) +void make_mac_key_share(KeySetup>& setup, int nplayers, Spdz2kShare, PRNG& G) { + auto& key = setup.key; + auto& key_shares = setup.key_shares; key = {}; key_shares.resize(nplayers); for (int i = 0; i < nplayers; i++) @@ -471,15 +487,17 @@ void make_mac_key_share(Z2& key, } template -void generate_mac_keys(typename T::mac_share_type::open_type& key, +void generate_mac_keys(KeySetup& key_setup, int nplayers, string prep_data_prefix, PRNG& G) { + auto& key = key_setup.key; + auto& key_shares = key_setup.key_shares; key.assign_zero(); int tmpN = 0; ifstream inpf; prep_data_prefix = get_prep_sub_dir(prep_data_prefix, nplayers, true); bool generate = false; - vector key_shares(nplayers); + key_shares.resize(nplayers); for (int i = 0; i < nplayers; i++) { @@ -512,7 +530,7 @@ void generate_mac_keys(typename T::mac_share_type::open_type& key, if (generate) { - make_mac_key_share(key, key_shares, nplayers, T(), G); + make_mac_key_share(key_setup, nplayers, T(), G); for (int i = 0; i < nplayers; i++) { @@ -545,7 +563,7 @@ inline void check_files(ofstream* outf, int N) * ntrip = Number triples needed */ template -void make_mult_triples(const typename T::mac_type& key, int N, int ntrip, +void make_mult_triples(const KeySetup& key, int N, int ntrip, bool zero, string prep_data_prefix, PRNG& G, int thread_num = -1) { T::clear::write_setup(get_prep_sub_dir(prep_data_prefix, N)); @@ -571,7 +589,7 @@ void make_mult_triples(const typename T::mac_type& key, int N, int ntrip, * ntrip = Number inverses needed */ template -void make_inverse(const typename T::mac_type& key, int N, int ntrip, bool zero, +void make_inverse(const KeySetup& key, int N, int ntrip, bool zero, string prep_data_prefix, PRNG& G) { diff --git a/README.md b/README.md index 4aba415a5..265c95839 100644 --- a/README.md +++ b/README.md @@ -32,8 +32,8 @@ as well as information on how to solve common issues. #### TL;DR (Binary Distribution on Linux or Source Distribution on macOS) -This requires either a Linux distribution originally released 2014 or -later (glibc 2.17) or macOS High Sierra or later as well as Python 3 +This requires either a Linux distribution originally released 2018 or +later (glibc 2.18) or macOS High Sierra or later as well as Python 3 and basic command-line utilities. Download and unpack the @@ -68,7 +68,7 @@ security. make setup echo 1 2 3 4 > Player-Data/Input-P0-0 echo 1 2 3 4 > Player-Data/Input-P1-0 -Scripts/compile-run.py -E mascot tutorial +Scripts/compile-run.py mascot tutorial ``` On strong enough hardware setups (several cores and GB of RAM), you @@ -84,7 +84,7 @@ docker build --tag mpspdz:mascot-party --build-arg machine=mascot-party.x . Run the [the tutorial](Programs/Source/tutorial.mpc): ``` -docker run --rm -it mpspdz:mascot-party ./Scripts/mascot.sh tutorial +docker run --rm -it mpspdz:mascot-party ./Scripts/compile-run.py mascot tutorial ``` See the [`Dockerfile`](./Dockerfile) for examples of how it can be used. @@ -191,7 +191,7 @@ there are a few things to consider: preprocessing in smaller batches at a higher asymptotic cost. - `--batch-size`: Preprocessing in smaller batches avoids generating too much but larger batches save communication rounds. - - `--direct`: In dishonest-majority protocols, direct communication + - `--direct`: In protocols with any number of parties, direct communication instead of star-shaped saves communication rounds at the expense of a quadratic amount. This might be beneficial with a small number of parties. @@ -242,8 +242,9 @@ AES-NI pipelining (for garbled circuits). The software uses two different bytecode sets, one for arithmetic circuits and one for boolean circuits. The high-level code -slightly differs between the two variants, but we aim to keep these -differences a at minimum. +differs between the two variants. Most computation functionality is +available in both, but binary circuits are lacking some input-output +functionality. In the section on computation we will explain how to compile a high-level program for the various computation domains and then how to @@ -256,9 +257,9 @@ compute the preprocessing time for a particular computation. #### Requirements - - GCC 5 or later (tested with up to 11) or LLVM/clang 6 or later - (tested with up to 14). The default is to use clang because it performs - better. Note that GCC 5/6 and clang 9 don't support libOTe, so you + - GCC 7 or later (tested with up to 11) or LLVM/clang 6 or later + (tested with up to 19). The default is to use clang because it performs + better. clang 9 doesn't support libOTe, so you need to deactivate its use for these compilers (see the next section). - For protocols using oblivious transfer, libOTe with [the necessary @@ -309,9 +310,7 @@ compute the preprocessing time for a particular computation. - `SSL_DIR` should point to a local, unversioned directory to store ssl keys (the default is `Player-Data` in the current directory). - For homomorphic encryption with GF(2^40), set `USE_NTL = 1`. - To use KOS instead of SoftSpokenOT, add `USE_KOS = 1` and - `SECURE = -DINSECURE` to `CONFIG.mine`. This is necessary with - GCC 5 and 6 because these compilers don't support the C++ - standard used by libOTe. + `SECURE = -DINSECURE` to `CONFIG.mine`. - On macOS, there have been issues with non-system compilers. Add `CXX = /usr/bin/g++` to fix them. @@ -399,8 +398,13 @@ the integer length. Note that `-P` is optional, and it involves algorithms that are more expensive while allowing for a wider range of integer lengths. +The command-line options primarily affects non-linear computation such +as comparisons. See the [documentation on non-linear +computation](https://mp-spdz.readthedocs.io/en/latest/non-linear.html) +for more details and pointers to relevant papers. + Note that in this context integers do not wrap around according to the -bit integer bit length but the length is used for non-linear +integer bit length but the length is used for non-linear computations such as comparison. Overflow in secret integers might have security implications if no concrete prime is given. @@ -584,6 +588,14 @@ $ ../MP-SPDZ/Scripts/rep-field.sh test ### TensorFlow inference +**Note: All networks mentioned below are now supported by the +[PyTorch +interface](https://mp-spdz.readthedocs.io/en/latest/machine-learning.html#loading-pre-trained-models), +which is better integrated and thus easier to use. This section is +merely kept to document the approach used for [an earlier +paper](https://eprint.iacr.org/2019/131), but it is recommended to use +the PyTorch interface.** + MP-SPDZ supports inference with selected TensorFlow graphs, in particular DenseNet, ResNet, and SqueezeNet as used in [CrypTFlow](https://github.com/mpc-msri/EzPC). For example, you can @@ -621,7 +633,8 @@ can emulate the computation as follows: ``` ./emulate.x ``` -This runs the compiled bytecode in cleartext computation. +This runs the compiled bytecode in cleartext computation, that is, +*no* multi-party computation is performed. ## Dishonest majority @@ -647,7 +660,7 @@ The following table shows all programs for dishonest-majority computation using | `cowgear-party.x` | Adapted [LowGear](https://eprint.iacr.org/2017/1230) | Mod prime | Covert | `cowgear.sh` | | `chaigear-party.x` | Adapted [HighGear](https://eprint.iacr.org/2017/1230) | Mod prime | Covert | `chaigear.sh` | | `hemi-party.x` | Semi-homomorphic encryption | Mod prime | Semi-honest | `hemi.sh` | -| `temi-party.x` | Adapted [CDN01](https://eprint.iacr.org/2000/055) | Mod prime | Semi-honest | `temi.sh` | +| `temi-party.x` | Adapted [CDN01](https://eprint.iacr.org/2022/933) | Mod prime | Semi-honest | `temi.sh` | | `soho-party.x` | Somewhat homomorphic encryption | Mod prime | Semi-honest | `soho.sh` | | `semi-bin-party.x` | OT-based | Binary | Semi-honest | `semi-bin.sh` | | `tiny-party.x` | Adapted SPDZ2k | Binary | Malicious | `tiny.sh` | @@ -669,6 +682,9 @@ Tiny denotes the adaption of SPDZ2k to the binary setting. In particular, the SPDZ2k sacrifice does not work for bits, so we replace it by cut-and-choose according to [Furukawa et al.](https://eprint.iacr.org/2016/944) +Tinier on the other hand denotes the protocol by [Frederiksen et +al.](https://eprint.iacr.org/2015/901) also using the cut-and-choose +sacrifice by Furukawa et al. The virtual machines for LowGear and HighGear run a key generation similar to the one by [Rotaru et @@ -684,7 +700,8 @@ security similar to Semi, that is, generating additively shared Beaver triples using semi-homomorphic encryption. Temi in turn denotes the adaption of [Cramer et al.](https://eprint.iacr.org/2000/055) to LWE-based -semi-homomorphic encryption. +semi-homomorphic encryption as described in Appendix B of [this +work](https://eprint.iacr.org/2022/933). Both Hemi and Temi use the diagonal packing by [Halevi and Shoup](https://eprint.iacr.org/2014/106) for matrix multiplication. diff --git a/Scripts/compile-emulate.py b/Scripts/compile-emulate.py index b6a53a969..de04b5a6b 100755 --- a/Scripts/compile-emulate.py +++ b/Scripts/compile-emulate.py @@ -10,7 +10,8 @@ compiler.prep_compile(build=False) compiler.execute = True compiler.options.execute = 'emulate' -compiler.options.ring = compiler.options.ring or '64' +if not compiler.options.binary: + compiler.options.ring = compiler.options.ring or '64' compiler.options.keep_cisc = compiler.options.keep_cisc or '' compiler.build() prog = compiler.compile_file() diff --git a/Scripts/memory-usage.py b/Scripts/memory-usage.py index 22ed3b212..9cec52afb 100755 --- a/Scripts/memory-usage.py +++ b/Scripts/memory-usage.py @@ -62,8 +62,11 @@ def output(data): print ('Registers in other threads:') output(regout(thread_regs)) -min = 1 * domain_size -max = 3 * domain_size +if len(sys.argv) > 2: + min = max = int(sys.argv[2]) * domain_size +else: + min = 1 * domain_size + max = 3 * domain_size print ('The program requires at least an estimated %f-%f GB of RAM per party.' % (min * (total + thread_total) * 1e-9, diff --git a/Scripts/test_ecdsa.sh b/Scripts/test_ecdsa.sh index dd70b5f11..a038ef357 100755 --- a/Scripts/test_ecdsa.sh +++ b/Scripts/test_ecdsa.sh @@ -6,16 +6,17 @@ touch ECDSA/Fake-ECDSA.cpp make -j4 ecdsa Fake-ECDSA.x +port=${PORT:-$((RANDOM%10000+10000))} + run() { echo $1 - port=$[RANDOM+1024] if ! { for j in $(seq 0 $2); do - ./$1-ecdsa-party.x -pn $port -p $j 1 2>/dev/null & true + ./$1-ecdsa-party.x -pn $port -p $j 1 2>logs/ecdsa-$j & true done wait - } | grep "Online checking"; then + } | tee logs/ecdsa | grep "Online checking"; then exit 1 fi } diff --git a/Scripts/torch_mnist_lenet_import.py b/Scripts/torch_mnist_lenet_import.py index 9df05285d..06812b92b 100755 --- a/Scripts/torch_mnist_lenet_import.py +++ b/Scripts/torch_mnist_lenet_import.py @@ -6,8 +6,9 @@ import torch import torch.nn as nn import numpy +import sys -net = nn.Sequential( +layers = [ nn.Conv2d(1, 20, 5), nn.ReLU(), nn.MaxPool2d(2), @@ -19,7 +20,12 @@ nn.Linear(800, 500), nn.ReLU(), nn.Linear(500, 10) -) +] + +if 'bn' in sys.argv: + layers.insert(3, nn.BatchNorm2d(20)) + +net = nn.Sequential(*layers) f = open('Player-Data/Binary-Output-P0-0') @@ -27,10 +33,12 @@ for name in state: shape = state[name].shape - size = numpy.prod(shape) - var = numpy.fromfile(f, 'double', count=size) - var = var.reshape(shape) - state[name] = torch.Tensor(var) + if shape: + size = numpy.prod(shape) + print (name, shape, size) + var = numpy.fromfile(f, 'double', count=size) + var = var.reshape(shape) + state[name] = torch.Tensor(var) net.load_state_dict(state) diff --git a/Tools/Buffer.cpp b/Tools/Buffer.cpp index 0a20bedb7..3acc2fd06 100644 --- a/Tools/Buffer.cpp +++ b/Tools/Buffer.cpp @@ -135,14 +135,24 @@ void BufferBase::prune() void BufferBase::purge() { - if (file and not is_pipe()) + bool verbose = OnlineOptions::singleton.has_option("verbose_purge"); + if (not filename.empty() and not is_pipe()) { -#ifdef VERBOSE - cerr << "Removing " << filename << endl; -#endif + if (verbose) + cerr << "Removing " << filename << endl; unlink(filename.c_str()); - file->close(); - file = 0; + if (file) + { + file->close(); + file = 0; + } + } + else if (verbose) + { + cerr << "Not removing " << filename; + if (is_pipe()) + cerr << "because it's a pipe"; + cerr << endl; } } diff --git a/Tools/Buffer.h b/Tools/Buffer.h index 3eb916b57..2edb751d2 100644 --- a/Tools/Buffer.h +++ b/Tools/Buffer.h @@ -14,6 +14,8 @@ using namespace std; #include "Math/field_types.h" #include "Tools/time-func.h" #include "Tools/octetStream.h" +#include "Tools/pprint.h" +#include "Processor/OnlineOptions.h" #ifndef BUFFER_SIZE #define BUFFER_SIZE 101 @@ -71,10 +73,17 @@ class Buffer : public BufferBase }; template -octetStream file_signature() +octetStream file_signature(const typename T::mac_type& mac_key = {}) { octetStream res(T::type_string()); T::specification(res); + if (T::has_mac) + { + if (mac_key == typename T::mac_type()) + T::get_mac_key().pack(res); + else + mac_key.pack(res); + } return res; } @@ -95,11 +104,21 @@ octetStream check_file_signature(ifstream& file, const string& filename) throw signature_mismatch(filename); } if (file_signature() != file_spec) + { +#ifndef DEBUG_FILE_SIGNATURE + if (OnlineOptions::singleton.has_option("debug_file_signature")) +#endif + { + auto exp = file_signature(); + pprint_bytes("found ", file_spec.get_data(), file_spec.get_length()); + pprint_bytes("expected", exp.get_data(), exp.get_length()); + } throw signature_mismatch(filename); + } return file_spec; } -template +template class BufferOwner : public Buffer { ifstream* file; @@ -135,7 +154,7 @@ class BufferOwner : public Buffer BufferBase::file = file; if (file->good()) { - auto file_spec = check_file_signature(*file, this->filename); + auto file_spec = check_file_signature(*file, this->filename); this->header_length = file_spec.get_length() + sizeof(file_spec.get_length()); } diff --git a/Tools/CheckVector.h b/Tools/CheckVector.h index f74ce0963..2a0604c45 100644 --- a/Tools/CheckVector.h +++ b/Tools/CheckVector.h @@ -9,6 +9,10 @@ #include using namespace std; +#include "Math/Integer.h" +#include "Processor/Instruction.h" +#include "Processor/OnlineOptions.h" + template class CheckVector : public vector { @@ -16,7 +20,7 @@ class CheckVector : public vector CheckVector() : vector() {} CheckVector(size_t size) : vector(size) {} CheckVector(size_t size, const T& def) : vector(size, def) {} -#ifdef CHECK_SIZE +#ifndef NO_CHECK_SIZE T& operator[](size_t i) { return this->at(i); } const T& operator[](size_t i) const { return this->at(i); } #else @@ -25,4 +29,104 @@ class CheckVector : public vector #endif }; +template +class StackedVector : CheckVector +{ + vector stack; + CheckVector& full; + size_t start; + +public: + StackedVector() : + StackedVector(0) + { + } + StackedVector(size_t size) : + StackedVector(size, {}) + { + } + StackedVector(size_t size, const T& def) : + CheckVector(size, def), full(*this), start(0) + { + } + + size_t size() const { return full.size() - start; } + + void resize(size_t new_size) + { + try + { + if (OnlineOptions::singleton.has_option("verbose_registers")) + fprintf(stderr, "adding %zu %s registers to %zu\n", new_size, + T::type_string().c_str(), start); + full.resize(start + new_size); + } + catch (bad_alloc&) + { + throw runtime_error( + "not enough RAM for " + to_string(start + new_size) + + " registers"); + } + } + + void reserve(size_t new_size) { full.reserve(start + new_size); } + + auto begin() { return full.begin() + start; } + auto end() { return full.end(); } + auto begin() const { return full.begin() + start; } + auto end() const { return full.end(); } + + T& operator[](size_t i) { return full[start + i]; } + const T& operator[](size_t i) const { return full[start + i]; } + T& at(size_t i) { return full[start + i]; } + const T& at(size_t i) const { return full[start + i]; } + + void push_back(const T& x) { full.push_back(x); } + + void push_stack() + { + stack.push_back(start); + start = full.size(); + } + + void push_args(const vector& args, RegType type) + { + for (auto it = args.begin(); it < args.end(); it += 5) + if (it[1] == type and not it[0]) + { + auto dest = begin() + it[3]; + auto source = full.begin() + stack.back() + it[4]; + if (dest + it[2] > full.end()) + full.resize(start + it[1]); + assert(dest + it[2] <= full.end()); + assert(source + it[2] <= full.begin() + start); + copy(source, source + it[2], dest); + } + } + + void pop_stack(const vector& results, RegType type) + { + assert(not stack.empty()); + + for (auto it = results.begin(); it < results.end(); it += 5) + if (it[1] == type and it[0]) + { + auto source = begin() + it[4]; + auto dest = full.begin() + stack.back() + it[3]; + assert(source + it[2] <= full.end()); + assert(dest + it[2] <= full.begin() + start); + copy(source, source + it[2], dest); + } + + full.resize(start); + start = stack.back(); + stack.pop_back(); + } + + void check_index(Integer index) const + { + assert(size_t(index.get()) < size()); + } +}; + #endif /* TOOLS_CHECKVECTOR_H_ */ diff --git a/Tools/Exceptions.cpp b/Tools/Exceptions.cpp index a0e9ab4f5..88a5d3636 100644 --- a/Tools/Exceptions.cpp +++ b/Tools/Exceptions.cpp @@ -5,6 +5,16 @@ #include "Exceptions.h" #include "Math/bigint.h" +#include "Processor/OnlineOptions.h" + +void exit_error(const string& message) +{ + if (OnlineOptions::singleton.has_option("throw_exceptions")) + throw runtime_error(message); + + cerr << message << endl; + exit(1); +} IO_Error::IO_Error(const string& m) { diff --git a/Tools/Exceptions.h b/Tools/Exceptions.h index 469f544fc..f5bf59001 100644 --- a/Tools/Exceptions.h +++ b/Tools/Exceptions.h @@ -7,6 +7,8 @@ #include using namespace std; +void exit_error(const string& message); + class not_implemented: public exception { virtual const char* what() const throw() { return "Case not implemented"; } diff --git a/Tools/NamedStats.cpp b/Tools/NamedStats.cpp new file mode 100644 index 000000000..3ca528565 --- /dev/null +++ b/Tools/NamedStats.cpp @@ -0,0 +1,34 @@ +/* + * NamedStats.cpp + * + */ + +#include "NamedStats.h" + +#include +#include + +NamedStats& NamedStats::operator+=(const NamedStats& other) +{ + for (auto x : other) + (*this)[x.first] += x.second; + return *this; +} + +void NamedStats::print() +{ + long sum = 0; + for (auto x : *this) + sum += x.second; + if (sum > 0) + cerr << "Detailed costs:" << endl; + for (auto x : *this) + { + if (x.second > 0) + { + cerr.fill(' '); + cerr << " "; + cerr << setw(10) << x.second << " " << x.first << endl; + } + } +} diff --git a/Tools/NamedStats.h b/Tools/NamedStats.h new file mode 100644 index 000000000..21e9019e6 --- /dev/null +++ b/Tools/NamedStats.h @@ -0,0 +1,22 @@ +/* + * NamedStats.h + * + */ + +#ifndef TOOLS_NAMEDSTATS_H_ +#define TOOLS_NAMEDSTATS_H_ + +#include +#include + +using namespace std; + +class NamedStats : public map +{ +public: + NamedStats& operator+=(const NamedStats& other); + + void print(); +}; + +#endif /* TOOLS_NAMEDSTATS_H_ */ diff --git a/Tools/SwitchableOutput.h b/Tools/SwitchableOutput.h index 24e3d523a..54a8b950e 100644 --- a/Tools/SwitchableOutput.h +++ b/Tools/SwitchableOutput.h @@ -65,6 +65,13 @@ class SwitchableOutput if (out) out->width(w); } + + template + void signed_output(const T& x) + { + if (out) + x.output(*out, true, true); + } }; #endif /* TOOLS_SWITCHABLEOUTPUT_H_ */ diff --git a/Tools/octetStream.cpp b/Tools/octetStream.cpp index b909c7ba8..980c8c6aa 100644 --- a/Tools/octetStream.cpp +++ b/Tools/octetStream.cpp @@ -124,6 +124,14 @@ bool octetStream::equals(const octetStream& a) const } +void octetStream::flush_bits() +{ + bits[0].n = 0; + store_int<1>(bits[0].buffer); + bits[0].buffer = 0; +} + + void octetStream::append_random(size_t num) { randombytes_buf(append(num), num); @@ -226,7 +234,7 @@ void octetStream::input(istream& s) throw IO_Error("not enough data"); } -void octetStream::output(ostream& s) +void octetStream::output(ostream& s) const { s.write((char*)&len, sizeof(len)); s.write((char*)data, len); diff --git a/Tools/octetStream.h b/Tools/octetStream.h index 77faa05eb..f36c3d864 100644 --- a/Tools/octetStream.h +++ b/Tools/octetStream.h @@ -147,6 +147,8 @@ class octetStream // Return pointer to next l octets and advance pointer octet* consume(size_t l); + void flush_bits(); + /* Now store and restore different types of data (with padding for decoding) */ void store_bytes(octet* x, const size_t l); //not really "bytes"... @@ -181,6 +183,14 @@ class octetStream void store_bit(char a); char get_bit(); + template + void store_bits(char a); + template + char get_bits(); + + void store_bits(char a, int n_bits); + char get_bits(int n_bits); + /// Append big integer void store(const bigint& x); /// Read big integer @@ -243,7 +253,7 @@ class octetStream /// Input from stream, overwriting current content void input(istream& s); /// Output to stream - void output(ostream& s); + void output(ostream& s) const; /// Send on ``socket_num`` while receiving on ``receiving_socket``, /// overwriting current content @@ -309,9 +319,7 @@ inline octet* octetStream::append(const size_t l) { if (bits[0].n) { - bits[0].n = 0; - store_int<1>(bits[0].buffer); - bits[0].buffer = 0; + flush_bits(); } if (len+l>mxlen) @@ -375,29 +383,67 @@ inline size_t octetStream::get_int() } inline void octetStream::store_bit(char a) +{ + store_bits<1>(a); +} + +template +inline void octetStream::store_bits(char a) { auto& n = bits[0].n; auto& buffer = bits[0].buffer; - if (n == 8) + if (n > 8 - N_BITS) append(0); - buffer |= (a & 1) << n; - n++; + buffer |= (a & ((1 << N_BITS) - 1)) << n; + n += N_BITS; } inline char octetStream::get_bit() +{ + return get_bits<1>(); +} + +template +inline char octetStream::get_bits() { auto& n = bits[1].n; auto& buffer = bits[1].buffer; - if (n == 0) + if (n < N_BITS) { buffer = get_int<1>(); n = 8; } - return (buffer >> (8 - n--)) & 1; + auto res = (buffer >> (8 - n)) & ((1 << N_BITS) - 1); + n -= N_BITS; + return res; +} + +inline void octetStream::store_bits(char a, int n_bits) +{ + switch (n_bits) + { +#define X(N) case N: store_bits(a); break; + X(1) X(2) X(3) X(4) X(5) X(6) X(7) +#undef X + default: + throw runtime_error("wrong number of bits"); + } +} + +inline char octetStream::get_bits(int n_bits) +{ + switch (n_bits) + { +#define X(N) case N: return get_bits(); + X(1) X(2) X(3) X(4) X(5) X(6) X(7) +#undef X + default: + throw runtime_error("wrong number of bits"); + } } @@ -448,6 +494,12 @@ inline int octetStream::get() return get_int(sizeof(int)); } +template<> +inline size_t octetStream::get() +{ + return get_int(sizeof(size_t)); +} + template void octetStream::store(const vector& v) { @@ -461,6 +513,7 @@ void octetStream::get(vector& v, const T& init) { size_t size; get(size); + v.clear(); v.reserve(size); for (size_t i = 0; i < size; i++) { diff --git a/Tools/pprint.h b/Tools/pprint.h index 3df479f13..89ea49569 100644 --- a/Tools/pprint.h +++ b/Tools/pprint.h @@ -1,3 +1,5 @@ +#ifndef TOOLS_PPRINT_H_ +#define TOOLS_PPRINT_H_ #include #include @@ -11,3 +13,5 @@ inline void pprint_bytes(const char *label, unsigned char *bytes, int len) cout << setfill('0') << setw(2) << hex << (int) bytes[j]; cout << dec << endl; } + +#endif diff --git a/Utils/Check-Offline-Z2k.cpp b/Utils/Check-Offline-Z2k.cpp index b7255a282..147c0a8ce 100644 --- a/Utils/Check-Offline-Z2k.cpp +++ b/Utils/Check-Offline-Z2k.cpp @@ -6,6 +6,7 @@ #include "Protocols/fake-stuff.h" #include "Protocols/fake-stuff.hpp" +#include "Protocols/Share.hpp" #include "Math/Z2k.hpp" #include diff --git a/Utils/Check-Offline.cpp b/Utils/Check-Offline.cpp index aec3c5954..c8a9c9f22 100644 --- a/Utils/Check-Offline.cpp +++ b/Utils/Check-Offline.cpp @@ -20,6 +20,7 @@ #include "Math/Setup.h" #include "Processor/Data_Files.h" +#include "Protocols/Share.hpp" #include "Protocols/fake-stuff.hpp" #include "Protocols/ReplicatedPrep.hpp" #include "Processor/Data_Files.hpp" @@ -35,7 +36,7 @@ using namespace std; string PREP_DATA_PREFIX; template -void check_mult_triples(const typename T::mac_key_type& key,int N,vector*>& dataF) +void check_mult_triples(const KeySetup& key_setup,int N,vector*>& dataF) { typename T::clear a,b,c; typename T::mac_type mac; @@ -46,7 +47,11 @@ void check_mult_triples(const typename T::mac_key_type& key,int N,vectoreof(DATA_TRIPLE)) { for (int i = 0; i < N; i++) - dataF[i]->get_three(DATA_TRIPLE, Sa[i], Sb[i], Sc[i]); + { + T::set_mac_key(key_setup.key_shares.at(i)); + dataF[i]->get_three(DATA_TRIPLE, Sa[i], Sb[i], Sc[i]); + } + auto& key = key_setup.key; check_share(Sa, a, mac, N, key); check_share(Sb, b, mac, N, key); check_share(Sc, c, mac, N, key); @@ -63,7 +68,10 @@ void check_mult_triples(const typename T::mac_key_type& key,int N,vector -void check_tuples(const typename T::mac_key_type& key,int N,vector*>& dataF, Dtype type) +void check_tuples(const KeySetup& key_setup,int N,vector*>& dataF, Dtype type) { typename T::clear a,b,c,res; typename T::mac_type mac; @@ -110,9 +118,12 @@ void check_tuples(const typename T::mac_key_type& key,int N,vectoreof(type)) { for (int i = 0; i < N; i++) - dataF[i]->get_two(type, Sa[i], Sb[i]); - check_share(Sa, a, mac, N, key); - check_share(Sb, b, mac, N, key); + { + T::set_mac_key(key_setup.key_shares.at(i)); + dataF[i]->get_two(type, Sa[i], Sb[i]); + } + check_share(Sa, a, mac, N, key_setup.key); + check_share(Sb, b, mac, N, key_setup.key); check_tuple(a, b, n, type); n++; } @@ -128,7 +139,7 @@ void check_tuples(const typename T::mac_key_type& key,int N,vector -void check_bits(const typename T::mac_key_type& key,int N,vector*>& dataF) +void check_bits(const KeySetup& key_setup,int N,vector*>& dataF) { typename T::clear a,b,c,res; typename T::mac_type mac; @@ -139,8 +150,11 @@ void check_bits(const typename T::mac_key_type& key,int N,vectoreof(DATA_BIT)) { for (int i = 0; i < N; i++) + { + T::set_mac_key(key_setup.key_shares.at(i)); dataF[i]->get_one(DATA_BIT, Sa[i]); - check_share(Sa, a, mac, N, key); + } + check_share(Sa, a, mac, N, key_setup.key); if (!(a.is_zero() || a.is_one())) { @@ -159,7 +173,7 @@ void check_bits(const typename T::mac_key_type& key,int N,vector -void check_inputs(const typename T::mac_key_type& key,int N,vector*>& dataF) +void check_inputs(const KeySetup& key_setup,int N,vector*>& dataF) { typename T::clear a, x; typename T::mac_type mac; @@ -172,8 +186,11 @@ void check_inputs(const typename T::mac_key_type& key,int N,vectorinput_eof(player)) { for (int i = 0; i < N; i++) + { + T::set_mac_key(key_setup.key_shares.at(i)); dataF[i]->get_input(Sa[i], x, player); - check_share(Sa, a, mac, N, key); + } + check_share(Sa, a, mac, N, key_setup.key); if (a != (x)) throw bad_value(); n++; @@ -189,30 +206,46 @@ void check_inputs(const typename T::mac_key_type& key,int N,vector -vector*> setup(int N, DataPositions& usage, int thread_num = -1) +vector*> setup(int N, + DataPositions& usage, int thread_num = -1) { vector*> dataF(N); for (int i = 0; i < N; i++) - dataF[i] = new Sub_Data_Files(i, N, - get_prep_sub_dir(PREP_DATA_PREFIX, N), usage, thread_num); + { + dataF[i] = new Sub_Data_Files(i, N, + get_prep_sub_dir(PREP_DATA_PREFIX, N), usage, thread_num); + } return dataF; } template -void check(int N, bool only_bits = false) +void check_with_error(int N, bool only_bits = false) { - typename T::mac_key_type key; - read_global_mac_key(get_prep_sub_dir(PREP_DATA_PREFIX, N), N, key); + auto key_setup = read_global_mac_key( + get_prep_sub_dir(PREP_DATA_PREFIX, N), N); DataPositions usage(N); auto dataF = setup(N, usage); - check_bits(key, N, dataF); + check_bits(key_setup, N, dataF); if (not only_bits) { - check_mult_triples(key, N, dataF); - check_inputs(key, N, dataF); - check_tuples(key, N, dataF, DATA_SQUARE); - check_tuples(key, N, dataF, DATA_INVERSE); + check_mult_triples(key_setup, N, dataF); + check_inputs(key_setup, N, dataF); + check_tuples(key_setup, N, dataF, DATA_SQUARE); + check_tuples(key_setup, N, dataF, DATA_INVERSE); + } +} + +template +void check(int N, bool only_bits = false) +{ + try + { + check_with_error(N, only_bits); + } + catch (exception& e) + { + cerr << "Error: " << e.what() << endl; } } @@ -303,9 +336,13 @@ int main(int argc, const char** argv) if (!use_montgomery) { // no montgomery - gfp::init_field(gfp::pr(), false); + auto pr = gfp::pr(); + gfp::reset(); + gfp::init_field(pr, false); } + OnlineOptions::singleton.options.push_back("debug_file_signature"); + int N = nparties; try @@ -327,17 +364,4 @@ int main(int argc, const char** argv) gf2n_short::init_field(lg2); check>(N); } - - if (N == 3) - { - DataPositions pos(N); - auto dataF = setup>(N, pos); - check_bits({}, N, dataF); - - check>({}, N); - - auto dataF2 = setup(N, pos, 0); - check_mult_triples({}, N, dataF2); - check_bits({}, N, dataF2); - } } diff --git a/Utils/Fake-Offline.cpp b/Utils/Fake-Offline.cpp index 68465bcc7..12ceb7d01 100644 --- a/Utils/Fake-Offline.cpp +++ b/Utils/Fake-Offline.cpp @@ -31,6 +31,7 @@ #include "Protocols/fake-stuff.hpp" #include "Protocols/Shamir.hpp" +#include "Protocols/Share.hpp" #include "Processor/Data_Files.hpp" #include "Math/Z2k.hpp" #include "Math/gfp.hpp" @@ -67,17 +68,17 @@ class FakeParams template void make_with_mac_key(int nplayers, int default_num, bool zero,PRNG& G, - const typename T::bit_type::mac_type& bit_key = {}); + const KeySetup& bit_keys = {}); template - void make_basic(const typename T::mac_type& key, int nplayers, int nitems, - bool zero, PRNG& G, const typename T::bit_type::mac_type& bit_key = {}); + void make_basic(const KeySetup& key, int nplayers, int nitems, + bool zero, PRNG& G, const KeySetup& bit_keys = {}); template - void make_edabits(const typename T::mac_type& key, int N, int ntrip, bool zero, PRNG& G, false_type, - const typename T::bit_type::mac_type& bit_key = {}); + void make_edabits(const KeySetup& key, int N, int ntrip, bool zero, PRNG& G, false_type, + const KeySetup& bit_key = {}); template - void make_edabits(const typename T::mac_type&, int, int, bool, PRNG&, true_type, - const typename T::bit_type::mac_type& = {}) + void make_edabits(const KeySetup&, int, int, bool, PRNG&, true_type, + const KeySetup& = {}) { } }; @@ -87,7 +88,7 @@ class FakeParams * ntrip = Number tuples needed */ template -void make_square_tuples(const typename T::mac_type& key,int N,int ntrip,const string& str,bool zero,PRNG& G) +void make_square_tuples(const KeySetup& key,int N,int ntrip,const string& str,bool zero,PRNG& G) { (void) str; Files files(N, key, prep_data_prefix, DATA_SQUARE, G); @@ -108,7 +109,7 @@ void make_square_tuples(const typename T::mac_type& key,int N,int ntrip,const st * ntrip = Number bits needed */ template -void make_bits(const typename T::mac_type& key, int N, int ntrip, bool zero, PRNG& G, +void make_bits(const KeySetup& key, int N, int ntrip, bool zero, PRNG& G, int thread_num = -1) { @@ -124,8 +125,8 @@ void make_bits(const typename T::mac_type& key, int N, int ntrip, bool zero, PRN } template -void make_dabits(const typename T::mac_type& key, int N, int ntrip, bool zero, PRNG& G, - const typename T::bit_type::mac_type& bit_key = { }) +void make_dabits(const KeySetup& key, int N, int ntrip, bool zero, PRNG& G, + const KeySetup& bit_key = { }) { Files files(N, key, get_prep_sub_dir(prep_data_prefix, N) @@ -134,13 +135,13 @@ void make_dabits(const typename T::mac_type& key, int N, int ntrip, bool zero, P { bool bit = not zero && G.get_bit(); files.template output_shares(bit); - files.template output_shares::bit_type>(bit, bit_key); + files.template output_shares::bit_type>(bit, bit_key.key); } } template -void FakeParams::make_edabits(const typename T::mac_type& key, int N, int ntrip, bool zero, PRNG& G, false_type, - const typename T::bit_type::mac_type& bit_key) +void FakeParams::make_edabits(const KeySetup& key, int N, int ntrip, bool zero, PRNG& G, false_type, + const KeySetup& bit_key) { vector lengths; opt.get("-e")->getInts(lengths); @@ -159,7 +160,7 @@ void FakeParams::make_edabits(const typename T::mac_type& key, int N, int ntrip, for (auto& a : as) files.template output_shares(a); for (auto& b : bs) - files.template output_shares(b, bit_key); + files.template output_shares(b, bit_key.key); } } } @@ -168,7 +169,7 @@ void FakeParams::make_edabits(const typename T::mac_type& key, int N, int ntrip, * ntrip = Number inputs needed */ template -void make_inputs(const typename T::mac_type& key,int N,int ntrip,const string& str,bool zero,PRNG& G) +void make_inputs(const KeySetup& key,int N,int ntrip,const string& str,bool zero,PRNG& G) { (void) str; @@ -184,7 +185,7 @@ void make_inputs(const typename T::mac_type& key,int N,int ntrip,const string& s i); cout << "Opening " << filename << endl; outf[i].open(filename, ios::out | ios::binary); - file_signature().output(outf[i]); + file_signature(key.get(i)).output(outf[i]); if (outf[i].fail()) throw file_error(filename); } @@ -192,7 +193,7 @@ void make_inputs(const typename T::mac_type& key,int N,int ntrip,const string& s { if (!zero) a.randomize(G); - make_share(Sa,a,N,key,G); + make_share(Sa,a,N,key.key,G); for (int j=0; j -void make_PreMulC(const typename T::mac_type& key, int N, int ntrip, bool zero, PRNG& G) +void make_PreMulC(const KeySetup& key, int N, int ntrip, bool zero, PRNG& G) { stringstream ss; ss << get_prep_sub_dir(prep_data_prefix, N) << "PreMulC-" << T::type_short(); @@ -253,7 +254,7 @@ unsigned char sbox[256] = }; template -void make_AES(const typename T::mac_type& key, int N, int ntrip, bool zero, PRNG& G) { +void make_AES(const KeySetup& key, int N, int ntrip, bool zero, PRNG& G) { stringstream ss; ss << get_prep_sub_dir(prep_data_prefix, N) << "Sbox-" << T::type_short(); Files files(N, key, ss.str(), G); @@ -289,7 +290,7 @@ vector> des_sbox = { template -void make_DES(const typename T::mac_type& key, int N, int ntrip, bool zero, PRNG& G) +void make_DES(const KeySetup& key, int N, int ntrip, bool zero, PRNG& G) { stringstream ss; ss << get_prep_sub_dir(prep_data_prefix, N) << "SboxDes-" << T::type_short(); @@ -314,7 +315,7 @@ void make_DES(const typename T::mac_type& key, int N, int ntrip, bool zero, PRNG } template -void make_Sbox(const typename T::mac_type& key, int N, int ntrip, bool zero, PRNG& G, T, true_type) +void make_Sbox(const KeySetup& key, int N, int ntrip, bool zero, PRNG& G, T, true_type) { make_AES(key, N, ntrip, zero, G); make_DES(key, N, ntrip, zero, G); @@ -322,19 +323,19 @@ void make_Sbox(const typename T::mac_type& key, int N, int ntrip, bool zero, PRN template -void make_Sbox(const typename T::mac_type& key, int N, int ntrip, bool zero, PRNG&, T, false_type) +void make_Sbox(const KeySetup& key, int N, int ntrip, bool zero, PRNG&, T, false_type) { (void)key, (void)N, (void)ntrip, (void)zero; } template -void make_Sbox(const typename T::mac_type& key, int N, int ntrip, bool zero, PRNG& G) +void make_Sbox(const KeySetup& key, int N, int ntrip, bool zero, PRNG& G) { make_Sbox(key, N, ntrip, zero, G, T(), T::clear::characteristic_two); } template -void make_minimal(const typename T::mac_type& key, int nplayers, int nitems, bool zero, PRNG& G) +void make_minimal(const KeySetup& key, int nplayers, int nitems, bool zero, PRNG& G) { make_mult_triples(key, nplayers, nitems, zero, prep_data_prefix, G); make_bits(key, nplayers, nitems, zero, G); @@ -342,8 +343,8 @@ void make_minimal(const typename T::mac_type& key, int nplayers, int nitems, boo } template -void FakeParams::make_basic(const typename T::mac_type& key, int nplayers, - int nitems, bool zero, PRNG& G, const typename T::bit_type::mac_type& bit_key) +void FakeParams::make_basic(const KeySetup& key, int nplayers, + int nitems, bool zero, PRNG& G, const KeySetup& bit_key) { make_minimal(key, nplayers, nitems, zero, G); make_square_tuples(key, nplayers, nitems, T::type_short(), zero, G); @@ -363,11 +364,11 @@ void FakeParams::make_basic(const typename T::mac_type& key, int nplayers, template void FakeParams::make_with_mac_key(int nplayers, int default_num, bool zero, PRNG& G, - const typename T::bit_type::mac_type& bit_key) + const KeySetup& bit_keys) { - typename T::mac_share_type::open_type key; - generate_mac_keys(key, nplayers, prep_data_prefix, G); - make_basic(key, nplayers, default_num, zero, G, bit_key); + KeySetup keys; + generate_mac_keys(keys, nplayers, prep_data_prefix, G); + make_basic(keys, nplayers, default_num, zero, G, bit_keys); } template @@ -714,8 +715,9 @@ int FakeParams::generate() } /* Find number players and MAC keys etc*/ - typename T::mac_type::Scalar keyp; - gf2n key2; + typedef Share sgf2n; + KeySetup keyp; + KeySetup key2; // create PREP_DIR if not there if (mkdir_p(PREP_DIR) == -1) @@ -724,8 +726,6 @@ int FakeParams::generate() throw file_error(PREP_DIR); } - typedef Share sgf2n; - generate_mac_keys(keyp, nplayers, prep_data_prefix, G); generate_mac_keys(key2, nplayers, prep_data_prefix, G); @@ -766,16 +766,16 @@ int FakeParams::generate() gf2n_short::reset(); gf2n_short::init_field(); - Z2 keyt; + KeySetup> keyt; generate_mac_keys>(keyt, nplayers, prep_data_prefix, G); make_minimal>(keyt, nplayers, default_num / 64, zero, G); - gf2n_short keytt; - generate_mac_keys>(keytt, nplayers, prep_data_prefix, G); - make_minimal>(keytt, nplayers, default_num, zero, G); + KeySetup keytt; + generate_mac_keys(keytt, nplayers, prep_data_prefix, G); + make_minimal(keytt, nplayers, default_num, zero, G); make_dabits(keyp, nplayers, default_num, zero, G, keytt); make_edabits(keyp, nplayers, default_num, zero, G, false_type(), keytt); diff --git a/Utils/check-passive.cpp b/Utils/check-passive.cpp index 86923be52..6f95d9815 100644 --- a/Utils/check-passive.cpp +++ b/Utils/check-passive.cpp @@ -25,7 +25,7 @@ void check_triples(int n_players, string type_char = "") ss << "-P" << i; inputFiles[i].open(ss.str().c_str()); cout << "Opening file " << ss.str() << endl; - octetStream tmp, tmp2 = file_signature(); + octetStream tmp, tmp2 = file_signature>(); tmp.input(inputFiles[i]); assert(tmp == tmp2); } diff --git a/Utils/stream-fake-mascot-triples.cpp b/Utils/stream-fake-mascot-triples.cpp index 16901471c..e2592217d 100644 --- a/Utils/stream-fake-mascot-triples.cpp +++ b/Utils/stream-fake-mascot-triples.cpp @@ -9,13 +9,14 @@ #include "Math/Setup.hpp" #include "Protocols/fake-stuff.hpp" +#include "Protocols/Share.hpp" class Info { public: int thread_num; int nplayers; - gfpvar key; + KeySetup> key; pthread_t thread; }; @@ -52,7 +53,7 @@ int main() int lgp = 128; string prep_data_prefix = PREP_DIR; gfpvar::generate_setup(prep_data_prefix, nplayers, lgp); - T::mac_key_type keyp; + KeySetup keyp; SeededPRNG G; generate_mac_keys(keyp, nplayers, prep_data_prefix, G); diff --git a/Yao/YaoEvalWire.cpp b/Yao/YaoEvalWire.cpp index 456a68922..d49547e65 100644 --- a/Yao/YaoEvalWire.cpp +++ b/Yao/YaoEvalWire.cpp @@ -11,6 +11,7 @@ #include "BMR/prf.h" #include "BMR/common.h" #include "GC/ArgTuples.h" +#include "Tools/CheckVector.h" #include "GC/Processor.hpp" #include "GC/Secret.hpp" @@ -78,7 +79,7 @@ void YaoEvalWire::and_singlethread(GC::Processor >& proc party.counter += counter - party.get_gate_id(); } -void YaoEvalWire::and_(GC::Memory >& S, +void YaoEvalWire::and_(StackedVector >& S, const vector& args, size_t start, size_t end, size_t, YaoGate* gates, long& gate_id, PRNG&, map&, bool repeat, YaoEvaluator& evaluator) diff --git a/Yao/YaoEvalWire.h b/Yao/YaoEvalWire.h index 0f082657e..7bc1e9c0a 100644 --- a/Yao/YaoEvalWire.h +++ b/Yao/YaoEvalWire.h @@ -48,7 +48,7 @@ class YaoEvalWire : public YaoWire static void and_singlethread( GC::Processor>& processor, const vector& args, int total_ands); - static void and_(GC::Memory>& S, + static void and_(StackedVector>& S, const vector& args, size_t start, size_t end, size_t total_ands, YaoGate* gate, long& counter, PRNG& prng, map& timers, bool repeat, YaoEvaluator& garbler); diff --git a/Yao/YaoGarbleWire.cpp b/Yao/YaoGarbleWire.cpp index 606a84acb..d0881761e 100644 --- a/Yao/YaoGarbleWire.cpp +++ b/Yao/YaoGarbleWire.cpp @@ -95,7 +95,7 @@ void YaoGarbleWire::and_singlethread(GC::Processor >& garbler.counter += counter - garbler.get_gate_id(); } -void YaoGarbleWire::and_(GC::Memory >& S, +void YaoGarbleWire::and_(StackedVector >& S, const vector& args, size_t start, size_t end, size_t, YaoGate* gate, long& counter, PRNG& prng, map& timers, bool repeat, YaoGarbler& garbler) diff --git a/Yao/YaoGarbleWire.h b/Yao/YaoGarbleWire.h index 20d56ef8f..4af6b0272 100644 --- a/Yao/YaoGarbleWire.h +++ b/Yao/YaoGarbleWire.h @@ -49,7 +49,7 @@ class YaoGarbleWire : public YaoWire static void and_singlethread( GC::Processor>& processor, const vector& args, bool repeat); - static void and_(GC::Memory>& S, + static void and_(StackedVector>& S, const vector& args, size_t start, size_t end, size_t total_ands, YaoGate* gate, long& counter, PRNG& prng, map& timers, bool repeat, YaoGarbler& garbler); diff --git a/deps/libOTe b/deps/libOTe index cd89232ff..f613f2216 160000 --- a/deps/libOTe +++ b/deps/libOTe @@ -1 +1 @@ -Subproject commit cd89232ffac74286a4963d93480db44cbf604c41 +Subproject commit f613f221650144367e0fddce5ca07fc2dda09e32 diff --git a/doc/add-instruction.rst b/doc/add-instruction.rst new file mode 100644 index 000000000..2b02a1c7e --- /dev/null +++ b/doc/add-instruction.rst @@ -0,0 +1,171 @@ +Adding an Instruction +--------------------- + +If you want to add functionality that isn't captured by the current +virtual machine design, you might need to add further +instructions. This section explains how the virtual machine is built +on both the frontend (Python) and the backend (C++). + + +Frontend +======== + +The instructions are defined as classes in +:download:`../Compiler/instructions.py` and +:download:`../Compiler/GC/instructions.py`. Every class requires the +attributes :py:obj:`opcode` and :py:obj:`arg_format` to be +set. Consider the example of :py:class:`~Compiler.instructions.prefixsums`:: + + @base.vectorize + class prefixsums(base.Instruction): + """ Prefix sum. + + :param: result (sint) + :param: input (sint) + + """ + __slots__ = [] + code = base.opcodes['PREFIXSUMS'] + arg_format = ['sw','s'] + +:py:obj:`opcode` is set from :py:obj:`opcodes` in +:download:`../Compiler/instructions_base.py`. This is simply for +convenience as it allow copying from the C++ code (see below). The +only requirement for opcodes is that they are unique 10-bit integers. + +:py:obj:`arg_format` has to be iterable over strings indicating the +nature of arguments. In the example, ``sw`` indicates that a secret +integer register is written to, and ``s`` indicates that a secret +integer register is read from. :py:obj:`ArgFormats` defines all +possible argument types:: + + ArgFormats = { + 'c': ClearModpAF, + 's': SecretModpAF, + 'cw': ClearModpAF, + 'sw': SecretModpAF, + 'cg': ClearGF2NAF, + 'sg': SecretGF2NAF, + 'cgw': ClearGF2NAF, + 'sgw': SecretGF2NAF, + 'ci': ClearIntAF, + 'ciw': ClearIntAF, + '*': AnyRegAF, + '*w': AnyRegAF, + 'i': ImmediateModpAF, + 'ig': ImmediateGF2NAF, + 'int': IntArgFormat, + 'long': LongArgFormat, + 'p': PlayerNoAF, + 'str': String, + 'varstr': VarString, + } + +The values of this dictionary are classes defined in +:download:`../Compiler/instructions_base.py` which can encode them +into bytes to written to the bytecode files. Most types are encoded as +four-byte values except ``long``, which uses eight bytes, and +``varstr``, which has a variable length. The type classes also define +functionality to check arguments for correctness such as Python type. + +By default, register arguments are understood as single registers. The +:py:obj:`vectorize` decorator is an easy way to allow vector arguments +if all register arguments have the same length. The vector size stored +independently of the arguments. + +All instruction classes should inherit from :py:obj:`Instruction` in +:download:`../Compiler/instructions_base.py`. + + +Backend +======= + +.. default-domain:: cpp + +The backend functionality has three aspects: + +1. Parsing the bytecode and creating an internal representation +2. Figuring out the resource requirements of the instruction + (registers and memory) +3. Execution + + +Parsing +~~~~~~~ + +The internal representation is done via the :cpp:class:`Instruction` +class defined in :download:`../Processor/Instruction.h`. The arguments +are parsed in :cpp:func:`parse_operands` defined in +:download:`../Processor/Instruction.hpp`. It contains a large switch +statement covering all opcodes. Sticking to the example of +:py:class:`~Compiler.instructions.prefixsums`, the relevant code there +is as follows:: + + case PREFIXSUMS: + ... + get_ints(r, s, 2); + break; + +This puts the two integer values corresponding to the two arguments +into ``r[0]`` and ``r[1]`` within the :cpp:class:`Instruction` +object. :cpp:member:`r` is an array of four 32-bit integers, which is +enough for many simple instructions. More complex instruction use +:cpp:member:`start`, which is a variable-length C++ vector of 32-bit +integers. + + +Resourcing +~~~~~~~~~~ + +Because the number of registers depends on the programs, the virtual +machine has to find out the requirements for every single +instruction. The main function for this is :cpp:func:`get_max_reg` in +:download:`../Processor/Instruction.hpp`, which returns the maximum +register that is written to for a particular register type. It +contains two switch statements. The first one contains special +treatment for instructions that write to more than one register type +such as :py:class:`~Compiler.instructions.dabit`. However, for most +instruction including :py:class:`~Compiler.instructions.prefixsums`, +it checks the type currently queried against the type defined by +:cpp:func:`get_reg_type` and returns 0 if there is a +mismatch. :cpp:func:`get_reg_type` makes use of the fact that the +opcodes are grouped. For :py:class:`prefixsums`, it returns ``SINT``, +which is the default. + +The second switch statement then treats further special cases +where :cpp:class:`start` is used or `r` contains registers of +different types. None of this applies for :py:class:`prefixsums`, so +the return value is simply the maximum over :cpp:member:`r` and +:cpp:member:`start` plus the vector size:: + + unsigned res = 0; + for (auto x : start) + res = max(res, (unsigned)x); + for (auto x : r) + res = max(res, (unsigned)x); + return res + size; + + +Execution +~~~~~~~~~ + +Execution is again defined by several switch statements over the +opcode, the outermost of which is in :py:func:`Program::execute` +defined in :download:`../Processor/Instruction.hpp`. It uses the `X +macro `_ pattern for a compact +representation. :py:class:`~Compiler.instructions.prefixsums` is +implemented in :download:`../Processor/instructions.h` as follows:: + + X(PREFIXSUMS, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \ + sint s, \ + s += *op1++; *dest++ = s) \ + +This macro has three arguments: the opcode, the setup step, and vector +loop step. The setup step involves getting pointers according to the +register addresses ``r[0]`` and ``r[1]`` as well as initializing a +running variable. The loop step then adds the next input element to +the running variable and stores in the destination. + +Another important switch statement is in +:cpp:member:`Instruction::execute`. See :ref:`execution` for further +examples. diff --git a/doc/conf.py b/doc/conf.py index fb935940c..2024f08bf 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -15,6 +15,7 @@ import os import sys sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath('../ExternalIO')) exec(compile(open('gen-instructions.py').read(), 'gen', 'exec')) diff --git a/doc/index.rst b/doc/index.rst index 47b481a33..2cecbe0da 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -23,6 +23,7 @@ If you're new to MP-SPDZ, consider the following: Compiler utils journey + optimization instructions low-level ml-quickstart @@ -35,6 +36,7 @@ If you're new to MP-SPDZ, consider the following: preprocessing lowest-level add-protocol + add-instruction homomorphic-encryption ecdsa troubleshooting diff --git a/doc/instructions.rst b/doc/instructions.rst index 354597e2e..39bcf309b 100644 --- a/doc/instructions.rst +++ b/doc/instructions.rst @@ -65,7 +65,10 @@ potentially be used. The length 106 is composed as follows: assuming 64-bit integers, the difference used for comparison is a 65-bit integer, to which 40 bits are added for statistical masking, resulting in a 105 bits, and it takes a 106-bit prime to able to contain all -105-bit numbers. +105-bit numbers. Finally, the last line indicates which compile-time +options would change the program. This supports the virtual machine +in suggesting options that are compatible with the protocol +implementation. Bytecode diff --git a/doc/io.rst b/doc/io.rst index fd092993d..99813ac81 100644 --- a/doc/io.rst +++ b/doc/io.rst @@ -89,10 +89,24 @@ functions are available for :py:class:`~Compiler.types.sfix` and See also :ref:`client ref` below. +Secret Shares via Socket +~~~~~~~~~~~~~~~~~~~~~~~~ + +Secret can be sent and received via socket by using +:py:func:`~Compiler.types.sint.write_to_socket` and +:py:func:`~Compiler.types.sint.read_from_socket` (and the same +functions in :py:class:`~Compiler.types.sfix`). The connections are set +up in the same way as in the previous section. See :ref:`multinode` +for an example how this is used to distribute every party among +multiple nodes. If you use the client interface, you should use the +:cpp:class:`octetStream` class for serialization. The format is the same +as in the following section. + + .. _persistence: -Secret Shares -~~~~~~~~~~~~~ +Secret Shares via Files +~~~~~~~~~~~~~~~~~~~~~~~ :py:func:`Compiler.types.sint.read_from_file` and :py:func:`Compiler.types.sint.write_to_file` allow reading and writing @@ -125,10 +139,76 @@ etc. Note also that all types based on address is only a base address. This means that vectors will be written to the memory starting at the given address. + +Python Trusted Client Tutorial +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In this section, we will illustrate how to use the client interface to +supplement individual parties in the secure computation. This example +consists of :download:`../Programs/Source/personal_client_example.py` +for the server side and +:download:`../ExternalIO/personal-client-example.py` for the client +side. + +The servers start by listening for and accepting one connection:: + + listen_for_clients(15000) + socket = accept_client_connection(15000) + +The clients in turn connect to the server that is assigned to them:: + + party = int(sys.argv[1]) + client = Client(['localhost'], 15000 + party, 0) + +:py:obj:`party` stands for the number of the relevant server. Then, +the clients of the of the first two servers sample 1000 random values +and send them to their assigned server:: + + n = 1000 + if party < 2: + client.send_public_inputs(random.gauss(0, 1) * 2 ** 16 for i in range(n)) + +Note that the values are multiplied by :math:`2^{16}` to match the +default fixed-point precision. + +The first two servers then receive these values, convert them to +shares, and then send the *shares* to their personal client:: + + n = 1000 + for i in range(2): + x = personal.read_fix_from_socket(i, socket, n) + sfix(x).write_fully_to_socket(socket) + +Note that all servers run this code because they are all involved in +the secret-sharing process. If you're aiming for the secret sharing to +happen on the client side, see `this section `_. + +The clients receive the shares and sum them pair-wise before sending them +back:: + + x = [client.receive_plain_values() for i in range(2)] + client.send_public_inputs(a + b for a, b in zip(*x)) + +Note that this works whether the shares have MACs or not because +adding shares with MACs amounts to simply adding both. + +The servers finally receive the summed values, perform another sum, +and output the result:: + + res = sum(sfix.read_from_socket(socket, n)) + print_ln('%s', res.reveal()) + + +Python Reference +~~~~~~~~~~~~~~~~ + +.. autoclass:: ExternalIO.client.Client + :members: + .. _client ref: -Reference -~~~~~~~~~ +C++ Reference +~~~~~~~~~~~~~ .. doxygenclass:: Client :members: diff --git a/doc/journey.rst b/doc/journey.rst index 64ba67be4..b6b7521a1 100644 --- a/doc/journey.rst +++ b/doc/journey.rst @@ -151,6 +151,8 @@ the sixth indicates the compilation command line. The remaining lines indicate further options used during compilation. +.. _execution: + Execution --------- diff --git a/doc/machine-learning.rst b/doc/machine-learning.rst index 5dc2fc865..a6ac1b7d5 100644 --- a/doc/machine-learning.rst +++ b/doc/machine-learning.rst @@ -446,6 +446,12 @@ This outputs the accuracy of the network. You can use probability distributions or top guesses (the latter with ``top=True``) for any sample data. +You can also use some networks provided within PyTorch as demonstrated +by :download:`../Programs/Source/torch_squeeze.py`:: + + model = torchvision.models.get_model('SqueezeNet1_1', weights='DEFAULT') + layers = ml.layers_from_torch(model, ...) + Storing and loading models ~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -478,6 +484,8 @@ Using ``var.input_from(player)`` instead the model would be input privately by a party. +.. _reveal-model: + Exporting models ~~~~~~~~~~~~~~~~ diff --git a/doc/ml-quickstart.rst b/doc/ml-quickstart.rst index f6114378c..ab0231a61 100644 --- a/doc/ml-quickstart.rst +++ b/doc/ml-quickstart.rst @@ -48,40 +48,54 @@ The first call should give the following output: .. code-block:: console $ Scripts/compile-emulate.py foo - Default bit length: 63 - Default security parameter: 40 + Default bit length for compilation: 63 + Default security parameter for compilation: 40 Compiling file Programs/Source/foo.mpc Writing binary data to Player-Data/Input-Binary-P0-0 Setting learning rate to 0.01 Using SGD Initializing dense weights in [-1.224745,1.224745] + Writing to Programs/Bytecode/foo-TruncPr(3)_47_16-2.bc Writing to Programs/Bytecode/foo-multithread-1.bc 2 runs per epoch - Writing to Programs/Bytecode/foo-multithread-3.bc - Writing to Programs/Bytecode/foo-multithread-4.bc - Writing to Programs/Bytecode/foo-multithread-5.bc + Writing to Programs/Bytecode/foo-TruncPr(1)_47_16-5.bc + Writing to Programs/Bytecode/foo-Dense-forward-4.bc + Writing to Programs/Bytecode/foo-TruncPr(1)_45_14-7.bc + Writing to Programs/Bytecode/foo-exp2_fx(1)_31_16_False-9.bc + Writing to Programs/Bytecode/foo-log2_fx(1)_31_16-11.bc + Writing to Programs/Bytecode/foo-TruncPr(1)_46_15-13.bc + Writing to Programs/Bytecode/foo-Output-forward-6.bc + Writing to Programs/Bytecode/foo-multithread-15.bc + Writing to Programs/Bytecode/foo-multithread-16.bc + Writing to Programs/Bytecode/foo-TruncPr(3)_46_15-18.bc + Writing to Programs/Bytecode/foo-multithread-17.bc Initializing dense weights in [-1.224745,1.224745] - Writing to Programs/Bytecode/foo-multithread-7.bc - Writing to Programs/Bytecode/foo-multithread-8.bc - Writing to Programs/Bytecode/foo-multithread-9.bc + Writing to Programs/Bytecode/foo-multithread-19.bc + Writing to Programs/Bytecode/foo-TruncPr(2)_47_16-22.bc + Writing to Programs/Bytecode/foo-multithread-21.bc + Writing to Programs/Bytecode/foo-multithread-23.bc + Writing to Programs/Bytecode/foo-Dense-forward-20.bc + Writing to Programs/Bytecode/foo-FPDiv(1)_31_16-24.bc Writing to Programs/Schedules/foo.sch Writing to Programs/Bytecode/foo-0.bc - Hash: 33f8d22d99960897f41fb2da31e7f5a0501d2e1071789e52d73b4043e5343831 + Hash: 8227349c6796977e0035cd9e925585603531eb9aa98ac586440c1abd360ae712 Program requires at most: - 8 integer inputs from player 0 - 61054 integer bits - 190109 integer triples - 200 matrix multiplications (1x3 * 3x1) - 200 matrix multiplications (3x1 * 1x1) - 1 matrix multiplications (2x3 * 3x1) - 28406 virtual machine rounds - Using security parameter 40 + 8 integer inputs from player 0 + 2402 integer opens + 67654 integer bits + 204509 integer triples + 200 matrix multiplications (1x3 * 3x1) + 200 matrix multiplications (3x1 * 1x1) + 1 matrix multiplications (2x3 * 3x1) + 37109 virtual machine rounds + Compilation finished, running program... + Using statistical security parameter 40 Trying to run 64-bit computation Using SGD done with epoch 99 [0, 1] The following benchmarks are including preprocessing (offline phase). - Time = 0.0250086 seconds + Time = 0.0390132 seconds See `the documentation `_ diff --git a/doc/multinode.rst b/doc/multinode.rst index 69cb47be0..8ad415ec9 100644 --- a/doc/multinode.rst +++ b/doc/multinode.rst @@ -1,3 +1,6 @@ +.. _multinode: + + Multinode Computation Example ============================= diff --git a/doc/optimization.rst b/doc/optimization.rst new file mode 100644 index 000000000..f223365cf --- /dev/null +++ b/doc/optimization.rst @@ -0,0 +1,148 @@ +Compiler Optimizations +====================== + +A core tenet of MP-SPDZ is to reduce the number of communication +rounds by parallelizing operations. The are two mechanisms for this, +called merging and CISC. Merging is the more basic mechanism, doing +just round reduction and nothing else. CISC in addition uses +vectorization, which reduces the cost during compilation and the +footprint. In the following, we will explain the two mechanisms using +examples. + + +Merging +------- + +Merging is the original mechanism forming the basis for `Keller et +al. `_ It is based on extending the +common three-operand instruction design to any number of argument +tuples. Consider the following high-level code:: + + sint(1) * sint(2) + sint(3) * sint(4) + +Putting this into ``Programs/Source/2mul.py`` and running +``./compile.py 2mul -a debug`` results in the following content in +``debug-2mul-0``:: + + # 2mul-0--0 + ldsi s2, 1 # 0 + ldsi s3, 2 # 1 + ldsi s4, 3 # 2 + ldsi s5, 4 # 3 + muls 8, 1, s0, s2, s3, 1, s1, s4, s5 # 4 + +You can see all the inputs loaded in the first few lines appear in the +last one. The arguments are follows (see also +:py:class:`Compiler.instructions.muls`): 8 indicates the number of +arguments to follow, and every four arguments correspond to one +multiplication. The first number withing every four arguments is the +vector size, following by the result register and the input +registers. + +On the backend side, the multiplication is implemented in the +:cpp:func:`muls` member function in +:download:`../Processor/Processor.hpp`. Notice that +``protocol.exchange()`` (where the communication happens) is only +called once. See :ref:`low-level` for more information. + +Within the compiler, the optimization is executed in the +:py:class:`Merger` class in :download:`../Compiler/allocator.py`. The +:py:func:`dependency_graph` member function builds a dependency graph +for all instructions, and instructions are merged in +:py:func:`longest_path_merge`. + + +CISC +---- + +This mechanism takes its name from `Complex Instruction Set Computer +`_, +which refers to the fact that more complex operations are treated as +instructions internally before being merged and spelled in vectorized +instructions. For a simple example, consider the following high-level +code:: + + program.use_trunc_pr = True + sfix(1) * sfix(2) + sfix(3) * sfix(4) + +The resulting file ``debug-2fmul-0`` starts similarly to the one +above:: + + # 2fmul-0-cisc-1 + ldsi s8, 65536 # 0 + ldsi s9, 131072 # 1 + ldsi s10, 196608 # 2 + ldsi s11, 262144 # 3 + muls 8, 1, s0, s8, s9, 1, s7, s10, s11 # 4 + +This corresponds to the fact that fixed-point multiplication starts +with an integer multiplication before truncation. The file continues +as follows:: + + # 2fmul-0--2 + concats 5, s3(2), 1, s0, 1, s7 # 5 + vmovs 2, s3(2), s3(2) # 6 + # 2fmul-0-update-3 + jmp 11 # 7 + +The first instruction creates a vector of size from the multiplication +results, and the last instruction jumps over the next code block just +to end up here:: + + # 2fmul-0-end-TruncPr(2)_47_16-5 + ldint ci0, 3 # 19 + stmint ci0, 8192 # 20 + jmp -14 # 21 + +This code prepares for a function call. Functions are used for code +reusability, i.e., the same code only has to be compiled once per tape +and vector size. The last instruction jumps to the start of the +function here (the code block jumped over above):: + + # 2fmul-0-begin-TruncPr(2)_47_16-4 + ldarg ci1 # 8 + vldi 2, c0(2), 32768 # 9 + vmulci 2, c2(2), c0(2), 2147483647 # 10 + vaddci 2, c0(2), c2(2), 32768 # 11 + vaddm 2, s5(2), s3(2), c0(2) # 12 + vtrunc_pr 2, 4, s3(2), s5(2), 47, 16 # 13 + vsubsi 2, s5(2), s3(2), 1073741824 # 14 + vmovs 2, s3(2), s5(2) # 15 + vmovs 2, s1(2), s3(2) # 16 + ldmint ci1, 8192 # 17 + jmpi ci1 # 18 + +The actual truncation happens in the vectorized instructions starting +with v. The only communication-relevant instruction is (v)trunc_pr, +where truncation by 16 bits is done via a protocol defined in the +virtual machine. For example, the implementation for Rep3 is found in +:download:`../Protocols/Replicated.hpp`. The other vectorized +instructions are required to turn negative values into positive ones, +which is a precondition for the protocol. See Protocol in 3.1 in +`Catrina and Saxena `_ for an +explanation. Lastly, the last two instructions load where to jump back +to, which is here:: + + # 2fmul-0-call-TruncPr(2)_47_16-6 + picks s0, s1(2), 0, 1 # 22 + picks s0, s1(2), 1, 1 # 23 + +The two instructions extract the results from the vector, which makes +them available individually for further computation. + +While this example saves relatively little by using the CISC +functionality, this isn't the case for other usages. For example, +removing ``program.use_trunc_pr`` results in several hundred +instructions, and more involved mathematical functions corresponds to +thousands of instructions. + +Internally, the CISC functionality is in implemented in functions +decorators in :download:`../Compiler/instructions_base.py`. Our +example uses :py:func:`ret_cisc` on :py:func:`TruncPr` in +:download:`../Compiler/floatingpoint.py`. This is because +:py:func:`TruncPr` returns an :py:class:`sint`. Other decorators are +:py:func:`cisc` (the result is stored in the first argument, which +must be an :py:class:`sint`) and :py:func:`sfix_cisc` (the result and +all arguments are instances of :py:class:`sfix`). diff --git a/doc/preprocessing.rst b/doc/preprocessing.rst index 74a686eb6..71f4964b4 100644 --- a/doc/preprocessing.rst +++ b/doc/preprocessing.rst @@ -57,6 +57,7 @@ follows: - Length to follow (little-endian 8-byte number) - Protocol descriptor - Domain descriptor +- MAC if applicable The protocol descriptor is defined by ``::type_string()``. For SPDZ modulo a prime it is ``SPDZ gfp``. @@ -81,11 +82,14 @@ As an example, the following output of ``hexdump -C`` describes SPDZ modulo the default 128-bit prime (170141183460469231731687303715885907969):: - 00000000 1d 00 00 00 00 00 00 00 53 50 44 5a 20 67 66 70 |........SPDZ gfp| + 00000000 2d 00 00 00 00 00 00 00 53 50 44 5a 20 67 66 70 |-.......SPDZ gfp| 00000010 00 10 00 00 00 80 00 00 00 00 00 00 00 00 00 00 |................| - 00000020 00 00 1b 80 01 |.....| - 00000025 + 00000020 00 00 1b 80 01 3a ed c2 28 c0 3d 5e 24 8f 2c a5 |.....:..(.=^$.,.| + 00000030 9b d6 2d 83 12 +The last 128 bits denote the MAC and will differ from instance to +instance. The MAC is stored to avoid errors that are hard to track +otherwise. The actual data is stored is by simple concatenation. For example, triples are stored as repetitions of ``a, b, ab``, and daBits are diff --git a/doc/troubleshooting.rst b/doc/troubleshooting.rst index 7517312c7..b9f918615 100644 --- a/doc/troubleshooting.rst +++ b/doc/troubleshooting.rst @@ -113,6 +113,14 @@ restore the representation after a multiplication. See `Catrina and Saxena to deterministic rounding by calling ``sfix.round_nearest = True``. +Only party 0 produces outputs +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This is to improve readability when running all parties in the same +terminal. You can activate outputs on other parties using ``-OF .`` as +an argument to a virtual machine (``*-party.x``). + + Order of memory instructions not preserved ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~