From cd25c2e9f192a14be2430f48e8d4e3855cb68dc0 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Wed, 9 Nov 2022 11:21:34 +1100 Subject: [PATCH] Decision tree training. --- BMR/Register.h | 3 + CHANGELOG.md | 11 + CONFIG | 4 +- Compiler/GC/instructions.py | 49 ++ Compiler/GC/types.py | 59 +- Compiler/allocator.py | 6 +- Compiler/circuit.py | 2 - Compiler/circuit_oram.py | 3 +- Compiler/compilerLib.py | 7 + Compiler/decision_tree.py | 504 ++++++++++++++++++ Compiler/floatingpoint.py | 1 + Compiler/instructions.py | 13 + Compiler/instructions_base.py | 15 +- Compiler/library.py | 6 +- Compiler/ml.py | 17 +- Compiler/mpc_math.py | 32 ++ Compiler/non_linear.py | 2 + Compiler/oram.py | 11 +- Compiler/program.py | 27 +- Compiler/sorting.py | 17 +- Compiler/sqrt_oram.py | 21 +- Compiler/types.py | 81 ++- ECDSA/Fake-ECDSA.cpp | 1 + ECDSA/P256Element.cpp | 19 +- ECDSA/P256Element.h | 6 +- ECDSA/fake-spdz-ecdsa-party.cpp | 1 + ECDSA/hm-ecdsa-party.hpp | 3 + ECDSA/ot-ecdsa-party.hpp | 1 + ExternalIO/README.md | 2 +- FHEOffline/PairwiseSetup.cpp | 8 + FHEOffline/Prover.cpp | 1 + GC/BitAdder.hpp | 3 +- GC/BitPrepFiles.h | 6 +- GC/FakeSecret.h | 7 + GC/Instruction.h | 1 + GC/Machine.h | 2 +- GC/Machine.hpp | 4 +- GC/Memory.h | 2 + GC/NoShare.h | 1 + GC/PersonalPrep.hpp | 5 +- GC/PostSacriBin.cpp | 1 + GC/Processor.h | 1 + GC/Processor.hpp | 41 +- GC/Program.h | 2 + GC/Secret.h | 3 + GC/Semi.cpp | 36 ++ GC/Semi.h | 31 ++ GC/SemiPrep.cpp | 27 +- GC/SemiPrep.h | 6 +- GC/SemiSecret.h | 5 + GC/SemiSecret.hpp | 55 ++ GC/ShareParty.h | 2 - GC/ShareParty.hpp | 2 - GC/ShareSecret.h | 1 + GC/ShareSecret.hpp | 10 +- GC/ShareThread.h | 5 +- GC/ShareThread.hpp | 49 +- GC/ThreadMaster.hpp | 1 + GC/TinierSharePrep.hpp | 2 +- GC/TinyMC.h | 2 +- GC/TinyPrep.hpp | 2 + GC/instructions.h | 1 + License.txt | 30 +- Machines/MalRep.hpp | 2 + Machines/Rep.hpp | 5 +- Machines/dealer-ring-party.cpp | 8 +- Machines/emulate.cpp | 1 + Machines/malicious-rep-bin-party.cpp | 2 + Machines/mascot-offline.cpp | 1 + Machines/no-party.cpp | 2 + Machines/ps-rep-bin-party.cpp | 3 + Machines/real-bmr-party.cpp | 1 + Machines/replicated-bin-party.cpp | 2 + Machines/replicated-ring-party.cpp | 1 - Machines/sy-rep-field-party.cpp | 3 +- Machines/sy-rep-ring-party.cpp | 3 +- Machines/sy-shamir-party.cpp | 1 + Machines/tinier-party.cpp | 1 + Makefile | 69 ++- Math/BitVec.h | 4 + Math/Square.hpp | 11 + Math/Zp_Data.h | 4 +- Math/field_types.h | 3 +- Math/mpn_fixed.h | 14 - Networking/data.h | 2 +- OT/BaseOT.h | 2 +- OT/BitMatrix.h | 3 + OT/BitMatrix.hpp | 4 +- OT/MamaRectangle.h | 2 + OT/NPartyTripleGenerator.h | 5 +- OT/NPartyTripleGenerator.hpp | 43 +- OT/OTCorrelator.hpp | 2 +- OT/OTExtensionWithMatrix.cpp | 7 +- OT/OTExtensionWithMatrix.h | 6 +- OT/OTMultiplier.h | 2 + OT/OTMultiplier.hpp | 58 ++ Processor/BaseMachine.cpp | 16 +- Processor/BaseMachine.h | 3 +- Processor/Data_Files.hpp | 17 +- Processor/Instruction.h | 1 + Processor/Instruction.hpp | 46 +- Processor/Machine.h | 2 +- Processor/Machine.hpp | 3 +- Processor/Online-Thread.hpp | 10 +- Processor/Processor.h | 2 +- Processor/Processor.hpp | 5 +- Processor/Program.h | 2 + Processor/ThreadQueues.cpp | 26 +- Processor/instructions.h | 7 + Programs/Source/adult.mpc | 54 ++ Programs/Source/bench-dt.mpc | 32 ++ Programs/Source/benchmark_secureNN.mpc | 7 +- Programs/Source/gc_oram.mpc | 3 - Programs/Source/mnist_full_A.mpc | 1 + Programs/Source/spect.mpc | 49 ++ Programs/Source/test_gc.mpc | 2 +- Protocols/Beaver.h | 1 + Protocols/Beaver.hpp | 10 +- Protocols/DabitSacrifice.hpp | 3 +- Protocols/DealerMC.h | 2 +- Protocols/DealerMC.hpp | 4 +- Protocols/DealerPrep.hpp | 1 + Protocols/FakeProtocol.h | 47 ++ Protocols/HemiMatrixPrep.hpp | 7 +- Protocols/HemiPrep.h | 9 +- Protocols/HemiPrep.hpp | 54 ++ Protocols/HighGearKeyGen.hpp | 2 +- Protocols/LowGearKeyGen.hpp | 1 + Protocols/MAC_Check.h | 6 +- Protocols/MAC_Check.hpp | 6 +- Protocols/MAC_Check_Base.h | 2 +- Protocols/MAC_Check_Base.hpp | 2 +- Protocols/MalRepRingPrep.hpp | 3 +- Protocols/MaliciousRepPrep.hpp | 5 + Protocols/MascotPrep.hpp | 14 - Protocols/PostSacriRepRingShare.h | 1 + Protocols/ProtocolSetup.h | 8 + Protocols/Rep3Share.h | 5 +- Protocols/Rep3Share2k.h | 4 +- Protocols/Rep3Shuffler.h | 33 ++ Protocols/Rep3Shuffler.hpp | 131 +++++ Protocols/Replicated.h | 7 + Protocols/Replicated.hpp | 9 +- Protocols/ReplicatedInput.h | 6 +- Protocols/ReplicatedInput.hpp | 3 +- Protocols/ReplicatedPrep.h | 2 + Protocols/ReplicatedPrep.hpp | 96 +++- Protocols/SecureShuffle.hpp | 2 +- Protocols/Semi.h | 15 +- Protocols/SemiInput.h | 18 +- Protocols/SemiInput.hpp | 22 +- Protocols/SemiMC.h | 6 +- Protocols/SemiMC.hpp | 29 +- Protocols/SemiPrep.h | 10 +- Protocols/SemiPrep.hpp | 31 +- Protocols/SemiPrep2k.h | 6 + .../{ReplicatedPrep2k.h => SemiRep3Prep.h} | 21 +- Protocols/Shamir.h | 3 +- Protocols/Shamir.hpp | 22 +- Protocols/ShamirInput.h | 14 +- Protocols/ShamirInput.hpp | 44 +- Protocols/ShamirMC.h | 2 +- Protocols/ShamirMC.hpp | 2 +- Protocols/ShamirShare.h | 1 + Protocols/ShuffleSacrifice.hpp | 6 +- Protocols/SpdzWiseMC.h | 2 +- README.md | 12 +- Scripts/build.sh | 3 +- Scripts/compile-for-emulation.sh | 3 + Scripts/emulate-append.sh | 7 + Scripts/run-common.sh | 11 +- Scripts/test_tutorial.sh | 1 + Scripts/tldr.sh | 5 + Tools/ExecutionStats.cpp | 5 +- Tools/names.cpp | 2 +- Utils/Check-Offline.cpp | 1 + Utils/binary-example.cpp | 1 + Utils/l2h-example.cpp | 1 + azure-pipelines.yml | 2 +- doc/Compiler.rst | 17 + doc/Doxyfile | 2 +- doc/compilation.rst | 5 + doc/index.rst | 1 + doc/io.rst | 2 + doc/machine-learning.rst | 3 + doc/non-linear.rst | 6 +- doc/troubleshooting.rst | 10 + 187 files changed, 2356 insertions(+), 328 deletions(-) create mode 100644 Compiler/decision_tree.py create mode 100644 GC/Semi.cpp create mode 100644 GC/Semi.h create mode 100644 Programs/Source/adult.mpc create mode 100644 Programs/Source/bench-dt.mpc create mode 100644 Programs/Source/spect.mpc create mode 100644 Protocols/Rep3Shuffler.h create mode 100644 Protocols/Rep3Shuffler.hpp rename Protocols/{ReplicatedPrep2k.h => SemiRep3Prep.h} (51%) create mode 100755 Scripts/compile-for-emulation.sh create mode 100755 Scripts/emulate-append.sh diff --git a/BMR/Register.h b/BMR/Register.h index 6a15a720c..4def65901 100644 --- a/BMR/Register.h +++ b/BMR/Register.h @@ -235,6 +235,9 @@ class Phase template static void ands(T& processor, const vector& args) { processor.ands(args); } template + static void andrsvec(T& processor, const vector& args) + { processor.andrsvec(args); } + template static void xors(T& processor, const vector& args) { processor.xors(args); } template static void inputb(T& processor, const vector& args) { processor.input(args); } diff --git a/CHANGELOG.md b/CHANGELOG.md index e8e015348..f201d4640 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,16 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.3.4 (Nov 9, 2022) + +- Decision tree learning +- Optimized oblivious shuffle in Rep3 +- Optimized daBit generation in Rep3 and semi-honest HE-based 2PC +- Optimized element-vector AND in SemiBin +- Optimized input protocol in Shamir-based protocols +- Square-root ORAM (@Quitlox) +- Improved ORAM in binary circuits +- UTF-8 outputs + ## 0.3.3 (Aug 25, 2022) - Use SoftSpokenOT to avoid unclear security of KOS OT extension candidate diff --git a/CONFIG b/CONFIG index fb9db2009..0d41c9ef7 100644 --- a/CONFIG +++ b/CONFIG @@ -67,8 +67,11 @@ endif # MOD = -DMAX_MOD_SZ=10 -DGFP_MOD_SZ=5 LDLIBS = -lmpirxx -lmpir -lsodium $(MY_LDLIBS) +LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib LDLIBS += -lboost_system -lssl -lcrypto +CFLAGS += -I./local/include + ifeq ($(USE_NTL),1) CFLAGS += -DUSE_NTL LDLIBS := -lntl $(LDLIBS) @@ -100,5 +103,4 @@ ifeq ($(USE_KOS),1) CFLAGS += -DUSE_KOS else CFLAGS += -std=c++17 -LDLIBS += -llibOTe -lcryptoTools endif diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index 2b5ec46ad..73a8af216 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -13,6 +13,7 @@ import Compiler.tools as tools import collections import itertools +import math class SecretBitsAF(base.RegisterArgFormat): reg_type = 'sb' @@ -50,6 +51,7 @@ class ClearBitsAF(base.RegisterArgFormat): INPUTBVEC = 0x247, SPLIT = 0x248, CONVCBIT2S = 0x249, + ANDRSVEC = 0x24a, XORCBI = 0x210, BITDECC = 0x211, NOTCB = 0x212, @@ -155,6 +157,52 @@ class andrs(BinaryVectorInstruction): def add_usage(self, req_node): req_node.increment(('bit', 'triple'), sum(self.args[::4])) + req_node.increment(('bit', 'mixed'), + sum(int(math.ceil(x / 64)) for x in self.args[::4])) + +class andrsvec(base.VarArgsInstruction, base.Mergeable, + base.DynFormatInstruction): + """ Constant-vector AND of secret bit registers (vectorized version). + + :param: total number of arguments to follow (int) + :param: number of arguments to follow for one operation / + operation vector size plus three (int) + :param: vector size (int) + :param: result vector (sbit) + :param: (repeat)... + :param: constant operand (sbits) + :param: vector operand + :param: (repeat)... + :param: (repeat from number of arguments to follow for one operation)... + + """ + code = opcodes['ANDRSVEC'] + + def __init__(self, *args, **kwargs): + super(andrsvec, self).__init__(*args, **kwargs) + for i, n in self.bases(iter(self.args)): + size = self.args[i + 1] + for x in self.args[i + 2:i + n]: + assert x.n == size + + @classmethod + def dynamic_arg_format(cls, args): + yield 'int' + for i, n in cls.bases(args): + yield 'int' + n_args = (n - 3) // 2 + assert n_args > 0 + for j in range(n_args): + yield 'sbw' + for j in range(n_args + 1): + yield 'sb' + yield 'int' + + def add_usage(self, req_node): + for i, n in self.bases(iter(self.args)): + size = self.args[i + 1] + req_node.increment(('bit', 'triple'), size * (n - 3) // 2) + req_node.increment(('bit', 'mixed'), size) class ands(BinaryVectorInstruction): """ Bitwise AND of secret bit register vector. @@ -605,6 +653,7 @@ def dynamic_arg_format(cls, args): for i, n in cls.bases(args): yield 'int' yield 'p' + assert n > 3 for j in range(n - 3): yield 'sbw' yield 'int' diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index d092a4740..e895061a5 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -652,7 +652,7 @@ class sbitvec(_vec, _bit): You can access the rows by member :py:obj:`v` and the columns by calling :py:obj:`elements`. - There are three ways to create an instance: + There are four ways to create an instance: 1. By transposition:: @@ -685,6 +685,11 @@ class sbitvec(_vec, _bit): This should output:: [1, 0, 1] + + 4. Private input:: + + x = sbitvec.get_type(32).get_input_from(player) + """ bit_extend = staticmethod(lambda v, n: v[:n] + [0] * (n - len(v))) is_clear = False @@ -904,6 +909,34 @@ def half_adder(self, other): def __mul__(self, other): if isinstance(other, int): return self.from_vec(x * other for x in self.v) + if isinstance(other, sbitvec): + if len(other.v) == 1: + other = other.v[0] + elif len(self.v) == 1: + self, other = other, self.v[0] + else: + raise CompilerError('no operand of lenght 1: %d/%d', + (len(self.v), len(other.v))) + if not isinstance(other, sbits): + return NotImplemented + ops = [] + for x in self.v: + if not util.is_zero(x): + assert x.n == other.n + ops.append(x) + if ops: + prods = [sbits.get_type(other.n)() for i in ops] + inst.andrsvec(3 + 2 * len(ops), other.n, *prods, other, *ops) + res = [] + i = 0 + for x in self.v: + if util.is_zero(x): + res.append(0) + else: + res.append(prods[i]) + i += 1 + return sbitvec.from_vec(res) + __rmul__ = __mul__ def __add__(self, other): return self.from_vec(x + y for x, y in zip(self.v, other)) def bit_and(self, other): @@ -945,6 +978,13 @@ def expand(self, other, expand=True): else: res.append([x.expand(m) if (expand and isinstance(x, bits)) else x for x in y.v]) return res + def demux(self): + if len(self) == 1: + return sbitvec.from_vec([self.v[0].bit_not(), self.v[0]]) + a = sbitvec.from_vec(self.v[:len(self) // 2]).demux() + b = sbitvec.from_vec(self.v[len(self) // 2:]).demux() + prod = [a * bb for bb in b.v] + return sbitvec.from_vec(reduce(operator.add, (x.v for x in prod))) class bit(object): n = 1 @@ -1243,20 +1283,19 @@ def __mul__(self, other): return other * self.v[0] elif isinstance(other, sbitfixvec): return NotImplemented - _, other_bits = self.expand(other, False) + my_bits, other_bits = self.expand(other, False) + matrix = [] m = float('inf') - for x in itertools.chain(self.v, other_bits): + for x in itertools.chain(my_bits, other_bits): try: m = min(m, x.n) except: pass - if m == 1: - op = operator.mul - else: - op = operator.and_ - matrix = [] for i, b in enumerate(other_bits): - matrix.append([op(x, b) for x in self.v[:len(self.v)-i]]) + if m == 1: + matrix.append([x * b for x in my_bits[:len(self.v)-i]]) + else: + matrix.append((sbitvec.from_vec(my_bits[:len(self.v)-i]) * b).v) v = sbitint.wallace_tree_from_matrix(matrix) return self.from_vec(v[:len(self.v)]) __rmul__ = __mul__ @@ -1366,7 +1405,7 @@ class cls(_fix): cls.set_precision(f, k) return cls._new(cls.int_type(other), k, f) -class sbitfixvec(_fix): +class sbitfixvec(_fix, _vec): """ Vector of fixed-point numbers for parallel binary computation. Use :py:obj:`set_precision()` to change the precision. diff --git a/Compiler/allocator.py b/Compiler/allocator.py index bf431ca38..e5c99a7be 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -261,6 +261,7 @@ def longest_paths_merge(self): instructions = self.instructions merge_nodes = self.open_nodes depths = self.depths + self.req_num = defaultdict(lambda: 0) if not merge_nodes: return 0 @@ -281,6 +282,7 @@ def longest_paths_merge(self): print('Merging %d %s in round %d/%d' % \ (len(merge), t.__name__, i, len(merges))) self.do_merge(merge) + self.req_num[t.__name__, 'round'] += 1 preorder = None @@ -530,7 +532,9 @@ def eliminate_dead_code(self): can_eliminate_defs = True for reg in inst.get_def(): for dup in reg.duplicates: - if not dup.can_eliminate: + if not (dup.can_eliminate and reduce( + operator.and_, + (x.can_eliminate for x in dup.vector), True)): can_eliminate_defs = False break # remove if instruction has result that isn't used diff --git a/Compiler/circuit.py b/Compiler/circuit.py index 9c4187f75..41e4df9ee 100644 --- a/Compiler/circuit.py +++ b/Compiler/circuit.py @@ -137,8 +137,6 @@ def sha3_256(x): 0x4a43f8804b0ad882fa493be44dff80f562d661a05647c15166d71ebff8c6ffa7 0xf0d7aa0ab2d92d580bb080e17cbb52627932ba37f085d3931270d31c39357067 - Note that :py:obj:`sint` to :py:obj:`sbitvec` conversion is only - implemented for computation modulo a power of two. """ global Keccak_f diff --git a/Compiler/circuit_oram.py b/Compiler/circuit_oram.py index f5ddebfd6..a2cada540 100644 --- a/Compiler/circuit_oram.py +++ b/Compiler/circuit_oram.py @@ -1,5 +1,6 @@ -from Compiler.path_oram import * +from Compiler.oram import * +from Compiler.path_oram import PathORAM, XOR from Compiler.util import bit_compose def first_diff(a_bits, b_bits): diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index 4a4706ff6..c304ebc46 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -125,6 +125,13 @@ def build_option_parser(self): default=defaults.binary, help="bit length of sint in binary circuit (default: 0 for arithmetic)", ) + parser.add_option( + "-G", + "--garbled-circuit", + dest="garbled", + action="store_true", + help="compile for binary circuits only (default: false)", + ) parser.add_option( "-F", "--field", diff --git a/Compiler/decision_tree.py b/Compiler/decision_tree.py new file mode 100644 index 000000000..89e3fe5c7 --- /dev/null +++ b/Compiler/decision_tree.py @@ -0,0 +1,504 @@ +from Compiler.types import * +from Compiler.sorting import * +from Compiler.library import * +from Compiler import util, oram + +from itertools import accumulate +import math + +debug = False +debug_split = False +debug_layers = False +max_leaves = None + +def get_type(x): + if isinstance(x, (Array, SubMultiArray)): + return x.value_type + elif isinstance(x, (tuple, list)): + x = x[0] + x[-1] + if util.is_constant(x): + return cint + else: + return type(x) + else: + return type(x) + +def PrefixSum(x): + return x.get_vector().prefix_sum() + +def PrefixSumR(x): + tmp = get_type(x).Array(len(x)) + tmp.assign_vector(x) + break_point() + tmp[:] = tmp.get_reverse_vector().prefix_sum() + break_point() + return tmp.get_reverse_vector() + +def PrefixSum_inv(x): + tmp = get_type(x).Array(len(x) + 1) + tmp.assign_vector(x, base=1) + tmp[0] = 0 + return tmp.get_vector(size=len(x), base=1) - tmp.get_vector(size=len(x)) + +def PrefixSumR_inv(x): + tmp = get_type(x).Array(len(x) + 1) + tmp.assign_vector(x) + tmp[-1] = 0 + return tmp.get_vector(size=len(x)) - tmp.get_vector(base=1, size=len(x)) + +class SortPerm: + def __init__(self, x): + B = sint.Matrix(len(x), 2) + B.set_column(0, 1 - x.get_vector()) + B.set_column(1, x.get_vector()) + self.perm = Array.create_from(dest_comp(B)) + def apply(self, x): + res = Array.create_from(x) + reveal_sort(self.perm, res, False) + return res + def unapply(self, x): + res = Array.create_from(x) + reveal_sort(self.perm, res, True) + return res + +def Sort(keys, *to_sort, n_bits=None, time=False): + if time: + start_timer(1) + for k in keys: + assert len(k) == len(keys[0]) + n_bits = n_bits or [None] * len(keys) + bs = Matrix.create_from( + sum([k.get_vector().bit_decompose(nb) + for k, nb in reversed(list(zip(keys, n_bits)))], [])) + res = Matrix.create_from(to_sort) + res = res.transpose() + if time: + start_timer(11) + print_ln('sort') + radix_sort_from_matrix(bs, res) + if time: + stop_timer(11) + stop_timer(1) + return res.transpose() + +def VectMax(key, *data): + def reducer(x, y): + b = x[0] > y[0] + return [b.if_else(xx, yy) for xx, yy in zip(x, y)] + if debug: + key = list(key) + data = [list(x) for x in data] + print_ln('vect max key=%s data=%s', util.reveal(key), util.reveal(data)) + return util.tree_reduce(reducer, zip(key, *data))[1:] + +def GroupSum(g, x): + assert len(g) == len(x) + p = PrefixSumR(x) * g + pi = SortPerm(g.get_vector().bit_not()) + p1 = pi.apply(p) + s1 = PrefixSumR_inv(p1) + d1 = PrefixSum_inv(s1) + d = pi.unapply(d1) * g + return PrefixSum(d) + +def GroupPrefixSum(g, x): + assert len(g) == len(x) + s = get_type(x).Array(len(x) + 1) + s[0] = 0 + s.assign_vector(PrefixSum(x), base=1) + q = get_type(s).Array(len(x)) + q.assign_vector(s.get_vector(size=len(x)) * g) + return s.get_vector(size=len(x), base=1) - GroupSum(g, q) + +def GroupMax(g, keys, *x): + if debug: + print_ln('group max input g=%s keys=%s x=%s', util.reveal(g), + util.reveal(keys), util.reveal(x)) + assert len(keys) == len(g) + for xx in x: + assert len(xx) == len(g) + n = len(g) + m = int(math.ceil(math.log(n, 2))) + keys = Array.create_from(keys) + x = [Array.create_from(xx) for xx in x] + g_new = Array.create_from(g) + g_old = g_new.same_shape() + for d in range(m): + w = 2 ** d + g_old[:] = g_new[:] + break_point() + vsize = n - w + g_new.assign_vector(g_old.get_vector(size=vsize).bit_or( + g_old.get_vector(size=vsize, base=w)), base=w) + b = keys.get_vector(size=vsize) > keys.get_vector(size=vsize, base=w) + for xx in [keys] + x: + a = b.if_else(xx.get_vector(size=vsize), + xx.get_vector(size=vsize, base=w)) + xx.assign_vector(g_old.get_vector(size=vsize, base=w).if_else( + xx.get_vector(size=vsize, base=w), a), base=w) + break_point() + if debug: + print_ln('group max w=%s b=%s a=%s keys=%s x=%s g=%s', w, b.reveal(), + util.reveal(a), util.reveal(keys), + util.reveal(x), g_new.reveal()) + t = sint.Array(len(g)) + t[-1] = 1 + t.assign_vector(g.get_vector(size=n - 1, base=1)) + if debug: + print_ln('group max end g=%s t=%s keys=%s x=%s', util.reveal(g), + util.reveal(t), util.reveal(keys), util.reveal(x)) + return [GroupSum(g, t[:] * xx) for xx in [keys] + x] + +def ModifiedGini(g, y, debug=False): + assert len(g) == len(y) + y = [y.get_vector().bit_not(), y] + u = [GroupPrefixSum(g, yy) for yy in y] + s = [GroupSum(g, yy) for yy in y] + w = [ss - uu for ss, uu in zip(s, u)] + us = sum(u) + ws = sum(w) + uqs = u[0] ** 2 + u[1] ** 2 + wqs = w[0] ** 2 + w[1] ** 2 + res = sfix(uqs) / us + sfix(wqs) / ws + if debug: + print_ln('u0=%s', util.reveal(u[0])) + print_ln('u0=%s', util.reveal(u[1])) + print_ln('us=%s', util.reveal(us)) + print_ln('w0=%s', util.reveal(w[0])) + print_ln('w1=%s', util.reveal(w[1])) + print_ln('ws=%s', util.reveal(ws)) + print_ln('p=%s', util.reveal(p)) + print_ln('q=%s', util.reveal(q)) + print_ln('g=%s y=%s s=%s', + util.reveal(g), util.reveal(y), + util.reveal(s)) + if debug: + print_ln('gini %s %s', str(res), util.reveal(res)) + return res + +MIN_VALUE = -10000 + +def FormatLayer(h, g, *a): + return CropLayer(h, *FormatLayer_without_crop(g, *a)) + +def FormatLayer_without_crop(g, *a): + for x in a: + assert len(x) == len(g) + v = [g.if_else(aa, 0) for aa in a] + v = Sort([g.bit_not()], *v, n_bits=[1]) + return v + +def CropLayer(k, *v): + if max_leaves: + n = min(2 ** k, max_leaves) + else: + n = 2 ** k + return [vv[:min(n, len(vv))] for vv in v] + +def TrainLeafNodes(h, g, y, NID): + assert len(g) == len(y) + assert len(g) == len(NID) + Label = GroupSum(g, y.bit_not()) < GroupSum(g, y) + return FormatLayer(h, g, NID, Label) + +def GroupSame(g, y): + assert len(g) == len(y) + s = GroupSum(g, [sint(1)] * len(g)) + s0 = GroupSum(g, y.bit_not()) + s1 = GroupSum(g, y) + if debug_split: + print_ln('group same g=%s', util.reveal(g)) + print_ln('group same y=%s', util.reveal(y)) + return (s == s0).bit_or(s == s1) + +def GroupFirstOne(g, b): + assert len(g) == len(b) + s = GroupPrefixSum(g, b) + return s * b == 1 + +class TreeTrainer: + """ Decision tree training by `Hamada et al.`_ + + :param x: sample data (by attribute, list or + :py:obj:`~Compiler.types.Matrix`) + :param y: binary labels (list or sint vector) + :param h: height (int) + :param binary: binary attributes instead of continuous + :param attr_lengths: attribute description for mixed data + (list of 0/1 for continuous/binary) + :param n_threads: number of threads (default: single thread) + + .. _`Hamada et al.`: https://arxiv.org/abs/2112.12906 + + """ + def ApplyTests(self, x, AID, Threshold): + m = len(x) + n = len(AID) + assert len(AID) == len(Threshold) + for xx in x: + assert len(xx) == len(AID) + e = sint.Matrix(m, n) + AID = Array.create_from(AID) + @for_range_multithread(self.n_threads, 1, m) + def _(j): + e[j][:] = AID[:] == j + xx = sum(x[j] * e[j] for j in range(m)) + if debug: + print_ln('apply e=%s xx=%s', util.reveal(e), util.reveal(xx)) + return 2 * xx < Threshold + + def AttributeWiseTestSelection(self, g, x, y, time=False, debug=False): + assert len(g) == len(x) + assert len(g) == len(y) + if time: + start_timer(2) + s = ModifiedGini(g, y, debug=debug) + if time: + stop_timer(2) + if debug: + print_ln('gini %s', s.reveal()) + xx = x + t = get_type(x).Array(len(x)) + t[-1] = MIN_VALUE + t.assign_vector(xx.get_vector(size=len(x) - 1) + \ + xx.get_vector(size=len(x) - 1, base=1)) + gg = g + p = sint.Array(len(x)) + p[-1] = 1 + p.assign_vector(gg.get_vector(base=1, size=len(x) - 1).bit_or( + xx.get_vector(size=len(x) - 1) == \ + xx.get_vector(size=len(x) - 1, base=1))) + break_point() + if debug: + print_ln('attribute t=%s p=%s', util.reveal(t), util.reveal(p)) + s = p[:].if_else(MIN_VALUE, s) + t = p[:].if_else(MIN_VALUE, t[:]) + if debug: + print_ln('attribute s=%s t=%s', util.reveal(s), util.reveal(t)) + if time: + start_timer(3) + s, t = GroupMax(gg, s, t) + if time: + stop_timer(3) + if debug: + print_ln('attribute s=%s t=%s', util.reveal(s), util.reveal(t)) + return t, s + + def GlobalTestSelection(self, x, y, g): + assert len(y) == len(g) + for xx in x: + assert(len(xx) == len(g)) + m = len(x) + n = len(y) + u, t = [get_type(x).Matrix(m, n) for i in range(2)] + v = get_type(y).Matrix(m, n) + s = sfix.Matrix(m, n) + @for_range_multithread(self.n_threads, 1, m) + def _(j): + single = not self.n_threads or self.n_threads == 1 + print_ln('run %s', j) + @if_e(self.attr_lengths[j]) + def _(): + u[j][:], v[j][:] = Sort((PrefixSum(g), x[j]), x[j], y, + n_bits=[util.log2(n), 1], time=single) + @else_ + def _(): + u[j][:], v[j][:] = Sort((PrefixSum(g), x[j]), x[j], y, + n_bits=[util.log2(n), None], + time=single) + if self.debug_threading: + print_ln('global sort %s %s %s', j, util.reveal(u[j]), + util.reveal(v[j])) + t[j][:], s[j][:] = self.AttributeWiseTestSelection( + g, u[j], v[j], time=single, debug=self.debug_selection) + if self.debug_threading: + print_ln('global attribute %s %s %s', j, util.reveal(t[j]), + util.reveal(s[j])) + n = len(g) + a, tt = [sint.Array(n) for i in range(2)] + if self.debug_threading: + print_ln('global s=%s', util.reveal(s)) + if self.debug_gini: + print_ln('Gini indices ' + ' '.join(str(i) + ':%s' for i in range(m)), + *(ss[0].reveal() for ss in s)) + start_timer(4) + a[:], tt[:] = VectMax((s[j][:] for j in range(m)), range(m), + (t[j][:] for j in range(m))) + stop_timer(4) + return a[:], tt[:] + + def TrainInternalNodes(self, k, x, y, g, NID): + assert len(g) == len(y) + for xx in x: + assert len(xx) == len(g) + AID, Threshold = self.GlobalTestSelection(x, y, g) + s = GroupSame(g[:], y[:]) + if debug or debug_split: + print_ln('AID=%s', util.reveal(AID)) + print_ln('Threshold=%s', util.reveal(Threshold)) + print_ln('GroupSame=%s', util.reveal(s)) + AID, Threshold = s.if_else(0, AID), s.if_else(MIN_VALUE, Threshold) + b = self.ApplyTests(x, AID, Threshold) + return FormatLayer_without_crop(g[:], NID, AID, Threshold), b + + @method_block + def train_layer(self, k): + x = self.x + y = self.y + g = self.g + NID = self.NID + layer_matrix = self.layer_matrix + self.layer_matrix[k], b = \ + self.TrainInternalNodes(k, x, y, g, NID) + if debug: + print_ln('internal %s %s', + util.reveal(layer_matrix[k]), util.reveal(b)) + if debug_layers: + print_ln('layer %s:', k) + for name, data in zip(('NID', 'AID', 'Thr'), layer_matrix[k]): + print_ln(' %s: %s', name, data.reveal()) + NID[:] = 2 ** k * b + NID + b_not = b.bit_not() + if debug: + print_ln('b_not=%s', b_not.reveal()) + g[:] = GroupFirstOne(g, b_not) + GroupFirstOne(g, b) + y[:], g[:], NID[:], *xx = Sort([b], y, g, NID, *x, n_bits=[1]) + for i, xxx in enumerate(xx): + x[i] = xxx + + def __init__(self, x, y, h, binary=False, attr_lengths=None, + n_threads=None): + assert not (binary and attr_lengths) + if binary: + attr_lengths = [1] * len(x) + else: + attr_lengths = attr_lengths or ([0] * len(x)) + for l in attr_lengths: + assert l in (0, 1) + self.attr_lengths = Array.create_from(regint(attr_lengths)) + Array.check_indices = False + Matrix.disable_index_checks() + for xx in x: + assert len(xx) == len(y) + n = len(y) + self.g = sint.Array(n) + self.g.assign_all(0) + self.g[0] = 1 + self.NID = sint.Array(n) + self.NID.assign_all(1) + self.y = Array.create_from(y) + self.x = Matrix.create_from(x) + self.layer_matrix = sint.Tensor([h, 3, n]) + self.n_threads = n_threads + self.debug_selection = False + self.debug_threading = False + self.debug_gini = True + + def train(self): + """ Train and return decision tree. """ + h = len(self.layer_matrix) + @for_range(h) + def _(k): + self.train_layer(k) + return self.get_tree(h) + + def train_with_testing(self, *test_set): + """ Train decision tree and test against test data. + + :param y: binary labels (list or sint vector) + :param x: sample data (by attribute, list or + :py:obj:`~Compiler.types.Matrix`) + :returns: tree + + """ + for k in range(len(self.layer_matrix)): + self.train_layer(k) + tree = self.get_tree(k + 1) + output_decision_tree(tree) + test_decision_tree('train', tree, self.y, self.x, + n_threads=self.n_threads) + if test_set: + test_decision_tree('test', tree, *test_set, + n_threads=self.n_threads) + return tree + + def get_tree(self, h): + Layer = [None] * (h + 1) + for k in range(h): + Layer[k] = CropLayer(k, *self.layer_matrix[k]) + Layer[h] = TrainLeafNodes(h, self.g[:], self.y[:], self.NID) + return Layer + +def DecisionTreeTraining(x, y, h, binary=False): + return TreeTrainer(x, y, h, binary=binary).train() + +def output_decision_tree(layers): + """ Print decision tree output by :py:class:`TreeTrainer`. """ + print_ln('full model %s', util.reveal(layers)) + for i, layer in enumerate(layers[:-1]): + print_ln('level %s:', i) + for j, x in enumerate(('NID', 'AID', 'Thr')): + print_ln(' %s: %s', x, util.reveal(layer[j])) + print_ln('leaves:') + for j, x in enumerate(('NID', 'result')): + print_ln(' %s: %s', x, util.reveal(layers[-1][j])) + +def pick(bits, x): + if len(bits) == 1: + return bits[0] * x[0] + else: + try: + return x[0].dot_product(bits, x) + except: + return sum(aa * bb for aa, bb in zip(bits, x)) + +def run_decision_tree(layers, data): + """ Run decision tree against sample data. + + :param layers: tree output by :py:class:`TreeTrainer` + :param data: sample data (:py:class:`~Compiler.types.Array`) + :returns: binary label + + """ + h = len(layers) - 1 + index = 1 + for k, layer in enumerate(layers[:-1]): + assert len(layer) == 3 + for x in layer: + assert len(x) <= 2 ** k + bits = layer[0].equal(index, k) + threshold = pick(bits, layer[2]) + key_index = pick(bits, layer[1]) + if key_index.is_clear: + key = data[key_index] + else: + key = pick( + oram.demux(key_index.bit_decompose(util.log2(len(data)))), data) + child = 2 * key < threshold + index += child * 2 ** k + bits = layers[h][0].equal(index, h) + return pick(bits, layers[h][1]) + +def test_decision_tree(name, layers, y, x, n_threads=None): + start_timer(100) + n = len(y) + x = x.transpose().reveal() + y = y.reveal() + guess = regint.Array(n) + truth = regint.Array(n) + correct = regint.Array(2) + parts = regint.Array(2) + layers = [Matrix.create_from(util.reveal(layer)) for layer in layers] + @for_range_multithread(n_threads, 1, n) + def _(i): + guess[i] = run_decision_tree([[part[:] for part in layer] + for layer in layers], x[i]).reveal() + truth[i] = y[i].reveal() + @for_range(n) + def _(i): + parts[truth[i]] += 1 + c = (guess[i].bit_xor(truth[i]).bit_not()) + correct[truth[i]] += c + print_ln('%s for height %s: %s/%s (%s/%s, %s/%s)', name, len(layers) - 1, + sum(correct), n, correct[0], parts[0], correct[1], parts[1]) + stop_timer(100) diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index 94a47f1bf..7786f73c8 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -311,6 +311,7 @@ def BitDecField(a, k, m, kappa, bits_to_compute=None): @instructions_base.ret_cisc def Pow2(a, l, kappa): + comparison.program.curr_tape.require_bit_length(l - 1) m = int(ceil(log(l, 2))) t = BitDec(a, m, m, kappa) return Pow2_from_bits(t) diff --git a/Compiler/instructions.py b/Compiler/instructions.py index a8894a0df..c51318322 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -614,6 +614,18 @@ class submr(base.SubBase): code = base.opcodes['SUBMR'] arg_format = ['sw','c','s'] +@base.vectorize +class prefixsums(base.Instruction): + """ Prefix sum. + + :param: result (sint) + :param: input (sint) + + """ + __slots__ = [] + code = base.opcodes['PREFIXSUMS'] + arg_format = ['sw','s'] + @base.gf2n @base.vectorize class mulc(base.MulBase): @@ -2301,6 +2313,7 @@ def dynamic_arg_format(self, args): yield 'int' for i, n in self.bases(args): yield 's' + field + 'w' + assert n > 2 for j in range(n - 2): yield 's' + field yield 'int' diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index 7a47c46c8..f811e47c8 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -80,6 +80,7 @@ SUBSI = 0x2A, SUBCFI = 0x2B, SUBSFI = 0x2C, + PREFIXSUMS = 0x2D, # Multiplication/division MULC = 0x30, MULM = 0x31, @@ -702,10 +703,16 @@ class ClearIntAF(RegisterArgFormat): reg_type = RegType.ClearInt class IntArgFormat(ArgFormat): + n_bits = 32 + @classmethod def check(cls, arg): - if not isinstance(arg, int) and not arg is None: - raise ArgumentError(arg, 'Expected an integer-valued argument') + if not arg is None: + if not isinstance(arg, int): + raise ArgumentError(arg, 'Expected an integer-valued argument') + if arg >= 2 ** cls.n_bits or arg < -2 ** cls.n_bits: + raise ArgumentError( + arg, 'Immediate value outside of %d-bit range' % cls.n_bits) @classmethod def encode(cls, arg): @@ -718,6 +725,8 @@ def __str__(self): return str(self.i) class LongArgFormat(IntArgFormat): + n_bits = 64 + @classmethod def encode(cls, arg): return list(struct.pack('>Q', arg)) @@ -729,8 +738,6 @@ class ImmediateModpAF(IntArgFormat): @classmethod def check(cls, arg): super(ImmediateModpAF, cls).check(arg) - if arg >= 2**32 or arg < -2**32: - raise ArgumentError(arg, 'Immediate value outside of 32-bit range') class ImmediateGF2NAF(IntArgFormat): @classmethod diff --git a/Compiler/library.py b/Compiler/library.py index 42a5826dc..1f1fd88c3 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -139,7 +139,7 @@ def print_str_if(cond, ss, *args): """ Print string conditionally. See :py:func:`print_ln_if` for details. """ if util.is_constant(cond): if cond: - print_ln(ss, *args) + print_str(ss, *args) else: subs = ss.split('%s') assert len(subs) == len(args) + 1 @@ -1021,9 +1021,11 @@ def write_state_to_memory(r): def f(i): state = tuplify(initializer()) start_block = get_block() + j = i * n_parallel + one = regint(1) for k in range(n_parallel): - j = i * n_parallel + k state = reducer(tuplify(loop_body(j)), state) + j += one if n_parallel > 1 and start_block != get_block(): print('WARNING: parallelization broken ' 'by control flow instruction') diff --git a/Compiler/ml.py b/Compiler/ml.py index 173c2eac0..bc93933dc 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -73,8 +73,13 @@ def log_e(x): return mpc_math.log_fx(x, math.e) +use_mux = False + def exp(x): - return mpc_math.pow_fx(math.e, x) + if use_mux: + return mpc_math.mux_exp(math.e, x) + else: + return mpc_math.pow_fx(math.e, x) def get_limit(x): exp_limit = 2 ** (x.k - x.f - 1) @@ -164,13 +169,16 @@ def softmax(x): return softmax_from_exp(exp_for_softmax(x)[0]) def exp_for_softmax(x): - m = util.max(x) - get_limit(x[0]) + 1 + math.log(len(x), 2) + m = util.max(x) - get_limit(x[0]) + math.log(len(x)) mv = m.expand_to_vector(len(x)) try: x = x.get_vector() except AttributeError: x = sfix(x) - return (x - mv > -get_limit(x)).if_else(exp(x - mv), 0), m + if use_mux: + return exp(x - mv), m + else: + return (x - mv > -get_limit(x)).if_else(exp(x - mv), 0), m def softmax_from_exp(x): return x / sum(x) @@ -2002,6 +2010,9 @@ def from_args(program, layers): return res def __init__(self, report_loss=None): + if get_program().options.binary: + raise CompilerError( + 'machine learning code not compatible with binary circuits') self.tol = 0.000 self.report_loss = report_loss self.X_by_label = None diff --git a/Compiler/mpc_math.py b/Compiler/mpc_math.py index a16214a82..8b5836bc6 100644 --- a/Compiler/mpc_math.py +++ b/Compiler/mpc_math.py @@ -8,6 +8,8 @@ import math +import operator +from functools import reduce from Compiler import floatingpoint from Compiler import types from Compiler import comparison @@ -398,6 +400,36 @@ class my_fix(type(a)): return s.if_else(1 / g, g) +def mux_exp(x, y, block_size=8): + assert util.is_constant_float(x) + from Compiler.GC.types import sbitvec, sbits + bits = sbitvec.from_vec(y.v.bit_decompose(y.k, maybe_mixed=True)).v + sign = bits[-1] + m = math.log(2 ** (y.k - y.f - 1), x) + del bits[int(math.ceil(math.log(m, 2))) + y.f:] + parts = [] + for i in range(0, len(bits), block_size): + one_hot = sbitvec.from_vec(bits[i:i + block_size]).demux().v + exp = [] + try: + for j in range(len(one_hot)): + exp.append(types.cfix.int_rep(x ** (j * 2 ** (i - y.f)), y.f)) + except OverflowError: + pass + exp = list(filter(lambda x: x < 2 ** (y.k - 1), exp)) + bin_part = [0] * max(x.bit_length() for x in exp) + for j in range(len(bin_part)): + for k, (a, b) in enumerate(zip(one_hot, exp)): + bin_part[j] ^= a if util.bit_decompose(b, len(bin_part))[j] \ + else 0 + if util.is_zero(bin_part[j]): + bin_part[j] = sbits.get_type(y.size)(0) + if i == 0: + bin_part[j] = sign.if_else(0, bin_part[j]) + parts.append(y._new(y.int_type(sbitvec.from_vec(bin_part)))) + return util.tree_reduce(operator.mul, parts) + + @types.vectorize @instructions_base.sfix_cisc def log2_fx(x, use_division=True): diff --git a/Compiler/non_linear.py b/Compiler/non_linear.py index 01cb4db58..66e82908d 100644 --- a/Compiler/non_linear.py +++ b/Compiler/non_linear.py @@ -32,6 +32,8 @@ def trunc_pr(self, a, k, m, signed=True): return shift_two(a, m) prog = program.Program.prog if prog.use_trunc_pr: + if not prog.options.ring: + prog.curr_tape.require_bit_length(k + prog.security) if signed and prog.use_trunc_pr != -1: a += (1 << (k - 1)) res = sint() diff --git a/Compiler/oram.py b/Compiler/oram.py index ebc5b8a02..bbaa3938c 100644 --- a/Compiler/oram.py +++ b/Compiler/oram.py @@ -1034,8 +1034,9 @@ def get_n_threads_for_tree(size): class TreeORAM(AbstractORAM): """ Tree ORAM. """ - def __init__(self, size, value_type=sint, value_length=1, entry_size=None, \ + def __init__(self, size, value_type=None, value_length=1, entry_size=None, \ bucket_oram=TrivialORAM, init_rounds=-1): + value_type = value_type or sint print('create oram of size', size) self.bucket_oram = bucket_oram # heuristic bucket size @@ -1722,6 +1723,8 @@ def OptimalORAM(size,*args,**kwargs): :param value_type: :py:class:`sint` (default) / :py:class:`sg2fn` / :py:class:`sfix` """ + if not util.is_constant(size): + raise CompilerError('ORAM size has be a compile-time constant') if get_program().options.binary: return BinaryORAM(size, *args, **kwargs) if optimal_threshold is None: @@ -1772,6 +1775,12 @@ class OptimalPackedORAMWithEmpty(PackedORAMWithEmpty): def test_oram(oram_type, N, value_type=sint, iterations=100): stop_grind() oram = oram_type(N, value_type=value_type, entry_size=32, init_rounds=0) + test_oram_initialized(oram, iterations) + return oram + +def test_oram_initialized(oram, iterations=100): + N = oram.size + value_type = oram.value_type value_type = value_type.get_type(32) index_type = value_type.get_type(log2(N)) start_grind() diff --git a/Compiler/program.py b/Compiler/program.py index f92ab4971..7431d6009 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -29,6 +29,7 @@ bit=2, inverse=3, dabit=4, + mixed=5, ) field_types = dict( @@ -45,6 +46,7 @@ class defaults: ring = 0 field = 0 binary = 0 + garbled = False prime = None galois = 40 budget = 100000 @@ -150,10 +152,11 @@ def __init__(self, args, options=defaults, name=None): gc.ldmsd, gc.stmsd, gc.stmsdci, - gc.xors, gc.andrs, gc.ands, gc.inputb, + gc.inputbvec, + gc.reveal, ] self.use_trunc_pr = False """ Setting whether to use special probabilistic truncation. """ @@ -350,7 +353,8 @@ def write_bytes(self): print("Writing to", sch_filename) sch_file.write(str(self.max_par_tapes()) + "\n") sch_file.write(str(len(nonempty_tapes)) + "\n") - sch_file.write(" ".join(tape.name for tape in nonempty_tapes) + "\n") + sch_file.write(" ".join("%s:%d" % (tape.name, len(tape)) + for tape in nonempty_tapes) + "\n") sch_file.write("1 0\n") sch_file.write("0\n") sch_file.write(" ".join(sys.argv) + "\n") @@ -506,7 +510,8 @@ def security(self, security): self.set_security(security) def optimize_for_gc(self): - pass + import Compiler.GC.instructions as gc + self.to_merge += [gc.xors] def get_tape_counter(self): res = self.tape_counter @@ -686,6 +691,7 @@ def __init__(self, parent, name, scope, exit_condition=None): self.purged = False self.n_rounds = 0 self.n_to_merge = 0 + self.rounds = Tape.ReqNum() self.warn_about_mem = parent.program.warn_about_mem[-1] def __len__(self): @@ -750,6 +756,7 @@ def add_usage(self, req_node): inst.add_usage(req_node) req_node.num["all", "round"] += self.n_rounds req_node.num["all", "inv"] += self.n_to_merge + req_node.num += self.rounds def expand_cisc(self): new_instructions = [] @@ -796,7 +803,14 @@ def init_names(self, name): self.name = name self.outfile = self.program.programs_dir + "/Bytecode/" + self.name + ".bc" + def __len__(self): + if self.purged: + return self.size + else: + return sum(len(block) for block in self.basicblocks) + def purge(self): + self.size = len(self) for block in self.basicblocks: block.purge() self._is_empty = len(self.basicblocks) == 0 @@ -865,6 +879,8 @@ def optimize(self, options): numrounds = merger.longest_paths_merge() block.n_rounds = numrounds block.n_to_merge = len(merger.open_nodes) + if options.verbose: + block.rounds = merger.req_num if merger.counter and self.program.verbose: print( "Block requires", @@ -1113,7 +1129,8 @@ def __mul__(self, other): __rmul__ = __mul__ def set_all(self, value): - if value == float("inf") and self["all", "inv"] > 0: + if Program.prog.options.verbose and \ + value == float("inf") and self["all", "inv"] > 0: print("Going to unknown from %s" % self) res = Tape.ReqNum() for i in self: @@ -1142,6 +1159,8 @@ def t(x): res = [] for req, num in self.items(): domain = t(req[0]) + if num < 0: + num = float('inf') n = "%12.0f" % num if req[1] == "input": res += ["%s %s inputs from player %d" % (n, domain, req[2])] diff --git a/Compiler/sorting.py b/Compiler/sorting.py index 248b3ea07..fc619b732 100644 --- a/Compiler/sorting.py +++ b/Compiler/sorting.py @@ -3,12 +3,7 @@ def dest_comp(B): Bt = B.transpose() - Bt_flat = Bt.get_vector() - St_flat = Bt.value_type.Array(len(Bt_flat)) - St_flat.assign(Bt_flat) - @library.for_range(len(St_flat) - 1) - def _(i): - St_flat[i + 1] = St_flat[i + 1] + St_flat[i] + St_flat = Bt.get_vector().prefix_sum() Tt_flat = Bt.get_vector() * St_flat.get_vector() Tt = types.Matrix(*Bt.sizes, B.value_type) Tt.assign_vector(Tt_flat) @@ -37,8 +32,14 @@ def radix_sort(k, D, n_bits=None, signed=True): bs = types.Matrix.create_from(k.get_vector().bit_decompose(n_bits)) if signed and len(bs) > 1: bs[-1][:] = bs[-1][:].bit_not() - B = types.sint.Matrix(len(k), 2) - h = types.Array.create_from(types.sint(types.regint.inc(len(k)))) + radix_sort_from_matrix(bs, D) + +def radix_sort_from_matrix(bs, D): + n = len(D) + for b in bs: + assert(len(b) == n) + B = types.sint.Matrix(n, 2) + h = types.Array.create_from(types.sint(types.regint.inc(n))) @library.for_range(len(bs)) def _(i): b = bs[i] diff --git a/Compiler/sqrt_oram.py b/Compiler/sqrt_oram.py index 1024ab889..741baaf74 100644 --- a/Compiler/sqrt_oram.py +++ b/Compiler/sqrt_oram.py @@ -10,9 +10,7 @@ from Compiler.program import Program from Compiler.types import (Array, MemValue, MultiArray, _clear, _secret, cint, regint, sint, sintbit) -from oram import demux_array, get_n_threads - -program = Program.prog +from Compiler.oram import demux_array, get_n_threads # Adds messages on completion of heavy computation steps debug = False @@ -44,6 +42,13 @@ def get_n_threads(n_loops): class SqrtOram(Generic[T, B]): + """Oblivious RAM using the "Square-Root" algorithm. + + :param MultiArray data: The data with which to initialize the ORAM. One may provide a MultiArray such that every "block" can hold multiple elements (an Array). + :param sint value_type: The secret type to use, defaults to sint. + :param int k: Leave at 0, this parameter is used to recursively pass down the depth of this ORAM. + :param int period: Leave at None, this parameter is used to recursively pass down the top-level period. + """ # TODO: Preferably this is an Array of vectors, but this is currently not supported # One should regard these structures as Arrays where an entry may hold more # than one value (which is a nice property to have when using the ORAM in @@ -69,14 +74,6 @@ class SqrtOram(Generic[T, B]): t: cint def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type[T] = sint, k: int = 0, period: int | None = None, initialize: bool = True, empty_data=False) -> None: - """Initialize a new Oblivious RAM using the "Square-Root" algorithm. - - Args: - data (MultiArray): The data with which to initialize the ORAM. One may provide a MultiArray such that every "block" can hold multiple elements (an Array). - value_type (sint): The secret type to use, defaults to sint. - k (int): Leave at 0, this parameter is used to recursively pass down the depth of this ORAM. - period (int): Leave at None, this parameter is used to recursively pass down the top-level period. - """ global debug, allow_memory_allocation # Correctly initialize the shuffle (memory) depending on the type of data @@ -103,6 +100,7 @@ def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type self.index_size = util.log2(self.n) + 1 # +1 because signed self.index_type = value_type.get_type(self.index_size) self.entry_length = entry_length + self.size = self.n if debug: lib.print_ln( @@ -632,6 +630,7 @@ def get_position(self, logical_address: T, fake: B) -> _clear: # The item at logical_address # will be in block with index h (block.) # at position l in block.data (block.data) + program = Program.prog h = MemValue(self.value_type.bit_compose(sbits.get_type(program.bit_length)( logical_address).right_shift(pack_log, program.bit_length))) l = self.value_type.bit_compose(sbits(logical_address) & (pack - 1)) diff --git a/Compiler/types.py b/Compiler/types.py index 5e4893e3d..3366e2f40 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -749,7 +749,14 @@ def expand_to_vector(self, size=None): self.mov(res[i], self) return res -class _clear(_register): +class _arithmetic_register(_register): + """ Arithmetic circuit type. """ + def __init__(self, *args, **kwargs): + if program.options.garbled: + raise CompilerError('functionality only available in arithmetic circuits') + super(_arithmetic_register, self).__init__(*args, **kwargs) + +class _clear(_arithmetic_register): """ Clear domain-dependent type. """ __slots__ = [] mov = staticmethod(movc) @@ -1085,6 +1092,8 @@ def __eq__(self, other): def __ne__(self, other): return 1 - (self == other) + equal = lambda self, other, *args, **kwargs: self.__eq__(other) + def __lshift__(self, other): """ Clear left shift. @@ -1836,7 +1845,7 @@ def bit_decompose(self, bit_length): res += x.bit_decompose(64) return res[:bit_length] -class _secret(_register, _secret_structure): +class _secret(_arithmetic_register, _secret_structure): __slots__ = [] mov = staticmethod(set_instruction_type(movs)) @@ -2682,6 +2691,15 @@ def int_div(self, other, bit_length=None, security=None): comparison.Trunc(res, tmp, 2 * k, k, kappa, True) return res + @vectorize + def int_mod(self, other, bit_length=None): + """ Secret integer modulo. + + :param other: sint + :param bit_length: bit length of input (default: global bit length) + """ + return self - other * self.int_div(other, bit_length=bit_length) + def trunc_zeros(self, n_zeros, bit_length=None, signed=True): bit_length = bit_length or program.bit_length return comparison.TruncZeros(self, bit_length, n_zeros, signed) @@ -2808,6 +2826,13 @@ def inverse_permutation(self): res = res.get_vector() return res + @vectorize + def prefix_sum(self): + """ Prefix sum. """ + res = sint() + prefixsums(res, self) + return res + class sintbit(sint): """ :py:class:`sint` holding a bit, supporting binary operations (``&, |, ^``). """ @@ -3940,6 +3965,8 @@ def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType :param n: number of inputs (int) :param client_id: regint :param size: vector size (default 1) + :returns: list of length ``n`` + """ sint_inputs = cls.int_type.receive_from_client(n, client_id, message_type) @@ -3977,6 +4004,8 @@ def load_mem(cls, address, mem_type=None): def conv(cls, other): if isinstance(other, cls): return other + elif isinstance(other, (list, tuple)): + return type(other)(cls.conv(x) for x in other) else: try: return cls.from_sint(other) @@ -4216,7 +4245,7 @@ def conv(cls, other): if isinstance(other, _fix) and (cls.k, cls.f) == (other.k, other.f): return other else: - return cls(other) + return super(_fix, cls).conv(other) @classmethod def _new(cls, other, k=None, f=None): @@ -4524,6 +4553,9 @@ def secure_permute(self, *args, **kwargs): return self._new(self.v.secure_permute(*args, **kwargs), k=self.k, f=self.f) + def prefix_sum(self): + return self._new(self.v.prefix_sum(), k=self.k, f=self.f) + class unreduced_sfix(_single): int_type = sint @@ -5271,6 +5303,8 @@ class Array(_vectorizable): a[:] += b[:] """ + check_indices = True + @classmethod def create_from(cls, l): """ Convert Python iterator or vector to array. Basic type will be taken @@ -5283,7 +5317,9 @@ def create_from(cls, l): """ if isinstance(l, cls): - return l + res = l.same_shape() + res[:] = l[:] + return res if isinstance(l, _number): tmp = l t = type(l) @@ -5304,7 +5340,6 @@ def __init__(self, length, value_type, address=None, debug=None, alloc=True): self.debug = debug self.creator_tape = program.curr_tape self.sink = None - self.check_indices = True if alloc: self.alloc() @@ -5435,7 +5470,10 @@ def _load(self, address): return self.value_type.load_mem(address) def _store(self, value, address): - self.value_type.conv(value).store_in_mem(address) + tmp = self.value_type.conv(value) + if not isinstance(tmp, _vec) and tmp.size != self.value_type.mem_size(): + raise CompilerError('size mismatch in array assignment') + tmp.store_in_mem(address) def __len__(self): return self.length @@ -5506,6 +5544,12 @@ def get_vector(self, base=0, size=None): get_part_vector = get_vector + def get_reverse_vector(self): + """ Return vector with content in reverse order. """ + size = self.length + address = regint.inc(size, size - 1, -1) + return self.value_type.load_mem(self.address + address, size=size) + def get_part(self, base, size): """ Part array. @@ -5605,7 +5649,6 @@ def __sub__(self, other): """ Vector subtraction. :param other: vector or container of same length and type that supports operations with type of this array """ - assert len(self) == len(other) return self.get_vector() - other def __mul__(self, value): @@ -5668,7 +5711,7 @@ def reveal(self): """ Reveal the whole array. :returns: Array of relevant clear type. """ - return Array.create_from(x.reveal() for x in self) + return Array.create_from(self.get_vector().reveal()) def reveal_list(self): """ Reveal as list. """ @@ -6367,13 +6410,15 @@ def transpose(self): res = Matrix(self.sizes[1], self.sizes[0], self.value_type) library.break_point() if self.value_type.n_elements() == 1: - @library.for_range_opt(self.sizes[0]) - def _(j): - res.set_column(j, self[j][:]) + nr = self.sizes[1] + nc = self.sizes[0] + a = regint.inc(nr * nc, 0, nr, 1, nc) + b = regint.inc(nr * nc, 0, 1, nc) + res[:] = self.value_type.load_mem(self.address + a + b) else: - @library.for_range_opt(self.sizes[1]) + @library.for_range_opt(self.sizes[1], budget=100) def _(i): - @library.for_range_opt(self.sizes[0]) + @library.for_range_opt(self.sizes[0], budget=100) def _(j): res[i][j] = self[j][i] library.break_point() @@ -6424,7 +6469,7 @@ def sort(self, key_indices=None, n_bits=None): def randomize(self, *args): """ Randomize according to data type. """ - if self.total_size() < program.options.budget: + if self.total_size() < program.budget: self.assign_vector( self.value_type.get_random(*args, size=self.total_size())) else: @@ -6432,6 +6477,12 @@ def randomize(self, *args): def _(i): self[i].randomize(*args) + def reveal(self): + """ Reveal to :py:obj:`MultiArray` of same shape. """ + res = MultiArray(self.sizes, self.value_type.clear_type) + res[:] = self.get_vector().reveal() + return res + def reveal_list(self): """ Reveal as list. """ return list(self.get_vector().reveal()) @@ -6542,7 +6593,7 @@ def __init__(self, rows, columns, value_type, debug=None, address=None): @staticmethod def create_from(rows): rows = list(rows) - if isinstance(rows[0], (list, tuple)): + if isinstance(rows[0], (list, tuple, Array)): t = type(rows[0][0]) else: t = type(rows[0]) diff --git a/ECDSA/Fake-ECDSA.cpp b/ECDSA/Fake-ECDSA.cpp index 23f81b9ef..ecf7011ba 100644 --- a/ECDSA/Fake-ECDSA.cpp +++ b/ECDSA/Fake-ECDSA.cpp @@ -22,4 +22,5 @@ int main() generate_mac_keys>(key, 2, prefix); make_mult_triples>(key, 2, 1000, false, prefix); make_inverse>(key, 2, 1000, false, prefix); + P256Element::finish(); } diff --git a/ECDSA/P256Element.cpp b/ECDSA/P256Element.cpp index 2c8c776d2..1ff3273f8 100644 --- a/ECDSA/P256Element.cpp +++ b/ECDSA/P256Element.cpp @@ -14,7 +14,14 @@ void P256Element::init() curve = EC_GROUP_new_by_curve_name(NID_secp256k1); assert(curve != 0); auto modulus = EC_GROUP_get0_order(curve); - Scalar::init_field(BN_bn2dec(modulus), false); + auto mod = BN_bn2dec(modulus); + Scalar::init_field(mod, false); + free(mod); +} + +void P256Element::finish() +{ + EC_GROUP_free(curve); } P256Element::P256Element() @@ -42,6 +49,11 @@ P256Element::P256Element(word other) : BN_free(exp); } +P256Element::~P256Element() +{ + EC_POINT_free(point); +} + P256Element& P256Element::operator =(const P256Element& other) { assert(EC_POINT_copy(point, other.point) != 0); @@ -99,7 +111,7 @@ bool P256Element::operator ==(const P256Element& other) const return not cmp; } -void P256Element::pack(octetStream& os) const +void P256Element::pack(octetStream& os, int) const { octet* buffer; size_t length = EC_POINT_point2buf(curve, point, @@ -107,9 +119,10 @@ void P256Element::pack(octetStream& os) const assert(length != 0); os.store_int(length, 8); os.append(buffer, length); + free(buffer); } -void P256Element::unpack(octetStream& os) +void P256Element::unpack(octetStream& os, int) { size_t length = os.get_int(8); assert( diff --git a/ECDSA/P256Element.h b/ECDSA/P256Element.h index 27ea7f75c..bd005c840 100644 --- a/ECDSA/P256Element.h +++ b/ECDSA/P256Element.h @@ -32,11 +32,13 @@ class P256Element : public ValueInterface static string type_string() { return "P256"; } static void init(); + static void finish(); P256Element(); P256Element(const P256Element& other); P256Element(const Scalar& other); P256Element(word other); + ~P256Element(); P256Element& operator=(const P256Element& other); @@ -58,8 +60,8 @@ class P256Element : public ValueInterface bool is_zero() { return *this == P256Element(); } void add(octetStream& os) { *this += os.get(); } - void pack(octetStream& os) const; - void unpack(octetStream& os); + void pack(octetStream& os, int = -1) const; + void unpack(octetStream& os, int = -1); octetStream hash(size_t n_bytes) const; diff --git a/ECDSA/fake-spdz-ecdsa-party.cpp b/ECDSA/fake-spdz-ecdsa-party.cpp index 5bef730d5..ea19c8ee3 100644 --- a/ECDSA/fake-spdz-ecdsa-party.cpp +++ b/ECDSA/fake-spdz-ecdsa-party.cpp @@ -64,4 +64,5 @@ int main(int argc, const char** argv) pShare::MAC_Check::teardown(); Share::MAC_Check::teardown(); + P256Element::finish(); } diff --git a/ECDSA/hm-ecdsa-party.hpp b/ECDSA/hm-ecdsa-party.hpp index fc19e989b..07520f336 100644 --- a/ECDSA/hm-ecdsa-party.hpp +++ b/ECDSA/hm-ecdsa-party.hpp @@ -30,6 +30,8 @@ #include "GC/ThreadMaster.hpp" #include "GC/Secret.hpp" #include "Machines/ShamirMachine.hpp" +#include "Machines/MalRep.hpp" +#include "Machines/Rep.hpp" #include @@ -69,4 +71,5 @@ void run(int argc, const char** argv) preprocessing(tuples, n_tuples, sk, proc, opts); // check(tuples, sk, {}, P); sign_benchmark(tuples, sk, MCp, P, opts, prep_mul ? 0 : &proc); + P256Element::finish(); } diff --git a/ECDSA/ot-ecdsa-party.hpp b/ECDSA/ot-ecdsa-party.hpp index ebf0aea96..550c0ac8a 100644 --- a/ECDSA/ot-ecdsa-party.hpp +++ b/ECDSA/ot-ecdsa-party.hpp @@ -140,4 +140,5 @@ void run(int argc, const char** argv) pShare::MAC_Check::teardown(); T::MAC_Check::teardown(); + P256Element::finish(); } diff --git a/ExternalIO/README.md b/ExternalIO/README.md index f5f418ed9..89328440e 100644 --- a/ExternalIO/README.md +++ b/ExternalIO/README.md @@ -15,7 +15,7 @@ make bankers-bonus-client.x ./compile.py bankers_bonus 1 Scripts/setup-ssl.sh Scripts/setup-clients.sh 3 -Scripts/.sh bankers_bonus-1 & +PLAYERS= Scripts/.sh bankers_bonus-1 & ./bankers-bonus-client.x 0 100 0 & ./bankers-bonus-client.x 1 200 0 & ./bankers-bonus-client.x 2 50 1 diff --git a/FHEOffline/PairwiseSetup.cpp b/FHEOffline/PairwiseSetup.cpp index 019711829..bc890ed21 100644 --- a/FHEOffline/PairwiseSetup.cpp +++ b/FHEOffline/PairwiseSetup.cpp @@ -116,6 +116,14 @@ void secure_init(T& setup, Player& P, U& machine, ofstream file(filename); os.output(file); } + + if (OnlineOptions::singleton.verbose) + { + cerr << "Ciphertext length: " << params.p0().numBits(); + for (size_t i = 1; i < params.FFTD().size(); i++) + cerr << "+" << params.FFTD()[i].get_prime().numBits(); + cerr << endl; + } } template diff --git a/FHEOffline/Prover.cpp b/FHEOffline/Prover.cpp index d92f30806..7127b8c77 100644 --- a/FHEOffline/Prover.cpp +++ b/FHEOffline/Prover.cpp @@ -128,6 +128,7 @@ size_t Prover::NIZKPoK(Proof& P, octetStream& ciphertexts, octetStream& cl bool ok=false; int cnt=0; + (void) cnt; while (!ok) { cnt++; Stage_1(P,ciphertexts,c,pk); diff --git a/GC/BitAdder.hpp b/GC/BitAdder.hpp index 437af179a..a3f821a0d 100644 --- a/GC/BitAdder.hpp +++ b/GC/BitAdder.hpp @@ -44,7 +44,8 @@ void BitAdder::add(vector>& res, const vector>>& summ &supplies); BitAdder().add(res, summands, start, summands[0][0].size(), proc, T::default_length); - queues->wrap_up(job); + if (start) + queues->wrap_up(job); } else add(res, summands, 0, res.size(), proc, length); diff --git a/GC/BitPrepFiles.h b/GC/BitPrepFiles.h index 0a406a461..e8c4d0cf2 100644 --- a/GC/BitPrepFiles.h +++ b/GC/BitPrepFiles.h @@ -6,12 +6,12 @@ #ifndef GC_BITPREPFILES_H_ #define GC_BITPREPFILES_H_ -namespace GC -{ - #include "ShiftableTripleBuffer.h" #include "Processor/Data_Files.h" +namespace GC +{ + template class BitPrepFiles : public ShiftableTripleBuffer, public Sub_Data_Files { diff --git a/GC/FakeSecret.h b/GC/FakeSecret.h index ee7a84462..cd43ae1d5 100644 --- a/GC/FakeSecret.h +++ b/GC/FakeSecret.h @@ -11,11 +11,13 @@ #include "GC/Access.h" #include "GC/ArgTuples.h" #include "GC/NoShare.h" +#include "GC/Processor.h" #include "Math/gf2nlong.h" #include "Tools/SwitchableOutput.h" #include "Processor/DummyProtocol.h" +#include "Processor/Instruction.h" #include "Protocols/FakePrep.h" #include "Protocols/FakeMC.h" #include "Protocols/FakeProtocol.h" @@ -85,6 +87,11 @@ class FakeSecret : public ShareInterface, public BitVec { processor.andrs(args); } static void ands(GC::Processor& processor, const vector& regs); template + static void andrsvec(T&, const vector&) + { throw runtime_error("andrsvec not implemented"); } + static void andm(GC::Processor& processor, const ::Instruction& instruction) + { processor.andm(instruction); } + template static void xors(GC::Processor& processor, const vector& regs) { processor.xors(regs); } template diff --git a/GC/Instruction.h b/GC/Instruction.h index e990f954e..ab6f3f478 100644 --- a/GC/Instruction.h +++ b/GC/Instruction.h @@ -64,6 +64,7 @@ enum INPUTBVEC = 0x247, SPLIT = 0x248, CONVCBIT2S = 0x249, + ANDRSVEC = 0x24a, // write to clear CLEAR_WRITE = 0x210, XORCBI = 0x210, diff --git a/GC/Machine.h b/GC/Machine.h index ecf352cc6..991f00143 100644 --- a/GC/Machine.h +++ b/GC/Machine.h @@ -47,7 +47,7 @@ class Machine : public ::BaseMachine, public Memories ~Machine(); void load_schedule(const string& progname); - void load_program(const string& threadname, const string& filename); + size_t load_program(const string& threadname, const string& filename); template void reset(const U& program); diff --git a/GC/Machine.hpp b/GC/Machine.hpp index 8cfe08f2a..8b555f6c6 100644 --- a/GC/Machine.hpp +++ b/GC/Machine.hpp @@ -35,12 +35,14 @@ Machine::~Machine() } template -void Machine::load_program(const string& threadname, const string& filename) +size_t Machine::load_program(const string& threadname, + const string& filename) { (void)threadname; progs.push_back({}); progs.back().parse_file(filename); reset(progs.back()); + return progs.back().size(); } template diff --git a/GC/Memory.h b/GC/Memory.h index 006a91d94..aa02d563e 100644 --- a/GC/Memory.h +++ b/GC/Memory.h @@ -18,6 +18,8 @@ using namespace std; class NoMemory { +public: + void resize_min(size_t, const char*) {} }; namespace GC diff --git a/GC/NoShare.h b/GC/NoShare.h index 917e71c5e..ec2c85ac0 100644 --- a/GC/NoShare.h +++ b/GC/NoShare.h @@ -154,6 +154,7 @@ class NoShare : public ShareInterface static void xors(Processor&, const vector&) { fail(); } static void ands(Processor&, const vector&) { fail(); } static void andrs(Processor&, const vector&) { fail(); } + static void andrsvec(Processor&, const vector&) { fail(); } static void trans(Processor&, Integer, const vector&) { fail(); } diff --git a/GC/PersonalPrep.hpp b/GC/PersonalPrep.hpp index df1725854..44c4080e0 100644 --- a/GC/PersonalPrep.hpp +++ b/GC/PersonalPrep.hpp @@ -8,6 +8,8 @@ #include "PersonalPrep.h" +#include "Protocols/ShuffleSacrifice.hpp" + namespace GC { @@ -36,7 +38,8 @@ void PersonalPrep::buffer_personal_triples(size_t batch_size, ThreadQueues* q PersonalTripleJob job(&triples, input_player); int start = queues->distribute(job, batch_size); buffer_personal_triples(triples, start, batch_size); - queues->wrap_up(job); + if (start) + queues->wrap_up(job); } else buffer_personal_triples(triples, 0, batch_size); diff --git a/GC/PostSacriBin.cpp b/GC/PostSacriBin.cpp index 742480600..aff82818b 100644 --- a/GC/PostSacriBin.cpp +++ b/GC/PostSacriBin.cpp @@ -10,6 +10,7 @@ #include "Protocols/Replicated.hpp" #include "Protocols/MaliciousRepMC.hpp" #include "Protocols/MalRepRingPrep.hpp" +#include "Protocols/ReplicatedPrep.hpp" #include "ShareSecret.hpp" namespace GC diff --git a/GC/Processor.h b/GC/Processor.h index a5acb950a..e21cf6007 100644 --- a/GC/Processor.h +++ b/GC/Processor.h @@ -91,6 +91,7 @@ class Processor : public ::ProcessorBase, public GC::RuntimeBranching void and_(const vector& args, bool repeat); void andrs(const vector& args) { and_(args, true); } void ands(const vector& args) { and_(args, false); } + void andrsvec(const vector& args); void input(const vector& args); void inputb(typename T::Input& input, ProcessorBase& input_processor, diff --git a/GC/Processor.hpp b/GC/Processor.hpp index 96b2d62d8..87296edfb 100644 --- a/GC/Processor.hpp +++ b/GC/Processor.hpp @@ -15,6 +15,7 @@ using namespace std; #include "GC/Program.h" #include "Access.h" #include "Processor/FixInput.h" +#include "Math/BitVec.h" #include "GC/Machine.hpp" #include "Processor/ProcessorBase.hpp" @@ -205,9 +206,13 @@ template void Processor::mem_op(int n, Memory& dest, const Memory& source, Integer dest_address, Integer source_address) { + dest.check_index(dest_address + n - 1); + source.check_index(source_address + n - 1); + auto d = &dest[dest_address]; + auto s = &source[source_address]; for (int i = 0; i < n; i++) { - dest[dest_address + i] = source[source_address + i]; + *d++ = *s++; } } @@ -302,6 +307,40 @@ void Processor::and_(const vector& args, bool repeat) } } +template +void Processor::andrsvec(const vector& args) +{ + int N_BITS = T::default_length; + auto it = args.begin(); + while (it < args.end()) + { + int n_args = (*it++ - 3) / 2; + int size = *it++; + int base = *(it + n_args); + assert(n_args <= N_BITS); + for (int i = 0; i < size; i += 1) + { + if (i % N_BITS == 0) + for (int j = 0; j < n_args; j++) + S.at(*(it + j) + i / N_BITS).resize_regs( + min(N_BITS, size - i)); + + T y; + y.get_regs().push_back(S.at(base + i / N_BITS).get_reg(i % N_BITS)); + for (int j = 0; j < n_args; j++) + { + T x, tmp; + x.get_regs().push_back( + S.at(*(it + n_args + 1 + j) + i / N_BITS).get_reg( + i % N_BITS)); + tmp.and_(1, x, y, false); + S.at(*(it + j) + i / N_BITS).get_reg(i % N_BITS) = tmp.get_reg(0); + } + } + it += 2 * n_args + 1; + } +} + template void Processor::input(const vector& args) { diff --git a/GC/Program.h b/GC/Program.h index 8280c3f75..5d4b16435 100644 --- a/GC/Program.h +++ b/GC/Program.h @@ -40,6 +40,8 @@ class Program Program(); + size_t size() const { return p.size(); } + // Read in a program void parse_file(const string& filename); void parse(const string& programe); diff --git a/GC/Secret.h b/GC/Secret.h index c4b6e8eb1..9fee3f2ff 100644 --- a/GC/Secret.h +++ b/GC/Secret.h @@ -98,6 +98,9 @@ class Secret static void ands(Processor& processor, const vector& args) { T::ands(processor, args); } template + static void andrsvec(Processor& processor, const vector& args) + { T::andrsvec(processor, args); } + template static void xors(Processor& processor, const vector& args) { T::xors(processor, args); } template diff --git a/GC/Semi.cpp b/GC/Semi.cpp new file mode 100644 index 000000000..e00fed69e --- /dev/null +++ b/GC/Semi.cpp @@ -0,0 +1,36 @@ +/* + * Semi.cpp + * + */ + +#include "Semi.h" +#include "SemiPrep.h" + +#include "Protocols/MAC_Check_Base.hpp" +#include "Protocols/Replicated.hpp" +#include "Protocols/SemiInput.hpp" +#include "Protocols/Beaver.hpp" + +namespace GC +{ + +void Semi::prepare_mult(const SemiSecret& x, const SemiSecret& y, int n, + bool repeat) +{ + if (repeat and OnlineOptions::singleton.live_prep) + { + this->triples.push_back({{}}); + auto& triple = this->triples.back(); + triple = dynamic_cast(prep)->get_mixed_triple(n); + for (int i = 0; i < 2; i++) + triple[1 + i] = triple[1 + i].mask(n); + triple[0] = triple[0].extend_bit().mask(n); + shares.push_back(y - triple[0]); + shares.push_back(x - triple[1]); + lengths.push_back(n); + } + else + prepare_mul(x, y, n); +} + +} /* namespace GC */ diff --git a/GC/Semi.h b/GC/Semi.h new file mode 100644 index 000000000..92f9139aa --- /dev/null +++ b/GC/Semi.h @@ -0,0 +1,31 @@ +/* + * Semi.h + * + */ + +#ifndef GC_SEMI_H_ +#define GC_SEMI_H_ + +#include "Protocols/Beaver.h" +#include "SemiSecret.h" + +namespace GC +{ + +class Semi : public Beaver +{ + typedef Beaver super; + +public: + Semi(Player& P) : + super(P) + { + } + + void prepare_mult(const SemiSecret& x, const SemiSecret& y, int n, + bool repeat); +}; + +} /* namespace GC */ + +#endif /* GC_SEMI_H_ */ diff --git a/GC/SemiPrep.cpp b/GC/SemiPrep.cpp index 9eed3b316..3adc385d5 100644 --- a/GC/SemiPrep.cpp +++ b/GC/SemiPrep.cpp @@ -4,6 +4,7 @@ */ #include "SemiPrep.h" +#include "Semi.h" #include "ThreadMaster.h" #include "OT/NPartyTripleGenerator.h" #include "OT/BitDiagonal.h" @@ -21,7 +22,7 @@ SemiPrep::SemiPrep(DataPositions& usage, bool) : { } -void SemiPrep::set_protocol(Beaver& protocol) +void SemiPrep::set_protocol(SemiSecret::Protocol& protocol) { if (triple_generator) { @@ -53,6 +54,9 @@ SemiPrep::~SemiPrep() { if (triple_generator) delete triple_generator; + this->print_left("mixed triples", mixed_triples.size(), + SemiSecret::type_string(), + this->usage.files.at(DATA_GF2N).at(DATA_MIXED)); } void SemiPrep::buffer_bits() @@ -64,4 +68,25 @@ void SemiPrep::buffer_bits() } } +array SemiPrep::get_mixed_triple(int n) +{ + assert(n < 128); + + if (mixed_triples.empty()) + { + assert(this->triple_generator); + this->triple_generator->generateMixedTriples(); + for (auto& x : this->triple_generator->mixedTriples) + { + this->mixed_triples.push_back({{x[0], x[1], x[2]}}); + } + this->triple_generator->unlock(); + } + + this->count(DATA_MIXED); + auto res = mixed_triples.back(); + mixed_triples.pop_back(); + return res; +} + } /* namespace GC */ diff --git a/GC/SemiPrep.h b/GC/SemiPrep.h index 737cfb986..ee4a7abe0 100644 --- a/GC/SemiPrep.h +++ b/GC/SemiPrep.h @@ -25,11 +25,13 @@ class SemiPrep : public BufferPrep, ShiftableTripleBuffer> mixed_triples; + public: SemiPrep(DataPositions& usage, bool = true); ~SemiPrep(); - void set_protocol(Beaver& protocol); + void set_protocol(SemiSecret::Protocol& protocol); void buffer_triples(); void buffer_bits(); @@ -37,6 +39,8 @@ class SemiPrep : public BufferPrep, ShiftableTripleBuffer get_mixed_triple(int n); + void get(Dtype type, SemiSecret* data) { BufferPrep::get(type, data); diff --git a/GC/SemiSecret.h b/GC/SemiSecret.h index e95554bf5..dc9e0a341 100644 --- a/GC/SemiSecret.h +++ b/GC/SemiSecret.h @@ -19,6 +19,7 @@ namespace GC class SemiPrep; class DealerPrep; +class Semi; template class SemiSecretBase : public V, public ShareSecret @@ -88,9 +89,13 @@ class SemiSecret: public SemiSecretBase> typedef MC MAC_Check; typedef SemiInput Input; typedef SemiPrep LivePrep; + typedef Semi Protocol; static MC* new_mc(typename SemiShare::mac_key_type); + static void andrsvec(Processor& processor, + const vector& args); + SemiSecret() { } diff --git a/GC/SemiSecret.hpp b/GC/SemiSecret.hpp index f6a4d3984..b147cce36 100644 --- a/GC/SemiSecret.hpp +++ b/GC/SemiSecret.hpp @@ -8,6 +8,7 @@ #include "Protocols/MAC_Check_Base.hpp" #include "Protocols/DealerMC.h" #include "SemiSecret.h" +#include "Semi.h" namespace GC { @@ -60,6 +61,60 @@ void SemiSecretBase::trans(Processor& processor, int n_outputs, } } +inline +void SemiSecret::andrsvec(Processor& processor, + const vector& args) +{ + int N_BITS = default_length; + auto protocol = ShareThread::s().protocol; + assert(protocol); + protocol->init_mul(); + auto it = args.begin(); + while (it < args.end()) + { + int n_args = (*it++ - 3) / 2; + int size = *it++; + it += n_args; + int base = *it++; + assert(n_args <= N_BITS); + for (int i = 0; i < size; i += N_BITS) + { + square64 square; + for (int j = 0; j < n_args; j++) + square.rows[j] = processor.S.at(*(it + j) + i / N_BITS).get(); + int n_ops = min(N_BITS, size - i); + square.transpose(n_args, n_ops); + for (int j = 0; j < n_ops; j++) + { + long bit = processor.S.at(base + i / N_BITS).get_bit(j); + auto y_ext = SemiSecret(bit).extend_bit(); + protocol->prepare_mult(square.rows[j], y_ext, n_args, true); + } + } + it += n_args; + } + + protocol->exchange(); + + it = args.begin(); + while (it < args.end()) + { + int n_args = (*it++ - 3) / 2; + int size = *it++; + for (int i = 0; i < size; i += N_BITS) + { + int n_ops = min(N_BITS, size - i); + square64 square; + for (int j = 0; j < n_ops; j++) + square.rows[j] = protocol->finalize_mul(n_args).get(); + square.transpose(n_ops, n_args); + for (int j = 0; j < n_args; j++) + processor.S.at(*(it + j) + i / N_BITS) = square.rows[j]; + } + it += 2 * n_args + 1; + } +} + template void SemiSecretBase::load_clear(int n, const Integer& x) { diff --git a/GC/ShareParty.h b/GC/ShareParty.h index 389efa331..ceda2f01f 100644 --- a/GC/ShareParty.h +++ b/GC/ShareParty.h @@ -6,8 +6,6 @@ #ifndef GC_SHAREPARTY_H_ #define GC_SHAREPARTY_H_ -#include "Protocols/ReplicatedMC.h" -#include "Protocols/MaliciousRepMC.h" #include "ShareSecret.h" #include "Processor.h" #include "Program.h" diff --git a/GC/ShareParty.hpp b/GC/ShareParty.hpp index 28c28710b..57beaec0d 100644 --- a/GC/ShareParty.hpp +++ b/GC/ShareParty.hpp @@ -16,14 +16,12 @@ #include "Protocols/fake-stuff.h" #include "ShareThread.hpp" -#include "RepPrep.hpp" #include "ThreadMaster.hpp" #include "Thread.hpp" #include "ShareSecret.hpp" #include "Protocols/Replicated.hpp" #include "Protocols/ReplicatedPrep.hpp" -#include "Protocols/MaliciousRepMC.hpp" #include "Protocols/fake-stuff.hpp" namespace GC diff --git a/GC/ShareSecret.h b/GC/ShareSecret.h index fb254486b..d8c0c18c5 100644 --- a/GC/ShareSecret.h +++ b/GC/ShareSecret.h @@ -63,6 +63,7 @@ class ShareSecret static void ands(Processor& processor, const vector& args) { and_(processor, args, false); } static void and_(Processor& processor, const vector& args, bool repeat); + static void andrsvec(Processor& processor, const vector& args); static void xors(Processor& processor, const vector& args); static void inputb(Processor& processor, const vector& args) { inputb(processor, processor, args); } diff --git a/GC/ShareSecret.hpp b/GC/ShareSecret.hpp index 12568ef8f..db57e3dd9 100644 --- a/GC/ShareSecret.hpp +++ b/GC/ShareSecret.hpp @@ -8,16 +8,12 @@ #include "ShareSecret.h" -#include "MaliciousRepSecret.h" -#include "Protocols/MaliciousRepMC.h" #include "ShareThread.h" #include "Thread.h" #include "square64.h" #include "Protocols/Share.h" -#include "Protocols/ReplicatedMC.hpp" -#include "Protocols/Beaver.hpp" #include "ShareParty.h" #include "ShareThread.hpp" #include "Thread.hpp" @@ -288,6 +284,12 @@ void ShareSecret::and_( ShareThread::s().and_(processor, args, repeat); } +template +void ShareSecret::andrsvec(Processor& processor, const vector& args) +{ + ShareThread::s().andrsvec(processor, args); +} + template void ShareSecret::xors(Processor& processor, const vector& args) { diff --git a/GC/ShareThread.h b/GC/ShareThread.h index 9c5f4ddb5..70aae69b3 100644 --- a/GC/ShareThread.h +++ b/GC/ShareThread.h @@ -7,11 +7,7 @@ #define GC_SHARETHREAD_H_ #include "Thread.h" -#include "MaliciousRepSecret.h" -#include "RepPrep.h" -#include "SemiHonestRepPrep.h" #include "Processor/Data_Files.h" -#include "Protocols/ReplicatedInput.h" #include @@ -45,6 +41,7 @@ class ShareThread void check(); void and_(Processor& processor, const vector& args, bool repeat); + void andrsvec(Processor& processor, const vector& args); void xors(Processor& processor, const vector& args); }; diff --git a/GC/ShareThread.hpp b/GC/ShareThread.hpp index 27eefda06..b0eea1b0b 100644 --- a/GC/ShareThread.hpp +++ b/GC/ShareThread.hpp @@ -107,7 +107,7 @@ void ShareThread::and_(Processor& processor, else processor.S[right + j].mask(y_ext, n); processor.S[left + j].mask(x_ext, n); - protocol->prepare_mul(x_ext, y_ext, n); + protocol->prepare_mult(x_ext, y_ext, n, repeat); } } @@ -127,6 +127,53 @@ void ShareThread::and_(Processor& processor, } } +template +void ShareThread::andrsvec(Processor& processor, const vector& args) +{ + int N_BITS = T::default_length; + auto& protocol = this->protocol; + assert(protocol); + protocol->init_mul(); + auto it = args.begin(); + T x_ext, y_ext; + while (it < args.end()) + { + int n_args = (*it++ - 3) / 2; + int size = *it++; + it += n_args; + int base = *it++; + assert(n_args <= N_BITS); + for (int i = 0; i < size; i += N_BITS) + { + int n_ops = min(N_BITS, size - i); + for (int j = 0; j < n_args; j++) + { + processor.S.at(*(it + j) + i / N_BITS).mask(x_ext, n_ops); + processor.S.at(base + i / N_BITS).mask(y_ext, n_ops); + protocol->prepare_mul(x_ext, y_ext, n_ops); + } + } + it += n_args; + } + + protocol->exchange(); + + it = args.begin(); + while (it < args.end()) + { + int n_args = (*it++ - 3) / 2; + int size = *it++; + for (int i = 0; i < size; i += N_BITS) + { + int n_ops = min(N_BITS, size - i); + for (int j = 0; j < n_args; j++) + protocol->finalize_mul(n_ops).mask( + processor.S.at(*(it + j) + i / N_BITS), n_ops); + } + it += 2 * n_args + 1; + } +} + template void ShareThread::xors(Processor& processor, const vector& args) { diff --git a/GC/ThreadMaster.hpp b/GC/ThreadMaster.hpp index a754b2e75..03eea7813 100644 --- a/GC/ThreadMaster.hpp +++ b/GC/ThreadMaster.hpp @@ -68,6 +68,7 @@ void ThreadMaster::run() P = new PlainPlayer(N, "main"); machine.load_schedule(progname); + machine.reset(machine.progs[0], memory); for (int i = 0; i < machine.nthreads; i++) threads.push_back(new_thread(i)); diff --git a/GC/TinierSharePrep.hpp b/GC/TinierSharePrep.hpp index e136ec446..d288d8260 100644 --- a/GC/TinierSharePrep.hpp +++ b/GC/TinierSharePrep.hpp @@ -8,7 +8,7 @@ #include "TinierSharePrep.h" -#include "PersonalPrep.h" +#include "PersonalPrep.hpp" namespace GC { diff --git a/GC/TinyMC.h b/GC/TinyMC.h index c94677ff5..8ef5e10fd 100644 --- a/GC/TinyMC.h +++ b/GC/TinyMC.h @@ -46,7 +46,7 @@ class TinyMC : public MAC_Check_Base sizes.reserve(n); } - void prepare_open(const T& secret) + void prepare_open(const T& secret, int = -1) { for (auto& part : secret.get_regs()) part_MC.prepare_open(part); diff --git a/GC/TinyPrep.hpp b/GC/TinyPrep.hpp index 897b3b482..d3efbb831 100644 --- a/GC/TinyPrep.hpp +++ b/GC/TinyPrep.hpp @@ -6,6 +6,8 @@ #include "TinierSharePrep.h" #include "Protocols/MascotPrep.hpp" +#include "Protocols/ShuffleSacrifice.hpp" +#include "Protocols/MalRepRingPrep.hpp" namespace GC { diff --git a/GC/instructions.h b/GC/instructions.h index 49443cc23..62a71603f 100644 --- a/GC/instructions.h +++ b/GC/instructions.h @@ -45,6 +45,7 @@ X(NOTS, processor.nots(INST)) \ X(NOTCB, processor.notcb(INST)) \ X(ANDRS, T::andrs(PROC, EXTRA)) \ + X(ANDRSVEC, T::andrsvec(PROC, EXTRA)) \ X(ANDS, T::ands(PROC, EXTRA)) \ X(ANDM, T::andm(PROC, instruction)) \ X(ADDCB, C0 = PC1 + PC2) \ diff --git a/License.txt b/License.txt index ccaafe01e..ab7ae3bb9 100644 --- a/License.txt +++ b/License.txt @@ -1,19 +1,17 @@ -CSIRO Open Source Software Licence Agreement (variation of the BSD / MIT License) -Copyright (c) 2022, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230. -All rights reserved. CSIRO is willing to grant you a licence to this MP-SPDZ sofware on the following terms, except where otherwise indicated for third party material. -Redistribution and use of this software in source and binary forms, with or without modification, are permitted provided that the following conditions are met: -* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. -* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. -* Neither the name of CSIRO nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission of CSIRO. -EXCEPT AS EXPRESSLY STATED IN THIS AGREEMENT AND TO THE FULL EXTENT PERMITTED BY APPLICABLE LAW, THE SOFTWARE IS PROVIDED "AS-IS". CSIRO MAKES NO REPRESENTATIONS, WARRANTIES OR CONDITIONS OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY REPRESENTATIONS, WARRANTIES OR CONDITIONS REGARDING THE CONTENTS OR ACCURACY OF THE SOFTWARE, OR OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, THE ABSENCE OF LATENT OR OTHER DEFECTS, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT DISCOVERABLE. -TO THE FULL EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL CSIRO BE LIABLE ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, IN AN ACTION FOR BREACH OF CONTRACT, NEGLIGENCE OR OTHERWISE) FOR ANY CLAIM, LOSS, DAMAGES OR OTHER LIABILITY HOWSOEVER INCURRED. WITHOUT LIMITING THE SCOPE OF THE PREVIOUS SENTENCE THE EXCLUSION OF LIABILITY SHALL INCLUDE: LOSS OF PRODUCTION OR OPERATION TIME, LOSS, DAMAGE OR CORRUPTION OF DATA OR RECORDS; OR LOSS OF ANTICIPATED SAVINGS, OPPORTUNITY, REVENUE, PROFIT OR GOODWILL, OR OTHER ECONOMIC LOSS; OR ANY SPECIAL, INCIDENTAL, INDIRECT, CONSEQUENTIAL, PUNITIVE OR EXEMPLARY DAMAGES, ARISING OUT OF OR IN CONNECTION WITH THIS AGREEMENT, ACCESS OF THE SOFTWARE OR ANY OTHER DEALINGS WITH THE SOFTWARE, EVEN IF CSIRO HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH CLAIM, LOSS, DAMAGES OR OTHER LIABILITY. -APPLICABLE LEGISLATION SUCH AS THE AUSTRALIAN CONSUMER LAW MAY APPLY REPRESENTATIONS, WARRANTIES, OR CONDITIONS, OR IMPOSES OBLIGATIONS OR LIABILITY ON CSIRO THAT CANNOT BE EXCLUDED, RESTRICTED OR MODIFIED TO THE FULL EXTENT SET OUT IN THE EXPRESS TERMS OF THIS CLAUSE ABOVE "CONSUMER GUARANTEES". TO THE EXTENT THAT SUCH CONSUMER GUARANTEES CONTINUE TO APPLY, THEN TO THE FULL EXTENT PERMITTED BY THE APPLICABLE LEGISLATION, THE LIABILITY OF CSIRO UNDER THE RELEVANT CONSUMER GUARANTEE IS LIMITED (WHERE PERMITTED AT CSIRO'S OPTION) TO ONE OF FOLLOWING REMEDIES OR SUBSTANTIALLY EQUIVALENT REMEDIES: -(a) THE REPLACEMENT OF THE SOFTWARE, THE SUPPLY OF EQUIVALENT SOFTWARE, OR SUPPLYING RELEVANT SERVICES AGAIN; -(b) THE REPAIR OF THE SOFTWARE; -(c) THE PAYMENT OF THE COST OF REPLACING THE SOFTWARE, OF ACQUIRING EQUIVALENT SOFTWARE, HAVING THE RELEVANT SERVICES SUPPLIED AGAIN, OR HAVING THE SOFTWARE REPAIRED. -IN THIS CLAUSE, CSIRO INCLUDES ANY THIRD PARTY AUTHOR OR OWNER OF ANY PART OF THE SOFTWARE OR MATERIAL DISTRIBUTED WITH IT. CSIRO MAY ENFORCE ANY RIGHTS ON BEHALF OF THE RELEVANT THIRD PARTY. -Third Party Components -The following third party components are distributed with the Software. You agree to comply with the licence terms for these components as part of accessing the Software. Other third party software may also be identified in separate files distributed with the Software. +The Software is copyright (c) 2022, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230. + +CSIRO grants you a licence to the Software on the terms of the BSD 3-Clause Licence. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + + Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +The following third party components are distributed with the Software. ___________________________________________________________________ SPDZ-2 [https://github.com/bristolcrypto/SPDZ-2] Copyright (c) 2018, The University of Bristol diff --git a/Machines/MalRep.hpp b/Machines/MalRep.hpp index 020477fae..b68da12cd 100644 --- a/Machines/MalRep.hpp +++ b/Machines/MalRep.hpp @@ -9,5 +9,7 @@ #include "Protocols/MalRepRingPrep.hpp" #include "Protocols/MaliciousRepPrep.hpp" #include "Protocols/MaliciousRepMC.hpp" +#include "Protocols/Beaver.hpp" +#include "Rep.hpp" #endif /* MACHINES_MALREP_HPP_ */ diff --git a/Machines/Rep.hpp b/Machines/Rep.hpp index a480860f8..d684909b7 100644 --- a/Machines/Rep.hpp +++ b/Machines/Rep.hpp @@ -4,7 +4,8 @@ */ #include "Protocols/MalRepRingPrep.h" -#include "Protocols/ReplicatedPrep2k.h" +#include "Protocols/SemiRep3Prep.h" +#include "GC/SemiHonestRepPrep.h" #include "Processor/Data_Files.hpp" #include "Processor/Instruction.hpp" @@ -12,6 +13,8 @@ #include "Protocols/MAC_Check_Base.hpp" #include "Protocols/Beaver.hpp" #include "Protocols/Spdz2kPrep.hpp" +#include "Protocols/ReplicatedMC.hpp" +#include "Protocols/Rep3Shuffler.hpp" #include "Math/Z2k.hpp" #include "GC/ShareSecret.hpp" #include "GC/RepPrep.hpp" diff --git a/Machines/dealer-ring-party.cpp b/Machines/dealer-ring-party.cpp index 890a24ab5..e738cd868 100644 --- a/Machines/dealer-ring-party.cpp +++ b/Machines/dealer-ring-party.cpp @@ -15,8 +15,14 @@ #include "Protocols/DealerMC.hpp" #include "Protocols/DealerMatrixPrep.hpp" #include "Protocols/Beaver.hpp" -#include "Semi.hpp" +#include "Protocols/SemiInput.hpp" +#include "Protocols/MAC_Check_Base.hpp" +#include "Protocols/ReplicatedPrep.hpp" +#include "Protocols/MalRepRingPrep.hpp" +#include "Protocols/SemiMC.hpp" #include "GC/DealerPrep.h" +#include "GC/SemiPrep.h" +#include "GC/SemiSecret.hpp" int main(int argc, const char** argv) { diff --git a/Machines/emulate.cpp b/Machines/emulate.cpp index 5999050c2..469116953 100644 --- a/Machines/emulate.cpp +++ b/Machines/emulate.cpp @@ -17,6 +17,7 @@ #include "Protocols/ReplicatedPrep.hpp" #include "Protocols/FakeShare.hpp" #include "Protocols/MalRepRingPrep.hpp" +#include "Protocols/MAC_Check_Base.hpp" int main(int argc, const char** argv) { diff --git a/Machines/malicious-rep-bin-party.cpp b/Machines/malicious-rep-bin-party.cpp index 2ae796715..d2747a0e4 100644 --- a/Machines/malicious-rep-bin-party.cpp +++ b/Machines/malicious-rep-bin-party.cpp @@ -7,12 +7,14 @@ #include "GC/ShareParty.hpp" #include "GC/ShareSecret.hpp" #include "GC/MaliciousRepSecret.h" +#include "GC/RepPrep.h" #include "GC/Machine.hpp" #include "GC/Processor.hpp" #include "GC/Program.hpp" #include "GC/Thread.hpp" #include "GC/ThreadMaster.hpp" +#include "GC/RepPrep.hpp" #include "Processor/Instruction.hpp" #include "Protocols/MaliciousRepMC.hpp" diff --git a/Machines/mascot-offline.cpp b/Machines/mascot-offline.cpp index 975ae030b..e24735b7d 100644 --- a/Machines/mascot-offline.cpp +++ b/Machines/mascot-offline.cpp @@ -9,6 +9,7 @@ #include "Math/gfp.hpp" #include "Processor/FieldMachine.hpp" #include "Processor/OfflineMachine.hpp" +#include "Protocols/MascotPrep.hpp" int main(int argc, const char** argv) { diff --git a/Machines/no-party.cpp b/Machines/no-party.cpp index ce542de18..ceb35b089 100644 --- a/Machines/no-party.cpp +++ b/Machines/no-party.cpp @@ -9,6 +9,8 @@ #include "Processor/Machine.hpp" #include "Protocols/Replicated.hpp" #include "Protocols/MalRepRingPrep.hpp" +#include "Protocols/ReplicatedPrep.hpp" +#include "Protocols/MAC_Check_Base.hpp" #include "Math/gfp.hpp" #include "Math/Z2k.hpp" diff --git a/Machines/ps-rep-bin-party.cpp b/Machines/ps-rep-bin-party.cpp index 98ffb2984..4ab361398 100644 --- a/Machines/ps-rep-bin-party.cpp +++ b/Machines/ps-rep-bin-party.cpp @@ -5,8 +5,11 @@ #include "GC/PostSacriBin.h" #include "GC/PostSacriSecret.h" +#include "GC/RepPrep.h" #include "GC/ShareParty.hpp" +#include "GC/RepPrep.hpp" +#include "Protocols/MaliciousRepMC.hpp" int main(int argc, const char** argv) { diff --git a/Machines/real-bmr-party.cpp b/Machines/real-bmr-party.cpp index 42000ddf8..8f329971d 100644 --- a/Machines/real-bmr-party.cpp +++ b/Machines/real-bmr-party.cpp @@ -7,6 +7,7 @@ #include "BMR/RealProgramParty.hpp" #include "Machines/SPDZ.hpp" +#include "Protocols/MascotPrep.hpp" int main(int argc, const char** argv) { diff --git a/Machines/replicated-bin-party.cpp b/Machines/replicated-bin-party.cpp index 763b1918b..153d830eb 100644 --- a/Machines/replicated-bin-party.cpp +++ b/Machines/replicated-bin-party.cpp @@ -4,6 +4,7 @@ */ #include "GC/ShareParty.h" +#include "GC/SemiHonestRepPrep.h" #include "GC/ShareParty.hpp" #include "GC/ShareSecret.hpp" @@ -12,6 +13,7 @@ #include "GC/Program.hpp" #include "GC/Thread.hpp" #include "GC/ThreadMaster.hpp" +#include "GC/RepPrep.hpp" #include "Processor/Instruction.hpp" #include "Protocols/MaliciousRepMC.hpp" diff --git a/Machines/replicated-ring-party.cpp b/Machines/replicated-ring-party.cpp index 2b3646fe4..a295eafe1 100644 --- a/Machines/replicated-ring-party.cpp +++ b/Machines/replicated-ring-party.cpp @@ -4,7 +4,6 @@ */ #include "Protocols/Rep3Share2k.h" -#include "Protocols/ReplicatedPrep2k.h" #include "Processor/RingOptions.h" #include "Math/Integer.h" #include "Machines/RepRing.hpp" diff --git a/Machines/sy-rep-field-party.cpp b/Machines/sy-rep-field-party.cpp index 1da856768..a457e3b09 100644 --- a/Machines/sy-rep-field-party.cpp +++ b/Machines/sy-rep-field-party.cpp @@ -13,10 +13,10 @@ #include "Math/gf2n.h" #include "Tools/ezOptionParser.h" #include "GC/MaliciousCcdSecret.h" +#include "GC/SemiHonestRepPrep.h" #include "Processor/FieldMachine.hpp" #include "Protocols/Replicated.hpp" -#include "Protocols/MaliciousRepMC.hpp" #include "Protocols/Share.hpp" #include "Protocols/fake-stuff.hpp" #include "Protocols/SpdzWise.hpp" @@ -30,6 +30,7 @@ #include "GC/RepPrep.hpp" #include "GC/ThreadMaster.hpp" #include "Math/gfp.hpp" +#include "MalRep.hpp" int main(int argc, const char** argv) { diff --git a/Machines/sy-rep-ring-party.cpp b/Machines/sy-rep-ring-party.cpp index 728466f72..45faca6f1 100644 --- a/Machines/sy-rep-ring-party.cpp +++ b/Machines/sy-rep-ring-party.cpp @@ -11,10 +11,10 @@ #include "Protocols/MalRepRingPrep.h" #include "Processor/RingOptions.h" #include "GC/MaliciousCcdSecret.h" +#include "GC/SemiHonestRepPrep.h" #include "Processor/RingMachine.hpp" #include "Protocols/Replicated.hpp" -#include "Protocols/MaliciousRepMC.hpp" #include "Protocols/Share.hpp" #include "Protocols/fake-stuff.hpp" #include "Protocols/SpdzWise.hpp" @@ -32,6 +32,7 @@ #include "GC/ShareSecret.hpp" #include "GC/RepPrep.hpp" #include "GC/ThreadMaster.hpp" +#include "MalRep.hpp" int main(int argc, const char** argv) { diff --git a/Machines/sy-shamir-party.cpp b/Machines/sy-shamir-party.cpp index b009abb3f..d251e7cdc 100644 --- a/Machines/sy-shamir-party.cpp +++ b/Machines/sy-shamir-party.cpp @@ -12,6 +12,7 @@ #include "Math/gf2n.h" #include "GC/CcdSecret.h" #include "GC/MaliciousCcdSecret.h" +#include "GC/SemiHonestRepPrep.h" #include "Protocols/Share.hpp" #include "Protocols/SpdzWise.hpp" diff --git a/Machines/tinier-party.cpp b/Machines/tinier-party.cpp index 35aae3aa9..1ea00ffe3 100644 --- a/Machines/tinier-party.cpp +++ b/Machines/tinier-party.cpp @@ -25,6 +25,7 @@ #include "Protocols/MAC_Check_Base.hpp" #include "Protocols/Beaver.hpp" #include "Protocols/MascotPrep.hpp" +#include "Protocols/MalRepRingPrep.hpp" int main(int argc, const char** argv) { diff --git a/Makefile b/Makefile index 12fda5bdc..467e6d8f1 100644 --- a/Makefile +++ b/Makefile @@ -12,7 +12,7 @@ PROCESSOR = $(patsubst %.cpp,%.o,$(wildcard Processor/*.cpp)) FHEOBJS = $(patsubst %.cpp,%.o,$(wildcard FHEOffline/*.cpp FHE/*.cpp)) Protocols/CowGearOptions.o GC = $(patsubst %.cpp,%.o,$(wildcard GC/*.cpp)) $(PROCESSOR) -GC_SEMI = GC/SemiPrep.o GC/square64.o +GC_SEMI = GC/SemiPrep.o GC/square64.o GC/Semi.o OT = $(patsubst %.cpp,%.o,$(wildcard OT/*.cpp)) $(LIBSIMPLEOT) OT_EXE = ot.x ot-offline.x @@ -40,6 +40,17 @@ LIBSIMPLEOT_ASM = deps/SimpleOT/libsimpleot.a LIBSIMPLEOT += $(LIBSIMPLEOT_ASM) endif +STATIC_OTE = local/lib/liblibOTe.a +SHARED_OTE = local/lib/liblibOTe.so + +ifeq ($(USE_KOS), 0) +ifeq ($(USE_SHARED_OTE), 1) +OT += $(SHARED_OTE) local/lib/libcryptoTools.so +else +OT += $(STATIC_OTE) local/lib/libcryptoTools.a +endif +endif + # used for dependency generation OBJS = $(BMR) $(FHEOBJS) $(TINYOTOFFLINE) $(YAO) $(COMPLETE) $(patsubst %.cpp,%.o,$(wildcard Machines/*.cpp Utils/*.cpp)) DEPS := $(wildcard */*.d */*/*.d) @@ -106,6 +117,7 @@ endif tldr: libote $(MAKE) mascot-party.x + mkdir Player-Data 2> /dev/null; true ifeq ($(ARM), 1) Tools/intrinsics.h: deps/simde/simde @@ -130,8 +142,8 @@ $(SHAREDLIB): $(PROCESSOR) $(COMMONOBJS) GC/square64.o GC/Instruction.o $(FHEOFFLINE): $(FHEOBJS) $(SHAREDLIB) $(CXX) $(CFLAGS) -shared -o $@ $^ $(LDLIBS) -static/%.x: Machines/%.o $(LIBRELEASE) $(LIBSIMPLEOT) - $(CXX) $(CFLAGS) -o $@ $^ -Wl,-Map=$<.map -Wl,-Bstatic -static-libgcc -static-libstdc++ $(LIBRELEASE) $(LIBSIMPLEOT) $(BOOST) $(LDLIBS) -Wl,-Bdynamic -ldl +static/%.x: Machines/%.o $(LIBRELEASE) $(LIBSIMPLEOT) local/lib/libcryptoTools.a local/lib/liblibOTe.a + $(CXX) -o $@ $(CFLAGS) $^ -Wl,-Map=$<.map -Wl,-Bstatic -static-libgcc -static-libstdc++ $(LIBRELEASE) -llibOTe -lcryptoTools $(LIBSIMPLEOT) $(BOOST) $(LDLIBS) -Wl,-Bdynamic -ldl static/%.x: ECDSA/%.o ECDSA/P256Element.o $(VMOBJS) $(OT) $(LIBSIMPLEOT) $(CXX) $(CFLAGS) -o $@ $^ -Wl,-Map=$<.map -Wl,-Bstatic -static-libgcc -static-libstdc++ $(BOOST) $(LDLIBS) -Wl,-Bdynamic -ldl @@ -201,13 +213,13 @@ replicated-field-party.x: GC/square64.o brain-party.x: GC/square64.o malicious-rep-bin-party.x: GC/square64.o ps-rep-bin-party.x: GC/PostSacriBin.o -semi-bin-party.x: $(OT) GC/SemiPrep.o GC/square64.o +semi-bin-party.x: $(OT) $(GC_SEMI) tiny-party.x: $(OT) tinier-party.x: $(OT) spdz2k-party.x: $(TINIER) $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp)) static/spdz2k-party.x: $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp)) -semi-party.x: $(OT) GC/SemiPrep.o GC/square64.o -semi2k-party.x: $(OT) GC/SemiPrep.o GC/square64.o +semi-party.x: $(OT) $(GC_SEMI) +semi2k-party.x: $(OT) $(GC_SEMI) hemi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) temi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) soho-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) @@ -232,15 +244,15 @@ malicious-rep-ring-party.x: Protocols/MalRepRingOptions.o sy-rep-ring-party.x: Protocols/MalRepRingOptions.o rep4-ring-party.x: GC/Rep4Secret.o no-party.x: Protocols/ShareInterface.o -semi-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) GC/SemiPrep.o +semi-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) $(GC_SEMI) mascot-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) fake-spdz-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) emulate.x: GC/FakeSecret.o -semi-bmr-party.x: GC/SemiPrep.o $(OT) +semi-bmr-party.x: $(GC_SEMI) $(OT) real-bmr-party.x: $(OT) paper-example.x: $(VM) $(OT) $(FHEOFFLINE) -binary-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/AtlasSecret.o -mixed-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/AtlasSecret.o Machines/Tinier.o +binary-example.x: $(VM) $(OT) GC/PostSacriBin.o $(GC_SEMI) GC/AtlasSecret.o +mixed-example.x: $(VM) $(OT) GC/PostSacriBin.o $(GC_SEMI) GC/AtlasSecret.o Machines/Tinier.o l2h-example.x: $(VM) $(OT) Machines/Tinier.o he-example.x: $(FHEOFFLINE) mascot-offline.x: $(VM) $(TINIER) @@ -272,14 +284,15 @@ OT/BaseOT.o: deps/SimplestOT_C/ref10/Makefile deps/SimplestOT_C/ref10/Makefile: git submodule update --init deps/SimplestOT_C || git clone https://github.com/mkskeller/SimplestOT_C deps/SimplestOT_C - cd deps/SimplestOT_C/ref10; cmake . + cd deps/SimplestOT_C/ref10; PATH=$(CURDIR)/local/bin:$(PATH) cmake . .PHONY: Programs/Circuits Programs/Circuits: git submodule update --init Programs/Circuits -.PHONY: mpir-setup mpir-global mpir -mpir-setup: +.PHONY: mpir-setup mpir-global +mpir-setup: deps/mpir/Makefile +deps/mpir/Makefile: git submodule update --init deps/mpir || git clone https://github.com/wbhart/mpir deps/mpir cd deps/mpir; \ autoreconf -i; \ @@ -292,35 +305,45 @@ mpir-global: mpir-setup $(MAKE) -C deps/mpir sudo $(MAKE) -C deps/mpir install -mpir: mpir-setup +mpir: local/lib/libmpirxx.so +local/lib/libmpirxx.so: deps/mpir/Makefile cd deps/mpir; \ ./configure --enable-cxx --prefix=$(CURDIR)/local $(MAKE) -C deps/mpir install - -echo MY_CFLAGS += -I./local/include >> CONFIG.mine - -echo MY_LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib >> CONFIG.mine deps/libOTe/libOTe: - git submodule update --init --recursive deps/libOTe - -echo MY_CFLAGS += -I./local/include >> CONFIG.mine - -echo MY_LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib >> CONFIG.mine - + git submodule update --init --recursive deps/libOTe || git clone --recurse-submodules https://github.com/mkskeller/softspoken-implementation deps/libOTe boost: deps/libOTe/libOTe cd deps/libOTe; \ python3 build.py --setup --boost --install=$(CURDIR)/local OTE_OPTS = -DENABLE_SOFTSPOKEN_OT=ON -DCMAKE_CXX_COMPILER=$(CXX) -DCMAKE_INSTALL_LIBDIR=lib +ifeq ($(USE_SHARED_OTE), 1) +OTE = $(SHARED_OTE) +else +OTE = $(STATIC_OTE) +endif + +libote: + rm $(STATIC_OTE) $(SHARED_OTE)* 2>/dev/null; true + $(MAKE) $(OTE) + +local/lib/libcryptoTools.a: $(STATIC_OTE) +local/lib/libcryptoTools.so: $(SHARED_OTE) +OT/OTExtensionWithMatrix.o: $(OTE) + ifeq ($(ARM), 1) -libote: deps/libOTe/libOTe +local/lib/liblibOTe.a: deps/libOTe/libOTe cd deps/libOTe; \ PATH="$(CURDIR)/local/bin:$(PATH)" python3 build.py --install=$(CURDIR)/local -- -DBUILD_SHARED_LIBS=0 -DENABLE_AVX=OFF -DENABLE_SSE=OFF $(OTE_OPTS) else -libote: deps/libOTe/libOTe +local/lib/liblibOTe.a: deps/libOTe/libOTe cd deps/libOTe; \ PATH="$(CURDIR)/local/bin:$(PATH)" python3 build.py --install=$(CURDIR)/local -- -DBUILD_SHARED_LIBS=0 $(OTE_OPTS) endif -libote-shared: deps/libOTe/libOTe +$(SHARED_OTE): deps/libOTe/libOTe cd deps/libOTe; \ python3 build.py --install=$(CURDIR)/local -- -DBUILD_SHARED_LIBS=1 $(OTE_OPTS) diff --git a/Math/BitVec.h b/Math/BitVec.h index f0d60a1b9..f4b0a1e2c 100644 --- a/Math/BitVec.h +++ b/Math/BitVec.h @@ -69,6 +69,8 @@ class BitVec_ : public IntBase { if (n == -1) pack(os); + else if (n == 1) + os.store_int<1>(this->a & 1); else os.store_int(super::mask(n).get(), DIV_CEIL(n, 8)); } @@ -77,6 +79,8 @@ class BitVec_ : public IntBase { if (n == -1) unpack(os); + else if (n == 1) + this->a = os.get_int<1>(); else this->a = os.get_int(DIV_CEIL(n, 8)); } diff --git a/Math/Square.hpp b/Math/Square.hpp index 98b646ee9..7ca997eb0 100644 --- a/Math/Square.hpp +++ b/Math/Square.hpp @@ -4,6 +4,7 @@ */ #include "Math/Square.h" +#include "Math/BitVec.h" template Square& Square::sub(const Square& other) @@ -40,6 +41,16 @@ void Square::bit_sub(const BitVector& bits, int start) } } +template<> +inline +void Square::bit_sub(const BitVector& bits, int start) +{ + for (int i = 0; i < BitVec::length(); i++) + { + rows[i] -= bits.get_portion(start + i); + } +} + template void Square::conditional_add(BitVector& conditions, Square& other, int offset) diff --git a/Math/Zp_Data.h b/Math/Zp_Data.h index 60a132d60..3d3ecc20d 100644 --- a/Math/Zp_Data.h +++ b/Math/Zp_Data.h @@ -20,10 +20,10 @@ using namespace std; #ifndef MAX_MOD_SZ - #if defined(GFP_MOD_SZ) and GFP_MOD_SZ > 10 + #if defined(GFP_MOD_SZ) and GFP_MOD_SZ > 11 #define MAX_MOD_SZ GFP_MOD_SZ #else - #define MAX_MOD_SZ 10 + #define MAX_MOD_SZ 11 #endif #endif diff --git a/Math/field_types.h b/Math/field_types.h index 9f54d3afa..052cc40a8 100644 --- a/Math/field_types.h +++ b/Math/field_types.h @@ -16,7 +16,8 @@ enum Dtype DATA_BIT, DATA_INVERSE, DATA_DABIT, - N_DTYPE + DATA_MIXED, + N_DTYPE, }; #endif /* MATH_FIELD_TYPES_H_ */ diff --git a/Math/mpn_fixed.h b/Math/mpn_fixed.h index b1c5642be..49bc8528f 100644 --- a/Math/mpn_fixed.h +++ b/Math/mpn_fixed.h @@ -70,20 +70,6 @@ inline void mpn_add_fixed_n<2>(mp_limb_t* res, const mp_limb_t* x, const mp_limb ); } -template <> -inline void mpn_add_fixed_n<3>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y) -{ - memcpy(res, y, 3 * sizeof(mp_limb_t)); - __asm__ ( - "add %3, %0 \n" - "adc %4, %1 \n" - "adc %5, %2 \n" - : "+&r"(res[0]), "+&r"(res[1]), "+r"(res[2]) - : "rm"(x[0]), "rm"(x[1]), "rm"(x[2]) - : "cc" - ); -} - template <> inline void mpn_add_fixed_n<4>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y) { diff --git a/Networking/data.h b/Networking/data.h index 6d7fb7289..e2bda0428 100644 --- a/Networking/data.h +++ b/Networking/data.h @@ -26,7 +26,7 @@ inline void short_memcpy(void* out, void* in, size_t n_bytes) X(1) X(2) X(3) X(4) X(5) X(6) X(7) X(8) #undef X default: - throw invalid_length("length outside range"); + throw invalid_length("length outside range: " + to_string(n_bytes)); } } diff --git a/OT/BaseOT.h b/OT/BaseOT.h index 3069bd89a..4faf92832 100644 --- a/OT/BaseOT.h +++ b/OT/BaseOT.h @@ -68,7 +68,7 @@ class BaseOT void set_receiver_inputs(const BitVector& new_inputs) { if ((int)new_inputs.size() != nOT) - throw invalid_length(); + throw invalid_length("BaseOT"); receiver_inputs = new_inputs; } diff --git a/OT/BitMatrix.h b/OT/BitMatrix.h index 445ebc3c7..a797b9798 100644 --- a/OT/BitMatrix.h +++ b/OT/BitMatrix.h @@ -127,6 +127,9 @@ class Matrix vector< U, aligned_allocator > squares; + typename U::RowType& operator[](int i) + { return squares[i / U::n_rows()].rows[i % U::n_rows()]; } + size_t vertical_size(); void resize_vertical(int length) diff --git a/OT/BitMatrix.hpp b/OT/BitMatrix.hpp index 00ede6337..74a682a71 100644 --- a/OT/BitMatrix.hpp +++ b/OT/BitMatrix.hpp @@ -19,7 +19,7 @@ template bool Matrix::operator==(Matrix& other) { if (squares.size() != other.squares.size()) - throw invalid_length(); + throw invalid_length("Matrix"); for (size_t i = 0; i < squares.size(); i++) if (not(squares[i] == other.squares[i])) return false; @@ -109,7 +109,7 @@ template Slice& Slice::rsub(Slice& other) { if (bm.squares.size() < other.end) - throw invalid_length(); + throw invalid_length("rsub"); for (size_t i = other.start; i < other.end; i++) bm.squares[i].rsub(other.bm.squares[i]); return *this; diff --git a/OT/MamaRectangle.h b/OT/MamaRectangle.h index 98da4d5a7..a17e3064b 100644 --- a/OT/MamaRectangle.h +++ b/OT/MamaRectangle.h @@ -18,6 +18,8 @@ class MamaRectangle typename T::Square squares[N]; public: + typedef GC::NoValue RowType; + static int n_rows() { return T::Square::n_rows(); } static int n_columns() { return T::Square::n_columns(); } static int n_row_bytes() { return T::Square::n_row_bytes(); } diff --git a/OT/NPartyTripleGenerator.h b/OT/NPartyTripleGenerator.h index 3c58e690d..b212a4805 100644 --- a/OT/NPartyTripleGenerator.h +++ b/OT/NPartyTripleGenerator.h @@ -6,6 +6,7 @@ #include "Tools/random.h" #include "Tools/time-func.h" #include "Processor/InputTuple.h" +#include "Protocols/dabit.h" #include "OT/OTTripleSetup.h" #include "OT/MascotParams.h" @@ -98,7 +99,8 @@ class OTTripleGenerator : public GeneratorThread vector> preampTriples; vector> plainTriples; - vector plainBits; + vector> plainBits; + vector> mixedTriples; typename T::MAC_Check* MC; @@ -114,6 +116,7 @@ class OTTripleGenerator : public GeneratorThread void plainTripleRound(int k = 0); void generatePlainBits(); + void generateMixedTriples(); void run_multipliers(MultJob job); diff --git a/OT/NPartyTripleGenerator.hpp b/OT/NPartyTripleGenerator.hpp index f2b981c10..47df8f49c 100644 --- a/OT/NPartyTripleGenerator.hpp +++ b/OT/NPartyTripleGenerator.hpp @@ -489,7 +489,8 @@ void OTTripleGenerator::generatePlainBits() machine.set_passive(); machine.output = false; - int n = multiple_minimum(nPreampTriplesPerLoop, T::open_type::size_in_bits()); + int n = multiple_minimum(100 * nPreampTriplesPerLoop, + T::open_type::size_in_bits()); valueBits.resize(1); valueBits[0].resize(n); @@ -500,16 +501,52 @@ void OTTripleGenerator::generatePlainBits() wait_for_multipliers(); plainBits.clear(); + typename T::open_type two = 2; + for (int j = 0; j < n; j++) { if (j % T::open_type::size_in_bits() < T::open_type::length()) { - plainBits.push_back(valueBits[0].get_bit(j)); - plainBits.back() += ot_multipliers[0]->c_output[j] * 2; + bool b = valueBits[0].get_bit(j); + plainBits.push_back({b, b}); + plainBits.back().first += ot_multipliers[0]->c_output[j] * two; } } } +template +void OTTripleGenerator::generateMixedTriples() +{ + assert(ot_multipliers.size() == 1); + + machine.set_passive(); + machine.output = false; + + int n = multiple_minimum(100 * nPreampTriplesPerLoop, + T::open_type::size_in_bits()); + + valueBits.resize(2); + valueBits[0].resize(n); + valueBits[0].randomize(share_prg); + valueBits[1].resize(n * T::open_type::N_BITS); + valueBits[1].randomize(share_prg); + + signal_multipliers(DATA_MIXED); + + wait_for_multipliers(); + mixedTriples.clear(); + + for (int j = 0; j < n; j++) + { + auto a = valueBits[0].get_bit(j); + auto b = valueBits[1].template get_portion(j); + auto c = a ? b : typename T::open_type(); + for (auto& x : ot_multipliers) + c += x->c_output[j]; + mixedTriples.push_back({{a, b, c}}); + } +} + template void OTTripleGenerator::plainTripleRound(int k) { diff --git a/OT/OTCorrelator.hpp b/OT/OTCorrelator.hpp index 00561d3c1..d6c19761b 100644 --- a/OT/OTCorrelator.hpp +++ b/OT/OTCorrelator.hpp @@ -188,7 +188,7 @@ template void OTCorrelator::reduce_squares(unsigned int nTriples, vector& output, int start) { if (receiverOutputMatrix.squares.size() < nTriples + start) - throw invalid_length(); + throw invalid_length("reduce_squares"); output.resize(nTriples); for (unsigned int j = 0; j < nTriples; j++) { diff --git a/OT/OTExtensionWithMatrix.cpp b/OT/OTExtensionWithMatrix.cpp index 258e74309..409a4f995 100644 --- a/OT/OTExtensionWithMatrix.cpp +++ b/OT/OTExtensionWithMatrix.cpp @@ -9,7 +9,10 @@ #ifndef USE_KOS #include "Networking/PlayerCtSocket.h" -osuCrypto::IOService OTExtensionWithMatrix::ios; +#include +#include + +osuCrypto::IOService ot_extension_ios; #endif #include "OTCorrelator.hpp" @@ -112,7 +115,7 @@ void OTExtensionWithMatrix::extend(int nOTs_requested, const BitVector& newRecei resize(nOTs_requested); if (not channel) - channel = new osuCrypto::Channel(ios, new PlayerCtSocket(*player)); + channel = new osuCrypto::Channel(ot_extension_ios, new PlayerCtSocket(*player)); if (player->my_num()) { diff --git a/OT/OTExtensionWithMatrix.h b/OT/OTExtensionWithMatrix.h index e15ac9537..e6eab6da0 100644 --- a/OT/OTExtensionWithMatrix.h +++ b/OT/OTExtensionWithMatrix.h @@ -11,8 +11,9 @@ #include "Math/gf2n.h" #ifndef USE_KOS -#include -#include +namespace osuCrypto { +class Channel; +} #endif template @@ -57,7 +58,6 @@ class OTExtensionWithMatrix : public OTCorrelator int nsubloops; #ifndef USE_KOS - static osuCrypto::IOService ios; osuCrypto::Channel* channel; #endif diff --git a/OT/OTMultiplier.h b/OT/OTMultiplier.h index 21ec0622b..0f86bc0ca 100644 --- a/OT/OTMultiplier.h +++ b/OT/OTMultiplier.h @@ -59,6 +59,7 @@ class OTMultiplier : public OTMultiplierMac } void multiplyForBits(); + void multiplyForMixed(); void after_correlation(); diff --git a/OT/OTMultiplier.hpp b/OT/OTMultiplier.hpp index 24ad88a1d..63f4dd087 100644 --- a/OT/OTMultiplier.hpp +++ b/OT/OTMultiplier.hpp @@ -128,6 +128,9 @@ void OTMultiplier::multiply() case DATA_TRIPLE: multiplyForTriples(); break; + case DATA_MIXED: + multiplyForMixed(); + break; default: throw not_implemented(); } @@ -188,6 +191,55 @@ void SemiMultiplier::multiplyForBits() this->outbox.push({}); } +template +void SemiMultiplier::multiplyForMixed() +{ + auto& rot_ext = this->rot_ext; + + typedef Square X; + OTCorrelator> otCorrelator( + this->generator.players[this->thread_num], BOTH, true); + + BitVector aBits = this->generator.valueBits[0]; + rot_ext.extend_correlated(aBits); + + auto& baseSenderOutputs = otCorrelator.matrices; + auto& baseReceiverOutput = otCorrelator.senderOutputMatrices[0]; + + rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput); + + if (this->generator.get_player().num_players() == 2) + { + c_output.clear(); + + for (unsigned j = 0; j < aBits.size(); j++) + { + this->generator.valueBits[1].set_portion(j, + BitVec(baseSenderOutputs[0][j] ^ baseSenderOutputs[1][j])); + c_output.push_back(baseReceiverOutput[j] ^ baseSenderOutputs[0][j]); + } + + this->outbox.push({}); + return; + } + + otCorrelator.setup_for_correlation(aBits, baseSenderOutputs, + baseReceiverOutput); + otCorrelator.correlate(0, otCorrelator.receiverOutputMatrix.squares.size(), + this->generator.valueBits[1], false, -1); + + c_output.clear(); + + for (unsigned j = 0; j < aBits.size(); j++) + { + c_output.push_back( + otCorrelator.receiverOutputMatrix[j] + ^ otCorrelator.senderOutputMatrices[0][j]); + } + + this->outbox.push({}); +} + template void OTMultiplier::multiplyForTriples() { @@ -592,3 +644,9 @@ void OTMultiplier::multiplyForBits() { throw runtime_error("bit generation not implemented in this case"); } + +template +void OTMultiplier::multiplyForMixed() +{ + throw runtime_error("mixed generation not implemented in this case"); +} diff --git a/Processor/BaseMachine.cpp b/Processor/BaseMachine.cpp index 3a4a63b4c..d49d08c73 100644 --- a/Processor/BaseMachine.cpp +++ b/Processor/BaseMachine.cpp @@ -67,6 +67,14 @@ void BaseMachine::load_schedule(const string& progname, bool load_bytecode) string threadname; for (int i=0; i> threadname; + size_t split = threadname.find(":"); + long expected = -1; + if (split != string::npos) + { + expected = atoi(threadname.substr(split + 1).c_str()); + threadname = threadname.substr(0, split); + } + string filename = "Programs/Bytecode/" + threadname + ".bc"; bc_filenames.push_back(filename); if (load_bytecode) @@ -74,8 +82,11 @@ void BaseMachine::load_schedule(const string& progname, bool load_bytecode) #ifdef DEBUG_FILES cerr << "Loading program " << i << " from " << filename << endl; #endif - load_program(threadname, filename); + long size = load_program(threadname, filename); + if (expected >= 0 and expected != size) + throw runtime_error("broken bytecode file"); } + } for (auto i : {1, 0, 0}) @@ -99,7 +110,8 @@ void BaseMachine::print_compiler() cerr << "Compiler: " << compiler << endl; } -void BaseMachine::load_program(const string& threadname, const string& filename) +size_t BaseMachine::load_program(const string& threadname, + const string& filename) { (void)threadname; (void)filename; diff --git a/Processor/BaseMachine.h b/Processor/BaseMachine.h index 564affe04..6b5a029f1 100644 --- a/Processor/BaseMachine.h +++ b/Processor/BaseMachine.h @@ -31,7 +31,8 @@ class BaseMachine string domain; string relevant_opts; - virtual void load_program(const string& threadname, const string& filename); + virtual size_t load_program(const string& threadname, + const string& filename); public: static thread_local int thread_num; diff --git a/Processor/Data_Files.hpp b/Processor/Data_Files.hpp index 3d40e2ca7..46c84903c 100644 --- a/Processor/Data_Files.hpp +++ b/Processor/Data_Files.hpp @@ -7,8 +7,7 @@ #include "Protocols/dabit.h" #include "Math/Setup.h" #include "GC/BitPrepFiles.h" - -#include "Protocols/MascotPrep.hpp" +#include "Tools/benchmarking.h" template Preprocessing* Preprocessing::get_live_prep(SubProcessor* proc, @@ -44,6 +43,20 @@ Preprocessing* Preprocessing::get_new( BaseMachine::thread_num); } +template +T Preprocessing::get_random_from_inputs(int nplayers) +{ + T res; + for (int j = 0; j < nplayers; j++) + { + T tmp; + typename T::open_type _; + this->get_input_no_count(tmp, _, j); + res += tmp; + } + return res; +} + template Sub_Data_Files::Sub_Data_Files(const Names& N, DataPositions& usage, int thread_num) : diff --git a/Processor/Instruction.h b/Processor/Instruction.h index 1de58c994..011dcb581 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -84,6 +84,7 @@ enum SUBSI = 0x2A, SUBCFI = 0x2B, SUBSFI = 0x2C, + PREFIXSUMS = 0x2D, // Multiplication/division/other arithmetic MULC = 0x30, MULM = 0x31, diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index e3761f5f5..da4dd01ea 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -130,6 +130,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case DABIT: case SHUFFLE: case ACCEPTCLIENTCONNECTION: + case PREFIXSUMS: get_ints(r, s, 2); break; // instructions with 1 register operand @@ -458,6 +459,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case STMSDCI: case XORS: case ANDRS: + case ANDRSVEC: case ANDS: case INPUTB: case INPUTBVEC: @@ -646,6 +648,7 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const int offset = 0; int size_offset = 0; int size = this->size; + bool n_prefix = 0; // special treatment for instructions writing to different types switch (opcode) @@ -731,25 +734,17 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const offset = 1; size_offset = -1; break; + case ANDRSVEC: + n_prefix = 2; + break; case INPUTB: skip = 4; offset = 3; size_offset = -2; break; case INPUTBVEC: - { - int res = 0; - auto it = start.begin(); - while (it < start.end()) - { - int n = *it - 3; - it += 3; - assert(it + n <= start.end()); - for (int i = 0; i < n; i++) - res = max(res, *it++); - } - return res + 1; - } + n_prefix = 3; + break; case ANDM: case NOTS: case NOTCB: @@ -795,6 +790,22 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const break; } + if (n_prefix > 0) + { + int res = 0; + auto it = start.begin(); + while (it < start.end()) + { + int n = *it - n_prefix; + int size = DIV_CEIL(*(it + 1), 64); + it += n_prefix; + assert(it + n <= start.end()); + for (int i = 0; i < n; i++) + res = max(res, *it++ + size); + } + return res; + } + if (skip > 0) { unsigned m = 0; @@ -1323,8 +1334,13 @@ void Program::execute(Processor& Proc) const (void) start; #ifdef COUNT_INSTRUCTIONS +#ifdef TIME_INSTRUCTIONS + RunningTimer timer; + int PC = Proc.PC; +#else Proc.stats[p[Proc.PC].get_opcode()]++; #endif +#endif #ifdef OUTPUT_INSTRUCTIONS cerr << instruction << endl; @@ -1352,6 +1368,10 @@ void Program::execute(Processor& Proc) const default: instruction.execute(Proc); } + +#if defined(COUNT_INSTRUCTIONS) and defined(TIME_INSTRUCTIONS) + Proc.stats[p[PC].get_opcode()] += timer.elapsed() * 1e9; +#endif } } diff --git a/Processor/Machine.h b/Processor/Machine.h index 8b3d018cc..d3c1346b2 100644 --- a/Processor/Machine.h +++ b/Processor/Machine.h @@ -44,7 +44,7 @@ class Machine : public BaseMachine Player* P; - void load_program(const string& threadname, const string& filename); + size_t load_program(const string& threadname, const string& filename); void prepare(const string& progname_str); diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index 6fb9b5fc9..4ff526084 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -199,7 +199,7 @@ Machine::~Machine() } template -void Machine::load_program(const string& threadname, +size_t Machine::load_program(const string& threadname, const string& filename) { progs.push_back(N.num_players()); @@ -208,6 +208,7 @@ void Machine::load_program(const string& threadname, M2.minimum_size(SGF2N, CGF2N, progs[i], threadname); Mp.minimum_size(SINT, CINT, progs[i], threadname); Mi.minimum_size(NONE, INT, progs[i], threadname); + return progs.back().size(); } template diff --git a/Processor/Online-Thread.hpp b/Processor/Online-Thread.hpp index 038f28d2c..0c845ccf3 100644 --- a/Processor/Online-Thread.hpp +++ b/Processor/Online-Thread.hpp @@ -126,7 +126,8 @@ void thread_info::Sub_Main_Func() program = job.prognum; wait_timer.stop(); #ifdef DEBUG_THREADS - printf("\tRunning program %d\n",program); + printf("\tRunning program %d/job %d in thread %d\n", program, job.type, + num); #endif if (program==-1) @@ -208,6 +209,10 @@ void thread_info::Sub_Main_Func() *(vector>*) job.output, job.length, job.prognum, job.arg, Proc.Procp, job.begin, job.end, job.supply); +#ifdef DEBUG_THREADS + printf("\tSignalling I have finished with job %d in thread %d\n", + job.type, num); +#endif queues->finished(job); } else if (job.type == PERSONAL_TRIPLE_JOB) @@ -282,7 +287,8 @@ void thread_info::Sub_Main_Func() } #ifdef DEBUG_THREADS - printf("\tSignalling I have finished\n"); + printf("\tSignalling I have finished with program %d" + "in thread %d\n", program, num); #endif wait_timer.start(); queues->finished(job, P.total_comm()); diff --git a/Processor/Processor.h b/Processor/Processor.h index 3fedb3df7..273b3a660 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -32,7 +32,7 @@ class SubProcessor DataPositions bit_usage; - SecureShuffle shuffler; + typename T::Protocol::Shuffler shuffler; void resize(size_t size) { C.resize(size); S.resize(size); } diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index 7bd800c60..8236071e9 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -651,8 +651,9 @@ void SubProcessor::conv2ds(const Instruction& instruction) template void SubProcessor::secure_shuffle(const Instruction& instruction) { - SecureShuffle(S, instruction.get_size(), instruction.get_n(), - instruction.get_r(0), instruction.get_r(1), *this); + typename T::Protocol::Shuffler(S, instruction.get_size(), + instruction.get_n(), instruction.get_r(0), instruction.get_r(1), + *this); } template diff --git a/Processor/Program.h b/Processor/Program.h index 8fb3df141..96a70e5eb 100644 --- a/Processor/Program.h +++ b/Processor/Program.h @@ -36,6 +36,8 @@ class Program unknown_usage(false), writes_persistence(false) { compute_constants(); } + size_t size() const { return p.size(); } + // Read in a program void parse(string filename); void parse(istream& s); diff --git a/Processor/ThreadQueues.cpp b/Processor/ThreadQueues.cpp index de0134994..ecca7bbe3 100644 --- a/Processor/ThreadQueues.cpp +++ b/Processor/ThreadQueues.cpp @@ -19,6 +19,9 @@ int ThreadQueues::distribute(ThreadJob job, int n_items, int base, int ThreadQueues::find_available() { +#ifdef VERBOSE_QUEUES + cerr << available.size() << " threads in use" << endl; +#endif if (not available.empty()) return 0; for (size_t i = 1; i < size(); i++) @@ -32,7 +35,7 @@ int ThreadQueues::find_available() int ThreadQueues::get_n_per_thread(int n_items, int granularity) { - int n_per_thread = ceil(n_items / (available.size() + 1.0)) / granularity + int n_per_thread = int(ceil(n_items / (available.size() + 1.0)) / granularity) * granularity; return n_per_thread; } @@ -40,11 +43,23 @@ int ThreadQueues::get_n_per_thread(int n_items, int granularity) int ThreadQueues::distribute_no_setup(ThreadJob job, int n_items, int base, int granularity, const vector* supplies) { +#ifdef VERBOSE_QUEUES + cerr << "Distribute " << job.type << " among " << available.size() << endl; +#endif + int n_per_thread = get_n_per_thread(n_items, granularity); + + if (n_items and (n_per_thread == 0 or base + n_per_thread > n_items)) + { + available.clear(); + return base; + } + for (size_t i = 0; i < available.size(); i++) { if (base + (i + 1) * n_per_thread > size_t(n_items)) { + assert(i); available.resize(i); return base + i * n_per_thread; } @@ -59,7 +74,14 @@ int ThreadQueues::distribute_no_setup(ThreadJob job, int n_items, int base, void ThreadQueues::wrap_up(ThreadJob job) { +#ifdef VERBOSE_QUEUES + cerr << "Wrap up " << available.size() << " threads" << endl; +#endif for (int i : available) - assert(at(i)->result().output == job.output); + { + auto result = at(i)->result(); + assert(result.output == job.output); + assert(result.type == job.type); + } available.clear(); } diff --git a/Processor/instructions.h b/Processor/instructions.h index bf443b0f7..f22fde8e6 100644 --- a/Processor/instructions.h +++ b/Processor/instructions.h @@ -62,6 +62,9 @@ X(SUBCFI, auto dest = &Procp.get_C()[r[0]]; auto op1 = &Procp.get_C()[r[1]]; \ typename sint::clear op2 = int(n), \ *dest++ = op2 - *op1++) \ + X(PREFIXSUMS, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \ + sint s, \ + s += *op1++; *dest++ = s) \ X(MULM, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \ auto op2 = &Procp.get_C()[r[2]], \ *dest++ = *op1++ * *op2++) \ @@ -380,6 +383,10 @@ X(PREP, throw not_implemented(),) \ X(GPREP, throw not_implemented(),) \ X(CISC, throw not_implemented(),) \ + X(SECSHUFFLE, throw not_implemented(),) \ + X(GENSECSHUFFLE, throw not_implemented(),) \ + X(APPLYSHUFFLE, throw not_implemented(),) \ + X(DELSHUFFLE, throw not_implemented(),) \ #define ALL_INSTRUCTIONS ARITHMETIC_INSTRUCTIONS REGINT_INSTRUCTIONS \ CLEAR_GF2N_INSTRUCTIONS REMAINING_INSTRUCTIONS diff --git a/Programs/Source/adult.mpc b/Programs/Source/adult.mpc new file mode 100644 index 000000000..373e332be --- /dev/null +++ b/Programs/Source/adult.mpc @@ -0,0 +1,54 @@ +m = 6 +n_train = 32561 +n_test = 16281 + +combo = 'combo' in program.args +binary = 'binary' in program.args +mixed = 'mixed' in program.args +nocap = 'nocap' in program.args + +try: + n_threads = int(program.args[2]) +except: + n_threads = None + +if combo: + n_train += n_test + +if binary: + m = 60 + attr_lengths = [1] * m +elif mixed or nocap: + cont = 6 if mixed else 3 + m = 60 + cont + attr_lengths = [0] * cont + [1] * 60 +else: + attr_lengths = None + +program.set_bit_length(32) +program.options_from_args() + +train = sint.Array(n_train), sint.Matrix(m, n_train) +test = sint.Array(n_test), sint.Matrix(m, n_test) + +for x in train + test: + x.input_from(0) + +import decision_tree, util + +#decision_tree.debug_layers = True +decision_tree.max_leaves = 3000 + +if 'nearest' in program.args: + sfix.round_nearest = True + +sfix.set_precision_from_args(program, True) + +trainer = decision_tree.TreeTrainer( + train[1], train[0], int(program.args[1]), attr_lengths=attr_lengths, + n_threads=n_threads) +trainer.debug_selection = 'debug_selection' in program.args +trainer.debug_gini = True +layers = trainer.train_with_testing(*test) + +#decision_tree.output_decision_tree(layers) diff --git a/Programs/Source/bench-dt.mpc b/Programs/Source/bench-dt.mpc new file mode 100644 index 000000000..4c8c64c90 --- /dev/null +++ b/Programs/Source/bench-dt.mpc @@ -0,0 +1,32 @@ +binary = 'binary' in program.args + +program.set_bit_length(32) + +n_train = int(program.args[1]) +m = int(program.args[2]) + +try: + n_levels = int(program.args[3]) +except: + n_levels = 1 + +try: + n_threads = int(program.args[4]) +except: + n_threads = None + +train = sint.Array(n_train), sint.Matrix(m, n_train) + +import decision_tree, util + +decision_tree.max_leaves = 2000 + +if 'nearest' in program.args: + sfix.round_nearest = True + +layers = decision_tree.TreeTrainer( + train[1], train[0], n_levels, binary=binary, n_threads=n_threads).train() + +#decision_tree.output_decision_tree(layers) + +#decision_tree.test_decision_tree('foo', layers, *train) diff --git a/Programs/Source/benchmark_secureNN.mpc b/Programs/Source/benchmark_secureNN.mpc index 7bba218be..9b2f96744 100644 --- a/Programs/Source/benchmark_secureNN.mpc +++ b/Programs/Source/benchmark_secureNN.mpc @@ -28,7 +28,12 @@ NetworkC = [ (500, 10, 'FC') ] -network = globals()['Network' + program.args[1]] +try: + network = globals()['Network' + program.args[1]] +except: + import sys + print('Usage: %s [A,B,C,D]' % ' '.join(sys.argv)) + sys.exit(1) # c5.9xlarge has 36 cores n_threads = 8 diff --git a/Programs/Source/gc_oram.mpc b/Programs/Source/gc_oram.mpc index fa7ac702f..5ddcb5a70 100644 --- a/Programs/Source/gc_oram.mpc +++ b/Programs/Source/gc_oram.mpc @@ -7,9 +7,6 @@ from Compiler.GC.instructions import * bits.unit = 64 -program.to_merge = [ldmsdi, stmsdi, ldmsd, stmsd, stmsdci, xors, andrs] -program.stop_class = type(None) - from Compiler.circuit_oram import * from Compiler import circuit_oram diff --git a/Programs/Source/mnist_full_A.mpc b/Programs/Source/mnist_full_A.mpc index 37cd73d2d..caca22140 100644 --- a/Programs/Source/mnist_full_A.mpc +++ b/Programs/Source/mnist_full_A.mpc @@ -10,6 +10,7 @@ import util program.options_from_args() sfix.set_precision_from_args(program, adapt_ring=True) +ml.use_mux = 'mux' in program.args MultiArray.disable_index_checks() if 'profile' in program.args: diff --git a/Programs/Source/spect.mpc b/Programs/Source/spect.mpc new file mode 100644 index 000000000..95fb7d12c --- /dev/null +++ b/Programs/Source/spect.mpc @@ -0,0 +1,49 @@ +m = 22 +n_train = 80 +n_test = 187 + +debug = 'debug' in program.args +combo = 'combo' in program.args + +if debug: + n_train = 7 + +if combo: + n_train += n_test + +Array.check_indices = False +MultiArray.disable_index_checks() + +train = sint.Array(n_train), sint.Matrix(m, n_train) +test = sint.Array(n_test), sint.Matrix(m, n_test) + +for x in train: + x.input_from(0) + +if not (debug or combo): + for x in test: + x.input_from(0) + +import decision_tree, util + +#decision_tree.debug = True + +if 'nearest' in program.args: + sfix.round_nearest = True + +sfix.set_precision_from_args(program, True) + +try: + n_threads = int(program.args[3]) +except: + n_threads = None + +trainer = decision_tree.TreeTrainer( + train[1], train[0], int(program.args[1]), binary=int(program.args[2]), + n_threads=n_threads) + +if not (debug or combo): + layers = trainer.train_with_testing(*test) +else: + layers = trainer.train() + test_decision_tree('train', layers, y, x) diff --git a/Programs/Source/test_gc.mpc b/Programs/Source/test_gc.mpc index cc6a5ea12..9792aa665 100644 --- a/Programs/Source/test_gc.mpc +++ b/Programs/Source/test_gc.mpc @@ -71,7 +71,7 @@ test(r * sbit(1) + sbit(1) * r, 0) test(sbits.get_type(64)(2**64 - 1).popcnt(), 64) a = [sbits.new(x, 2) for x in range(4)] -x, y = sbits.trans(a) +x, y, *z = sbits.trans(a) test(x, 0xa) test(y, 0xc) diff --git a/Protocols/Beaver.h b/Protocols/Beaver.h index 9b695d0d1..e24cad3ae 100644 --- a/Protocols/Beaver.h +++ b/Protocols/Beaver.h @@ -27,6 +27,7 @@ class Beaver : public ProtocolBase vector shares; vector opened; vector> triples; + vector lengths; typename vector::iterator it; typename vector>::iterator triple; Preprocessing* prep; diff --git a/Protocols/Beaver.hpp b/Protocols/Beaver.hpp index dc9814870..8c89f420b 100644 --- a/Protocols/Beaver.hpp +++ b/Protocols/Beaver.hpp @@ -37,6 +37,7 @@ void Beaver::init_mul() shares.clear(); opened.clear(); triples.clear(); + lengths.clear(); } template @@ -48,12 +49,19 @@ void Beaver::prepare_mul(const T& x, const T& y, int n) triple = prep->get_triple(n); shares.push_back(x - triple[0]); shares.push_back(y - triple[1]); + lengths.push_back(n); } template void Beaver::exchange() { - MC->POpen(opened, shares, P); + assert(shares.size() == 2 * lengths.size()); + MC->init_open(P, shares.size()); + for (size_t i = 0; i < shares.size(); i++) + MC->prepare_open(shares[i], lengths[i / 2]); + MC->exchange(P); + for (size_t i = 0; i < shares.size(); i++) + opened.push_back(MC->finalize_raw()); it = opened.begin(); triple = triples.begin(); } diff --git a/Protocols/DabitSacrifice.hpp b/Protocols/DabitSacrifice.hpp index 74d9f0267..d6f485cc9 100644 --- a/Protocols/DabitSacrifice.hpp +++ b/Protocols/DabitSacrifice.hpp @@ -109,7 +109,8 @@ void DabitSacrifice::sacrifice_and_check_bits(vector >& dabits, ThreadJob job(&products, &multiplicands); int start = queues->distribute(job, multiplicands.size()); protocol.multiply(products, multiplicands, start, multiplicands.size(), proc); - queues->wrap_up(job); + if (start) + queues->wrap_up(job); } else protocol.multiply(products, multiplicands, 0, multiplicands.size(), proc); diff --git a/Protocols/DealerMC.h b/Protocols/DealerMC.h index 4e6681366..db1ed813b 100644 --- a/Protocols/DealerMC.h +++ b/Protocols/DealerMC.h @@ -22,7 +22,7 @@ class DealerMC : public MAC_Check_Base ~DealerMC(); void init_open(const Player& P, int n = 0); - void prepare_open(const T& secret); + void prepare_open(const T& secret, int n_bits = -1); void exchange(const Player& P); typename T::open_type finalize_raw(); array finalize_several(int n); diff --git a/Protocols/DealerMC.hpp b/Protocols/DealerMC.hpp index 0f63b93dc..08b4b4587 100644 --- a/Protocols/DealerMC.hpp +++ b/Protocols/DealerMC.hpp @@ -46,10 +46,10 @@ void DealerMC::init_open(const Player& P, int n) } template -void DealerMC::prepare_open(const T& secret) +void DealerMC::prepare_open(const T& secret, int n_bits) { if (sub_player) - internal.prepare_open(secret); + internal.prepare_open(secret, n_bits); else { if (secret != T()) diff --git a/Protocols/DealerPrep.hpp b/Protocols/DealerPrep.hpp index cc010dd71..ea334257c 100644 --- a/Protocols/DealerPrep.hpp +++ b/Protocols/DealerPrep.hpp @@ -7,6 +7,7 @@ #define PROTOCOLS_DEALERPREP_HPP_ #include "DealerPrep.h" +#include "GC/SemiSecret.h" template void DealerPrep::buffer_triples() diff --git a/Protocols/FakeProtocol.h b/Protocols/FakeProtocol.h index 018ac3384..c40308c59 100644 --- a/Protocols/FakeProtocol.h +++ b/Protocols/FakeProtocol.h @@ -13,6 +13,51 @@ #include +template +class FakeShuffle +{ +public: + FakeShuffle(SubProcessor&) + { + } + + FakeShuffle(vector& a, size_t n, int unit_size, size_t output_base, + size_t input_base, SubProcessor&) + { + apply(a, n, unit_size, output_base, input_base, 0, 0); + } + + size_t generate(size_t) + { + return 0; + } + + void apply(vector& a, size_t n, int unit_size, size_t output_base, + size_t input_base, int, bool) + { + auto source = a.begin() + input_base; + auto dest = a.begin() + output_base; + for (size_t i = 0; i < n; i++) + // just copy + *dest++ = *source++; + + if (n > 1) + { + // swap first two to pass check + for (int i = 0; i < unit_size; i++) + swap(a[output_base + i], a[output_base + i + unit_size]); + } + } + + void del(size_t) + { + } + + void inverse_permutation(vector&, size_t, size_t, size_t) + { + } +}; + template class FakeProtocol : public ProtocolBase { @@ -31,6 +76,8 @@ class FakeProtocol : public ProtocolBase map ltz_stats; public: + typedef FakeShuffle Shuffler; + Player& P; FakeProtocol(Player& P) : diff --git a/Protocols/HemiMatrixPrep.hpp b/Protocols/HemiMatrixPrep.hpp index 3446733e5..062c22391 100644 --- a/Protocols/HemiMatrixPrep.hpp +++ b/Protocols/HemiMatrixPrep.hpp @@ -4,6 +4,7 @@ */ #include "HemiMatrixPrep.h" +#include "MAC_Check.h" #include "FHE/Diagonalizer.h" #include "Tools/Bundle.h" @@ -113,7 +114,8 @@ void HemiMatrixPrep::buffer_triples() job.begin = start; job.end = n_matrices; matrix_rand_mult(job); - queues.wrap_up(job); + if (start) + queues.wrap_up(job); } else { @@ -177,7 +179,8 @@ void HemiMatrixPrep::buffer_triples() #endif for (int i = start; i < n_inner; i++) products[i] = multiplicands.at(i) * multiplicands2.at(i); - queues.wrap_up(job); + if (start) + queues.wrap_up(job); #ifdef VERBOSE_HE fprintf(stderr, "adding at %f\n", timer.elapsed()); fflush(stderr); diff --git a/Protocols/HemiPrep.h b/Protocols/HemiPrep.h index b2b510aa0..6db5bf432 100644 --- a/Protocols/HemiPrep.h +++ b/Protocols/HemiPrep.h @@ -30,6 +30,10 @@ class HemiPrep : public SemiHonestRingPrep map timers; + SemiPrep* two_party_prep; + + SemiPrep& get_two_party_prep(); + public: static void basic_setup(Player& P); static void teardown(); @@ -40,7 +44,7 @@ class HemiPrep : public SemiHonestRingPrep HemiPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), RingPrep(proc, usage), - SemiHonestRingPrep(proc, usage) + SemiHonestRingPrep(proc, usage), two_party_prep(0) { } ~HemiPrep(); @@ -48,6 +52,9 @@ class HemiPrep : public SemiHonestRingPrep vector*>& get_multipliers(); void buffer_triples(); + + void buffer_bits(); + void buffer_dabits(ThreadQueues* queues); }; #endif /* PROTOCOLS_HEMIPREP_H_ */ diff --git a/Protocols/HemiPrep.hpp b/Protocols/HemiPrep.hpp index ce55bce75..099466de3 100644 --- a/Protocols/HemiPrep.hpp +++ b/Protocols/HemiPrep.hpp @@ -56,6 +56,13 @@ HemiPrep::~HemiPrep() { for (auto& x : multipliers) delete x; + + if (two_party_prep) + { + auto& usage = two_party_prep->usage; + delete two_party_prep; + delete &usage; + } } template @@ -110,4 +117,51 @@ void HemiPrep::buffer_triples() {{ a.element(i), b.element(i), c.element(i) }}); } +template +SemiPrep& HemiPrep::get_two_party_prep() +{ + assert(this->proc); + assert(this->proc->P.num_players() == 2); + + if (not two_party_prep) + { + two_party_prep = new SemiPrep(this->proc, + *new DataPositions(this->proc->P.num_players())); + two_party_prep->set_protocol(this->proc->protocol); + } + + return *two_party_prep; +} + +template +void HemiPrep::buffer_bits() +{ + assert(this->proc); + if (this->proc->P.num_players() == 2) + { + auto& prep = get_two_party_prep(); + prep.buffer_dabits(0); + for (auto& x : prep.dabits) + this->bits.push_back(x.first); + prep.dabits.clear(); + } + else + SemiHonestRingPrep::buffer_bits(); +} + +template +void HemiPrep::buffer_dabits(ThreadQueues* queues) +{ + assert(this->proc); + if (this->proc->P.num_players() == 2) + { + auto& prep = get_two_party_prep(); + prep.buffer_dabits(queues); + this->dabits = prep.dabits; + prep.dabits.clear(); + } + else + SemiHonestRingPrep::buffer_dabits(queues); +} + #endif diff --git a/Protocols/HighGearKeyGen.hpp b/Protocols/HighGearKeyGen.hpp index 49fa6702b..4c405472d 100644 --- a/Protocols/HighGearKeyGen.hpp +++ b/Protocols/HighGearKeyGen.hpp @@ -168,7 +168,7 @@ void HighGearKeyGen::run(PartSetup& setup, MachineBase& machine) timer.reset(); map timers; - SimpleEncCommit_ EC(P, setup.pk, setup.FieldD, timers, machine, 0, true); + SummingEncCommit EC(P, setup.pk, setup.FieldD, timers, machine, 0, true); Plaintext_ alpha(setup.FieldD); EC.next(alpha, setup.calpha); assert(alpha.is_diagonal()); diff --git a/Protocols/LowGearKeyGen.hpp b/Protocols/LowGearKeyGen.hpp index 4b9d0d057..5056c3a84 100644 --- a/Protocols/LowGearKeyGen.hpp +++ b/Protocols/LowGearKeyGen.hpp @@ -10,6 +10,7 @@ #include "Machines/SPDZ.hpp" #include "ShareVector.hpp" +#include "MascotPrep.hpp" template LowGearKeyGen::LowGearKeyGen(Player& P, PairwiseMachine& machine, diff --git a/Protocols/MAC_Check.h b/Protocols/MAC_Check.h index d0b062c43..311de4d9a 100644 --- a/Protocols/MAC_Check.h +++ b/Protocols/MAC_Check.h @@ -92,7 +92,7 @@ class Tree_MAC_Check : public TreeSum, public MAC_Check_B virtual ~Tree_MAC_Check(); virtual void init_open(const Player& P, int n = 0); - virtual void prepare_open(const U& secret); + virtual void prepare_open(const U& secret, int = -1); virtual void exchange(const Player& P); virtual void AddToCheck(const U& share, const T& value, const Player& P); @@ -143,7 +143,7 @@ class MAC_Check_Z2k : public Tree_MAC_Check MAC_Check_Z2k(const T& ai, int opening_sum=10, int max_broadcast=10, int send_player=0); MAC_Check_Z2k(const T& ai, Names& Nms, int thread_num); - void prepare_open(const W& secret); + void prepare_open(const W& secret, int = -1); void prepare_open_no_mask(const W& secret); virtual void Check(const Player& P); @@ -184,7 +184,7 @@ class Direct_MAC_Check: public MAC_Check_ ~Direct_MAC_Check(); void init_open(const Player& P, int n = 0); - void prepare_open(const T& secret); + void prepare_open(const T& secret, int = -1); void exchange(const Player& P); }; diff --git a/Protocols/MAC_Check.hpp b/Protocols/MAC_Check.hpp index 5798d9a43..fe6a0108f 100644 --- a/Protocols/MAC_Check.hpp +++ b/Protocols/MAC_Check.hpp @@ -96,7 +96,7 @@ void Tree_MAC_Check::init_open(const Player&, int n) } template -void Tree_MAC_Check::prepare_open(const U& secret) +void Tree_MAC_Check::prepare_open(const U& secret, int) { this->values.push_back(secret.get_share()); macs.push_back(secret.get_mac()); @@ -242,7 +242,7 @@ MAC_Check_Z2k::MAC_Check_Z2k(const T& ai, Names& Nms, } template -void MAC_Check_Z2k::prepare_open(const W& secret) +void MAC_Check_Z2k::prepare_open(const W& secret, int) { prepare_open_no_mask(secret + (get_random_element() << W::clear::N_BITS)); } @@ -402,7 +402,7 @@ void Direct_MAC_Check::init_open(const Player& P, int n) } template -void Direct_MAC_Check::prepare_open(const T& secret) +void Direct_MAC_Check::prepare_open(const T& secret, int) { this->values.push_back(secret.get_share()); this->macs.push_back(secret.get_mac()); diff --git a/Protocols/MAC_Check_Base.h b/Protocols/MAC_Check_Base.h index b4f684bcf..fed190ef4 100644 --- a/Protocols/MAC_Check_Base.h +++ b/Protocols/MAC_Check_Base.h @@ -59,7 +59,7 @@ class MAC_Check_Base /// Initialize opening round virtual void init_open(const Player& P, int n = 0); /// Add value to be opened - virtual void prepare_open(const T& secret); + virtual void prepare_open(const T& secret, int n_bits = -1); /// Run opening protocol virtual void exchange(const Player& P) = 0; /// Get next opened value diff --git a/Protocols/MAC_Check_Base.hpp b/Protocols/MAC_Check_Base.hpp index 47528e006..01096fa97 100644 --- a/Protocols/MAC_Check_Base.hpp +++ b/Protocols/MAC_Check_Base.hpp @@ -53,7 +53,7 @@ void MAC_Check_Base::init_open(const Player&, int n) } template -void MAC_Check_Base::prepare_open(const T& secret) +void MAC_Check_Base::prepare_open(const T& secret, int) { secrets.push_back(secret); } diff --git a/Protocols/MalRepRingPrep.hpp b/Protocols/MalRepRingPrep.hpp index 6ce2e2442..1be5bb9a5 100644 --- a/Protocols/MalRepRingPrep.hpp +++ b/Protocols/MalRepRingPrep.hpp @@ -165,7 +165,8 @@ void TripleShuffleSacrifice::triple_sacrifice(vector>& triples, TripleSacrificeJob job(&triples, &check_triples); int start = queues->distribute(job, N); triple_sacrifice(triples, check_triples, P, MC, start, N); - queues->wrap_up(job); + if (start) + queues->wrap_up(job); } else triple_sacrifice(triples, check_triples, P, MC, 0, N); diff --git a/Protocols/MaliciousRepPrep.hpp b/Protocols/MaliciousRepPrep.hpp index b7296748e..ec26bc846 100644 --- a/Protocols/MaliciousRepPrep.hpp +++ b/Protocols/MaliciousRepPrep.hpp @@ -3,6 +3,9 @@ * */ +#ifndef PROTOCOLS_MALICIOUSREPPREP_HPP_ +#define PROTOCOLS_MALICIOUSREPPREP_HPP_ + #include "MaliciousRepPrep.h" #include "Tools/Subroutines.h" #include "Processor/OnlineOptions.h" @@ -232,3 +235,5 @@ void MaliciousRepPrep::buffer_inputs(int player) assert(proc); this->buffer_inputs_as_usual(player, proc); } + +#endif diff --git a/Protocols/MascotPrep.hpp b/Protocols/MascotPrep.hpp index 1393bb464..f5c09941d 100644 --- a/Protocols/MascotPrep.hpp +++ b/Protocols/MascotPrep.hpp @@ -108,18 +108,4 @@ void MascotInputPrep::buffer_inputs(int player) this->inputs[player].push_back(input); } -template -T Preprocessing::get_random_from_inputs(int nplayers) -{ - T res; - for (int j = 0; j < nplayers; j++) - { - T tmp; - typename T::open_type _; - this->get_input_no_count(tmp, _, j); - res += tmp; - } - return res; -} - #endif diff --git a/Protocols/PostSacriRepRingShare.h b/Protocols/PostSacriRepRingShare.h index 7cbd483c4..fccb9e0cf 100644 --- a/Protocols/PostSacriRepRingShare.h +++ b/Protocols/PostSacriRepRingShare.h @@ -9,6 +9,7 @@ #include "Protocols/MaliciousRep3Share.h" #include "Protocols/MalRepRingShare.h" #include "Protocols/Rep3Share2k.h" +#include "GC/MaliciousRepSecret.h" template class MalRepRingPrepWithBits; template class PostSacrifice; diff --git a/Protocols/ProtocolSetup.h b/Protocols/ProtocolSetup.h index 2f417c427..63c471725 100644 --- a/Protocols/ProtocolSetup.h +++ b/Protocols/ProtocolSetup.h @@ -65,6 +65,14 @@ class ProtocolSetup { return mac_key; } + + /** + * Set how much preprocessing is produced at once. + */ + static void set_batch_size(size_t batch_size) + { + OnlineOptions::singleton.batch_size = batch_size; + } }; /** diff --git a/Protocols/Rep3Share.h b/Protocols/Rep3Share.h index fb02d26ff..cd321b26e 100644 --- a/Protocols/Rep3Share.h +++ b/Protocols/Rep3Share.h @@ -9,11 +9,13 @@ #include "Math/FixedVec.h" #include "Math/Integer.h" #include "Protocols/Replicated.h" +#include "Protocols/Rep3Shuffler.h" #include "GC/ShareSecret.h" #include "ShareInterface.h" #include "Processor/Instruction.h" template class ReplicatedPrep; +template class SemiRep3Prep; template class ReplicatedRingPrep; template class ReplicatedPO; template class SpecificPrivateOutput; @@ -109,7 +111,8 @@ class Rep3Share : public RepShare typedef ReplicatedInput Input; typedef ReplicatedPO PO; typedef SpecificPrivateOutput PrivateOutput; - typedef ReplicatedPrep LivePrep; + typedef typename conditional, SemiRep3Prep>::type LivePrep; typedef ReplicatedRingPrep TriplePrep; typedef Rep3Share Honest; diff --git a/Protocols/Rep3Share2k.h b/Protocols/Rep3Share2k.h index e52d160bb..0fc2e50e4 100644 --- a/Protocols/Rep3Share2k.h +++ b/Protocols/Rep3Share2k.h @@ -11,7 +11,7 @@ #include "Math/Z2k.h" #include "GC/square64.h" -template class ReplicatedPrep2k; +template class SemiRep3Prep; template class Rep3Share2 : public Rep3Share> @@ -26,7 +26,7 @@ class Rep3Share2 : public Rep3Share> typedef ReplicatedInput Input; typedef ReplicatedPO PO; typedef SpecificPrivateOutput PrivateOutput; - typedef ReplicatedPrep2k LivePrep; + typedef SemiRep3Prep LivePrep; typedef Rep3Share2 Honest; typedef SignedZ2 clear; diff --git a/Protocols/Rep3Shuffler.h b/Protocols/Rep3Shuffler.h new file mode 100644 index 000000000..ec80a48e4 --- /dev/null +++ b/Protocols/Rep3Shuffler.h @@ -0,0 +1,33 @@ +/* + * Rep3Shuffler.h + * + */ + +#ifndef PROTOCOLS_REP3SHUFFLER_H_ +#define PROTOCOLS_REP3SHUFFLER_H_ + +template +class Rep3Shuffler +{ + SubProcessor& proc; + + vector, 2>> shuffles; + +public: + Rep3Shuffler(vector& a, size_t n, int unit_size, size_t output_base, + size_t input_base, SubProcessor& proc); + + Rep3Shuffler(SubProcessor& proc); + + int generate(int n_shuffle); + + void apply(vector& a, size_t n, int unit_size, size_t output_base, + size_t input_base, int handle, bool reverse); + + void inverse_permutation(vector& stack, size_t n, size_t output_base, + size_t input_base); + + void del(int handle); +}; + +#endif /* PROTOCOLS_REP3SHUFFLER_H_ */ diff --git a/Protocols/Rep3Shuffler.hpp b/Protocols/Rep3Shuffler.hpp new file mode 100644 index 000000000..a2edfb76f --- /dev/null +++ b/Protocols/Rep3Shuffler.hpp @@ -0,0 +1,131 @@ +/* + * Rep3Shuffler.cpp + * + */ + +#ifndef PROTOCOLS_REP3SHUFFLER_HPP_ +#define PROTOCOLS_REP3SHUFFLER_HPP_ + +#include "Rep3Shuffler.h" + +template +Rep3Shuffler::Rep3Shuffler(vector& a, size_t n, int unit_size, + size_t output_base, size_t input_base, SubProcessor& proc) : + proc(proc) +{ + apply(a, n, unit_size, output_base, input_base, generate(n / unit_size), + false); + shuffles.pop_back(); +} + +template +Rep3Shuffler::Rep3Shuffler(SubProcessor& proc) : + proc(proc) +{ +} + +template +int Rep3Shuffler::generate(int n_shuffle) +{ + shuffles.push_back({}); + auto& shuffle = shuffles.back(); + for (int i = 0; i < 2; i++) + { + auto& perm = shuffle[i]; + for (int j = 0; j < n_shuffle; j++) + perm.push_back(j); + for (int j = 0; j < n_shuffle; j++) + { + int k = proc.protocol.shared_prngs[i].get_uint(n_shuffle - j); + swap(perm[k], perm[k + j]); + } + } + return shuffles.size() - 1; +} + +template +void Rep3Shuffler::apply(vector& a, size_t n, int unit_size, + size_t output_base, size_t input_base, int handle, bool reverse) +{ + assert(proc.P.num_players() == 3); + assert(not T::malicious); + assert(not T::dishonest_majority); + assert(n % unit_size == 0); + + auto& shuffle = shuffles.at(handle); + vector to_shuffle; + for (size_t i = 0; i < n; i++) + to_shuffle.push_back(a[input_base + i]); + + typename T::Input input(proc); + + vector to_share(n); + + for (int ii = 0; ii < 3; ii++) + { + int i; + if (reverse) + i = 2 - ii; + else + i = ii; + + if (proc.P.get_player(i) == 0) + { + for (size_t j = 0; j < n / unit_size; j++) + for (int k = 0; k < unit_size; k++) + if (reverse) + to_share.at(j * unit_size + k) = to_shuffle.at( + shuffle[0].at(j) * unit_size + k).sum(); + else + to_share.at(shuffle[0].at(j) * unit_size + k) = + to_shuffle.at(j * unit_size + k).sum(); + } + else if (proc.P.get_player(i) == 1) + { + for (size_t j = 0; j < n / unit_size; j++) + for (int k = 0; k < unit_size; k++) + if (reverse) + to_share[j * unit_size + k] = to_shuffle[shuffle[1][j] + * unit_size + k][0]; + else + to_share[shuffle[1][j] * unit_size + k] = to_shuffle[j + * unit_size + k][0]; + } + + input.reset_all(proc.P); + + if (proc.P.get_player(i) < 2) + for (auto& x : to_share) + input.add_mine(x); + + for (int k = 0; k < 2; k++) + input.add_other((-i + 3 + k) % 3); + + input.exchange(); + to_shuffle.clear(); + + for (size_t j = 0; j < n; j++) + { + T x = input.finalize((-i + 3) % 3) + input.finalize((-i + 4) % 3); + to_shuffle.push_back(x); + } + } + + for (size_t i = 0; i < n; i++) + a[output_base + i] = to_shuffle[i]; +} + +template +void Rep3Shuffler::del(int handle) +{ + for (int i = 0; i < 2; i++) + shuffles.at(handle)[i].clear(); +} + +template +void Rep3Shuffler::inverse_permutation(vector&, size_t, size_t, size_t) +{ + throw runtime_error("inverse permutation not implemented"); +} + +#endif diff --git a/Protocols/Replicated.h b/Protocols/Replicated.h index 48b014408..05f132ce7 100644 --- a/Protocols/Replicated.h +++ b/Protocols/Replicated.h @@ -20,6 +20,8 @@ template class SubProcessor; template class ReplicatedMC; template class ReplicatedInput; template class Preprocessing; +template class SecureShuffle; +template class Rep3Shuffler; class Instruction; /** @@ -59,6 +61,8 @@ class ProtocolBase public: typedef T share_type; + typedef SecureShuffle Shuffler; + int counter; ProtocolBase(); @@ -81,6 +85,7 @@ class ProtocolBase virtual void init_mul() = 0; /// Schedule multiplication of operand pair virtual void prepare_mul(const T& x, const T& y, int n = -1) = 0; + virtual void prepare_mult(const T& x, const T& y, int n, bool repeat); /// Run multiplication protocol virtual void exchange() = 0; /// Get next multiplication result @@ -143,6 +148,8 @@ class Replicated : public ReplicatedBase, public ProtocolBase public: static const bool uses_triples = false; + typedef Rep3Shuffler Shuffler; + Replicated(Player& P); Replicated(const ReplicatedBase& other); diff --git a/Protocols/Replicated.hpp b/Protocols/Replicated.hpp index f398da7fe..494a7e0de 100644 --- a/Protocols/Replicated.hpp +++ b/Protocols/Replicated.hpp @@ -41,7 +41,7 @@ inline ReplicatedBase::ReplicatedBase(Player& P) : P(P) { assert(P.num_players() == 3); if (not P.is_encrypted()) - insecure("unencrypted communication"); + insecure("unencrypted communication", false); shared_prngs[0].ReSeed(); octetStream os; @@ -121,6 +121,13 @@ T ProtocolBase::mul(const T& x, const T& y) return finalize_mul(); } +template +void ProtocolBase::prepare_mult(const T& x, const T& y, int n, + bool) +{ + prepare_mul(x, y, n); +} + template void ProtocolBase::finalize_mult(T& res, int n) { diff --git a/Protocols/ReplicatedInput.h b/Protocols/ReplicatedInput.h index 9e1498df0..26451f170 100644 --- a/Protocols/ReplicatedInput.h +++ b/Protocols/ReplicatedInput.h @@ -17,19 +17,17 @@ template class PrepLessInput : public InputBase { protected: - vector shares; - size_t i_share; + PointerVector shares; public: PrepLessInput(SubProcessor* proc) : - InputBase(proc ? proc->Proc : 0), i_share(0) {} + InputBase(proc ? proc->Proc : 0) {} virtual ~PrepLessInput() {} virtual void reset(int player) = 0; virtual void add_mine(const typename T::open_type& input, int n_bits = -1) = 0; virtual void add_other(int player, int n_bits = - 1) = 0; - virtual void send_mine() = 0; virtual void finalize_other(int player, T& target, octetStream& o, int n_bits = -1) = 0; diff --git a/Protocols/ReplicatedInput.hpp b/Protocols/ReplicatedInput.hpp index 1cfac4a16..ffc34d6f6 100644 --- a/Protocols/ReplicatedInput.hpp +++ b/Protocols/ReplicatedInput.hpp @@ -19,7 +19,6 @@ void ReplicatedInput::reset(int player) if (player == P.my_num()) { this->shares.clear(); - this->i_share = 0; os.resize(2); for (auto& o : os) o.reset_write_head(); @@ -89,7 +88,7 @@ inline void ReplicatedInput::finalize_other(int player, T& target, template T PrepLessInput::finalize_mine() { - return this->shares[this->i_share++]; + return this->shares.next(); } #endif diff --git a/Protocols/ReplicatedPrep.h b/Protocols/ReplicatedPrep.h index 8a30749c3..e73d9cc2c 100644 --- a/Protocols/ReplicatedPrep.h +++ b/Protocols/ReplicatedPrep.h @@ -38,6 +38,8 @@ class BufferPrep : public Preprocessing template void buffer_inverses(false_type) { throw runtime_error("no inverses"); } + virtual bool bits_from_dabits() { return false; } + protected: vector> triples; vector> squares; diff --git a/Protocols/ReplicatedPrep.hpp b/Protocols/ReplicatedPrep.hpp index a172b05b3..867b844d1 100644 --- a/Protocols/ReplicatedPrep.hpp +++ b/Protocols/ReplicatedPrep.hpp @@ -7,6 +7,7 @@ #define PROTOCOlS_REPLICATEDPREP_HPP_ #include "ReplicatedPrep.h" +#include "SemiRep3Prep.h" #include "DabitSacrifice.h" #include "Spdz2kPrep.h" @@ -64,17 +65,24 @@ BufferPrep::~BufferPrep() * T::default_length); size_t used_bits = my_usage.at(DATA_BIT); - if (not T::clear::invertible and field_type == DATA_INT and not T::has_mac) - // add dabits with computation modulo power of two but without MAC - used_bits += my_usage.at(DATA_DABIT); + size_t used_dabits = my_usage.at(DATA_DABIT); + if (bits_from_dabits()) + { + if (not T::clear::invertible and field_type == DATA_INT and not T::has_mac) + // add dabits with computation modulo power of two but without MAC + used_bits += my_usage.at(DATA_DABIT); + } + else + used_dabits += used_bits; + this->print_left("bits", bits.size(), type_string, used_bits); + this->print_left("dabits", dabits.size(), type_string, used_dabits); #define X(KIND, TYPE) \ this->print_left(#KIND, KIND.size(), type_string, \ this->usage.files.at(T::clear::field_type()).at(TYPE)); X(squares, DATA_SQUARE) X(inverses, DATA_INVERSE) - X(dabits, DATA_DABIT) #undef X for (auto& x : this->edabits) @@ -549,7 +557,8 @@ void MaliciousRingPrep::buffer_personal_edabits(int n_bits, vector& wholes int start = queues->distribute(job, buffer_size, 0, BT::default_length); this->template buffer_personal_edabits_without_check<0>(n_bits, sums, bits, proc, input_player, start, buffer_size); - queues->wrap_up(job); + if (start) + queues->wrap_up(job); } else this->template buffer_personal_edabits_without_check<0>(n_bits, sums, @@ -651,12 +660,78 @@ void RingPrep::buffer_dabits_without_check(vector>& dabits, int start = queues->distribute(job, buffer_size, old_size); this->buffer_dabits_without_check(dabits, start, dabits.size()); - queues->wrap_up(job); + if (start > old_size) + queues->wrap_up(job); } else buffer_dabits_without_check(dabits, old_size, dabits.size()); } +template +void SemiRep3Prep::buffer_dabits(ThreadQueues*) +{ + assert(this->protocol); + assert(this->proc); + + typedef typename T::bit_type BT; + int n_blocks = DIV_CEIL(this->buffer_size, BT::default_length); + int n_bits = n_blocks * BT::default_length; + + vector b(n_blocks); + + vector> a(n_bits); + Player& P = this->proc->P; + + for (int i = 0; i < 2; i++) + { + for (auto& x : b) + x[i].randomize(this->protocol->shared_prngs[i]); + + int j = P.get_offset(i); + + for (int k = 0; k < n_bits; k++) + a[k][j][i] = b[k / BT::default_length][i].get_bit( + k % BT::default_length); + } + + // the first multiplication + vector first(n_bits), second(n_bits); + typename T::Input input(P); + + if (P.my_num() == 0) + { + for (auto& x : a) + input.add_mine(x[0][0] * x[1][1]); + } + else + input.add_other(0); + + input.exchange(); + + for (int k = 0; k < n_bits; k++) + first[k] = a[k][0] + a[k][1] - 2 * input.finalize(0); + + input.reset_all(P); + + if (P.my_num() != 0) + { + for (int k = 0; k < n_bits; k++) + input.add_mine(first[k].local_mul(a[k][2])); + } + + input.add_other(1); + input.add_other(2); + input.exchange(); + + for (int k = 0; k < n_bits; k++) + { + second[k] = first[k] + a[k][2] + - 2 * (input.finalize(1) + input.finalize(2)); + this->dabits.push_back({second[k], + b[k / BT::default_length].get_bit(k % BT::default_length)}); + } +} + template void RingPrep::buffer_dabits_without_check(vector>& dabits, size_t begin, size_t end) @@ -718,7 +793,8 @@ void RingPrep::buffer_edabits_without_check(int n_bits, vector& sums, ThreadJob job(n_bits, &sums, &bits); int start = queues->distribute(job, rounded, 0, dl); buffer_edabits_without_check<0>(n_bits, sums, bits, start, rounded); - queues->wrap_up(job); + if (start) + queues->wrap_up(job); } else buffer_edabits_without_check<0>(n_bits, sums, bits, 0, rounded); @@ -844,7 +920,8 @@ void RingPrep::sanitize(vector>& edabits, int n_bits, SanitizeJob job(&edabits, n_bits, player); int start = queues->distribute(job, edabits.size()); sanitize<0>(edabits, n_bits, player, start, edabits.size()); - queues->wrap_up(job); + if (start) + queues->wrap_up(job); } else sanitize<0>(edabits, n_bits, player, 0, edabits.size()); @@ -1027,6 +1104,7 @@ void BufferPrep::get_dabit_no_count(T& a, typename T::bit_type& b) InScope in_scope(this->do_count, false); ThreadQueues* queues = 0; buffer_dabits(queues); + assert(not dabits.empty()); } a = dabits.back().first; b = dabits.back().second; @@ -1085,7 +1163,7 @@ template void BufferPrep::buffer_edabits_with_queues(bool strict, int n_bits) { ThreadQueues* queues = 0; - if (BaseMachine::thread_num == 0) + if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton()) queues = &BaseMachine::s().queues; buffer_edabits(strict, n_bits, queues); } diff --git a/Protocols/SecureShuffle.hpp b/Protocols/SecureShuffle.hpp index d8c3d8e61..752798b2a 100644 --- a/Protocols/SecureShuffle.hpp +++ b/Protocols/SecureShuffle.hpp @@ -21,7 +21,7 @@ SecureShuffle::SecureShuffle(SubProcessor& proc) : template SecureShuffle::SecureShuffle(vector& a, size_t n, int unit_size, size_t output_base, size_t input_base, SubProcessor& proc) : - proc(proc), unit_size(unit_size) + proc(proc), unit_size(unit_size), n_shuffle(0), exact(false) { pre(a, n, input_base); diff --git a/Protocols/Semi.h b/Protocols/Semi.h index 5f63a9d62..903aca6d1 100644 --- a/Protocols/Semi.h +++ b/Protocols/Semi.h @@ -59,7 +59,20 @@ class Semi : public SPDZ for (auto& info : infos) { if (not info.big_gap()) - throw runtime_error("bit length too large"); + { + if (not T::clear::invertible) + { + int min_size = 64 * DIV_CEIL( + info.k + OnlineOptions::singleton.trunc_error, 64); + throw runtime_error( + "Bit length too large for trunc_pr. " + "Disable it or increase the ring size " + "during compilation using '-R " + + to_string(min_size) + "'."); + } + else + throw runtime_error("bit length too large"); + } if (this->P.my_num()) for (int i = 0; i < size; i++) proc.get_S_ref(info.dest_base + i) = -open_type( diff --git a/Protocols/SemiInput.h b/Protocols/SemiInput.h index c40d0c170..d4c864f06 100644 --- a/Protocols/SemiInput.h +++ b/Protocols/SemiInput.h @@ -6,20 +6,28 @@ #ifndef PROTOCOLS_SEMIINPUT_H_ #define PROTOCOLS_SEMIINPUT_H_ -#include "ShamirInput.h" +#include "ReplicatedInput.h" template class SemiMC; +template +class PairwiseKeyInput : public PrepLessInput +{ +protected: + vector send_prngs; + vector recv_prngs; + +public: + PairwiseKeyInput(SubProcessor* proc, PlayerBase& P); +}; + /** * Additive secret sharing input protocol */ template -class SemiInput : public InputBase +class SemiInput : public PairwiseKeyInput { - vector send_prngs; - vector recv_prngs; PlayerBase& P; - vector> shares; public: SemiInput(SubProcessor& proc, SemiMC&) : diff --git a/Protocols/SemiInput.hpp b/Protocols/SemiInput.hpp index f0fefe137..7ab4a855a 100644 --- a/Protocols/SemiInput.hpp +++ b/Protocols/SemiInput.hpp @@ -12,9 +12,15 @@ template SemiInput::SemiInput(SubProcessor* proc, PlayerBase& P) : - InputBase(proc), P(P) + PairwiseKeyInput(proc, P), P(P) +{ + this->reset_all(P); +} + +template +PairwiseKeyInput::PairwiseKeyInput(SubProcessor* proc, PlayerBase& P) : + PrepLessInput(proc) { - shares.resize(P.num_players()); vector to_send(P.num_players()), to_receive; for (int i = 0; i < P.num_players(); i++) { @@ -26,13 +32,13 @@ SemiInput::SemiInput(SubProcessor* proc, PlayerBase& P) : for (int i = 0; i < P.num_players(); i++) if (i != P.my_num()) recv_prngs[i].SetSeed(to_receive[i].consume(SEED_SIZE)); - this->reset_all(P); } template void SemiInput::reset(int player) { - shares[player].clear(); + if (player == P.my_num()) + this->shares.clear(); } template @@ -43,9 +49,9 @@ void SemiInput::add_mine(const typename T::clear& input, int) for (int i = 0; i < P.num_players(); i++) { if (i != P.my_num()) - sum += send_prngs[i].template get(); + sum += this->send_prngs[i].template get(); } - shares[P.my_num()].push_back(input - sum); + this->shares.push_back(input - sum); } template @@ -62,13 +68,13 @@ template void SemiInput::finalize_other(int player, T& target, octetStream&, int) { - target = recv_prngs[player].template get(); + target = this->recv_prngs[player].template get(); } template T SemiInput::finalize_mine() { - return shares[P.my_num()].next(); + return this->shares.next(); } #endif diff --git a/Protocols/SemiMC.h b/Protocols/SemiMC.h index fe4d9db6c..27fd3b71f 100644 --- a/Protocols/SemiMC.h +++ b/Protocols/SemiMC.h @@ -15,13 +15,17 @@ template class SemiMC : public TreeSum, public MAC_Check_Base { +protected: + vector lengths; + public: // emulate MAC_Check SemiMC(const typename T::mac_key_type& _ = {}, int __ = 0, int ___ = 0) { (void)_; (void)__; (void)___; } virtual ~SemiMC() {} - virtual void prepare_open(const T& secret); + virtual void init_open(const Player& P, int n = 0); + virtual void prepare_open(const T& secret, int n_bits = -1); virtual void exchange(const Player& P); void Check(const Player& P) { (void)P; } diff --git a/Protocols/SemiMC.hpp b/Protocols/SemiMC.hpp index b54878577..75aa0c6eb 100644 --- a/Protocols/SemiMC.hpp +++ b/Protocols/SemiMC.hpp @@ -11,9 +11,18 @@ #include "MAC_Check.hpp" template -void SemiMC::prepare_open(const T& secret) +void SemiMC::init_open(const Player& P, int n) +{ + MAC_Check_Base::init_open(P, n); + lengths.clear(); + lengths.reserve(n); +} + +template +void SemiMC::prepare_open(const T& secret, int n_bits) { this->values.push_back(secret); + lengths.push_back(n_bits); } template @@ -28,6 +37,8 @@ void DirectSemiMC::POpen_(vector& values, { this->values.clear(); this->values.reserve(S.size()); + this->lengths.clear(); + this->lengths.reserve(S.size()); for (auto& secret : S) this->prepare_open(secret); this->exchange_(P); @@ -39,10 +50,20 @@ void DirectSemiMC::exchange_(const PlayerBase& P) { Bundle oss(P); oss.mine.reserve(this->values.size()); - for (auto& x : this->values) - x.pack(oss.mine); + assert(this->values.size() == this->lengths.size()); + for (size_t i = 0; i < this->lengths.size(); i++) + this->values[i].pack(oss.mine, this->lengths[i]); P.unchecked_broadcast(oss); - direct_add_openings(this->values, P, oss); + size_t n = P.num_players(); + size_t me = P.my_num(); + for (size_t i = 0; i < this->lengths.size(); i++) + for (size_t j = 0; j < n; j++) + if (j != me) + { + T tmp; + tmp.unpack(oss[j], this->lengths[i]); + this->values[i] += tmp; + } } template diff --git a/Protocols/SemiPrep.h b/Protocols/SemiPrep.h index 3580a73bf..9646e9453 100644 --- a/Protocols/SemiPrep.h +++ b/Protocols/SemiPrep.h @@ -8,18 +8,26 @@ #include "MascotPrep.h" +template class HemiPrep; + /** * Semi-honest triple generation based on oblivious transfer */ template class SemiPrep : public virtual OTPrep, public virtual SemiHonestRingPrep { + friend class HemiPrep; + public: SemiPrep(SubProcessor* proc, DataPositions& usage); void buffer_triples(); - void buffer_bits(); + void buffer_dabits(ThreadQueues* queues); + + void get_one_no_count(Dtype dtype, T& a); + + bool bits_from_dabits(); }; #endif /* PROTOCOLS_SEMIPREP_H_ */ diff --git a/Protocols/SemiPrep.hpp b/Protocols/SemiPrep.hpp index bc61787d4..f1ec6efd9 100644 --- a/Protocols/SemiPrep.hpp +++ b/Protocols/SemiPrep.hpp @@ -6,6 +6,8 @@ #include "SemiPrep.h" #include "ReplicatedPrep.hpp" +#include "MascotPrep.hpp" +#include "OT/NPartyTripleGenerator.hpp" template SemiPrep::SemiPrep(SubProcessor* proc, DataPositions& usage) : @@ -31,16 +33,37 @@ void SemiPrep::buffer_triples() } template -void SemiPrep::buffer_bits() +bool SemiPrep::bits_from_dabits() { assert(this->proc); - if (this->proc->P.num_players() == 2 and not T::clear::characteristic_two) + return this->proc->P.num_players() == 2 and not T::clear::characteristic_two; +} + +template +void SemiPrep::buffer_dabits(ThreadQueues* queues) +{ + if (bits_from_dabits()) { assert(this->triple_generator); this->triple_generator->generatePlainBits(); for (auto& x : this->triple_generator->plainBits) - this->bits.push_back(x); + this->dabits.push_back({x.first, x.second}); + } + else + SemiHonestRingPrep::buffer_dabits(queues); +} + +template +void SemiPrep::get_one_no_count(Dtype dtype, T& a) +{ + if (bits_from_dabits()) + { + if (dtype != DATA_BIT) + throw not_implemented(); + + typename T::bit_type b; + this->get_dabit_no_count(a, b); } else - SemiHonestRingPrep::buffer_bits(); + SemiHonestRingPrep::get_one_no_count(dtype, a); } diff --git a/Protocols/SemiPrep2k.h b/Protocols/SemiPrep2k.h index 50311c594..49ccca479 100644 --- a/Protocols/SemiPrep2k.h +++ b/Protocols/SemiPrep2k.h @@ -49,6 +49,12 @@ class SemiPrep2k : public SemiPrep, public RepRingOnlyEdabitPrep void get_dabit_no_count(T& a, typename T::bit_type& b) { + if (this->bits_from_dabits()) + { + SemiPrep::get_dabit_no_count(a, b); + return; + } + this->get_one_no_count(DATA_BIT, a); b = a & 1; } diff --git a/Protocols/ReplicatedPrep2k.h b/Protocols/SemiRep3Prep.h similarity index 51% rename from Protocols/ReplicatedPrep2k.h rename to Protocols/SemiRep3Prep.h index da35865e4..5d68f03ea 100644 --- a/Protocols/ReplicatedPrep2k.h +++ b/Protocols/SemiRep3Prep.h @@ -3,8 +3,8 @@ * */ -#ifndef PROTOCOLS_REPLICATEDPREP2K_H_ -#define PROTOCOLS_REPLICATEDPREP2K_H_ +#ifndef PROTOCOLS_SEMIREP3PREP_H_ +#define PROTOCOLS_SEMIREP3PREP_H_ #include "ReplicatedPrep.h" @@ -12,11 +12,13 @@ * Preprocessing for three-party replicated secret sharing modulo a power of two */ template -class ReplicatedPrep2k : public virtual SemiHonestRingPrep, +class SemiRep3Prep : public virtual SemiHonestRingPrep, public virtual ReplicatedRingPrep { + void buffer_dabits(ThreadQueues*); + public: - ReplicatedPrep2k(SubProcessor* proc, DataPositions& usage) : + SemiRep3Prep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), RingPrep(proc, usage), SemiHonestRingPrep(proc, usage), ReplicatedRingPrep(proc, usage) @@ -25,11 +27,14 @@ class ReplicatedPrep2k : public virtual SemiHonestRingPrep, void buffer_bits() { this->buffer_bits_without_check(); } - void get_dabit_no_count(T& a, typename T::bit_type& b) + void get_one_no_count(Dtype dtype, T& a) { - this->get_one_no_count(DATA_BIT, a); - b = a & 1; + if (dtype != DATA_BIT) + throw not_implemented(); + + typename T::bit_type b; + this->get_dabit_no_count(a, b); } }; -#endif /* PROTOCOLS_REPLICATEDPREP2K_H_ */ +#endif /* PROTOCOLS_SEMIREP3PREP_H_ */ diff --git a/Protocols/Shamir.h b/Protocols/Shamir.h index 402173e98..db056ae49 100644 --- a/Protocols/Shamir.h +++ b/Protocols/Shamir.h @@ -49,7 +49,8 @@ class Shamir : public ProtocolBase Player& P; static U get_rec_factor(int i, int n); - static U get_rec_factor(int i, int n_total, int start, int threshold); + static U get_rec_factor(int i, int n_total, int start, int threshold, + int target = -1); Shamir(Player& P, int threshold = 0); ~Shamir(); diff --git a/Protocols/Shamir.hpp b/Protocols/Shamir.hpp index 8bfdf70ea..89fa6853e 100644 --- a/Protocols/Shamir.hpp +++ b/Protocols/Shamir.hpp @@ -20,14 +20,24 @@ typename T::open_type::Scalar Shamir::get_rec_factor(int i, int n) template typename T::open_type::Scalar Shamir::get_rec_factor(int i, int n_total, - int start, int n_points) + int start, int n_points, int target) { U res = 1; for (int j = 0; j < n_points; j++) { - int other = positive_modulo(start + j, n_total); + int other; + if (n_total > 0) + other = positive_modulo(start + j, n_total); + else + other = start + j; if (i != other) - res *= U(other + 1) / (U(other + 1) - U(i + 1)); + { + res *= (U(other + 1) - U(target + 1)) / (U(other + 1) - U(i + 1)); +#ifdef DEBUG_SHAMIR + cout << "res=" << res << " other+1=" << (other + 1) << " target=" + << target << " i+1=" << (i + 1) << endl; +#endif + } } return res; } @@ -43,6 +53,7 @@ Shamir::Shamir(Player& P, int t) : else threshold = ShamirMachine::s().threshold; n_mul_players = 2 * threshold + 1; + resharing = new ShamirInput(0, P); } template @@ -69,11 +80,6 @@ int Shamir::get_n_relevant_players() template void Shamir::reset() { - if (resharing == 0) - { - resharing = new ShamirInput(0, P); - } - for (int i = 0; i < P.num_players(); i++) resharing->reset(i); diff --git a/Protocols/ShamirInput.h b/Protocols/ShamirInput.h index 91e093091..eaa72f2d3 100644 --- a/Protocols/ShamirInput.h +++ b/Protocols/ShamirInput.h @@ -8,7 +8,7 @@ #include "Processor/Input.h" #include "Shamir.h" -#include "ReplicatedInput.h" +#include "SemiInput.h" #include "Machines/ShamirMachine.h" /** @@ -16,7 +16,7 @@ * to every other player */ template -class IndividualInput : public PrepLessInput +class IndividualInput : public PairwiseKeyInput { protected: Player& P; @@ -25,7 +25,7 @@ class IndividualInput : public PrepLessInput public: IndividualInput(SubProcessor* proc, Player& P) : - PrepLessInput(proc), P(P), senders(P.num_players()) + PairwiseKeyInput(proc, P), P(P), senders(P.num_players()) { this->reset_all(P); } @@ -53,14 +53,14 @@ class ShamirInput : public IndividualInput { friend class Shamir; - vector> vandermonde; - - SeededPRNG secure_prng; + vector> reconstruction; vector randomness; int threshold; + void init(); + public: static vector> get_vandermonde(size_t t, size_t n); @@ -79,6 +79,7 @@ class ShamirInput : public IndividualInput else threshold = ShamirMachine::s().threshold; + init(); } ShamirInput(ShamirMC&, Preprocessing&, Player& P) : @@ -87,6 +88,7 @@ class ShamirInput : public IndividualInput } void add_mine(const typename T::open_type& input, int n_bits = -1); + void finalize_other(int player, T& target, octetStream& o, int n_bits = -1); }; #endif /* PROTOCOLS_SHAMIRINPUT_H_ */ diff --git a/Protocols/ShamirInput.hpp b/Protocols/ShamirInput.hpp index 6d9992ad7..41c880121 100644 --- a/Protocols/ShamirInput.hpp +++ b/Protocols/ShamirInput.hpp @@ -10,6 +10,7 @@ #include "Machines/ShamirMachine.h" #include "Protocols/ReplicatedInput.hpp" +#include "Protocols/SemiInput.hpp" template void IndividualInput::reset(int player) @@ -17,7 +18,6 @@ void IndividualInput::reset(int player) if (player == P.my_num()) { this->shares.clear(); - this->i_share = 0; os.reset(P); } @@ -45,6 +45,20 @@ vector> ShamirInput::get_vandermonde( return vandermonde; } +template +void ShamirInput::init() +{ + reconstruction.resize(this->P.num_players() - threshold); + for (size_t i = 0; i < reconstruction.size(); i++) + { + auto& x = reconstruction[i]; + for (int j = 0; j <= threshold; j++) + x.push_back( + Shamir::get_rec_factor(j - 1, 0, -1, threshold + 1, + i + threshold)); + } +} + template void ShamirInput::add_mine(const typename T::open_type& input, int n_bits) { @@ -53,18 +67,20 @@ void ShamirInput::add_mine(const typename T::open_type& input, int n_bits) int n = P.num_players(); int t = threshold; - if (vandermonde.empty()) - vandermonde = get_vandermonde(t, n); - randomness.resize(t); - for (auto& x : randomness) - x.randomize(secure_prng); + for (int i = 0; i < t; i++) + { + randomness[i].randomize(this->send_prngs[i]); + if (i == P.my_num()) + this->shares.push_back(randomness[i]); + } - for (int i = 0; i < n; i++) + for (int i = threshold; i < n; i++) { - typename T::open_type x = input; + typename T::open_type x = input + * reconstruction.at(i - threshold).at(0); for (int j = 0; j < t; j++) - x += randomness[j] * vandermonde[i][j]; + x += randomness[j] * reconstruction.at(i - threshold).at(j + 1); if (i == P.my_num()) this->shares.push_back(x); else @@ -74,6 +90,16 @@ void ShamirInput::add_mine(const typename T::open_type& input, int n_bits) this->senders[P.my_num()] = true; } +template +void ShamirInput::finalize_other(int player, T& target, + octetStream& o, int n_bits) +{ + if (this->P.my_num() < threshold) + target.randomize(this->recv_prngs.at(player)); + else + IndividualInput::finalize_other(player, target, o, n_bits); +} + template void IndividualInput::add_sender(int player) { diff --git a/Protocols/ShamirMC.h b/Protocols/ShamirMC.h index c6a88f0ad..bd0cc3176 100644 --- a/Protocols/ShamirMC.h +++ b/Protocols/ShamirMC.h @@ -67,7 +67,7 @@ class ShamirMC : public IndirectShamirMC void POpen_End(vector& values,const vector& S,const Player& P); virtual void init_open(const Player& P, int n = 0); - virtual void prepare_open(const T& secret); + virtual void prepare_open(const T& secret, int = -1); virtual void exchange(const Player& P); virtual typename T::open_type finalize_raw(); diff --git a/Protocols/ShamirMC.hpp b/Protocols/ShamirMC.hpp index 7238aa5ef..585a6896b 100644 --- a/Protocols/ShamirMC.hpp +++ b/Protocols/ShamirMC.hpp @@ -72,7 +72,7 @@ void ShamirMC::prepare(const vector& S, const Player& P) } template -void ShamirMC::prepare_open(const T& share) +void ShamirMC::prepare_open(const T& share, int) { share.pack(os->mine); } diff --git a/Protocols/ShamirShare.h b/Protocols/ShamirShare.h index bf40cb287..318f050dd 100644 --- a/Protocols/ShamirShare.h +++ b/Protocols/ShamirShare.h @@ -9,6 +9,7 @@ #include "Protocols/Shamir.h" #include "Protocols/ShamirInput.h" #include "Machines/ShamirMachine.h" +#include "GC/NoShare.h" #include "ShareInterface.h" template class ReplicatedPrep; diff --git a/Protocols/ShuffleSacrifice.hpp b/Protocols/ShuffleSacrifice.hpp index 4d03dd67e..150cdb610 100644 --- a/Protocols/ShuffleSacrifice.hpp +++ b/Protocols/ShuffleSacrifice.hpp @@ -141,7 +141,8 @@ void DabitShuffleSacrifice::dabit_sacrifice(vector >& output, int start = queues->distribute(job, products.size()); protocol.multiply(products, multiplicands, start, products.size(), proc); - queues->wrap_up(job); + if (start) + queues->wrap_up(job); } else protocol.multiply(products, multiplicands, 0, products.size(), proc); @@ -311,7 +312,8 @@ void EdabitShuffleSacrifice::edabit_sacrifice(vector >& output, &supplies); edabit_sacrifice_buckets(to_check, n_bits, strict, player, proc, start, N, personal_prep); - queues->wrap_up(job); + if (start) + queues->wrap_up(job); } else edabit_sacrifice_buckets(to_check, n_bits, strict, player, proc, 0, N, diff --git a/Protocols/SpdzWiseMC.h b/Protocols/SpdzWiseMC.h index 9ad761985..f48adfc9a 100644 --- a/Protocols/SpdzWiseMC.h +++ b/Protocols/SpdzWiseMC.h @@ -36,7 +36,7 @@ class SpdzWiseMC : public MAC_Check_Base { inner_MC.init_open(P, n); } - void prepare_open(const T& secret) + void prepare_open(const T& secret, int = -1) { inner_MC.prepare_open(secret.get_share()); } diff --git a/README.md b/README.md index 5f3fa2a46..c14f41ce9 100644 --- a/README.md +++ b/README.md @@ -270,8 +270,8 @@ compute the preprocessing time for a particular computation. #### Requirements - - GCC 5 or later (tested with up to 11) or LLVM/clang 5 or later - (tested with up to 12). We recommend clang because it performs + - GCC 5 or later (tested with up to 11) or LLVM/clang 6 or later + (tested with up to 14). We recommend clang because it performs better. Note that GCC 5/6 and clang 9 don't support libOTe, so you need to deactivate its use for these compilers (see the next section). @@ -694,7 +694,7 @@ Compile the virtual machine: and the high-level program: -`./compile.py -B ` +`./compile.py -G -B ` Then run as follows: @@ -874,7 +874,7 @@ three parties, change the definition of `MAX_N_PARTIES` in In order to compile a high-level program, use `./compile.py -B`: -`./compile.py -B 32 tutorial` +`./compile.py -G -B 32 tutorial` Finally, run the two parties as follows: @@ -1004,7 +1004,7 @@ you entirely delete the definition, it will be able to run for any number of parties albeit slower. Compile the virtual machine: -`make -j 8 libote` + `make -j 8 bmr` After compiling the mpc file: @@ -1020,7 +1020,7 @@ You can benchmark the ORAM implementation as follows: 1) Edit `Program/Source/gc_oram.mpc` to change size and to choose Circuit ORAM or linear scan without ORAM. -2) Run `./compile.py -D gc_oram`. The `-D` argument instructs the +2) Run `./compile.py -G -D gc_oram`. The `-D` argument instructs the compiler to remove dead code. This is useful for more complex programs such as this one. 3) Run `gc_oram` in the virtual machines as explained above. diff --git a/Scripts/build.sh b/Scripts/build.sh index c541152ae..1c3f72866 100755 --- a/Scripts/build.sh +++ b/Scripts/build.sh @@ -6,7 +6,8 @@ function build echo GDEBUG = >> CONFIG.mine root=`pwd` cd deps/libOTe - python3 build.py --install=$root/local -- -DENABLE_SOFTSPOKEN_OT=ON -DBUILD_SHARED_LIBS=0 $3 + rm -R out + python3 build.py --install=$root/local -- -DENABLE_SOFTSPOKEN_OT=ON -DBUILD_SHARED_LIBS=0 -DCMAKE_INSTALL_LIBDIR=lib $3 cd $root make clean rm -R static diff --git a/Scripts/compile-for-emulation.sh b/Scripts/compile-for-emulation.sh new file mode 100755 index 000000000..b808ef0cd --- /dev/null +++ b/Scripts/compile-for-emulation.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +pypy3 ./compile.py -CDR 64 -K '' $* diff --git a/Scripts/emulate-append.sh b/Scripts/emulate-append.sh new file mode 100755 index 000000000..554752107 --- /dev/null +++ b/Scripts/emulate-append.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +. $(dirname $0)/run-common.sh +prog=${1%.sch} +prog=${prog##*/} +shift +$prefix ./emulate.x $prog $* 2>&1 | tee -a logs/emulate-append-$prog diff --git a/Scripts/run-common.sh b/Scripts/run-common.sh index c6835069f..fe3c54e71 100644 --- a/Scripts/run-common.sh +++ b/Scripts/run-common.sh @@ -41,12 +41,21 @@ run_player() { if test "$prog"; then log_prefix=$prog- fi + if test "$BENCH"; then + log_prefix=$log_prefix$bin-$(echo "$*" | sed 's/ /-/g')-N$players- + fi set -o pipefail for i in $(seq 0 $[players-1]); do >&2 echo Running $prefix $SPDZROOT/$bin $i $params log=$SPDZROOT/logs/$log_prefix$i $prefix $SPDZROOT/$bin $i $params 2>&1 | - { if test $i = 0; then tee $log; else cat > $log; fi; } & + { + if test "$BENCH"; then + if test $i = 0; then tee -a $log; else cat >> $log; fi; + else + if test $i = 0; then tee $log; else cat > $log; fi; + fi + } & codes[$i]=$! done for i in $(seq 0 $[players-1]); do diff --git a/Scripts/test_tutorial.sh b/Scripts/test_tutorial.sh index e58edef8e..60157c934 100755 --- a/Scripts/test_tutorial.sh +++ b/Scripts/test_tutorial.sh @@ -41,6 +41,7 @@ function test_vm run_opts="$run_opts -B 5" export PORT=$((RANDOM%10000+10000)) +export BENCH= for dabit in ${dabit:-0 1 2}; do if [[ $dabit = 1 ]]; then diff --git a/Scripts/tldr.sh b/Scripts/tldr.sh index ce906a3a3..bd5b396a7 100755 --- a/Scripts/tldr.sh +++ b/Scripts/tldr.sh @@ -24,6 +24,11 @@ if test "$flags"; then if $flags | grep -q avx2; then cpu=avx2 else + if test `uname -m` != x86_64; then + echo Binaries are not available for `uname -m` + echo Use the source distribution: https://github.com/data61/MP-SPDZ/#tldr-source-distribution + exit 1 + fi cpu=amd64 fi diff --git a/Tools/ExecutionStats.cpp b/Tools/ExecutionStats.cpp index bbc36dca6..daa2309aa 100644 --- a/Tools/ExecutionStats.cpp +++ b/Tools/ExecutionStats.cpp @@ -26,6 +26,7 @@ void ExecutionStats::print() { sorted_stats.insert({x.second, x.first}); } + size_t total = 0; for (auto& x : sorted_stats) { auto opcode = x.second; @@ -35,7 +36,7 @@ void ExecutionStats::print() switch (opcode) { #define X(NAME, PRE, CODE) case NAME: cerr << #NAME; n_fill -= strlen(#NAME); break; - ARITHMETIC_INSTRUCTIONS + ALL_INSTRUCTIONS #undef X #define X(NAME, CODE) case NAME: cerr << #NAME; n_fill -= strlen(#NAME); break; COMBI_INSTRUCTIONS @@ -48,5 +49,7 @@ void ExecutionStats::print() for (int i = 0; i < n_fill; i++) cerr << " "; cerr << dec << calls << endl; + total += calls; } + cerr << "\tTotal:" << string(9, ' ') << total << endl; } diff --git a/Tools/names.cpp b/Tools/names.cpp index 062beb022..220263c89 100644 --- a/Tools/names.cpp +++ b/Tools/names.cpp @@ -2,4 +2,4 @@ const char* DataPositions::dtype_names[N_DTYPE + 1] = { "Triples", "Squares", "Bits", "Inverses", - "daBits", "None" }; + "daBits", "Mixed triples", "None" }; diff --git a/Utils/Check-Offline.cpp b/Utils/Check-Offline.cpp index 203328f22..aec3c5954 100644 --- a/Utils/Check-Offline.cpp +++ b/Utils/Check-Offline.cpp @@ -15,6 +15,7 @@ #include "GC/TinierSecret.h" #include "GC/TinyMC.h" #include "GC/SemiSecret.h" +#include "GC/RepPrep.h" #include "Math/Setup.h" #include "Processor/Data_Files.h" diff --git a/Utils/binary-example.cpp b/Utils/binary-example.cpp index d13f79d33..a07d74276 100644 --- a/Utils/binary-example.cpp +++ b/Utils/binary-example.cpp @@ -29,6 +29,7 @@ #include "Protocols/fake-stuff.hpp" #include "Machines/ShamirMachine.hpp" #include "Machines/Rep4.hpp" +#include "Machines/Rep.hpp" template void run(int argc, char** argv); diff --git a/Utils/l2h-example.cpp b/Utils/l2h-example.cpp index 475bcb8aa..91ce7f0be 100644 --- a/Utils/l2h-example.cpp +++ b/Utils/l2h-example.cpp @@ -7,6 +7,7 @@ #include "Math/gfp.hpp" #include "Machines/SPDZ.hpp" +#include "Protocols/MascotPrep.hpp" int main(int argc, char** argv) { diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 5dbb1fe7b..f026f3005 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -22,7 +22,7 @@ steps: - script: echo MY_CFLAGS += -DFEWER_RINGS >> CONFIG.mine - script: - echo MY_CFLAGS += -DCOMP_SEC=64 >> CONFIG.mine + echo MY_CFLAGS += -DCOMP_SEC=10 >> CONFIG.mine - script: echo CXX = clang++ >> CONFIG.mine - script: diff --git a/doc/Compiler.rst b/doc/Compiler.rst index db5c1e9c2..34343c51e 100644 --- a/doc/Compiler.rst +++ b/doc/Compiler.rst @@ -77,6 +77,13 @@ Compiler.ml module :show-inheritance: .. autofunction:: approx_sigmoid +Compiler.decision_tree module +----------------------------- + +.. automodule:: Compiler.decision_tree + :members: + :no-undoc-members: + Compiler.circuit module ----------------------- @@ -112,3 +119,13 @@ Compiler.oram module TrivialORAMIndexStructure, ValueTuple, demux, get_log_value_size, get_parallel, get_value_size, gf2nBlock, intBlock + + +Compiler.sqrt_oram module +------------------------- + +.. automodule:: Compiler.sqrt_oram + :members: + :no-undoc-members: + :exclude-members: LinearPositionMap, PositionMap, RecursivePositionMap, + refresh, shuffle_the_shuffle diff --git a/doc/Doxyfile b/doc/Doxyfile index 8420157a5..f82046ebc 100644 --- a/doc/Doxyfile +++ b/doc/Doxyfile @@ -933,7 +933,7 @@ EXCLUDE_SYMLINKS = NO # Note that the wildcards are matched against the file with absolute path, so to # exclude all test directories for example use the pattern */test/* -EXCLUDE_PATTERNS = +EXCLUDE_PATTERNS = *.d # The EXCLUDE_SYMBOLS tag can be used to specify one or more symbol names # (namespaces, classes, functions, etc.) that should be excluded from the diff --git a/doc/compilation.rst b/doc/compilation.rst index f01581ff4..01753edde 100644 --- a/doc/compilation.rst +++ b/doc/compilation.rst @@ -54,6 +54,11 @@ The following options influence the computation domain: Compile for binary computation using *integer length* as default. +.. cmdoption:: -G + --garbled-circuit + + Compile for garbled circuits (does not replace :option:`-B`). + For arithmetic computation (:option:`-F`, :option:`-P`, and :option:`-R`) you can set the bit length during execution using ``program.set_bit_length(length)``. For diff --git a/doc/index.rst b/doc/index.rst index 0abbac7c8..648546c89 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -6,6 +6,7 @@ If you're new to MP-SPDZ, consider the following: 1. `Quickstart tutorial `_ 2. `Implemented protocols `_ 3. :ref:`troubleshooting` +4. :ref:`io` lists all the ways of getting data in and out. .. toctree:: :maxdepth: 4 diff --git a/doc/io.rst b/doc/io.rst index a4d00cee8..50128d945 100644 --- a/doc/io.rst +++ b/doc/io.rst @@ -1,3 +1,5 @@ +.. _io: + Input/Output ------------ diff --git a/doc/machine-learning.rst b/doc/machine-learning.rst index 084bc1c80..54764e37f 100644 --- a/doc/machine-learning.rst +++ b/doc/machine-learning.rst @@ -5,6 +5,9 @@ MP-SPDZ supports a limited subset of the Keras interface for machine learning. This includes the SGD and Adam optimizers and the following layer types: dense, 2D convolution, 2D max-pooling, and dropout. +The machine learning code only works in with arithmetic machines, that +is, you cannot compile it with ``-B``. + In the following we will walk through the example code in ``keras_mnist_dense.mpc``, which trains a dense neural network for MNIST. It starts by defining tensors to hold data:: diff --git a/doc/non-linear.rst b/doc/non-linear.rst index 969e6d6c3..4687cc637 100644 --- a/doc/non-linear.rst +++ b/doc/non-linear.rst @@ -1,3 +1,5 @@ +.. _nonlinear: + Non-linear Computation ---------------------- @@ -8,14 +10,14 @@ throughout MP-SPDZ: Unknown prime modulus This approach goes back to `Catrina and de Hoogh - `_. It crucially relies on + `_. It crucially relies on the use of secret random bits in the arithmetic domain. Enough such bits allow to mask a secret value so that it is secure to reveal the masked value. This can then be split in bits as it is public. The public bits and the secret mask bits are then used to compute a number of non-linear functions. The same idea has been used to implement `fixed-point - `_ and + `_ and `floating-point `_ computation. We call this method "unknown prime modulus" because it only mandates a minimum modulus size for a given cleartext range, which diff --git a/doc/troubleshooting.rst b/doc/troubleshooting.rst index f76d2d5b1..6a32bd37e 100644 --- a/doc/troubleshooting.rst +++ b/doc/troubleshooting.rst @@ -140,6 +140,16 @@ This indicates an error in the internal accounting of preprocessing. Please file a bug report. +Required prime bit length is not the same as ``-F`` parameter during compilation +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This is related to statistical masking that requires the prime to be a +fair bit larger than the actual "payload". The technique goes to back +to `Catrina and de Hoogh +`_. +See also the paragraph on unknown prime moduli in :ref:`nonlinear`. + + Windows/VirtualBox performance ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~