diff --git a/.gitignore b/.gitignore index 1a3b15339..ce68ca4e4 100644 --- a/.gitignore +++ b/.gitignore @@ -64,6 +64,7 @@ Programs/Public-Input/* *.d local/ *.map +*.gch # Packages # ############ diff --git a/BMR/Register.cpp b/BMR/Register.cpp index 04af155f4..a5499e13d 100644 --- a/BMR/Register.cpp +++ b/BMR/Register.cpp @@ -574,11 +574,11 @@ void EvalRegister::inputb(GC::Processor >& processor, } for (auto& access : accesses) access.prepare_masks(my_os); - party.P->Broadcast_Receive(oss, true); + party.P->unchecked_broadcast(oss); my_os.reset_write_head(); for (auto& access : accesses) access.received_masks(oss); - party.P->Broadcast_Receive(oss, true); + party.P->unchecked_broadcast(oss); for (auto& access : accesses) access.received_labels(oss); } @@ -622,7 +622,7 @@ void EvalRegister::other_input(EvalInputter& inputter, int from) void EvalInputter::exchange() { - party.P->Broadcast_Receive(oss, true); + party.P->unchecked_broadcast(oss); for (auto& tuple : tuples) { if (tuple.from != party.P->my_num()) @@ -646,7 +646,7 @@ void EvalInputter::exchange() #endif } - party.P->Broadcast_Receive(oss, true); + party.P->unchecked_broadcast(oss); } void EvalRegister::finalize_input(EvalInputter& inputter, int, int) diff --git a/BMR/Register.hpp b/BMR/Register.hpp index 8c8c9a976..fe9c44e61 100644 --- a/BMR/Register.hpp +++ b/BMR/Register.hpp @@ -236,7 +236,7 @@ void EvalRegister::load(vector >& accesses, } } - party.P->Broadcast_Receive(keys, true); + party.P->unchecked_broadcast(keys); int base = 0; for (auto access : accesses) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0cbe8a95f..78748b74a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.2.1 (Dec 11, 2020) + +- Virtual machines automatically use the modulus used during compilation +- Non-linear computation modulo a prime without large gap in bit length +- Fewer communication rounds in several protocols + ## 0.2.0 (Oct 28, 2020) - Rep4: honest-majority four-party computation with malicious security diff --git a/CONFIG b/CONFIG index 5e7b7eba9..9348d851d 100644 --- a/CONFIG +++ b/CONFIG @@ -40,9 +40,9 @@ endif # MAX_MOD_SZ (for FHE) must be least and GFP_MOD_SZ (for computation) # must be exactly ceil(len(p)/len(word)) for the relevant prime p -# Default for GFP_MOD_SZ is 2, which is good for 128-bit p +# GFP_MOD_SZ only needs to be set for primes of bit length more that 256. # Default for MAX_MOD_SZ is 10, which suffices for all Overdrive protocols -# MOD = -DMAX_MOD_SZ=10 -DGFP_MOD_SZ=2 +# MOD = -DMAX_MOD_SZ=10 -DGFP_MOD_SZ=5 LDLIBS = -lmpirxx -lmpir -lsodium $(MY_LDLIBS) LDLIBS += -lboost_system -lssl -lcrypto diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index ed0dd350d..c9fd4fad1 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -521,7 +521,7 @@ class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction): class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction, base.Mergeable): - """ Copy private input to secret bit register bit by bit. The input is + """ Copy private input to secret bit registers bit by bit. The input is read as floating-point number, multiplied by a power of two, rounded to an integer, and then decomposed into bits. diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 8d3b3c8f2..4a413469c 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -1,3 +1,10 @@ +""" +This modules contains basic types for binary circuits. The +fixed-length types obtained by :py:obj:`get_type(n)` are the preferred +way of using them, and in some cases required in connection with +container types. +""" + from Compiler.types import MemValue, read_mem_value, regint, Array, cint from Compiler.types import _bitint, _number, _fix, _structure, _bit, _vec, sint from Compiler.program import Tape, Program @@ -10,6 +17,7 @@ from functools import reduce class bits(Tape.Register, _structure, _bit): + """ Base class for binary registers. """ n = 40 unit = 64 PreOp = staticmethod(floatingpoint.PreOpN) @@ -21,6 +29,7 @@ def PreOR(l): [1 - x for x in l])] @classmethod def get_type(cls, length): + """ Returns a fixed-length type. """ if length == 1: return cls.bit_type if length not in cls.types: @@ -169,6 +178,7 @@ def _new_by_number(self, i, size=1): return res class cbits(bits): + """ Clear bits register. Helper type with limited functionality. """ max_length = 64 reg_type = 'cb' is_clear = True @@ -255,6 +265,25 @@ def to_regint_by_bit(self): return res class sbits(bits): + """ + Secret bits register. This type supports basic bit-wise operations:: + + sb32 = sbits.get_type(32) + a = sb32(3) + b = sb32(5) + print_ln('XOR: %s', (a ^ b).reveal()) + print_ln('AND: %s', (a & b).reveal()) + print_ln('NOT: %s', (~a).reveal()) + + This will output the following:: + + XOR: 6 + AND: 1 + NOT: -4 + + Instances can be also be initalized from :py:obj:`~Compiler.types.regint` + and :py:obj:`~Compiler.types.sint`. + """ max_length = 128 reg_type = 'sb' is_clear = False @@ -287,6 +316,10 @@ def get_random_bit(): return res @classmethod def get_input_from(cls, player, n_bits=None): + """ Secret input from :py:obj:`player`. + + :param: player (int) + """ if n_bits is None: n_bits = cls.n res = cls() @@ -450,6 +483,9 @@ def two_power(cls, n): res.load_int(1 << n) return res def popcnt(self): + """ Population count / Hamming weight. + + :return: :py:obj:`sbits` of required length """ return sbitvec(self).popcnt().elements()[0] @classmethod def trans(cls, rows): @@ -466,7 +502,14 @@ def trans(cls, rows): inst.trans(len(res), *(res + rows)) return res def if_else(self, x, y): - # vectorized if/else + """ + Vectorized oblivious selection:: + + sb32 = sbits.get_type(32) + print_ln('%s', sb32(3).if_else(sb32(5), sb32(2)).reveal()) + + This will output 1. + """ return result_conv(x, y)(self & (x ^ y) ^ y) @staticmethod def bit_adder(*args, **kwargs): @@ -475,13 +518,62 @@ def bit_adder(*args, **kwargs): def ripple_carry_adder(*args, **kwargs): return sbitint.ripple_carry_adder(*args, **kwargs) def to_sint(self, n_bits): + """ Convert the :py:obj:`n_bits` least significant bits to + :py:obj:`~Compiler.types.sint`. """ bits = sbitvec.from_vec(sbitvec([self]).v[:n_bits]).elements()[0] bits = sint(bits, size=n_bits) return sint.bit_compose(bits) class sbitvec(_vec): + """ Vector of registers of secret bits, effectively a matrix of secret bits. + This facilitates parallel arithmetic operations in binary circuits. + Container types are not supported, use :py:obj:`sbitvec.get_type` for that. + + You can access the rows by member :py:obj:`v` and the columns by calling + :py:obj:`elements`. + + There are three ways to create an instance: + + 1. By transposition:: + + sb32 = sbits.get_type(32) + x = sbitvec([sb32(5), sb32(3), sb32(0)]) + print_ln('%s', [x.v[0].reveal(), x.v[1].reveal(), x.v[2].reveal()]) + print_ln('%s', [x.elements()[0].reveal(), x.elements()[1].reveal()]) + + This should output:: + + [3, 2, 1] + [5, 3] + + 2. Without transposition:: + + sb32 = sbits.get_type(32) + x = sbitvec.from_vec([sb32(5), sb32(3)]) + print_ln('%s', [x.v[0].reveal(), x.v[1].reveal()]) + + This should output:: + + [5, 3] + + 3. From :py:obj:`~Compiler.types.sint`:: + + y = sint(5) + x = sbitvec(y, 3, 3) + print_ln('%s', [x.v[0].reveal(), x.v[1].reveal(), x.v[2].reveal()]) + + This should output:: + + [1, 0, 1] + """ + bit_extend = staticmethod(lambda v, n: v[:n] + [0] * (n - len(v))) @classmethod def get_type(cls, n): + """ Create type for fixed-length vector of registers of secret bits. + + As with :py:obj:`sbitvec`, you can access the rows by member + :py:obj:`v` and the columns by calling :py:obj:`elements`. + """ class sbitvecn(cls, _structure): @staticmethod def malloc(size, creator_tape=None): @@ -491,16 +583,32 @@ def n_elements(): return n @classmethod def get_input_from(cls, player): + """ Secret input from :py:obj:`player`. The input is decomposed + into bits. + + :param: player (int) + """ res = cls.from_vec(sbit() for i in range(n)) inst.inputbvec(n + 3, 0, player, *res.v) return res get_raw_input_from = get_input_from - def __init__(self, other=None): + @classmethod + def from_vec(cls, vector): + res = cls() + res.v = _complement_two_extend(list(vector), n)[:n] + return res + def __init__(self, other=None, size=None): + assert size in (None, 1) if other is not None: if util.is_constant(other): self.v = [sbit((other >> i) & 1) for i in range(n)] + elif isinstance(other, _vec): + self.v = self.bit_extend(other.v, n) + elif isinstance(other, (list, tuple)): + self.v = self.bit_extend(sbitvec(other).v, n) else: self.v = sbits(other, n=n).bit_decompose(n) + assert len(self.v) == n @classmethod def load_mem(cls, address): if not isinstance(address, int) and len(address) == n: @@ -509,17 +617,22 @@ def load_mem(cls, address): return cls.from_vec(sbit.load_mem(address + i) for i in range(n)) def store_in_mem(self, address): - assert self.v[0].n == 1 + for x in self.v: + assert util.is_constant(x) or x.n == 1 + v = [sbit.conv(x) for x in self.v] if not isinstance(address, int) and len(address) == n: - for x, y in zip(self.v, address): + for x, y in zip(v, address): x.store_in_mem(y) else: for i in range(n): - self.v[i].store_in_mem(address + i) + v[i].store_in_mem(address + i) def reveal(self): revealed = [cbit() for i in range(len(self))] for i in range(len(self)): - inst.reveal(1, revealed[i], self.v[i]) + try: + inst.reveal(1, revealed[i], self.v[i]) + except: + revealed[i] = cbit.conv(self.v[i]) return cbits.get_type(len(self)).bit_compose(revealed) @classmethod def two_power(cls, nn): @@ -531,6 +644,7 @@ def coerce(self, other): return super(sbitvecn, self).coerce(other) @classmethod def bit_compose(cls, bits): + bits = list(bits) if len(bits) < n: bits += [0] * (n - len(bits)) assert len(bits) == n @@ -553,14 +667,24 @@ def combine(cls, vectors): def from_matrix(cls, matrix): # any number of rows, limited number of columns return cls.combine(cls(row) for row in matrix) - def __init__(self, elements=None, length=None): + def __init__(self, elements=None, length=None, input_length=None): if length: assert isinstance(elements, sint) if Program.prog.use_split(): x = elements.split_to_two_summands(length) v = sbitint.carry_lookahead_adder(x[0], x[1], fewer_inv=True) else: - assert Program.prog.options.ring + prog = Program.prog + if not prog.options.ring: + # force the use of edaBits + backup = prog.use_edabit() + prog.use_edabit(True) + from Compiler.floatingpoint import BitDecFieldRaw + self.v = BitDecFieldRaw(elements, + input_length or prog.bit_length, + length, prog.security) + prog.use_edabit(backup) + return l = int(Program.prog.options.ring) r, r_bits = sint.get_edabit(length, size=elements.size) c = ((elements - r) << (l - length)).reveal() @@ -573,10 +697,13 @@ def __init__(self, elements=None, length=None): elements == 0): self.v = sbits.trans(elements) def popcnt(self): + """ Population count / Hamming weight. + + :return: :py:obj:`sbitintvec` of required length """ res = sbitint.wallace_tree([[b] for b in self.v]) while util.is_zero(res[-1]): del res[-1] - return self.from_vec(res) + return sbitintvec.get_type(len(res)).from_vec(res) def elements(self, start=None, stop=None): if stop is None: start, stop = stop, start @@ -607,7 +734,10 @@ def __getitem__(self, index): return self.v[index] @classmethod def conv(cls, other): - return cls.from_vec(other.v) + if isinstance(other, cls): + return cls.from_vec(other.v) + else: + return cls(other) @property def size(self): if not self.v or util.is_constant(self.v[0]): @@ -620,7 +750,7 @@ def n_bits(self): def store_in_mem(self, address): for i, x in enumerate(self.elements()): x.store_in_mem(address + i) - def bit_decompose(self, n_bits=None): + def bit_decompose(self, n_bits=None, security=None): return self.v[:n_bits] bit_compose = from_vec def reveal(self): @@ -643,6 +773,8 @@ def bit_and(self, other): return self & other def bit_xor(self, other): return self ^ other + def right_shift(self, m, k, security=None, signed=True): + return self.from_vec(self.v[m:]) class bit(object): n = 1 @@ -663,7 +795,15 @@ def result_conv(x, y): return lambda x: x class sbit(bit, sbits): + """ Single secret bit. """ def if_else(self, x, y): + """ Non-vectorized oblivious selection:: + + sb32 = sbits.get_type(32) + print_ln('%s', sbit(1).if_else(sb32(5), sb32(2)).reveal()) + + This will output 5. + """ return result_conv(x, y)(self * (x ^ y) ^ y) class cbit(bit, cbits): @@ -756,14 +896,38 @@ def Norm(self, k, f, kappa=None, simplex_flag=False): absolute_val_2k = t2k.bit_compose(absolute_val.bit_decompose()) part_reciprocal = absolute_val_2k * acc return part_reciprocal, signed_acc + def pow2(self, k): + l = int(math.ceil(math.log(k, 2))) + bits = [self.equal(i, l) for i in range(k)] + return self.get_type(k).bit_compose(bits) class sbitint(_bitint, _number, sbits, _sbitintbase): + """ Secret signed integer in one binary register. Use :py:obj:`get_type()` + to specify the bit length:: + + si32 = sbitint.get_type(32) + print_ln('add: %s', (si32(5) + si32(3)).reveal()) + print_ln('sub: %s', (si32(5) - si32(3)).reveal()) + print_ln('mul: %s', (si32(5) * si32(3)).reveal()) + print_ln('lt: %s', (si32(5) < si32(3)).reveal()) + + This should output:: + + add: 8 + sub: 2 + mul: 15 + lt: 0 + + """ n_bits = None bin_type = None types = {} vector_mul = True @classmethod def get_type(cls, n, other=None): + """ Returns a signed integer type with fixed length. + + :param n: length """ if isinstance(other, sbitvec): return sbitvec if n in cls.types: @@ -800,7 +964,7 @@ def TruncMul(self, other, k, m, kappa=None, nearest=False): raise CompilerError('round to nearest not implemented') self_bits = self.bit_decompose() other_bits = other.bit_decompose() - if len(self_bits) + len(other_bits) != k: + if len(self_bits) + len(other_bits) > k: raise Exception('invalid parameters for TruncMul: ' 'self:%d, other:%d, k:%d' % (len(self_bits), len(other_bits), k)) @@ -811,7 +975,7 @@ def TruncMul(self, other, k, m, kappa=None, nearest=False): product = a * b res_bits = product.bit_decompose()[m:k] res_bits += [res_bits[-1]] * (self.n - len(res_bits)) - t = self.combo_type(other) + t = self.combo_type(other).get_type(k - m) return t.bit_compose(res_bits) def __mul__(self, other): if isinstance(other, sbitintvec): @@ -836,15 +1000,47 @@ def get_bit_matrix(cls, self_bits, other): return res @classmethod def popcnt_bits(cls, bits): - res = sbitvec.from_vec(bits).popcnt().elements()[0] + res = sbitintvec.popcnt_bits(bits).elements()[0] res = cls.conv(res) return res def pow2(self, k): - l = int(math.ceil(math.log(k, 2))) - bits = [self.equal(i, l) for i in range(k)] - return self.bit_compose(bits) + """ Computer integer power of two. + + :param k: bit length of input """ + return _sbitintbase.pow2(self, k) class sbitintvec(sbitvec, _number, _bitint, _sbitintbase): + """ + Vector of signed integers for parallel binary computation:: + + sb32 = sbits.get_type(32) + siv32 = sbitintvec.get_type(32) + a = siv32([sb32(3), sb32(5)]) + b = siv32([sb32(4), sb32(6)]) + c = (a + b).elements() + print_ln('add: %s, %s', c[0].reveal(), c[1].reveal()) + c = (a * b).elements() + print_ln('mul: %s, %s', c[0].reveal(), c[1].reveal()) + c = (a - b).elements() + print_ln('sub: %s, %s', c[0].reveal(), c[1].reveal()) + c = (a < b).bit_decompose() + print_ln('lt: %s, %s', c[0].reveal(), c[1].reveal()) + + This should output:: + + add: 7, 11 + mul: 12, 30 + sub: -1, 11 + lt: 1, 1 + + """ + bit_extend = staticmethod(_complement_two_extend) + @classmethod + def popcnt_bits(cls, bits): + return sbitvec.from_vec(bits).popcnt() + def elements(self): + return [sbitint.get_type(len(self.v))(x) + for x in sbitvec.elements(self)] def __add__(self, other): if util.is_zero(other): return self @@ -853,15 +1049,14 @@ def __add__(self, other): v = sbitint.bit_adder(self.v, other.v) return self.from_vec(v) __radd__ = __add__ - def less_than(self, other, *args, **kwargs): - assert(len(self.v) == len(other.v)) - return self.from_vec(sbitint.bit_less_than(self.v, other.v)) def __mul__(self, other): if isinstance(other, sbits): return self.from_vec(other * x for x in self.v) + elif isinstance(other, sbitfixvec): + return NotImplemented matrix = [] for i, b in enumerate(util.bit_decompose(other)): - matrix.append([x * b for x in self.v[:len(self.v)-i]]) + matrix.append([x & b for x in self.v[:len(self.v)-i]]) v = sbitint.wallace_tree_from_matrix(matrix) return self.from_vec(v[:len(self.v)]) __rmul__ = __mul__ @@ -871,12 +1066,17 @@ def TruncMul(self, other, k, m, kappa=None, nearest=False): raise CompilerError('round to nearest not implemented') if not isinstance(other, sbitintvec): other = sbitintvec(other) - assert len(self.v) + len(other.v) == k - a = self.from_vec(_complement_two_extend(self.v, k)) - b = self.from_vec(_complement_two_extend(other.v, k)) + assert len(self.v) + len(other.v) <= k + a = self.get_type(k).from_vec(_complement_two_extend(self.v, k)) + b = self.get_type(k).from_vec(_complement_two_extend(other.v, k)) tmp = a * b assert len(tmp.v) == k - return self.from_vec(tmp[m:]) + return self.get_type(k - m).from_vec(tmp[m:]) + def pow2(self, k): + """ Computer integer power of two. + + :param k: bit length of input """ + return _sbitintbase.pow2(self, k) sbitint.vec = sbitintvec @@ -900,11 +1100,29 @@ def output(self): inst.print_float_plainb(v, cbits(-self.f, n=32), cbits(0), cbits(0), cbits(0)) class sbitfix(_fix): + """ Secret signed integer in one binary register. + Use :py:obj:`set_precision()` to change the precision. + + Example:: + + print_ln('add: %s', (sbitfix(0.5) + sbitfix(0.3)).reveal()) + print_ln('mul: %s', (sbitfix(0.5) * sbitfix(0.3)).reveal()) + print_ln('sub: %s', (sbitfix(0.5) - sbitfix(0.3)).reveal()) + print_ln('lt: %s', (sbitfix(0.5) < sbitfix(0.3)).reveal()) + + will output roughly:: + + add: 0.800003 + mul: 0.149994 + sub: 0.199997 + lt: 0 + + """ float_type = type(None) clear_type = cbitfix @classmethod def set_precision(cls, f, k=None): - super(cls, sbitfix).set_precision(f, k) + super(sbitfix, cls).set_precision(f, k) cls.int_type = sbitint.get_type(cls.k) @classmethod def load_mem(cls, address, size=None): @@ -915,6 +1133,10 @@ def load_mem(cls, address, size=None): return super(sbitfix, cls).load_mem(address) @classmethod def get_input_from(cls, player): + """ Secret input from :py:obj:`player`. + + :param: player (int) + """ v = cls.int_type() inst.inputb(player, cls.k, cls.f, v) return cls._new(v) @@ -937,39 +1159,72 @@ class cls(_fix): cls.set_precision(f, k) return cls._new(cls.int_type(other), k, f) -sbitfix.set_precision(16, 31) - class sbitfixvec(_fix): - int_type = sbitintvec + """ Vector of fixed-point numbers for parallel binary computation. + + Use :py:obj:`set_precision()` to change the precision. + + Example:: + + a = sbitfixvec([sbitfix(0.3), sbitfix(0.5)]) + b = sbitfixvec([sbitfix(0.4), sbitfix(0.6)]) + c = (a + b).elements() + print_ln('add: %s, %s', c[0].reveal(), c[1].reveal()) + c = (a * b).elements() + print_ln('mul: %s, %s', c[0].reveal(), c[1].reveal()) + c = (a - b).elements() + print_ln('sub: %s, %s', c[0].reveal(), c[1].reveal()) + c = (a < b).bit_decompose() + print_ln('lt: %s, %s', c[0].reveal(), c[1].reveal()) + + This should output roughly:: + + add: 0.699997, 1.10001 + mul: 0.119995, 0.300003 + sub: -0.0999908, -0.100021 + lt: 1, 1 + + """ + int_type = sbitintvec.get_type(sbitfix.k) float_type = type(None) - clear_type = type(None) - _f = None - _k = None - @property - def f(self): - if self._f is None: - return sbitfix.f - else: - return self._f - @f.setter - def f(self, value): - self._f = value - @property - def k(self): - if self._k is None: - return sbitfix.k + clear_type = cbitfix + @classmethod + def set_precision(cls, f, k=None): + super(sbitfixvec, cls).set_precision(f=f, k=k) + cls.int_type = sbitintvec.get_type(cls.k) + @classmethod + def get_input_from(cls, player): + """ Secret input from :py:obj:`player`. + + :param: player (int) + """ + v = [sbit() for i in range(sbitfix.k)] + inst.inputbvec(len(v) + 3, sbitfix.f, player, *v) + return cls._new(cls.int_type.from_vec(v)) + def __init__(self, value=None, *args, **kwargs): + if isinstance(value, (list, tuple)): + self.v = self.int_type.from_vec(sbitvec([x.v for x in value])) else: - return self._k - @k.setter - def k(self, value): - self._k = value - def coerce(self, other): - return other + super(sbitfixvec, self).__init__(value, *args, **kwargs) + def elements(self): + return [sbitfix._new(x, f=self.f, k=self.k) for x in self.v.elements()] def mul(self, other): if isinstance(other, sbits): return self._new(self.v * other) else: return super(sbitfixvec, self).mul(other) + def __xor__(self, other): + return self._new(self.v ^ other.v) + @staticmethod + def multipliable(other, k, f, size): + class cls(_fix): + int_type = sbitint.get_type(k) + clear_type = cbitfix + cls.set_precision(f, k) + return cls._new(cls.int_type(other), k, f) + +sbitfix.set_precision(16, 31) +sbitfixvec.set_precision(16, 31) sbitfix.vec = sbitfixvec diff --git a/Compiler/allocator.py b/Compiler/allocator.py index 1ae37a61c..0ff79329e 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -280,11 +280,11 @@ def longest_paths_merge(self): preorder = None - if len(instructions) > 100000: + if len(instructions) > 1000000: print("Topological sort ...") order = Compiler.graph.topological_sort(G, preorder) instructions[:] = [instructions[i] for i in order if instructions[i] is not None] - if len(instructions) > 100000: + if len(instructions) > 1000000: print("Done at", time.asctime()) return len(merges) @@ -356,14 +356,16 @@ def mem_access(n, instr, last_access_this_kind, last_access_other_kind): addr_i = addr + i handle_mem_access(addr_i, reg_type, last_access_this_kind, last_access_other_kind) - if not warned_about_mem and (instr.get_size() > 100): + if block.warn_about_mem and not warned_about_mem and \ + (instr.get_size() > 100): print('WARNING: Order of memory instructions ' \ 'not preserved due to long vector, errors possible') warned_about_mem.append(True) else: handle_mem_access(addr, reg_type, last_access_this_kind, last_access_other_kind) - if not warned_about_mem and not isinstance(instr, DirectMemoryInstruction): + if block.warn_about_mem and not warned_about_mem and \ + not isinstance(instr, DirectMemoryInstruction): print('WARNING: Order of memory instructions ' \ 'not preserved, errors possible') # hack @@ -477,7 +479,7 @@ def keep_merged_order(instr, n, t): if not G.pred[n]: self.sources.append(n) - if n % 100000 == 0 and n > 0: + if n % 1000000 == 0 and n > 0: print("Processed dependency of %d/%d instructions at" % \ (n, len(block.instructions)), time.asctime()) diff --git a/Compiler/comparison.py b/Compiler/comparison.py index 3005036a2..d330e3614 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -69,6 +69,11 @@ def divide_by_two(res, x, m=1): inv2m(tmp, m) mulc(res, x, tmp) +def require_ring_size(k, op): + if int(program.options.ring) < k: + raise CompilerError('ring size too small for %s, compile ' + 'with \'-R %d\' or more' % (op, k)) + @instructions_base.cisc def LTZ(s, a, k, kappa): """ @@ -86,7 +91,7 @@ def LTZ(s, a, k, kappa): return elif program.options.ring: from . import floatingpoint - assert(int(program.options.ring) >= k) + require_ring_size(k, 'comparison') m = k - 1 shift = int(program.options.ring) - k r_prime, r_bin = MaskingBitsInRing(k) @@ -116,7 +121,6 @@ def Trunc(d, a, k, m, kappa, signed): m: compile-time integer signed: True/False, describes a """ - a_prime = program.curr_block.new_reg('s') t = program.curr_block.new_reg('s') c = [program.curr_block.new_reg('c') for i in range(3)] c2m = program.curr_block.new_reg('c') @@ -125,10 +129,8 @@ def Trunc(d, a, k, m, kappa, signed): return elif program.options.ring: return TruncRing(d, a, k, m, signed) - elif m == 1: - Mod2(a_prime, a, k, kappa, signed) else: - Mod2m(a_prime, a, k, m, kappa, signed) + a_prime = program.non_linear.mod2m(a, k, m, signed) subs(t, a, a_prime) ldi(c[1], 1) divide_by_two(c[2], c[1], m) @@ -218,14 +220,9 @@ def TruncRoundNearest(a, k, m, kappa, signed=False): """ if m == 0: return a - if k == int(program.options.ring): - # cannot work with bit length k+1 - tmp = TruncRing(None, a, k, m - 1, signed) - return TruncRing(None, tmp + 1, k - m + 1, 1, signed) - from .types import sint - res = sint() - Trunc(res, a + (1 << (m - 1)), k + 1, m, kappa, signed) - return res + nl = program.non_linear + nl.check_security(kappa) + return program.non_linear.trunc_round_nearest(a, k, m, signed) @instructions_base.cisc def Mod2m(a_prime, a, k, m, kappa, signed): @@ -236,18 +233,9 @@ def Mod2m(a_prime, a, k, m, kappa, signed): m: compile-time integer signed: True/False, describes a """ - if not util.is_constant(m): - raise CompilerError('m must be a public constant') - if m >= k: - movs(a_prime, a) - return - if program.options.ring: - return Mod2mRing(a_prime, a, k, m, signed) - else: - if m == 1: - return Mod2(a_prime, a, k, kappa, signed) - else: - return Mod2mField(a_prime, a, k, m, kappa, signed) + nl = program.non_linear + nl.check_security(kappa) + movs(a_prime, program.non_linear.mod2m(a, k, m, signed)) def Mod2mRing(a_prime, a, k, m, signed): assert(int(program.options.ring) >= k) diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index 6e4e31ef1..79f3b2a8e 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -1,48 +1,29 @@ from Compiler.program import Program -from Compiler.config import * -from Compiler.exceptions import * -from . import instructions, instructions_base, types, comparison, library from .GC import types as GC_types -import random -import time import sys -def run(args, options, reallocate=True, debug=False): +def run(args, options): """ Compile a file and output a Program object. If options.merge_opens is set to True, will attempt to merge any parallelisable open instructions. """ prog = Program(args, options) - instructions.program = prog - instructions_base.program = prog - types.program = prog - comparison.program = prog - prog.DEBUG = debug VARS['program'] = prog if options.binary: VARS['sint'] = GC_types.sbitintvec.get_type(int(options.binary)) - VARS['sfix'] = GC_types.sbitfix - comparison.set_variant(options) + VARS['sfix'] = GC_types.sbitfixvec print('Compiling file', prog.infile) - if instructions_base.Instruction.count != 0: - print('instructions count', instructions_base.Instruction.count) - instructions_base.Instruction.count = 0 # make compiler modules directly accessible sys.path.insert(0, 'Compiler') # create the tapes exec(compile(open(prog.infile).read(), prog.infile, 'exec'), VARS) - - # optimize the tapes - for tape in prog.tapes: - tape.optimize(options) - - if prog.tapes: - prog.update_req(prog.curr_tape) + + prog.finalize() if prog.req_num: print('Program requires:') @@ -54,7 +35,4 @@ def run(args, options, reallocate=True, debug=False): print('Cost:', 0 if prog.req_num is None else prog.req_num.cost()) print('Memory size:', dict(prog.allocated_mem)) - # finalize the memory - prog.finalize_memory() - return prog diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index 1fde0aab4..c720ce914 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -22,13 +22,7 @@ def two_power(n): return res def shift_two(n, pos): - if pos < 63: - return n >> pos - else: - res = (n >> (pos%63)) - for i in range(pos // 63): - res >>= 63 - return res + return n >> pos def maskRing(a, k): @@ -48,7 +42,9 @@ def maskField(a, k, kappa): c = types.cint() r = [types.sint() for i in range(k)] comparison.PRandM(r_dprime, r_prime, r, k, k, kappa) - asm_open(c, a + two_power(k) * r_dprime + r_prime)# + 2**(k-1)) + # always signed due to usage in equality testing + a += two_power(k) + asm_open(c, a + two_power(k) * r_dprime + r_prime) return c, r @instructions_base.ret_cisc @@ -59,14 +55,8 @@ def EQZ(a, k, kappa): v = sbitvec(a, k).v bit = util.tree_reduce(operator.and_, (~b for b in v)) return types.sint.conv(bit) - if program.Program.prog.options.ring: - c, r = maskRing(a, k) - else: - c, r = maskField(a, k, kappa) - d = [None]*k - for i,b in enumerate(r[0].bit_decompose_clear(c, k)): - d[i] = r[i].bit_xor(b) - return 1 - types.sint.conv(KOR(d, kappa)) + prog.non_linear.check_security(kappa) + return prog.non_linear.eqz(a, k) def bits(a,m): """ Get the bits of an int """ @@ -206,7 +196,7 @@ def KOpL(op, a): t2 = KOpL(op, a[k//2:]) return op(t1, t2) -def KORL(a, kappa): +def KORL(a, kappa=None): """ log rounds k-ary OR """ k = len(a) if k == 1: @@ -254,61 +244,23 @@ def BitAdd(a, b, bits_to_compute=None): bits_to_compute = list(range(k)) d = [None] * k for i in range(1,k): - #assert(a[i].value == 0 or a[i].value == 1) - #assert(b[i].value == 0 or b[i].value == 1) t = a[i]*b[i] d[i] = (a[i] + b[i] - 2*t, t) - #assert(d[i][0].value == 0 or d[i][0].value == 1) d[0] = (None, a[0]*b[0]) pg = PreOpL(carry, d) c = [pair[1] for pair in pg] - # (for testing) - def print_state(): - print('a: ', end=' ') - for i in range(k): - print('%d ' % a[i].value, end=' ') - print('\nb: ', end=' ') - for i in range(k): - print('%d ' % b[i].value, end=' ') - print('\nd: ', end=' ') - for i in range(k): - print('%d ' % d[i][0].value, end=' ') - print('\n ', end=' ') - for i in range(k): - print('%d ' % d[i][1].value, end=' ') - print('\n\npg:', end=' ') - for i in range(k): - print('%d ' % pg[i][0].value, end=' ') - print('\n ', end=' ') - for i in range(k): - print('%d ' % pg[i][1].value, end=' ') - print('') - - for bit in c: - pass#assert(bit.value == 0 or bit.value == 1) s = [None] * (k+1) if 0 in bits_to_compute: s[0] = a[0] + b[0] - 2*c[0] bits_to_compute.remove(0) - #assert(c[0].value == a[0].value*b[0].value) - #assert(s[0].value == 0 or s[0].value == 1) for i in bits_to_compute: s[i] = a[i] + b[i] + c[i-1] - 2*c[i] - try: - pass#assert(s[i].value == 0 or s[i].value == 1) - except AssertionError: - print('#assertion failed in BitAdd for s[%d]' % i) - print_state() s[k] = c[k-1] - #print_state() return s def BitDec(a, k, m, kappa, bits_to_compute=None): - if program.Program.prog.options.ring: - return BitDecRing(a, k, m) - else: - return BitDecField(a, k, m, kappa, bits_to_compute) + return program.Program.prog.non_linear.bit_dec(a, k, m) def BitDecRing(a, k, m): n_shift = int(program.Program.prog.options.ring) - m @@ -330,7 +282,7 @@ def BitDecRing(a, k, m): bits = r_bits[0].bit_adder(r_bits, masked.bit_decompose(m)) return [types.sint.conv(bit) for bit in bits] -def BitDecField(a, k, m, kappa, bits_to_compute=None): +def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None): r_dprime = types.sint() r_prime = types.sint() c = types.cint() @@ -349,6 +301,10 @@ def BitDecField(a, k, m, kappa, bits_to_compute=None): print('a =', a.value) print('a mod 2^%d =' % k, (a.value % 2**k)) res = r[0].bit_adder(r, list(r[0].bit_decompose_clear(c,m))) + return res + +def BitDecField(a, k, m, kappa, bits_to_compute=None): + res = BitDecFieldRaw(a, k, m, kappa, bits_to_compute) return [types.sint.conv(bit) for bit in res] @@ -536,18 +492,15 @@ def TruncPr(a, k, m, kappa=None, signed=True): """ Probabilistic truncation [a/2^m + u] where Pr[u = 1] = (a % 2^m) / 2^m """ - if isinstance(a, types.cint): - return shift_two(a, m) - if program.Program.prog.options.ring: - return TruncPrRing(a, k, m, signed=signed) - else: - return TruncPrField(a, k, m, kappa) + nl = program.Program.prog.non_linear + nl.check_security(kappa) + return nl.trunc_pr(a, k, m, signed) def TruncPrRing(a, k, m, signed=True): if m == 0: return a n_ring = int(program.Program.prog.options.ring) - assert n_ring >= k, '%d too large' % k + comparison.require_ring_size(k, 'truncation') if k == n_ring: program.Program.prog.curr_tape.require_bit_length(1) if program.Program.prog.use_edabit(): @@ -652,64 +605,67 @@ def SDiv_mono(a, b, l, kappa): y = TruncPr(y, 3 * l, 2 * l, kappa) return y - -def FPDiv(a, b, k, f, kappa): - theta = int(ceil(log(k/3.5))) - alpha = types.cint(1 * two_power(2*f)) - - w = AppRcr(b, k, f, kappa) - x = alpha - b * w - y = a * w - y = TruncPr(y, 2*k, f, kappa) - - for i in range(theta): - y = y * (alpha + x) - x = x * x - y = TruncPr(y, 2*k, 2*f, kappa) - x = TruncPr(x, 2*k, 2*f, kappa) - - y = y * (alpha + x) - y = TruncPr(y, 2*k, 2*f, kappa) - return y - -def AppRcr(b, k, f, kappa): - """ - Approximate reciprocal of [b]: - Given [b], compute [1/b] - """ - alpha = types.cint(int(2.9142 * (2**k))) - c, v = Norm(b, k, f, kappa) - d = alpha - 2 * c - w = d * v - w = TruncPr(w, 2 * k, 2 * (k - f)) - return w - -def Norm(b, k, f, kappa): - """ - Computes secret integer values [c] and [v_prime] st. - 2^{k-1} <= c < 2^k and c = b*v_prime - """ - temp = types.sint() - comparison.LTZ(temp, b, k, kappa) - sign = 1 - 2 * temp # 1 - 2 * [b < 0] - - x = sign * b - #x = |b| - bits = x.bit_decompose(k) - y = PreOR(bits) - - z = [0] * k - for i in range(k - 1): - z[i] = y[i] - y[i + 1] - - z[k - 1] = y[k - 1] - # z[i] = 0 for all i except when bits[i + 1] = first one - - #now reverse bits of z[i] - v = types.sint() - for i in range(k): - v += two_power(k - i - 1) * z[i] - c = x * v - v_prime = sign * v - return c, v_prime - +# LT bit comparison on shared bit values +# Assumes b has the larger size +# - From the paper +# Unconditionally Secure Constant-Rounds Multi-party Computation +# for Equality, Comparison, Bits and Exponentiation +def BITLT(a, b, bit_length): + sint = types.sint + e = [sint(0)]*bit_length + g = [sint(0)]*bit_length + h = [sint(0)]*bit_length + for i in range(bit_length): + # Compute the XOR (reverse order of e for PreOpL) + e[bit_length-i-1] = a[i].bit_xor(b[i]) + f = PreOpL(or_op, e) + g[bit_length-1] = f[0] + for i in range(bit_length-1): + # reverse order of f due to PreOpL + g[i] = f[bit_length-i-1]-f[bit_length-i-2] + ans = 0 + for i in range(bit_length): + h[i] = g[i]*b[i] + ans = ans + h[i] + return ans + +# Exact BitDec with no need for a statistical gap +# - From the paper +# Multiparty Computation for Interval, Equality, and Comparison without +# Bit-Decomposition Protocol +def BitDecFull(a): + from .library import get_program, do_while, if_, break_point + from .types import sint, regint, longint + p=int(get_program().options.prime) + assert p + bit_length = p.bit_length() + bbits = [sint(size=a.size) for i in range(bit_length)] + tbits = [[sint(size=1) for i in range(bit_length)] for j in range(a.size)] + pbits = util.bit_decompose(p) + # Loop until we get some random integers less than p + done = [regint(0) for i in range(a.size)] + @do_while + def get_bits_loop(): + for j in range(a.size): + @if_(done[j] == 0) + def _(): + for i in range(bit_length): + tbits[j][i].link(sint.get_random_bit()) + c = regint(BITLT(tbits[j], pbits, bit_length).reveal()) + done[j].link(c) + return (sum(done) != a.size) + for j in range(a.size): + for i in range(bit_length): + movs(bbits[i][j], tbits[j][i]) + b = sint.bit_compose(bbits) + c = (a-b).reveal() + t = (p-c).bit_decompose(bit_length) + c = longint(c, bit_length) + czero = (c==0) + q = 1-BITLT( bbits, t, bit_length) + fbar=((1<.data``. + + :param: number of arguments to follow / number of shares plus two (int) + :param: starting position in number of shares from beginning (regint) + :param: destination for final position, -1 for eof reached, or -2 for file not found (regint) + :param: destination for share (sint) + :param: (repeat from destination for share)... """ __slots__ = [] code = base.opcodes['READFILESHARE'] diff --git a/Compiler/library.py b/Compiler/library.py index dcbad5561..4af9ae1d9 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -841,7 +841,7 @@ def for_range(start, stop=None, step=None): Decorator to execute loop bodies consecutively. Arguments work as in Python :py:func:`range`, but they can by any public integer. Information has to be passed out via container types such - as :py:class:`Compiler.types.Array` or declaring registers as + as :py:class:`~Compiler.types.Array` or declaring registers as :py:obj:`global`. Note that changing Python data structures such as lists within the loop is not possible, but the compiler cannot warn about this. @@ -1519,49 +1519,36 @@ def approximate_reciprocal(divisor, k, f, theta): """ def twos_complement(x): bits = x.bit_decompose(k)[::-1] - bit_array = Array(k, cint) - bit_array.assign(bits) - twos_result = MemValue(cint(0)) - @for_range(k) - def block(i): - val = twos_result.read() + twos_result = cint(0) + for i in range(k): + val = twos_result val <<= 1 - val += 1 - bit_array[i] - twos_result.write(val) + val += 1 - bits[i] + twos_result = val - return twos_result.read() + 1 + return twos_result + 1 - bit_array = Array(k, cint) bits = divisor.bit_decompose(k)[::-1] - bit_array.assign(bits) - cnt_leading_zeros = MemValue(regint(0)) + flag = regint(0) + cnt_leading_zeros = regint(0) + normalized_divisor = divisor - flag = MemValue(regint(0)) - cnt_leading_zeros = MemValue(regint(0)) - normalized_divisor = MemValue(divisor) - - @for_range(k) - def block(i): - flag.write(flag.read() | bit_array[i] == 1) - @if_(flag.read() == 0) - def block(): - cnt_leading_zeros.write(cnt_leading_zeros.read() + 1) - normalized_divisor.write(normalized_divisor << 1) - - q = MemValue(two_power(k)) - e = MemValue(twos_complement(normalized_divisor.read())) + for i in range(k): + flag = flag | (bits[i] == 1) + flag_zero = cint(flag == 0) + cnt_leading_zeros += flag_zero + normalized_divisor <<= flag_zero - qr = q.read() - er = e.read() + q = two_power(k) + e = twos_complement(normalized_divisor) for i in range(theta): - qr = qr + shift_two(qr * er, k) - er = shift_two(er * er, k) + q += (q * e) >> k + e = (e * e) >> k - q = qr - res = q >> (2*k - 2*f - cnt_leading_zeros) + res = q >> cint(2*k - 2*f - cnt_leading_zeros) return res @@ -1583,19 +1570,16 @@ def cint_cint_division(a, b, k, f): absolute_b = b * sign_b absolute_a = a * sign_a w0 = approximate_reciprocal(absolute_b, k, f, theta) - A = Array(theta, cint) - B = Array(theta, cint) - W = Array(theta, cint) - A[0] = absolute_a - B[0] = absolute_b - W[0] = w0 - for i in range(1, theta): - A[i] = shift_two(A[i - 1] * W[i - 1], f) - B[i] = shift_two(B[i - 1] * W[i - 1], f) - W[i] = two - B[i] + A = absolute_a + B = absolute_b + W = w0 - return (sign_a * sign_b) * A[theta - 1] + for i in range(1, theta): + A = (A * W) >> f + B = (B * W) >> f + W = two - B + return (sign_a * sign_b) * A from Compiler.program import Program def sint_cint_division(a, b, k, f, kappa): @@ -1610,23 +1594,18 @@ def sint_cint_division(a, b, k, f, kappa): absolute_a = a * sign_a w0 = approximate_reciprocal(absolute_b, k, f, theta) - A = Array(theta, sint) - B = Array(theta, cint) - W = Array(theta, cint) - - A[0] = absolute_a - B[0] = absolute_b - W[0] = w0 + A = absolute_a + B = absolute_b + W = w0 @for_range(1, theta) def block(i): - A[i] = TruncPr(A[i - 1] * W[i - 1], 2*k, f, kappa) - temp = shift_two(B[i - 1] * W[i - 1], f) - # no reading and writing to the same variable in a for loop. - W[i] = two - temp - B[i] = temp - return (sign_a * sign_b) * A[theta - 1] + A.link(TruncPr(A * W, 2*k, f, kappa)) + temp = (B * W) >> f + W.link(two - temp) + B.link(temp) + return (sign_a * sign_b) * A def IntDiv(a, b, k, kappa=None): return FPDiv(a.extend(2 * k) << k, b.extend(2 * k) << k, 2 * k, k, @@ -1637,7 +1616,9 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False): """ Goldschmidt method as presented in Catrina10, """ - if 2 * k == int(get_program().options.ring): + prime = get_program().prime + if 2 * k == int(get_program().options.ring) or \ + (prime and 2 * k <= (prime.bit_length() - 1)): # not fitting otherwise nearest = True if get_program().options.binary: diff --git a/Compiler/ml.py b/Compiler/ml.py index 6814d15a6..735cb781a 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -1,7 +1,8 @@ """ This module contains machine learning functionality. It is work in progress, so you must expect things to change. The only tested -functionality for training is logistic regression. It can be run as +functionality for training is using consective dense/fully-connected +layers. This includes logistic regression. It can be run as follows:: sgd = ml.SGD([ml.Dense(n_examples, n_features, 1), @@ -39,6 +40,9 @@ See the `readme `_ for an example of how to run MP-SPDZ on TensorFlow graphs. + +See also `this repository `_ +for an example of how to train a model for MNIST. """ import math @@ -149,6 +153,14 @@ def set_n_threads(n_threads): Layer.n_threads = n_threads Optimizer.n_threads = n_threads +def _no_mem_warnings(function): + def wrapper(*args, **kwargs): + get_program().warn_about_mem.append(False) + res = function(*args, **kwargs) + get_program().warn_about_mem.pop() + return res + return wrapper + class Tensor(MultiArray): def __init__(self, *args, **kwargs): kwargs['alloc'] = False @@ -374,6 +386,10 @@ def from_args(program, N, n_output): res = MultiOutput(N, n_output, approx='approx' in program.args) res.cheaper_loss = 'mse' in program.args res.compute_loss = not 'no_loss' in program.args + for arg in program.args: + m = re.match('approx=(.*)', arg) + if m: + res.approx = float(m.group(1)) return res class MultiOutput(MultiOutputBase): @@ -401,7 +417,11 @@ def forward(self, batch): @for_range_opt_multithread(self.n_threads, N) def _(i): if self.approx: - positives = self.X[i].get_vector() > (0 if self.cheaper_loss else 0.1) + if self.cheaper_loss or isinstance(self.approx, float): + limit = 0 + else: + limit = 0.1 + positives = self.X[i].get_vector() > limit relus = positives.if_else(self.X[i].get_vector(), 0) self.positives[i].assign_vector(positives) self.relus[i].assign_vector(relus) @@ -464,6 +484,8 @@ def _(j): self.nabla_X[i][j] = self.positives[i][j].if_else(res, fallback) return relus = self.relus[i].get_vector() + if isinstance(self.approx, float): + relus += self.approx positives = self.positives[i].get_vector() inv = (1 / sum(relus)).expand_to_vector(d_out) truths = self.Y[batch[i]].get_vector() @@ -1485,12 +1507,13 @@ def batch_for(self, layer, batch): batch.assign(regint.inc(len(batch))) return batch + @_no_mem_warnings def forward(self, N=None, batch=None, keep_intermediate=True, model_from=None): """ Compute graph. :param N: batch size (used if batch not given) - :param batch: indices for computation (:py:class:`Compiler.types.Array`. or list) + :param batch: indices for computation (:py:class:`~Compiler.types.Array` or list) :param keep_intermediate: do not free memory of intermediate results after use """ if batch is None: @@ -1511,6 +1534,7 @@ def forward(self, N=None, batch=None, keep_intermediate=True, for theta in layer.thetas(): theta.delete() + @_no_mem_warnings def eval(self, data): """ Compute evaluation after training. """ N = len(data) @@ -1518,6 +1542,7 @@ def eval(self, data): self.forward(N) return self.layers[-1].eval(N) + @_no_mem_warnings def backward(self, batch): """ Compute backward propagation. """ for layer in reversed(self.layers): @@ -1531,6 +1556,7 @@ def backward(self, batch): layer.inputs[0].nabla_Y.assign_vector( layer.nabla_X.get_part_vector(0, len(batch))) + @_no_mem_warnings def run(self, batch_size=None, stop_on_loss=0): """ Run training. @@ -1601,6 +1627,7 @@ def _(j): return res print_ln('finished after %s epochs and %s iterations', i, n_iterations) + @_no_mem_warnings def run_by_args(self, program, n_runs, batch_size, test_X, test_Y): for arg in program.args: m = re.match('rate(.*)', arg) @@ -1724,6 +1751,7 @@ def __init__(self, layers, n_epochs, debug=False, report_loss=None): self.i_epoch = MemValue(0) self.stopped_on_loss = MemValue(0) + @_no_mem_warnings def reset(self, X_by_label=None): """ Reset layer parameters. diff --git a/Compiler/mpc_math.py b/Compiler/mpc_math.py index 918b0c591..3508b63b4 100644 --- a/Compiler/mpc_math.py +++ b/Compiler/mpc_math.py @@ -34,7 +34,12 @@ 0.00000000000000000040] ## # @private -p_1045 = [math.log(2) ** i / math.factorial(i) for i in range(12)] +p_1045 = [math.log(2) ** i / math.factorial(i) for i in range(100)] + +p_2508 = [-4.585323876456, 18.351352559641, -51.525644374262, + 111.76784165654, -174.170840774074, 191.731001033848, + -145.61191979671, 72.650082977468, -21.447349196774, + 2.840799797315] ## # @private @@ -270,6 +275,18 @@ def exp2_fx(a, zero_output=False): :return: :math:`2^a` if it is within the range. Undefined otherwise """ + def exp_from_parts(whole_exp, frac): + class my_fix(type(a)): + pass + # improve precision + my_fix.set_precision(a.k - 2, a.k) + n_shift = a.k - 2 - a.f + x = my_fix._new(frac.v << n_shift) + # evaluates fractional part of a in p_1045 + e = p_eval(p_1045, x) + g = a._new(whole_exp.TruncMul(e.v, 2 * a.k, n_shift, + nearest=a.round_nearest), a.k, a.f) + return g if types.program.options.ring: sint = types.sint intbitint = types.intbitint @@ -352,8 +369,7 @@ def exp2_fx(a, zero_output=False): assert(len(higher_bits) == n_bits - a.f) pow2_bits = [sint.conv(x) for x in higher_bits] d = floatingpoint.Pow2_from_bits(pow2_bits) - e = p_eval(p_1045, c) - g = d * e + g = exp_from_parts(d, c) small_result = types.sfix._new(g.v.round(a.f + 2 ** n_int_bits, 2 ** n_int_bits, signed=False, nearest=types.sfix.round_nearest), @@ -371,15 +387,13 @@ def exp2_fx(a, zero_output=False): c = a - b # squares integer part of a d = b.pow2(a.k - a.f) - # evaluates fractional part of a in p_1045 - e = p_eval(p_1045, c) - g = d * e + g = exp_from_parts(d, c) return s.if_else(1 / g, g) @types.vectorize @instructions_base.sfix_cisc -def log2_fx(x): +def log2_fx(x, use_division=False): """ Returns the result of :math:`\log_2(x)` for any unbounded number. This is achieved by changing :py:obj:`x` into @@ -407,10 +421,14 @@ def log2_fx(x): # isolates mantisa of d, now the n can be also substituted by the # secret shared p from d in the expresion above. # polynomials for the log_2 evaluation of f are calculated - P = p_eval(p_2524, v) - Q = p_eval(q_2524, v) + if use_division: + P = p_eval(p_2524, v) + Q = p_eval(q_2524, v) + approx = P / Q + else: + approx = p_eval(p_2508, v) # the log is returned by adding the result of the division plus p. - a = P / Q + (vlen + p) + a = approx + (vlen + p) return a # *(1-(f.z))*(1-f.s)*(1-f.error) diff --git a/Compiler/non_linear.py b/Compiler/non_linear.py new file mode 100644 index 000000000..6af4b414e --- /dev/null +++ b/Compiler/non_linear.py @@ -0,0 +1,145 @@ +from .comparison import * +from .floatingpoint import * +from .types import * +from . import comparison + +class NonLinear: + kappa = None + + def set_security(self, kappa): + pass + + def check_security(self, kappa): + pass + + def mod2m(self, a, k, m, signed): + """ + a_prime = a % 2^m + + k: bit length of a + m: compile-time integer + signed: True/False, describes a + """ + if not util.is_constant(m): + raise CompilerError('m must be a public constant') + if m >= k: + return a + else: + return self._mod2m(a, k, m, signed) + + def trunc_pr(self, a, k, m, signed=True): + if isinstance(a, types.cint): + return shift_two(a, m) + return self._trunc_pr(a, k, m, signed) + + def trunc_round_nearest(self, a, k, m, signed): + res = sint() + comparison.Trunc(res, a + (1 << (m - 1)), k + 1, m, self.kappa, + signed) + return res + +class Masking(NonLinear): + def eqz(self, a, k): + c, r = self._mask(a, k) + d = [None]*k + for i,b in enumerate(r[0].bit_decompose_clear(c, k)): + d[i] = r[i].bit_xor(b) + return 1 - types.sint.conv(self.kor(d)) + +class Prime(Masking): + """ Non-linear functionality modulo a prime with statistical masking. """ + def __init__(self, kappa): + self.set_security(kappa) + + def set_security(self, kappa): + self.kappa = kappa + + def check_security(self, kappa): + assert self.kappa == kappa or kappa is None + + def _mod2m(self, a, k, m, signed): + res = sint() + if m == 1: + Mod2(res, a, k, self.kappa, signed) + else: + Mod2mField(res, a, k, m, self.kappa, signed) + return res + + def _mask(self, a, k): + return maskField(a, k, self.kappa) + + def _trunc_pr(self, a, k, m, signed=None): + return TruncPrField(a, k, m, self.kappa) + + def bit_dec(self, a, k, m): + return BitDecField(a, k, m, self.kappa) + + def kor(self, d): + return KOR(d, self.kappa) + +class KnownPrime(NonLinear): + """ Non-linear functionality modulo a prime known at compile time. """ + def __init__(self, prime): + self.prime = prime + + def _mod2m(self, a, k, m, signed): + if signed: + a += cint(1) << (k - 1) + return sint.bit_compose(self.bit_dec(a, k, k)[:m]) + + def _trunc_pr(self, a, k, m, signed): + # nearest truncation + return self.trunc_round_nearest(a, k, m, signed) + + def trunc_round_nearest(self, a, k, m, signed): + a += cint(1) << (m - 1) + if signed: + a += cint(1) << (k - 1) + k += 1 + res = sint.bit_compose(self.bit_dec(a, k, k)[m:]) + if signed: + res -= cint(1) << (k - m - 2) + return res + + def bit_dec(self, a, k, m): + assert k < self.prime.bit_length() + bits = BitDecFull(a) + if len(bits) < m: + raise CompilerError('%d has fewer than %d bits' % (self.prime, m)) + return bits[:m] + + def eqz(self, a, k): + # always signed + a += two_power(k) + return 1 - KORL(self.bit_dec(a, k, k)) + +class Ring(Masking): + """ Non-linear functionality modulo a power of two known at compile time. + """ + def __init__(self, ring_size): + self.ring_size = ring_size + + def _mod2m(self, a, k, m, signed): + res = sint() + Mod2mRing(res, a, k, m, signed) + return res + + def _mask(self, a, k): + return maskRing(a, k) + + def _trunc_pr(self, a, k, m, signed): + return TruncPrRing(a, k, m, signed=signed) + + def bit_dec(self, a, k, m): + return BitDecRing(a, k, m) + + def kor(self, d): + return KORL(d) + + def trunc_round_nearest(self, a, k, m, signed): + if k == self.ring_size: + # cannot work with bit length k+1 + tmp = TruncRing(None, a, k, m - 1, signed) + return TruncRing(None, tmp + 1, k - m + 1, 1, signed) + else: + return super(Ring, self).trunc_round_nearest(a, k, m, signed) diff --git a/Compiler/oram.py b/Compiler/oram.py index db26bcf10..7ca4cb54f 100644 --- a/Compiler/oram.py +++ b/Compiler/oram.py @@ -1,3 +1,16 @@ +""" +This module contains an implementation of the tree-based oblivious +RAM as proposed by `Shi et al. `_ as +well as the straight-forward construction using linear scanning. +Unlike :py:class:`~Compiler.types.Array`, this allows access by a +secret index:: + + a = OptimalORAM(1000) + i = sint.get_input_from(0) + a[i] = sint.get_input_from(1) + +""" + import random import math import collections @@ -1645,6 +1658,12 @@ class OneLevelORAM(TreeORAM): index_structure = BaseORAMIndexStructure def OptimalORAM(size,*args,**kwargs): + """ Create an ORAM instance suitable for the size based on + experiments. + + :param size: number of elements + :param value_type: :py:class:`sint` (default) / :py:class:`sg2fn` + """ if optimal_threshold is None: if n_threads == 1: threshold = 2**11 diff --git a/Compiler/program.py b/Compiler/program.py index 38364bad4..004f4dc26 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -1,10 +1,15 @@ +""" +This module contains the building blocks of the compiler such as code +blocks and registers. Most relevant is the central :py:class:`Program` +object that holds various properties of the computation. +""" + from Compiler.config import * from Compiler.exceptions import * from Compiler.instructions_base import RegType import Compiler.instructions import Compiler.instructions_base import Compiler.instructions_base as inst_base -from . import compilerLib from . import allocator as al from . import util import random @@ -34,29 +39,70 @@ bit = 2, ) +class defaults: + debug = False + verbose = False + outfile = None + ring = 0 + field = 0 + binary = 0 + prime = None + galois = 40 + budget = 100000 + mixed = False + edabit = False + split = None + cisc = False + comparison = None + merge_opens = True + preserve_mem_order = False + max_parallel_open = 0 + dead_code_elimination = False + noreallocate = False + asmoutfile = None + stop = False + insecure = False class Program(object): - """ A program consists of a list of tapes and a scheduled order - of execution for these tapes. - - These are created by executing a file containing appropriate instructions - and threads. """ - def __init__(self, args, options): + """ A program consists of a list of tapes representing the whole + computation. + + When compiling an :file:`.mpc` file, the single instances is + available as :py:obj:`program` in order. When compiling directly + from Python code, an instance has to be created before running any + instructions. + """ + def __init__(self, args, options=defaults): + from .non_linear import Ring, Prime, KnownPrime self.options = options self.verbose = options.verbose self.args = args self.init_names(args) + self._security = 40 + self.prime = None if sum(x != 0 for x in(options.ring, options.field, options.binary)) > 1: raise CompilerError('can only use one out of -B, -R, -F') + if options.prime and (options.ring or options.binary): + raise CompilerError('can only use one out of -B, -R, -p') if options.ring: self.bit_length = int(options.ring) - 1 + self.non_linear = Ring(int(options.ring)) else: self.bit_length = int(options.binary) or int(options.field) - if not self.bit_length: - self.bit_length = 64 + if options.prime: + self.prime = int(options.prime) + max_bit_length = int(options.prime).bit_length() - 2 + if self.bit_length > max_bit_length: + raise CompilerError('integer bit length can be maximal %s' % + max_bit_length) + self.bit_length = self.bit_length or max_bit_length + self.non_linear = KnownPrime(self.prime) + else: + self.non_linear = Prime(self.security) + if not self.bit_length: + self.bit_length = 64 print('Default bit length:', self.bit_length) - self.security = 40 print('Default security parameter:', self.security) self.galois_length = int(options.galois) if self.verbose: @@ -64,7 +110,7 @@ def __init__(self, args, options): self.tape_counter = 0 self.tapes = [] self._curr_tape = None - self.DEBUG = False + self.DEBUG = options.debug self.allocated_mem = RegType.create_dict(lambda: USER_MEM) self.free_mem_blocks = defaultdict(al.BlockAllocator) self.allocated_mem_blocks = {} @@ -95,14 +141,23 @@ def __init__(self, args, options): self.to_merge += [gc.ldmsdi, gc.stmsdi, gc.ldmsd, gc.stmsd, \ gc.stmsdci, gc.xors, gc.andrs, gc.ands, gc.inputb] self.use_trunc_pr = False + """ Setting whether to use special probabilistic truncation. """ self.use_dabit = options.mixed + """ Setting whether to use daBits for non-linear functionality. """ self._edabit = options.edabit self._split = False if options.split: self.use_split(int(options.split)) self._square = False self._always_raw = False + self.warn_about_mem = [True] Program.prog = self + from . import instructions_base, instructions, types, comparison + instructions.program = self + instructions_base.program = self + types.program = self + comparison.program = self + comparison.set_variant(options) def get_args(self): return self.args @@ -202,7 +257,7 @@ def update_req(self, tape): else: self.req_num += tape.req_num - def write_bytes(self, outfile=None): + def write_bytes(self): """ Write all non-empty threads and schedule to files. """ nonempty_tapes = [t for t in self.tapes] @@ -216,6 +271,13 @@ def write_bytes(self, outfile=None): sch_file.write('1 0\n') sch_file.write('0\n') sch_file.write(' '.join(sys.argv) + '\n') + req = max(x.req_bit_length['p'] for x in self.tapes) + if self.options.ring: + sch_file.write('R:%s' % (self.options.ring if req else 0)) + elif self.options.prime: + sch_file.write('p:%s' % self.options.prime) + else: + sch_file.write('lgp:%s' % req) for tape in self.tapes: tape.write_bytes() @@ -282,6 +344,23 @@ def free(self, addr, mem_type): size = self.allocated_mem_blocks.pop((addr,mem_type)) self.free_mem_blocks[mem_type].push(addr, size) + def finalize(self): + # optimize the tapes + for tape in self.tapes: + tape.optimize(self.options) + + if self.tapes: + self.update_req(self.curr_tape) + + # finalize the memory + self.finalize_memory() + + self.write_bytes() + + if self.options.asmoutfile: + for tape in self.tapes: + tape.write_str(self.options.asmoutfile + '-' + tape.name) + def finalize_memory(self): from . import library self.curr_tape.start_new_basicblock(None, 'memory-usage') @@ -300,19 +379,32 @@ def finalize_memory(self): print('Saved %s memory units through reallocation' % self.saved) def public_input(self, x): + """ Append a value to the public input file. """ if self.public_input_file is None: self.public_input_file = open(self.programs_dir + '/Public-Input/%s' % self.name, 'w') self.public_input_file.write('%s\n' % str(x)) def set_bit_length(self, bit_length): + """ Change the integer bit length for non-linear functions. """ self.bit_length = bit_length print('Changed bit length for comparisons etc. to', bit_length) def set_security(self, security): - self.security = security + self._security = security + self.non_linear.set_security(security) print('Changed statistical security for comparison etc. to', security) + @property + def security(self): + """ The statistical security parameter for non-linear + functions. """ + return self._security + + @security.setter + def security(self, security): + self.set_security(security) + def optimize_for_gc(self): pass @@ -322,6 +414,12 @@ def get_tape_counter(self): return res def use_edabit(self, change=None): + """ Setting whether to use edaBits for non-linear + functionality (default: false). + + :param change: change setting if not :py:obj:`None` + :returns: setting if :py:obj:`change` is :py:obj:`None` + """ if change is None: return self._edabit else: @@ -331,6 +429,12 @@ def use_edabit_for(self, *args): return True def use_split(self, change=None): + """ Setting whether to use local arithmetic-binary share + conversion for non-linear functionality (default: false). + + :param change: change setting if not :py:obj:`None` + :returns: setting if :py:obj:`change` is :py:obj:`None` + """ if change is None: return self._split else: @@ -340,6 +444,12 @@ def use_split(self, change=None): self._split = change def use_square(self, change=None): + """ Setting whether to use preprocessed square tuples + (default: false). + + :param change: change setting if not :py:obj:`None` + :returns: setting if :py:obj:`change` is :py:obj:`None` + """ if change is None: return self._square else: @@ -352,17 +462,24 @@ def always_raw(self, change=None): self._always_raw = change def options_from_args(self): + """ Set a number of options from the command-line arguments. """ if 'trunc_pr' in self.args: self.use_trunc_pr = True if 'split' in self.args or 'split3' in self.args: self.use_split(3) if 'split4' in self.args: self.use_split(4) + if 'split2' in self.args: + self.use_split(2) if 'raw' in self.args: self.always_raw(True) if 'edabit' in self.args: self.use_edabit(True) + def disable_memory_warnings(self): + self.warn_about_mem.append(False) + self.curr_block.warn_about_mem = False + class Tape: """ A tape contains a list of basic blocks, onto which instructions are added. """ def __init__(self, name, program): @@ -405,6 +522,7 @@ def __init__(self, parent, name, scope, exit_condition=None): self.purged = False self.n_rounds = 0 self.n_to_merge = 0 + self.warn_about_mem = parent.program.warn_about_mem[-1] def __len__(self): return len(self.instructions) @@ -506,13 +624,7 @@ def init_registers(self): self.reg_counter = RegType.create_dict(lambda: 0) def init_names(self, name): - # ignore path to file - source must be in Programs/Source - name = name.split('/')[-1] - if name.endswith('.asm'): - self.name = name[:-4] - else: - self.name = name - self.infile = self.program.programs_dir + '/Source/' + self.name + '.asm' + self.name = name self.outfile = self.program.programs_dir + '/Bytecode/' + self.name + '.bc' def purge(self): @@ -557,14 +669,14 @@ def optimize(self, options): merger = al.Merger(block, options, \ tuple(self.program.to_merge)) if options.dead_code_elimination: - if len(block.instructions) > 100000: + if len(block.instructions) > 1000000: print('Eliminate dead code...') merger.eliminate_dead_code() if options.merge_opens and self.merge_opens: if len(block.instructions) == 0: block.used_from_scope = util.set_by_id() continue - if len(block.instructions) > 100000: + if len(block.instructions) > 1000000: print('Merging instructions...') numrounds = merger.longest_paths_merge() block.n_rounds = numrounds @@ -626,7 +738,7 @@ def alloc_loop(block): if child.instructions: left.append(child) for i,block in enumerate(reversed(self.basicblocks)): - if len(block.instructions) > 100000: + if len(block.instructions) > 1000000: print('Allocating %s, %d/%d' % \ (block.name, i, len(self.basicblocks))) if block.exit_condition is not None: @@ -870,6 +982,8 @@ def close_scope(self, outer_scope, parent_req_node, name): def require_bit_length(self, bit_length, t='p'): if t == 'p': + if self.program.prime: + assert bit_length < self.program.prime.bit_length() - 1 self.req_bit_length[t] = max(bit_length + 1, \ self.req_bit_length[t]) else: @@ -916,8 +1030,6 @@ def __init__(self, reg_type, program, size=None, i=None): self.caller = [frame[1:] for frame in inspect.stack()[1:]] else: self.caller = None - if self.i % 1000000 == 0 and self.i > 0: - print("Initialized %d registers at" % self.i, time.asctime()) @property def i(self): @@ -971,6 +1083,8 @@ def get_all(self): return self.vector or [self] def __getitem__(self, index): + if self.size == 1 and index == 0: + return self if not self.vector: self.create_vector_elements() return self.vector[index] diff --git a/Compiler/types.py b/Compiler/types.py index 41bedba00..c7042542d 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -142,8 +142,10 @@ def vectorized_operation(self, *args, **kwargs): raise CompilerError('Different vector sizes of operands: %d/%d' % (self.size, args[0].size)) set_global_vector_size(self.size) - res = operation(self, *args, **kwargs) - reset_global_vector_size() + try: + res = operation(self, *args, **kwargs) + finally: + reset_global_vector_size() return res copy_doc(vectorized_operation, operation) return vectorized_operation @@ -157,8 +159,10 @@ def vectorized_operation(self, *args, **kwargs): except AttributeError: pass set_global_vector_size(size) - res = operation(self, *args, **kwargs) - reset_global_vector_size() + try: + res = operation(self, *args, **kwargs) + finally: + reset_global_vector_size() return res copy_doc(vectorized_operation, operation) return vectorized_operation @@ -170,8 +174,10 @@ def vectorized_function(cls, *args, **kwargs): size = kwargs.pop('size') if size: set_global_vector_size(size) - res = function(cls, *args, **kwargs) - reset_global_vector_size() + try: + res = function(cls, *args, **kwargs) + finally: + reset_global_vector_size() else: res = function(cls, *args, **kwargs) return res @@ -191,8 +197,10 @@ def vectorized_init(*args, **kwargs): size = kwargs['size'] if size is not None: set_global_vector_size(size) - res = function(*args, **kwargs) - reset_global_vector_size() + try: + res = function(*args, **kwargs) + finally: + reset_global_vector_size() else: res = function(*args, **kwargs) return res @@ -202,8 +210,10 @@ def vectorized_init(*args, **kwargs): def set_instruction_type(operation): def instruction_typed_operation(self, *args, **kwargs): set_global_instruction_type(self.instruction_type) - res = operation(self, *args, **kwargs) - reset_global_instruction_type() + try: + res = operation(self, *args, **kwargs) + finally: + reset_global_instruction_type() return res copy_doc(instruction_typed_operation, operation) return instruction_typed_operation @@ -353,6 +363,11 @@ def bit_xor(self, other): :param self/other: 0 or 1 (any compatible type) :return: type depends on inputs (secret if any of them is) """ + if util.is_constant(other): + if other: + return 1 - self + else: + return self return self + other - 2 * self * other def bit_and(self, other): @@ -362,6 +377,10 @@ def bit_and(self, other): :rtype: depending on inputs (secret if any of them is) """ return self * other + def bit_not(self): + """ NOT in arithmetic circuits. """ + return 1 - self + def half_adder(self, other): """ Half adder in arithmetic circuits. @@ -388,6 +407,10 @@ def bit_and(self, other): :rtype: depending on inputs (secret if any of them is) """ return self & other + def bit_not(self): + """ NOT in binary circuits. """ + return ~self + def half_adder(self, other): """ Half adder in binary circuits. @@ -397,14 +420,14 @@ def half_adder(self, other): return self ^ other, self & other class _gf2n(_bit): - """ GF(2^n) functionality. """ + """ :math:`\mathrm{GF}(2^n)` functionality. """ def if_else(self, a, b): - """ MUX in GF(2^n) circuits. Similar to :py:meth:`_int.if_else`. """ + """ MUX in :math:`\mathrm{GF}(2^n)` circuits. Similar to :py:meth:`_int.if_else`. """ return b ^ self * self.hard_conv(a ^ b) def cond_swap(self, a, b, t=None): - """ Swapping in GF(2^n). Similar to :py:meth:`_int.if_else`. """ + """ Swapping in :math:`\mathrm{GF}(2^n)`. Similar to :py:meth:`_int.if_else`. """ prod = self * self.hard_conv(a ^ b) res = a ^ prod, b ^ prod if t is None: @@ -413,12 +436,15 @@ def cond_swap(self, a, b, t=None): return tuple(t.conv(r) for r in res) def bit_xor(self, other): - """ XOR in GF(2^n) circuits. + """ XOR in :math:`\mathrm{GF}(2^n)` circuits. :param self/other: 0 or 1 (any compatible type) :rtype: depending on inputs (secret if any of them is) """ return self ^ other + def bit_not(self): + return self ^ 1 + class _structure(object): """ Interface for type-dependent container types. """ @@ -754,10 +780,6 @@ def read_from_socket(cls, client_id, n=1): else: return res - @vectorize - def write_to_socket(self, client_id, message_type=ClientMessageType.NoType): - writesocketc(client_id, message_type, self) - @vectorized_classmethod def write_to_socket(self, client_id, values, message_type=ClientMessageType.NoType): """ Send a list of clear values to socket """ @@ -997,7 +1019,7 @@ def output_if(self, cond): class cgf2n(_clear, _gf2n): """ - Clear GF(2^n) value. n is 40 or 128, + Clear :math:`\mathrm{GF}(2^n)` value. n is 40 or 128, depending on USE_GF2N_LONG compile-time variable. """ __slots__ = [] @@ -1006,7 +1028,7 @@ class cgf2n(_clear, _gf2n): @classmethod def bit_compose(cls, bits, step=None): - """ Clear GF(2^n) bit composition. + """ Clear :math:`\mathrm{GF}(2^n)` bit composition. :param bits: list of cgf2n :param step: set every :py:obj:`step`-th bit in output (defaults to 1) """ @@ -1056,7 +1078,7 @@ def load_int(self, val): sum += chunk def __mul__(self, other): - """ Clear GF(2^n) multiplication. + """ Clear :math:`\mathrm{GF}(2^n)` multiplication. :param other: cgf2n/regint/int """ return super(cgf2n, self).__mul__(other) @@ -1100,7 +1122,7 @@ def __rshift__(self, other): def bit_decompose(self, bit_length=None, step=None): """ Clear bit decomposition. - :param bit_length: number of bits (defaults to global GF(2^n) bit length) + :param bit_length: number of bits (defaults to global :math:`\mathrm{GF}(2^n)` bit length) :param step: extract every :py:obj:`step`-th bit (defaults to 1) """ bit_length = bit_length or program.galois_length step = step or 1 @@ -1190,10 +1212,6 @@ def read_from_socket(cls, client_id, n=1): else: return res - @vectorize - def write_to_socket(self, client_id, message_type=ClientMessageType.NoType): - writesocketint(client_id, message_type, self) - @vectorized_classmethod def write_to_socket(self, client_id, values, message_type=ClientMessageType.NoType): """ Send a list of integers to socket """ @@ -1456,6 +1474,55 @@ def __init__(self, player, value): self.player = player self._v = value +class longint: + def __init__(self, value, length=None, n_limbs=None): + assert length is None or n_limbs is None + if isinstance(value, longint): + if n_limbs is None: + n_limbs = int(math.ceil(length / 64)) + assert n_limbs <= len(value.v) + self.v = value.v[:n_limbs] + elif isinstance(value, list): + assert length is None + self.v = value[:] + else: + if length is None: + length = 64 * n_limbs + if isinstance(value, int): + self.v = [(value >> i) for i in range(0, length, 64)] + else: + self.v = [(value >> i).to_regint(0) + for i in range(0, length, 64)] + + def coerce(self, other): + return longint(other, n_limbs=len(self.v)) + + def __eq__(self, other): + return reduce(operator.mul, (x == y for x, y in + zip(self.v, self.coerce(other).v))) + + def __add__(self, other): + other = self.coerce(other) + assert len(self.v) == len(other.v) + res = [] + carry = 0 + for x, y in zip(self.v, other.v): + res.append(x + y + carry) + carry = (res[-1] - 2 ** 31) < (x - 2 ** 31) + return longint(res) + + __radd__ = __add__ + + def __sub__(self, other): + return self + -other + + def bit_decompose(self, bit_length): + assert bit_length <= 64 * len(self.v) + res = [] + for x in self.v: + res += x.bit_decompose(64) + return res[:bit_length] + class _secret(_register): __slots__ = [] @@ -1706,7 +1773,7 @@ def reveal_to(self, player): Result written to ``Player-Data/Private-Output-P`` :param player: int - :returns: value to be used with :py:func:`Compiler.library.print_ln_to` + :returns: value to be used with :py:func:`~Compiler.library.print_ln_to` """ masked = self.__class__() res = personal(player, self.clear_type()) @@ -1818,7 +1885,11 @@ def get_raw_input_from(cls, player): @classmethod def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType): - """ Securely obtain shares of n values input by a client """ + """ Securely obtain shares of values input by a client. + + :param n: number of inputs (int) + :param client_id: regint + """ # send shares of a triple to client triples = list(itertools.chain(*(sint.get_random_triple() for i in range(n)))) sint.write_shares_to_socket(client_id, triples, message_type) @@ -1839,11 +1910,6 @@ def read_from_socket(cls, client_id, n=1): else: return res - @vectorize - def write_to_socket(self, client_id, message_type=ClientMessageType.NoType): - """ Send share and MAC share to socket """ - writesockets(client_id, message_type, self) - @vectorized_classmethod def write_to_socket(self, client_id, values, message_type=ClientMessageType.NoType): """ Send a list of shares and MAC shares to socket """ @@ -1856,7 +1922,11 @@ def write_share_to_socket(self, client_id, message_type=ClientMessageType.NoType @vectorized_classmethod def write_shares_to_socket(cls, client_id, values, message_type=ClientMessageType.NoType, include_macs=False): - """ Send shares of a list of values to a specified client socket """ + """ Send shares of a list of values to a specified client socket. + + :param client_id: regint + :param values: list of sint + """ if include_macs: writesockets(client_id, message_type, *values) else: @@ -2126,7 +2196,7 @@ def reveal_to(self, player): Result potentially written to ``Player-Data/Private-Output-P.`` :param player: public integer (int/regint/cint): - :returns: value to be used with :py:func:`Compiler.library.print_ln_to` + :returns: value to be used with :py:func:`~Compiler.library.print_ln_to` """ if not util.is_constant(player) or self.size > 1: secret_mask = sint() @@ -2138,7 +2208,7 @@ def reveal_to(self, player): return super(sint, self).reveal_to(player) class sgf2n(_secret, _gf2n): - """ Secret GF(2^n) value. """ + """ Secret :math:`\mathrm{GF}(2^n)` value. """ __slots__ = [] instruction_type = 'gf2n' clear_type = cgf2n @@ -2155,7 +2225,7 @@ def get_raw_input_from(cls, player): return res def add(self, other): - """ Secret GF(2^n) addition (XOR). + """ Secret :math:`\mathrm{GF}(2^n)` addition (XOR). :param other: sg2fn/cgf2n/regint/int """ if isinstance(other, sgf2nint): @@ -2164,7 +2234,7 @@ def add(self, other): return super(sgf2n, self).add(other) def mul(self, other): - """ Secret GF(2^n) multiplication. + """ Secret :math:`\mathrm{GF}(2^n)` multiplication. :param other: sg2fn/cgf2n/regint/int """ if isinstance(other, (sgf2nint)): @@ -2236,7 +2306,7 @@ def right_shift(self, other, bit_length=None): """ Secret right shift by public value: :param other: compile-time (int) - :param bit_length: number of bits of :py:obj:`self` (defaults to GF(2^n) bit length) """ + :param bit_length: number of bits of :py:obj:`self` (defaults to :math:`\mathrm{GF}(2^n)` bit length) """ bits = self.bit_decompose(bit_length) return sum(b << i for i,b in enumerate(bits[other:])) @@ -2519,10 +2589,18 @@ def __sub__(self, other): raise CompilerError('Unclear subtraction') a = self.bit_decompose() b = util.bit_decompose(other, self.n_bits) - d = [(reduce(util.bit_xor, (ai, bi, 1)), (1 - ai) * bi) + from util import bit_not, bit_and, bit_xor + n = 1 + for x in (a + b): + try: + n = x.n + break + except: + pass + d = [(bit_not(bit_xor(ai, bi), n), bit_and(bit_not(ai, n), bi)) for (ai,bi) in zip(a,b)] borrow = lambda y,x,*args: \ - (x[0] * y[0], 1 - (1 - x[1]) * (1 - x[0] * y[1])) + (bit_and(x[0], y[0]), util.OR(x[1], bit_and(x[0], y[1]))) borrows = (0,) + list(zip(*floatingpoint.PreOpL(borrow, d)))[1] return self.compose(reduce(util.bit_xor, (ai, bi, borrow)) \ for (ai,bi,borrow) in zip(a,b,borrows)) @@ -2597,7 +2675,15 @@ def __ne__(self, other): equal = __eq__ def __neg__(self): - return 1 + self.compose(1 ^ b for b in self.bit_decompose()) + bits = self.bit_decompose() + n = 1 + for b in bits: + try: + n = x.n + break + except: + pass + return 1 + self.compose(util.bit_not(b, n) for b in bits) def __abs__(self): return util.if_else(self.bit_decompose()[-1], -self, self) @@ -2837,11 +2923,6 @@ def read_from_socket(cls, client_id, n=1): return cfix._new(cint_inputs) else: return list(map(cfix, cint_inputs)) - - @vectorize - def write_to_socket(self, client_id, message_type=ClientMessageType.NoType): - """ Send cfix to socket. Value is sent as bit shifted cint. """ - writesocketc(client_id, message_type, cint(self.v)) @vectorized_classmethod def write_to_socket(self, client_id, values, message_type=ClientMessageType.NoType): @@ -2872,13 +2953,18 @@ def _new(cls, other, k=None, f=None): return res @staticmethod - def int_rep(v, f): - v = v * (2 ** f) + def int_rep(v, f, k=None): + res = v * (2 ** f) try: - v = int(round(v)) + res = int(round(res)) + if k and abs(res) >= 2 ** k: + raise CompilerError( + 'Value out of fixed-point range (maximum %d). ' + 'Use `sfix.set_precision(f, k)` with k being at least f+%d' + % (2 ** (k - f), math.ceil(math.log(abs(v), 2)) + 1)) except TypeError: pass - return v + return res @vectorize_init @read_mem_value @@ -2889,7 +2975,7 @@ def __init__(self, v=None, k=None, f=None, size=None): self.f = f self.k = k if isinstance(v, cfix.scalars): - v = self.int_rep(v, f) + v = self.int_rep(v, f=f, k=k) self.v = cint(v, size=size) elif isinstance(v, cfix): self.v = v.v @@ -3109,7 +3195,7 @@ class _single(_number, _structure): """ Representation as single integer preserving the order """ """ E.g. fixed-point numbers """ __slots__ = ['v'] - kappa = 40 + kappa = None round_nearest = False @classmethod @@ -3142,15 +3228,15 @@ def coerce(cls, other): @classmethod def malloc(cls, size, creator_tape=None): - return program.malloc(size, cls.int_type, creator_tape=creator_tape) + return cls.int_type.malloc(size, creator_tape=creator_tape) @classmethod def free(cls, addr): return cls.int_type.free(addr) - @staticmethod - def n_elements(): - return 1 + @classmethod + def n_elements(cls): + return cls.int_type.n_elements() @classmethod def dot_product(cls, x, y, res_params=None): @@ -3339,7 +3425,7 @@ def __init__(self, _v=None, k=None, f=None, size=None): elif isinstance(_v, self.int_type): self.load_int(_v) elif isinstance(_v, cfix.scalars): - self.v = self.int_type(cfix.int_rep(_v, f=f), size=size) + self.v = self.int_type(cfix.int_rep(_v, f=f, k=k), size=size) elif isinstance(_v, self.float_type): p = (f + _v.p) b = (p.greater_equal(0, _v.vlen)) @@ -3421,6 +3507,13 @@ def __truediv__(self, other): """ Secret fixed-point division. :param other: sfix/cfix/sint/cint/regint/int """ + if util.is_constant_float(other): + assert other != 0 + other_length = self.f + math.ceil(math.log(abs(other), 2)) + if other_length >= self.k: + factor = 2 ** (self.k - other_length - 1) + self *= factor + other *= factor other = self.coerce(other) assert self.k == other.k assert self.f == other.f @@ -3541,7 +3634,7 @@ def reveal_to(self, player): ``Player-Data/Private-Output-P.`` :param player: public integer (int/regint/cint) - :returns: value to be used with :py:func:`Compiler.library.print_ln_to` + :returns: value to be used with :py:func:`~Compiler.library.print_ln_to` """ return personal(player, cfix._new(self.v.reveal_to(player)._v, self.k, self.f)) @@ -3794,7 +3887,7 @@ class sfloat(_number, _structure): # single precision vlen = 24 plen = 8 - kappa = 40 + kappa = None round_nearest = False @staticmethod @@ -5115,13 +5208,11 @@ def write(self, value): :param value: convertible to relevant basic type """ self.check() if isinstance(value, MemValue): - self.register = value.read() - elif isinstance(value, int): - self.register = self.value_type(value) - else: - if value.size != self.size: - raise CompilerError('size mismatch') - self.register = value + value = value.read() + value = self.value_type.conv(value) + if value.size != self.size: + raise CompilerError('size mismatch') + self.register = value if not isinstance(self.register, self.value_type): raise CompilerError('Mismatch in register type, cannot write \ %s to %s' % (type(self.register), self.value_type)) diff --git a/Compiler/util.py b/Compiler/util.py index 10f2693a2..d586d61f1 100644 --- a/Compiler/util.py +++ b/Compiler/util.py @@ -120,7 +120,7 @@ def tree_reduce(function, sequence): return tree_reduce(function, reduced + sequence[n//2*2:]) def or_op(a, b): - return a + b - a * b + return a + b - bit_and(a, b) OR = or_op @@ -133,6 +133,21 @@ def bit_xor(a, b): else: return a.bit_xor(b) +def bit_and(a, b): + if is_constant(a): + if is_constant(b): + return a & b + else: + return b.bit_and(a) + else: + return a.bit_and(b) + +def bit_not(a, n): + if is_constant(a): + return ~a & (2 ** n - 1) + else: + return a.bit_not() + def pow2(bits): powers = [b.if_else(2**2**i, 1) for i,b in enumerate(bits)] return tree_reduce(operator.mul, powers) diff --git a/ECDSA/P256Element.cpp b/ECDSA/P256Element.cpp index 2440b1368..0d4f5cf1c 100644 --- a/ECDSA/P256Element.cpp +++ b/ECDSA/P256Element.cpp @@ -73,6 +73,11 @@ P256Element P256Element::operator *(const Scalar& other) const return res; } +P256Element operator*(const P256Element::Scalar& x, const P256Element& y) +{ + return y * x; +} + P256Element& P256Element::operator +=(const P256Element& other) { *this = *this + other; diff --git a/ECDSA/P256Element.h b/ECDSA/P256Element.h index 2b817d61f..4c335b8b7 100644 --- a/ECDSA/P256Element.h +++ b/ECDSA/P256Element.h @@ -65,6 +65,8 @@ class P256Element : public ValueInterface void unpack(octetStream& os); }; +P256Element operator*(const P256Element::Scalar& x, const P256Element& y); + ostream& operator<<(ostream& s, const P256Element& x); #endif /* ECDSA_P256ELEMENT_H_ */ diff --git a/ECDSA/hm-ecdsa-party.hpp b/ECDSA/hm-ecdsa-party.hpp index 5a3627f07..c17a7ff12 100644 --- a/ECDSA/hm-ecdsa-party.hpp +++ b/ECDSA/hm-ecdsa-party.hpp @@ -21,6 +21,7 @@ #include "Protocols/MaliciousRepMC.hpp" #include "Protocols/Beaver.hpp" #include "Protocols/fake-stuff.hpp" +#include "Protocols/MaliciousRepPrep.hpp" #include "Processor/Input.hpp" #include "Processor/Processor.hpp" #include "Processor/Data_Files.hpp" @@ -48,7 +49,7 @@ void run(int argc, const char** argv) OnlineOptions::singleton.batch_size = 1; // synchronize Bundle bundle(P); - P.Broadcast_Receive(bundle, false); + P.unchecked_broadcast(bundle); Timer timer; timer.start(); auto stats = P.comm_stats; diff --git a/ECDSA/ot-ecdsa-party.hpp b/ECDSA/ot-ecdsa-party.hpp index ad2ce981a..237b86500 100644 --- a/ECDSA/ot-ecdsa-party.hpp +++ b/ECDSA/ot-ecdsa-party.hpp @@ -110,7 +110,7 @@ void run(int argc, const char** argv) pShare sk, __; // synchronize Bundle bundle(P); - P.Broadcast_Receive(bundle, false); + P.unchecked_broadcast(bundle); Timer timer; timer.start(); auto stats = P.comm_stats; diff --git a/ECDSA/sign.hpp b/ECDSA/sign.hpp index e36e2905c..78d5322a2 100644 --- a/ECDSA/sign.hpp +++ b/ECDSA/sign.hpp @@ -139,7 +139,7 @@ void sign_benchmark(vector>& tuples, T sk, // synchronize Bundle bundle(P); - P.Broadcast_Receive(bundle, true); + P.unchecked_broadcast(bundle); Timer timer; timer.start(); auto stats = P.comm_stats; diff --git a/ExternalIO/bankers-bonus-client.cpp b/ExternalIO/bankers-bonus-client.cpp index cf5cdf440..a2b6b1296 100644 --- a/ExternalIO/bankers-bonus-client.cpp +++ b/ExternalIO/bankers-bonus-client.cpp @@ -38,6 +38,8 @@ #include "Math/Setup.h" #include "Protocols/fake-stuff.h" +#include "Math/gfp.hpp" + #include #include #include diff --git a/FHE/AddableVector.cpp b/FHE/AddableVector.cpp index 3f4ba2d99..9bf09c362 100644 --- a/FHE/AddableVector.cpp +++ b/FHE/AddableVector.cpp @@ -37,9 +37,3 @@ AddableVector AddableVector::mul_by_X_i(int j, template AddableVector> AddableVector>::mul_by_X_i(int j, const FHE_PK& pk) const; -template -AddableVector> AddableVector>::mul_by_X_i(int j, - const FHE_PK& pk) const; -template -AddableVector> AddableVector>::mul_by_X_i(int j, - const FHE_PK& pk) const; diff --git a/FHE/Ciphertext.cpp b/FHE/Ciphertext.cpp index 84a47e158..d06194597 100644 --- a/FHE/Ciphertext.cpp +++ b/FHE/Ciphertext.cpp @@ -3,6 +3,8 @@ #include "P2Data.h" #include "Exceptions/Exceptions.h" +#include "Math/modp.hpp" + Ciphertext::Ciphertext(const FHE_PK& pk) : Ciphertext(pk.get_params()) { } diff --git a/FHE/DiscreteGauss.h b/FHE/DiscreteGauss.h index b65247a40..7a51b75f4 100644 --- a/FHE/DiscreteGauss.h +++ b/FHE/DiscreteGauss.h @@ -6,8 +6,6 @@ */ #include -#include "Math/modp.h" -#include "Math/gfp.h" #include "Tools/random.h" #include #include diff --git a/FHE/FFT.cpp b/FHE/FFT.cpp index 8e63f1ae8..baed86eec 100644 --- a/FHE/FFT.cpp +++ b/FHE/FFT.cpp @@ -2,6 +2,9 @@ #include "FHE/FFT.h" #include "Math/Zp_Data.h" +#include "Math/modp.hpp" + + /* Computes the FFT via Horner's Rule theta is assumed to be an Nth root of unity */ diff --git a/FHE/FFT_Data.cpp b/FHE/FFT_Data.cpp index 2498e23ab..ecf87ac92 100644 --- a/FHE/FFT_Data.cpp +++ b/FHE/FFT_Data.cpp @@ -3,6 +3,8 @@ #include "FHE/Subroutines.h" +#include "Math/modp.hpp" + void FFT_Data::assign(const FFT_Data& FFTD) { diff --git a/FHE/FFT_Data.h b/FHE/FFT_Data.h index e68d93900..fc339c81e 100644 --- a/FHE/FFT_Data.h +++ b/FHE/FFT_Data.h @@ -3,7 +3,7 @@ #include "Math/modp.h" #include "Math/Zp_Data.h" -#include "Math/gfp.h" +#include "Math/gfpvar.h" #include "Math/fixint.h" #include "FHE/Ring.h" @@ -37,7 +37,7 @@ class FFT_Data public: typedef gfp T; typedef bigint S; - typedef fixint poly_type; + typedef fixint poly_type; void init(const Ring& Rg,const Zp_Data& PrD); diff --git a/FHE/FHE_Keys.cpp b/FHE/FHE_Keys.cpp index b7ba184ce..ec9e6e956 100644 --- a/FHE/FHE_Keys.cpp +++ b/FHE/FHE_Keys.cpp @@ -5,6 +5,8 @@ #include "PPData.h" #include "FFT_Data.h" +#include "Math/modp.hpp" + FHE_SK::FHE_SK(const FHE_PK& pk) : FHE_SK(pk.get_params(), pk.p()) { @@ -155,15 +157,6 @@ void FHE_PK::encrypt(Ciphertext& c, encrypt(c, mess.get_poly(), rc); } -template -void FHE_PK::encrypt(Ciphertext& c, const vector& mess, - const Random_Coins& rc) const -{ - Rq_Element mm((*params).FFTD(),polynomial,polynomial); - mm.from(Iterator(mess)); - quasi_encrypt(c, mm, rc); -} - void FHE_PK::quasi_encrypt(Ciphertext& c, const Rq_Element& mess,const Random_Coins& rc) const { @@ -400,11 +393,6 @@ template Ciphertext FHE_PK::encrypt(const Plaintext_& mess, template Ciphertext FHE_PK::encrypt(const Plaintext_& mess) const; template Ciphertext FHE_PK::encrypt(const Plaintext_& mess) const; -template void FHE_PK::encrypt(Ciphertext& c, const vector& mess, - const Random_Coins& rc) const; -template void FHE_PK::encrypt(Ciphertext& c, const vector>& mess, - const Random_Coins& rc) const; - template Plaintext_ FHE_SK::decrypt(const Ciphertext& c, const FFT_Data& FieldD); template Plaintext_ FHE_SK::decrypt(const Ciphertext& c, diff --git a/FHE/FHE_Keys.h b/FHE/FHE_Keys.h index decb7f592..70d9689f7 100644 --- a/FHE/FHE_Keys.h +++ b/FHE/FHE_Keys.h @@ -177,4 +177,13 @@ class FHE_KeyPair } }; +template +void FHE_PK::encrypt(Ciphertext& c, const vector& mess, + const Random_Coins& rc) const +{ + Rq_Element mm((*params).FFTD(),polynomial,polynomial); + mm.from(Iterator(mess)); + quasi_encrypt(c, mm, rc); +} + #endif diff --git a/FHE/FHE_Params.cpp b/FHE/FHE_Params.cpp index 035bd6263..59f548bc6 100644 --- a/FHE/FHE_Params.cpp +++ b/FHE/FHE_Params.cpp @@ -1,7 +1,6 @@ #include "FHE_Params.h" #include "FHE/Ring_Element.h" -#include "Math/gfp.h" #include "Exceptions/Exceptions.h" diff --git a/FHE/NTL-Subs.cpp b/FHE/NTL-Subs.cpp index 7b3b1a61e..53980239c 100644 --- a/FHE/NTL-Subs.cpp +++ b/FHE/NTL-Subs.cpp @@ -2,7 +2,7 @@ #include "FHE/NTL-Subs.h" #include "Math/Setup.h" -#include "Math/gfp.h" +#include "Math/gfpvar.h" #include "Math/gf2n.h" #include "FHE/P2Data.h" diff --git a/FHE/NoiseBounds.cpp b/FHE/NoiseBounds.cpp index d8c4549dc..1ec1e8714 100644 --- a/FHE/NoiseBounds.cpp +++ b/FHE/NoiseBounds.cpp @@ -50,7 +50,9 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p, * (16 * phi_m * sqrt(n / 2) + 6 * sqrt(phi_m) + 16 * sqrt(n * h * phi_m))) << slack; B_scale = p * sqrt(3 * phi_m) * (1 + 8 * sqrt(n * h) / 3); +#ifdef VERBOSE cout << "log(slack): " << slack << endl; +#endif } drown = 1 + n * (bigint(1) << sec); diff --git a/FHE/P2Data.cpp b/FHE/P2Data.cpp index 1cdc66448..63f38b840 100644 --- a/FHE/P2Data.cpp +++ b/FHE/P2Data.cpp @@ -1,10 +1,11 @@ #include "FHE/P2Data.h" #include "Math/Setup.h" +#include "Math/fixint.h" #include -void P2Data::forward(vector& ans,const vector& a) const +void P2Data::forward(vector& ans,const vector& a) const { int n=gf2n_short::degree(); @@ -25,12 +26,12 @@ void P2Data::forward(vector& ans,const vector& a) const } -void P2Data::backward(vector& ans,const vector& a) const +void P2Data::backward(vector& ans,const vector& a) const { int n=gf2n_short::degree(); BitVector bv(a.size()); for (size_t i = 0; i < a.size(); i++) - bv.set_bit(i, a[i]); + bv.set_bit(i, a[i].get_limb(0)); ans.resize(slots); word y; diff --git a/FHE/P2Data.h b/FHE/P2Data.h index 37707a266..adb6f70ca 100644 --- a/FHE/P2Data.h +++ b/FHE/P2Data.h @@ -20,7 +20,7 @@ class P2Data public: typedef gf2n_short T; typedef int S; - typedef int poly_type; + typedef fixint<0> poly_type; int num_slots() const { return slots; } int degree() const { return A.size() ? A.size() : 0; } @@ -28,10 +28,8 @@ class P2Data void check_dimensions() const; - // Despite only dealing with bits, we still use bigint's so - // we can easily dovetail into the FHE code - void forward(vector& ans,const vector& a) const; - void backward(vector& ans,const vector& a) const; + void forward(vector& ans,const vector& a) const; + void backward(vector& ans,const vector& a) const; int get_prime() const { return 2; } diff --git a/FHE/PPData.cpp b/FHE/PPData.cpp index b92b2012b..b73277e19 100644 --- a/FHE/PPData.cpp +++ b/FHE/PPData.cpp @@ -72,13 +72,13 @@ void PPData::from_eval(vector& elem) const void PPData::reset_iteration() { - pow=1; theta.assign(root); thetaPow=theta; + pow=1; theta = (root); thetaPow=theta; } void PPData::next_iteration() { do - { thetaPow.mul(theta); + { thetaPow *= (theta); pow++; } while (gcd(pow,m())!=1); @@ -91,12 +91,12 @@ gfp PPData::get_evaluation(const vector& mess) const { // Uses Horner's rule gfp ans; - to_gfp(ans,mess[mess.size()-1]); + ans = mess[mess.size()-1]; gfp coeff; for (int j=mess.size()-2; j>=0; j--) - { ans.mul(thetaPow); - to_gfp(coeff,mess[j]); - ans.add(coeff); + { ans *= (thetaPow); + coeff = mess[j]; + ans += (coeff); } return ans; } diff --git a/FHE/PPData.h b/FHE/PPData.h index 10a41d72e..fcb5a3fd1 100644 --- a/FHE/PPData.h +++ b/FHE/PPData.h @@ -3,9 +3,10 @@ #include "Math/modp.h" #include "Math/Zp_Data.h" -#include "Math/gfp.h" +#include "Math/gfpvar.h" #include "Math/fixint.h" #include "FHE/Ring.h" +#include "FHE/FFT_Data.h" /* Class for holding modular arithmetic data wrt the ring * @@ -15,9 +16,9 @@ class PPData { public: - typedef gf2n_short T; + typedef gfp T; typedef bigint S; - typedef fixint poly_type; + typedef typename FFT_Data::poly_type poly_type; Ring R; Zp_Data prData; diff --git a/FHE/Plaintext.cpp b/FHE/Plaintext.cpp index bbdae4442..afdec58b3 100644 --- a/FHE/Plaintext.cpp +++ b/FHE/Plaintext.cpp @@ -6,6 +6,7 @@ #include "FHE/Rq_Element.h" #include "FHE_Keys.h" #include "Math/Z2k.hpp" +#include "Math/modp.hpp" @@ -60,7 +61,7 @@ void Plaintext::from_poly() const (*Field_Data).to_eval(aa); a.resize(n_slots); for (unsigned int i=0; i::set_poly_mod(const Generator& generator,const template<> void Plaintext::set_poly_mod(const vector& vv,const bigint& mod) { - vector pol(vv.size()); + vector pol(vv.size()); bigint te; for (unsigned int i=0; imod/2) { te=vv[i]-mod; } @@ -221,21 +222,7 @@ void rand_poly(vector& b,PRNG& G,const bigint& pr,bool positive=true) { for (unsigned int i=0; i& b,PRNG& G,const bigint& pr,bool positive=true) -{ - (void)positive; - if (pr!=2) { throw bad_value(); } - int l=0; - unsigned char ch=0; - for (unsigned int i=0; i>=1; l--; + b[i].randomBnd(G, pr, positive); } } @@ -256,19 +243,13 @@ void Plaintext::randomize(PRNG& G,condition cond) break; default: // Gen a plaintext with 0/1 in each slot - int nb=0; - unsigned char ch=0; a.resize(n_slots); for (unsigned int i=0; i>1; nb--; } type=Evaluation; break; @@ -313,7 +294,7 @@ void Plaintext::randomize(PRNG& G, int n_bits, bool Diag, bool binary, P { case Polynomial: for (int i = 0; i < n_slots; i++) - G.get(b[i], n_bits, false); + b[i].generateUniform(G, n_bits, false); break; default: throw not_implemented(); @@ -401,7 +382,7 @@ void add(Plaintext& z,const Plaintext& { z.a.resize(z.n_slots); for (unsigned int i=0; i& z,const Plaintext& x, { z.a.resize(z.n_slots); for (unsigned int i=0; i& z,const Plaintext& { z.a.resize(z.n_slots); for (unsigned int i=0; i& z,const Plaintext& x, { z.a.resize(z.n_slots); for (unsigned int i=0; i& z,const Plaintext& x,const Plaintext z.allocate(); for (unsigned int i=0; i -void sqr(Plaintext& z,const Plaintext& x) -{ - if (z.Field_Data!=x.Field_Data) { throw field_mismatch(); } - - if (x.type==Polynomial) { throw not_implemented(); } - z.type=Evaluation; - - z.allocate(); - for (unsigned int i=0; i::equals(const Plaintext& x) const { a.resize(n_slots); for (unsigned int i=0; i::print_evaluation(int n_elements, string desc) const template class Plaintext; template void mul(Plaintext& z,const Plaintext& x,const Plaintext& y); -template void sqr(Plaintext& z,const Plaintext& x); template class Plaintext; template void mul(Plaintext& z,const Plaintext& x,const Plaintext& y); -template void sqr(Plaintext& z,const Plaintext& x); template class Plaintext; template void mul(Plaintext& z,const Plaintext& x,const Plaintext& y); -template void sqr(Plaintext& z,const Plaintext& x); diff --git a/FHE/Random_Coins.h b/FHE/Random_Coins.h index 102ac6889..e6be91b8f 100644 --- a/FHE/Random_Coins.h +++ b/FHE/Random_Coins.h @@ -15,6 +15,8 @@ class Int_Random_Coins : public AddableMatrix> const FHE_Params* params; public: + typedef value_type::value_type rand_type; + Int_Random_Coins(const FHE_Params& params) : params(¶ms) { resize(3, params.phi_m()); } diff --git a/FHE/Ring_Element.cpp b/FHE/Ring_Element.cpp index b80e78137..500d80a00 100644 --- a/FHE/Ring_Element.cpp +++ b/FHE/Ring_Element.cpp @@ -3,6 +3,8 @@ #include "Exceptions/Exceptions.h" #include "FHE/FFT.h" +#include "Math/modp.hpp" + void reduce_step(vector& aa,int i,const FFT_Data& FFTD) { modp temp=aa[i]; for (int j=0; j #include "Tools/Subroutines.h" +#include "Math/modp.hpp" + /* * This creates the "pseudo-encryption" of the R_q element mess, * - As required for key switching. diff --git a/FHEOffline/EncCommit.cpp b/FHEOffline/EncCommit.cpp index 693c0fb82..208b2ceb9 100644 --- a/FHEOffline/EncCommit.cpp +++ b/FHEOffline/EncCommit.cpp @@ -11,6 +11,8 @@ #include using namespace std; +#include "Math/modp.hpp" + // XXXX File_prefix is only used for active code #ifndef file_prefix #define file_prefix "/tmp/" diff --git a/FHEOffline/Multiplier.cpp b/FHEOffline/Multiplier.cpp index 3f6ecbf51..89017c754 100644 --- a/FHEOffline/Multiplier.cpp +++ b/FHEOffline/Multiplier.cpp @@ -7,6 +7,8 @@ #include "FHEOffline/PairwiseGenerator.h" #include "FHEOffline/PairwiseMachine.h" +#include "Math/modp.hpp" + template Multiplier::Multiplier(int offset, PairwiseGenerator& generator) : Multiplier(offset, generator.machine, generator.P, generator.timers) diff --git a/FHEOffline/PairwiseGenerator.cpp b/FHEOffline/PairwiseGenerator.cpp index 7286031a6..ed5fb303e 100644 --- a/FHEOffline/PairwiseGenerator.cpp +++ b/FHEOffline/PairwiseGenerator.cpp @@ -14,7 +14,7 @@ #include "Protocols/SemiInput.hpp" #include "Protocols/ReplicatedInput.hpp" #include "Processor/Input.hpp" -#include "Math/gfp.hpp" +#include "Math/modp.hpp" template PairwiseGenerator::PairwiseGenerator(int thread_num, diff --git a/FHEOffline/PairwiseSetup.cpp b/FHEOffline/PairwiseSetup.cpp index 1f6512eaf..649bd454a 100644 --- a/FHEOffline/PairwiseSetup.cpp +++ b/FHEOffline/PairwiseSetup.cpp @@ -11,6 +11,7 @@ #include "FHEOffline/PairwiseMachine.h" #include "Tools/Commit.h" #include "Tools/Bundle.h" +#include "Processor/OnlineOptions.h" template void PairwiseSetup::init(const Player& P, int sec, int plaintext_length, diff --git a/FHEOffline/Producer.cpp b/FHEOffline/Producer.cpp index 8551c737a..362a42204 100644 --- a/FHEOffline/Producer.cpp +++ b/FHEOffline/Producer.cpp @@ -13,8 +13,6 @@ #include "SimpleMachine.h" #include "Tools/mkpath.h" -#include "Math/gfp.hpp" - template Producer::Producer(int output_thread, bool write_output) : n_slots(0), output_thread(output_thread), write_output(write_output), @@ -396,8 +394,8 @@ void gfpBitProducer::run(const Player& P, const FHE_PK& pk, else { marks[i] = 0; - gfp temp = s.element(i).sqrRoot(); - temp.invert(); + gfp temp; + temp.invert(s.element(i).sqrRoot()); s.set_element(i, temp); } } @@ -416,7 +414,7 @@ void gfpBitProducer::run(const Player& P, const FHE_PK& pk, // Step j and k Share a; gfp two_inv, zero; - to_gfp(two_inv, (dd.f.get_field().get_prime() + 1) / 2); + two_inv = bigint((dd.f.get_field().get_prime() + 1) / 2); zero.assign_zero(); one.assign_one(); bits.clear(); @@ -586,15 +584,15 @@ void InputProducer::run(const Player& P, const FHE_PK& pk, for (auto& x : m) x.randomize(G); personal_EC.generate_proof(C, m, ciphertexts, cleartexts); - P.send_all(ciphertexts, true); - P.send_all(cleartexts, true); + P.send_all(ciphertexts); + P.send_all(cleartexts); } else { - P.receive_player(j, ciphertexts, true); - P.receive_player(j, cleartexts, true); + P.receive_player(j, ciphertexts); + P.receive_player(j, cleartexts); C.resize(personal_EC.machine->sec, pk.get_params()); - Verifier>(personal_EC.proof).NIZKPoK(C, ciphertexts, + Verifier(personal_EC.proof).NIZKPoK(C, ciphertexts, cleartexts, pk, false, false); } diff --git a/FHEOffline/Proof.cpp b/FHEOffline/Proof.cpp index 4fb9e7e39..e5bc641db 100644 --- a/FHEOffline/Proof.cpp +++ b/FHEOffline/Proof.cpp @@ -86,14 +86,14 @@ class AbsoluteBoundChecker } }; -template bool Proof::check_bounds(T& z, X& t, int i) const { unsigned int j,k; // Check Bound 1 and Bound 2 - AbsoluteBoundChecker> plain_checker(plain_check * n_proofs); - AbsoluteBoundChecker> rand_checker(rand_check * n_proofs); + AbsoluteBoundChecker plain_checker(plain_check * n_proofs); + AbsoluteBoundChecker rand_checker( + rand_check * n_proofs); for (j=0; j>& z, AddableMatrix>& t, int i) const; -template bool Proof::check_bounds(AddableVector>& z, AddableMatrix>& t, int i) const; -template bool Proof::check_bounds(AddableVector>& z, AddableMatrix>& t, int i) const; diff --git a/FHEOffline/Proof.h b/FHEOffline/Proof.h index d94710d11..e4b3c41fb 100644 --- a/FHEOffline/Proof.h +++ b/FHEOffline/Proof.h @@ -29,17 +29,18 @@ class Proof public: typedef AddableVector< Int_Random_Coins > Randomness; + typedef typename FFT_Data::poly_type bound_type; class Preimages { typedef Int_Random_Coins::value_type::value_type r_type; - fixint m_tmp; + bound_type m_tmp; AddableVector r_tmp; public: Preimages(int size, const FHE_PK& pk, const bigint& p, int n_players); - AddableMatrix> m; + AddableMatrix m; Randomness r; void add(octetStream& os); void pack(octetStream& os); @@ -67,6 +68,9 @@ class Proof static double dist; protected: + typedef AddableVector T; + typedef AddableMatrix X; + Proof(int sc, const bigint& Tau, const bigint& Rho, const FHE_PK& pk, int n_proofs = 1) : B_plain_length(0), B_rand_length(0), pk(&pk), n_proofs(n_proofs) @@ -115,7 +119,6 @@ class Proof void set_challenge(PRNG& G); void generate_challenge(const Player& P); - template bool check_bounds(T& z, X& t, int i) const; template diff --git a/FHEOffline/Prover.cpp b/FHEOffline/Prover.cpp index 6bf7977d3..0993d4fe1 100644 --- a/FHEOffline/Prover.cpp +++ b/FHEOffline/Prover.cpp @@ -4,6 +4,7 @@ #include "FHE/P2Data.h" #include "Tools/random.h" #include "Math/Z2k.hpp" +#include "Math/modp.hpp" template @@ -72,8 +73,8 @@ bool Prover::Stage_2(Proof& P, octetStream& cleartexts, unsigned int i; #ifndef LESS_ALLOC_MORE_MEM - AddableVector> z; - AddableMatrix> t; + AddableVector> z; + AddableMatrix> t; #endif cleartexts.reset_write_head(); cleartexts.store(P.V); diff --git a/FHEOffline/Prover.h b/FHEOffline/Prover.h index f6f121212..5e4b28c0b 100644 --- a/FHEOffline/Prover.h +++ b/FHEOffline/Prover.h @@ -14,7 +14,7 @@ class Prover AddableVector< Plaintext_ > y; #ifdef LESS_ALLOC_MORE_MEM - AddableVector> z; + AddableVector z; AddableMatrix t; #endif diff --git a/FHEOffline/Reshare.cpp b/FHEOffline/Reshare.cpp index 11777a9a5..a5bb31ae2 100644 --- a/FHEOffline/Reshare.cpp +++ b/FHEOffline/Reshare.cpp @@ -4,6 +4,8 @@ #include "FHE/P2Data.h" #include "Tools/random.h" +#include "Math/modp.hpp" + template void Reshare(Plaintext& m,Ciphertext& cc, const Ciphertext& cm,bool NewCiphertext, diff --git a/FHEOffline/Sacrificing.cpp b/FHEOffline/Sacrificing.cpp index 2926f0486..3bd5cd255 100644 --- a/FHEOffline/Sacrificing.cpp +++ b/FHEOffline/Sacrificing.cpp @@ -69,7 +69,7 @@ void Triple_Checking(const Player& P, MAC_Check& MC, int nm, Sh_Tau[i].sub(Sh_Tau[i],temp); temp.mul(b2[i],PO[2*i]); Sh_Tau[i].sub(Sh_Tau[i],temp); - te.mul(PO[2*i],PO[2*i+1]); + te = (PO[2*i] * PO[2*i+1]); Sh_Tau[i].sub(Sh_Tau[i],te,P.my_num(),MC.get_alphai()); } MC.POpen_Begin(Tau,Sh_Tau,P); @@ -192,7 +192,7 @@ void Square_Checking(const Player& P, MAC_Check& MC, int ns, T te,t,t2; Create_Random(t,P); - t2.mul(t,t); + t2 = t * t; vector > Sh_PO(amortize); vector PO(amortize); vector > f(amortize),h(amortize),a(amortize),b(amortize); @@ -259,7 +259,7 @@ void Bit_Checking(const Player& P, MAC_Check& MC, int nb, gfp te,t,t2; Create_Random(t,P); - t2.mul(t,t); + t2 = t * t; vector > Sh_PO(amortize); vector PO(amortize); vector > f(amortize),h(amortize),a(amortize),b(amortize); diff --git a/FHEOffline/Sacrificing.h b/FHEOffline/Sacrificing.h index 9ccbecc61..b29c26447 100644 --- a/FHEOffline/Sacrificing.h +++ b/FHEOffline/Sacrificing.h @@ -9,7 +9,7 @@ #include "Networking/Player.h" #include "Protocols/MAC_Check.h" #include "Math/Setup.h" -#include "Math/gfp.h" +#include "Math/gfpvar.h" template class TripleSacriFactory diff --git a/FHEOffline/SimpleEncCommit.cpp b/FHEOffline/SimpleEncCommit.cpp index 28ac67ab4..742482500 100644 --- a/FHEOffline/SimpleEncCommit.cpp +++ b/FHEOffline/SimpleEncCommit.cpp @@ -302,9 +302,9 @@ void SummingEncCommit::create_more() preimages.pack(cleartexts); this->timers["Verifying"].start(); #ifdef LESS_ALLOC_MORE_MEM - Verifier& verifier = this->verifier; + Verifier& verifier = this->verifier; #else - Verifier verifier(proof); + Verifier verifier(proof); #endif verifier.Stage_2(this->c, ciphertexts, cleartexts, this->pk, false, false); diff --git a/FHEOffline/SimpleEncCommit.h b/FHEOffline/SimpleEncCommit.h index c286faf2f..e3af52ee8 100644 --- a/FHEOffline/SimpleEncCommit.h +++ b/FHEOffline/SimpleEncCommit.h @@ -45,8 +45,6 @@ template class NonInteractiveProofSimpleEncCommit : public SimpleEncCommitBase_ { protected: - typedef fixint S; - const PlayerBase& P; const FHE_PK& pk; const FD& FTD; @@ -60,7 +58,7 @@ class NonInteractiveProofSimpleEncCommit : public SimpleEncCommitBase_ #ifdef LESS_ALLOC_MORE_MEM Proof::Randomness r; Prover > prover; - Verifier verifier; + Verifier verifier; #endif map& timers; @@ -129,8 +127,6 @@ template class SummingEncCommit: public SimpleEncCommitFactory, public SimpleEncCommitBase_ { - typedef fixint S; - InteractiveProof proof; const FHE_PK& pk; const FD& FTD; @@ -139,7 +135,7 @@ class SummingEncCommit: public SimpleEncCommitFactory, #ifdef LESS_ALLOC_MORE_MEM Prover > prover; - Verifier verifier; + Verifier verifier; Proof::Preimages preimages; #endif diff --git a/FHEOffline/SimpleGenerator.cpp b/FHEOffline/SimpleGenerator.cpp index 3c2163f3c..b2701b2c5 100644 --- a/FHEOffline/SimpleGenerator.cpp +++ b/FHEOffline/SimpleGenerator.cpp @@ -9,7 +9,6 @@ #include "Protocols/MAC_Check.h" #include "Protocols/MAC_Check.hpp" -#include "Math/gfp.hpp" template