diff --git a/.gitignore b/.gitignore index c6c59de3a..2d9ea905f 100644 --- a/.gitignore +++ b/.gitignore @@ -4,10 +4,42 @@ Player-Data/* Prep-Data/* logs/* Language-Definition/main.pdf +keys/* # Personal CONFIG file # ############################## CONFIG.mine +config_mine.py + +# Temporary files # +################### +*.bak +*.orig +*.rej +*.tmp +callgrind.out.* + +# Vim +.*.swp +tags + +# Eclipse # +########### +.project +.cproject +.settings + +# VS Code IDE # +############### +.vscode/** + +# Temporary files # +################### +*.bak +*.orig +*.rej +*.tmp +callgrind.out.* # Compiled source # ################### @@ -25,6 +57,8 @@ Programs/Public-Input/* *.bc *.sch *.a +*.static +*.d # Packages # ############ @@ -59,6 +93,8 @@ Programs/Public-Input/* *.log *.sql *.sqlite +*.data +Persistence/* # OS generated files # ###################### @@ -69,4 +105,5 @@ Programs/Public-Input/* .Spotlight-V100 .Trashes ehthumbs.db -Thumbs.db \ No newline at end of file +Thumbs.db +**/*.x.dSYM/** diff --git a/Auth/MAC_Check.cpp b/Auth/MAC_Check.cpp index 7bff41a31..e529c9be7 100644 --- a/Auth/MAC_Check.cpp +++ b/Auth/MAC_Check.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Auth/MAC_Check.h" diff --git a/Auth/MAC_Check.h b/Auth/MAC_Check.h index b61cc0c75..e290ee54a 100644 --- a/Auth/MAC_Check.h +++ b/Auth/MAC_Check.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _MAC_Check #define _MAC_Check diff --git a/Auth/Subroutines.cpp b/Auth/Subroutines.cpp index efaff68f9..2e6336626 100644 --- a/Auth/Subroutines.cpp +++ b/Auth/Subroutines.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Auth/Subroutines.h" diff --git a/Auth/Subroutines.h b/Auth/Subroutines.h index 6680718f4..50028945e 100644 --- a/Auth/Subroutines.h +++ b/Auth/Subroutines.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _Subroutines #define _Subroutines diff --git a/Auth/Summer.cpp b/Auth/Summer.cpp index b2de05c87..691f36b6e 100644 --- a/Auth/Summer.cpp +++ b/Auth/Summer.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Summer.cpp diff --git a/Auth/Summer.h b/Auth/Summer.h index 4a7c54046..c3a9df138 100644 --- a/Auth/Summer.h +++ b/Auth/Summer.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Summer.h diff --git a/Auth/fake-stuff.cpp b/Auth/fake-stuff.cpp index 40b0f4261..7cef045a1 100644 --- a/Auth/fake-stuff.cpp +++ b/Auth/fake-stuff.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Math/gf2n.h" diff --git a/Auth/fake-stuff.h b/Auth/fake-stuff.h index 3fc6783de..0695f7602 100644 --- a/Auth/fake-stuff.h +++ b/Auth/fake-stuff.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _fake_stuff diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 000000000..f6336c16e --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,56 @@ +The changelog explains changes pulled through from the private development repository. Bug fixes and small enchancements are committed between releases and not documented here. + +## 0.0.2 (Sep 13, 2017) + +### Support sockets based external client input and output to a SPDZ MPC program. + +See the [ExternalIO directory](./ExternalIO/README.md) for more details and examples. + +Note that [libsodium](https://download.libsodium.org/doc/) is now a dependency on the SPDZ build. + +Added compiler instructions: + +* LISTEN +* ACCEPTCLIENTCONNECTION +* CONNECTIPV4 +* WRITESOCKETSHARE +* WRITESOCKETINT + +Removed instructions: + +* OPENSOCKET +* CLOSESOCKET + +Modified instructions: + +* READSOCKETC +* READSOCKETS +* READSOCKETINT +* WRITESOCKETC +* WRITESOCKETS + +Support secure external client input and output with new instructions: + +* READCLIENTPUBLICKEY +* INITSECURESOCKET +* RESPSECURESOCKET + +### Read/Write secret shares to disk to support persistence in a SPDZ MPC program. + +Added compiler instructions: + +* READFILESHARE +* WRITEFILESHARE + +### Other instructions + +Added compiler instructions: + +* DIGESTC - Clear truncated hash computation +* PRINTINT - Print register value + +## 0.0.1 (Sep 2, 2016) + +### Initial Release + +* See `README.md` and `tutorial.md`. diff --git a/CONFIG b/CONFIG index 8c763a37e..32f940fbe 100644 --- a/CONFIG +++ b/CONFIG @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt ROOT = . @@ -28,7 +28,7 @@ endif # Default is 3, which suffices for 128-bit p # MOD = -DMAX_MOD_SZ=3 -LDLIBS = -lmpirxx -lmpir $(MY_LDLIBS) -lm -lpthread +LDLIBS = -lmpirxx -lmpir -lsodium $(MY_LDLIBS) -lm -lpthread ifeq ($(USE_NTL),1) LDLIBS := -lntl $(LDLIBS) @@ -40,7 +40,7 @@ LDLIBS += -lrt endif CXX = g++ -CFLAGS = $(MY_CFLAGS) -g -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) -maes -mpclmul -msse4.1 $(ARCH) +CFLAGS = $(MY_CFLAGS) -g -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) -maes -mpclmul -msse4.1 $(ARCH) --std=c++11 -Werror CPPFLAGS = $(CFLAGS) LD = g++ diff --git a/Check-Offline.cpp b/Check-Offline.cpp index 44f5f75da..f28e86ba0 100644 --- a/Check-Offline.cpp +++ b/Check-Offline.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Check-Offline.cpp @@ -62,21 +62,27 @@ void check_bits(const T& key,int N,vector& dataF,DataFieldType fiel vector > Sa(N),Sb(N),Sc(N); int n = 0; - while (!dataF[0]->eof(DATA_BIT)) - { - for (int i = 0; i < N; i++) - dataF[i]->get_one(field_type, DATA_BIT, Sa[i]); - check_share(Sa, a, mac, N, key); - - if (!(a.is_zero() || a.is_one())) + try { + while (!dataF[0]->eof(DATA_BIT)) { - cout << n << ": " << a << " neither 0 or 1" << endl; - throw bad_value(); + for (int i = 0; i < N; i++) + dataF[i]->get_one(field_type, DATA_BIT, Sa[i]); + check_share(Sa, a, mac, N, key); + + if (!(a.is_zero() || a.is_one())) + { + cout << n << ": " << a << " neither 0 or 1" << endl; + throw bad_value(); + } + n++; } - n++; - } - cout << n << " bits of type " << T::type_string() << endl; + cout << n << " bits of type " << T::type_string() << endl; + } + catch (exception& e) + { + cout << "Error with bits of type " << T::type_string() << endl; + } } template @@ -85,20 +91,26 @@ void check_inputs(const T& key,int N,vector& dataF) T a, mac, x; vector< Share > Sa(N); - for (int player = 0; player < N; player++) - { - int n = 0; - while (!dataF[0]->input_eof(player)) - { - for (int i = 0; i < N; i++) - dataF[i]->get_input(Sa[i], x, player); - check_share(Sa, a, mac, N, key); - if (!a.equal(x)) - throw bad_value(); - n++; - } - cout << n << " input masks for player " << player << " of type " << T::type_string() << endl; - } + try { + for (int player = 0; player < N; player++) + { + int n = 0; + while (!dataF[0]->input_eof(player)) + { + for (int i = 0; i < N; i++) + dataF[i]->get_input(Sa[i], x, player); + check_share(Sa, a, mac, N, key); + if (!a.equal(x)) + throw bad_value(); + n++; + } + cout << n << " input masks for player " << player << " of type " << T::type_string() << endl; + } + } + catch (exception& e) + { + cout << "Error with inputs of type " << T::type_string() << endl; + } } int main(int argc, const char** argv) diff --git a/Compiler/__init__.py b/Compiler/__init__.py index 417a7bff0..2879d6cbb 100644 --- a/Compiler/__init__.py +++ b/Compiler/__init__.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt import compilerLib, program, instructions, types, library, floatingpoint import inspect diff --git a/Compiler/allocator.py b/Compiler/allocator.py index 5d07e844f..91095dd97 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt import itertools, time from collections import defaultdict, deque @@ -11,20 +11,20 @@ import Compiler.program import heapq, itertools import operator +import sys class StraightlineAllocator: """Allocate variables in a straightline program using n registers. It is based on the precondition that every register is only defined once.""" def __init__(self, n): - self.free = defaultdict(set) self.alloc = {} self.usage = Compiler.program.RegType.create_dict(lambda: 0) self.defined = {} self.dealloc = set() self.n = n - def alloc_reg(self, reg, persistent_allocation): + def alloc_reg(self, reg, free): base = reg.vectorbase if base in self.alloc: # already allocated @@ -32,8 +32,8 @@ def alloc_reg(self, reg, persistent_allocation): reg_type = reg.reg_type size = base.size - if not persistent_allocation and self.free[reg_type, size]: - res = self.free[reg_type, size].pop() + if free[reg_type, size]: + res = free[reg_type, size].pop() else: if self.usage[reg_type] < self.n: res = self.usage[reg_type] @@ -48,7 +48,7 @@ def alloc_reg(self, reg, persistent_allocation): else: base.i = self.alloc[base] - def dealloc_reg(self, reg, inst): + def dealloc_reg(self, reg, inst, free): self.dealloc.add(reg) base = reg.vectorbase @@ -57,14 +57,14 @@ def dealloc_reg(self, reg, inst): if i not in self.dealloc: # not all vector elements ready for deallocation return - self.free[reg.reg_type, base.size].add(self.alloc[base]) + free[reg.reg_type, base.size].add(self.alloc[base]) if inst.is_vec() and base.vector: for i in base.vector: self.defined[i] = inst else: self.defined[reg] = inst - def process(self, program, persistent_allocation=False): + def process(self, program, alloc_pool): for k,i in enumerate(reversed(program)): unused_regs = [] for j in i.get_def(): @@ -75,7 +75,7 @@ def process(self, program, persistent_allocation=False): (j,i,format_trace(i.caller))) else: # unused register - self.alloc_reg(j, persistent_allocation) + self.alloc_reg(j, alloc_pool) unused_regs.append(j) if unused_regs and len(unused_regs) == len(i.get_def()): # only report if all assigned registers are unused @@ -83,9 +83,9 @@ def process(self, program, persistent_allocation=False): (unused_regs,i,format_trace(i.caller)) for j in i.get_used(): - self.alloc_reg(j, persistent_allocation) + self.alloc_reg(j, alloc_pool) for j in i.get_def(): - self.dealloc_reg(j, i) + self.dealloc_reg(j, i, alloc_pool) if k % 1000000 == 0 and k > 0: print "Allocated registers for %d instructions at" % k, time.asctime() @@ -98,7 +98,7 @@ def process(self, program, persistent_allocation=False): return self.usage -def determine_scope(block): +def determine_scope(block, options): last_def = defaultdict(lambda: -1) used_from_scope = set() @@ -120,12 +120,16 @@ def read(reg, n): print '\tline %d: %s' % (n, instr) print '\tinstruction trace: %s' % format_trace(instr.caller, '\t\t') print '\tregister trace: %s' % format_trace(reg.caller, '\t\t') + if options.stop: + sys.exit(1) def write(reg, n): if last_def[reg] != -1: print 'Warning: double write at register', reg print '\tline %d: %s' % (n, instr) print '\ttrace: %s' % format_trace(instr.caller, '\t\t') + if options.stop: + sys.exit(1) last_def[reg] = n for n,instr in enumerate(block.instructions): diff --git a/Compiler/comparison.py b/Compiler/comparison.py index b3df11b4e..8c6be3933 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt """ Functions for secure comparison of GF(p) types. @@ -68,10 +68,10 @@ def divide_by_two(res, x): """ Faster clear division by two using a cached value of 2^-1 mod p """ from program import Program import types - tape = Program.prog.curr_block - if tape not in inverse_of_two: - inverse_of_two[tape] = types.cint(1) / 2 - mulc(res, x, inverse_of_two[tape]) + block = Program.prog.curr_block + if len(inverse_of_two) == 0 or block not in inverse_of_two: + inverse_of_two[block] = types.cint(1) / 2 + mulc(res, x, inverse_of_two[block]) def LTZ(s, a, k, kappa): """ diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index 3df53c679..32d7573b1 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt from Compiler.program import Program from Compiler.config import * diff --git a/Compiler/config.py b/Compiler/config.py index 42248b1aa..3a237cc5c 100644 --- a/Compiler/config.py +++ b/Compiler/config.py @@ -1,5 +1,5 @@ -# (C) 2016 University of Bristol. See License.txt - +# (C) 2017 University of Bristol. See License.txt + from collections import defaultdict #INIT_REG_MAX = 655360 diff --git a/Compiler/dijkstra.py b/Compiler/dijkstra.py index b6ad19803..2ca13df78 100644 --- a/Compiler/dijkstra.py +++ b/Compiler/dijkstra.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt from Compiler.oram import * diff --git a/Compiler/exceptions.py b/Compiler/exceptions.py index eb1bf519d..8373b1d60 100644 --- a/Compiler/exceptions.py +++ b/Compiler/exceptions.py @@ -1,5 +1,5 @@ -# (C) 2016 University of Bristol. See License.txt - +# (C) 2017 University of Bristol. See License.txt + class CompilerError(Exception): """Base class for compiler exceptions.""" pass diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index 9ee2a868b..0e575925e 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt from math import log, floor, ceil from Compiler.instructions import * @@ -404,8 +404,8 @@ def TruncPr(a, k, m, kappa=None): return shift_two(a, m) if kappa is None: - kappa = 40 - + kappa = 40 + b = two_power(k-1) + a r_prime, r_dprime = types.sint(), types.sint() comparison.PRandM(r_dprime, r_prime, [types.sint() for i in range(m)], diff --git a/Compiler/graph.py b/Compiler/graph.py index 97c152af6..7f8e7f200 100644 --- a/Compiler/graph.py +++ b/Compiler/graph.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt import heapq from Compiler.exceptions import * diff --git a/Compiler/gs.py b/Compiler/gs.py index 8823c23bf..510f27c98 100644 --- a/Compiler/gs.py +++ b/Compiler/gs.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt import sys import math diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 63ad1173f..1453152e9 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt """ This module is for classes of actual assembly instructions. @@ -446,6 +446,13 @@ class legendrec(base.Instruction): code = base.opcodes['LEGENDREC'] arg_format = ['cw','c'] +@base.vectorize +class digestc(base.Instruction): + r""" Clear truncated hash computation, $c_i = H(c_j)[bytes]$. """ + __slots__ = [] + code = base.opcodes['DIGESTC'] + arg_format = ['cw','c','int'] + ### ### Bitwise operations ### @@ -915,6 +922,11 @@ class print_float_plain(base.IOInstruction): code = base.opcodes['PRINTFLOATPLAIN'] arg_format = ['c', 'c', 'c', 'c'] +class print_int(base.IOInstruction): + r""" Print only the value of register \verb|ci| to stdout. """ + __slots__ = [] + code = base.opcodes['PRINTINT'] + arg_format = ['ci'] class print_char(base.IOInstruction): r""" Print a single character to stdout. """ @@ -952,43 +964,156 @@ class pubinput(base.PublicFileIOInstruction): @base.vectorize class readsocketc(base.IOInstruction): - """Read an int from socket and store in register""" + """Read a variable number of clear GF(p) values from socket for a specified client id and store in registers""" __slots__ = [] code = base.opcodes['READSOCKETC'] - arg_format = ['ciw', 'int'] + arg_format = tools.chain(['ci'], itertools.repeat('cw')) + + def has_var_args(self): + return True @base.vectorize class readsockets(base.IOInstruction): - """Read a secret share + MAC from socket and store in register""" + """Read a variable number of secret shares + MACs from socket for a client id and store in registers""" __slots__ = [] code = base.opcodes['READSOCKETS'] - arg_format = ['sw', 'int'] + arg_format = tools.chain(['ci'], itertools.repeat('sw')) + + def has_var_args(self): + return True + +@base.vectorize +class readsocketint(base.IOInstruction): + """Read variable number of 32-bit int from socket for a client id and store in registers""" + __slots__ = [] + code = base.opcodes['READSOCKETINT'] + arg_format = tools.chain(['ci'], itertools.repeat('ciw')) + + def has_var_args(self): + return True @base.vectorize class writesocketc(base.IOInstruction): - """Write int from register into socket""" + """ + Write a variable number of clear GF(p) values from registers into socket + for a specified client id, message_type + """ __slots__ = [] code = base.opcodes['WRITESOCKETC'] - arg_format = ['ci', 'int'] + arg_format = tools.chain(['ci', 'int'], itertools.repeat('c')) + + def has_var_args(self): + return True @base.vectorize class writesockets(base.IOInstruction): - """Write secret share + MAC from register into socket""" + """ + Write a variable number of secret shares + MACs from registers into a socket + for a specified client id, message_type + """ __slots__ = [] code = base.opcodes['WRITESOCKETS'] - arg_format = ['s', 'int'] + arg_format = tools.chain(['ci', 'int'], itertools.repeat('s')) + + def has_var_args(self): + return True + +@base.vectorize +class writesocketshare(base.IOInstruction): + """ + Write a variable number of secret shares (without MACs) from registers into socket + for a specified client id, message_type + """ + __slots__ = [] + code = base.opcodes['WRITESOCKETSHARE'] + arg_format = tools.chain(['ci', 'int'], itertools.repeat('s')) + + def has_var_args(self): + return True + +@base.vectorize +class writesocketint(base.IOInstruction): + """ + Write a variable number of 32-bit ints from registers into socket + for a specified client id, message_type + """ + __slots__ = [] + code = base.opcodes['WRITESOCKETINT'] + arg_format = tools.chain(['ci', 'int'], itertools.repeat('ci')) + + def has_var_args(self): + return True -class opensocket(base.IOInstruction): - """Open a server socket connection at the given port number""" +class listen(base.IOInstruction): + """Open a server socket on a party specific port number and listen for client connections (non-blocking)""" __slots__ = [] - code = base.opcodes['OPENSOCKET'] + code = base.opcodes['LISTEN'] arg_format = ['int'] -class closesocket(base.IOInstruction): - """Close a server socket connection""" +class acceptclientconnection(base.IOInstruction): + """Wait for a connection at the given port and write socket handle to register """ __slots__ = [] - code = base.opcodes['CLOSESOCKET'] - arg_format = [] + code = base.opcodes['ACCEPTCLIENTCONNECTION'] + arg_format = ['ciw', 'int'] + +class connectipv4(base.IOInstruction): + """Connect to server at IPv4 address in register \verb|cj| at given port. Write socket handle to register \verb|ci|""" + __slots__ = [] + code = base.opcodes['CONNECTIPV4'] + arg_format = ['ciw', 'ci', 'int'] + +class readclientpublickey(base.IOInstruction): + """Read a client public key as 8 32-bit ints for a specified client id""" + __slots__ = [] + code = base.opcodes['READCLIENTPUBLICKEY'] + arg_format = tools.chain(['ci'], itertools.repeat('ci')) + + def has_var_args(self): + return True + +class initsecuresocket(base.IOInstruction): + """Read a client public key as 8 32-bit ints for a specified client id, + negotiate a shared key via STS and use it for replay resistant comms""" + __slots__ = [] + code = base.opcodes['INITSECURESOCKET'] + arg_format = tools.chain(['ci'], itertools.repeat('ci')) + + def has_var_args(self): + return True + +class respsecuresocket(base.IOInstruction): + """Read a client public key as 8 32-bit ints for a specified client id, + negotiate a shared key via STS and use it for replay resistant comms""" + __slots__ = [] + code = base.opcodes['RESPSECURESOCKET'] + arg_format = tools.chain(['ci'], itertools.repeat('ci')) + + def has_var_args(self): + return True + +class writesharestofile(base.IOInstruction): + """Write shares to a file""" + __slots__ = [] + code = base.opcodes['WRITEFILESHARE'] + arg_format = itertools.repeat('s') + + def has_var_args(self): + return True + +class readsharesfromfile(base.IOInstruction): + """ + Read shares from a file. Pass in start posn, return finish posn, shares. + Finish posn will return: + -2 file not found + -1 eof reached + position in file after read finished + """ + __slots__ = [] + code = base.opcodes['READFILESHARE'] + arg_format = tools.chain(['ci', 'ciw'], itertools.repeat('sw')) + + def has_var_args(self): + return True @base.gf2n @base.vectorize @@ -1173,7 +1298,7 @@ class gconvgf2n(base.Instruction): @base.gf2n @base.vectorize -class startopen(base.Instruction): +class startopen(base.VarArgsInstruction): """ Start opening secret register $s_i$. """ __slots__ = [] code = base.opcodes['STARTOPEN'] @@ -1183,12 +1308,9 @@ def execute(self): for arg in self.args[::-1]: program.curr_block.open_queue.append(arg.value) - def has_var_args(self): - return True - @base.gf2n @base.vectorize -class stopopen(base.Instruction): +class stopopen(base.VarArgsInstruction): """ Store previous opened value in $c_i$. """ __slots__ = [] code = base.opcodes['STOPOPEN'] @@ -1198,9 +1320,6 @@ def execute(self): for arg in self.args: arg.value = program.curr_block.open_queue.pop() - def has_var_args(self): - return True - ### ### CISC-style instructions ### diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index 88ab56b14..7fe003a26 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt import itertools from random import randint @@ -78,6 +78,7 @@ MODC = 0x36, MODCI = 0x37, LEGENDREC = 0x38, + DIGESTC = 0x39, GMULBITC = 0x136, GMULBITM = 0x137, # Open @@ -95,13 +96,18 @@ # Input INPUT = 0x60, STARTINPUT = 0x61, - STOPINPUT = 0x62, + STOPINPUT = 0x62, READSOCKETC = 0x63, READSOCKETS = 0x64, WRITESOCKETC = 0x65, WRITESOCKETS = 0x66, - OPENSOCKET = 0x67, - CLOSESOCKET = 0x68, + READSOCKETINT = 0x69, + WRITESOCKETINT = 0x6a, + WRITESOCKETSHARE = 0x6b, + LISTEN = 0x6c, + ACCEPTCLIENTCONNECTION = 0x6d, + CONNECTIPV4 = 0x6e, + READCLIENTPUBLICKEY = 0x6f, # Bitwise logic ANDC = 0x70, XORC = 0x71, @@ -131,6 +137,7 @@ SUBINT = 0x9C, MULINT = 0x9D, DIVINT = 0x9E, + PRINTINT = 0x9F, # Conversion CONVINT = 0xC0, CONVMODP = 0xC1, @@ -149,8 +156,13 @@ PRINTCHRINT = 0xBA, PRINTSTRINT = 0xBB, PRINTFLOATPLAIN = 0xBC, + WRITEFILESHARE = 0xBD, + READFILESHARE = 0xBE, GBITDEC = 0x184, GBITCOM = 0x185, + # Secure socket + INITSECURESOCKET = 0x1BA, + RESPSECURESOCKET = 0x1BB ) @@ -329,13 +341,11 @@ class RegType(object): @staticmethod def create_dict(init_value_fn): """ Create a dictionary with all the RegTypes as keys """ - return { - RegType.ClearModp : init_value_fn(), - RegType.SecretModp : init_value_fn(), - RegType.ClearGF2N : init_value_fn(), - RegType.SecretGF2N : init_value_fn(), - RegType.ClearInt : init_value_fn(), - } + res = defaultdict(init_value_fn) + # initialization for legacy + for t in RegType.Types: + res[t] + return res class ArgFormat(object): @classmethod @@ -481,7 +491,7 @@ def get_code(self): def get_encoding(self): enc = int_to_bytes(self.get_code()) - # add the number of registers to a start/stop open instruction + # add the number of registers if instruction flagged as has var args if self.has_var_args(): enc += int_to_bytes(len(self.args)) for arg,format in zip(self.args, self.arg_format): @@ -508,6 +518,8 @@ def check_args(self): except ArgumentError as e: raise CompilerError('Invalid argument "%s" to instruction: %s' % (e.arg, self) + '\n' + e.msg) + except KeyError as e: + raise CompilerError('Incorrect number of arguments for instruction %s' % (self)) def get_used(self): """ Return the set of registers that are read in this instruction. """ @@ -537,8 +549,15 @@ def get_size(self): def add_usage(self, req_node): pass + # String version of instruction attempting to replicate encoded version def __str__(self): - return self.__class__.__name__ + ' ' + self.get_pre_arg() + ', '.join(str(a) for a in self.args) + + if self.has_var_args(): + varargCount = str(len(self.args)) + ', ' + else: + varargCount = '' + + return self.__class__.__name__ + ' ' + self.get_pre_arg() + varargCount + ', '.join(str(a) for a in self.args) def __repr__(self): return self.__class__.__name__ + '(' + self.get_pre_arg() + ','.join(str(a) for a in self.args) + ')' @@ -725,6 +744,11 @@ def get_relative_jump(self): return self.args[self.jump_arg] +class VarArgsInstruction(Instruction): + def has_var_args(self): + return True + + class CISC(Instruction): """ Base class for a CISC instruction. diff --git a/Compiler/library.py b/Compiler/library.py index 3b99e496f..7e03620a9 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types, cfloat from Compiler.instructions import * @@ -72,9 +72,7 @@ def print_plain_str(ss): else: val = args[i] if isinstance(val, program.Tape.Register): - if val.reg_type == 'ci': - cint(val).print_reg_plain() - elif val.is_clear: + if val.is_clear: val.print_reg_plain() else: raise CompilerError('Cannot print secret value:', args[i]) @@ -355,7 +353,7 @@ def on_first_call(self, wrapped_function): parent_node = get_tape().req_node get_tape().open_scope(lambda x: x[0], None, 'begin-' + self.name) block = get_tape().active_basicblock - block.persistent_allocation = True + block.alloc_pool = defaultdict(set) del parent_node.children[-1] self.node = get_tape().req_node print 'Compiling function', self.name diff --git a/Compiler/oram.py b/Compiler/oram.py index 335373290..99dccc818 100644 --- a/Compiler/oram.py +++ b/Compiler/oram.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt import random import math @@ -15,8 +15,6 @@ from Compiler.util import * -sys.setrecursionlimit(1000000) - print_access = False sint_bit_length = 6 max_demux_bits = 3 @@ -40,12 +38,6 @@ def maybe_stop_timer(n): if detailed_timing: stop_timer(n) -def reveal(a): - try: - return a.reveal() - except AttributeError: - return a - class Block(object): def __init__(self, value, lengths): self.value = self.value_type.hard_conv(value) @@ -53,8 +45,7 @@ def __init__(self, value, lengths): def get_slice(self): res = [] for length,start in zip(self.lengths, series(self.lengths)): - res.append(sum(b << i for i,b in \ - enumerate(self.bits[start:start+length]))) + res.append(util.bit_compose((self.bits[start:start+length]))) return res def __repr__(self): return '<' + str(self.value) + '>' @@ -150,11 +141,17 @@ def set_slice(self, value): self.value = self.lower + value * self.adjust + upper return self +block_types = { sint: intBlock, + sgf2n: gf2nBlock, +} + def get_block(x, y, *args): - if isinstance(x, sgf2n) or isinstance(y, sgf2n): - return gf2nBlock(x, y, *args) - else: - return intBlock(x, y, *args) + for t in block_types: + if isinstance(x, t): + return block_types[t](x, y, *args) + elif isinstance(y, t): + return block_types[t](x, y, *args) + raise CompilerError('appropiate block type not found') def get_bit(x, index, bit_length): if isinstance(x, sgf2n): @@ -242,14 +239,14 @@ def equal(self, other, length=None): return (1 - self.empty) * (other == self.value) return (1 - self.empty) * self.value.equal(other, length) def reveal(self): - return Value(self.value.reveal(), self.empty.reveal()) + return Value(reveal(self.value), reveal(self.empty)) def output(self): - @if_e(self.empty) - def f(): - print_str('<>') - @else_ - def f(): - print_str('<%s>', self.value) + # @if_e(self.empty) + # def f(): + # print_str('<>') + # @else_ + # def f(): + print_str('<%s:%s>', self.empty, self.value) def __index__(self): return int(self.value) def __repr__(self): @@ -344,12 +341,13 @@ def __mul__(self, other): def reveal(self): return Entry(x.reveal() for x in self) def output(self): - @if_e(self.is_empty) - def f(): - print_str('{empty=%s}', self.is_empty) - @else_ - def f(): - print_str('{%s: %s}', self.v, self.x) + # @if_e(self.is_empty) + # def f(): + # print_str('{empty=%s}', self.is_empty) + # @else_ + # def f(): + # print_str('{%s: %s}', self.v, self.x)\ + print_str('{%s: %s,empty=%s}', self.v, self.x, self.is_empty) class RefRAM(object): """ RAM reference. """ @@ -362,8 +360,8 @@ def f(): crash() self.size = oram.bucket_size self.entry_type = oram.entry_type - self.l = [Array(self.size, t, array.address + \ - index * oram.bucket_size) \ + self.l = [t.dynamic_array(self.size, t, array.address + \ + index * oram.bucket_size) \ for t,array in zip(self.entry_type,oram.ram.l)] self.index = index def init_mem(self, empty_entry): @@ -410,7 +408,7 @@ def reveal(self): Program.prog.curr_tape.start_new_basicblock() return res def output(self): - self.reveal().print_reg() + print_ln('%s', [x.reveal() for x in self]) def print_reg(self): print_ln('listing of RAM at index %s', self.index) Program.prog.curr_tape.start_new_basicblock() @@ -428,7 +426,7 @@ def __init__(self, size, entry_type, index=0): #print_reg(cint(0), 'r in') self.size = size self.entry_type = entry_type - self.l = [Array(self.size, t) for t in entry_type] + self.l = [t.dynamic_array(self.size, t) for t in entry_type] self.index = index class AbstractORAM(object): @@ -902,7 +900,7 @@ class List(object): def __init__(self, size, value_type, value_length=1, init_rounds=None): self.value_type = value_type self.value_length = value_length - self.l = [Array(size, value_type) \ + self.l = [value_type.dynamic_array(size, value_type) \ for i in range(value_length)] for l in self.l: l.assign_all(0) @@ -1322,8 +1320,10 @@ def get_value_size(value_type): """ Return element size. """ if value_type == sgf2n: return Program.prog.galois_length - else: + elif value_type == sint: return 127 - Program.prog.security + else: + return value_type.max_length def get_parallel(index_size, value_type, value_length): """ Returning the number of parallel readings feasible, based on @@ -1410,7 +1410,7 @@ def f(i): else: self.l[i] = [0] * self.elements_per_block time() - print_ln('packed ORAM init %s/%s', cint(i), real_init_rounds) + print_ln('packed ORAM init %s/%s', i, real_init_rounds) print 'index initialized, size', size def translate_index(self, index): """ Bit slicing *index* according parameters. Output is tuple @@ -1425,18 +1425,17 @@ def translate_index(self, index): return 0, b, c else: return (index - rem) / self.entries_per_block, b, c - elif self.value_type == sgf2n: + else: index_bits = bit_decompose(index, log2(self.size)) l1 = self.log_entries_per_element l2 = self.log_entries_per_block - c = sum(bit << i for i,bit in enumerate(index_bits[:l1])) - b = sum(bit << i for i,bit in enumerate(index_bits[l1:l2])) + c = bit_compose(index_bits[:l1]) + b = bit_compose(index_bits[l1:l2]) if self.small: return 0, b, c else: - a = sum(bit << i for i,bit in enumerate(index_bits[l2:])) + a = bit_compose(index_bits[l2:]) return a, b, c - else: raise CompilerError('Cannot process indices of type', self.value_type) class Slicer(object): def __init__(self, pack, index): @@ -1624,11 +1623,11 @@ class OptimalPackedORAMWithEmpty(PackedORAMWithEmpty): def test_oram(oram_type, N, value_type=sint, iterations=100): oram = oram_type(N, value_type=value_type, entry_size=32, init_rounds=0) print 'initialized' - print_reg(cint(0), 'init') + print_ln('initialized') stop_timer() # synchronize Program.prog.curr_tape.start_new_basicblock(name='sync') - sint(0).reveal() + value_type(0).reveal() Program.prog.curr_tape.start_new_basicblock(name='sync') start_timer() #oram[value_type(0)] = -1 diff --git a/Compiler/path_oram.py b/Compiler/path_oram.py index 3e02d14ab..a6cc52b7d 100644 --- a/Compiler/path_oram.py +++ b/Compiler/path_oram.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt if '_Array' not in dir(): from oram import * @@ -76,7 +76,10 @@ def XOR(a, b): elif isinstance(a, sgf2n) or isinstance(b, sgf2n): return a + b else: - return a + b - 2*a*b + try: + return a ^ b + except TypeError: + return a + b - 2*a*b def pow2_eq(a, i, bit_length=40): """ Test for equality with 2**i, when a is a power of 2 (gf2n only)""" diff --git a/Compiler/permutation.py b/Compiler/permutation.py index 7c98beccf..394980f21 100644 --- a/Compiler/permutation.py +++ b/Compiler/permutation.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt from random import randint import math diff --git a/Compiler/program.py b/Compiler/program.py index 109772a34..843de14ae 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt from Compiler.config import * from Compiler.exceptions import * @@ -65,6 +65,7 @@ def __init__(self, args, options, param=-1, assemblymode=False): self.n_threads = 1 self.free_threads = set() self.public_input_file = open(self.programs_dir + '/Public-Input/%s' % self.name, 'w') + self.types = {} Program.prog = self self.reset_values() @@ -230,7 +231,7 @@ def write_bytes(self, outfile=None): # runtime doesn't support 'new-style' parallelism yet old_style = True - nonempty_tapes = [t for t in self.tapes if not t.is_empty()] + nonempty_tapes = [t for t in self.tapes] sch_filename = self.programs_dir + '/Schedules/%s.sch' % self.name sch_file = open(sch_filename, 'w') @@ -327,12 +328,15 @@ def curr_block(self): """ The basic block that is currently being created. """ return self.curr_tape.active_basicblock - def malloc(self, size, mem_type): + def malloc(self, size, mem_type, reg_type=None): """ Allocate memory from the top """ if size == 0: return if isinstance(mem_type, type): + self.types[mem_type.reg_type] = mem_type mem_type = mem_type.reg_type + elif reg_type is not None: + self.types[mem_type] = reg_type key = size, mem_type if self.free_mem_blocks[key]: addr = self.free_mem_blocks[key].pop() @@ -346,7 +350,8 @@ def malloc(self, size, mem_type): def free(self, addr, mem_type): """ Free memory """ - if self.curr_block.persistent_allocation: + if self.curr_block.alloc_pool \ + is not self.curr_tape.basicblocks[0].alloc_pool: raise CompilerError('Cannot free memory within function block') size = self.allocated_mem_blocks.pop((addr,mem_type)) self.free_mem_blocks[size,mem_type].add(addr) @@ -354,10 +359,15 @@ def free(self, addr, mem_type): def finalize_memory(self): import library self.curr_tape.start_new_basicblock(None, 'memory-usage') + # reset register counter to 0 + self.curr_tape.init_registers() for mem_type,size in self.allocated_mem.items(): if size: #print "Memory of type '%s' of size %d" % (mem_type, size) - library.load_mem(size - 1, mem_type) + if mem_type in self.types: + self.types[mem_type].load_mem(size - 1, mem_type) + else: + library.load_mem(size - 1, mem_type) def public_input(self, x): self.public_input_file.write('%s\n' % str(x)) @@ -407,9 +417,9 @@ def __init__(self, parent, name, scope, exit_condition=None): self.children = [] if scope is not None: scope.children.append(self) - self.persistent_allocation = scope.persistent_allocation + self.alloc_pool = scope.alloc_pool else: - self.persistent_allocation = False + self.alloc_pool = defaultdict(set) def new_reg(self, reg_type, size=None): return self.parent.new_reg(reg_type, size=size) @@ -511,7 +521,7 @@ def optimize(self, options): print 'Processing tape', self.name, 'with %d blocks' % len(self.basicblocks) for block in self.basicblocks: - al.determine_scope(block) + al.determine_scope(block, options) # merge open instructions # need to do this if there are several blocks @@ -563,15 +573,15 @@ def optimize(self, options): # allocate registers reg_counts = self.count_regs() - if filter(lambda n: n > REG_MAX, reg_counts) and not options.noreallocate: - print 'Tape register usage:' + if not options.noreallocate: + print 'Tape register usage:', reg_counts print 'modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp]) print 'GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N]) print 'Re-allocating...' allocator = al.StraightlineAllocator(REG_MAX) def alloc_loop(block): for reg in block.used_from_scope: - allocator.alloc_reg(reg, block.persistent_allocation) + allocator.alloc_reg(reg, block.alloc_pool) for child in block.children: if child.instructions: alloc_loop(child) @@ -584,7 +594,7 @@ def alloc_loop(block): if isinstance(jump, (int,long)) and jump < 0 and \ block.exit_block.scope is not None: alloc_loop(block.exit_block.scope) - allocator.process(block.instructions, block.persistent_allocation) + allocator.process(block.instructions, block.alloc_pool) # offline data requirements print 'Compile offline data requirements...' @@ -614,10 +624,11 @@ def alloc_loop(block): if not self.is_empty(): # bit length requirement - self.basicblocks[-1].instructions.append( - Compiler.instructions.reqbl(self.req_bit_length['p'], add_to_prog=False)) - self.basicblocks[-1].instructions.append( - Compiler.instructions.greqbl(self.req_bit_length['2'], add_to_prog=False)) + for x in ('p', '2'): + if self.req_bit_length['p']: + self.basicblocks[-1].instructions.append( + Compiler.instructions.reqbl(self.req_bit_length['p'], + add_to_prog=False)) print 'Tape requires prime bit length', self.req_bit_length['p'] print 'Tape requires galois bit length', self.req_bit_length['2'] diff --git a/Compiler/tools.py b/Compiler/tools.py index d30891be7..b36ede6b1 100644 --- a/Compiler/tools.py +++ b/Compiler/tools.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt import itertools diff --git a/Compiler/types.py b/Compiler/types.py index 1f86e3213..8ac97b53a 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt from Compiler.program import Tape from Compiler.exceptions import * @@ -11,6 +11,20 @@ import operator +class ClientMessageType: + """ Enum to define type of message sent to external client. Each may be array of length n.""" + # No client message type to be sent, for backwards compatibility - virtual machine relies on this value + NoType = 0 + # 3 x sint x n + TripleShares = 1 + # 1 x cint x n + ClearModpInt = 2 + # 1 x regint x n + Int32 = 3 + # 1 x cint (fixed point left shifted by precision) x n + ClearModpFix = 4 + + class MPCThread(object): def __init__(self, target, name, args = [], runtime_arg = None): """ Create a thread from a callable object. """ @@ -97,6 +111,10 @@ def read_mem_operation(self, other, *args, **kwargs): class _number(object): + @staticmethod + def bit_compose(bits): + return sum(b << i for i,b in enumerate(bits)) + def square(self): return self * self @@ -152,7 +170,6 @@ def cond_swap(self, a, b, t=None): else: return tuple(t.conv(r) for r in res) - class _register(Tape.Register, _number): MemValue = staticmethod(lambda value: MemValue(value)) @@ -340,6 +357,9 @@ def __or__(self, other): __rxor__ = __xor__ __ror__ = __or__ + def reveal(self): + return self + class cint(_clear, _int): " Clear mod p integer type. """ @@ -348,7 +368,25 @@ class cint(_clear, _int): reg_type = 'c' @vectorized_classmethod - def load_mem(cls, address): + def read_from_socket(cls, client_id, n=1): + res = [cls() for i in range(n)] + readsocketc(client_id, *res) + if n == 1: + return res[0] + else: + return res + + @vectorize + def write_to_socket(self, client_id, message_type=ClientMessageType.NoType): + writesocketc(client_id, message_type, self) + + @vectorized_classmethod + def write_to_socket(self, client_id, values, message_type=ClientMessageType.NoType): + """ Send a list of modp integers to socket """ + writesocketc(client_id, message_type, *values) + + @vectorized_classmethod + def load_mem(cls, address, mem_type=None): return cls._load_mem(address, ldmc, ldmci) def store_in_mem(self, address): @@ -464,6 +502,13 @@ def legendre(self): legendrec(res, self) return res + def digest(self, num_bytes): + res = cint() + digestc(res, self, num_bytes) + return res + + + class cgf2n(_clear, _gf2n): __slots__ = [] @@ -478,7 +523,7 @@ def bit_compose(cls, bits, step=None): return res @vectorized_classmethod - def load_mem(cls, address): + def load_mem(cls, address, mem_type=None): return cls._load_mem(address, gldmc, gldmci) def store_in_mem(self, address): @@ -560,7 +605,7 @@ def protect_memory(cls, start, end): protectmemint(regint(start), regint(end)) @vectorized_classmethod - def load_mem(cls, address): + def load_mem(cls, address, mem_type=None): return cls._load_mem(address, ldmint, ldminti) def store_in_mem(self, address): @@ -581,14 +626,40 @@ def get_random(cls, bit_length): return res @vectorized_classmethod - def read_from_socket(cls): - res = cls() - readsocketc(res,0) + def read_from_socket(cls, client_id, n=1): + """ Receive n register values from socket """ + res = [cls() for i in range(n)] + readsocketint(client_id, *res) + if n == 1: + return res[0] + else: + return res + + @vectorized_classmethod + def read_client_public_key(cls, client_id): + """ Receive 8 register values from socket containing client public key.""" + res = [cls() for i in range(8)] + readclientpublickey(client_id, *res) return res + @vectorized_classmethod + def init_secure_socket(cls, client_id, w1, w2, w3, w4, w5, w6, w7, w8): + """ Use 8 register values containing client public key.""" + initsecuresocket(client_id, w1, w2, w3, w4, w5, w6, w7, w8) + + @vectorized_classmethod + def resp_secure_socket(cls, client_id, w1, w2, w3, w4, w5, w6, w7, w8): + """ Receive 8 register values from socket containing client public key.""" + respsecuresocket(client_id, w1, w2, w3, w4, w5, w6, w7, w8) + @vectorize - def write_to_socket(self): - writesocketc(self,0) + def write_to_socket(self, client_id, message_type=ClientMessageType.NoType): + writesocketint(client_id, message_type, self) + + @vectorized_classmethod + def write_to_socket(self, client_id, values, message_type=ClientMessageType.NoType): + """ Send a list of integers to socket """ + writesocketint(client_id, message_type, *values) @vectorize_init def __init__(self, val=None, size=None): @@ -614,7 +685,11 @@ def load_other(self, val): elif isinstance(val, regint): addint(self, val, regint(0)) else: - raise CompilerError("Cannot convert '%s' to integer" % type(val)) + try: + val.to_regint(self) + except AttributeError: + raise CompilerError("Cannot convert '%s' to integer" % \ + type(val)) @vectorize @read_mem_value @@ -652,10 +727,10 @@ def __rdiv__(self, other): return self.int_op(other, divint, True) def __mod__(self, other): - return cint(self) % other + return self - (self / other) * other def __rmod__(self, other): - return other % cint(self) + return regint(other) % self def __rpow__(self, other): return other**cint(self) @@ -679,10 +754,16 @@ def __ge__(self, other): return 1 - (self < other) def __lshift__(self, other): - return regint(cint(self) << other) + if isinstance(other, (int, long)): + return self * 2**other + else: + return regint(cint(self) << other) def __rshift__(self, other): - return regint(cint(self) >> other) + if isinstance(other, (int, long)): + return self / 2**other + else: + return regint(cint(self) >> other) def __rlshift__(self, other): return regint(other << cint(self)) @@ -706,6 +787,31 @@ def __xor__(self, other): def mod2m(self, *args, **kwargs): return cint(self).mod2m(*args, **kwargs) + def bit_decompose(self, bit_length=None): + res = [] + x = self + two = regint(2) + for i in range(bit_length or program.bit_length): + y = x / two + res.append(x - two * y) + x = y + return res + + @staticmethod + def bit_compose(bits): + two = regint(2) + res = 0 + for bit in reversed(bits): + res *= two + res += bit + return res + + def reveal(self): + return self + + def print_reg_plain(self): + print_int(self) + class _secret(_register): __slots__ = [] @@ -875,18 +981,54 @@ def get_raw_input_from(cls, player): stopinput(player, res) return res + @classmethod + def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType): + """ Securely obtain shares of n values input by a client """ + # send shares of a triple to client + triples = list(itertools.chain(*(sint.get_random_triple() for i in range(n)))) + sint.write_shares_to_socket(client_id, triples, message_type) + + received = cint.read_from_socket(client_id, n) + y = [0] * n + for i in range(n): + y[i] = received[i] - triples[i * 3] + return y + @vectorized_classmethod - def read_from_socket(cls): - res = cls() - readsockets(res,0) - return res + def read_from_socket(cls, client_id, n=1): + """ Receive n shares and MAC shares from socket """ + res = [cls() for i in range(n)] + readsockets(client_id, *res) + if n == 1: + return res[0] + else: + return res + + @vectorize + def write_to_socket(self, client_id, message_type=ClientMessageType.NoType): + """ Send share and MAC share to socket """ + writesockets(client_id, message_type, self) + + @vectorized_classmethod + def write_to_socket(self, client_id, values, message_type=ClientMessageType.NoType): + """ Send a list of shares and MAC shares to socket """ + writesockets(client_id, message_type, *values) @vectorize - def write_to_socket(self): - writesockets(self,0) + def write_share_to_socket(self, client_id, message_type=ClientMessageType.NoType): + """ Send only share to socket """ + writesocketshare(client_id, message_type, self) @vectorized_classmethod - def load_mem(cls, address): + def write_shares_to_socket(cls, client_id, values, message_type=ClientMessageType.NoType, include_macs=False): + """ Send shares of a list of values to a specified client socket """ + if include_macs: + writesockets(client_id, message_type, *values) + else: + writesocketshare(client_id, message_type, *values) + + @vectorized_classmethod + def load_mem(cls, address, mem_type=None): return cls._load_mem(address, ldms, ldmsi) def store_in_mem(self, address): @@ -1035,7 +1177,7 @@ def mul(self, other): return super(sgf2n, self).mul(other) @vectorized_classmethod - def load_mem(cls, address): + def load_mem(cls, address, mem_type=None): return cls._load_mem(address, gldms, gldmsi) def store_in_mem(self, address): @@ -1100,9 +1242,10 @@ def bit_decompose(self, bit_length=None, step=1): bit_length = bit_length or program.galois_length random_bits = [self.get_random_bit() \ for i in range(0, bit_length, step)] + one = cgf2n(1) masked = sum([b * (one << (i * step)) for i,b in enumerate(random_bits)], self).reveal() - masked_bits = masked.bit_decompose(bit_length) + masked_bits = masked.bit_decompose(bit_length,step=step) return [m + r for m,r in zip(masked_bits, random_bits)] @vectorize @@ -1456,6 +1599,29 @@ def load_mem(cls, address, mem_type=None): res.append(cint.load_mem(address)) return cfix(*res) + @vectorized_classmethod + def read_from_socket(cls, client_id, n=1): + """ Read one or more cfix values from a socket. + Sender will have already bit shifted and sent as cints.""" + cint_input = cint.read_from_socket(client_id, n) + if n == 1: + return cfix(cint_inputs) + else: + return map(cfix, cint_inputs) + + @vectorize + def write_to_socket(self, client_id, message_type=ClientMessageType.NoType): + """ Send cfix to socket. Value is sent as bit shifted cint. """ + writesocketc(client_id, message_type, cint(self.v)) + + @vectorized_classmethod + def write_to_socket(self, client_id, values, message_type=ClientMessageType.NoType): + """ Send a list of cfix values to socket. Values are sent as bit shifted cints. """ + def cfix_to_cint(fix_val): + return cint(fix_val.v) + cint_values = map(cfix_to_cint, values) + writesocketc(client_id, message_type, *cint_values) + @vectorize_init def __init__(self, v=None, size=None): f = self.f @@ -1613,6 +1779,13 @@ def set_precision(cls, f, k = None): else: cls.k = k + @classmethod + def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType): + """ Securely obtain shares of n values input by a client. + Assumes client has already run bit shift to convert fixed point to integer.""" + sint_inputs = sint.receive_from_client(n, client_id, ClientMessageType.TripleShares) + return map(sfix, sint_inputs) + @vectorized_classmethod def load_mem(cls, address, mem_type=None): res = [] @@ -1787,7 +1960,7 @@ class sfloat(_number): error = 0 @vectorized_classmethod - def load_mem(cls, address): + def load_mem(cls, address, mem_type=None): res = [] for i in range(4): res.append(sint.load_mem(address + i * get_global_vector_size())) @@ -2075,10 +2248,13 @@ def __init__(self, length, value_type, address=None): if value_type in _types: value_type = _types[value_type] self.address = address - if address is None: - self.address = program.malloc(length, value_type.reg_type) self.length = length self.value_type = value_type + if address is None: + self.address = self._malloc() + + def _malloc(self): + return program.malloc(self.length, self.value_type) def delete(self): if program: @@ -2106,7 +2282,7 @@ def __getitem__(self, index): def f(i): res[i] = self[start+i*step] return res - return self.value_type.load_mem(self.get_address(index)) + return self._load(self.get_address(index)) def __setitem__(self, index, value): if isinstance(index, slice): @@ -2117,7 +2293,13 @@ def f(i): self[i] = value[source_index] source_index.iadd(1) return - self.value_type.conv(value).store_in_mem(self.get_address(index)) + self._store(self.value_type.conv(value), self.get_address(index)) + + def _load(self, address): + return self.value_type.load_mem(address) + + def _store(self, value, address): + value.store_in_mem(address) def __len__(self): return self.length @@ -2149,6 +2331,8 @@ def f(i): self[i] = mem_value return self +sint.dynamic_array = Array +sgf2n.dynamic_array = Array class Matrix(object): def __init__(self, rows, columns, value_type, address=None): @@ -2309,7 +2493,7 @@ def __init__(self, value): else: self.value_type = type(value) self.reg_type = self.value_type.reg_type - self.address = program.malloc(1, self.reg_type) + self.address = program.malloc(1, self.value_type) self.deleted = False self.write(value) @@ -2339,7 +2523,7 @@ def write(self, value): if not isinstance(self.register, self.value_type): raise CompilerError('Mismatch in register type, cannot write \ %s to %s' % (type(self.register), self.value_type)) - library.store_in_mem(self.register, self.address) + self.register.store_in_mem(self.address) self.last_write_block = program.curr_block return self diff --git a/Compiler/util.py b/Compiler/util.py index a4f9e3fd6..f44c0d714 100644 --- a/Compiler/util.py +++ b/Compiler/util.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt import math import operator @@ -54,7 +54,14 @@ def bit_decompose(a, bits): return a.bit_decompose(bits) def bit_compose(bits): - return sum(b << i for i,b in enumerate(bits)) + bits = list(bits) + try: + if bits: + return bits[0].bit_compose(bits) + else: + return 0 + except AttributeError: + return sum(b << i for i,b in enumerate(bits)) def series(a): sum = 0 @@ -103,3 +110,25 @@ def or_op(a, b): def pow2(bits): powers = [b.if_else(2**2**i, 1) for i,b in enumerate(bits)] return tree_reduce(operator.mul, powers) + +def irepeat(l, n): + return reduce(operator.add, ([i] * n for i in l)) + +def int_len(x): + return len(bin(x)) - 2 + +def reveal(x): + if isinstance(x, str): + return x + try: + return x.reveal() + except AttributeError: + pass + try: + return [reveal(y) for y in x] + except TypeError: + pass + return x + +def is_constant(x): + return isinstance(x, (int, long, bool)) diff --git a/Exceptions/Exceptions.h b/Exceptions/Exceptions.h index 3c20a4baf..c96e16d28 100644 --- a/Exceptions/Exceptions.h +++ b/Exceptions/Exceptions.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _Exceptions #define _Exceptions @@ -121,9 +121,39 @@ class file_error: public exception } }; class end_of_file: public exception - { virtual const char* what() const throw() - { return "End of file reached"; } + { string filename, context, ans; + public: + end_of_file(string pfilename="no filename", string pcontext="") : + filename(pfilename), context(pcontext) + { + ans="End of file when reading "; + ans+=filename; + ans+=" "; + ans+=context; + } + ~end_of_file()throw() { } + virtual const char* what() const throw() + { + return ans.c_str(); + } }; +class file_missing: public exception + { string filename, context, ans; + public: + file_missing(string pfilename="no filename", string pcontext="") : + filename(pfilename), context(pcontext) + { + ans="File missing : "; + ans+=filename; + ans+=" "; + ans+=context; + } + ~file_missing()throw() { } + virtual const char* what() const throw() + { + return ans.c_str(); + } + }; class Processor_Error: public exception { string msg; public: @@ -137,6 +167,11 @@ class Processor_Error: public exception return msg.c_str(); } }; +class Invalid_Instruction : public Processor_Error + { + public: + Invalid_Instruction(string m) : Processor_Error(m) {} + }; class max_mod_sz_too_small : public exception { int len; public: diff --git a/ExternalIO/README.md b/ExternalIO/README.md new file mode 100644 index 000000000..07b5b7757 --- /dev/null +++ b/ExternalIO/README.md @@ -0,0 +1,105 @@ +(C) 2017 University of Bristol. See License.txt. + +The ExternalIO directory contains examples of managing I/O between external client processes and SPDZ parties running SPDZ engines. These instructions assume that SPDZ has been built as per the [project readme](../README.md). + +## I/O MPC Instructions + +### Connection Setup + +**listen**(*int port_num*) + +Setup a socket server to listen for client connections. Runs in own thread so once created clients will be able to connect in the background. + +*port_num* - the port number to listen on. + +**acceptclientconnection**(*regint client_socket_id*, *int port_num*) + +Picks the first available client socket connection. Blocks if none available. + +*client_socket_id* - an identifier used to refer to the client socket. + +*port_num* - the port number identifies the socket server to accept connections on. + +### Data Exchange + +Only the sint methods are documented here, equivalent methods are available for the other data types **cfix**, **cint** and **regint**. See implementation details in [types.py](../Compiler/types.py). + +*[sint inputs]* **sint.read_from_socket**(*regint client_socket_id*, *int number_of_inputs*) + +Read a share of an input from a client, blocking on the client send. + +*client_socket_id* - an identifier used to refer to the client socket. + +*number_of_inputs* - the number of inputs expected + +*[inputs]* - returned list of shares of private input. + +**sint.write_to_socket**(*regint client_socket_id*, *[sint values]*, *int message_type*) + +Write shares of values including macs to an external client. + +*client_socket_id* - an identifier used to refer to the client socket. + +*[values]* - list of shares of values to send to client. + +*message_type* - optional integer which will be sent in first 4 bytes of message, to indicate message type to client. + +See also sint.write_shares_to_socket where macs can be explicitly included or excluded from the message. + +*[sint inputs]* **sint.receive_from_client**(*int number_of_inputs*, *regint client_socket_id*, *int message_type*) + +Receive shares of private inputs from a client, blocking on client send. This is an abstraction which first sends shares of random values to the client and then receives masked input from the client, using the input protocol introduced in [Confidential Benchmarking based on Multiparty Computation. Damgard et al.](http://eprint.iacr.org/2015/1006.pdf) + +*number_of_inputs* - the number of inputs expected + +*client_socket_id* - an identifier used to refer to the client socket. + +*message_type* - optional integer which will be sent in first 4 bytes of message, to indicate message type to client. + +*[inputs]* - returned list of shares of private input. + + +## Securing communications + +Two cryptographic protocols have been implemented for use in particular applications and are included here for completeness: + +1. Communication security using a Station to Station key agreement and libsodium Secret Box using a nonce counter for message ordering. +2. Authenticated Diffie-Hellman without message ordering. + + Please note these are **NOT** required to allow external client I/O. Your mileage may vary, for example in a web setting TLS may be sufficient to secure communications between processes. + +[client-setup.cpp](../client-setup.cpp) is a utility which is run to generate the key material for both the external clients and SPDZ parties for both protocols. + +#### MPC instructions + +**regint.init_secure_socket**(*regint client_socket_id*, *[regint] public_signing_key*) + +STS protocol initiator. Read a client public key for a specified client connection and negotiate a shared key via STS. All subsequent write_socket / read_socket instructions are encrypted / decrypted for replay resistant commsec. + +*client_socket_id* - an identifier used to refer to the client socket. + +*public_signing_key* - client public key supplied as list of 8 32-bit ints. + +**regint.resp_secure_socket**(*regint client_socket_id*, *[regint] public_signing_key*) + +STS protocol responder. Read a client public key for a specified client connection and negotiate a shared key via STS. All subsequent write_socket / read_socket instructions are encrypted / decrypted for replay resistant commsec. + +*client_socket_id* - an identifier used to refer to the client socket. + +*public_signing_key* - client public key supplied as list of 8 32-bit ints. + +*[regint public_key]* **regint.read_client_public_key**(*regint client_socket_id*) + +Instruction to read the client public key and run setup for the authenticated Diffie-Hellman encryption. All subsequent write_socket instructions are encrypted. Only the sint.read_from_socket instruction is encrypted. + +*client_socket_id* - an identifier used to refer to the client socket. + +*public_key* - client public key made available to mpc programs as list of 8 32-bit ints. + +## Working Examples + +See [bankers-bonus-client.cpp](./bankers-bonus-client.cpp) which acts as a client to [bankers_bonus.mpc](../Programs/Source/bankers_bonus.mpc) and demonstrates sending input and receiving output with no communications security. + +See [bankers-bonus-commsec-client.cpp](./bankers-bonus-commsec-client.cpp) which acts as a client to [bankers_bonus_commsec.mpc](../Programs/Source/bankers_bonus_commsec.mpc) which runs the same algorithm but includes both the available crypto protocols. + +More instructions on how to run these are provided in the *-client files. diff --git a/ExternalIO/bankers-bonus-client.cpp b/ExternalIO/bankers-bonus-client.cpp new file mode 100644 index 000000000..661d24f8e --- /dev/null +++ b/ExternalIO/bankers-bonus-client.cpp @@ -0,0 +1,198 @@ +/* + * (C) 2017 University of Bristol. See License.txt + * + * Demonstrate external client inputing and receiving outputs from a SPDZ process, + * following the protocol described in https://eprint.iacr.org/2015/1006.pdf. + * + * Provides a client to bankers_bonus.mpc program to calculate which banker pays for lunch based on + * the private value annual bonus. Up to 8 clients can connect to the SPDZ engines running + * the bankers_bonus.mpc program. + * + * Each connecting client: + * - sends a unique id to identify the client + * - sends an integer input (bonus value to compare) + * - sends an integer (0 meaining more players will join this round or 1 meaning stop the round and calc the result). + * + * The result is returned authenticated with a share of a random value: + * - share of winning unique id [y] + * - share of random value [r] + * - share of winning unique id * random value [w] + * winning unique id is valid if ∑ [y] * ∑ [r] = ∑ [w] + * + * No communications security is used. + * + * To run with 2 parties / SPDZ engines: + * ./Scripts/setup-online.sh to create triple shares for each party (spdz engine). + * ./compile.py bankers_bonus + * ./Scripts/run-online bankers_bonus to run the engines. + * + * ./bankers-bonus-client.x 123 2 100 0 + * ./bankers-bonus-client.x 456 2 200 0 + * ./bankers-bonus-client.x 789 2 50 1 + * + * Expect winner to be second client with id 456. + */ + +#include "Math/gfp.h" +#include "Math/gf2n.h" +#include "Networking/sockets.h" +#include "Tools/int.h" +#include "Math/Setup.h" +#include "Auth/fake-stuff.h" + +#include +#include +#include +#include + +// Send the private inputs masked with a random value. +// Receive shares of a preprocessed triple from each SPDZ engine, combine and check the triples are valid. +// Add the private input value to triple[0] and send to each spdz engine. +void send_private_inputs(vector& values, vector& sockets, int nparties) +{ + int num_inputs = values.size(); + octetStream os; + vector< vector > triples(num_inputs, vector(3)); + vector triple_shares(3); + + // Receive num_inputs triples from SPDZ + for (int j = 0; j < nparties; j++) + { + os.reset_write_head(); + os.Receive(sockets[j]); + + for (int j = 0; j < num_inputs; j++) + { + for (int k = 0; k < 3; k++) + { + triple_shares[k].unpack(os); + triples[j][k] += triple_shares[k]; + } + } + } + + // Check triple relations (is a party cheating?) + for (int i = 0; i < num_inputs; i++) + { + if (triples[i][0] * triples[i][1] != triples[i][2]) + { + cerr << "Incorrect triple at " << i << ", aborting\n"; + exit(1); + } + } + // Send inputs + triple[0], so SPDZ can compute shares of each value + os.reset_write_head(); + for (int i = 0; i < num_inputs; i++) + { + gfp y = values[i] + triples[i][0]; + y.pack(os); + } + for (int j = 0; j < nparties; j++) + os.Send(sockets[j]); +} + +// Assumes that Scripts/setup-online.sh has been run to compute prime +void initialise_fields(const string& dir_prefix) +{ + int lg2; + bigint p; + + string filename = dir_prefix + "Params-Data"; + cout << "loading params from: " << filename << endl; + + ifstream inpf(filename.c_str()); + if (inpf.fail()) { throw file_error(filename.c_str()); } + inpf >> p; + inpf >> lg2; + + inpf.close(); + + gfp::init_field(p); + gf2n::init_field(lg2); +} + + +// Receive shares of the result and sum together. +// Also receive authenticating values. +gfp receive_result(vector& sockets, int nparties) +{ + vector output_values(3); + octetStream os; + for (int i = 0; i < nparties; i++) + { + os.reset_write_head(); + os.Receive(sockets[i]); + for (unsigned int j = 0; j < 3; j++) + { + gfp value; + value.unpack(os); + output_values[j] += value; + } + } + + if (output_values[0] * output_values[1] != output_values[2]) + { + cerr << "Unable to authenticate output value as correct, aborting." << endl; + exit(1); + } + return output_values[0]; +} + + +int main(int argc, char** argv) +{ + int my_client_id; + int nparties; + int salary_value; + int finish; + int port_base = 14000; + string host_name = "localhost"; + + if (argc < 5) { + cout << "Usage is bankers-bonus-client " + << " " + << "" << endl; + exit(0); + } + + my_client_id = atoi(argv[1]); + nparties = atoi(argv[2]); + salary_value = atoi(argv[3]); + finish = atoi(argv[4]); + if (argc > 5) + host_name = argv[5]; + if (argc > 6) + port_base = atoi(argv[6]); + + // init static gfp + string prep_data_prefix = get_prep_dir(nparties, 128, 40); + initialise_fields(prep_data_prefix); + + // Setup connections from this client to each party socket + vector sockets(nparties); + for (int i = 0; i < nparties; i++) + { + set_up_client_socket(sockets[i], host_name.c_str(), port_base + i); + } + cout << "Finish setup socket connections to SPDZ engines." << endl; + + // Map inputs into gfp + vector input_values_gfp(3); + input_values_gfp[0].assign(my_client_id); + input_values_gfp[1].assign(salary_value); + input_values_gfp[2].assign(finish); + + // Run the commputation + send_private_inputs(input_values_gfp, sockets, nparties); + cout << "Sent private inputs to each SPDZ engine, waiting for result..." << endl; + + // Get the result back (client_id of winning client) + gfp result = receive_result(sockets, nparties); + + cout << "Winning client id is : " << result << endl; + + for (unsigned int i = 0; i < sockets.size(); i++) + close_client_socket(sockets[i]); + + return 0; +} diff --git a/ExternalIO/bankers-bonus-commsec-client.cpp b/ExternalIO/bankers-bonus-commsec-client.cpp new file mode 100644 index 000000000..9b99ad1e0 --- /dev/null +++ b/ExternalIO/bankers-bonus-commsec-client.cpp @@ -0,0 +1,407 @@ +/* + * (C) 2017 University of Bristol. See License.txt + * + * Demonstrate external client inputing and receiving outputs from a SPDZ process, + * following the protocol described in https://eprint.iacr.org/2015/1006.pdf. + * Uses SPDZ implemented encryption for external client communication, see bankers-bonus-client.cpp + * for a simpler client with no crypto. + * + * Provides a client to bankers_bonus_commsec.mpc program to calculate which banker pays for lunch based on + * the private value annual bonus. Up to 8 clients can connect to the SPDZ engines running + * the bankers_bonus.mpc program. + * + * Each connecting client: + * - runs crypto setup to demonstrate both DH Auth Encryption and STS protocol for comms security. + * - sends a unique id to identify the client + * - sends an integer input (bonus value to compare) + * - sends an integer (0 meaining more players will join this round or 1 meaning stop the round and calc the result). + * + * The result is returned authenticated with a share of a random value: + * - share of winning unique id [y] + * - share of random value [r] + * - share of winning unique id * random value [w] + * winning unique id is valid if ∑ [y] * ∑ [r] = ∑ [w] + * + * To run with 2 parties (SPDZ engines) and 3 external clients: + * ./Scripts/setup-online.sh to create triple shares for each party (spdz engine). + * ./client-setup.x 2 -nc 3 to create the crypto key material for both parties and clients. + * ./compile.py bankers_bonus_commsec + * ./Scripts/run-online bankers_bonus_commsec to run the engines. + * + * ./bankers-bonus-commsec-client.x 0 2 100 0 + * ./bankers-bonus-commsec-client.x 1 2 200 0 + * ./bankers-bonus-commsec-client.x 2 2 50 1 + * + * Expect winner to be second client with id 1. + * Note here client id must match id used in generating client key material, Client-Keys-C. + */ + +#include "Math/gfp.h" +#include "Math/gf2n.h" +#include "Networking/sockets.h" +#include "Networking/STS.h" +#include "Tools/int.h" +#include "Math/Setup.h" +#include "Auth/fake-stuff.h" + +#include +#include +#include +#include +#include + +typedef pair< vector, vector > keypair_t; // A pair of send/recv keys for talking to SPDZ +typedef vector< keypair_t > commsec_t; // A database of send/recv keys indexed by server +typedef struct { + unsigned char client_secretkey[crypto_sign_SECRETKEYBYTES]; + unsigned char client_publickey[crypto_sign_PUBLICKEYBYTES]; + vector client_publickey_ints; + vector< vector >server_publickey; +} sign_key_container_t; + +keypair_t sts_response_role_exceptions(sign_key_container_t keys, vector& sockets, int server_id) +{ + STS ke(&keys.server_publickey[server_id][0], keys.client_publickey, keys.client_secretkey); + sts_msg1_t m1; + sts_msg2_t m2; + sts_msg3_t m3; + octetStream os; + + os.Receive(sockets[server_id]); + os.consume(m1.bytes, sizeof m1.bytes); + m2 = ke.recv_msg1(m1); + os.reset_write_head(); + os.append(m2.pubkey, sizeof m2.pubkey); + os.append(m2.sig, sizeof m2.sig); + os.Send(sockets[server_id]); + os.Receive(sockets[server_id]); + os.consume(m3.bytes, sizeof m3.bytes); + ke.recv_msg3(m3); + vector recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES); + vector sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES); + return make_pair(sendKey,recvKey); +} + +keypair_t sts_initiator_role_exceptions(sign_key_container_t keys, vector& sockets, int server_id) +{ + STS ke(&keys.server_publickey[server_id][0], keys.client_publickey, keys.client_secretkey); + sts_msg1_t m1; + sts_msg2_t m2; + sts_msg3_t m3; + octetStream os; + + m1 = ke.send_msg1(); + cout << "m1: "; + for (unsigned int j = 0; j < 32; j++) + cout << setfill('0') << setw(2) << hex << (int) m1.bytes[j]; + cout << dec << endl; + os.reset_write_head(); + os.append(m1.bytes, sizeof m1.bytes); + os.Send(sockets[server_id]); + + os.reset_write_head(); + os.Receive(sockets[server_id]); + os.consume(m2.pubkey, sizeof m2.pubkey); + os.consume(m2.sig, sizeof m2.sig); + m3 = ke.recv_msg2(m2); + + os.reset_write_head(); + os.append(m3.bytes, sizeof m3.bytes); + os.Send(sockets[server_id]); + + vector sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES); + vector recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES); + return make_pair(sendKey,recvKey); +} + +pair< vector, vector > sts_response_role(sign_key_container_t keys, vector& sockets, int server_id) +{ + pair< vector, vector > res; + try { + res = sts_response_role_exceptions(keys, sockets, server_id); + } catch(char const *e) { + cerr << "Error in STS: " << e << endl; + exit(1); + } + return res; +} + +pair< vector, vector > sts_initiator_role(sign_key_container_t keys, vector& sockets, int server_id) +{ + pair< vector, vector > res; + try { + res = sts_initiator_role_exceptions(keys, sockets, server_id); + } catch(char const *e) { + cerr << "Error in STS: " << e << endl; + exit(1); + } + return res; +} + +// Send the private inputs masked with a random value. +// Receive shares of a preprocessed triple from each SPDZ engine, combine and check the triples are valid. +// Add the private input value to triple[0] and send to each spdz engine. +void send_private_inputs(vector& values, vector& sockets, int nparties, + commsec_t commsec, vector& keys) +{ + int num_inputs = values.size(); + octetStream os; + vector< vector > triples(num_inputs, vector(3)); + vector triple_shares(3); + + // Receive num_inputs triples from SPDZ + for (int j = 0; j < nparties; j++) + { + os.reset_write_head(); + os.Receive(sockets[j]); + os.decrypt_sequence(&commsec[j].second[0],0); + os.decrypt(keys[j]); + + for (int j = 0; j < num_inputs; j++) + { + for (int k = 0; k < 3; k++) + { + triple_shares[k].unpack(os); + triples[j][k] += triple_shares[k]; + } + } + } + // Check triple relations + for (int i = 0; i < num_inputs; i++) + { + if (triples[i][0] * triples[i][1] != triples[i][2]) + { + cerr << "Incorrect triple at " << i << ", aborting\n"; + exit(1); + } + } + // Send inputs + triple[0], so SPDZ can compute shares of each value + os.reset_write_head(); + for (int i = 0; i < num_inputs; i++) + { + gfp y = values[i] + triples[i][0]; + y.pack(os); + } + for (int j = 0; j < nparties; j++) { + os.encrypt_sequence(&commsec[j].first[0],0); + os.Send(sockets[j]); + } +} + +// Send public key in clear to each SPDZ engine. +void send_public_key(vector& pubkey, int socket) +{ + octetStream os; + os.reset_write_head(); + + for (unsigned int i = 0; i < pubkey.size(); i++) + { + os.store(pubkey[i]); + } + + os.Send(socket); +} + +// Assumes that Scripts/setup-online.sh has been run to compute prime +void initialise_fields(const string& dir_prefix) +{ + int lg2; + bigint p; + + string filename = dir_prefix + "Params-Data"; + cout << "loading params from: " << filename << endl; + + ifstream inpf(filename.c_str()); + if (inpf.fail()) { throw file_error(filename.c_str()); } + inpf >> p; + inpf >> lg2; + + inpf.close(); + + gfp::init_field(p); + gf2n::init_field(lg2); +} + +// Assumes that client-setup has been run to create key pairs for clients and parties +void generate_symmetric_keys(vector& keys, vector& client_public_key_ints, + sign_key_container_t *sts_key, const string& dir_prefix, int client_no) +{ + unsigned char client_publickey[crypto_box_PUBLICKEYBYTES]; + unsigned char client_secretkey[crypto_box_SECRETKEYBYTES]; + unsigned char server_publickey[crypto_box_PUBLICKEYBYTES]; + unsigned char scalarmult_q[crypto_scalarmult_BYTES]; + crypto_generichash_state h; + + // read client public/secret keys + SPDZ server public keys + ifstream keyfile; + stringstream client_filename; + client_filename << dir_prefix << "Client-Keys-C" << client_no; + keyfile.open(client_filename.str().c_str()); + if (keyfile.fail()) + throw file_error(client_filename.str()); + keyfile.read((char*)client_publickey, sizeof client_publickey); + if (keyfile.eof()) + throw end_of_file(client_filename.str(), "client public key" ); + + // Convert client public key unsigned char to int, reverse endianness + for(unsigned int j = 0; j < client_public_key_ints.size(); j++) { + int keybyte = 0; + for(unsigned int k = 0; k < 4; k++) { + keybyte = keybyte + (((int)client_publickey[j*4+k]) << ((3-k) * 8)); + } + client_public_key_ints[j] = keybyte; + } + + keyfile.read((char*)client_secretkey, sizeof client_secretkey); + if (keyfile.eof()) { + throw end_of_file(client_filename.str(), "client private key" ); + } + + keyfile.read((char*)sts_key->client_publickey, crypto_sign_PUBLICKEYBYTES); + keyfile.read((char*)sts_key->client_secretkey, crypto_sign_SECRETKEYBYTES); + // Convert client public key unsigned char to int, reverse endianness + sts_key->client_publickey_ints.resize(8); + for(unsigned int j = 0; j < sts_key->client_publickey_ints.size(); j++) { + int keybyte = 0; + for(unsigned int k = 0; k < 4; k++) { + keybyte = keybyte + (((int)sts_key->client_publickey[j*4+k]) << ((3-k) * 8)); + } + sts_key->client_publickey_ints[j] = keybyte; + } + + for (unsigned int i = 0; i < keys.size(); i++) + { + keys[i] = new octet[crypto_generichash_BYTES]; + keyfile.read((char*)server_publickey, crypto_box_PUBLICKEYBYTES); + if (keyfile.eof()) + throw end_of_file(client_filename.str(), "server public key for party " + i); + keyfile.read((char*)(&sts_key->server_publickey[i][0]), crypto_sign_PUBLICKEYBYTES); + if (keyfile.eof()) + throw end_of_file(client_filename.str(), "server public signing key for party " + i); + + // Derive a shared key from this server's secret key and the client's public key + // shared key = h(q || client_secretkey || server_publickey) + if (crypto_scalarmult(scalarmult_q, client_secretkey, server_publickey) != 0) { + cerr << "Scalar mult failed\n"; + exit(1); + } + crypto_generichash_init(&h, NULL, 0U, crypto_generichash_BYTES); + crypto_generichash_update(&h, scalarmult_q, sizeof scalarmult_q); + crypto_generichash_update(&h, client_publickey, sizeof client_publickey); + crypto_generichash_update(&h, server_publickey, sizeof server_publickey); + crypto_generichash_final(&h, keys[i], crypto_generichash_BYTES); + } + keyfile.close(); + + cout << "My public key is: "; + for (unsigned int j = 0; j < 32; j++) + cout << setfill('0') << setw(2) << hex << (int) client_publickey[j]; + cout << dec << endl; +} + + +// Receive shares of the result and sum together. +// Also receive authenticating values. +gfp receive_result(vector& sockets, int nparties, commsec_t commsec, vector& keys) +{ + vector output_values(3); + octetStream os; + for (int i = 0; i < nparties; i++) + { + os.reset_write_head(); + os.Receive(sockets[i]); + + os.decrypt_sequence(&commsec[i].second[0],1); + os.decrypt(keys[i]); + + for (unsigned int j = 0; j < 3; j++) + { + gfp value; + value.unpack(os); + output_values[j] += value; + } + } + + if (output_values[0] * output_values[1] != output_values[2]) + { + cerr << "Unable to authenticate output value as correct, aborting." << endl; + exit(1); + } + return output_values[0]; +} + + +int main(int argc, char** argv) +{ + int my_client_id; + int nparties; + int salary_value; + int finish; + int port_base = 14000; + sign_key_container_t sts_key; + string host_name = "localhost"; + + if (argc < 5) { + cout << "Usage is external-client " + << " " + << "" << endl; + exit(0); + } + + my_client_id = atoi(argv[1]); + nparties = atoi(argv[2]); + salary_value = atoi(argv[3]); + finish = atoi(argv[4]); + if (argc > 5) + host_name = argv[5]; + if (argc > 6) + port_base = atoi(argv[6]); + + sts_key.server_publickey.resize(nparties); + for(int i = 0 ; i < nparties; i++) { + sts_key.server_publickey[i].resize(crypto_sign_PUBLICKEYBYTES); + } + + // init static gfp + string prep_data_prefix = get_prep_dir(nparties, 128, 40); + initialise_fields(prep_data_prefix); + + // Generate session keys to decrypt data sent from each spdz engine (party) + vector session_keys(nparties); + vector client_public_key_ints(8); + + generate_symmetric_keys(session_keys, client_public_key_ints, &sts_key, prep_data_prefix, my_client_id); + + // Setup connections from this client to each party socket and send the client public keys + vector sockets(nparties); + // vector< pair , vector > > commseckey(nparties); + commsec_t commseckey(nparties); + for (int i = 0; i < nparties; i++) + { + set_up_client_socket(sockets[i], host_name.c_str(), port_base + i); + send_public_key(sts_key.client_publickey_ints, sockets[i]); + send_public_key(client_public_key_ints, sockets[i]); + commseckey[i] = sts_initiator_role(sts_key, sockets, i); + } + cout << "Finish setup socket connections to SPDZ engines." << endl; + + // Map inputs into gfp + vector input_values_gfp(3); + input_values_gfp[0].assign(my_client_id); + input_values_gfp[1].assign(salary_value); + input_values_gfp[2].assign(finish); + + // Send the inputs to the SPDZ Engines + send_private_inputs(input_values_gfp, sockets, nparties, commseckey, session_keys); + cout << "Sent private inputs to each SPDZ engine, waiting for result..." << endl; + + // Get the result back + gfp result = receive_result(sockets, nparties, commseckey, session_keys); + + cout << "Winning client id is : " << result << endl; + + for (unsigned int i = 0; i < sockets.size(); i++) + close_client_socket(sockets[i]); + + return 0; +} diff --git a/Fake-Offline.cpp b/Fake-Offline.cpp index 750598d18..cfac46eca 100644 --- a/Fake-Offline.cpp +++ b/Fake-Offline.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Math/gf2n.h" @@ -490,8 +490,6 @@ int main(int argc, const char** argv) bigint p; generate_online_setup(outf, prep_data_prefix, p, lgp, lg2); - generate_keys(prep_data_prefix, nplayers); - /* Find number players and MAC keys etc*/ gfp keyp,pp; keyp.assign_zero(); gf2n key2,p2; key2.assign_zero(); diff --git a/License.txt b/License.txt index e30817d3a..ed66199f6 100644 --- a/License.txt +++ b/License.txt @@ -1,6 +1,6 @@ University of Bristol : Open Access Software Licence -Copyright (c) 2016, The University of Bristol, a chartered corporation having Royal Charter number RC000648 and a charity (number X1121) and its place of administration being at Senate House, Tyndall Avenue, Bristol, BS8 1TH, United Kingdom. +Copyright (c) 2017, The University of Bristol, a chartered corporation having Royal Charter number RC000648 and a charity (number X1121) and its place of administration being at Senate House, Tyndall Avenue, Bristol, BS8 1TH, United Kingdom. All rights reserved diff --git a/Makefile b/Makefile index eff45dac5..2b89126d1 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt include CONFIG @@ -26,7 +26,7 @@ LIB = libSPDZ.a LIBSIMPLEOT = SimpleOT/libsimpleot.a -all: gen_input online offline +all: gen_input online offline externalIO online: Fake-Offline.x Server.x Player-Online.x Check-Offline.x @@ -34,6 +34,8 @@ offline: $(OT_EXE) Check-Offline.x gen_input: gen_input_f2n.x gen_input_fp.x +externalIO: client-setup.x bankers-bonus-client.x bankers-bonus-commsec-client.x + Fake-Offline.x: Fake-Offline.cpp $(COMMON) $(PROCESSOR) $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) @@ -69,7 +71,14 @@ gen_input_f2n.x: Scripts/gen_input_f2n.cpp $(COMMON) gen_input_fp.x: Scripts/gen_input_fp.cpp $(COMMON) $(CXX) $(CFLAGS) Scripts/gen_input_fp.cpp -o gen_input_fp.x $(COMMON) $(LDLIBS) +client-setup.x: client-setup.cpp $(COMMON) $(PROCESSOR) + $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) -clean: - -rm */*.o *.o *.x core.* *.a gmon.out +bankers-bonus-client.x: ExternalIO/bankers-bonus-client.cpp $(COMMON) $(PROCESSOR) + $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) +bankers-bonus-commsec-client.x: ExternalIO/bankers-bonus-commsec-client.cpp $(COMMON) $(PROCESSOR) + $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) + +clean: + -rm */*.o *.o */*.d *.d *.x core.* *.a gmon.out diff --git a/Math/Integer.cpp b/Math/Integer.cpp index b6dc06e7c..184952d99 100644 --- a/Math/Integer.cpp +++ b/Math/Integer.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Integer.cpp diff --git a/Math/Integer.h b/Math/Integer.h index b1730b340..89e00571f 100644 --- a/Math/Integer.h +++ b/Math/Integer.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Integer.h @@ -15,10 +15,13 @@ using namespace std; class Integer { +protected: long a; public: + static string type_string() { return "integer"; } + Integer() { a = 0; } Integer(long a) : a(a) {} @@ -26,6 +29,31 @@ class Integer void assign_zero() { a = 0; } + long operator+(const Integer& other) const { return a + other.a; } + long operator-(const Integer& other) const { return a - other.a; } + long operator*(const Integer& other) const { return a * other.a; } + long operator/(const Integer& other) const { return a / other.a; } + + long operator>>(const Integer& other) const { return a >> other.a; } + long operator<<(const Integer& other) const { return a << other.a; } + + long operator^(const Integer& other) const { return a ^ other.a; } + long operator&(const Integer& other) const { return a ^ other.a; } + long operator|(const Integer& other) const { return a ^ other.a; } + + bool operator==(const Integer& other) const { return a == other.a; } + bool operator!=(const Integer& other) const { return a != other.a; } + bool operator<(const Integer& other) const { return a < other.a; } + bool operator<=(const Integer& other) const { return a <= other.a; } + bool operator>(const Integer& other) const { return a > other.a; } + bool operator>=(const Integer& other) const { return a >= other.a; } + + long operator^=(const Integer& other) { return a ^= other.a; } + + friend unsigned int& operator+=(unsigned int& x, const Integer& other) { return x += other.a; } + + friend ostream& operator<<(ostream& s, const Integer& x) { x.output(s, true); return s; } + void output(ostream& s,bool human) const; void input(istream& s,bool human); diff --git a/Math/Setup.cpp b/Math/Setup.cpp index 29c8002dd..c13704b1b 100644 --- a/Math/Setup.cpp +++ b/Math/Setup.cpp @@ -1,5 +1,5 @@ -// (C) 2016 University of Bristol. See License.txt - +// (C) 2017 University of Bristol. See License.txt + #include "Math/Setup.h" #include "Math/gfp.h" @@ -111,8 +111,8 @@ void generate_online_setup(ofstream& outf, string dirname, bigint& p, int lgp, i } string get_prep_dir(int nparties, int lg2p, int gf2ndegree) -{ - if (gf2ndegree == 0) +{ + if (gf2ndegree == 0) gf2ndegree = gf2n::default_length(); stringstream ss; ss << PREP_DIR << nparties << "-" << lg2p << "-" << gf2ndegree << "/"; diff --git a/Math/Setup.h b/Math/Setup.h index a3813f1d9..db82a5112 100644 --- a/Math/Setup.h +++ b/Math/Setup.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Setup.h diff --git a/Math/Share.cpp b/Math/Share.cpp index 065c69a8d..57b138a70 100644 --- a/Math/Share.cpp +++ b/Math/Share.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Share.h" @@ -99,6 +99,23 @@ T combine(const vector< Share >& S) } + + +template +inline void Share::pack(octetStream& os) const +{ + a.pack(os); + mac.pack(os); +} + +template +inline void Share::unpack(octetStream& os) +{ + a.unpack(os); + mac.unpack(os); +} + + template bool check_macs(const vector< Share >& S,const T& key) { diff --git a/Math/Share.h b/Math/Share.h index 95382c11d..0320029c1 100644 --- a/Math/Share.h +++ b/Math/Share.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _Share @@ -69,6 +69,19 @@ class Share void sub(const Share& S1,const Share& S2); void add(const Share& S1) { add(*this,S1); } + Share operator+(const Share& x) const + { Share res; res.add(*this, x); return res; } + template + Share operator*(const U& x) const + { Share res; res.mul(*this, x); return res; } + + Share& operator+=(const Share& x) { add(x); return *this; } + template + Share& operator*=(const U& x) { mul(*this, x); return *this; } + + Share operator<<(int i) { return this->operator*(T(1) << i); } + Share& operator<<=(int i) { return *this = *this << i; } + // Input and output from a stream // - Can do in human or machine only format (later should be faster) void output(ostream& s,bool human) const @@ -80,6 +93,11 @@ class Share mac.input(s,human); } + friend ostream& operator<<(ostream& s, const Share& x) { x.output(s, true); return s; } + + void pack(octetStream& os) const; + void unpack(octetStream& os); + /* Takes a vector of shares, one from each player and * determines the shared value * - i.e. Partially open the shares diff --git a/Math/Zp_Data.cpp b/Math/Zp_Data.cpp index d5ef56cca..7b43591e4 100644 --- a/Math/Zp_Data.cpp +++ b/Math/Zp_Data.cpp @@ -1,5 +1,5 @@ -// (C) 2016 University of Bristol. See License.txt - +// (C) 2017 University of Bristol. See License.txt + #include "Zp_Data.h" diff --git a/Math/Zp_Data.h b/Math/Zp_Data.h index 0ddb14e94..954d0d4f7 100644 --- a/Math/Zp_Data.h +++ b/Math/Zp_Data.h @@ -1,5 +1,5 @@ -// (C) 2016 University of Bristol. See License.txt - +// (C) 2017 University of Bristol. See License.txt + #ifndef _Zp_Data #define _Zp_Data diff --git a/Math/bigint.cpp b/Math/bigint.cpp index 238616da8..a6b38a968 100644 --- a/Math/bigint.cpp +++ b/Math/bigint.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "bigint.h" diff --git a/Math/bigint.h b/Math/bigint.h index 99c6f4073..364f8b8c8 100644 --- a/Math/bigint.h +++ b/Math/bigint.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _bigint #define _bigint diff --git a/Math/field_types.h b/Math/field_types.h index f1f9fb6db..bfeb519c6 100644 --- a/Math/field_types.h +++ b/Math/field_types.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * types.h diff --git a/Math/gf2n.cpp b/Math/gf2n.cpp index 967d06299..1e8a76983 100644 --- a/Math/gf2n.cpp +++ b/Math/gf2n.cpp @@ -1,5 +1,5 @@ -// (C) 2016 University of Bristol. See License.txt - +// (C) 2017 University of Bristol. See License.txt + #include "Math/gf2n.h" @@ -58,12 +58,12 @@ void gf2n_short::init_tables() void gf2n_short::init_field(int nn) { - if (nn == 0) - { - nn = default_length(); - cerr << "Using GF(2^" << nn << ")" << endl; - } - + if (nn == 0) + { + nn = default_length(); + cerr << "Using GF(2^" << nn << ")" << endl; + } + gf2n_short::init_tables(); int i,j=-1; for (i=0; i(aa)&mask; } void assign(const char* buffer) { a = *(word*)buffer; } @@ -93,8 +94,10 @@ class gf2n_short } gf2n_short() { a=0; } - gf2n_short(const gf2n_short& g) { assign(g); } - gf2n_short(int g) { assign(g); } + gf2n_short(word a) { assign(a); } + gf2n_short(long a) { assign(a); } + gf2n_short(int a) { assign(a); } + gf2n_short(const char* a) { assign(a); } ~gf2n_short() { ; } gf2n_short& operator=(const gf2n_short& g) @@ -167,7 +170,7 @@ class gf2n_short void input(istream& s,bool human); friend ostream& operator<<(ostream& s,const gf2n_short& x) - { s << hex << "0x" << x.a << dec; + { s << hex << showbase << x.a << dec; return s; } friend istream& operator>>(istream& s,gf2n_short& x) diff --git a/Math/gf2nlong.cpp b/Math/gf2nlong.cpp index 849726673..f465911c7 100644 --- a/Math/gf2nlong.cpp +++ b/Math/gf2nlong.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * gf2n_longlong.cpp diff --git a/Math/gf2nlong.h b/Math/gf2nlong.h index 4211d8d7f..cd9350644 100644 --- a/Math/gf2nlong.h +++ b/Math/gf2nlong.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * gf2nlong.h diff --git a/Math/gfp.cpp b/Math/gfp.cpp index da20448db..f0c8baa89 100644 --- a/Math/gfp.cpp +++ b/Math/gfp.cpp @@ -1,5 +1,5 @@ -// (C) 2016 University of Bristol. See License.txt - +// (C) 2017 University of Bristol. See License.txt + #include "Math/gfp.h" @@ -71,10 +71,15 @@ void gfp::SHL(const gfp& x,int n) { if (!x.is_zero()) { - bigint bi; - to_bigint(bi,x,false); - mpn_lshift(bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_size,n); - to_gfp(*this, bi); + if (n != 0) + { + bigint bi; + to_bigint(bi,x,false); + mpn_lshift(bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_size,n); + to_gfp(*this, bi); + } + else + assign(x); } else { @@ -87,10 +92,15 @@ void gfp::SHR(const gfp& x,int n) { if (!x.is_zero()) { - bigint bi; - to_bigint(bi,x); - mpn_rshift(bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_size,n); - to_gfp(*this, bi); + if (n != 0) + { + bigint bi; + to_bigint(bi,x); + mpn_rshift(bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_size,n); + to_gfp(*this, bi); + } + else + assign(x); } else { diff --git a/Math/gfp.h b/Math/gfp.h index 5ca3afb54..c5e8d945b 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -1,5 +1,5 @@ -// (C) 2016 University of Bristol. See License.txt - +// (C) 2017 University of Bristol. See License.txt + #ifndef _gfp #define _gfp @@ -93,7 +93,7 @@ class gfp bool is_zero() const { return isZero(a,ZpD); } - bool is_one() const { return isOne(a,ZpD); } + bool is_one() const { return isOne(a,ZpD); } bool is_bit() const { return is_zero() or is_one(); } bool equal(const gfp& y) const { return areEqual(a,y.a,ZpD); } bool operator==(const gfp& y) const { return equal(y); } diff --git a/Math/modp.cpp b/Math/modp.cpp index 033edda32..604ce363c 100644 --- a/Math/modp.cpp +++ b/Math/modp.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Zp_Data.h" #include "modp.h" diff --git a/Math/modp.h b/Math/modp.h index b384c3a1d..bbab59911 100644 --- a/Math/modp.h +++ b/Math/modp.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _Modp #define _Modp diff --git a/Math/operators.h b/Math/operators.h index d4d34b4a9..b2910cb39 100644 --- a/Math/operators.h +++ b/Math/operators.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * operations.h @@ -17,15 +17,15 @@ T& operator*=(const T& y, const bool& x) { y = x ? y : T(); return y; } template T operator+(const T& x, const U& y) { T res; res.add(x, y); return res; } -template -T operator*(const T& x, const U& y) { T res; res.mul(x, y); return res; } +template +T operator*(const T& x, const T& y) { T res; res.mul(x, y); return res; } template T operator-(const T& x, const U& y) { T res; res.sub(x, y); return res; } template T& operator+=(T& x, const U& y) { x.add(y); return x; } -template -T& operator*=(T& x, const U& y) { x.mul(y); return x; } +template +T& operator*=(T& x, const T& y) { x.mul(y); return x; } template T& operator-=(T& x, const U& y) { x.sub(y); return x; } diff --git a/Networking/Player.cpp b/Networking/Player.cpp index 2325cec5d..fd081742b 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -1,40 +1,48 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Player.h" #include "Exceptions/Exceptions.h" +#include "Networking/STS.h" #include +#include -// Use printf rather than cout so valgrind can detect thread issues +using namespace std; + +CommsecKeysPackage::CommsecKeysPackage(vector playerpubs, + secret_signing_key mypriv, + public_signing_key mypub) +{ + player_public_keys = playerpubs; + my_secret_key = mypriv; + my_public_key = mypub; +} void Names::init(int player,int pnb,const char* servername) -{ +{ player_no=player; portnum_base=pnb; setup_names(servername); + keys = NULL; setup_server(); } - -void Names::init(int player,int pnb,vector Nms) +void Names::init(int player,int pnb,vector Nms) { - player_no=player; - portnum_base=pnb; - nplayers=Nms.size(); - names.resize(nplayers); - for (int i=0; i Nms) +void Names::init(int player,int pnb,vector Nms) { player_no=player; portnum_base=pnb; nplayers=Nms.size(); - names=Nms; + names.resize(nplayers); + for (int i=0; ikeys = keys; +} + void Names::setup_names(const char *servername) { int socket_num; @@ -79,11 +93,16 @@ void Names::setup_names(const char *servername) // Send my name octet my_name[512]; memset(my_name,0,512*sizeof(octet)); - gethostname((char*)my_name,512); + sockaddr_in address; + socklen_t size = sizeof address; + getsockname(socket_num, (sockaddr*)&address, &size); + char* name = inet_ntoa(address.sin_addr); + // max length of IP address with ending 0 + strncpy((char*)my_name, name, 16); fprintf(stderr, "My Name = %s\n",my_name); send(socket_num,my_name,512); cerr << "My number = " << player_no << endl; - + // Now get the set of names int i; receive(socket_num,nplayers); @@ -102,6 +121,7 @@ void Names::setup_names(const char *servername) void Names::setup_server() { server = new ServerSocket(portnum_base + player_no); + server->init(); } @@ -113,6 +133,7 @@ Names::Names(const Names& other) nplayers = other.nplayers; portnum_base = other.portnum_base; names = other.names; + keys = NULL; server = 0; } @@ -135,7 +156,7 @@ Player::Player(const Names& Nms, int id) : PlayerBase(Nms), send_to_self_socket( Player::~Player() -{ +{ /* Close down the sockets */ for (int i=0; i& names,int portnum_base,int id_base,ServerSocket& server) { - sockets.resize(nplayers); - // Set up the client side - for (int i=player_no; i& o,bool donthash) const { for (int i=0; iplayer_no) { o[player_no].Send(sockets[i]); } - else if (iplayer_no) + else if (i>player_no) { o[i].reset_write_head(); - o[i].Receive(sockets[i]); + o[i].Receive(sockets[i]); } } if (!donthash) @@ -240,7 +269,7 @@ void Player::Check_Broadcast() const Broadcast_Receive(h,true); for (int i=0; i other_player; - setup_sockets(Nms.names[other_player].c_str(), *Nms.server, Nms.portnum_base + other_player, id); + setup_sockets(other_player, Nms, Nms.portnum_base + other_player, id); } TwoPartyPlayer::~TwoPartyPlayer() -{ +{ + for(size_t i=0; i < my_secret_key.size(); i++) { + my_secret_key[i] = 0; + } close_client_socket(socket); } -void TwoPartyPlayer::setup_sockets(const char* hostname, ServerSocket& server, int pn, int id) +static pair sts_initiator(int socket, CommsecKeysPackage *keys, int other_player) { - if (is_server) - { - fprintf(stderr, "Setting up server with id %d\n",id); - socket = server.get_connection_socket(id); + sts_msg1_t m1; + sts_msg2_t m2; + sts_msg3_t m3; + octetStream socket_stream; + + // Start Station to Station Protocol + STS ke(&keys->player_public_keys[other_player][0], &keys->my_public_key[0], &keys->my_secret_key[0]); + m1 = ke.send_msg1(); + socket_stream.reset_write_head(); + socket_stream.append(m1.bytes, sizeof m1.bytes); + socket_stream.Send(socket); + socket_stream.Receive(socket); + socket_stream.consume(m2.pubkey, sizeof m2.pubkey); + socket_stream.consume(m2.sig, sizeof m2.sig); + m3 = ke.recv_msg2(m2); + socket_stream.reset_write_head(); + socket_stream.append(m3.bytes, sizeof m3.bytes); + socket_stream.Send(socket); + + // Use results of STS to generate send and receive keys. + vector sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES); + vector recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES); + keyinfo sendkeyinfo = make_pair(sendKey,0); + keyinfo recvkeyinfo = make_pair(recvKey,0); + return make_pair(sendkeyinfo,recvkeyinfo); +} + +static pair sts_responder(int socket, CommsecKeysPackage *keys, int other_player) + // secret_signing_key mykey, public_signing_key mypubkey, public_signing_key theirkey) +{ + sts_msg1_t m1; + sts_msg2_t m2; + sts_msg3_t m3; + octetStream socket_stream; + + // Start Station to Station Protocol for the responder + STS ke(&keys->player_public_keys[other_player][0], &keys->my_public_key[0], &keys->my_secret_key[0]); + socket_stream.Receive(socket); + socket_stream.consume(m1.bytes, sizeof m1.bytes); + m2 = ke.recv_msg1(m1); + socket_stream.reset_write_head(); + socket_stream.append(m2.pubkey, sizeof m2.pubkey); + socket_stream.append(m2.sig, sizeof m2.sig); + socket_stream.Send(socket); + socket_stream.Receive(socket); + socket_stream.consume(m3.bytes, sizeof m3.bytes); + ke.recv_msg3(m3); + + // Use results of STS to generate send and receive keys. + vector recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES); + vector sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES); + keyinfo sendkeyinfo = make_pair(sendKey,0); + keyinfo recvkeyinfo = make_pair(recvKey,0); + return make_pair(sendkeyinfo,recvkeyinfo); +} + +void TwoPartyPlayer::setup_sockets(int other_player, const Names &nms, int portNum, int id) +{ + const char *hostname = nms.names[other_player].c_str(); + ServerSocket *server = nms.server; + if (is_server) { + fprintf(stderr, "Setting up server with id %d\n",id); + socket = server->get_connection_socket(id); + if(NULL != nms.keys) { + pair send_recv_pair = sts_responder(socket, nms.keys, other_player); + player_send_key = send_recv_pair.first; + player_recv_key = send_recv_pair.second; + } } - else - { - fprintf(stderr, "Setting up client to %s:%d with id %d\n", hostname, pn, id); - set_up_client_socket(socket, hostname, pn); - ::send(socket, (unsigned char*)&id, sizeof(id)); + else { + fprintf(stderr, "Setting up client to %s:%d with id %d\n", hostname, portNum, id); + set_up_client_socket(socket, hostname, portNum); + ::send(socket, (unsigned char*)&id, sizeof(id)); + if(NULL != nms.keys) { + pair send_recv_pair = sts_initiator(socket, nms.keys, other_player); + player_send_key = send_recv_pair.first; + player_recv_key = send_recv_pair.second; + } } } @@ -381,31 +481,37 @@ int TwoPartyPlayer::other_player_num() const return other_player; } -void TwoPartyPlayer::send(octetStream& o) const +void TwoPartyPlayer::send(octetStream& o) { + if(p2pcommsec) { + o.encrypt_sequence(&player_send_key.first[0], player_send_key.second); + player_send_key.second++; + } o.Send(socket); } -void TwoPartyPlayer::receive(octetStream& o) const +void TwoPartyPlayer::receive(octetStream& o) { o.reset_write_head(); o.Receive(socket); + if(p2pcommsec) { + o.decrypt_sequence(&player_recv_key.first[0], player_recv_key.second); + player_recv_key.second++; + } } -void TwoPartyPlayer::send_receive_player(vector& o) const +void TwoPartyPlayer::send_receive_player(vector& o) { { if (is_server) { - o[0].Send(socket); - o[1].reset_write_head(); - o[1].Receive(socket); + send(o[0]); + receive(o[1]); } else { - o[1].reset_write_head(); - o[1].Receive(socket); - o[0].Send(socket); + receive(o[1]); + send(o[0]); } } } diff --git a/Networking/Player.h b/Networking/Player.h index 500061197..7784ce8b6 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _Player #define _Player @@ -23,6 +23,23 @@ using namespace std; #include "Networking/Receiver.h" #include "Networking/Sender.h" +typedef vector public_signing_key; +typedef vector secret_signing_key; +typedef vector chachakey; +typedef pair< chachakey, uint64_t > keyinfo; + +class CommsecKeysPackage { +public: + vector player_public_keys; + secret_signing_key my_secret_key; + public_signing_key my_public_key; + + CommsecKeysPackage(vector playerpubs, + secret_signing_key mypriv, + public_signing_key mypub); + ~CommsecKeysPackage(); +}; + /* Class to get the names off the server */ class Names { @@ -31,6 +48,8 @@ class Names int portnum_base; int player_no; + CommsecKeysPackage *keys; + void setup_names(const char *servername); void setup_server(); @@ -39,7 +58,6 @@ class Names mutable ServerSocket* server; - // Usual setup names void init(int player,int pnb,const char* servername); Names(int player,int pnb,const char* servername) { init(player,pnb,servername); } @@ -50,11 +68,10 @@ class Names void init(int player,int pnb,vector Nms); Names(int player,int pnb,vector Nms) { init(player,pnb,Nms); } - // Set up names from file -- reads the first nplayers names in the file void init(int player, int nplayers, int pnb, const string& hostsfile); Names(int player, int nplayers, int pnb, const string& hostsfile) { init(player, nplayers, pnb, hostsfile); } - + void set_keys( CommsecKeysPackage *keys ); Names() : nplayers(-1), portnum_base(-1), player_no(-1), server(0) { ; } Names(const Names& other); @@ -81,7 +98,6 @@ class PlayerBase int my_num() const { return player_no; } }; - class Player : public PlayerBase { protected: @@ -161,25 +177,31 @@ class TwoPartyPlayer : public PlayerBase { private: // setup sockets for comm. with only one other player - void setup_sockets(const char* hostname, ServerSocket& server, int pn, int id); + void setup_sockets(int other_player, const Names &nms, int portNum, int id); int socket; bool is_server; int other_player; + bool p2pcommsec; + + secret_signing_key my_secret_key; + map player_public_keys; + keyinfo player_send_key; + keyinfo player_recv_key; public: TwoPartyPlayer(const Names& Nms, int other_player, int pn_offset=0); ~TwoPartyPlayer(); - void send(octetStream& o) const; - void receive(octetStream& o) const; + void send(octetStream& o); + void receive(octetStream& o); int other_player_num() const; /* Send and receive to/from the other player * - o[0] contains my data, received data put in o[1] */ - void send_receive_player(vector& o) const; + void send_receive_player(vector& o); }; #endif diff --git a/Networking/Receiver.cpp b/Networking/Receiver.cpp index c7527f3c8..b4353af2f 100644 --- a/Networking/Receiver.cpp +++ b/Networking/Receiver.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Receiver.cpp diff --git a/Networking/Receiver.h b/Networking/Receiver.h index f7d62d075..f81912ba9 100644 --- a/Networking/Receiver.h +++ b/Networking/Receiver.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Receiver.h diff --git a/Networking/STS.cpp b/Networking/STS.cpp new file mode 100644 index 000000000..c6310a3bf --- /dev/null +++ b/Networking/STS.cpp @@ -0,0 +1,230 @@ +// (C) 2017 University of Bristol. See License.txt + +#include "Networking/STS.h" +#include +#include +#include +#include +#include +#include +#include + +void STS::kdf_block(unsigned char *block) +{ + crypto_hash_sha512_state state; + crypto_hash_sha512_init(&state); + unsigned char ctrbytes[sizeof kdf_counter]; + kdf_counter++; + + // Little endian serialization + for(size_t i=0; i> i*8) & 0xFF); + } + crypto_hash_sha512_update(&state,ctrbytes,sizeof ctrbytes); + crypto_hash_sha512_update(&state,raw_secret,crypto_hash_sha512_BYTES); + crypto_hash_sha512_final(&state, block); +} + +vector STS::unsafe_derive_secret(size_t sz) +{ + // KDF ~ H(cnt || raw_secret) + vector resultSecret(sz + crypto_hash_sha512_BYTES - (sz % crypto_hash_sha512_BYTES)); + size_t total=0; + while(total < sz) { + unsigned char *block = &resultSecret[total]; + kdf_block(block); + total += crypto_hash_sha512_BYTES; + } + return resultSecret; +} + +STS::STS() +{ + phase = UNDEFINED; +} + +void STS::init( const unsigned char theirPub[crypto_sign_PUBLICKEYBYTES] + , const unsigned char myPub[crypto_sign_PUBLICKEYBYTES] + , const unsigned char myPriv[crypto_sign_SECRETKEYBYTES]) +{ + phase = UNKNOWN; + memcpy(their_public_sign_key, theirPub, crypto_sign_PUBLICKEYBYTES); + memcpy(my_public_sign_key, myPub, crypto_sign_PUBLICKEYBYTES); + memcpy(my_private_sign_key, myPriv, crypto_sign_SECRETKEYBYTES); + memset(their_ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES); + memset(ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES); + memset(ephemeral_private_key, 0, crypto_box_SECRETKEYBYTES); + kdf_counter = 0; +} + +STS::STS( const unsigned char theirPub[crypto_sign_PUBLICKEYBYTES] + , const unsigned char myPub[crypto_sign_PUBLICKEYBYTES] + , const unsigned char myPriv[crypto_sign_SECRETKEYBYTES]) +{ + phase = UNKNOWN; + memcpy(their_public_sign_key, theirPub, crypto_sign_PUBLICKEYBYTES); + memcpy(my_public_sign_key, myPub, crypto_sign_PUBLICKEYBYTES); + memcpy(my_private_sign_key, myPriv, crypto_sign_SECRETKEYBYTES); + memset(their_ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES); + memset(ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES); + memset(ephemeral_private_key, 0, crypto_box_SECRETKEYBYTES); + kdf_counter = 0; +} + +STS::~STS() +{ + memset(their_public_sign_key, 0, crypto_sign_PUBLICKEYBYTES); + memset(my_private_sign_key, 0, crypto_sign_SECRETKEYBYTES); + memset(ephemeral_private_key, 0, crypto_box_SECRETKEYBYTES); + memset(ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES); + memset(their_ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES); + memset(raw_secret, 0, crypto_hash_sha512_BYTES); + kdf_counter = 0; + phase = UNKNOWN; +} + +sts_msg1_t STS::send_msg1() +{ + sts_msg1_t m; + if(UNKNOWN != phase) { + throw "STS BAD PHASE"; + } + + crypto_box_keypair(ephemeral_public_key, ephemeral_private_key); + memcpy(m.bytes,ephemeral_public_key,crypto_box_PUBLICKEYBYTES); + phase = SENT1; + return m; +} + +// If the incoming signature is valid, compute: +// shared secret = H(DH(pubB,privA) || pubA || pubB) +// msg = Sign_{privED-A} (pubA || pubB ) +// +sts_msg3_t STS::recv_msg2(sts_msg2_t msg2) +{ + unsigned char *theirPublicKey = msg2.pubkey; + unsigned char *theirSig = msg2.sig; + unsigned char theirSigDec[crypto_sign_BYTES]; + unsigned char scalar_result[crypto_scalarmult_SCALARBYTES]; + const unsigned char zeroNonce[crypto_stream_NONCEBYTES] = {0}; + int ret; + crypto_hash_sha512_state state; + sts_msg3_t msg; + + if(SENT1 != phase) { + throw "STS BAD PHASE"; + } + ret = crypto_scalarmult(scalar_result, ephemeral_private_key, theirPublicKey); + if(0 != ret) { + throw "crypto_scalarmult failed"; + } + + crypto_hash_sha512_init(&state); + crypto_hash_sha512_update(&state,scalar_result,crypto_scalarmult_SCALARBYTES); + crypto_hash_sha512_update(&state,ephemeral_public_key,crypto_box_PUBLICKEYBYTES); + crypto_hash_sha512_update(&state,theirPublicKey,crypto_box_PUBLICKEYBYTES); + crypto_hash_sha512_final(&state,raw_secret); + + vector keKey = unsafe_derive_secret(crypto_stream_KEYBYTES); + vector expectedMessage; + expectedMessage.insert(expectedMessage.end(), theirPublicKey , theirPublicKey + crypto_box_PUBLICKEYBYTES); + expectedMessage.insert(expectedMessage.end(), ephemeral_public_key, ephemeral_public_key + crypto_box_PUBLICKEYBYTES); + + crypto_stream_xor(theirSigDec, theirSig, crypto_sign_BYTES, zeroNonce, &keKey[0]); + + int badSig = crypto_sign_verify_detached(theirSigDec, &expectedMessage[0], expectedMessage.size(), their_public_sign_key); + + if(badSig) { + throw "Bad signature received in message 2."; + } else { + unsigned char *mySigEnc = msg.bytes; + unsigned char mySig[crypto_sign_BYTES]; + vector signMessage; + signMessage.insert(signMessage.end(), ephemeral_public_key, ephemeral_public_key + crypto_box_PUBLICKEYBYTES); + signMessage.insert(signMessage.end(), theirPublicKey , theirPublicKey + crypto_box_PUBLICKEYBYTES); + if(0 != crypto_sign_detached(mySig, NULL, &signMessage[0], signMessage.size(), my_private_sign_key)) { + throw "Signing failed."; + } + vector keKey2 = unsafe_derive_secret(crypto_stream_KEYBYTES); + crypto_stream_xor(mySigEnc, mySig, crypto_sign_BYTES, zeroNonce, &keKey2[0]); + + phase = FINISHED; + return msg; + } +} + +sts_msg2_t STS::recv_msg1(sts_msg1_t msg1) +{ + unsigned char *theirPublicKey = msg1.bytes; + unsigned char scalar_result[crypto_scalarmult_SCALARBYTES]; + crypto_hash_sha512_state state; + sts_msg2_t m; + int ret; + + if(UNKNOWN != phase) { + throw "recv_msg1 called on non-unknown phase"; + } + + memcpy(their_ephemeral_public_key, theirPublicKey, crypto_box_PUBLICKEYBYTES); + + crypto_box_keypair(ephemeral_public_key, ephemeral_private_key); + memcpy(m.pubkey,ephemeral_public_key,crypto_box_PUBLICKEYBYTES); + ret = crypto_scalarmult(scalar_result, ephemeral_private_key, theirPublicKey); + if(0 != ret) { + throw "crypto_scalarmult failed when processing message 1"; + } + + crypto_hash_sha512_init(&state); + crypto_hash_sha512_update(&state,scalar_result,crypto_scalarmult_SCALARBYTES); + crypto_hash_sha512_update(&state,theirPublicKey,crypto_box_PUBLICKEYBYTES); + crypto_hash_sha512_update(&state,ephemeral_public_key,crypto_box_PUBLICKEYBYTES); + crypto_hash_sha512_final(&state,raw_secret); + + vector livenessProof; + livenessProof.insert(livenessProof.end(), ephemeral_public_key, ephemeral_public_key + crypto_box_PUBLICKEYBYTES); + livenessProof.insert(livenessProof.end(), theirPublicKey , theirPublicKey + crypto_box_PUBLICKEYBYTES); + unsigned char mySig[crypto_sign_BYTES]; + unsigned char *mySigEnc = m.sig; + vector keKey = unsafe_derive_secret(crypto_stream_KEYBYTES); + + unsigned char zeroNonce[crypto_stream_NONCEBYTES] = {0}; + if(0 != crypto_sign_detached(mySig, NULL, &livenessProof[0], livenessProof.size(), my_private_sign_key)) { + throw "Signing failed."; + } + crypto_stream_xor(mySigEnc, mySig, crypto_sign_BYTES, zeroNonce, &keKey[0]); + + phase = SENT2; + return m; +} + +void STS::recv_msg3(sts_msg3_t msg3) +{ + unsigned char *theirSig=msg3.bytes; + unsigned char theirSigDec[crypto_sign_BYTES]; + vector expectedMessage; + if(SENT2 != phase) { + throw "recv_msg3 called out of order"; + } + + expectedMessage.insert(expectedMessage.end(), their_ephemeral_public_key , their_ephemeral_public_key + crypto_box_PUBLICKEYBYTES); + expectedMessage.insert(expectedMessage.end(), ephemeral_public_key, ephemeral_public_key + crypto_box_PUBLICKEYBYTES); + unsigned char zeroNonce[crypto_stream_NONCEBYTES] = {0}; + vector keKey2 = unsafe_derive_secret(crypto_stream_KEYBYTES); + + crypto_stream_xor(theirSigDec, theirSig, crypto_sign_BYTES, zeroNonce, &keKey2[0]); + int badSig = crypto_sign_verify_detached(theirSigDec, &expectedMessage[0], expectedMessage.size(), their_public_sign_key); + + if(badSig) { + throw "Bad signature received in message 3."; + } else { + phase = FINISHED; + } +} + +vector STS::derive_secret(size_t sz) +{ + if(phase != FINISHED) { + throw "Can not derive secrets till the key exchange has completed."; + } + return unsafe_derive_secret(sz); +} diff --git a/Networking/STS.h b/Networking/STS.h new file mode 100644 index 000000000..d588d6c99 --- /dev/null +++ b/Networking/STS.h @@ -0,0 +1,72 @@ +// (C) 2017 University of Bristol. See License.txt + +#ifndef _NETWORK_STS +#define _NETWORK_STS + +/* The Station to Station protocol + */ + +#include +#include +#include +#include + +using namespace std; + +typedef enum + { UNKNOWN // Have not started the interaction or have cleared the memory + , SENT1 // Sent initial message + , SENT2 // Received 1, sent 2 + , FINISHED // Done (received msg 2 & sent 3 or received msg 3) + , UNDEFINED // For arrays/vectors/etc of STS classes that are initialized later. +} phase_t; + +struct msg1_st { + unsigned char bytes[crypto_box_PUBLICKEYBYTES]; +}; +typedef struct msg1_st sts_msg1_t; +struct msg2_st { + unsigned char pubkey[crypto_box_PUBLICKEYBYTES]; + unsigned char sig[crypto_sign_BYTES]; +}; +typedef struct msg2_st sts_msg2_t; +struct msg3_st { + unsigned char bytes[crypto_sign_BYTES]; +}; +typedef struct msg3_st sts_msg3_t; + +class STS +{ + phase_t phase; + unsigned char their_public_sign_key[crypto_sign_PUBLICKEYBYTES]; + unsigned char my_public_sign_key[crypto_sign_PUBLICKEYBYTES]; + unsigned char my_private_sign_key[crypto_sign_SECRETKEYBYTES]; + unsigned char ephemeral_private_key[crypto_box_SECRETKEYBYTES]; + unsigned char ephemeral_public_key[crypto_box_PUBLICKEYBYTES]; + unsigned char their_ephemeral_public_key[crypto_box_PUBLICKEYBYTES]; + unsigned char raw_secret[crypto_hash_sha512_BYTES]; + uint64_t kdf_counter; + public: + STS(); + STS( const unsigned char theirPub[crypto_sign_PUBLICKEYBYTES] + , const unsigned char myPub[crypto_sign_PUBLICKEYBYTES] + , const unsigned char myPriv[crypto_sign_SECRETKEYBYTES]); + ~STS(); + + void init( const unsigned char theirPub[crypto_sign_PUBLICKEYBYTES] + , const unsigned char myPub[crypto_sign_PUBLICKEYBYTES] + , const unsigned char myPriv[crypto_sign_SECRETKEYBYTES]); + + sts_msg1_t send_msg1(); + sts_msg3_t recv_msg2(sts_msg2_t msg2); + + sts_msg2_t recv_msg1(sts_msg1_t msg1); + void recv_msg3(sts_msg3_t msg3); + + vector derive_secret(size_t); + private: + vector unsafe_derive_secret(size_t); + void kdf_block(unsigned char *block); +}; + +#endif /* _NETWORK_STS */ diff --git a/Networking/Sender.cpp b/Networking/Sender.cpp index 89dc0eedd..afeb2c499 100644 --- a/Networking/Sender.cpp +++ b/Networking/Sender.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Sender.cpp diff --git a/Networking/Sender.h b/Networking/Sender.h index ff95c8bbb..f07fcec8c 100644 --- a/Networking/Sender.h +++ b/Networking/Sender.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Sender.h diff --git a/Networking/ServerSocket.cpp b/Networking/ServerSocket.cpp index 653068a1b..d890ae8fb 100644 --- a/Networking/ServerSocket.cpp +++ b/Networking/ServerSocket.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * ServerSocket.cpp @@ -57,7 +57,7 @@ ServerSocket::ServerSocket(int Portnum) : portnum(Portnum) sleep(1); } else - { cerr << "Bound on port " << Portnum << endl; } + { cerr << "ServerSocket is bound on port " << Portnum << endl; } } if (fl<0) { error("set_up_socket:bind"); } @@ -65,6 +65,11 @@ ServerSocket::ServerSocket(int Portnum) : portnum(Portnum) fl=listen(main_socket, 1000); if (fl<0) { error("set_up_socket:listen"); } + // Note: must not call virtual init() method in constructor: http://www.aristeia.com/EC3E/3E_item9.pdf +} + +void ServerSocket::init() +{ pthread_create(&thread, 0, accept_thread, this); } @@ -95,6 +100,15 @@ void ServerSocket::accept_clients() } } +int ServerSocket::get_connection_count() +{ + data_signal.lock(); + int connection_count = clients.size(); + data_signal.unlock(); + return connection_count; +} + + int ServerSocket::get_connection_socket(int id) { data_signal.lock(); @@ -108,8 +122,60 @@ int ServerSocket::get_connection_socket(int id) while (clients.find(id) == clients.end()) data_signal.wait(); - int client = clients[id]; + int client_socket = clients[id]; used.insert(id); data_signal.unlock(); - return client; + return client_socket; +} + +void* anonymous_accept_thread(void* server_socket) +{ + ((AnonymousServerSocket*)server_socket)->accept_clients(); + return 0; +} + +int AnonymousServerSocket::global_client_socket_count = 0; + +void AnonymousServerSocket::init() +{ + pthread_create(&thread, 0, anonymous_accept_thread, this); +} + +int AnonymousServerSocket::get_connection_count() +{ + return num_accepted_clients; +} + +void AnonymousServerSocket::accept_clients() +{ + while (true) + { + struct sockaddr dest; + memset(&dest, 0, sizeof(dest)); /* zero the struct before filling the fields */ + int socksize = sizeof(dest); + int consocket = accept(main_socket, (struct sockaddr *)&dest, (socklen_t*) &socksize); + if (consocket<0) { error("set_up_socket:accept"); } + + data_signal.lock(); + client_connection_queue.push(consocket); + num_accepted_clients++; + data_signal.broadcast(); + data_signal.unlock(); + } +} + +int AnonymousServerSocket::get_connection_socket(int& client_id) +{ + data_signal.lock(); + + //while (clients.find(next_client_id) == clients.end()) + while (client_connection_queue.empty()) + data_signal.wait(); + + client_id = global_client_socket_count; + global_client_socket_count++; + int client_socket = client_connection_queue.front(); + client_connection_queue.pop(); + data_signal.unlock(); + return client_socket; } diff --git a/Networking/ServerSocket.h b/Networking/ServerSocket.h index 08a9cbc46..27c388ee7 100644 --- a/Networking/ServerSocket.h +++ b/Networking/ServerSocket.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * ServerSocket.h @@ -10,6 +10,7 @@ #include #include + #include using namespace std; #include @@ -19,6 +20,7 @@ using namespace std; class ServerSocket { +protected: int main_socket, portnum; map clients; set used; @@ -28,17 +30,51 @@ class ServerSocket // disable copying ServerSocket(const ServerSocket& other); + // receive id from client + int assign_client_id(int consocket); + public: ServerSocket(int Portnum); - ~ServerSocket(); + virtual ~ServerSocket(); - void accept_clients(); + virtual void init(); + + virtual void accept_clients(); // This depends on clients sending their id as int. // Has to be thread-safe. int get_connection_socket(int number); + // How many client connections have been made. + virtual int get_connection_count(); + void close_socket(); }; +/* + * ServerSocket where clients do not send any identifiers upon connecting. + */ +class AnonymousServerSocket : public ServerSocket +{ +private: + // Global no. of client sockets that have been returned - used to create identifiers + static int global_client_socket_count; + // No. of accepted connections in this instance + int num_accepted_clients; + queue client_connection_queue; + +public: + AnonymousServerSocket(int Portnum) : + ServerSocket(Portnum), num_accepted_clients(0) { }; + // override so clients do not send id + void accept_clients(); + void init(); + + virtual int get_connection_count(); + + // Get socket for the last client who connected + // Writes a unique client identifier (i.e. a counter) to client_id + int get_connection_socket(int& client_id); +}; + #endif /* NETWORKING_SERVERSOCKET_H_ */ diff --git a/Networking/data.h b/Networking/data.h index d131a6a9a..564cd61b9 100644 --- a/Networking/data.h +++ b/Networking/data.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _Data #define _Data diff --git a/Networking/sockets.cpp b/Networking/sockets.cpp index f225a13df..2553fd58a 100644 --- a/Networking/sockets.cpp +++ b/Networking/sockets.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "sockets.h" @@ -28,8 +28,6 @@ void error(const char *str1,const char *str2) throw bad_value(); } - - void set_up_server_socket(sockaddr_in& dest,int& consocket,int& main_socket,int Portnum) { @@ -57,7 +55,7 @@ void set_up_server_socket(sockaddr_in& dest,int& consocket,int& main_socket,int memset(my_name,0,512*sizeof(octet)); gethostname((char*)my_name,512); - /* bind serv information to mysocket + /* bind serv information to mysocket * - Just assume it will eventually wake up */ fl=1; @@ -82,21 +80,18 @@ void set_up_server_socket(sockaddr_in& dest,int& consocket,int& main_socket,int } - void close_server_socket(int consocket,int main_socket) { if (close(consocket)) { error("close(socket)"); } if (close(main_socket)) { error("close(main_socket"); }; } - - void set_up_client_socket(int& mysocket,const char* hostname,int Portnum) { mysocket = socket(AF_INET, SOCK_STREAM, 0); if (mysocket<0) { error("set_up_socket:socket"); } - - /* disable Nagle's algorithm */ + + /* disable Nagle's algorithm */ int one=1; int fl= setsockopt(mysocket, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int)); if (fl<0) { error("set_up_socket:setsockopt"); } @@ -106,17 +101,8 @@ void set_up_client_socket(int& mysocket,const char* hostname,int Portnum) struct sockaddr_in dest; dest.sin_family = AF_INET; - dest.sin_port = htons(Portnum); // set destination port number - - /* - struct hostent *server; - server=gethostbyname(hostname); - if (server== NULL) - { error("set_up_socket:gethostbyname"); } - bcopy((char *)server->h_addr, - (char *)&dest.sin_addr.s_addr, - server->h_length); // set destination IP number - */ + dest.sin_port = htons(Portnum); // set destination port number + struct addrinfo hints, *ai=NULL,*rp; memset (&hints, 0, sizeof(hints)); hints.ai_family = AF_INET; @@ -140,13 +126,13 @@ void set_up_client_socket(int& mysocket,const char* hostname,int Portnum) } } if (erp!=0) - { error("set_up_socket:getaddrinfo"); } + { error("set_up_socket:getaddrinfo"); } for (rp=ai; rp!=NULL; rp=rp->ai_next) { const struct in_addr *addr4 = &((const struct sockaddr_in*)ai->ai_addr)->sin_addr; - + if (ai->ai_family == AF_INET) - { memcpy((char *)&dest.sin_addr.s_addr,addr4,sizeof(in_addr)); + { memcpy((char *)&dest.sin_addr.s_addr,addr4,sizeof(in_addr)); continue; } } @@ -162,8 +148,6 @@ void set_up_client_socket(int& mysocket,const char* hostname,int Portnum) if (fl<0) { error("set_up_socket:connect:",hostname); } } - - void close_client_socket(int socket) { if (close(socket)) @@ -174,8 +158,6 @@ void close_client_socket(int socket) } } - - unsigned long long sent_amount = 0, sent_counter = 0; @@ -195,7 +177,7 @@ void receive(int socket,int& a) while (i==0) { i=recv(socket,msg,1,0); if (i<0) { error("Receiving error - 2"); } - } + } a=msg[0]; } diff --git a/Networking/sockets.h b/Networking/sockets.h index a7c38fb08..a0f1d945c 100644 --- a/Networking/sockets.h +++ b/Networking/sockets.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _sockets #define _sockets diff --git a/OT/BaseOT.cpp b/OT/BaseOT.cpp index 2788ed726..544286931 100644 --- a/OT/BaseOT.cpp +++ b/OT/BaseOT.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "OT/BaseOT.h" #include "Tools/random.h" @@ -34,7 +34,7 @@ OT_ROLE INV_ROLE(OT_ROLE role) return BOTH; } -void send_if_ot_sender(const TwoPartyPlayer* P, vector& os, OT_ROLE role) +void send_if_ot_sender(TwoPartyPlayer* P, vector& os, OT_ROLE role) { if (role == SENDER) { @@ -51,7 +51,7 @@ void send_if_ot_sender(const TwoPartyPlayer* P, vector& os, OT_ROLE } } -void send_if_ot_receiver(const TwoPartyPlayer* P, vector& os, OT_ROLE role) +void send_if_ot_receiver(TwoPartyPlayer* P, vector& os, OT_ROLE role) { if (role == RECEIVER) { diff --git a/OT/BaseOT.h b/OT/BaseOT.h index 188801c35..e7c2d13ea 100644 --- a/OT/BaseOT.h +++ b/OT/BaseOT.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _BASE_OT #define _BASE_OT @@ -26,8 +26,8 @@ enum OT_ROLE OT_ROLE INV_ROLE(OT_ROLE role); const char* role_to_str(OT_ROLE role); -void send_if_ot_sender(const TwoPartyPlayer* P, vector& os, OT_ROLE role); -void send_if_ot_receiver(const TwoPartyPlayer* P, vector& os, OT_ROLE role); +void send_if_ot_sender(TwoPartyPlayer* P, vector& os, OT_ROLE role); +void send_if_ot_receiver(TwoPartyPlayer* P, vector& os, OT_ROLE role); class BaseOT { diff --git a/OT/BitMatrix.cpp b/OT/BitMatrix.cpp index 012bcafce..797f9c31d 100644 --- a/OT/BitMatrix.cpp +++ b/OT/BitMatrix.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * BitMatrix.cpp diff --git a/OT/BitMatrix.h b/OT/BitMatrix.h index 3dae60702..409f1959f 100644 --- a/OT/BitMatrix.h +++ b/OT/BitMatrix.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * BitMatrix.h diff --git a/OT/BitVector.cpp b/OT/BitVector.cpp index 285dfd6c1..64fe3eee8 100644 --- a/OT/BitVector.cpp +++ b/OT/BitVector.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "OT/BitVector.h" diff --git a/OT/BitVector.h b/OT/BitVector.h index 54eac5a4f..671512b17 100644 --- a/OT/BitVector.h +++ b/OT/BitVector.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _BITVECTOR #define _BITVECTOR diff --git a/OT/NPartyTripleGenerator.cpp b/OT/NPartyTripleGenerator.cpp index 14a14f7ba..23c795edf 100644 --- a/OT/NPartyTripleGenerator.cpp +++ b/OT/NPartyTripleGenerator.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "NPartyTripleGenerator.h" diff --git a/OT/NPartyTripleGenerator.h b/OT/NPartyTripleGenerator.h index 91a905729..ae3881a00 100644 --- a/OT/NPartyTripleGenerator.h +++ b/OT/NPartyTripleGenerator.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef OT_NPARTYTRIPLEGENERATOR_H_ #define OT_NPARTYTRIPLEGENERATOR_H_ diff --git a/OT/OTExtension.cpp b/OT/OTExtension.cpp index 8efd07af7..d264ed2c7 100644 --- a/OT/OTExtension.cpp +++ b/OT/OTExtension.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "OTExtension.h" diff --git a/OT/OTExtension.h b/OT/OTExtension.h index 91724a6a0..e07367c19 100644 --- a/OT/OTExtension.h +++ b/OT/OTExtension.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _OTEXTENSION #define _OTEXTENSION diff --git a/OT/OTExtensionWithMatrix.cpp b/OT/OTExtensionWithMatrix.cpp index 7894ae949..d96d6c241 100644 --- a/OT/OTExtensionWithMatrix.cpp +++ b/OT/OTExtensionWithMatrix.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * OTExtensionWithMatrix.cpp diff --git a/OT/OTExtensionWithMatrix.h b/OT/OTExtensionWithMatrix.h index d9cf59bde..6bb52b575 100644 --- a/OT/OTExtensionWithMatrix.h +++ b/OT/OTExtensionWithMatrix.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * OTExtensionWithMatrix.h diff --git a/OT/OTMachine.cpp b/OT/OTMachine.cpp index db5e2b6b0..6d11af886 100644 --- a/OT/OTMachine.cpp +++ b/OT/OTMachine.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Networking/Player.h" #include "OT/OTExtension.h" diff --git a/OT/OTMachine.h b/OT/OTMachine.h index 9706e68f5..9f80f3dbb 100644 --- a/OT/OTMachine.h +++ b/OT/OTMachine.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * OTMachine.h diff --git a/OT/OTMultiplier.cpp b/OT/OTMultiplier.cpp index 658ae6504..ead5d3e0c 100644 --- a/OT/OTMultiplier.cpp +++ b/OT/OTMultiplier.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * OTMultiplier.cpp diff --git a/OT/OTMultiplier.h b/OT/OTMultiplier.h index 15d9530e4..5149a24b3 100644 --- a/OT/OTMultiplier.h +++ b/OT/OTMultiplier.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * OTMultiplier.h diff --git a/OT/OTTripleSetup.cpp b/OT/OTTripleSetup.cpp index 049eda821..4f10f08da 100644 --- a/OT/OTTripleSetup.cpp +++ b/OT/OTTripleSetup.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "OTTripleSetup.h" diff --git a/OT/OTTripleSetup.h b/OT/OTTripleSetup.h index 354cd21e9..31809e58a 100644 --- a/OT/OTTripleSetup.h +++ b/OT/OTTripleSetup.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef OT_TRIPLESETUP_H_ #define OT_TRIPLESETUP_H_ diff --git a/OT/OText_main.cpp b/OT/OText_main.cpp index fc2edaaf7..39867547a 100644 --- a/OT/OText_main.cpp +++ b/OT/OText_main.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * OText_main.cpp diff --git a/OT/OutputCheck.h b/OT/OutputCheck.h index a598b8930..01681e79f 100644 --- a/OT/OutputCheck.h +++ b/OT/OutputCheck.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * check.h diff --git a/OT/Tools.cpp b/OT/Tools.cpp index 32e9208fe..a2cc03d65 100644 --- a/OT/Tools.cpp +++ b/OT/Tools.cpp @@ -1,9 +1,9 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Tools.h" #include "Math/gf2nlong.h" -void random_seed_commit(octet* seed, const TwoPartyPlayer& player, int len) +void random_seed_commit(octet* seed, TwoPartyPlayer& player, int len) { PRNG G; G.ReSeed(); diff --git a/OT/Tools.h b/OT/Tools.h index 1038d6430..53ec588b5 100644 --- a/OT/Tools.h +++ b/OT/Tools.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _OTTOOLS #define _OTTOOLS @@ -12,7 +12,7 @@ /* * Generate a secure, random seed between 2 parties via commitment */ -void random_seed_commit(octet* seed, const TwoPartyPlayer& player, int len); +void random_seed_commit(octet* seed, TwoPartyPlayer& player, int len); /* * GF(2^128) multiplication using Intel instructions diff --git a/OT/TripleMachine.cpp b/OT/TripleMachine.cpp index b1f72c9d1..5c1e6dbd6 100644 --- a/OT/TripleMachine.cpp +++ b/OT/TripleMachine.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * TripleMachine.cpp diff --git a/OT/TripleMachine.h b/OT/TripleMachine.h index d377bca22..d318b8142 100644 --- a/OT/TripleMachine.h +++ b/OT/TripleMachine.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * TripleMachine.h diff --git a/Player-Online.cpp b/Player-Online.cpp index 246e845c9..b0397d485 100644 --- a/Player-Online.cpp +++ b/Player-Online.cpp @@ -1,11 +1,14 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Processor/Machine.h" +#include "Math/Setup.h" #include "Tools/ezOptionParser.h" +#include "Tools/Config.h" #include #include #include +#include using namespace std; int main(int argc, const char** argv) @@ -108,6 +111,15 @@ int main(int argc, const char** argv) "-b", // Flag token. "--max-broadcast" // Flag token. ); + opt.add( + "0", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Use communications security between SPDZ players", // Help description. + "-c", // Flag token. + "--player-to-player-commsec" // Flag token. + ); opt.parse(argc, argv); @@ -156,6 +168,7 @@ int main(int argc, const char** argv) string memtype, hostname; int lg2, lgp, pnbase, opening_sum, max_broadcast; + int p2pcommsec; opt.get("--portnumbase")->getInt(pnbase); opt.get("--lgp")->getInt(lgp); @@ -164,11 +177,25 @@ int main(int argc, const char** argv) opt.get("--hostname")->getString(hostname); opt.get("--opening-sum")->getInt(opening_sum); opt.get("--max-broadcast")->getInt(max_broadcast); + opt.get("--player-to-player-commsec")->getInt(p2pcommsec); + int mynum; + sscanf((*allArgs[1]).c_str(), "%d", &mynum); + + CommsecKeysPackage *keys = NULL; + if(p2pcommsec) { + vector pubkeys; + secret_signing_key mykey; + public_signing_key mypublickey; + string prep_data_prefix = get_prep_dir(2, lgp, lg2); + Config::read_player_config(prep_data_prefix,mynum,pubkeys,mykey,mypublickey); + keys = new CommsecKeysPackage(pubkeys,mykey,mypublickey); + } + Machine(playerno, pnbase, hostname, progname, memtype, lgp, lg2, opt.get("--direct")->isSet, opening_sum, opt.get("--parallel")->isSet, - opt.get("--threads")->isSet, max_broadcast).run(); + opt.get("--threads")->isSet, max_broadcast, keys).run(); cerr << "Command line:"; for (int i = 0; i < argc; i++) diff --git a/Processor/Binary_File_IO.cpp b/Processor/Binary_File_IO.cpp new file mode 100644 index 000000000..5e4e4b8b0 --- /dev/null +++ b/Processor/Binary_File_IO.cpp @@ -0,0 +1,72 @@ +// (C) 2017 University of Bristol. See License.txt + +#include "Processor/Binary_File_IO.h" +#include "Math/gfp.h" + +/* + * Provides generalised file read and write methods for arrays of shares. + * Stateless and not optimised for multiple reads from file. + * Intended for application specific file IO. + */ + +template +void Binary_File_IO::write_to_file(const string filename, const vector< Share >& buffer) +{ + ofstream outf; + + outf.open(filename, ios::out | ios::binary | ios::app); + if (outf.fail()) { throw file_error(filename); } + + for (unsigned int i = 0; i < buffer.size(); i++) + { + buffer[i].output(outf, false); + } + + outf.close(); +} + +template +void Binary_File_IO::read_from_file(const string filename, vector< Share >& buffer, const int start_posn, int &end_posn) +{ + ifstream inf; + inf.open(filename, ios::in | ios::binary); + if (inf.fail()) { throw file_missing(filename, "Binary_File_IO.read_from_file expects this file to exist."); } + + int size_in_bytes = Share::size() * buffer.size(); + int n_read = 0; + char * read_buffer = new char[size_in_bytes]; + inf.seekg(start_posn); + do + { + inf.read(read_buffer + n_read, size_in_bytes - n_read); + n_read += inf.gcount(); + + if (inf.eof()) + { + stringstream ss; + ss << "Got to EOF when reading from disk (expecting " << size_in_bytes << " bytes)."; + throw file_error(ss.str()); + } + if (inf.fail()) + { + stringstream ss; + ss << "IO problem when reading from disk"; + throw file_error(ss.str()); + } + } + while (n_read < size_in_bytes); + + end_posn = inf.tellg(); + + //Check if at end of file by getting 1 more char. + inf.get(); + if (inf.eof()) + end_posn = -1; + inf.close(); + + for (unsigned int i = 0; i < buffer.size(); i++) + buffer[i].assign(&read_buffer[i*Share::size()]); +} + +template void Binary_File_IO::write_to_file(const string filename, const vector< Share >& buffer); +template void Binary_File_IO::read_from_file(const string filename, vector< Share >& buffer, const int start_posn, int &end_posn); diff --git a/Processor/Binary_File_IO.h b/Processor/Binary_File_IO.h new file mode 100644 index 000000000..b0c6ce0b5 --- /dev/null +++ b/Processor/Binary_File_IO.h @@ -0,0 +1,43 @@ +// (C) 2017 University of Bristol. See License.txt + +#ifndef _FILE_IO_HEADER +#define _FILE_IO_HEADER + +#include "Exceptions/Exceptions.h" +#include "Math/Share.h" + +#include +#include +#include +#include + +using namespace std; + +/* + * Provides generalised file read and write methods for arrays of numeric data types. + * Stateless and not optimised for multiple reads from file. + * Intended for MPC application specific file IO. + */ + +class Binary_File_IO +{ + public: + + /* + * Append the buffer values as binary to the filename. + * Throws file_error. + */ + template + void write_to_file(const string filename, const vector< Share >& buffer); + + /* + * Read from posn in the filename the binary values until the buffer is full. + * Assumes file holds binary that maps into the type passed in. + * Returns the current posn in the file or -1 if at eof. + * Throws file_error. + */ + template + void read_from_file(const string filename, vector< Share >& buffer, const int start_posn, int &end_posn); +}; + +#endif diff --git a/Processor/Buffer.cpp b/Processor/Buffer.cpp index a5ca44526..788782055 100644 --- a/Processor/Buffer.cpp +++ b/Processor/Buffer.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Buffer.cpp diff --git a/Processor/Buffer.h b/Processor/Buffer.h index a9936d1e1..f8aae5758 100644 --- a/Processor/Buffer.h +++ b/Processor/Buffer.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Buffer.h diff --git a/Processor/Data_Files.cpp b/Processor/Data_Files.cpp index 96d64d0ab..eb0eed775 100644 --- a/Processor/Data_Files.cpp +++ b/Processor/Data_Files.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Processor/Data_Files.h" @@ -56,7 +56,7 @@ void DataPositions::print_cost() const file >> cost_per_item; if (cost_per_item < 0) break; - int items_used = files[i][j]; + long long items_used = files[i][j]; double cost = items_used * cost_per_item; total_cost += cost; cerr.fill(' '); diff --git a/Processor/Data_Files.h b/Processor/Data_Files.h index 2b5581121..c9167df13 100644 --- a/Processor/Data_Files.h +++ b/Processor/Data_Files.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _Data_Files #define _Data_Files @@ -73,10 +73,10 @@ class Data_Files DataPositions usage; - const string prep_data_dir; - public: + const string prep_data_dir; + static const char* dtype_names[N_DTYPE]; static const char* field_names[N_DATA_FIELD_TYPE]; static const char* long_field_names[N_DATA_FIELD_TYPE]; diff --git a/Processor/ExternalClients.cpp b/Processor/ExternalClients.cpp new file mode 100644 index 000000000..6a5df0498 --- /dev/null +++ b/Processor/ExternalClients.cpp @@ -0,0 +1,179 @@ +// (C) 2017 University of Bristol. See License.txt + +#include "Processor/ExternalClients.h" +#include +#include +#include + +ExternalClients::ExternalClients(int party_num, const string& prep_data_dir): + party_num(party_num), prep_data_dir(prep_data_dir), server_connection_count(-1) +{ +} + +ExternalClients::~ExternalClients() +{ + // close client sockets + for (map::iterator it = external_client_sockets.begin(); + it != external_client_sockets.end(); it++) + { + if (close(it->second)) + { + error("failed to close external client connection socket)"); + } + } + for (map::iterator it = client_connection_servers.begin(); + it != client_connection_servers.end(); it++) + { + delete it->second; + } + for (map::iterator it = symmetric_client_keys.begin(); + it != symmetric_client_keys.end(); it++) + { + delete[] it->second; + } + for (map,uint64_t> >::iterator it_cs = symmetric_client_commsec_send_keys.begin(); + it_cs != symmetric_client_commsec_send_keys.end(); it_cs++) + { + memset(&(it_cs->second.first[0]), 0, it_cs->second.first.size()); + } + for (map,uint64_t> >::iterator it_cs = symmetric_client_commsec_recv_keys.begin(); + it_cs != symmetric_client_commsec_recv_keys.end(); it_cs++) + { + memset(&(it_cs->second.first[0]), 0, it_cs->second.first.size()); + } +} + +void ExternalClients::start_listening(int portnum_base) +{ + client_connection_servers[portnum_base] = new AnonymousServerSocket(portnum_base + get_party_num()); + client_connection_servers[portnum_base]->init(); + cerr << "Start listening on thread " << this_thread::get_id() << endl; + cerr << "Party " << get_party_num() << " is listening on port " << (portnum_base + get_party_num()) + << " for external client connections." << endl; +} + +int ExternalClients::get_client_connection(int portnum_base) +{ + map::iterator it = client_connection_servers.find(portnum_base); + if (it == client_connection_servers.end()) + { + cerr << "Thread " << this_thread::get_id() << " didn't find server." << endl; + return -1; + } + cerr << "Thread " << this_thread::get_id() << " found server." << endl; + int client_id, socket; + socket = client_connection_servers[portnum_base]->get_connection_socket(client_id); + external_client_sockets[client_id] = socket; + cerr << "Party " << get_party_num() << " received external client connection from client id: " << dec << client_id << endl; + return client_id; +} + +int ExternalClients::connect_to_server(int portnum_base, int ipv4_address) +{ + struct in_addr addr = { (unsigned int)ipv4_address }; + int csocket; + const char* address_str = inet_ntoa(addr); + cerr << "Party " << get_party_num() << " connecting to server at " << address_str << " on port " << portnum_base + get_party_num() << endl; + set_up_client_socket(csocket, address_str, portnum_base + get_party_num()); + cerr << "Party " << get_party_num() << " connected to server at " << address_str << " on port " << portnum_base + get_party_num() << endl; + int server_id = server_connection_count; + // server identifiers are -1, -2, ... to avoid conflict with client identifiers + server_connection_count--; + external_client_sockets[server_id] = csocket; + return server_id; +} + +void ExternalClients::curve25519_ints_to_bytes(unsigned char *bytes, const vector& key_ints) +{ + for(unsigned int j = 0; j < key_ints.size(); j++) { + for(unsigned int k = 0; k < 4; k++) { + bytes[j*sizeof(int) + k] = (key_ints[j] >> ((3-k)*8)) & 0xFF; + } + } +} + +// Generate sesssion key for a newly connected client, store in symmetric_client_keys +// public_key is expected to be size 8 and contain integer values of public key bytes. +// Assumes load_server_keys has been run. +void ExternalClients::generate_session_key_for_client(int client_id, const vector& public_key) +{ + assert(public_key.size() * sizeof(int) == crypto_box_PUBLICKEYBYTES); + + load_server_keys_once(); + + unsigned char client_publickey[crypto_box_PUBLICKEYBYTES]; + + curve25519_ints_to_bytes(client_publickey, public_key); + + cerr << "Recevied client public key for client " << dec << client_id << " :"; + for (unsigned int j = 0; j < crypto_box_PUBLICKEYBYTES; j++) + cerr << hex << (int) client_publickey[j] << " "; + cerr << dec << endl; + + unsigned char scalarmult_q_by_server[crypto_scalarmult_BYTES]; + crypto_generichash_state h; + + symmetric_client_keys[client_id] = new octet[crypto_generichash_BYTES]; + + // Derive a shared key from this server's secret key and the client's public key + // shared key = h(q || server_secretkey || client_publickey) + if (crypto_scalarmult(scalarmult_q_by_server, server_secretkey, client_publickey) != 0) { + cerr << "Scalar mult failed\n"; + exit(1); + } + crypto_generichash_init(&h, NULL, 0U, crypto_generichash_BYTES); + crypto_generichash_update(&h, scalarmult_q_by_server, sizeof scalarmult_q_by_server); + crypto_generichash_update(&h, client_publickey, sizeof client_publickey); + crypto_generichash_update(&h, server_publickey, sizeof server_publickey); + crypto_generichash_final(&h, symmetric_client_keys[client_id], crypto_generichash_BYTES); +} + +// Read pre-computed server keys from client-setup for this SPDZ engine. +// Only needs to be done once per run, but is only necessary if an external connection +// is being requested. +void ExternalClients::load_server_keys_once() +{ + if (server_keys_loaded) { + return; + } + + ifstream keyfile; + stringstream filename; + filename << prep_data_dir << "Player-SPDZ-Keys-P" << get_party_num(); + keyfile.open(filename.str().c_str()); + if (keyfile.fail()) + throw file_error(filename.str().c_str()); + + keyfile.read((char*)server_publickey, sizeof server_publickey); + if (keyfile.eof()) + throw end_of_file(filename.str(), "server public key" ); + keyfile.read((char*)server_secretkey, sizeof server_secretkey); + if (keyfile.eof()) + throw end_of_file(filename.str(), "server private key" ); + + bool loaded_ed25519 = true; + + keyfile.read((char*)server_publickey_ed25519, sizeof server_publickey_ed25519); + if (keyfile.eof() || keyfile.bad()) + loaded_ed25519 = false; + keyfile.read((char*)server_secretkey_ed25519, sizeof server_secretkey_ed25519); + if (keyfile.eof() || keyfile.bad()) + loaded_ed25519 = false; + + keyfile.close(); + + ed25519_keys_loaded = loaded_ed25519; + server_keys_loaded = true; +} + +void ExternalClients::require_ed25519_keys() +{ + if (!ed25519_keys_loaded) + throw "Ed25519 keys required but not found in player key files"; +} + +int ExternalClients::get_party_num() +{ + return party_num; +} + diff --git a/Processor/ExternalClients.h b/Processor/ExternalClients.h new file mode 100644 index 000000000..12810e999 --- /dev/null +++ b/Processor/ExternalClients.h @@ -0,0 +1,65 @@ +// (C) 2017 University of Bristol. See License.txt + +#ifndef _ExternalClients +#define _ExternalClients + +#include "Networking/ServerSocket.h" +#include "Networking/sockets.h" +#include "Exceptions/Exceptions.h" +#include +#include +#include +#include +#include +#include + +/* + * Manage the reading and writing of data from/to external clients via Sockets. + * Generate the session keys for encryption/decryption of secret communication with external clients. + */ + +class ExternalClients +{ + map client_connection_servers; + + int party_num; + const string prep_data_dir; + int server_connection_count; + unsigned char server_publickey[crypto_box_PUBLICKEYBYTES]; + unsigned char server_secretkey[crypto_box_SECRETKEYBYTES]; + bool server_keys_loaded = false; + bool ed25519_keys_loaded = false; + + public: + + unsigned char server_publickey_ed25519[crypto_sign_ed25519_PUBLICKEYBYTES]; + unsigned char server_secretkey_ed25519[crypto_sign_ed25519_SECRETKEYBYTES]; + + // Maps holding per client values (indexed by unique 32-bit id) + std::map external_client_sockets; + std::map symmetric_client_keys; + std::map,uint64_t>> symmetric_client_commsec_send_keys; + std::map,uint64_t>> symmetric_client_commsec_recv_keys; + + ExternalClients(int party_num, const string& prep_data_dir); + ~ExternalClients(); + + void start_listening(int portnum_base); + + int get_client_connection(int portnum_base); + + int connect_to_server(int portnum_base, int ipv4_address); + + // return the socket for a given client or server identifier + int get_socket(int socket_id); + + void curve25519_ints_to_bytes(unsigned char bytes[crypto_box_PUBLICKEYBYTES], const vector& key_ints); + void generate_session_key_for_client(int client_id, const vector& public_key); + + void load_server_keys_once(); + + int get_party_num(); + void require_ed25519_keys(); +}; + +#endif diff --git a/Processor/Input.cpp b/Processor/Input.cpp index cc42ff40f..5628ad25c 100644 --- a/Processor/Input.cpp +++ b/Processor/Input.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Input.cpp diff --git a/Processor/Input.h b/Processor/Input.h index 69a870727..d324f8985 100644 --- a/Processor/Input.h +++ b/Processor/Input.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Input.h diff --git a/Processor/InputTuple.h b/Processor/InputTuple.h index 67c38a126..740f32111 100644 --- a/Processor/InputTuple.h +++ b/Processor/InputTuple.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * InputTuple.h diff --git a/Processor/Instruction.cpp b/Processor/Instruction.cpp index 8f104fa6a..bcf039ad0 100644 --- a/Processor/Instruction.cpp +++ b/Processor/Instruction.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Processor/Instruction.h" @@ -6,6 +6,7 @@ #include "Processor/Processor.h" #include "Exceptions/Exceptions.h" #include "Tools/time-func.h" +#include "Tools/parse.h" #include #include @@ -13,28 +14,6 @@ #include -// Read a byte -int get_val(istream& s) -{ - char cc; - s.get(cc); - int a=cc; - if (a<0) { a+=256; } - return a; -} - -// Read a 4-byte integer -int get_int(istream& s) -{ - int n = 0; - for (int i=0; i<4; i++) - { n<<=8; - int t=get_val(s); - n+=t; - } - return n; -} - // Convert modp to signed bigint of a given bit length void to_signed_bigint(bigint& bi, const gfp& x, int len) { @@ -57,18 +36,10 @@ void to_signed_bigint(bigint& bi, const gfp& x, int len) } -void get_vector(int m, vector& start, istream& s) -{ - start.resize(m); - for (int i = 0; i < m; i++) - start[i] = get_int(s); -} - - void Instruction::parse(istream& s) { n=0; start.resize(0); - r[0]=0; r[1]=0; r[2]=0; + r[0]=0; r[1]=0; r[2]=0; r[3]=0; int pos=s.tellg(); opcode=get_int(s); @@ -78,6 +49,13 @@ void Instruction::parse(istream& s) if (size==0) size=1; + parse_operands(s, pos); +} + + +void BaseInstruction::parse_operands(istream& s, int pos) +{ + int num_var_args = 0; switch (opcode) { // instructions with 3 register operands @@ -182,6 +160,7 @@ void Instruction::parse(istream& s) case GRAWOUTPUT: case PRINTCHRINT: case PRINTSTRINT: + case PRINTINT: r[0]=get_int(s); break; // instructions with 3 registers + 1 integer operand @@ -227,6 +206,8 @@ void Instruction::parse(istream& s) case RUN_TAPE: case STARTPRIVATEOUTPUT: case GSTARTPRIVATEOUTPUT: + case DIGESTC: + case CONNECTIPV4: // write socket handle, read IPv4 address, portnum r[0]=get_int(s); r[1]=get_int(s); n = get_int(s); @@ -259,10 +240,7 @@ void Instruction::parse(istream& s) case GSTOPPRIVATEOUTPUT: case INPUTMASK: case GINPUTMASK: - case READSOCKETC: - case READSOCKETS: - case WRITESOCKETC: - case WRITESOCKETS: + case ACCEPTCLIENTCONNECTION: r[0]=get_int(s); n = get_int(s); break; @@ -272,49 +250,84 @@ void Instruction::parse(istream& s) case JMP: case START: case STOP: - case OPENSOCKET: + case LISTEN: n = get_int(s); break; // instructions with no operand case TIME: case CRASH: - case CLOSESOCKET: - break; + break; // instructions with 4 register operands case PRINTFLOATPLAIN: get_vector(4, start, s); break; - // open instructions + // open instructions + read/write instructions with variable length args case STARTOPEN: case STOPOPEN: case GSTARTOPEN: case GSTOPOPEN: - int m; - m = get_int(s); - get_vector(m, start, s); + case WRITEFILESHARE: + num_var_args = get_int(s); + get_vector(num_var_args, start, s); + break; + + // read from file, input is opcode num_args, + // start_file_posn (read), end_file_posn(write) var1, var2, ... + case READFILESHARE: + num_var_args = get_int(s) - 2; + r[0] = get_int(s); + r[1] = get_int(s); + get_vector(num_var_args, start, s); + break; + + // read from external client, input is : opcode num_args, client_id, var1, var2 ... + case READSOCKETC: + case READSOCKETS: + case READSOCKETINT: + case READCLIENTPUBLICKEY: + num_var_args = get_int(s) - 1; + r[0] = get_int(s); + get_vector(num_var_args, start, s); + break; + + // write to external client, input is : opcode num_args, client_id, message_type, var1, var2 ... + case WRITESOCKETC: + case WRITESOCKETS: + case WRITESOCKETSHARE: + case WRITESOCKETINT: + num_var_args = get_int(s) - 2; + r[0] = get_int(s); + r[1] = get_int(s); + get_vector(num_var_args, start, s); + break; + case INITSECURESOCKET: + case RESPSECURESOCKET: + num_var_args = get_int(s) - 1; + r[0] = get_int(s); + get_vector(num_var_args, start, s); break; // raw input case STOPINPUT: case GSTOPINPUT: // subtract player number argument - m = get_int(s) - 1; + num_var_args = get_int(s) - 1; n = get_int(s); - get_vector(m, start, s); + get_vector(num_var_args, start, s); break; case GBITDEC: case GBITCOM: - m = get_int(s) - 2; + num_var_args = get_int(s) - 2; r[0] = get_int(s); n = get_int(s); - get_vector(m, start, s); + get_vector(num_var_args, start, s); break; case PREP: case GPREP: // subtract extra argument - m = get_int(s) - 1; + num_var_args = get_int(s) - 1; s.read((char*)r, sizeof(r)); - start.resize(m); - for (int i = 0; i < m; i++) + start.resize(num_var_args); + for (int i = 0; i < num_var_args; i++) { start[i] = get_int(s); } break; case USE_PREP: @@ -341,8 +354,8 @@ void Instruction::parse(istream& s) break; default: ostringstream os; - os << "Invalid instruction " << hex << showbase << opcode << " at " << pos; - throw Processor_Error(os.str()); + os << "Invalid instruction " << hex << showbase << opcode << " at " << dec << pos; + throw Invalid_Instruction(os.str()); } } @@ -376,7 +389,7 @@ bool Instruction::get_offline_data_usage(DataPositions& usage) } } -RegType Instruction::get_reg_type() const +int BaseInstruction::get_reg_type() const { switch (opcode) { case LDMINT: @@ -386,6 +399,16 @@ RegType Instruction::get_reg_type() const case PUSHINT: case POPINT: case MOVINT: + case READSOCKETINT: + case WRITESOCKETINT: + case READCLIENTPUBLICKEY: + case INITSECURESOCKET: + case RESPSECURESOCKET: + case LDARG: + case LDINT: + case CONVMODP: + case GCONVGF2N: + case RAND: return INT; case PREP: case USE_PREP: @@ -402,7 +425,7 @@ RegType Instruction::get_reg_type() const } } -int Instruction::get_max_reg(RegType reg_type) const +int BaseInstruction::get_max_reg(int reg_type) const { if (get_reg_type() != reg_type) { return 0; } @@ -420,7 +443,7 @@ int Instruction::get_mem(RegType reg_type, SecrecyType sec_type) const return 0; } -bool Instruction::is_direct_memory_access(SecrecyType sec_type) const +bool BaseInstruction::is_direct_memory_access(SecrecyType sec_type) const { if (sec_type == SECRET) { @@ -825,10 +848,21 @@ void Instruction::execute(Processor& Proc) const case LEGENDREC: to_bigint(Proc.temp.aa, Proc.read_Cp(r[1])); Proc.temp.aa = mpz_legendre(Proc.temp.aa.get_mpz_t(), gfp::pr().get_mpz_t()); - //Proc.temp.aa = legendre; to_gfp(Proc.temp.ansp, Proc.temp.aa); Proc.write_Cp(r[0], Proc.temp.ansp); break; + case DIGESTC: + { + octetStream o; + to_bigint(Proc.temp.aa, Proc.read_Cp(r[1])); + + to_gfp(Proc.temp.ansp, Proc.temp.aa); + Proc.temp.ansp.pack(o); + // keep first n bytes + to_gfp(Proc.temp.ansp, o.check_sum(n)); + Proc.write_Cp(r[0], Proc.temp.ansp); + } + break; case DIVCI: if (n == 0) throw Processor_Error("Division by immediate zero"); @@ -1455,6 +1489,12 @@ void Instruction::execute(Processor& Proc) const cout << res << flush; } break; + case PRINTINT: + if (Proc.P.my_num() == 0) + { + cout << Proc.read_Ci(r[0]) << flush; + } + break; case PRINTSTR: if (Proc.P.my_num() == 0) { @@ -1490,16 +1530,13 @@ void Instruction::execute(Processor& Proc) const case GUSE_PREP: break; case TIME: - cout << "Elapsed time: " << Proc.machine.timer[0].elapsed() << endl; + Proc.machine.time(); break; case START: - cout << "Starting timer " << n << " at " << Proc.machine.timer[n].elapsed() - << " after " << Proc.machine.timer[n].idle() << endl; - Proc.machine.timer[n].start(); + Proc.machine.start(n); break; case STOP: - Proc.machine.timer[n].stop(); - cout << "Stopped timer " << n << " at " << Proc.machine.timer[n].elapsed() << endl; + Proc.machine.stop(n); break; case RUN_TAPE: Proc.DataF.skip(Proc.machine.run_tape(r[0], n, r[1], -1)); @@ -1513,39 +1550,81 @@ void Instruction::execute(Processor& Proc) const // *** // TODO: read/write shared GF(2^n) data instructions // *** - case OPENSOCKET: - Proc.open_socket(n); + case LISTEN: + // listen for connections at port number n + Proc.external_clients.start_listening(n); + break; + case ACCEPTCLIENTCONNECTION: + { + // get client connection at port number n + my_num()) + int client_handle = Proc.external_clients.get_client_connection(n); + if (client_handle == -1) + { + stringstream ss; + ss << "No connection on port " << r[0] << endl; + throw Processor_Error(ss.str()); + } + Proc.write_Ci(r[0], client_handle); + break; + } + case CONNECTIPV4: + { + // connect to server at port n + my_num() + int ipv4 = Proc.read_Ci(r[1]); + int server_handle = Proc.external_clients.connect_to_server(n, ipv4); + Proc.write_Ci(r[0], server_handle); + break; + } + case READCLIENTPUBLICKEY: + Proc.read_client_public_key(Proc.read_Ci(r[0]), start); break; - case CLOSESOCKET: - Proc.close_socket(); + case INITSECURESOCKET: + Proc.init_secure_socket(Proc.read_Ci(r[i]), start); break; - case READSOCKETC: // n is *unused atm*, r[0] is register to write to - int dest; - Proc.read_socket(dest); - Proc.write_Ci(r[0], (long)dest); + case RESPSECURESOCKET: + Proc.resp_secure_socket(Proc.read_Ci(r[i]), start); + break; + case READSOCKETINT: + Proc.read_socket_ints(Proc.read_Ci(r[0]), start); + break; + case READSOCKETC: + Proc.read_socket_vector(Proc.read_Ci(r[0]), start); break; case READSOCKETS: - // read share then MAC share - Proc.read_socket(Proc.temp.ansp); - Proc.get_Sp_ref(r[0]).set_share(Proc.temp.ansp); - Proc.read_socket(Proc.temp.ansp); - Proc.get_Sp_ref(r[0]).set_mac(Proc.temp.ansp); + // read shares and MAC shares + Proc.read_socket_private(Proc.read_Ci(r[0]), start, true); break; case GREADSOCKETS: //Proc.get_S2_ref(r[0]).get_share().pack(socket_octetstream); //Proc.get_S2_ref(r[0]).get_mac().pack(socket_octetstream); break; - case WRITESOCKETC: // n is *unused atm*, r[0] is register to write to; - Proc.write_socket((int&)Proc.get_Ci_ref(r[0])); + case WRITESOCKETINT: + Proc.write_socket(INT, CLEAR, false, Proc.read_Ci(r[0]), r[1], start); + break; + case WRITESOCKETC: + Proc.write_socket(MODP, CLEAR, false, Proc.read_Ci(r[0]), r[1], start); break; case WRITESOCKETS: - Proc.write_socket(Proc.get_Sp_ref(r[0]).get_share()); - Proc.write_socket(Proc.get_Sp_ref(r[0]).get_mac()); + // Send shares + MACs + Proc.write_socket(MODP, SECRET, true, Proc.read_Ci(r[0]), r[1], start); + break; + case WRITESOCKETSHARE: + // Send only shares, no MACs + // N.B. doesn't make sense to have a corresponding read instruction for this + Proc.write_socket(MODP, SECRET, false, Proc.read_Ci(r[0]), r[1], start); break; /*case GWRITESOCKETS: Proc.get_S2_ref(r[0]).get_share().pack(socket_octetstream); Proc.get_S2_ref(r[0]).get_mac().pack(socket_octetstream); break;*/ + case WRITEFILESHARE: + // Write shares to file system + Proc.write_shares_to_file(start); + break; + case READFILESHARE: + // Read shares from file system + Proc.read_shares_from_file(Proc.read_Ci(r[0]), r[1], start); + break; case PUBINPUT: Proc.public_input >> Proc.get_Ci_ref(r[0]); break; diff --git a/Processor/Instruction.h b/Processor/Instruction.h index 9975021f3..7ab07734c 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _Instruction #define _Instruction @@ -11,7 +11,6 @@ #include using namespace std; -#include "Processor/Memory.h" #include "Processor/Data_Files.h" #include "Networking/Player.h" #include "Math/Integer.h" @@ -23,7 +22,7 @@ class Processor; /* * Opcode constants * - * Whenever these are changed the corresponding dict in Compiler/instructions.py + * Whenever these are changed the corresponding dict in Compiler/instructions_base.py * MUST also be changed. (+ the documentation) */ enum @@ -89,6 +88,7 @@ enum MODC = 0x36, MODCI = 0x37, LEGENDREC = 0x38, + DIGESTC = 0x39, // Open STARTOPEN = 0xA0, STOPOPEN = 0xA1, @@ -107,8 +107,13 @@ enum READSOCKETS = 0x64, WRITESOCKETC = 0x65, WRITESOCKETS = 0x66, - OPENSOCKET = 0x67, - CLOSESOCKET = 0x68, + READSOCKETINT = 0x69, + WRITESOCKETINT = 0x6a, + WRITESOCKETSHARE = 0x6b, + LISTEN = 0x6c, + ACCEPTCLIENTCONNECTION = 0x6d, + CONNECTIPV4 = 0x6e, + READCLIENTPUBLICKEY = 0x6f, // Bitwise logic ANDC = 0x70, XORC = 0x71, @@ -138,6 +143,7 @@ enum SUBINT = 0x9C, MULINT = 0x9D, DIVINT = 0x9E, + PRINTINT = 0x9F, // Conversion CONVINT = 0xC0, CONVMODP = 0xC1, @@ -156,6 +162,8 @@ enum PRINTCHRINT = 0xBA, PRINTSTRINT = 0xBB, PRINTFLOATPLAIN = 0xBC, + WRITEFILESHARE = 0xBD, + READFILESHARE = 0xBE, // GF(2^n) versions @@ -241,6 +249,9 @@ enum GRAWOUTPUT = 0x1B7, GSTARTPRIVATEOUTPUT = 0x1B8, GSTOPPRIVATEOUTPUT = 0x1B9, + // Commsec ops + INITSECURESOCKET = 0x1BA, + RESPSECURESOCKET = 0x1BB }; @@ -259,7 +270,6 @@ enum SecrecyType { MAX_SECRECY_TYPE }; - struct TempVars { gf2n ans2; Share Sans2; gfp ansp; Share Sansp; @@ -273,29 +283,38 @@ struct TempVars { }; -class Instruction +class BaseInstruction { +protected: int opcode; // The code int size; // Vector size - int r[3]; // Three possible registers + int r[4]; // Fixed parameter registers unsigned int n; // Possible immediate value vector start; // Values for a start/stop open - public: +public: + virtual ~BaseInstruction() {}; - // Reads a single instruction from the istream - void parse(istream& s); - - // Return whether usage is known - bool get_offline_data_usage(DataPositions& usage); + void parse_operands(istream& s, int pos); bool is_gf2n_instruction() const { return ((opcode&0x100)!=0); } - RegType get_reg_type() const; + virtual int get_reg_type() const; bool is_direct_memory_access(SecrecyType sec_type) const; // Returns the maximal register used - int get_max_reg(RegType reg_type) const; + int get_max_reg(int reg_type) const; +}; + + +class Instruction : public BaseInstruction +{ +public: + // Reads a single instruction from the istream + void parse(istream& s); + + // Return whether usage is known + bool get_offline_data_usage(DataPositions& usage); // Returns the memory size used if applicable and known int get_mem(RegType reg_type, SecrecyType sec_type) const; diff --git a/Processor/Machine.cpp b/Processor/Machine.cpp index 061a4f416..549e06eed 100644 --- a/Processor/Machine.cpp +++ b/Processor/Machine.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Machine.h" @@ -17,12 +17,14 @@ using namespace std; Machine::Machine(int my_number, int PortnumBase, string hostname, string progname_str, string memtype, int lgp, int lg2, bool direct, - int opening_sum, bool parallel, bool receive_threads, int max_broadcast) + int opening_sum, bool parallel, bool receive_threads, int max_broadcast, + CommsecKeysPackage *commsec_keys) : my_number(my_number), nthreads(0), tn(0), numt(0), usage_unknown(false), progname(progname_str), direct(direct), opening_sum(opening_sum), parallel(parallel), receive_threads(receive_threads), max_broadcast(max_broadcast) { N.init(my_number,PortnumBase,hostname.c_str()); + N.set_keys(commsec_keys); if (opening_sum < 2) this->opening_sum = N.num_players(); @@ -106,36 +108,9 @@ Machine::Machine(int my_number, int PortnumBase, string hostname, if (pinp.fail()) { throw file_error(filename); } progs[i].parse(pinp); pinp.close(); - if (progs[i].direct_mem2_s() > M2.size_s()) - { - cerr << threadname << " needs more secret mod2 memory, resizing to " - << progs[i].direct_mem2_s() << endl; - M2.resize_s(progs[i].direct_mem2_s()); - } - if (progs[i].direct_memp_s() > Mp.size_s()) - { - cerr << threadname << " needs more secret modp memory, resizing to " - << progs[i].direct_memp_s() << endl; - Mp.resize_s(progs[i].direct_memp_s()); - } - if (progs[i].direct_mem2_c() > M2.size_c()) - { - cerr << threadname << " needs more clear mod2 memory, resizing to " - << progs[i].direct_mem2_c() << endl; - M2.resize_c(progs[i].direct_mem2_c()); - } - if (progs[i].direct_memp_c() > Mp.size_c()) - { - cerr << threadname << " needs more clear modp memory, resizing to " - << progs[i].direct_memp_c() << endl; - Mp.resize_c(progs[i].direct_memp_c()); - } - if (progs[i].direct_memi_c() > Mi.size_c()) - { - cerr << threadname << " needs more clear integer memory, resizing to " - << progs[i].direct_memi_c() << endl; - Mi.resize_c(progs[i].direct_memi_c()); - } + M2.minimum_size(GF2N, progs[i], threadname); + Mp.minimum_size(MODP, progs[i], threadname); + Mi.minimum_size(INT, progs[i], threadname); } progs[0].print_offline_cost(); @@ -179,6 +154,10 @@ Machine::Machine(int my_number, int PortnumBase, string hostname, DataPositions Machine::run_tape(int thread_number, int tape_number, int arg, int line_number) { + if (thread_number >= (int)tinfo.size()) + throw Processor_Error("invalid thread number: " + to_string(thread_number) + "/" + to_string(tinfo.size())); + if (tape_number >= (int)progs.size()) + throw Processor_Error("invalid tape number: " + to_string(tape_number) + "/" + to_string(progs.size())); pthread_mutex_lock(&t_mutex[thread_number]); tinfo[thread_number].prognum=tape_number; tinfo[thread_number].arg=arg; @@ -303,10 +282,7 @@ void Machine::run() cerr << "Join timer: " << i << " " << join_timer[i].elapsed() << endl; cerr << "Finish timer: " << finish_timer.elapsed() << endl; cerr << "Process timer: " << proc_timer.elapsed() << endl; - cerr << "Time = " << timer[0].elapsed() << " seconds " << endl; - timer.erase(0); - for (map::iterator it = timer.begin(); it != timer.end(); it++) - cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds " << endl; + print_timers(); if (opening_sum < N.num_players() && !direct) cerr << "Summed at most " << opening_sum << " shares at once with indirect communication" << endl; @@ -359,4 +335,28 @@ void Machine::run() cerr << "End of prog" << endl; } +void BaseMachine::time() +{ + cout << "Elapsed time: " << timer[0].elapsed() << endl; +} + +void BaseMachine::start(int n) +{ + cout << "Starting timer " << n << " at " << timer[n].elapsed() + << " after " << timer[n].idle() << endl; + timer[n].start(); +} + +void BaseMachine::stop(int n) +{ + timer[n].stop(); + cout << "Stopped timer " << n << " at " << timer[n].elapsed() << endl; +} +void BaseMachine::print_timers() +{ + cerr << "Time = " << timer[0].elapsed() << " seconds " << endl; + timer.erase(0); + for (map::iterator it = timer.begin(); it != timer.end(); it++) + cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds " << endl; +} diff --git a/Processor/Machine.h b/Processor/Machine.h index 703d000b1..2fabbe47f 100644 --- a/Processor/Machine.h +++ b/Processor/Machine.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Machine.h @@ -21,7 +21,19 @@ #include using namespace std; -class Machine +class BaseMachine +{ +protected: + std::map timer; + void print_timers(); + +public: + void time(); + void start(int n); + void stop(int n); +}; + +class Machine : public BaseMachine { /* The mutex's lock the C-threads and then only release * then we an MPC thread is ready to run on the C-thread. @@ -58,7 +70,6 @@ class Machine Memory Mp; Memory Mi; - std::map timer; vector join_timer; Timer finish_timer; @@ -73,7 +84,7 @@ class Machine Machine(int my_number, int PortnumBase, string hostname, string progname, string memtype, int lgp, int lg2, bool direct, int opening_sum, bool parallel, - bool receive_threads, int max_broadcast); + bool receive_threads, int max_broadcast, CommsecKeysPackage *keys); DataPositions run_tape(int thread_number, int tape_number, int arg, int line_number); void join_tape(int thread_number); diff --git a/Processor/Memory.cpp b/Processor/Memory.cpp index 5ce3c564c..a1d1acb86 100644 --- a/Processor/Memory.cpp +++ b/Processor/Memory.cpp @@ -1,12 +1,31 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Processor/Memory.h" +#include "Processor/Instruction.h" #include "Math/gf2n.h" #include "Math/gfp.h" #include "Math/Integer.h" #include +template +void Memory::minimum_size(RegType reg_type, const Program& program, string threadname) +{ + const int* sizes = program.direct_mem(reg_type); + if (sizes[SECRET] > size_s()) + { + cerr << threadname << " needs more secret " << T::type_string() << " memory, resizing to " + << sizes[SECRET] << endl; + resize_s(sizes[SECRET]); + } + if (sizes[CLEAR] > size_c()) + { + cerr << threadname << " needs more clear " << T::type_string() << " memory, resizing to " + << sizes[CLEAR] << endl; + resize_c(sizes[CLEAR]); + } +} + #ifdef MEMPROTECT template void Memory::protect_s(unsigned int start, unsigned int end) diff --git a/Processor/Memory.h b/Processor/Memory.h index 21c1b1f87..3ad573451 100644 --- a/Processor/Memory.h +++ b/Processor/Memory.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _Memory #define _Memory @@ -14,6 +14,7 @@ template class Memory; template ostream& operator<<(ostream& s,const Memory& M); template istream& operator>>(istream& s,Memory& M); +#include "Processor/Program.h" #include "Math/Share.h" template class Memory @@ -72,6 +73,8 @@ class Memory { (void)start, (void)end; cerr << "Memory protection not activated" << endl; } #endif + void minimum_size(RegType reg_type, const Program& program, string threadname); + friend ostream& operator<< <>(ostream& s,const Memory& M); friend istream& operator>> <>(istream& s,Memory& M); diff --git a/Processor/Online-Thread.cpp b/Processor/Online-Thread.cpp index f3cfa51c0..e969c88a6 100644 --- a/Processor/Online-Thread.cpp +++ b/Processor/Online-Thread.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Processor/Program.h" @@ -48,14 +48,14 @@ void* Main_Func(void* ptr) if (machine.direct) { cerr << "Using direct communication. If computation stalls, use -m when compiling." << endl; - MC2 = new Direct_MAC_Check(*(tinfo->alpha2i), *(tinfo->Nms), num); - MCp = new Direct_MAC_Check(*(tinfo->alphapi), *(tinfo->Nms), num); + MC2 = new Direct_MAC_Check(*(tinfo->alpha2i),*(tinfo->Nms), num); + MCp = new Direct_MAC_Check(*(tinfo->alphapi),*(tinfo->Nms), num); } else if (machine.parallel) { cerr << "Using indirect communication with background threads." << endl; - MC2 = new Parallel_MAC_Check(*(tinfo->alpha2i), *(tinfo->Nms), num, machine.opening_sum); - MCp = new Parallel_MAC_Check(*(tinfo->alphapi), *(tinfo->Nms), num, machine.opening_sum); + MC2 = new Parallel_MAC_Check(*(tinfo->alpha2i),*(tinfo->Nms), num, machine.opening_sum); + MCp = new Parallel_MAC_Check(*(tinfo->alphapi),*(tinfo->Nms), num, machine.opening_sum); } else { @@ -64,16 +64,14 @@ void* Main_Func(void* ptr) MCp = new MAC_Check(*(tinfo->alphapi), machine.opening_sum); } - Processor Proc(tinfo->thread_num,DataF,P,*MC2,*MCp,machine); + // Allocate memory for first program before starting the clock + Processor Proc(tinfo->thread_num,DataF,P,*MC2,*MCp,machine,progs[0]); Share a,b,c; bool flag=true; int program=-3; // int exec=0; - // Allocate memory for first program before starting the clock - Proc.reset(progs[0].num_regs2(),progs[0].num_regsp(),progs[0].num_regi(),tinfo->arg); - // synchronize cerr << "Locking for sync of thread " << num << endl; pthread_mutex_lock(&t_mutex[num]); @@ -103,7 +101,7 @@ void* Main_Func(void* ptr) else { // RUN PROGRAM //printf("\tClient %d about to run %d in execution %d\n",num,program,exec); - Proc.reset(progs[program].num_regs2(),progs[program].num_regsp(),progs[program].num_regi(),tinfo->arg); + Proc.reset(progs[program],tinfo->arg); // Bits, Triples, Squares, and Inverses skipping DataF.seekg(tinfo->pos); diff --git a/Processor/Online-Thread.h b/Processor/Online-Thread.h index 485a073e1..7714258c4 100644 --- a/Processor/Online-Thread.h +++ b/Processor/Online-Thread.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _Online_Thread #define _Online_Thread diff --git a/Processor/PrivateOutput.cpp b/Processor/PrivateOutput.cpp index 909871d4e..ec9c9c25d 100644 --- a/Processor/PrivateOutput.cpp +++ b/Processor/PrivateOutput.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * PrivateOutput.cpp diff --git a/Processor/PrivateOutput.h b/Processor/PrivateOutput.h index 52c7522ff..2957727fc 100644 --- a/Processor/PrivateOutput.h +++ b/Processor/PrivateOutput.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * PrivateOutput.h diff --git a/Processor/Processor.cpp b/Processor/Processor.cpp index a62c78e26..4245cb58a 100644 --- a/Processor/Processor.cpp +++ b/Processor/Processor.cpp @@ -1,19 +1,23 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Processor/Processor.h" +#include "Networking/STS.h" #include "Auth/MAC_Check.h" #include "Auth/fake-stuff.h" +#include +#include Processor::Processor(int thread_num,Data_Files& DataF,Player& P, MAC_Check& MC2,MAC_Check& MCp,Machine& machine, - int num_regs2,int num_regsp,int num_regi) -: thread_num(thread_num),socket_is_open(false),DataF(DataF),P(P),MC2(MC2),MCp(MCp),machine(machine), - input2(*this,MC2),inputp(*this,MCp),privateOutput2(*this),privateOutputp(*this),sent(0),rounds(0) + const Program& program) +: thread_num(thread_num),DataF(DataF),P(P),MC2(MC2),MCp(MCp),machine(machine), + input2(*this,MC2),inputp(*this,MCp),privateOutput2(*this),privateOutputp(*this),sent(0),rounds(0), + external_clients(ExternalClients(P.my_num(), DataF.prep_data_dir)),binary_file_io(Binary_File_IO()) { - reset(num_regs2,num_regsp,num_regi,0); + reset(program,0); public_input.open(get_filename("Programs/Public-Input/",false).c_str()); private_input.open(get_filename("Player-Data/Private-Input-",true).c_str()); @@ -27,7 +31,6 @@ Processor::~Processor() cerr << "Sent " << sent << " elements in " << rounds << " rounds" << endl; } - string Processor::get_filename(const char* prefix, bool use_number) { stringstream filename; @@ -43,16 +46,15 @@ string Processor::get_filename(const char* prefix, bool use_number) } -void Processor::reset(int num_regs2,int num_regsp,int num_regi,int arg) +void Processor::reset(const Program& program,int arg) { - reg_max2 = num_regs2; - reg_maxp = num_regsp; - reg_maxi = num_regi; + reg_max2 = program.num_reg(GF2N); + reg_maxp = program.num_reg(MODP); + reg_maxi = program.num_reg(INT); C2.resize(reg_max2); Cp.resize(reg_maxp); S2.resize(reg_max2); Sp.resize(reg_maxp); Ci.resize(reg_maxi); this->arg = arg; - close_socket(); #ifdef DEBUG rw2.resize(2*reg_max2); @@ -65,59 +67,323 @@ void Processor::reset(int num_regs2,int num_regsp,int num_regi,int arg) } #include "Networking/sockets.h" - -// Set up a server socket for some client -void Processor::open_socket(int portnum_base) +#include "Math/Setup.h" + +// Write socket (typically SPDZ engine -> external client), for different register types. +// RegType and SecrecyType determines how registers are read and the socket stream is packed. +// If message_type is > 0, send message_type in bytes 0 - 3, to allow an external client to +// determine the data structure being sent in a message. +// Encryption is enabled if key material (for DH Auth Encryption and/or STS protocol) has been already setup. +void Processor::write_socket(const RegType reg_type, const SecrecyType secrecy_type, const bool send_macs, + int socket_id, int message_type, const vector& registers) { - if (!socket_is_open) + if (socket_id >= (int)external_clients.external_client_sockets.size()) { - socket_is_open = true; - sockaddr_in dest; - set_up_server_socket(dest, final_socket_fd, socket_fd, portnum_base + P.my_num()); + cerr << "No socket connection exists for client id " << socket_id << endl; + return; } -} + int m = registers.size(); + socket_stream.reset_write_head(); -void Processor::close_socket() -{ - if (socket_is_open) + //First 4 bytes is message_type (unless indicate not needed) + if (message_type != 0) { + socket_stream.store(message_type); + } + + for (int i = 0; i < m; i++) { - socket_is_open = false; - close_server_socket(final_socket_fd, socket_fd); + if (reg_type == MODP && secrecy_type == SECRET) { + // Send vector of secret shares and optionally macs + get_S_ref(registers[i]).get_share().pack(socket_stream); + if (send_macs) + get_S_ref(registers[i]).get_mac().pack(socket_stream); + } + else if (reg_type == MODP && secrecy_type == CLEAR) { + // Send vector of clear public field elements + get_C_ref(registers[i]).pack(socket_stream); + } + else if (reg_type == INT && secrecy_type == CLEAR) { + // Send vector of 32-bit clear ints + socket_stream.store((int&)get_Ci_ref(registers[i])); + } + else { + stringstream ss; + ss << "Write socket instruction with unknown reg type " << reg_type << + " and secrecy type " << secrecy_type << "." << endl; + throw Processor_Error(ss.str()); + } } -} -// Receive 32-bit int -void Processor::read_socket(int& x) -{ - octet bytes[4]; - receive(final_socket_fd, bytes, 4); - x = BYTES_TO_INT(bytes); + // Apply DH Auth encryption if session keys have been created. + map::iterator it = external_clients.symmetric_client_keys.find(socket_id); + if (it != external_clients.symmetric_client_keys.end()) { + socket_stream.encrypt(it->second); + } + + // Apply STS commsec encryption if session keys have been created. + try { + maybe_encrypt_sequence(socket_id); + socket_stream.Send(external_clients.external_client_sockets[socket_id]); + } + catch (bad_value& e) { + cerr << "Send error thrown when writing " << m << " values of type " << reg_type << " to socket id " + << socket_id << "." << endl; + } } -// Send 32-bit int -void Processor::write_socket(int x) + +// Receive vector of 32-bit clear ints +void Processor::read_socket_ints(int client_id, const vector& registers) { - octet bytes[4]; - INT_TO_BYTES(bytes, x); - send(final_socket_fd, bytes, 4); + if (client_id >= (int)external_clients.external_client_sockets.size()) + { + cerr << "No socket connection exists for client id " << client_id << endl; + return; + } + + int m = registers.size(); + socket_stream.reset_write_head(); + socket_stream.Receive(external_clients.external_client_sockets[client_id]); + maybe_decrypt_sequence(client_id); + for (int i = 0; i < m; i++) + { + int val; + socket_stream.get(val); + write_Ci(registers[i], (long)val); + } } -// Receive field element +// Receive vector of public field elements template -void Processor::read_socket(T& x) +void Processor::read_socket_vector(int client_id, const vector& registers) { + if (client_id >= (int)external_clients.external_client_sockets.size()) + { + cerr << "No socket connection exists for client id " << client_id << endl; + return; + } + + int m = registers.size(); socket_stream.reset_write_head(); - socket_stream.Receive(final_socket_fd); - x.unpack(socket_stream); + socket_stream.Receive(external_clients.external_client_sockets[client_id]); + maybe_decrypt_sequence(client_id); + for (int i = 0; i < m; i++) + { + get_C_ref(registers[i]).unpack(socket_stream); + } } -// Send field element +// Receive vector of field element shares over private channel template -void Processor::write_socket(const T& x) +void Processor::read_socket_private(int client_id, const vector& registers, bool read_macs) { + if (client_id >= (int)external_clients.external_client_sockets.size()) + { + cerr << "No socket connection exists for client id " << client_id << endl; + return; + } + int m = registers.size(); socket_stream.reset_write_head(); - x.pack(socket_stream); - socket_stream.Send(final_socket_fd); + socket_stream.Receive(external_clients.external_client_sockets[client_id]); + maybe_decrypt_sequence(client_id); + + map::iterator it = external_clients.symmetric_client_keys.find(client_id); + if (it != external_clients.symmetric_client_keys.end()) + { + socket_stream.decrypt(it->second); + } + for (int i = 0; i < m; i++) + { + temp.ansp.unpack(socket_stream); + get_Sp_ref(registers[i]).set_share(temp.ansp); + if (read_macs) + { + temp.ansp.unpack(socket_stream); + get_Sp_ref(registers[i]).set_mac(temp.ansp); + } + } +} + +// Read socket for client public key as 8 ints, calculate session key for client. +void Processor::read_client_public_key(int client_id, const vector& registers) { + + read_socket_ints(client_id, registers); + + // After read into registers, need to extract values + vector client_public_key (registers.size(), 0); + for(unsigned int i = 0; i < registers.size(); i++) { + client_public_key[i] = (int&)get_Ci_ref(registers[i]); + } + + external_clients.generate_session_key_for_client(client_id, client_public_key); +} + +void Processor::init_secure_socket_internal(int client_id, const vector& registers) { + external_clients.symmetric_client_commsec_send_keys.erase(client_id); + external_clients.symmetric_client_commsec_recv_keys.erase(client_id); + unsigned char client_public_bytes[crypto_sign_PUBLICKEYBYTES]; + sts_msg1_t m1; + sts_msg2_t m2; + sts_msg3_t m3; + + external_clients.load_server_keys_once(); + external_clients.require_ed25519_keys(); + + // Validate inputs and state + if(registers.size() != 8) { + throw "Invalid call to init_secure_socket."; + } + if (client_id >= (int)external_clients.external_client_sockets.size()) + { + cerr << "No socket connection exists for client id " << client_id << endl; + throw "No socket connection exists for client"; + } + + // Extract client long term public key into bytes + vector client_public_key (registers.size(), 0); + for(unsigned int i = 0; i < registers.size(); i++) { + client_public_key[i] = (int&)get_Ci_ref(registers[i]); + } + external_clients.curve25519_ints_to_bytes(client_public_bytes, client_public_key); + + // Start Station to Station Protocol + STS ke(client_public_bytes, external_clients.server_publickey_ed25519, external_clients.server_secretkey_ed25519); + m1 = ke.send_msg1(); + socket_stream.reset_write_head(); + socket_stream.append(m1.bytes, sizeof m1.bytes); + socket_stream.Send(external_clients.external_client_sockets[client_id]); + socket_stream.ReceiveExpected(external_clients.external_client_sockets[client_id], + 96); + socket_stream.consume(m2.pubkey, sizeof m2.pubkey); + socket_stream.consume(m2.sig, sizeof m2.sig); + m3 = ke.recv_msg2(m2); + socket_stream.reset_write_head(); + socket_stream.append(m3.bytes, sizeof m3.bytes); + socket_stream.Send(external_clients.external_client_sockets[client_id]); + + // Use results of STS to generate send and receive keys. + vector sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES); + vector recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES); + external_clients.symmetric_client_commsec_send_keys[client_id] = make_pair(sendKey,0); + external_clients.symmetric_client_commsec_recv_keys[client_id] = make_pair(recvKey,0); +} + +void Processor::init_secure_socket(int client_id, const vector& registers) { + + try { + init_secure_socket_internal(client_id, registers); + } catch (char const *e) { + cerr << "STS initiator role failed with: " << e << endl; + throw Processor_Error("STS initiator failed"); + } +} + +void Processor::resp_secure_socket(int client_id, const vector& registers) { + try { + resp_secure_socket_internal(client_id, registers); + } catch (char const *e) { + cerr << "STS responder role failed with: " << e << endl; + throw Processor_Error("STS responder failed"); + } +} + +void Processor::resp_secure_socket_internal(int client_id, const vector& registers) { + external_clients.symmetric_client_commsec_send_keys.erase(client_id); + external_clients.symmetric_client_commsec_recv_keys.erase(client_id); + unsigned char client_public_bytes[crypto_sign_PUBLICKEYBYTES]; + sts_msg1_t m1; + sts_msg2_t m2; + sts_msg3_t m3; + + external_clients.load_server_keys_once(); + external_clients.require_ed25519_keys(); + + // Validate inputs and state + if(registers.size() != 8) { + throw "Invalid call to init_secure_socket."; + } + if (client_id >= (int)external_clients.external_client_sockets.size()) + { + cerr << "No socket connection exists for client id " << client_id << endl; + throw "No socket connection exists for client"; + } + vector client_public_key (registers.size(), 0); + for(unsigned int i = 0; i < registers.size(); i++) { + client_public_key[i] = (int&)get_Ci_ref(registers[i]); + } + external_clients.curve25519_ints_to_bytes(client_public_bytes, client_public_key); + + // Start Station to Station Protocol for the responder + STS ke(client_public_bytes, external_clients.server_publickey_ed25519, external_clients.server_secretkey_ed25519); + socket_stream.reset_read_head(); + socket_stream.ReceiveExpected(external_clients.external_client_sockets[client_id], + 32); + socket_stream.consume(m1.bytes, sizeof m1.bytes); + m2 = ke.recv_msg1(m1); + socket_stream.reset_write_head(); + socket_stream.append(m2.pubkey, sizeof m2.pubkey); + socket_stream.append(m2.sig, sizeof m2.sig); + socket_stream.Send(external_clients.external_client_sockets[client_id]); + + socket_stream.ReceiveExpected(external_clients.external_client_sockets[client_id], + 64); + socket_stream.consume(m3.bytes, sizeof m3.bytes); + ke.recv_msg3(m3); + + // Use results of STS to generate send and receive keys. + vector recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES); + vector sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES); + external_clients.symmetric_client_commsec_recv_keys[client_id] = make_pair(recvKey,0); + external_clients.symmetric_client_commsec_send_keys[client_id] = make_pair(sendKey,0); +} + +// Read share data from a file starting at file_pos until registers filled. +// file_pos_register is written with new file position (-1 is eof). +// Tolerent to no file if no shares yet persisted. +template +void Processor::read_shares_from_file(int start_file_posn, int end_file_pos_register, const vector& data_registers) { + string filename; + filename = "Persistence/Transactions-P" + to_string(P.my_num()) + ".data"; + + unsigned int size = data_registers.size(); + + vector< Share > outbuf(size); + + int end_file_posn = start_file_posn; + + try { + binary_file_io.read_from_file(filename, outbuf, start_file_posn, end_file_posn); + + for (unsigned int i = 0; i < size; i++) + { + get_Sp_ref(data_registers[i]).set_share(outbuf[i].get_share()); + get_Sp_ref(data_registers[i]).set_mac(outbuf[i].get_mac()); + } + + write_Ci(end_file_pos_register, (long)end_file_posn); + } + catch (file_missing& e) { + cerr << "Got file missing error, will return -2. " << e.what() << endl; + write_Ci(end_file_pos_register, (long)-2); + } +} + +// Append share data in data_registers to end of file. Expects Persistence directory to exist. +template +void Processor::write_shares_to_file(const vector& data_registers) { + string filename; + filename = "Persistence/Transactions-P" + to_string(P.my_num()) + ".data"; + + unsigned int size = data_registers.size(); + + vector< Share > inpbuf (size); + + for (unsigned int i = 0; i < size; i++) + { + inpbuf[i] = get_S_ref(data_registers[i]); + } + + binary_file_io.write_to_file(filename, inpbuf); } template @@ -180,12 +446,6 @@ void Processor::POpen_Stop(const vector& reg,const Player& P,MAC_Check& rounds++; } - - - - - - ostream& operator<<(ostream& s,const Processor& P) { s << "Processor State" << endl; @@ -196,7 +456,7 @@ ostream& operator<<(ostream& s,const Processor& P) P.read_C2(i).output(s,true); s << "\t"; P.read_S2(i).output(s,true); - s << endl; + s << endl; } s << "Char p Registers" << endl; s << "Val\tClearReg\tSharedReg" << endl; @@ -205,18 +465,37 @@ ostream& operator<<(ostream& s,const Processor& P) P.read_Cp(i).output(s,true); s << "\t"; P.read_Sp(i).output(s,true); - s << endl; + s << endl; } return s; } +void Processor::maybe_decrypt_sequence(int client_id) +{ + map,uint64_t> >::iterator it_cs = external_clients.symmetric_client_commsec_recv_keys.find(client_id); + if (it_cs != external_clients.symmetric_client_commsec_recv_keys.end()) + { + socket_stream.decrypt_sequence(&it_cs->second.first[0], it_cs->second.second); + it_cs->second.second++; + } +} + +void Processor::maybe_encrypt_sequence(int client_id) +{ + map,uint64_t> >::iterator it_cs = external_clients.symmetric_client_commsec_send_keys.find(client_id); + if (it_cs != external_clients.symmetric_client_commsec_send_keys.end()) + { + socket_stream.encrypt_sequence(&it_cs->second.first[0], it_cs->second.second); + it_cs->second.second++; + } +} template void Processor::POpen_Start(const vector& reg,const Player& P,MAC_Check& MC,int size); template void Processor::POpen_Start(const vector& reg,const Player& P,MAC_Check& MC,int size); template void Processor::POpen_Stop(const vector& reg,const Player& P,MAC_Check& MC,int size); template void Processor::POpen_Stop(const vector& reg,const Player& P,MAC_Check& MC,int size); -template void Processor::read_socket(gfp& x); -template void Processor::read_socket(gf2n& x); -template void Processor::write_socket(const gfp& x); -template void Processor::write_socket(const gf2n& x); +template void Processor::read_socket_private(int client_id, const vector& registers, bool send_macs); +template void Processor::read_socket_vector(int client_id, const vector& registers); +template void Processor::read_shares_from_file(int start_file_pos, int end_file_pos_register, const vector& data_registers); +template void Processor::write_shares_to_file(const vector& data_registers); diff --git a/Processor/Processor.h b/Processor/Processor.h index a4a0d37e1..b1482a378 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _Processor @@ -19,19 +19,43 @@ #include "Input.h" #include "PrivateOutput.h" #include "Machine.h" +#include "ExternalClients.h" +#include "Binary_File_IO.h" +#include "Instruction.h" #include -class Processor +class ProcessorBase +{ + // Stack + stack stacki; + +protected: + // Optional argument to tape + int arg; + +public: + void pushi(long x) { stacki.push(x); } + void popi(long& x) { x = stacki.top(); stacki.pop(); } + + int get_arg() const + { + return arg; + } + + void set_arg(int new_arg) + { + arg=new_arg; + } +}; + +class Processor : public ProcessorBase { vector C2; vector Cp; vector > S2; vector > Sp; vector Ci; - - // Stack - stack stacki; // This is the vector of partially opened values and shares we need to store // as the Open commands are split in two @@ -43,13 +67,8 @@ class Processor int reg_max2,reg_maxp,reg_maxi; int thread_num; - // Optional argument to tape - int arg; - - // For reading/reading data from a socket (i.e. external party to SPDZ) + // Data structure used for reading/writing data to/from a socket (i.e. an external party to SPDZ) octetStream socket_stream; - int socket_fd, final_socket_fd; - bool socket_is_open; #ifdef DEBUG vector rw2; @@ -91,14 +110,17 @@ class Processor int sent, rounds; + ExternalClients external_clients; + Binary_File_IO binary_file_io; + static const int reg_bytes = 4; - void reset(int num_regs2,int num_regsp,int num_regi,int arg); // Reset the state of the processor + void reset(const Program& program,int arg); // Reset the state of the processor string get_filename(const char* basename, bool use_number); Processor(int thread_num,Data_Files& DataF,Player& P, MAC_Check& MC2,MAC_Check& MCp,Machine& machine, - int num_regs2 = 256,int num_regsp = 256,int num_regi = 256); + const Program& program); ~Processor(); int get_thread_num() @@ -106,19 +128,6 @@ class Processor return thread_num; } - int get_arg() const - { - return arg; - } - - void set_arg(int new_arg) - { - arg=new_arg; - } - - void pushi(long x) { stacki.push(x); } - void popi(long& x) { x = stacki.top(); stacki.pop(); } - #ifdef DEBUG const gf2n& read_C2(int i) const { if (rw2[i]==0) @@ -226,16 +235,29 @@ class Processor template Share& get_S_ref(int i); template T& get_C_ref(int i); - // Access to sockets for reading clear/shared data - void open_socket(int portnum_base); - void close_socket(); - void read_socket(int& x); - void write_socket(int x); + // Access to external client sockets for reading clear/shared data + void read_socket_ints(int client_id, const vector& registers); + // Setup client public key + void read_client_public_key(int client_id, const vector& registers); + void init_secure_socket(int client_id, const vector& registers); + void init_secure_socket_internal(int client_id, const vector& registers); + void resp_secure_socket(int client_id, const vector& registers); + void resp_secure_socket_internal(int client_id, const vector& registers); + + void write_socket(const RegType reg_type, const SecrecyType secrecy_type, const bool send_macs, + int socket_id, int message_type, const vector& registers); + template - void read_socket(T& x); + void read_socket_vector(int client_id, const vector& registers); template - void write_socket(const T& x); + void read_socket_private(int client_id, const vector& registers, bool send_macs); + // Read and write secret numeric data to file (name hardcoded at present) + template + void read_shares_from_file(int start_file_pos, int end_file_pos_register, const vector& data_registers); + template + void write_shares_to_file(const vector& data_registers); + // Access to PO (via calls to POpen start/stop) template void POpen_Start(const vector& reg,const Player& P,MAC_Check& MC,int size); @@ -245,6 +267,10 @@ class Processor // Print the processor state friend ostream& operator<<(ostream& s,const Processor& P); + + private: + void maybe_decrypt_sequence(int client_id); + void maybe_encrypt_sequence(int client_id); }; template<> inline Share& Processor::get_S_ref(int i) { return get_S2_ref(i); } diff --git a/Processor/Program.cpp b/Processor/Program.cpp index 9b132305b..1d3a47f97 100644 --- a/Processor/Program.cpp +++ b/Processor/Program.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Processor/Program.h" @@ -7,23 +7,24 @@ void Program::compute_constants() { - max_reg2 = 0; - max_regp = 0; - max_regi = 0; for (int reg_type = 0; reg_type < MAX_REG_TYPE; reg_type++) - for (int sec_type = 0; sec_type < MAX_SECRECY_TYPE; sec_type++) - max_mem[reg_type][sec_type] = 0; + { + max_reg[reg_type] = 0; + for (int sec_type = 0; sec_type < MAX_SECRECY_TYPE; sec_type++) + max_mem[reg_type][sec_type] = 0; + } for (unsigned int i=0; i= MAX_NUM_CLIENTS) + finish.reveal() == 0 + + winning_client_id = determine_winner(number_clients, client_values, client_ids) + + # print_ln('Found winner, index: %s.', winning_client_id.reveal()) + + write_winner_to_clients(client_sockets, number_clients, winning_client_id) + + return True + +main() diff --git a/Programs/Source/bankers_bonus_commsec.mpc b/Programs/Source/bankers_bonus_commsec.mpc new file mode 100644 index 000000000..a60a2e613 --- /dev/null +++ b/Programs/Source/bankers_bonus_commsec.mpc @@ -0,0 +1,117 @@ +# (C) 2017 University of Bristol. See License.txt +# coding: latin-1 +""" + Solve Bankers bonus, aka Millionaires problem. + to deduce the maximum value from a range of integer input. + + Demonstrate clients external to computing parties supplying input and receiving + an authenticated result. See bankers-bonus-commsec-client.cpp for client (and setup instructions). + + For an implementation without communications security see bankers_bonus.mpc. + + Wait for MAX_NUM_CLIENTS to join the game or client finish flag to be sent + before calculating the maximum. + + Note each client connects in a single thread and so is potentially blocked. + + Each round / game will reset and so this runs indefinitiely. +""" + +from Compiler.types import sint, regint, Array, Matrix, MemValue +from Compiler.instructions import listen, acceptclientconnection +from Compiler.library import print_ln, do_while, if_e, else_, for_range +from Compiler.util import if_else + +PORTNUM = 14000 +MAX_NUM_CLIENTS = 8 + +def accept_client_input(): + """ + Wait for socket connection and read for client public key. + send share of random value, receive input and deduce share. + Expect 3 inputs: unique id, bonus value and flag to indicate end of this round. + """ + client_socket_id = regint() + acceptclientconnection(client_socket_id, PORTNUM) + + # Crypto setup + public_signing_key = regint.read_from_socket(client_socket_id, 8) + public_key = regint.read_client_public_key(client_socket_id) + regint.resp_secure_socket(client_socket_id,*public_signing_key) + + client_inputs = sint.receive_from_client(3, client_socket_id) + + return client_socket_id, client_inputs[0], client_inputs[1], client_inputs[2] + + +def determine_winner(number_clients, client_values, client_ids): + """Work out and return client_id which corresponds to max client_value""" + max_value = Array(1, sint) + max_value[0] = client_values[0] + win_client_id = Array(1, sint) + win_client_id[0] = client_ids[0] + + @for_range(number_clients-1) + def loop_body(i): + # Is this client input a new maximum, will be sint(1) if true, else sint(0) + is_new_max = max_value[0] < client_values[i+1] + # Keep latest max_value + max_value[0] = if_else(is_new_max, client_values[i+1], max_value[0]) + # Keep current winning client id + win_client_id[0] = if_else(is_new_max, client_ids[i+1], win_client_id[0]) + + return win_client_id[0] + + +def write_winner_to_clients(sockets, number_clients, winning_client_id): + """Send share of winning client id to all clients who joined game.""" + + # Setup authenticate result using share of random. + # client can validate ∑ winning_client_id * ∑ rnd_from_triple = ∑ auth_result + rnd_from_triple = sint.get_random_triple()[0] + auth_result = winning_client_id * rnd_from_triple + + @for_range(number_clients) + def loop_body(i): + sint.write_shares_to_socket(sockets[i], [winning_client_id, rnd_from_triple, auth_result]) + + +def main(): + """Listen in while loop for players to join a game. + Once maxiumum reached or have notified that round finished, run comparison and return result.""" + # Start listening for client socket connections + listen(PORTNUM) + print_ln('Listening for client connections on base port %s', PORTNUM) + + @do_while + def game_loop(): + print_ln('Starting a new round of the game.') + + # Clients socket id (integer). + client_sockets = Array(MAX_NUM_CLIENTS, regint) + # Number of clients + number_clients = MemValue(regint(0)) + # Clients secret input. + client_values = Array(MAX_NUM_CLIENTS, sint) + # Client ids to identity client + client_ids = Array(MAX_NUM_CLIENTS, sint) + + # Loop round waiting for each client to connect + @do_while + def client_connections(): + + client_sockets[number_clients], client_ids[number_clients], client_values[number_clients], finish = accept_client_input() + number_clients.write(number_clients+1) + + # continue while both expressions are false + return (number_clients >= MAX_NUM_CLIENTS) + finish.reveal() == 0 + + winning_client_id = determine_winner(number_clients, client_values, client_ids) + + print_ln('Found winner, index: %s.', winning_client_id.reveal()) + + write_winner_to_clients(client_sockets, number_clients, winning_client_id) + + return True + +main() diff --git a/Programs/Source/dijkstra_tutorial.mpc b/Programs/Source/dijkstra_tutorial.mpc index c595087b9..dbbef223d 100644 --- a/Programs/Source/dijkstra_tutorial.mpc +++ b/Programs/Source/dijkstra_tutorial.mpc @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt import dijkstra from path_oram import OptimalORAM diff --git a/Programs/Source/fixed_point_tutorial.mpc b/Programs/Source/fixed_point_tutorial.mpc index 5994b626c..640dbc911 100644 --- a/Programs/Source/fixed_point_tutorial.mpc +++ b/Programs/Source/fixed_point_tutorial.mpc @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt program.bit_length = 80 print "program.bit_length: ", program.bit_length diff --git a/Programs/Source/gale-shapley_tutorial.mpc b/Programs/Source/gale-shapley_tutorial.mpc index 2b296d41c..52c345ec1 100644 --- a/Programs/Source/gale-shapley_tutorial.mpc +++ b/Programs/Source/gale-shapley_tutorial.mpc @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt from Compiler import gs from Compiler.path_oram import OptimalORAM diff --git a/Programs/Source/oram_tutorial.mpc b/Programs/Source/oram_tutorial.mpc index a3b9a62a1..c974578e1 100644 --- a/Programs/Source/oram_tutorial.mpc +++ b/Programs/Source/oram_tutorial.mpc @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt from path_oram import OptimalORAM diff --git a/Programs/Source/tpmpc_tutorial.mpc b/Programs/Source/tpmpc_tutorial.mpc index befedf441..b75afb0c8 100644 --- a/Programs/Source/tpmpc_tutorial.mpc +++ b/Programs/Source/tpmpc_tutorial.mpc @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt """ Example programs used in the SPDZ tutorial at the TPMPC 2017 workshop in Bristol. diff --git a/Programs/Source/tutorial.mpc b/Programs/Source/tutorial.mpc index 05f394d97..f4b455586 100644 --- a/Programs/Source/tutorial.mpc +++ b/Programs/Source/tutorial.mpc @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt def test(actual, expected): if isinstance(actual, (sint, sgf2n)): diff --git a/Programs/Source/vickrey.mpc b/Programs/Source/vickrey.mpc index 6494ee4c1..586c27a60 100644 --- a/Programs/Source/vickrey.mpc +++ b/Programs/Source/vickrey.mpc @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt import util from Compiler import types diff --git a/README.md b/README.md index f593785a5..e15f636b9 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,15 @@ -(C) 2016 University of Bristol. See License.txt +(C) 2017 University of Bristol. See License.txt Software for the SPDZ and MASCOT secure multi-party computation protocols. See `Programs/Source/` for some example MPC programs, and `tutorial.md` for -a basic tutorial. More examples and documentation will be available in the -coming weeks. +a basic tutorial. See also https://www.cs.bris.ac.uk/Research/CryptographySecurity/SPDZ #### Requirements: - GCC - MPIR library, compiled with C++ support (use flag --enable-cxx when running configure) + - libsodium library, tested against 1.0.11 - CPU supporting AES-NI and PCLMUL - Python 2.x, ideally with `gmpy` package (for testing) diff --git a/Scripts/gen_input_f2n.cpp b/Scripts/gen_input_f2n.cpp index 4544d514c..78e8f1e34 100644 --- a/Scripts/gen_input_f2n.cpp +++ b/Scripts/gen_input_f2n.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include #include diff --git a/Scripts/gen_input_fp.cpp b/Scripts/gen_input_fp.cpp index cc47f8a4e..ab11fbed2 100644 --- a/Scripts/gen_input_fp.cpp +++ b/Scripts/gen_input_fp.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include #include diff --git a/Scripts/run-common.sh b/Scripts/run-common.sh index f946f3536..453bffcd2 100644 --- a/Scripts/run-common.sh +++ b/Scripts/run-common.sh @@ -1,11 +1,13 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt run_player() { port=$((RANDOM%10000+10000)) - >&2 echo Port $port bin=$1 shift + if ! test -e $SPDZROOT/logs; then + mkdir $SPDZROOT/logs + fi if test $bin = Player-Online.x; then params="$* -pn $port -h localhost" else @@ -14,13 +16,16 @@ run_player() { if test $bin = Player-KeyGen.x -a ! -e Player-Data/Params-Data; then ./Setup.x $players $size 40 fi - >&2 echo Parameters $params + >&2 echo Running $SPDZROOT/Server.x $players $port $SPDZROOT/Server.x $players $port & rem=$(($players - 2)) for i in $(seq 0 $rem); do + echo "trying with player $i" + >&2 echo Running $prefix $SPDZROOT/$bin $i $params $prefix $SPDZROOT/$bin $i $params 2>&1 | tee $SPDZROOT/logs/$i & done last_player=$(($players - 1)) + >&2 echo Running $prefix $SPDZROOT/$bin $last_player $params $prefix $SPDZROOT/$bin $last_player $params > $SPDZROOT/logs/$last_player 2>&1 || return 1 } diff --git a/Scripts/run-online.sh b/Scripts/run-online.sh index eee9f6e5b..7865713d6 100755 --- a/Scripts/run-online.sh +++ b/Scripts/run-online.sh @@ -1,6 +1,6 @@ #!/bin/bash -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt HERE=$(cd `dirname $0`; pwd) SPDZROOT=$HERE/.. diff --git a/Scripts/setup-online.sh b/Scripts/setup-online.sh index 9f5328df5..85082e244 100755 --- a/Scripts/setup-online.sh +++ b/Scripts/setup-online.sh @@ -1,6 +1,6 @@ #!/bin/bash -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt HERE=$(cd `dirname $0`; pwd) SPDZROOT=$HERE/.. diff --git a/Server.cpp b/Server.cpp index 103887cf9..6189b9268 100644 --- a/Server.cpp +++ b/Server.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Networking/sockets.h" @@ -18,16 +18,46 @@ int nmachines; +/* + * Get the client ip number on the socket connection for client i. + */ +void get_ip(int num) +{ + struct sockaddr_storage addr; + socklen_t len = sizeof addr; + + getpeername(socket_num[num], (struct sockaddr*)&addr, &len); + + // supports both IPv4 and IPv6: + char ipstr[INET6_ADDRSTRLEN]; + if (addr.ss_family == AF_INET) { + struct sockaddr_in *s = (struct sockaddr_in *)&addr; + inet_ntop(AF_INET, &s->sin_addr, ipstr, sizeof ipstr); + } else { // AF_INET6 + struct sockaddr_in6 *s = (struct sockaddr_in6 *)&addr; + inet_ntop(AF_INET6, &s->sin6_addr, ipstr, sizeof ipstr); + } + + names[num]=new octet[512]; + strncpy((char*)names[num], ipstr, INET6_ADDRSTRLEN); + + cerr << "Client IP address: " << names[num] << endl; +} + + void get_name(int num) { // Now all machines are set up, send GO to start them. send(socket_num[num], GO); cerr << "Player " << num << " started." << endl; - // Receive Name - names[num]=new octet[512]; - receive(socket_num[num],names[num],512); - cerr << "Player " << num << " is on machine " << names[num] << endl; + // Receive name sent by client (legacy) - not used here + octet my_name[512]; + receive(socket_num[num],my_name,512); + cerr << "Player " << num << " sent name (info only) " << my_name << endl; + + // Get client IP + get_ip(num); } @@ -42,9 +72,6 @@ void send_names(int num) } - - - /* Takes command line arguments of - Number of machines connecting - Base PORTNUM address @@ -71,6 +98,7 @@ int main(int argc,char **argv) // port number one lower to avoid conflict with players ServerSocket server(PortnumBase - 1); + server.init(); // set up connections for (i=0; i +#include +#include + +namespace Config { + class ConfigError : public std::exception + { + std::string s; + + public: + ConfigError(std::string ss) : s(ss) {} + ~ConfigError() throw () {} + const char* what() const throw() { return s.c_str(); } + }; + + static void output(const vector &vec, ofstream &of) + { + copy(vec.begin(), vec.end(), ostreambuf_iterator(of)); + } + void print_vector(const vector &vec) + { + cerr << hex; + for(size_t i = 0; i < vec.size(); i ++ ) { + cerr << setfill('0') << setw(2) << (int)vec[i]; + } + cerr << dec << endl; + } + + uint64_t getW64le(ifstream &infile) + { + uint8_t buf[8]; + uint64_t res=0; + infile.read((char*)buf,sizeof buf); + + if (!infile.good()) + throw ConfigError("getW64le: could not read from config file"); + + for(size_t i = 0; i < sizeof buf ; i ++ ) { + res |= ((uint64_t)buf[i]) << i*8; + } + + return res; + } + + void putW64le(ofstream &outf, uint64_t nr) + { + char buf[8]; + for(int i=0;i<8;i++) { + char byte = (uint8_t)(nr >> (i*8)); + buf[i] = (char)byte; + } + outf.write(buf,sizeof buf); + } + + const string default_player_config_file_prefix = "Player-SPDZ-Keys-P"; + string player_config_file(int player_number) + { + stringstream filename; + filename << default_player_config_file_prefix << player_number; + return filename.str(); + } + + void read_player_config(string cfgdir,int my_number,vector pubkeys,secret_signing_key mykey, public_signing_key mypubkey) + { + string filename; + filename = cfgdir + player_config_file(my_number); + ifstream infile(filename.c_str(), ios::in | ios::binary); + + infile.seekg(crypto_box_PUBLICKEYBYTES + crypto_box_SECRETKEYBYTES); + mypubkey.resize(crypto_sign_PUBLICKEYBYTES); + infile.read((char*)&mypubkey[0], crypto_sign_PUBLICKEYBYTES); + mykey.resize(crypto_sign_SECRETKEYBYTES); + infile.read((char*)&mykey[0], crypto_sign_SECRETKEYBYTES); + + // If we've failed by this point, abort. After this point we'll + // just try to read optional content. + if (!infile.good()) { + throw ConfigError("Could not parse player config file."); + } + + // Deal gracefully with absence of additional key material + try { + uint64_t nrClients = getW64le(infile); + infile.ignore(nrClients * (crypto_sign_PUBLICKEYBYTES + crypto_box_PUBLICKEYBYTES)); + uint64_t nrPlayers = getW64le(infile); + pubkeys.resize(nrPlayers); + for(size_t i=0; i client_pubs, vector client_signing_pubs + , vector player_pubs, vector player_signing_pubs) + { + stringstream filename; + filename << config_dir << "Player-SPDZ-Keys-P" << player_number; + ofstream outf(filename.str().c_str(), ios::out | ios::binary); + if (outf.fail()) + throw file_error(filename.str().c_str()); + if(crypto_box_PUBLICKEYBYTES != my_pub.size() || + crypto_box_SECRETKEYBYTES != my_priv.size() || + crypto_sign_PUBLICKEYBYTES != my_signing_pub.size() || + crypto_sign_SECRETKEYBYTES != my_signing_priv.size()) { + throw "Invalid key sizes"; + } else if(client_pubs.size() != client_signing_pubs.size()) { + throw "Incorrect number of client keys"; + } else if(player_pubs.size() != player_signing_pubs.size()) { + throw "Incorrect number of player keys"; + } else { + for(size_t i = 0; i < client_pubs.size(); i++) { + if(crypto_box_PUBLICKEYBYTES != client_pubs[i].size() || + crypto_sign_PUBLICKEYBYTES != client_signing_pubs[i].size()) { + throw "Incorrect size of client key."; + } + } + for(size_t i = 0; i < player_pubs.size(); i++) { + if(crypto_box_PUBLICKEYBYTES != player_pubs[i].size() || + crypto_sign_PUBLICKEYBYTES != player_signing_pubs[i].size()) { + throw "Incorrect size of player key."; + } + } + } + // Write public and secret X25519 keys + output(my_pub, outf); + output(my_priv, outf); + output(my_signing_pub, outf); + output(my_signing_priv, outf); + + putW64le(outf, (uint64_t)client_pubs.size()); + // Write all client public keys + for (size_t j = 0; j < client_pubs.size(); j++) { + output(client_pubs[j], outf); + output(client_signing_pubs[j], outf); + } + putW64le(outf, (uint64_t)player_pubs.size()); + for (size_t j = 0; j < player_pubs.size(); j++) { + output(player_pubs[j], outf); + output(player_signing_pubs[j], outf); + } + outf.flush(); + outf.close(); + } +} diff --git a/Tools/Config.h b/Tools/Config.h new file mode 100644 index 000000000..1ad725b61 --- /dev/null +++ b/Tools/Config.h @@ -0,0 +1,20 @@ +#include "Tools/octetStream.h" +#include "Networking/Player.h" +#include +namespace Config { + typedef vector public_key; + typedef vector public_signing_key; + typedef vector secret_key; + typedef vector secret_signing_key; + void read_player_config(string cfgdir,int my_number,vector pubkeys,secret_signing_key mysecretkey, public_signing_key mypubkey); + void write_player_config_file(string config_dir + ,int player_number, public_key my_pub, secret_key my_priv + , public_signing_key my_signing_pub, secret_signing_key my_signing_priv + , vector client_pubs, vector client_signing_pubs + , vector player_pubs, vector player_signing_pubs); + uint64_t getW64le(ifstream &infile); + void putW64le(ofstream &outf, uint64_t nr); + extern const string default_player_config_file_prefix; + string player_config_file(int player_number); + void print_vector(const vector &vec); +} diff --git a/Tools/Lock.cpp b/Tools/Lock.cpp index 73b9152c8..79221017b 100644 --- a/Tools/Lock.cpp +++ b/Tools/Lock.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Lock.cpp diff --git a/Tools/Lock.h b/Tools/Lock.h index 06c3b4b88..59d533f41 100644 --- a/Tools/Lock.h +++ b/Tools/Lock.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Lock.h diff --git a/Tools/MMO.cpp b/Tools/MMO.cpp index a940305a1..040febaa1 100644 --- a/Tools/MMO.cpp +++ b/Tools/MMO.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * MMO.cpp diff --git a/Tools/MMO.h b/Tools/MMO.h index 3f6fe3e8f..999383362 100644 --- a/Tools/MMO.h +++ b/Tools/MMO.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * MMO.h diff --git a/Tools/Signal.cpp b/Tools/Signal.cpp index 420fdd157..6515190f3 100644 --- a/Tools/Signal.cpp +++ b/Tools/Signal.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Signal.cpp diff --git a/Tools/Signal.h b/Tools/Signal.h index 27fbacc5b..1507d5fa1 100644 --- a/Tools/Signal.h +++ b/Tools/Signal.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Signal.h diff --git a/Tools/WaitQueue.h b/Tools/WaitQueue.h index 20722fd2e..e07b9ef3e 100644 --- a/Tools/WaitQueue.h +++ b/Tools/WaitQueue.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * WaitQueue.h diff --git a/Tools/aes-ni.cpp b/Tools/aes-ni.cpp index e1125efdc..6af4e596c 100644 --- a/Tools/aes-ni.cpp +++ b/Tools/aes-ni.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "aes.h" diff --git a/Tools/aes.cpp b/Tools/aes.cpp index 23a983343..99b99e2ef 100644 --- a/Tools/aes.cpp +++ b/Tools/aes.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "aes.h" diff --git a/Tools/aes.h b/Tools/aes.h index 2924abb3c..5d25101e2 100644 --- a/Tools/aes.h +++ b/Tools/aes.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef __AES_H #define __AES_H diff --git a/Tools/ezOptionParser.h b/Tools/ezOptionParser.h index d2b09c056..7012a36b5 100644 --- a/Tools/ezOptionParser.h +++ b/Tools/ezOptionParser.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* This file is part of ezOptionParser. See MIT-LICENSE. @@ -52,7 +52,7 @@ static T fromString(const char* s) { return t; }; /* ################################################################### */ -static inline bool isdigit(const std::string & s, int i=0) { +static inline bool isdigit(const std::string & s, int i=0) { int n = s.length(); for(; i < n; ++i) switch(s[i]) { @@ -85,12 +85,12 @@ For example, -d < --dimension < --dmn, and also lower come before upper. The def static bool CmpOptStringPtr(std::string * s1, std::string * s2) { int c1,c2; const char *s=s1->c_str(); - for(c1=0; c1 < (long int)s1->size(); ++c1) + for(c1=0; c1 < (long int)s1->size(); ++c1) if (isalnum(s[c1])) // locale sensitive. break; s=s2->c_str(); - for(c2=0; c2 < (long int)s2->size(); ++c2) + for(c2=0; c2 < (long int)s2->size(); ++c2) if (isalnum(s[c2])) break; @@ -232,67 +232,67 @@ static void ToD(std::string ** strings, double * out, int n) { }; /* ################################################################### */ static void StringsToInts(std::vector & strings, std::vector & out) { - for(int i=0; i < (long int)strings.size(); ++i) { + for(int i=0; i < (long int)strings.size(); ++i) { out.push_back(atoi(strings[i].c_str())); } }; /* ################################################################### */ static void StringsToInts(std::vector * strings, std::vector * out) { - for(int i=0; i < (long int)strings->size(); ++i) { + for(int i=0; i < (long int)strings->size(); ++i) { out->push_back(atoi(strings->at(i)->c_str())); } }; /* ################################################################### */ static void StringsToLongs(std::vector & strings, std::vector & out) { - for(int i=0; i < (long int)strings.size(); ++i) { + for(int i=0; i < (long int)strings.size(); ++i) { out.push_back(atol(strings[i].c_str())); } }; /* ################################################################### */ static void StringsToLongs(std::vector * strings, std::vector * out) { - for(int i=0; i < (long int)strings->size(); ++i) { + for(int i=0; i < (long int)strings->size(); ++i) { out->push_back(atol(strings->at(i)->c_str())); } }; /* ################################################################### */ static void StringsToULongs(std::vector & strings, std::vector & out) { - for(int i=0; i < (long int)strings.size(); ++i) { + for(int i=0; i < (long int)strings.size(); ++i) { out.push_back(strtoul(strings[i].c_str(),0,0)); } }; /* ################################################################### */ static void StringsToULongs(std::vector * strings, std::vector * out) { - for(int i=0; i < (long int)strings->size(); ++i) { + for(int i=0; i < (long int)strings->size(); ++i) { out->push_back(strtoul(strings->at(i)->c_str(),0,0)); } }; /* ################################################################### */ static void StringsToFloats(std::vector & strings, std::vector & out) { - for(int i=0; i < (long int)strings.size(); ++i) { + for(int i=0; i < (long int)strings.size(); ++i) { out.push_back(atof(strings[i].c_str())); } }; /* ################################################################### */ static void StringsToFloats(std::vector * strings, std::vector * out) { - for(int i=0; i < (long int)strings->size(); ++i) { + for(int i=0; i < (long int)strings->size(); ++i) { out->push_back(atof(strings->at(i)->c_str())); } }; /* ################################################################### */ static void StringsToDoubles(std::vector & strings, std::vector & out) { - for(int i=0; i < (long int)strings.size(); ++i) { + for(int i=0; i < (long int)strings.size(); ++i) { out.push_back(atof(strings[i].c_str())); } }; /* ################################################################### */ static void StringsToDoubles(std::vector * strings, std::vector * out) { - for(int i=0; i < (long int)strings->size(); ++i) { + for(int i=0; i < (long int)strings->size(); ++i) { out->push_back(atof(strings->at(i)->c_str())); } }; /* ################################################################### */ static void StringsToStrings(std::vector * strings, std::vector * out) { - for(int i=0; i < (long int)strings->size(); ++i) { + for(int i=0; i < (long int)strings->size(); ++i) { out->push_back( *strings->at(i) ); } }; @@ -335,7 +335,7 @@ static char** CommandLineToArgvA(char* CmdLine, int* _argc) { i = 0; j = 0; - while( (a = CmdLine[i]) ) { + while( (a = CmdLine[i]) ) { if(in_QM) { if( (a == '\"') || (a == '\'')) // rsz. Added single quote. @@ -498,71 +498,71 @@ void ezOptionValidator::reset() { type = NOTYPE; }; /* ------------------------------------------------------------------- */ -ezOptionValidator::ezOptionValidator(char _type) : s1(0), op(0), quiet(0), type(_type), size(0), insensitive(0) { +ezOptionValidator::ezOptionValidator(char _type) : s1(0), op(0), quiet(0), type(_type), size(0), insensitive(0) { id = ezOptionParserIDGenerator::instance().next(); }; /* ------------------------------------------------------------------- */ -ezOptionValidator::ezOptionValidator(char _type, char _op, const char* list, int _size) : s1(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { +ezOptionValidator::ezOptionValidator(char _type, char _op, const char* list, int _size) : s1(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { id = ezOptionParserIDGenerator::instance().next(); s1 = new char[size]; memcpy(s1, list, size); }; /* ------------------------------------------------------------------- */ -ezOptionValidator::ezOptionValidator(char _type, char _op, const unsigned char* list, int _size) : u1(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { +ezOptionValidator::ezOptionValidator(char _type, char _op, const unsigned char* list, int _size) : u1(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { id = ezOptionParserIDGenerator::instance().next(); u1 = new unsigned char[size]; memcpy(u1, list, size); }; /* ------------------------------------------------------------------- */ -ezOptionValidator::ezOptionValidator(char _type, char _op, const short* list, int _size) : s2(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { +ezOptionValidator::ezOptionValidator(char _type, char _op, const short* list, int _size) : s2(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { id = ezOptionParserIDGenerator::instance().next(); s2 = new short[size]; memcpy(s2, list, size*sizeof(short)); }; /* ------------------------------------------------------------------- */ -ezOptionValidator::ezOptionValidator(char _type, char _op, const unsigned short* list, int _size) : u2(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { +ezOptionValidator::ezOptionValidator(char _type, char _op, const unsigned short* list, int _size) : u2(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { id = ezOptionParserIDGenerator::instance().next(); u2 = new unsigned short[size]; memcpy(u2, list, size*sizeof(unsigned short)); }; /* ------------------------------------------------------------------- */ -ezOptionValidator::ezOptionValidator(char _type, char _op, const int* list, int _size) : s4(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { +ezOptionValidator::ezOptionValidator(char _type, char _op, const int* list, int _size) : s4(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { id = ezOptionParserIDGenerator::instance().next(); s4 = new int[size]; memcpy(s4, list, size*sizeof(int)); }; /* ------------------------------------------------------------------- */ -ezOptionValidator::ezOptionValidator(char _type, char _op, const unsigned int* list, int _size) : u4(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { +ezOptionValidator::ezOptionValidator(char _type, char _op, const unsigned int* list, int _size) : u4(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { id = ezOptionParserIDGenerator::instance().next(); u4 = new unsigned int[size]; memcpy(u4, list, size*sizeof(unsigned int)); }; /* ------------------------------------------------------------------- */ -ezOptionValidator::ezOptionValidator(char _type, char _op, const long long* list, int _size) : s8(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { +ezOptionValidator::ezOptionValidator(char _type, char _op, const long long* list, int _size) : s8(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { id = ezOptionParserIDGenerator::instance().next(); s8 = new long long[size]; memcpy(s8, list, size*sizeof(long long)); }; /* ------------------------------------------------------------------- */ -ezOptionValidator::ezOptionValidator(char _type, char _op, const unsigned long long* list, int _size) : u8(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { +ezOptionValidator::ezOptionValidator(char _type, char _op, const unsigned long long* list, int _size) : u8(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { id = ezOptionParserIDGenerator::instance().next(); u8 = new unsigned long long[size]; memcpy(u8, list, size*sizeof(unsigned long long)); }; /* ------------------------------------------------------------------- */ -ezOptionValidator::ezOptionValidator(char _type, char _op, const float* list, int _size) : f(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { +ezOptionValidator::ezOptionValidator(char _type, char _op, const float* list, int _size) : f(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { id = ezOptionParserIDGenerator::instance().next(); f = new float[size]; memcpy(f, list, size*sizeof(float)); }; /* ------------------------------------------------------------------- */ -ezOptionValidator::ezOptionValidator(char _type, char _op, const double* list, int _size) : d(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { +ezOptionValidator::ezOptionValidator(char _type, char _op, const double* list, int _size) : d(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { id = ezOptionParserIDGenerator::instance().next(); d = new double[size]; memcpy(d, list, size*sizeof(double)); }; /* ------------------------------------------------------------------- */ -ezOptionValidator::ezOptionValidator(char _type, char _op, const char** list, int _size, bool _insensitive) : t(0), op(_op), quiet(0), type(_type), size(_size), insensitive(_insensitive) { +ezOptionValidator::ezOptionValidator(char _type, char _op, const char** list, int _size, bool _insensitive) : t(0), op(_op), quiet(0), type(_type), size(_size), insensitive(_insensitive) { id = ezOptionParserIDGenerator::instance().next(); t = new std::string*[size]; int i=0; @@ -577,7 +577,7 @@ _type: s1, u1, s2, u2, ..., f, d, t _op: lt, gt, ..., in _list: comma-delimited string */ -ezOptionValidator::ezOptionValidator(const char* _type, const char* _op, const char* _list, bool _insensitive) : t(0), quiet(0), type(0), size(0), insensitive(_insensitive) { +ezOptionValidator::ezOptionValidator(const char* _type, const char* _op, const char* _list, bool _insensitive) : t(0), quiet(0), type(0), size(0), insensitive(_insensitive) { id = ezOptionParserIDGenerator::instance().next(); switch(_type[0]) { @@ -932,11 +932,11 @@ bool ezOptionValidator::isValid(const std::string * valueAsString) { /* ################################################################### */ class OptionGroup { public: - OptionGroup() : delim(0), expectArgs(0), isRequired(false), isSet(false) { } + OptionGroup() : delim(0), expectArgs(0), isRequired(false), isSet(false) { } ~OptionGroup() { - int i; - for(i=0; i < (long int)flags.size(); ++i) + int i; + for(i=0; i < (long int)flags.size(); ++i) delete flags[i]; flags.clear(); @@ -988,8 +988,8 @@ class OptionGroup { /* ################################################################### */ void OptionGroup::clearArgs() { int i,j; - for(i=0; i < (long int)args.size(); ++i) { - for(j=0; j < (long int)args[i]->size(); ++j) + for(i=0; i < (long int)args.size(); ++i) { + for(j=0; j < (long int)args[i]->size(); ++j) delete args[i]->at(j); delete args[i]; @@ -1209,7 +1209,7 @@ void OptionGroup::getMultiInts(std::vector< std::vector >& out) { } else { if (!args.empty()) { int n = args.size(); - if ((long int)out.size() < n) out.resize(n); + if ((long int)out.size() < n) out.resize(n); for(int i=0; i < n; ++i) { StringsToInts(args[i], &out[i]); } @@ -1228,7 +1228,7 @@ void OptionGroup::getMultiLongs(std::vector< std::vector >& out) { } else { if (!args.empty()) { int n = args.size(); - if ((long int)out.size() < n) out.resize(n); + if ((long int)out.size() < n) out.resize(n); for(int i=0; i < n; ++i) { StringsToLongs(args[i], &out[i]); } @@ -1247,7 +1247,7 @@ void OptionGroup::getMultiULongs(std::vector< std::vector >& out) } else { if (!args.empty()) { int n = args.size(); - if ((long int)out.size() < n) out.resize(n); + if ((long int)out.size() < n) out.resize(n); for(int i=0; i < n; ++i) { StringsToULongs(args[i], &out[i]); } @@ -1266,7 +1266,7 @@ void OptionGroup::getMultiFloats(std::vector< std::vector >& out) { } else { if (!args.empty()) { int n = args.size(); - if ((long int)out.size() < n) out.resize(n); + if ((long int)out.size() < n) out.resize(n); for(int i=0; i < n; ++i) { StringsToFloats(args[i], &out[i]); } @@ -1285,7 +1285,7 @@ void OptionGroup::getMultiDoubles(std::vector< std::vector >& out) { } else { if (!args.empty()) { int n = args.size(); - if ((long int)out.size() < n) out.resize(n); + if ((long int)out.size() < n) out.resize(n); for(int i=0; i < n; ++i) { StringsToDoubles(args[i], &out[i]); } @@ -1304,10 +1304,10 @@ void OptionGroup::getMultiStrings(std::vector< std::vector >& out) } else { if (!args.empty()) { int n = args.size(); - if ((long int)out.size() < n) out.resize(n); + if ((long int)out.size() < n) out.resize(n); for(int i=0; i < n; ++i) { - for(int j=0; j < (long int)args[i]->size(); ++j) + for(int j=0; j < (long int)args[i]->size(); ++j) out[i].push_back( *args[i]->at(j) ); } } @@ -1378,19 +1378,19 @@ void ezOptionParser::reset() { this->doublespace = 1; int i; - for(i=0; i < (long int)groups.size(); ++i) + for(i=0; i < (long int)groups.size(); ++i) delete groups[i]; groups.clear(); - for(i=0; i < (long int)unknownArgs.size(); ++i) + for(i=0; i < (long int)unknownArgs.size(); ++i) delete unknownArgs[i]; unknownArgs.clear(); - for(i=0; i < (long int)firstArgs.size(); ++i) + for(i=0; i < (long int)firstArgs.size(); ++i) delete firstArgs[i]; firstArgs.clear(); - for(i=0; i < (long int)lastArgs.size(); ++i) + for(i=0; i < (long int)lastArgs.size(); ++i) delete lastArgs[i]; lastArgs.clear(); @@ -1405,18 +1405,18 @@ void ezOptionParser::reset() { /* ################################################################### */ void ezOptionParser::resetArgs() { int i; - for(i=0; i < (long int)groups.size(); ++i) + for(i=0; i < (long int)groups.size(); ++i) groups[i]->clearArgs(); - for(i=0; i < (long int)unknownArgs.size(); ++i) + for(i=0; i < (long int)unknownArgs.size(); ++i) delete unknownArgs[i]; unknownArgs.clear(); - for(i=0; i < (long int)firstArgs.size(); ++i) + for(i=0; i < (long int)firstArgs.size(); ++i) delete firstArgs[i]; firstArgs.clear(); - for(i=0; i < (long int)lastArgs.size(); ++i) + for(i=0; i < (long int)lastArgs.size(); ++i) delete lastArgs[i]; lastArgs.clear(); }; @@ -1540,7 +1540,7 @@ bool ezOptionParser::exportFile(const char * filename, bool all) { bool quote; // Export the first args, except the program name, so start from 1. - for(i=1; i < (long int)firstArgs.size(); ++i) { + for(i=1; i < (long int)firstArgs.size(); ++i) { quote = ((firstArgs[i]->find_first_of(" \t") != std::string::npos) && (firstArgs[i]->find_first_of("\'\"") == std::string::npos)); if (quote) @@ -1557,7 +1557,7 @@ bool ezOptionParser::exportFile(const char * filename, bool all) { out.append("\n"); std::vector stringPtrs(groups.size()); - int m; + int m; int n = groups.size(); for(i=0; i < n; ++i) { stringPtrs[i] = groups[i]->flags[0]; @@ -1609,7 +1609,7 @@ bool ezOptionParser::exportFile(const char * filename, bool all) { } // Export the last args. - for(i=0; i < (long int)lastArgs.size(); ++i) { + for(i=0; i < (long int)lastArgs.size(); ++i) { quote = ( lastArgs[i]->find_first_of(" \t") != std::string::npos ); if (quote) out.append("\""); @@ -1804,18 +1804,18 @@ void ezOptionParser::getUsageDescriptions(std::string & usage, int width, Layout std::map stringPtrToIndexMap; std::vector stringPtrs(groups.size()); - for(i=0; i < (long int)groups.size(); ++i) { + for(i=0; i < (long int)groups.size(); ++i) { std::sort(groups[i]->flags.begin(), groups[i]->flags.end(), CmpOptStringPtr); stringPtrToIndexMap[groups[i]->flags[0]] = i; stringPtrs[i] = groups[i]->flags[0]; } - size_t j, k; + size_t j, k; std::string opts; std::vector sortedOpts; // Sort first flag of each group with other groups. std::sort(stringPtrs.begin(), stringPtrs.end(), CmpOptStringPtr); - for(i=0; i < (long int)groups.size(); ++i) { + for(i=0; i < (long int)groups.size(); ++i) { //printf("DEBUG:%d: %d %d %s\n", __LINE__, i, stringPtrToIndexMap[stringPtrs[i]], stringPtrs[i]->c_str()); k = stringPtrToIndexMap[stringPtrs[i]]; opts.clear(); @@ -1823,7 +1823,7 @@ void ezOptionParser::getUsageDescriptions(std::string & usage, int width, Layout opts.append(*groups[k]->flags[j]); opts.append(", "); - if ((long int)opts.size() > width) + if ((long int)opts.size() > width) opts.append("\n"); } // The last flag. No need to append comma anymore. @@ -1851,8 +1851,8 @@ void ezOptionParser::getUsageDescriptions(std::string & usage, int width, Layout // Find longest opt flag string to set column start for help usage descriptions. int maxlen=0; if (layout == ALIGN) { - for(i=0; i < (long int)groups.size(); ++i) { - if (maxlen < (long int)sortedOpts[i].size()) + for(i=0; i < (long int)groups.size(); ++i) { + if (maxlen < (long int)sortedOpts[i].size()) maxlen = sortedOpts[i].size(); } } @@ -1861,7 +1861,7 @@ void ezOptionParser::getUsageDescriptions(std::string & usage, int width, Layout int helpwidth; std::list::iterator cIter, insertionIter; size_t pos; - for(i=0; i < (long int)groups.size(); ++i) { + for(i=0; i < (long int)groups.size(); ++i) { k = stringPtrToIndexMap[stringPtrs[i]]; if (layout == STAGGER) @@ -1876,13 +1876,13 @@ void ezOptionParser::getUsageDescriptions(std::string & usage, int width, Layout for(insertionIter=desc.begin(), cIter=insertionIter++; cIter != desc.end(); cIter=insertionIter++) { - if ((long int)((*cIter)->size()) > helpwidth) { + if ((long int)((*cIter)->size()) > helpwidth) { // Get pointer to next string to insert new strings before it. std::string *rem = *cIter; // Remove this line and add back in pieces. desc.erase(cIter); // Loop until remaining string is short enough. - while ((long int)rem->size() > helpwidth) { + while ((long int)rem->size() > helpwidth) { // Find whitespace to split before helpwidth. if (rem->at(helpwidth) == ' ') { // If word ends exactly at helpwidth, then split after it. @@ -1940,7 +1940,7 @@ void ezOptionParser::getUsageDescriptions(std::string & usage, int width, Layout bool ezOptionParser::gotExpected(std::vector & badOptions) { int i,j; - for(i=0; i < (long int)groups.size(); ++i) { + for(i=0; i < (long int)groups.size(); ++i) { OptionGroup *g = groups[i]; // If was set, ensure number of args is correct. if (g->isSet) { @@ -1949,8 +1949,8 @@ bool ezOptionParser::gotExpected(std::vector & badOptions) { continue; } - for(j=0; j < (long int)g->args.size(); ++j) { - if ((g->expectArgs != -1) && (g->expectArgs != (long int)g->args[j]->size())) + for(j=0; j < (long int)g->args.size(); ++j) { + if ((g->expectArgs != -1) && (g->expectArgs != (long int)g->args[j]->size())) badOptions.push_back(*g->flags[0]); } } @@ -1962,7 +1962,7 @@ bool ezOptionParser::gotExpected(std::vector & badOptions) { bool ezOptionParser::gotRequired(std::vector & badOptions) { int i; - for(i=0; i < (long int)groups.size(); ++i) { + for(i=0; i < (long int)groups.size(); ++i) { OptionGroup *g = groups[i]; // Simple case when required but user never set it. if (g->isRequired && (!g->isSet)) { @@ -1987,10 +1987,10 @@ bool ezOptionParser::gotValid(std::vector & badOptions, std::vector ezOptionValidator *v = validators[validatorid]; bool nextgroup = false; - for (int i = 0; i < (long int)g->args.size(); ++i) { + for (int i = 0; i < (long int)g->args.size(); ++i) { if (nextgroup) break; std::vector< std::string* > * args = g->args[i]; - for (int j = 0; j < (long int)args->size(); ++j) { + for (int j = 0; j < (long int)args->size(); ++j) { if (!v->isValid(args->at(j))) { badOptions.push_back(*g->flags[0]); badArgs.push_back(*args->at(j)); @@ -2013,7 +2013,7 @@ void ezOptionParser::parse(int argc, const char * argv[]) { std::cout << (*it).first << " => " << (*it).second << std::endl; */ - int i, k, firstOptIndex=0, lastOptIndex=0; + int i, k, firstOptIndex=0, lastOptIndex=0; std::string s; OptionGroup *g; @@ -2090,7 +2090,7 @@ void ezOptionParser::prettyPrint(std::string & out) { int i,j,k; out += "First Args:\n"; - for(i=0; i < (long int)firstArgs.size(); ++i) { + for(i=0; i < (long int)firstArgs.size(); ++i) { sprintf(tmp, "%d: %s\n", i+1, firstArgs[i]->c_str()); out += tmp; } @@ -2111,7 +2111,7 @@ void ezOptionParser::prettyPrint(std::string & out) { g = get(stringPtrs[i]->c_str()); out += "\n"; // The flag names: - for(j=0; j < (long int)g->flags.size()-1; ++j) { + for(j=0; j < (long int)g->flags.size()-1; ++j) { sprintf(tmp, "%s, ", g->flags[j]->c_str()); out += tmp; } @@ -2124,12 +2124,12 @@ void ezOptionParser::prettyPrint(std::string & out) { sprintf(tmp, "%s (default)\n", g->defaults.c_str()); out += tmp; } else { - for(k=0; k < (long int)g->args.size(); ++k) { - for(j=0; j < (long int)g->args[k]->size()-1; ++j) { + for(k=0; k < (long int)g->args.size(); ++k) { + for(j=0; j < (long int)g->args[k]->size()-1; ++j) { sprintf(tmp, "%s%c", g->args[k]->at(j)->c_str(), g->delim); out += tmp; } - sprintf(tmp, "%s\n", g->args[k]->back()->c_str()); + sprintf(tmp, "%s\n", g->args[k]->back()->c_str()); out += tmp; } } @@ -2144,13 +2144,13 @@ void ezOptionParser::prettyPrint(std::string & out) { } out += "\nLast Args:\n"; - for(i=0; i < (long int)lastArgs.size(); ++i) { + for(i=0; i < (long int)lastArgs.size(); ++i) { sprintf(tmp, "%d: %s\n", i+1, lastArgs[i]->c_str()); out += tmp; } out += "\nUnknown Args:\n"; - for(i=0; i < (long int)unknownArgs.size(); ++i) { + for(i=0; i < (long int)unknownArgs.size(); ++i) { sprintf(tmp, "%d: %s\n", i+1, unknownArgs[i]->c_str()); out += tmp; } diff --git a/Tools/int.h b/Tools/int.h index 78e253bcb..69e7a4943 100644 --- a/Tools/int.h +++ b/Tools/int.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * int.h diff --git a/Tools/mkpath.cpp b/Tools/mkpath.cpp index a09525c8e..5dc13ec06 100644 --- a/Tools/mkpath.cpp +++ b/Tools/mkpath.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Tools/mkpath.h" #include diff --git a/Tools/mkpath.h b/Tools/mkpath.h index 4eca401e3..d10be4aee 100644 --- a/Tools/mkpath.h +++ b/Tools/mkpath.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef TOOLS_MKPATH_H_ #define TOOLS_MKPATH_H_ diff --git a/Tools/octetStream.cpp b/Tools/octetStream.cpp index d3d83dbb7..553374104 100644 --- a/Tools/octetStream.cpp +++ b/Tools/octetStream.cpp @@ -1,7 +1,8 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include +#include #include "octetStream.h" #include @@ -10,7 +11,6 @@ #include "Exceptions/Exceptions.h" #include "Networking/data.h" - void octetStream::assign(const octetStream& os) { if (os.len>=mxlen) @@ -45,38 +45,31 @@ octetStream::octetStream(const octetStream& os) void octetStream::hash(octetStream& output) const { - blk_SHA_CTX ctx; - blk_SHA1_Init(&ctx); - blk_SHA1_Update(&ctx,data,len); - blk_SHA1_Final(output.data,&ctx); - output.len=HASH_SIZE; + crypto_generichash(output.data, crypto_generichash_BYTES_MIN, data, len, NULL, 0); + output.len=crypto_generichash_BYTES_MIN; } octetStream octetStream::hash() const { - octetStream h(HASH_SIZE); + octetStream h(crypto_generichash_BYTES_MIN); hash(h); return h; } -bigint octetStream::check_sum() const +bigint octetStream::check_sum(int req_bytes) const { - unsigned char hash[HASH_SIZE]; - - blk_SHA_CTX ctx; - blk_SHA1_Init(&ctx); - blk_SHA1_Update(&ctx,data,len); - blk_SHA1_Final(hash,&ctx); + unsigned char hash[req_bytes]; + crypto_generichash(hash, req_bytes, data, len, NULL, 0); bigint ans; - bigintFromBytes(ans,hash,HASH_SIZE); + bigintFromBytes(ans,hash,req_bytes); + // cout << ans << "\n"; return ans; } - bool octetStream::equals(const octetStream& a) const { if (len!=a.len) { return false; } @@ -89,9 +82,7 @@ bool octetStream::equals(const octetStream& a) const void octetStream::append_random(int num) { resize(len+num); - int randomData = open("/dev/urandom", O_RDONLY); - read(randomData, data+len, num*sizeof(unsigned char)); - close(randomData); + randombytes_buf(data+len, num); len+=num; } @@ -126,6 +117,13 @@ void octetStream::store(unsigned int l) len+=4; } +void octetStream::store(int l) +{ + resize(len+4); + INT_TO_BYTES(data+len,l); + len+=4; +} + void octetStream::get(unsigned int& l) { @@ -133,6 +131,12 @@ void octetStream::get(unsigned int& l) ptr+=4; } +void octetStream::get(int& l) +{ + l=BYTES_TO_INT(data+ptr); + ptr+=4; +} + void octetStream::store(const bigint& x) { @@ -165,8 +169,84 @@ void octetStream::get(bigint& ans) } } +// Construct the ciphertext as `crypto_secretbox(pt, counter||random)` +void octetStream::encrypt_sequence(const octet* key, uint64_t counter) +{ + octet nonce[crypto_secretbox_NONCEBYTES]; + int i; + int message_len_bytes = len; + randombytes_buf(nonce, sizeof nonce); + if(counter == UINT64_MAX) { + throw Processor_Error("Encryption would overflow counter. Too many messages."); + } else { + counter++; + } + for(i=0; i<8; i++) { + nonce[i] = uint8_t ((counter >> (8*i)) & 0xFF); + } + + resize(len + crypto_secretbox_MACBYTES + crypto_secretbox_NONCEBYTES); + + // Encrypt data in-place + crypto_secretbox_easy(data, data, message_len_bytes, nonce, key); + // Adjust length to account for MAC, then append nonce + len += crypto_secretbox_MACBYTES; + append(nonce, sizeof nonce); +} + +void octetStream::decrypt_sequence(const octet* key, uint64_t counter) +{ + int ciphertext_len = len - crypto_box_NONCEBYTES; + const octet *nonce = data + ciphertext_len; + int i; + uint64_t recvCounter=0; + // Numbers are typically 24U + 16U so cast to int is safe. + if (len < (int)(crypto_box_NONCEBYTES + crypto_secretbox_MACBYTES)) + { + throw Processor_Error("Cannot decrypt octetStream: ciphertext too short"); + } + for(i=7; i>=0; i--) { + recvCounter |= (uint64_t) *(nonce + i); + recvCounter = recvCounter << (i*8); + } + if(recvCounter != counter + 1) { + throw Processor_Error("Incorrect counter on stream. Possible MITM."); + } + if (crypto_secretbox_open_easy(data, data, ciphertext_len, nonce, key) != 0) + { + throw Processor_Error("octetStream decryption failed!"); + } + rewind_write_head(crypto_box_NONCEBYTES + crypto_secretbox_MACBYTES); +} +void octetStream::encrypt(const octet* key) +{ + octet nonce[crypto_secretbox_NONCEBYTES]; + randombytes_buf(nonce, sizeof nonce); + int message_len_bytes = len; + resize(len + crypto_secretbox_MACBYTES + crypto_secretbox_NONCEBYTES); + + // Encrypt data in-place + crypto_secretbox_easy(data, data, message_len_bytes, nonce, key); + // Adjust length to account for MAC, then append nonce + len += crypto_secretbox_MACBYTES; + append(nonce, sizeof nonce); +} +void octetStream::decrypt(const octet* key) +{ + int ciphertext_len = len - crypto_box_NONCEBYTES; + // Numbers are typically 24U + 16U so cast to int is safe. + if (len < (int)(crypto_box_NONCEBYTES + crypto_secretbox_MACBYTES)) + { + throw Processor_Error("Cannot decrypt octetStream: ciphertext too short"); + } + if (crypto_secretbox_open_easy(data, data, ciphertext_len, data + ciphertext_len, key) != 0) + { + throw Processor_Error("octetStream decryption failed!"); + } + rewind_write_head(crypto_box_NONCEBYTES + crypto_secretbox_MACBYTES); +} ostream& operator<<(ostream& s,const octetStream& o) { diff --git a/Tools/octetStream.h b/Tools/octetStream.h index 13487decb..18ff3934b 100644 --- a/Tools/octetStream.h +++ b/Tools/octetStream.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _octetStream #define _octetStream @@ -23,6 +23,9 @@ #include #include #include + +#include + using namespace std; @@ -50,11 +53,15 @@ class octetStream int get_length() const { return len; } octet* get_data() const { return data; } + bool done() const { return ptr == len; } + bool empty() const { return len == 0; } + int left() const { return len - ptr; } + octetStream hash() const; // output must have length at least HASH_SIZE void hash(octetStream& output) const; // The following produces a check sum for debugging purposes - bigint check_sum() const; + bigint check_sum(int req_bytes=crypto_hash_BYTES) const; void concat(const octetStream& os); @@ -83,13 +90,19 @@ class octetStream void get_bytes(octet* ans, int& l); //Assumes enough space in ans void store(unsigned int a); - void store(int a) { store((unsigned int) a); } + void store(int a); void get(unsigned int& a); - void get(int& a) { get((unsigned int&) a); } + void get(int& a); void store(const bigint& x); void get(bigint& ans); + // works for all statically allocated types + template + void serialize(const T& x) { append((octet*)&x, sizeof(x)); } + template + void unserialize(T& x) { consume((octet*)&x, sizeof(x)); } + void consume(octetStream& s,int l) { s.resize(l); consume(s.data,l); @@ -98,6 +111,20 @@ class octetStream void Send(int socket_num) const; void Receive(int socket_num); + void ReceiveExpected(int socket_num, int expected); + + // In-place authenticated encryption using sodium; key of length crypto_generichash_BYTES + // ciphertext = Enc(message) | MAC | counter + // + // This is much like 'encrypt' but uses a deterministic counter for the nonce, + // allowing enforcement of message order. + void encrypt_sequence(const octet* key, uint64_t counter); + void decrypt_sequence(const octet* key, uint64_t counter); + + // In-place authenticated encryption using sodium; key of length crypto_secretbox_KEYBYTES + // ciphertext = Enc(message) | MAC | nonce + void encrypt(const octet* key); + void decrypt(const octet* key); friend ostream& operator<<(ostream& s,const octetStream& o); friend class PRNG; @@ -157,5 +184,24 @@ inline void octetStream::Receive(int socket_num) receive(socket_num,data,len); } +inline void octetStream::ReceiveExpected(int socket_num, int expected) +{ + octet blen[4]; + receive(socket_num,blen,4); + + int nlen=decode_length(blen); + if (nlen != expected) { + cerr << "octetStream::ReceiveExpected: got " << nlen << + " length, expected " << expected << endl; + throw bad_value(); + } + + len=0; + resize(nlen); + len=nlen; + + receive(socket_num,data,len); +} + #endif diff --git a/Tools/parse.h b/Tools/parse.h new file mode 100644 index 000000000..0ced2e3fd --- /dev/null +++ b/Tools/parse.h @@ -0,0 +1,49 @@ +/* + * parse.h + * + */ + +#ifndef TOOLS_PARSE_H_ +#define TOOLS_PARSE_H_ + +#include +#include +using namespace std; + +// Read a byte +inline int get_val(istream& s) +{ + char cc; + s.get(cc); + int a=cc; + if (a<0) { a+=256; } + return a; +} + +// Read a 4-byte integer +inline int get_int(istream& s) +{ + int n = 0; + for (int i=0; i<4; i++) + { n<<=8; + int t=get_val(s); + n+=t; + } + return n; +} + +// Read several integers +inline void get_ints(int* res, istream& s, int count) +{ + for (int i = 0; i < count; i++) + res[i] = get_int(s); +} + +inline void get_vector(int m, vector& start, istream& s) +{ + start.resize(m); + for (int i = 0; i < m; i++) + start[i] = get_int(s); +} + +#endif /* TOOLS_PARSE_H_ */ diff --git a/Tools/pprint.h b/Tools/pprint.h new file mode 100644 index 000000000..3df479f13 --- /dev/null +++ b/Tools/pprint.h @@ -0,0 +1,13 @@ + +#include +#include + +using namespace std; + +inline void pprint_bytes(const char *label, unsigned char *bytes, int len) +{ + cout << label << ": "; + for (int j = 0; j < len; j++) + cout << setfill('0') << setw(2) << hex << (int) bytes[j]; + cout << dec << endl; +} diff --git a/Tools/random.cpp b/Tools/random.cpp index 6d1f236a8..c89532206 100644 --- a/Tools/random.cpp +++ b/Tools/random.cpp @@ -1,8 +1,9 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Tools/random.h" #include +#include #include using namespace std; @@ -18,9 +19,7 @@ PRNG::PRNG() : cnt(0) void PRNG::ReSeed() { - FILE* rD=fopen("/dev/urandom", "r"); - fread(seed,sizeof(octet),SEED_SIZE,rD); - fclose(rD); + randombytes_buf(seed, SEED_SIZE); InitSeed(); } diff --git a/Tools/random.h b/Tools/random.h index 0c868904c..43773478f 100644 --- a/Tools/random.h +++ b/Tools/random.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _random #define _random diff --git a/Tools/sha1.cpp b/Tools/sha1.cpp index 2c555c33e..fe272332d 100644 --- a/Tools/sha1.cpp +++ b/Tools/sha1.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * SHA1 routine optimized to do word accesses rather than byte accesses, diff --git a/Tools/sha1.h b/Tools/sha1.h index 6d2cd1bd2..fad2af866 100644 --- a/Tools/sha1.h +++ b/Tools/sha1.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _SHA1 #define _SHA1 diff --git a/Tools/time-func.cpp b/Tools/time-func.cpp index 6ab284461..f8dda0fe6 100644 --- a/Tools/time-func.cpp +++ b/Tools/time-func.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Tools/time-func.h" diff --git a/Tools/time-func.h b/Tools/time-func.h index f3357c23d..4a2667ef6 100644 --- a/Tools/time-func.h +++ b/Tools/time-func.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _timer #define _timer diff --git a/check-passive.cpp b/check-passive.cpp index 53b9c4e8a..7f1209535 100644 --- a/check-passive.cpp +++ b/check-passive.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Math/gf2n.h" #include "Math/gfp.h" diff --git a/client-setup.cpp b/client-setup.cpp new file mode 100644 index 000000000..2c36658de --- /dev/null +++ b/client-setup.cpp @@ -0,0 +1,180 @@ +// (C) 2017 University of Bristol. See License.txt + +// Preprocessing stage to: +// Create the public/private key pairs for each client +// Create the public/private key pairs for each spdz engine +// For each client store the client keys + all spdz engine public keys +// in a file named Client-Keys-C +// For each spdz engine store the spdz engine keys + all client public keys +// in a file named Player-SPDZ-Keys-P +// + +#include + +#include "Math/gf2n.h" +#include "Math/gfp.h" +#include "Math/Share.h" +#include "Math/Setup.h" +#include "Auth/fake-stuff.h" +#include "Exceptions/Exceptions.h" + +#include "Math/Setup.h" +#include "Processor/Data_Files.h" +#include "Tools/mkpath.h" +#include "Tools/ezOptionParser.h" +#include "Tools/Config.h" + +#include +#include +using namespace std; + +static void output(const vector &vec, ofstream &of) +{ + copy(vec.begin(), vec.end(), ostreambuf_iterator(of)); +} + +int main(int argc, const char** argv) +{ + ez::ezOptionParser opt; + + opt.syntax = "./client-setup.x [OPTIONS]\n"; + + opt.add( + "0", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Number of external clients (default: nplayers)", // Help description. + "-nc", // Flag token. + "--numclients" // Flag token. + ); + opt.add( + "128", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Bit length of GF(p) field (default: 128)", // Help description. + "-lgp", // Flag token. + "--lgp" // Flag token. + ); + opt.add( + "40", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Bit length of GF(2^n) field (default: 40)", // Help description. + "-lg2", // Flag token. + "--lg2" // Flag token. + ); + opt.parse(argc, argv); + + string prep_data_prefix; + + string usage; + + int nplayers; + if (opt.firstArgs.size() == 2) + { + nplayers = atoi(opt.firstArgs[1]->c_str()); + } + else if (opt.lastArgs.size() == 1) + { + nplayers = atoi(opt.lastArgs[0]->c_str()); + } + else + { + cerr << "ERROR: invalid number of arguments\n"; + opt.getUsage(usage); + cout << usage; + return 1; + } + + int lg2, lgp, nclients; + opt.get("--numclients")->getInt(nclients); + if (nclients <= 0) + nclients = nplayers; + opt.get("--lgp")->getInt(lgp); + opt.get("--lg2")->getInt(lg2); + + cout << "nplayers = " << nplayers << endl; + cout << "nclients = " << nclients << endl; + cout << "lgp = " << lgp << endl; + cout << "lgp2 = " << lg2 << endl; + + prep_data_prefix = get_prep_dir(nplayers, lgp, lg2); + cout << "prep dir = " << prep_data_prefix << endl; + + vector client_publickeys; + vector client_secretkeys; + client_publickeys.resize(nclients); + client_secretkeys.resize(nclients); + for (int i = 0; i < nclients; i++) { + client_secretkeys[i].resize(crypto_box_SECRETKEYBYTES); + client_publickeys[i].resize(crypto_box_PUBLICKEYBYTES); + randombytes_buf(&client_secretkeys[i][0], client_secretkeys[i].size()); + crypto_scalarmult_base(&client_publickeys[i][0], &client_secretkeys[i][0]); + } + + vector client_signing_publickeys; + vector client_signing_secretkeys; + client_signing_publickeys.resize(nclients); + client_signing_secretkeys.resize(nclients); + for (int i = 0; i < nclients; i++) { + client_signing_publickeys[i].resize(crypto_sign_PUBLICKEYBYTES); + client_signing_secretkeys[i].resize(crypto_sign_SECRETKEYBYTES); + crypto_sign_keypair(&client_signing_publickeys[i][0], &client_signing_secretkeys[i][0]); + } + + vector server_publickeys; + vector server_secretkeys; + server_publickeys.resize(nplayers); + server_secretkeys.resize(nplayers); + for (int i = 0; i < nplayers; i++) { + server_publickeys[i].resize(crypto_box_PUBLICKEYBYTES); + server_secretkeys[i].resize(crypto_box_SECRETKEYBYTES); + randombytes_buf(&server_secretkeys[i][0], server_secretkeys[i].size()); + crypto_scalarmult_base(&server_publickeys[i][0], &server_secretkeys[i][0]); + } + vector server_signing_publickeys; + vector server_signing_secretkeys; + server_signing_publickeys.resize(nplayers); + server_signing_secretkeys.resize(nplayers); + for (int i = 0; i < nplayers; i++) { + server_signing_publickeys[i].resize(crypto_sign_PUBLICKEYBYTES); + server_signing_secretkeys[i].resize(crypto_sign_SECRETKEYBYTES); + crypto_sign_keypair(&server_signing_publickeys[i][0], &server_signing_secretkeys[i][0]); + } + + /* Write client files */ + for (int i = 0; i < nclients; i++) { + stringstream filename; + filename << prep_data_prefix << "Client-Keys-C" << i; + ofstream outf(filename.str().c_str()); + if (outf.fail()) + throw file_error(filename.str().c_str()); + // Write public key and secret key + output(client_publickeys[i],outf); + output(client_secretkeys[i],outf); + output(client_signing_publickeys[i],outf); + output(client_signing_secretkeys[i],outf); + int keycount = 2; + + // Write all spdz engine public keys + for (int j = 0; j < nplayers; j++) { + output(server_publickeys[j], outf); + output(server_signing_publickeys[j], outf); + keycount++; + } + outf.close(); + cout << "Wrote " << keycount << " keys to " << filename.str() << endl; + } + + /* Write spdz engine files */ + for (int i = 0; i < nplayers; i++) { + Config::write_player_config_file( prep_data_prefix, i + , server_publickeys[i], server_secretkeys[i] + , server_signing_publickeys[i], server_signing_secretkeys[i] + , client_publickeys, client_signing_publickeys + , server_publickeys, server_signing_publickeys); + } +} diff --git a/compile.py b/compile.py index 2b1258727..711417af7 100755 --- a/compile.py +++ b/compile.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt # ===== Compiler usage instructions ===== @@ -60,6 +60,8 @@ def main(): help="profile compilation") parser.add_option("-C", "--continous", action="store_true", dest="continuous", help="continuous computation") + parser.add_option("-s", "--stop", action="store_true", dest="stop", + help="stop on register errors") options,args = parser.parse_args() if len(args) < 1: parser.print_help() diff --git a/ot-offline.cpp b/ot-offline.cpp index 13c356339..23c4afaeb 100644 --- a/ot-offline.cpp +++ b/ot-offline.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * OT-Offline.cpp diff --git a/tutorial.md b/tutorial.md index ef87f090f..5b8a12851 100644 --- a/tutorial.md +++ b/tutorial.md @@ -1,4 +1,4 @@ -(C) 2016 University of Bristol. See License.txt +(C) 2017 University of Bristol. See License.txt Suppose we want to add 2 integers mod p in clear, where p has 128 bits and compute over 2 parties inputs: P0, P1. @@ -130,6 +130,7 @@ inputs. The executables can be found after compiling SPDZ. Customizing those should be straightforward. Make sure you copy the output files to Player-Data /Private-Input-{i} files. +There is a sockets interface to provide input and output from external client processes. See the [ExternalIO directory](./ExternalIO/README.md). Other examples ==============