diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..c6c59de3a --- /dev/null +++ b/.gitignore @@ -0,0 +1,72 @@ +# Offline data, runtime logs # +############################## +Player-Data/* +Prep-Data/* +logs/* +Language-Definition/main.pdf + +# Personal CONFIG file # +############################## +CONFIG.mine + +# Compiled source # +################### +Programs/Bytecode/* +Programs/Schedules/* +Programs/Public-Input/* +*.com +*.class +*.dll +*.exe +*.x +*.o +*.so +*.pyc +*.bc +*.sch +*.a + +# Packages # +############ +# it's better to unpack these files and commit the raw source +# git has its own built in compression methods +*.7z +*.dmg +*.gz +*.iso +*.jar +*.rar +*.tar +*.zip + +# Latex # +######### +*.aux +*.lof +*.log +*.lot +*.fls +*.out +*.toc +*.fmt +*.bbl +*.bcf +*.blg + + +# Logs and databases # +###################### +*.log +*.sql +*.sqlite + +# OS generated files # +###################### +*~ +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..0ccc7366b --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "SimpleOT"] + path = SimpleOT + url = git@github.com:pascholl/SimpleOT.git diff --git a/Auth/MAC_Check.cpp b/Auth/MAC_Check.cpp new file mode 100644 index 000000000..873d1d6b0 --- /dev/null +++ b/Auth/MAC_Check.cpp @@ -0,0 +1,509 @@ +// (C) 2016 University of Bristol. See License.txt + + +#include "Auth/MAC_Check.h" +#include "Auth/Subroutines.h" +#include "Exceptions/Exceptions.h" + +#include "Tools/random.h" +#include "Tools/time-func.h" +#include "Tools/int.h" + +#include "Math/gfp.h" +#include "Math/gf2n.h" + +#include + +enum mc_timer { SEND, RECV_ADD, BCAST, RECV_SUM, SEED, COMMIT, WAIT_SUMMER, RECV, SUM, SELECT, MAX_TIMER }; +const char* mc_timer_names[] = { + "sending", + "receiving and adding", + "broadcasting", + "receiving summed values", + "random seed", + "commit and open", + "wait for summer thread", + "receiving", + "summing", + "waiting for select()" +}; + + +template +MAC_Check::MAC_Check(const T& ai, int opening_sum, int max_broadcast, int send_player) : + base_player(send_player), opening_sum(opening_sum), max_broadcast(max_broadcast) +{ + popen_cnt=0; + alphai=ai; + values_opened=0; + timers.resize(MAX_TIMER); +} + +template +MAC_Check::~MAC_Check() +{ + for (unsigned int i = 0; i < timers.size(); i++) + if (timers[i].elapsed() > 0) + cerr << T::type_string() << " " << mc_timer_names[i] << ": " + << timers[i].elapsed() << endl; + + for (unsigned int i = 0; i < player_timers.size(); i++) + if (player_timers[i].elapsed() > 0) + cerr << T::type_string() << " waiting for " << i << ": " + << player_timers[i].elapsed() << endl; +} + +template +void add_openings(vector& values, const Player& P, int sum_players, int last_sum_players, int send_player, MAC_Check& MC) +{ + MC.player_timers.resize(P.num_players()); + vector& oss = MC.oss; + oss.resize(P.num_players()); + vector senders; + senders.reserve(P.num_players()); + + for (int relative_sender = positive_modulo(P.my_num() - send_player, P.num_players()) + sum_players; + relative_sender < last_sum_players; relative_sender += sum_players) + { + int sender = positive_modulo(send_player + relative_sender, P.num_players()); + senders.push_back(sender); + } + + for (int j = 0; j < (int)senders.size(); j++) + P.request_receive(senders[j], oss[j]); + + for (int j = 0; j < (int)senders.size(); j++) + { + int sender = senders[j]; + MC.player_timers[sender].start(); + P.wait_receive(sender, oss[j], true); + MC.player_timers[sender].stop(); + if ((unsigned)oss[j].get_length() < values.size() * T::size()) + { + stringstream ss; + ss << "Not enough information received, expected " + << values.size() * T::size() << " bytes, got " + << oss[j].get_length(); + throw Processor_Error(ss.str()); + } + MC.timers[SUM].start(); + for (unsigned int i=0; i(oss[j].consume(T::size())); + } + MC.timers[SUM].stop(); + } +} + +template +void MAC_Check::POpen_Begin(vector& values,const vector >& S,const Player& P) +{ + AddToMacs(S); + + for (unsigned int i=0; i= sum_players && my_relative_num < last_sum_players) + { + for (unsigned int i=0; i(values, P, sum_players, last_sum_players, base_player, *this); + else + add_openings(values, P, sum_players, last_sum_players, base_player, *this); + timers[RECV_ADD].stop(); + } + } + + if (P.my_num() == base_player) + { + os.reset_write_head(); + for (unsigned int i=0; i +void MAC_Check::POpen_End(vector& values,const vector >& S,const Player& P) +{ + S.size(); + int my_relative_num = positive_modulo(P.my_num() - base_player, P.num_players()); + if (my_relative_num * max_broadcast >= P.num_players()) + { + int sender = (base_player + my_relative_num / max_broadcast) % P.num_players(); + ReceiveValues(values, P, sender); + } + else + GetValues(values); + + popen_cnt += values.size(); + CheckIfNeeded(P); + + /* not compatible with continuous communication + send_player++; + if (send_player==P.num_players()) + { send_player=0; } + */ +} + + +template +void MAC_Check::AddToMacs(const vector >& shares) +{ + for (unsigned int i = 0; i < shares.size(); i++) + macs.push_back(shares[i].get_mac()); +} + + +template +void MAC_Check::AddToValues(vector& values) +{ + vals.insert(vals.end(), values.begin(), values.end()); +} + + +template +void MAC_Check::ReceiveValues(vector& values, const Player& P, int sender) +{ + timers[RECV_SUM].start(); + P.receive_player(sender, os, true); + timers[RECV_SUM].stop(); + for (unsigned int i = 0; i < values.size(); i++) + values[i].unpack(os); + AddToValues(values); +} + + +template +void MAC_Check::GetValues(vector& values) +{ + int size = values.size(); + if (popen_cnt + size > int(vals.size())) + { + stringstream ss; + ss << "wanted " << values.size() << " values from " << popen_cnt << ", only " << vals.size() << " in store"; + throw out_of_range(ss.str()); + } + values.clear(); + typename vector::iterator first = vals.begin() + popen_cnt; + values.insert(values.end(), first, first + size); +} + + +template +void MAC_Check::CheckIfNeeded(const Player& P) +{ + if (WaitingForCheck() >= POPEN_MAX) + Check(P); +} + + +template +void MAC_Check::AddToCheck(const T& mac, const T& value, const Player& P) +{ + CheckIfNeeded(P); + macs.push_back(mac); + vals.push_back(value); +} + + + +template +void MAC_Check::Check(const Player& P) +{ + if (WaitingForCheck() == 0) + return; + + //cerr << "In MAC Check : " << popen_cnt << endl; + octet seed[SEED_SIZE]; + timers[SEED].start(); + Create_Random_Seed(seed,P,SEED_SIZE); + timers[SEED].stop(); + PRNG G; + G.SetSeed(seed); + + Share sj; + T a,gami,h,temp; + a.assign_zero(); + gami.assign_zero(); + vector tau(P.num_players()); + for (int i=0; i +int mc_base_id(int function_id, int thread_num) +{ + return (function_id << 28) + ((T::field_type() + 1) << 24) + (thread_num << 16); +} + +template +Separate_MAC_Check::Separate_MAC_Check(const T& ai, Names& Nms, + int thread_num, int opening_sum, int max_broadcast, int send_player) : + MAC_Check(ai, opening_sum, max_broadcast, send_player), + check_player(Nms, mc_base_id(1, thread_num)) +{ +} + +template +void Separate_MAC_Check::Check(const Player& P) +{ + P.my_num(); + MAC_Check::Check(check_player); +} + + +template +void* run_summer_thread(void* summer) +{ + ((Summer*) summer)->run(); + return 0; +} + +template +Parallel_MAC_Check::Parallel_MAC_Check(const T& ai, Names& Nms, + int thread_num, int opening_sum, int max_broadcast, int base_player) : + Separate_MAC_Check(ai, Nms, thread_num, opening_sum, max_broadcast, base_player), + send_player(Nms, mc_base_id(2, thread_num)), + send_base_player(base_player) +{ + int sum_players = Nms.num_players(); + Player* summer_send_player = &send_player; + for (int i = 0; ; i++) + { + int last_sum_players = sum_players; + sum_players = (sum_players - 2 + opening_sum) / opening_sum; + int next_sum_players = (sum_players - 2 + opening_sum) / opening_sum; + if (sum_players == 0) + break; + Player* summer_receive_player = summer_send_player; + summer_send_player = new Player(Nms, mc_base_id(3, thread_num)); + summers.push_back(new Summer(sum_players, last_sum_players, next_sum_players, + summer_send_player, summer_receive_player, *this)); + pthread_create(&(summers[i]->thread), 0, run_summer_thread, summers[i]); + } + receive_player = summer_send_player; +} + +template +Parallel_MAC_Check::~Parallel_MAC_Check() +{ + for (unsigned int i = 0; i < summers.size(); i++) + { + summers[i]->input_queue.stop(); + pthread_join(summers[i]->thread, 0); + delete summers[i]; + } +} + +template +void Parallel_MAC_Check::POpen_Begin(vector& values, + const vector >& S, const Player& P) +{ + values.size(); + this->AddToMacs(S); + + int my_relative_num = positive_modulo(P.my_num() - send_base_player, P.num_players()); + int sum_players = (P.num_players() - 2 + this->opening_sum) / this->opening_sum; + int receiver = positive_modulo(send_base_player + my_relative_num % sum_players, P.num_players()); + + // use queue rather sending to myself + if (receiver == P.my_num()) + { + for (unsigned int i = 0; i < S.size(); i++) + values[i] = S[i].get_share(); + summers.front()->share_queue.push(values); + } + else + { + this->os.reset_write_head(); + for (unsigned int i=0; ios); + this->timers[SEND].start(); + send_player.send_to(receiver,this->os,true); + this->timers[SEND].stop(); + } + + for (unsigned int i = 0; i < summers.size(); i++) + summers[i]->input_queue.push(S.size()); + + this->values_opened += S.size(); + send_base_player = (send_base_player + 1) % send_player.num_players(); +} + +template +void Parallel_MAC_Check::POpen_End(vector& values, + const vector >& S, const Player& P) +{ + int last_size = 0; + this->timers[WAIT_SUMMER].start(); + summers.back()->output_queue.pop(last_size); + this->timers[WAIT_SUMMER].stop(); + if (int(values.size()) != last_size) + { + stringstream ss; + ss << "stopopen wants " << values.size() << " values, but I have " << last_size << endl; + throw Processor_Error(ss.str().c_str()); + } + + if (this->base_player == P.my_num()) + { + value_queue.pop(values); + if (int(values.size()) != last_size) + throw Processor_Error("wrong number of local values"); + else + this->AddToValues(values); + } + this->MAC_Check::POpen_End(values, S, *receive_player); + this->base_player = (this->base_player + 1) % send_player.num_players(); +} + + + +template +Direct_MAC_Check::Direct_MAC_Check(const T& ai, Names& Nms, int num) : Separate_MAC_Check(ai, Nms, num) { + open_counter = 0; +} + + +template +Direct_MAC_Check::~Direct_MAC_Check() { + cerr << T::type_string() << " open counter: " << open_counter << endl; +} + + +template +void Direct_MAC_Check::POpen_Begin(vector& values,const vector >& S,const Player& P) +{ + values.size(); + this->os.reset_write_head(); + for (unsigned int i=0; ios); + this->timers[SEND].start(); + P.send_all(this->os,true); + this->timers[SEND].stop(); + + this->AddToMacs(S); + for (unsigned int i=0; ivals.push_back(S[i].get_share()); +} + +template +void direct_add_openings(vector& values, const Player& P, vector& os) +{ + for (unsigned int i=0; i(os[j].consume(T::size())); +} + +template +void Direct_MAC_Check::POpen_End(vector& values,const vector >& S,const Player& P) +{ + S.size(); + oss.resize(P.num_players()); + this->GetValues(values); + + this->timers[RECV].start(); + + for (int j=0; jtimers[RECV].stop(); + open_counter++; + + if (T::t() == 2) + direct_add_openings(values, P, oss); + else + direct_add_openings(values, P, oss); + + for (unsigned int i = 0; i < values.size(); i++) + this->vals[this->popen_cnt+i] = values[i]; + + this->popen_cnt += values.size(); + this->CheckIfNeeded(P); +} + +template class MAC_Check; +template class Direct_MAC_Check; +template class Parallel_MAC_Check; + +template class MAC_Check; +template class Direct_MAC_Check; +template class Parallel_MAC_Check; + +#ifdef USE_GF2N_LONG +template class MAC_Check; +template class Direct_MAC_Check; +template class Parallel_MAC_Check; +#endif diff --git a/Auth/MAC_Check.h b/Auth/MAC_Check.h new file mode 100644 index 000000000..b61cc0c75 --- /dev/null +++ b/Auth/MAC_Check.h @@ -0,0 +1,141 @@ +// (C) 2016 University of Bristol. See License.txt + +#ifndef _MAC_Check +#define _MAC_Check + +/* Class for storing MAC Check data and doing the Check */ + +#include +#include +using namespace std; + +#include "Math/Share.h" +#include "Networking/Player.h" +#include "Networking/ServerSocket.h" +#include "Auth/Summer.h" +#include "Tools/time-func.h" + + +/* The MAX number of things we will partially open before running + * a MAC Check + * + * Keep this at much less than 1MB of data to be able to cope with + * multi-threaded players + * + */ +#define POPEN_MAX 1000000 + + +template +class MAC_Check +{ + protected: + + /* POpen Data */ + int popen_cnt; + vector macs; + vector vals; + int base_player; + int opening_sum; + int max_broadcast; + octetStream os; + + /* MAC Share */ + T alphai; + + void AddToMacs(const vector< Share >& shares); + void AddToValues(vector& values); + void ReceiveValues(vector& values, const Player& P, int sender); + void GetValues(vector& values); + void CheckIfNeeded(const Player& P); + int WaitingForCheck() + { return max(macs.size(), vals.size()); } + + public: + + int values_opened; + vector timers; + vector player_timers; + vector oss; + + MAC_Check(const T& ai, int opening_sum=10, int max_broadcast=10, int send_player=0); + virtual ~MAC_Check(); + + /* Run protocols to partially open data and check the MACs are + * all OK. + * - Implicit assume that the amount of data being sent does + * not overload the OS + * Begin and End expect the same arrays values and S passed to them + * and they expect values to be of the same size as S. + */ + virtual void POpen_Begin(vector& values,const vector >& S,const Player& P); + virtual void POpen_End(vector& values,const vector >& S,const Player& P); + void AddToCheck(const T& mac, const T& value, const Player& P); + virtual void Check(const Player& P); + + int number() const { return values_opened; } + + const T& get_alphai() const { return alphai; } +}; + + +template + void add_openings(vector& values, const Player& P, int sum_players, int last_sum_players, int send_player, MAC_Check& MC); + + +template +class Separate_MAC_Check: public MAC_Check +{ + // Different channel for checks + Player check_player; + +protected: + // No sense to expose this + Separate_MAC_Check(const T& ai, Names& Nms, int thread_num, int opening_sum=10, int max_broadcast=10, int send_player=0); + virtual ~Separate_MAC_Check() {}; + +public: + virtual void Check(const Player& P); +}; + + +template +class Parallel_MAC_Check: public Separate_MAC_Check +{ + // Different channel for every round + Player send_player; + // Managed by Summer + Player* receive_player; + + vector< Summer* > summers; + + int send_base_player; + + WaitQueue< vector > value_queue; + +public: + Parallel_MAC_Check(const T& ai, Names& Nms, int thread_num, int opening_sum=10, int max_broadcast=10, int send_player=0); + virtual ~Parallel_MAC_Check(); + + virtual void POpen_Begin(vector& values,const vector >& S,const Player& P); + virtual void POpen_End(vector& values,const vector >& S,const Player& P); + + friend class Summer; +}; + + +template +class Direct_MAC_Check: public Separate_MAC_Check +{ + int open_counter; + vector oss; + +public: + Direct_MAC_Check(const T& ai, Names& Nms, int thread_num); + ~Direct_MAC_Check(); + + void POpen_Begin(vector& values,const vector >& S,const Player& P); + void POpen_End(vector& values,const vector >& S,const Player& P); +}; + +#endif diff --git a/Auth/Subroutines.cpp b/Auth/Subroutines.cpp new file mode 100644 index 000000000..efaff68f9 --- /dev/null +++ b/Auth/Subroutines.cpp @@ -0,0 +1,235 @@ +// (C) 2016 University of Bristol. See License.txt + + +#include "Auth/Subroutines.h" + +#include "Tools/random.h" +#include "Exceptions/Exceptions.h" +#include "Tools/Commit.h" + +/* To ease readability as I re-write this program the following conventions + * will be used. + * For a variable v index by a player i + * Comm_v[i] is the commitment string for player i + * Open_v[i] is the opening data for player i + */ + + +// Special version for octetStreams +void Commit(vector< vector >& Comm_data, + vector& Open_data, + const vector< vector >& data,const Player& P,int num_runs) +{ + int my_number=P.my_num(); + for (int i=0; i >& data, + const vector< vector >& Comm_data, + const vector& My_Open_data, + const Player& P,int num_runs,int dont) +{ + int my_number=P.my_num(); + int num_players=P.num_players(); + vector Open_data(num_players); + for (int i=0; i >& data, + const vector< vector >& Comm_data, + const vector& My_Open_data, + const vector open, + const Player& P,int num_runs) +{ + int my_number=P.my_num(); + int num_players=P.num_players(); + vector Open_data(num_players); + for (int i=0; i& e, + vector& Comm_e,vector& Open_e, + const Player& P,int num_runs) +{ + PRNG G; + G.ReSeed(); + + e.resize(P.num_players()); + Comm_e.resize(P.num_players()); + Open_e.resize(P.num_players()); + + e[P.my_num()]=G.get_uint()%num_runs; + octetStream ee; ee.store(e[P.my_num()]); + Commit(Comm_e[P.my_num()],Open_e[P.my_num()],ee,P.my_num()); + P.Broadcast_Receive(Comm_e); +} + + +int Open_Challenge(vector& e,vector& Open_e, + const vector& Comm_e, + const Player& P,int num_runs) +{ + // Now open the challenge commitments and determine which run was for real + P.Broadcast_Receive(Open_e); + + int challenge=0; + octetStream ee; + for (int i = 0; i < P.num_players(); i++) + { if (i != P.my_num()) + { if (!Open(ee,Comm_e[i],Open_e[i],i)) + { throw invalid_commitment(); } + ee.get(e[i]); + } + challenge+=e[i]; + } + challenge = challenge % num_runs; + + return challenge; +} + + +template +void Create_Random(T& ans,const Player& P) +{ + PRNG G; + G.ReSeed(); + vector e(P.num_players()); + vector Comm_e(P.num_players()); + vector Open_e(P.num_players()); + + e[P.my_num()].randomize(G); + octetStream ee; + e[P.my_num()].pack(ee); + Commit(Comm_e[P.my_num()],Open_e[P.my_num()],ee,P.my_num()); + P.Broadcast_Receive(Comm_e); + + P.Broadcast_Receive(Open_e); + + ans.assign_zero(); + for (int i = 0; i < P.num_players(); i++) + { if (i != P.my_num()) + { if (!Open(ee,Comm_e[i],Open_e[i],i)) + { throw invalid_commitment(); } + e[i].unpack(ee); + } + ans.add(ans,e[i]); + } +} + + +void Create_Random_Seed(octet* seed,const Player& P,int len) +{ + PRNG G; + G.ReSeed(); + vector e(P.num_players()); + vector Comm_e(P.num_players()); + vector Open_e(P.num_players()); + + G.get_octetStream(e[P.my_num()],len); + Commit(Comm_e[P.my_num()],Open_e[P.my_num()],e[P.my_num()],P.my_num()); + P.Broadcast_Receive(Comm_e); + + P.Broadcast_Receive(Open_e); + + memset(seed,0,len*sizeof(octet)); + for (int i = 0; i < P.num_players(); i++) + { if (i != P.my_num()) + { if (!Open(e[i],Comm_e[i],Open_e[i],i)) + { throw invalid_commitment(); } + } + for (int j=0; j +void Commit_And_Open(vector& data,const Player& P) +{ + vector Comm_data(P.num_players()); + vector Open_data(P.num_players()); + + octetStream ee; + data[P.my_num()].pack(ee); + Commit(Comm_data[P.my_num()],Open_data[P.my_num()],ee,P.my_num()); + P.Broadcast_Receive(Comm_data); + + P.Broadcast_Receive(Open_data); + + for (int i = 0; i < P.num_players(); i++) + { if (i != P.my_num()) + { if (!Open(ee,Comm_data[i],Open_data[i],i)) + { throw invalid_commitment(); } + data[i].unpack(ee); + } + } +} + + +void Commit_To_Seeds(vector& G, + vector< vector >& seeds, + vector< vector >& Comm_seeds, + vector& Open_seeds, + const Player& P,int num_runs) +{ + seeds.resize(num_runs); + Comm_seeds.resize(num_runs); + Open_seeds.resize(num_runs); + for (int i=0; i& data,const Player& P); +template void Create_Random(gf2n& ans,const Player& P); + +#ifdef USE_GF2N_LONG +template void Commit_And_Open(vector& data,const Player& P); +template void Create_Random(gf2n_short& ans,const Player& P); +#endif + +template void Commit_And_Open(vector& data,const Player& P); +template void Create_Random(gfp& ans,const Player& P); diff --git a/Auth/Subroutines.h b/Auth/Subroutines.h new file mode 100644 index 000000000..6680718f4 --- /dev/null +++ b/Auth/Subroutines.h @@ -0,0 +1,140 @@ +// (C) 2016 University of Bristol. See License.txt + +#ifndef _Subroutines +#define _Subroutines + +/* Defines subroutines for use in both KeyGen and Offline phase + * Mainly focused around commiting and decommitting to various + * bits of data + */ + +#include "Tools/random.h" +#include "Networking/Player.h" +#include "Math/gf2n.h" +#include "Math/gfp.h" +#include "Tools/Commit.h" + +/* Run just the Open Protocol for data[i][j] of type octetStream + * 0 <= i < num_runs + * 0 <= j < num_players + * On output data[i][j] contains all the data + * If dont!=-1 then dont open this run + */ +void Open(vector< vector >& data, + const vector< vector >& Comm_data, + const vector& My_Open_data, + const Player& P,int num_runs,int dont=-1); + +/* This one takes a vector open which contains 0 and 1 + * If 1 then we open this value, otherwise we do not + */ +void Open(vector< vector >& data, + const vector< vector >& Comm_data, + const vector& My_Open_data, + const vector open, + const Player& P,int num_runs); + + + + +/* This runs the Commit and Open Protocol for data[i][j] of type T + * 0 <= i < num_runs + * 0 <= j < num_players + * On input data[i][j] is only defined for j=my_number + */ +template +void Commit_And_Open(vector< vector >& data,const Player& P,int num_runs); + +template +void Commit_And_Open(vector& data,const Player& P); + +template +void Transmit_Data(vector< vector >& data,const Player& P,int num_runs); + +/* Functions to Commit and Open a Challenge Value */ +void Commit_To_Challenge(vector& e, + vector& Comm_e,vector& Open_e, + const Player& P,int num_runs); + +int Open_Challenge(vector& e,vector& Open_e, + const vector& Comm_e, + const Player& P,int num_runs); + + +/* Function to create a shared random value for T=gfp/gf2n */ +template +void Create_Random(T& ans,const Player& P); + +/* Produce a random seed of length len */ +void Create_Random_Seed(octet* seed,const Player& P,int len); + + + +/* Functions to Commit to Seed Values + * This also initialises the PRNG's in G + */ +void Commit_To_Seeds(vector& G, + vector< vector >& seeds, + vector< vector >& Comm_seeds, + vector& Open_seeds, + const Player& P,int num_runs); + + + +/* Run just the Commit Protocol for data[i][j] of type T + * 0 <= i < num_runs + * 0 <= j < num_players + * On input data[i][j] is only defined for j=my_number + */ +template +void Commit(vector< vector >& Comm_data, + vector& Open_data, + const vector< vector >& data,const Player& P,int num_runs) +{ + octetStream os; + int my_number=P.my_num(); + for (int i=0; i +void Open(vector< vector >& data, + const vector< vector >& Comm_data, + const vector& My_Open_data, + const Player& P,int num_runs,int dont=-1) +{ + octetStream os; + int my_number=P.my_num(); + int num_players=P.num_players(); + vector Open_data(num_players); + for (int i=0; i +Summer::Summer(int sum_players, int last_sum_players, int next_sum_players, + Player* send_player, Player* receive_player, Parallel_MAC_Check& MC) : + sum_players(sum_players), last_sum_players(last_sum_players), next_sum_players(next_sum_players), + base_player(0), MC(MC),send_player(send_player), receive_player(receive_player), + thread(0), stop(false), size(0) +{ + cout << "Setting up summation by " << sum_players << " players" << endl; +} + +template +Summer::~Summer() +{ + delete send_player; + if (timer.elapsed()) + cout << T::type_string() << " summation by " << sum_players << " players: " + << timer.elapsed() << endl; +} + +template +void Summer::run() +{ + octetStream os; + + while (true) + { + int size = 0; + if (!input_queue.pop(size)) + break; + + int my_relative_num = positive_modulo(send_player->my_num() - base_player, send_player->num_players()); + if (my_relative_num < sum_players) + { + // first summer takes inputs from queue + if (last_sum_players == send_player->num_players()) + share_queue.pop(values); + else + { + values.resize(size); + receive_player->receive_player(receive_player->my_num(),os,true); + for (int i = 0; i < size; i++) + values[i].unpack(os); + } + + timer.start(); + if (T::t() == 2) + add_openings(values, *receive_player, sum_players, + last_sum_players, base_player, MC); + else + add_openings(values, *receive_player, sum_players, + last_sum_players, base_player, MC); + timer.stop(); + + os.reset_write_head(); + for (int i = 0; i < size; i++) + values[i].pack(os); + + if (sum_players > 1) + { + int receiver = positive_modulo(base_player + my_relative_num % next_sum_players, + send_player->num_players()); + send_player->send_to(receiver, os, true); + } + else + { + send_player->send_all(os); + MC.value_queue.push(values); + } + } + + if (sum_players == 1) + output_queue.push(size); + + base_player = (base_player + 1) % send_player->num_players(); + } +} + +template class Summer; +template class Summer; + +#ifdef USE_GF2N_LONG +template class Summer; +#endif diff --git a/Auth/Summer.h b/Auth/Summer.h new file mode 100644 index 000000000..4a7c54046 --- /dev/null +++ b/Auth/Summer.h @@ -0,0 +1,47 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * Summer.h + * + */ + +#ifndef OFFLINE_SUMMER_H_ +#define OFFLINE_SUMMER_H_ + +#include "Networking/Player.h" +#include "Tools/WaitQueue.h" +#include "Tools/time-func.h" + +#include +#include +using namespace std; + +template +class Parallel_MAC_Check; + +template +class Summer +{ + int sum_players, last_sum_players, next_sum_players; + int base_player; + Parallel_MAC_Check& MC; + Player* send_player; + Player* receive_player; + Timer timer; + +public: + vector values; + + pthread_t thread; + WaitQueue input_queue, output_queue; + bool stop; + int size; + WaitQueue< vector > share_queue; + + Summer(int sum_players, int last_sum_players, int next_sum_players, + Player* send_player, Player* receive_player, Parallel_MAC_Check& MC); + ~Summer(); + void run(); +}; + +#endif /* OFFLINE_SUMMER_H_ */ diff --git a/Auth/fake-stuff.cpp b/Auth/fake-stuff.cpp new file mode 100644 index 000000000..40b0f4261 --- /dev/null +++ b/Auth/fake-stuff.cpp @@ -0,0 +1,171 @@ +// (C) 2016 University of Bristol. See License.txt + + +#include "Math/gf2n.h" +#include "Math/gfp.h" +#include "Math/Share.h" + +#include + +template +void make_share(vector >& Sa,const T& a,int N,const T& key,PRNG& G) +{ + T mac,x,y; + mac.mul(a,key); + Share S; + S.set_share(a); + S.set_mac(mac); + + for (int i=0; i +void check_share(vector >& Sa,T& value,T& mac,int N,const T& key) +{ + value.assign(0); + mac.assign(0); + + for (int i=0; i >& Sa,const gf2n& a,int N,const gf2n& key,PRNG& G); +template void make_share(vector >& Sa,const gfp& a,int N,const gfp& key,PRNG& G); + +template void check_share(vector >& Sa,gf2n& value,gf2n& mac,int N,const gf2n& key); +template void check_share(vector >& Sa,gfp& value,gfp& mac,int N,const gfp& key); + +#ifdef USE_GF2N_LONG +template void make_share(vector >& Sa,const gf2n_short& a,int N,const gf2n_short& key,PRNG& G); +template void check_share(vector >& Sa,gf2n_short& value,gf2n_short& mac,int N,const gf2n_short& key); +#endif + +// Expansion is by x=y^5+1 (as we embed GF(256) into GF(2^40) +void expand_byte(gf2n_short& a,int b) +{ + gf2n_short x,xp; + x.assign(32+1); + xp.assign_one(); + a.assign_zero(); + + while (b!=0) + { if ((b&1)==1) + { a.add(a,xp); } + xp.mul(x); + b>>=1; + } +} + + +// Have previously worked out the linear equations we need to solve +void collapse_byte(int& b,const gf2n_short& aa) +{ + word w=aa.get(); + int e35=(w>>35)&1; + int e30=(w>>30)&1; + int e25=(w>>25)&1; + int e20=(w>>20)&1; + int e15=(w>>15)&1; + int e10=(w>>10)&1; + int e5=(w>>5)&1; + int e0=w&1; + int a[8]; + a[7]=e35; + a[6]=e30^a[7]; + a[5]=e25^a[7]; + a[4]=e20^a[5]^a[6]^a[7]; + a[3]=e15^a[7]; + a[2]=e10^a[3]^a[6]^a[7]; + a[1]=e5^a[3]^a[5]^a[7]; + a[0]=e0^a[1]^a[2]^a[3]^a[4]^a[5]^a[6]^a[7]; + + b=0; + for (int i=7; i>=0; i--) + { b=b<<1; + b+=a[i]; + } +} + +void generate_keys(const string& directory, int nplayers) +{ + PRNG G; + G.ReSeed(); + + gf2n mac2; + gfp macp; + mac2.assign_zero(); + macp.assign_zero(); + + ofstream outf; + + for (int i = 0; i < nplayers; i++) + { + stringstream filename; + filename << directory << "Player-MAC-Keys-P" << i; + mac2.randomize(G); + macp.randomize(G); + cout << "Writing to " << filename.str().c_str() << endl; + outf.open(filename.str().c_str()); + outf << nplayers << endl; + macp.output(outf,true); + outf << " "; + mac2.output(outf,true); + outf << endl; + outf.close(); + } +} + +void read_keys(const string& directory, gfp& keyp, gf2n& key2, int nplayers) +{ + gfp sharep; + gf2n share2; + keyp.assign_zero(); + key2.assign_zero(); + int i, tmpN = 0; + ifstream inpf; + + for (i = 0; i < nplayers; i++) + { + stringstream filename; + filename << directory << "Player-MAC-Keys-P" << i; + inpf.open(filename.str().c_str()); + if (inpf.fail()) + { + inpf.close(); + cout << "Error: No MAC key share found for player " << i << std::endl; + exit(1); + } + else + { + inpf >> tmpN; // not needed here + sharep.input(inpf,true); + share2.input(inpf,true); + inpf.close(); + } + std::cout << " Key " << i << "\t p: " << sharep << "\n\t 2: " << share2 << std::endl; + keyp.add(sharep); + key2.add(share2); + } + std::cout << "Final MAC keys :\t p: " << keyp << "\n\t\t 2: " << key2 << std::endl; +} diff --git a/Auth/fake-stuff.h b/Auth/fake-stuff.h new file mode 100644 index 000000000..3fc6783de --- /dev/null +++ b/Auth/fake-stuff.h @@ -0,0 +1,64 @@ +// (C) 2016 University of Bristol. See License.txt + + +#ifndef _fake_stuff +#define _fake_stuff + +#include "Math/gf2n.h" +#include "Math/gfp.h" +#include "Math/Share.h" + +#include +using namespace std; + +template +void make_share(vector >& Sa,const T& a,int N,const T& key,PRNG& G); + +template +void check_share(vector >& Sa,T& value,T& mac,int N,const T& key); + +void expand_byte(gf2n_short& a,int b); +void collapse_byte(int& b,const gf2n_short& a); + +// Generate MAC key shares +void generate_keys(const string& directory, int nplayers); + +// Read MAC key shares and compute keys +void read_keys(const string& directory, gfp& keyp, gf2n& key2, int nplayers); + +template +class Files +{ +public: + ofstream* outf; + int N; + T key; + PRNG G; + Files(int N, const T& key, const string& prefix) : N(N), key(key) + { + outf = new ofstream[N]; + for (int i=0; i > Sa(N); + make_share(Sa,a,N,key,G); + for (int j=0; j for optimization +# AVX2 support (Haswell or later) changes the bit matrix transpose +ARCH = -mtune=native + +#use CONFIG.mine to overwrite DIR settings +-include CONFIG.mine + +ifeq ($(USE_GF2N_LONG),1) +GF2N_LONG = -DUSE_GF2N_LONG +endif + +# MAX_MOD_SZ must be at least ceil(len(p)/len(word))+1 +# Default is 3, which suffices for 128-bit p +# MOD = -DMAX_MOD_SZ=3 + +LDLIBS = -lmpirxx -lmpir $(MY_LDLIBS) -lm -lpthread + +ifeq ($(USE_NTL),1) +LDLIBS := -lntl $(LDLIBS) +endif + +OS := $(shell uname -s) +ifeq ($(OS), Linux) +LDLIBS += -lrt +endif + +CXX = g++ +CFLAGS = $(MY_CFLAGS) -g -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) -maes -mpclmul -msse4.1 $(ARCH) +CPPFLAGS = $(CFLAGS) +LD = g++ + diff --git a/Check-Offline.cpp b/Check-Offline.cpp new file mode 100644 index 000000000..44f5f75da --- /dev/null +++ b/Check-Offline.cpp @@ -0,0 +1,203 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * Check-Offline.cpp + * + */ + +#include "Math/gf2n.h" +#include "Math/gfp.h" +#include "Math/Share.h" +#include "Auth/fake-stuff.h" +#include "Tools/ezOptionParser.h" +#include "Exceptions/Exceptions.h" + +#include "Math/Setup.h" +#include "Processor/Data_Files.h" + +#include +#include +#include +using namespace std; + +string PREP_DATA_PREFIX; + +template +void check_mult_triples(const T& key,int N,vector& dataF,DataFieldType field_type) +{ + T a,b,c,mac,res; + vector > Sa(N),Sb(N),Sc(N); + int n = 0; + + try { + while (!dataF[0]->eof(DATA_TRIPLE)) + { + for (int i = 0; i < N; i++) + dataF[i]->get_three(field_type, DATA_TRIPLE, Sa[i], Sb[i], Sc[i]); + check_share(Sa, a, mac, N, key); + check_share(Sb, b, mac, N, key); + check_share(Sc, c, mac, N, key); + + res.mul(a, b); + if (!res.equal(c)) + { + cout << n << ": " << c << " != " << a << " * " << b << endl; + throw bad_value(); + } + n++; + } + + cout << n << " triples of type " << T::type_string() << endl; + } + catch (exception& e) + { + cout << "Error with triples of type " << T::type_string() << endl; + } +} + +template +void check_bits(const T& key,int N,vector& dataF,DataFieldType field_type) +{ + T a,b,c,mac,res; + vector > Sa(N),Sb(N),Sc(N); + int n = 0; + + while (!dataF[0]->eof(DATA_BIT)) + { + for (int i = 0; i < N; i++) + dataF[i]->get_one(field_type, DATA_BIT, Sa[i]); + check_share(Sa, a, mac, N, key); + + if (!(a.is_zero() || a.is_one())) + { + cout << n << ": " << a << " neither 0 or 1" << endl; + throw bad_value(); + } + n++; + } + + cout << n << " bits of type " << T::type_string() << endl; +} + +template +void check_inputs(const T& key,int N,vector& dataF) +{ + T a, mac, x; + vector< Share > Sa(N); + + for (int player = 0; player < N; player++) + { + int n = 0; + while (!dataF[0]->input_eof(player)) + { + for (int i = 0; i < N; i++) + dataF[i]->get_input(Sa[i], x, player); + check_share(Sa, a, mac, N, key); + if (!a.equal(x)) + throw bad_value(); + n++; + } + cout << n << " input masks for player " << player << " of type " << T::type_string() << endl; + } +} + +int main(int argc, const char** argv) +{ + ez::ezOptionParser opt; + gfp::init_field(gfp::pr(), false); + + opt.syntax = "./Check-Offline.x [OPTIONS]\n"; + opt.example = "./Check-Offline.x 3 -lgp 64 -lg2 128\n"; + + opt.add( + "128", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Bit length of GF(p) field (default: 128)", // Help description. + "-lgp", // Flag token. + "--lgp" // Flag token. + ); + opt.add( + "40", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Bit length of GF(2^n) field (default: 40)", // Help description. + "-lg2", // Flag token. + "--lg2" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Read GF(p) triples in Montgomery representation (default: not set)", // Help description. + "-m", // Flag token. + "--usemont" // Flag token. + ); + opt.parse(argc, argv); + + string usage; + int lgp, lg2, nparties; + bool use_montgomery = false; + opt.get("--lgp")->getInt(lgp); + opt.get("--lg2")->getInt(lg2); + if (opt.isSet("--usemont")) + use_montgomery = true; + + if (opt.firstArgs.size() == 2) + nparties = atoi(opt.firstArgs[1]->c_str()); + else if (opt.lastArgs.size() == 1) + nparties = atoi(opt.lastArgs[0]->c_str()); + else + { + cerr << "ERROR: invalid number of arguments\n"; + opt.getUsage(usage); + cout << usage; + return 1; + } + + PREP_DATA_PREFIX = get_prep_dir(nparties, lgp, lg2); + read_setup(PREP_DATA_PREFIX); + + if (!use_montgomery) + { + // no montgomery + gfp::init_field(gfp::pr(), false); + } + + /* Find number players and MAC keys etc*/ + char filename[1024]; + gfp keyp,pp; keyp.assign_zero(); + gf2n key2,p2; key2.assign_zero(); + int N=1; + ifstream inpf; + for (int i= 0; i < nparties; i++) + { + sprintf(filename, (PREP_DATA_PREFIX + "Player-MAC-Keys-P%d").c_str(), i); + inpf.open(filename); + if (inpf.fail()) { throw file_error(filename); } + inpf >> N; + pp.input(inpf,true); + p2.input(inpf,true); + cout << " Key " << i << "\t p: " << pp << "\n\t 2: " << p2 << endl; + keyp.add(pp); + key2.add(p2); + inpf.close(); + } + cout << "--------------\n"; + cout << "Final Keys :\t p: " << keyp << "\n\t\t 2: " << key2 << endl; + + vector dataF(N); + for (int i = 0; i < N; i++) + dataF[i] = new Data_Files(i, N, PREP_DATA_PREFIX); + check_mult_triples(key2, N, dataF, DATA_GF2N); + check_mult_triples(keyp, N, dataF, DATA_MODP); + check_inputs(key2, N, dataF); + check_inputs(keyp, N, dataF); + check_bits(key2, N, dataF, DATA_GF2N); + check_bits(keyp, N, dataF, DATA_MODP); + for (int i = 0; i < N; i++) + delete dataF[i]; +} diff --git a/Compiler/__init__.py b/Compiler/__init__.py new file mode 100644 index 000000000..417a7bff0 --- /dev/null +++ b/Compiler/__init__.py @@ -0,0 +1,31 @@ +# (C) 2016 University of Bristol. See License.txt + +import compilerLib, program, instructions, types, library, floatingpoint +import inspect +from config import * +from compilerLib import run + + +# add all instructions to the program VARS dictionary +compilerLib.VARS = {} +instr_classes = [t[1] for t in inspect.getmembers(instructions, inspect.isclass)] + +instr_classes += [t[1] for t in inspect.getmembers(types, inspect.isclass)\ + if t[1].__module__ == types.__name__] + +instr_classes += [t[1] for t in inspect.getmembers(library, inspect.isfunction)\ + if t[1].__module__ == library.__name__] + +for op in instr_classes: + compilerLib.VARS[op.__name__] = op + +# add open and input separately due to name conflict +compilerLib.VARS['open'] = instructions.asm_open +compilerLib.VARS['vopen'] = instructions.vasm_open +compilerLib.VARS['gopen'] = instructions.gasm_open +compilerLib.VARS['vgopen'] = instructions.vgasm_open +compilerLib.VARS['input'] = instructions.asm_input +compilerLib.VARS['ginput'] = instructions.gasm_input + +compilerLib.VARS['comparison'] = comparison +compilerLib.VARS['floatingpoint'] = floatingpoint diff --git a/Compiler/allocator.py b/Compiler/allocator.py new file mode 100644 index 000000000..4fffdc48d --- /dev/null +++ b/Compiler/allocator.py @@ -0,0 +1,653 @@ +# (C) 2016 University of Bristol. See License.txt + +import itertools, time +from collections import defaultdict, deque +from Compiler.exceptions import * +from Compiler.config import * +from Compiler.instructions import * +from Compiler.instructions_base import * +from Compiler.util import * +import Compiler.graph +import Compiler.program +import heapq, itertools +import operator + + +class StraightlineAllocator: + """Allocate variables in a straightline program using n registers. + It is based on the precondition that every register is only defined once.""" + def __init__(self, n): + self.free = defaultdict(set) + self.alloc = {} + self.usage = Compiler.program.RegType.create_dict(lambda: 0) + self.defined = {} + self.dealloc = set() + self.n = n + + def alloc_reg(self, reg, persistent_allocation): + base = reg.vectorbase + if base in self.alloc: + # already allocated + return + + reg_type = reg.reg_type + size = base.size + if not persistent_allocation and self.free[reg_type, size]: + res = self.free[reg_type, size].pop() + else: + if self.usage[reg_type] < self.n: + res = self.usage[reg_type] + self.usage[reg_type] += size + else: + raise RegisterOverflowError() + self.alloc[base] = res + + if base.vector: + for i,r in enumerate(base.vector): + r.i = self.alloc[base] + i + else: + base.i = self.alloc[base] + + def dealloc_reg(self, reg, inst): + self.dealloc.add(reg) + base = reg.vectorbase + + if base.vector and not inst.is_vec(): + for i in base.vector: + if i not in self.dealloc: + # not all vector elements ready for deallocation + return + self.free[reg.reg_type, base.size].add(self.alloc[base]) + if inst.is_vec() and base.vector: + for i in base.vector: + self.defined[i] = inst + else: + self.defined[reg] = inst + + def process(self, program, persistent_allocation=False): + for k,i in enumerate(reversed(program)): + unused_regs = [] + for j in i.get_def(): + if j.vectorbase in self.alloc: + if j in self.defined: + raise CompilerError("Double write on register %s " \ + "assigned by '%s' in %s" % \ + (j,i,format_trace(i.caller))) + else: + # unused register + self.alloc_reg(j, persistent_allocation) + unused_regs.append(j) + if unused_regs and len(unused_regs) == len(i.get_def()): + # only report if all assigned registers are unused + print "Register(s) %s never used, assigned by '%s' in %s" % \ + (unused_regs,i,format_trace(i.caller)) + + for j in i.get_used(): + self.alloc_reg(j, persistent_allocation) + for j in i.get_def(): + self.dealloc_reg(j, i) + + if k % 1000000 == 0 and k > 0: + print "Allocated registers for %d instructions at" % k, time.asctime() + + # print "Successfully allocated registers" + # print "modp usage: %d clear, %d secret" % \ + # (self.usage[Compiler.program.RegType.ClearModp], self.usage[Compiler.program.RegType.SecretModp]) + # print "GF2N usage: %d clear, %d secret" % \ + # (self.usage[Compiler.program.RegType.ClearGF2N], self.usage[Compiler.program.RegType.SecretGF2N]) + return self.usage + + +def determine_scope(block): + last_def = defaultdict(lambda: -1) + used_from_scope = set() + + def find_in_scope(reg, scope): + if scope is None: + return False + elif reg in scope.defined_registers: + return True + else: + return find_in_scope(reg, scope.scope) + + def read(reg, n): + if last_def[reg] == -1: + if find_in_scope(reg, block.scope): + used_from_scope.add(reg) + reg.can_eliminate = False + else: + print 'Warning: read before write at register', reg + print '\tline %d: %s' % (n, instr) + print '\tinstruction trace: %s' % format_trace(instr.caller, '\t\t') + print '\tregister trace: %s' % format_trace(reg.caller, '\t\t') + + def write(reg, n): + if last_def[reg] != -1: + print 'Warning: double write at register', reg + print '\tline %d: %s' % (n, instr) + print '\ttrace: %s' % format_trace(instr.caller, '\t\t') + last_def[reg] = n + + for n,instr in enumerate(block.instructions): + outputs,inputs = instr.get_def(), instr.get_used() + for reg in inputs: + if reg.vector and instr.is_vec(): + for i in reg.vector: + read(i, n) + else: + read(reg, n) + for reg in outputs: + if reg.vector and instr.is_vec(): + for i in reg.vector: + write(i, n) + else: + write(reg, n) + + block.used_from_scope = used_from_scope + block.defined_registers = set(last_def.iterkeys()) + +class Merger: + def __init__(self, block, options): + self.block = block + self.instructions = block.instructions + self.options = options + if options.max_parallel_open: + self.max_parallel_open = int(options.max_parallel_open) + else: + self.max_parallel_open = float('inf') + self.dependency_graph() + + def do_merge(self, merges_iter): + """ Merge an iterable of nodes in G, returning the number of merged + instructions and the index of the merged instruction. """ + instructions = self.instructions + mergecount = 0 + try: + n = next(merges_iter) + except StopIteration: + return mergecount, None + + def expand_vector_args(inst): + new_args = [] + for arg in inst.args: + if inst.is_vec(): + arg.create_vector_elements() + for reg in arg: + new_args.append(reg) + else: + new_args.append(arg) + return new_args + + for i in merges_iter: + if isinstance(instructions[n], startinput_class): + instructions[n].args[1] += instructions[i].args[1] + elif isinstance(instructions[n], (stopinput, gstopinput)): + if instructions[n].get_size() != instructions[i].get_size(): + raise NotImplemented() + else: + instructions[n].args += instructions[i].args[1:] + else: + if instructions[n].get_size() != instructions[i].get_size(): + # merge as non-vector instruction + instructions[n].args = expand_vector_args(instructions[n]) + \ + expand_vector_args(instructions[i]) + if instructions[n].is_vec(): + instructions[n].size = 1 + else: + instructions[n].args += instructions[i].args + + # join arg_formats if not special iterators + # if not isinstance(instructions[n].arg_format, (itertools.repeat, itertools.cycle)) and \ + # not isinstance(instructions[i].arg_format, (itertools.repeat, itertools.cycle)): + # instructions[n].arg_format += instructions[i].arg_format + instructions[i] = None + self.merge_nodes(n, i) + mergecount += 1 + + return mergecount, n + + def compute_max_depths(self, depth_of): + """ Compute the maximum 'depth' at which every instruction can be placed. + This is the minimum depth of any merge_node succeeding an instruction. + + Similar to DAG shortest paths algorithm. Traverses the graph in reverse + topological order, updating the max depth of each node's predecessors. + """ + G = self.G + merge_nodes_set = self.open_nodes + top_order = Compiler.graph.topological_sort(G) + max_depth_of = [None] * len(G) + max_depth = max(depth_of) + + for i in range(len(max_depth_of)): + if i in merge_nodes_set: + max_depth_of[i] = depth_of[i] - 1 + else: + max_depth_of[i] = max_depth + + for u in reversed(top_order): + for v in G.pred[u]: + if v not in merge_nodes_set: + max_depth_of[v] = min(max_depth_of[u], max_depth_of[v]) + return max_depth_of + + def merge_inputs(self): + merges = defaultdict(list) + remaining_input_nodes = [] + def do_merge(nodes): + if len(nodes) > 1000: + print 'Merging %d inputs...' % len(nodes) + self.do_merge(iter(nodes)) + for n in self.input_nodes: + inst = self.instructions[n] + merge = merges[inst.args[0],inst.__class__] + if len(merge) == 0: + remaining_input_nodes.append(n) + merge.append(n) + if len(merge) >= self.max_parallel_open: + do_merge(merge) + merge[:] = [] + for merge in merges.itervalues(): + if merge: + do_merge(merge) + self.input_nodes = remaining_input_nodes + + def compute_preorder(self, merges, rev_depth_of): + # find flexible nodes that can be on several levels + # and find sources on level 0 + G = self.G + merge_nodes_set = self.open_nodes + depth_of = self.depths + instructions = self.instructions + flex_nodes = defaultdict(dict) + starters = [] + for n in xrange(len(G)): + if n not in merge_nodes_set and \ + depth_of[n] != rev_depth_of[n] and G[n] and G.get_attr(n,'start') == -1 and not isinstance(instructions[n], AsymmetricCommunicationInstruction): + #print n, depth_of[n], rev_depth_of[n] + flex_nodes[depth_of[n]].setdefault(rev_depth_of[n], set()).add(n) + elif len(G.pred[n]) == 0 and \ + not isinstance(self.instructions[n], RawInputInstruction): + starters.append(n) + if n % 10000000 == 0 and n > 0: + print "Processed %d nodes at" % n, time.asctime() + + inputs = defaultdict(list) + for node in self.input_nodes: + player = self.instructions[node].args[0] + inputs[player].append(node) + first_inputs = [l[0] for l in inputs.itervalues()] + other_inputs = [] + i = 0 + while True: + i += 1 + found = False + for l in inputs.itervalues(): + if i < len(l): + other_inputs.append(l[i]) + found = True + if not found: + break + other_inputs.reverse() + + preorder = [] + # magical preorder for topological search + max_depth = max(merges) + if max_depth > 10000: + print "Computing pre-ordering ..." + for i in xrange(max_depth, 0, -1): + preorder.append(G.get_attr(merges[i], 'stop')) + for j in flex_nodes[i-1].itervalues(): + preorder.extend(j) + preorder.extend(flex_nodes[0].get(i, [])) + preorder.append(merges[i]) + if i % 100000 == 0 and i > 0: + print "Done level %d at" % i, time.asctime() + preorder.extend(other_inputs) + preorder.extend(starters) + preorder.extend(first_inputs) + if max_depth > 10000: + print "Done at", time.asctime() + return preorder + + def compute_continuous_preorder(self, merges, rev_depth_of): + print 'Computing pre-ordering for continuous computation...' + preorder = [] + sources_for = defaultdict(list) + stops_in = defaultdict(list) + startinputs = [] + stopinputs = [] + for source in self.sources: + sources_for[rev_depth_of[source]].append(source) + for merge in merges.itervalues(): + stop = self.G.get_attr(merge, 'stop') + stops_in[rev_depth_of[stop]].append(stop) + for node in self.input_nodes: + if isinstance(self.instructions[node], startinput_class): + startinputs.append(node) + else: + stopinputs.append(node) + max_round = max(rev_depth_of) + for i in xrange(max_round, 0, -1): + preorder.extend(reversed(stops_in[i])) + preorder.extend(reversed(sources_for[i])) + # inputs at the beginning + preorder.extend(reversed(stopinputs)) + preorder.extend(reversed(sources_for[0])) + preorder.extend(reversed(startinputs)) + return preorder + + def longest_paths_merge(self, instruction_type=startopen_class, + merge_stopopens=True): + """ Attempt to merge instructions of type instruction_type (which are given in + merge_nodes) using longest paths algorithm. + + Returns the no. of rounds of communication required after merging (assuming 1 round/instruction). + + If merge_stopopens is True, will also merge associated stop_open instructions. + If reorder_between_opens is True, will attempt to place non-opens between start/stop opens. + + Doesn't use networkx. + """ + G = self.G + instructions = self.instructions + merge_nodes = self.open_nodes + depths = self.depths + if instruction_type is not startopen_class and merge_stopopens: + raise CompilerError('Cannot merge stopopens whilst merging %s instructions' % instruction_type) + if not merge_nodes and not self.input_nodes: + return 0 + + # merge opens at same depth + merges = defaultdict(list) + for node in merge_nodes: + merges[depths[node]].append(node) + + # after merging, the first element in merges[i] remains for each depth i, + # all others are removed from instructions and G + last_nodes = [None, None] + for i in sorted(merges): + merge = merges[i] + if len(merge) > 1000: + print 'Merging %d opens in round %d/%d' % (len(merge), i, len(merges)) + nodes = defaultdict(lambda: None) + for b in (False, True): + my_merge = (m for m in merge if instructions[m] is not None and instructions[m].is_gf2n() is b) + + if merge_stopopens: + my_stopopen = [G.get_attr(m, 'stop') for m in merge if instructions[m] is not None and instructions[m].is_gf2n() is b] + + mc, nodes[0,b] = self.do_merge(iter(my_merge)) + + if merge_stopopens: + mc, nodes[1,b] = self.do_merge(iter(my_stopopen)) + + # add edges to retain order of gf2n/modp start/stop opens + for j in (0,1): + node2 = nodes[j,True] + nodep = nodes[j,False] + if nodep is not None and node2 is not None: + G.add_edge(nodep, node2) + # add edge to retain order of opens over rounds + if last_nodes[j] is not None: + G.add_edge(last_nodes[j], node2 if nodep is None else nodep) + last_nodes[j] = nodep if node2 is None else node2 + merges[i] = last_nodes[0] + + self.merge_inputs() + + # compute preorder for topological sort + if merge_stopopens and self.options.reorder_between_opens: + if self.options.continuous or not merge_nodes: + rev_depths = self.compute_max_depths(self.real_depths) + preorder = self.compute_continuous_preorder(merges, rev_depths) + else: + rev_depths = self.compute_max_depths(self.depths) + preorder = self.compute_preorder(merges, rev_depths) + else: + preorder = None + + if len(instructions) > 100000: + print "Topological sort ..." + order = Compiler.graph.topological_sort(G, preorder) + instructions[:] = [instructions[i] for i in order if instructions[i] is not None] + if len(instructions) > 100000: + print "Done at", time.asctime() + + return len(merges) + + def dependency_graph(self, merge_class=startopen_class): + """ Create the program dependency graph. """ + block = self.block + options = self.options + open_nodes = set() + self.open_nodes = open_nodes + self.input_nodes = [] + colordict = defaultdict(lambda: 'gray', startopen='red', stopopen='red',\ + ldi='lightblue', ldm='lightblue', stm='blue',\ + mov='yellow', mulm='orange', mulc='orange',\ + triple='green', square='green', bit='green',\ + asm_input='lightgreen') + + G = Compiler.graph.SparseDiGraph(len(block.instructions)) + self.G = G + + reg_nodes = {} + last_def = defaultdict(lambda: -1) + last_mem_write = None + last_mem_read = None + warned_about_mem = [] + last_mem_write_of = defaultdict(list) + last_mem_read_of = defaultdict(list) + last_print_str = None + last = defaultdict(lambda: defaultdict(lambda: None)) + last_open = deque() + + depths = [0] * len(block.instructions) + self.depths = depths + parallel_open = defaultdict(lambda: 0) + next_available_depth = {} + self.sources = [] + self.real_depths = [0] * len(block.instructions) + + def add_edge(i, j): + from_merge = isinstance(block.instructions[i], merge_class) + to_merge = isinstance(block.instructions[j], merge_class) + G.add_edge(i, j) + is_source = G.get_attr(i, 'is_source') and G.get_attr(j, 'is_source') and not from_merge + G.set_attr(j, 'is_source', is_source) + for d in (self.depths, self.real_depths): + if d[j] < d[i]: + d[j] = d[i] + + def read(reg, n): + if last_def[reg] != -1: + add_edge(last_def[reg], n) + + def write(reg, n): + last_def[reg] = n + + def handle_mem_access(addr, reg_type, last_access_this_kind, + last_access_other_kind): + this = last_access_this_kind[addr,reg_type] + other = last_access_other_kind[addr,reg_type] + if this and other: + if this[-1] < other[0]: + del this[:] + this.append(n) + for inst in other: + add_edge(inst, n) + + def mem_access(n, instr, last_access_this_kind, last_access_other_kind): + addr = instr.args[1] + reg_type = instr.args[0].reg_type + if isinstance(addr, int): + for i in range(min(instr.get_size(), 100)): + addr_i = addr + i + handle_mem_access(addr_i, reg_type, last_access_this_kind, + last_access_other_kind) + if not warned_about_mem and (instr.get_size() > 100): + print 'WARNING: Order of memory instructions ' \ + 'not preserved due to long vector, errors possible' + warned_about_mem.append(True) + else: + handle_mem_access(addr, reg_type, last_access_this_kind, + last_access_other_kind) + if not warned_about_mem and not isinstance(instr, DirectMemoryInstruction): + print 'WARNING: Order of memory instructions ' \ + 'not preserved, errors possible' + # hack + warned_about_mem.append(True) + + def keep_order(instr, n, t, arg_index=None): + if arg_index is None: + player = None + else: + player = instr.args[arg_index] + if last[t][player] is not None: + add_edge(last[t][player], n) + last[t][player] = n + + for n,instr in enumerate(block.instructions): + outputs,inputs = instr.get_def(), instr.get_used() + + G.add_node(n, is_source=True) + + # if options.debug: + # col = colordict[instr.__class__.__name__] + # G.add_node(n, color=col, label=str(instr)) + for reg in inputs: + if reg.vector and instr.is_vec(): + for i in reg.vector: + read(i, n) + else: + read(reg, n) + + for reg in outputs: + if reg.vector and instr.is_vec(): + for i in reg.vector: + write(i, n) + else: + write(reg, n) + + if isinstance(instr, merge_class): + open_nodes.add(n) + last_open.append(n) + G.add_node(n, merges=[]) + # the following must happen after adding the edge + self.real_depths[n] += 1 + depth = depths[n] + 1 + if int(options.max_parallel_open): + skipped_depths = set() + while parallel_open[depth] >= int(options.max_parallel_open): + skipped_depths.add(depth) + depth = next_available_depth.get(depth, depth + 1) + for d in skipped_depths: + next_available_depth[d] = depth + parallel_open[depth] += len(instr.args) * instr.get_size() + depths[n] = depth + + if isinstance(instr, stopopen_class): + startopen = last_open.popleft() + add_edge(startopen, n) + G.set_attr(startopen, 'stop', n) + G.set_attr(n, 'start', last_open) + G.add_node(n, merges=[]) + + if isinstance(instr, ReadMemoryInstruction): + if options.preserve_mem_order: + last_mem_read = n + if last_mem_write: + add_edge(last_mem_write, n) + else: + mem_access(n, instr, last_mem_read_of, last_mem_write_of) + elif isinstance(instr, WriteMemoryInstruction): + if options.preserve_mem_order: + last_mem_write = n + if last_mem_read: + add_edge(last_mem_read, n) + else: + mem_access(n, instr, last_mem_write_of, last_mem_read_of) + # keep I/O instructions in order + elif isinstance(instr, IOInstruction): + if last_print_str is not None: + add_edge(last_print_str, n) + last_print_str = n + elif isinstance(instr, PublicFileIOInstruction): + keep_order(instr, n, instr.__class__) + elif isinstance(instr, RawInputInstruction): + keep_order(instr, n, instr.__class__, 0) + self.input_nodes.append(n) + G.add_node(n, merges=[]) + player = instr.args[0] + if isinstance(instr, stopinput): + add_edge(last[startinput_class][player], n) + elif isinstance(instr, gstopinput): + add_edge(last[gstartinput][player], n) + elif isinstance(instr, startprivateoutput_class): + keep_order(instr, n, startprivateoutput_class, 2) + elif isinstance(instr, stopprivateoutput_class): + keep_order(instr, n, stopprivateoutput_class, 1) + elif isinstance(instr, prep_class): + keep_order(instr, n, instr.args[0]) + + if not G.pred[n]: + self.sources.append(n) + + if n % 100000 == 0 and n > 0: + print "Processed dependency of %d/%d instructions at" % \ + (n, len(block.instructions)), time.asctime() + + if len(open_nodes) > 1000: + print "Program has %d %s instructions" % (len(open_nodes), merge_class) + + def merge_nodes(self, i, j): + """ Merge node j into i, removing node j """ + G = self.G + if j in G[i]: + G.remove_edge(i, j) + if i in G[j]: + G.remove_edge(j, i) + G.add_edges_from(zip(itertools.cycle([i]), G[j], [G.weights[(j,k)] for k in G[j]])) + G.add_edges_from(zip(G.pred[j], itertools.cycle([i]), [G.weights[(k,j)] for k in G.pred[j]])) + G.get_attr(i, 'merges').append(j) + G.remove_node(j) + + def eliminate_dead_code(self): + instructions = self.instructions + G = self.G + merge_nodes = self.open_nodes + count = 0 + open_count = 0 + for i,inst in zip(xrange(len(instructions) - 1, -1, -1), reversed(instructions)): + # remove if instruction has result that isn't used + unused_result = not G.degree(i) and len(inst.get_def()) \ + and reduce(operator.and_, (reg.can_eliminate for reg in inst.get_def())) \ + and not isinstance(inst, (DoNotEliminateInstruction)) + stop_node = G.get_attr(i, 'stop') + unused_startopen = stop_node != -1 and instructions[stop_node] is None + if unused_result or unused_startopen: + G.remove_node(i) + merge_nodes.discard(i) + instructions[i] = None + count += 1 + if unused_startopen: + open_count += len(inst.args) + if count > 0: + print 'Eliminated %d dead instructions, among which %d opens' % (count, open_count) + + def print_graph(self, filename): + f = open(filename, 'w') + print >>f, 'digraph G {' + for i in range(self.G.n): + for j in self.G[i]: + print >>f, '"%d: %s" -> "%d: %s";' % \ + (i, self.instructions[i], j, self.instructions[j]) + print >>f, '}' + f.close() + + def print_depth(self, filename): + f = open(filename, 'w') + for i in range(self.G.n): + print >>f, '%d: %s' % (self.depths[i], self.instructions[i]) + f.close() diff --git a/Compiler/comparison.py b/Compiler/comparison.py new file mode 100644 index 000000000..d0cfdc56b --- /dev/null +++ b/Compiler/comparison.py @@ -0,0 +1,548 @@ +# (C) 2016 University of Bristol. See License.txt + +""" +Functions for secure comparison of GF(p) types. +Most protocols come from [1], with a few subroutines described in [2]. + +Function naming of comparison routines is as in [1,2], with k always +representing the integer bit length, and kappa the statistical security +parameter. + +Most of these routines were implemented before the cint/sint classes, so use +the old-fasioned Register class and assembly instructions instead of operator +overloading. + +The PreMulC function has a few variants, depending on whether +preprocessing is only triples/bits, or inverse tuples or "special" +comparison-specific preprocessing is also available. + +[1] https://www1.cs.fau.de/filepool/publications/octavian_securescm/smcint-scn10.pdf +[2] https://www1.cs.fau.de/filepool/publications/octavian_securescm/SecureSCM-D.9.2.pdf +""" + +# Use constant rounds protocols instead of log rounds +const_rounds = False +# Set use_inv to use preprocessed inverse tuples for more efficient +# online phase comparisons. +use_inv = True +# If do_precomp is not set, use_inv uses standard inverse tuples, otherwise if +# both are set, use a list of "special" tuples of the form +# (r[i], r[i]^-1, r[i] * r[i-1]^-1) +do_precomp = True + +import instructions_base + +def set_variant(options): + """ Set flags based on the command-line option provided """ + global const_rounds, do_precomp, use_inv + variant = options.comparison + if variant == 'log': + const_rounds = False + elif variant == 'plain': + const_rounds = True + use_inv = False + elif variant == 'inv': + const_rounds = True + use_inv = True + do_precomp = True + elif variant == 'sinv': + const_rounds = True + use_inv = True + do_precomp = False + elif variant is not None: + raise CompilerError('Unknown comparison variant: %s' % variant) + +def ld2i(c, n): + """ Load immediate 2^n into clear GF(p) register c """ + t1 = program.curr_block.new_reg('c') + ldi(t1, 2 ** (n % 30)) + for i in range(n / 30): + t2 = program.curr_block.new_reg('c') + mulci(t2, t1, 2 ** 30) + t1 = t2 + movc(c, t1) + +inverse_of_two = {} + +def divide_by_two(res, x): + """ Faster clear division by two using a cached value of 2^-1 mod p """ + from program import Program + import types + tape = Program.prog.curr_tape + if len(inverse_of_two) == 0 or tape not in inverse_of_two: + inverse_of_two[tape] = types.cint(1) / 2 + mulc(res, x, inverse_of_two[tape]) + +def LTZ(s, a, k, kappa): + """ + s = (a ?< 0) + + k: bit length of a + """ + t = program.curr_block.new_reg('s') + Trunc(t, a, k, k - 1, kappa, True) + subsfi(s, t, 0) + +def Trunc(d, a, k, m, kappa, signed): + """ + d = a >> m + + k: bit length of a + m: compile-time integer + signed: True/False, describes a + """ + a_prime = program.curr_block.new_reg('s') + t = program.curr_block.new_reg('s') + c = [program.curr_block.new_reg('c') for i in range(3)] + c2m = program.curr_block.new_reg('c') + if m == 0: + movs(d, a) + return + elif m == 1: + Mod2(a_prime, a, k, kappa, signed) + else: + Mod2m(a_prime, a, k, m, kappa, signed) + subs(t, a, a_prime) + ldi(c[1], 1) + ld2i(c2m, m) + divc(c[2], c[1], c2m) + mulm(d, t, c[2]) + +def TruncRoundNearest(a, k, m, kappa): + """ + Returns a / 2^m, rounded to the nearest integer. + + k: bit length of m + m: compile-time integer + """ + from types import sint, cint + from library import reveal, load_int_to_secret + if m == 1: + lsb = sint() + Mod2(lsb, a, k, kappa, False) + return (a + lsb) / 2 + r_dprime = sint() + r_prime = sint() + r = [sint() for i in range(m)] + u = sint() + PRandM(r_dprime, r_prime, r, k, m, kappa) + c = reveal((cint(1) << (k - 1)) + a + (cint(1) << m) * r_dprime + r_prime) + c_prime = c % (cint(1) << (m - 1)) + if const_rounds: + BitLTC1(u, c_prime, r[:-1], kappa) + else: + BitLTL(u, c_prime, r[:-1], kappa) + bit = ((c - c_prime) / (cint(1) << (m - 1))) % 2 + xor = bit + u - 2 * bit * u + prod = xor * r[-1] + # u_prime = xor * u + (1 - xor) * r[-1] + u_prime = bit * u + u - 2 * bit * u + r[-1] - prod + a_prime = (c % (cint(1) << m)) - r_prime + (cint(1) << m) * u_prime + d = (a - a_prime) / (cint(1) << m) + rounding = xor + r[-1] - 2 * prod + return d + rounding + +def Mod2m(a_prime, a, k, m, kappa, signed): + """ + a_prime = a % 2^m + + k: bit length of a + m: compile-time integer + signed: True/False, describes a + """ + if m >= k: + movs(a_prime, a) + return + r_dprime = program.curr_block.new_reg('s') + r_prime = program.curr_block.new_reg('s') + r = [program.curr_block.new_reg('s') for i in range(m)] + c = program.curr_block.new_reg('c') + c_prime = program.curr_block.new_reg('c') + v = program.curr_block.new_reg('s') + u = program.curr_block.new_reg('s') + t = [program.curr_block.new_reg('s') for i in range(6)] + c2m = program.curr_block.new_reg('c') + c2k1 = program.curr_block.new_reg('c') + PRandM(r_dprime, r_prime, r, k, m, kappa) + ld2i(c2m, m) + mulm(t[0], r_dprime, c2m) + if signed: + ld2i(c2k1, k - 1) + addm(t[1], a, c2k1) + else: + t[1] = a + adds(t[2], t[0], t[1]) + adds(t[3], t[2], r_prime) + startopen(t[3]) + stopopen(c) + modc(c_prime, c, c2m) + if const_rounds: + BitLTC1(u, c_prime, r, kappa) + else: + BitLTL(u, c_prime, r, kappa) + mulm(t[4], u, c2m) + submr(t[5], c_prime, r_prime) + adds(a_prime, t[5], t[4]) + return r_dprime, r_prime, c, c_prime, u, t, c2k1 + +def PRandM(r_dprime, r_prime, b, k, m, kappa): + """ + r_dprime = random secret integer in range [0, 2^(k + kappa - m) - 1] + r_prime = random secret integer in range [0, 2^m - 1] + b = array containing bits of r_prime + """ + t = [[program.curr_block.new_reg('s') for j in range(2)] for i in range(m)] + t[0][1] = b[-1] + PRandInt(r_dprime, k + kappa - m) + # r_dprime is always multiplied by 2^m + program.curr_tape.require_bit_length(k + kappa) + bit(b[-1]) + for i in range(1,m): + adds(t[i][0], t[i-1][1], t[i-1][1]) + bit(b[-i-1]) + adds(t[i][1], t[i][0], b[-i-1]) + movs(r_prime, t[m-1][1]) + +def PRandInt(r, k): + """ + r = random secret integer in range [0, 2^k - 1] + """ + t = [[program.curr_block.new_reg('s') for i in range(k)] for j in range(3)] + t[2][k-1] = r + bit(t[2][0]) + for i in range(1,k): + adds(t[0][i], t[2][i-1], t[2][i-1]) + bit(t[1][i]) + adds(t[2][i], t[0][i], t[1][i]) + +def BitLTC1(u, a, b, kappa): + """ + u = a (p_1 & p_2, g_2 | (p_2 & g_1)) + """ + if a is None: + return b + if b is None: + return a + t = [program.curr_block.new_reg('s') for i in range(3)] + if compute_p: + muls(t[0], a[0], b[0]) + muls(t[1], a[0], b[1]) + adds(t[2], a[1], t[1]) + return t[0], t[2] + +# from WP9 report +# length of a is even +def CarryOutAux(d, a, kappa): + k = len(a) + if k > 1 and k % 2 == 1: + a.append(None) + k += 1 + u = [None]*(k/2) + a = a[::-1] + if k > 1: + for i in range(k/2): + u[i] = carry(a[2*i+1], a[2*i], i != k/2-1) + CarryOutAux(d, u[:k/2][::-1], kappa) + else: + movs(d, a[0][1]) + +# carry out with carry-in bit c +def CarryOut(res, a, b, c, kappa): + """ + res = last carry bit in addition of a and b + + a: array of clear bits + b: array of secret bits (same length as a) + c: initial carry-in bit + """ + k = len(a) + d = [program.curr_block.new_reg('s') for i in range(k)] + t = [[program.curr_block.new_reg('s') for i in range(k)] for i in range(4)] + s = [program.curr_block.new_reg('s') for i in range(3)] + for i in range(k): + mulm(t[0][i], b[i], a[i]) + mulsi(t[1][i], t[0][i], 2) + addm(t[2][i], b[i], a[i]) + subs(t[3][i], t[2][i], t[1][i]) + d[i] = [t[3][i], t[0][i]] + mulsi(s[0], d[-1][0], c) + adds(s[1], d[-1][1], s[0]) + d[-1][1] = s[1] + + CarryOutAux(res, d[::-1], kappa) + +def BitLTL(res, a, b, kappa): + """ + res = a 0: + print "Initialized %d register variables at" % i, time.asctime() + + # first pass determines how many assembler registers are used + prog.FIRST_PASS = True + execfile(prog.infile, VARS) + + if instructions_base.Instruction.count != 0: + print 'instructions count', instructions_base.Instruction.count + instructions_base.Instruction.count = 0 + prog.FIRST_PASS = False + prog.reset_values() + # make compiler modules directly accessible + sys.path.insert(0, 'Compiler') + # create the tapes + execfile(prog.infile, VARS) + + # optimize the tapes + for tape in prog.tapes: + tape.optimize(options) + + # check program still does the same thing after optimizations + if emulate: + clearmem = list(prog.mem_c) + sharedmem = list(prog.mem_s) + prog.emulate() + if prog.mem_c != clearmem or prog.mem_s != sharedmem: + print 'Warning: emulated memory values changed after compiler optimization' + # raise CompilerError('Compiler optimization caused incorrect memory write.') + + if prog.main_thread_running: + prog.update_req(prog.curr_tape) + print 'Program requires:', repr(prog.req_num) + print 'Cost:', prog.req_num.cost() + print 'Memory size:', prog.allocated_mem + + # finalize the memory + prog.finalize_memory() + + return prog diff --git a/Compiler/config.py b/Compiler/config.py new file mode 100644 index 000000000..42248b1aa --- /dev/null +++ b/Compiler/config.py @@ -0,0 +1,58 @@ +# (C) 2016 University of Bristol. See License.txt + +from collections import defaultdict + +#INIT_REG_MAX = 655360 +INIT_REG_MAX = 1310720 +REG_MAX = 2 ** 32 +USER_MEM = 8192 +TMP_MEM = 8192 +TMP_MEM_BASE = USER_MEM +TMP_REG = 3 +TMP_REG_BASE = REG_MAX - TMP_REG + +P_VALUES = { -1: 2147483713, \ + 32: 2147565569, \ + 64: 9223372036855103489, \ + 128: 172035116406933162231178957667602464769, \ + 256: 57896044624266469032429686755131815517604980759976795324963608525438406557697, \ + 512: 6703903964971298549787012499123814115273848577471136527425966013026501536706464354255445443244279389455058889493431223951165286470575994074291745908195329 } + +BIT_LENGTHS = { -1: 24, + 32: 24, + 64: 32, + 128: 64, + 256: 64, + 512: 64 } + +STAT_SEC = { -1: 6, + 32: 6, + 64: 30, + 128: 40, + 256: 40, + 512: 40 } + + +COST = { 'modp': defaultdict(lambda: 0, + { 'triple': 0.00020652622883106154, + 'square': 0.00020652622883106154, + 'bit': 0.00020652622883106154, + 'inverse': 0.00020652622883106154, + 'PreMulC': 2 * 0.00020652622883106154, + }), + 'gf2n': defaultdict(lambda: 0, + { 'triple': 0.00020716801325875284, + 'square': 0.00020716801325875284, + 'inverse': 0.00020716801325875284, + 'bit': 1.4492753623188405e-07, + 'bittriple': 0.00004828818388140422, + 'bitgf2ntriple': 0.00020716801325875284, + 'PreMulC': 2 * 0.00020716801325875284, + }) +} + + +try: + from config_mine import * +except: + pass diff --git a/Compiler/exceptions.py b/Compiler/exceptions.py new file mode 100644 index 000000000..eb1bf519d --- /dev/null +++ b/Compiler/exceptions.py @@ -0,0 +1,17 @@ +# (C) 2016 University of Bristol. See License.txt + +class CompilerError(Exception): + """Base class for compiler exceptions.""" + pass + +class RegisterOverflowError(CompilerError): + pass + +class MemoryOverflowError(CompilerError): + pass + +class ArgumentError(CompilerError): + """ Exception raised for errors in instruction argument parsing. """ + def __init__(self, arg, msg): + self.arg = arg + self.msg = msg \ No newline at end of file diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py new file mode 100644 index 000000000..cf4a6f519 --- /dev/null +++ b/Compiler/floatingpoint.py @@ -0,0 +1,517 @@ +# (C) 2016 University of Bristol. See License.txt + +from math import log, floor, ceil +from Compiler.instructions import * +import types +import comparison +import program + +## +## Helper functions for floating point arithmetic +## + + +def two_power(n): + if isinstance(n, int) and n < 31: + return 2**n + else: + max = types.cint(1) << 31 + res = 2**(n%31) + for i in range(n / 31): + res *= max + return res + +def EQZ(a, k, kappa): + r_dprime = types.sint() + r_prime = types.sint() + c = types.cint() + d = [None]*k + r = [types.sint() for i in range(k)] + comparison.PRandM(r_dprime, r_prime, r, k, k, kappa) + startopen(a + two_power(k) * r_dprime + r_prime)# + 2**(k-1)) + stopopen(c) + for i,b in enumerate(bits(c, k)): + d[i] = b + r[i] - 2*b*r[i] + return 1 - KOR(d, kappa) + +def bits(a,m): + """ Get the bits of an int """ + if isinstance(a, int): + res = [None]*m + for i in range(m): + res[i] = a & 1 + a >>= 1 + else: + c = [[types.cint() for i in range(m)] for i in range(2)] + res = [types.cint() for i in range(m)] + modci(res[0], a, 2) + c[1][0] = a + for i in range(1,m): + subc(c[0][i], c[1][i-1], res[i-1]) + divci(c[1][i], c[0][i], 2) + modci(res[i], c[1][i], 2) + return res + +def carry(b, a, compute_p=True): + """ Carry propogation: + (p,g) = (p_2, g_2)o(p_1, g_1) -> (p_1 & p_2, g_2 | (p_2 & g_1)) + """ + if compute_p: + t1 = a[0]*b[0] + else: + t1 = None + t2 = a[1] + a[0]*b[1] + return (t1, t2) + +def or_op(a, b, void=None): + return a + b - a*b + +def mul_op(a, b, void=None): + return a * b + +def PreORC(a, kappa=None, m=None, raw=False): + k = len(a) + if k == 1: + return [a[0]] + m = m or k + if isinstance(a[0], types.sgf2n): + max_k = program.Program.prog.galois_length - 1 + else: + max_k = int(log(program.Program.prog.P) / log(2)) - kappa + if k <= max_k: + p = [None] * m + if m == k: + p[0] = a[0] + if isinstance(a[0], types.sgf2n): + b = comparison.PreMulC([3 - a[i] for i in range(k)]) + for i in range(m): + tmp = b[k-1-i] + if not raw: + tmp = tmp.bit_decompose()[0] + p[m-1-i] = 1 - tmp + else: + t = [types.sint() for i in range(m)] + b = comparison.PreMulC([a[i] + 1 for i in range(k)]) + for i in range(m): + comparison.Mod2(t[i], b[k-1-i], k, kappa, False) + p[m-1-i] = 1 - t[i] + return p + else: + # not constant-round anymore + s = [PreORC(a[i:i+max_k], kappa, raw=raw) for i in range(0,k,max_k)] + t = PreORC([si[-1] for si in s[:-1]], kappa, raw=raw) + return sum(([or_op(x, y) for x in si] for si,y in zip(s[1:],t)), s[0]) + +def PreOpL(op, items): + """ + Uses algorithm from SecureSCM WP9 deliverable. + + op must be a binary function that outputs a new register + """ + k = len(items) + logk = int(ceil(log(k,2))) + kmax = 2**logk + output = list(items) + for i in range(logk): + for j in range(kmax/(2**(i+1))): + y = two_power(i) + j*two_power(i+1) - 1 + for z in range(1, 2**i+1): + if y+z < k: + output[y+z] = op(output[y], output[y+z], j != 0) + return output + +def PreOpN(op, items): + """ Naive PreOp algorithm """ + k = len(items) + output = [None]*k + output[0] = items[0] + for i in range(1, k): + output[i] = op(output[i-1], items[i]) + return output + +def PreOR(a, kappa=None, raw=False): + if comparison.const_rounds: + return PreORC(a, kappa, raw=raw) + else: + return PreOpL(or_op, a) + +def KOpL(op, a): + k = len(a) + if k == 1: + return a[0] + else: + t1 = KOpL(op, a[:k/2]) + t2 = KOpL(op, a[k/2:]) + return op(t1, t2) + +def KORL(a, kappa): + """ log rounds k-ary OR """ + k = len(a) + if k == 1: + return a[0] + else: + t1 = KORL(a[:k/2], kappa) + t2 = KORL(a[k/2:], kappa) + return t1 + t2 - t1*t2 + +def KORC(a, kappa): + return PreORC(a, kappa, 1)[0] + +def KOR(a, kappa): + if comparison.const_rounds: + return KORC(a, kappa) + else: + return KORL(a, None) + +def KMul(a): + if comparison.const_rounds: + return comparison.KMulC(a) + else: + return KOpL(mul_op, a) + + +def Inv(a): + """ Invert a non-zero value """ + t = [types.sint() for i in range(3)] + c = [types.cint() for i in range(2)] + one = types.cint() + ldi(one, 1) + inverse(t[0], t[1]) + s = t[0]*a + asm_open(c[0], s) + # avoid division by zero for benchmarking + divc(c[1], one, c[0]) + #divc(c[1], c[0], one) + return c[1]*t[0] + +def BitAdd(a, b, bits_to_compute=None): + """ Add the bits a[k-1], ..., a[0] and b[k-1], ..., b[0], return k+1 + bits s[0], ... , s[k] """ + k = len(a) + if not bits_to_compute: + bits_to_compute = range(k) + d = [None] * k + for i in range(1,k): + #assert(a[i].value == 0 or a[i].value == 1) + #assert(b[i].value == 0 or b[i].value == 1) + t = a[i]*b[i] + d[i] = (a[i] + b[i] - 2*t, t) + #assert(d[i][0].value == 0 or d[i][0].value == 1) + d[0] = (None, a[0]*b[0]) + pg = PreOpL(carry, d) + c = [pair[1] for pair in pg] + + # (for testing) + def print_state(): + print 'a: ', + for i in range(k): + print '%d ' % a[i].value, + print '\nb: ', + for i in range(k): + print '%d ' % b[i].value, + print '\nd: ', + for i in range(k): + print '%d ' % d[i][0].value, + print '\n ', + for i in range(k): + print '%d ' % d[i][1].value, + print '\n\npg:', + for i in range(k): + print '%d ' % pg[i][0].value, + print '\n ', + for i in range(k): + print '%d ' % pg[i][1].value, + print '' + + for bit in c: + pass#assert(bit.value == 0 or bit.value == 1) + s = [None] * (k+1) + if 0 in bits_to_compute: + s[0] = a[0] + b[0] - 2*c[0] + bits_to_compute.remove(0) + #assert(c[0].value == a[0].value*b[0].value) + #assert(s[0].value == 0 or s[0].value == 1) + for i in bits_to_compute: + s[i] = a[i] + b[i] + c[i-1] - 2*c[i] + try: + pass#assert(s[i].value == 0 or s[i].value == 1) + except AssertionError: + print '#assertion failed in BitAdd for s[%d]' % i + print_state() + s[k] = c[k-1] + #print_state() + return s + +def BitDec(a, k, m, kappa, bits_to_compute=None): + r_dprime = types.sint() + r_prime = types.sint() + c = types.cint() + r = [types.sint() for i in range(m)] + comparison.PRandM(r_dprime, r_prime, r, k, m, kappa) + #assert(r_prime.value == sum(r[i].value*2**i for i in range(m)) % comparison.program.P) + pow2 = two_power(k + kappa) + asm_open(c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime) + #rval = 2**m*r_dprime.value + r_prime.value + #assert(rval % 2**m == r_prime.value) + #assert(rval == (2**m*r_dprime.value + sum(r[i].value*2**i for i in range(m)) % comparison.program.P )) + try: + pass#assert(c.value == (2**(k + kappa) + 2**k + (a.value%2**k) - rval) % comparison.program.P) + except AssertionError: + print 'BitDec assertion failed' + print 'a =', a.value + print 'a mod 2^%d =' % k, (a.value % 2**k) + return BitAdd(list(bits(c,m)), r, bits_to_compute)[:-1] + + +def Pow2(a, l, kappa): + m = int(ceil(log(l, 2))) + t = BitDec(a, m, m, kappa) + x = [types.sint() for i in range(m)] + pow2k = [types.cint() for i in range(m)] + for i in range(m): + pow2k[i] = two_power(2**i) + t[i] = t[i]*pow2k[i] + 1 - t[i] + return KMul(t) + +def B2U(a, l, kappa): + pow2a = Pow2(a, l, kappa) + #assert(pow2a.value == 2**a.value) + r = [types.sint() for i in range(l)] + t = types.sint() + c = types.cint() + for i in range(l): + bit(r[i]) + comparison.PRandInt(t, kappa) + asm_open(c, pow2a + two_power(l) * t + sum(two_power(i)*r[i] for i in range(l))) + comparison.program.curr_tape.require_bit_length(l + kappa) + c = list(bits(c, l)) + x = [c[i] + r[i] - 2*c[i]*r[i] for i in range(l)] + #print ' '.join(str(b.value) for b in x) + y = PreOR(x, kappa) + #print ' '.join(str(b.value) for b in y) + return [1 - y[i] for i in range(l)], pow2a + +def Trunc(a, l, m, kappa, compute_modulo=False): + """ Oblivious truncation by secret m """ + if l == 1: + if compute_modulo: + return a * m, 1 + m + else: + return a * (1 - m) + r = [types.sint() for i in range(l)] + r_dprime = types.sint(0) + r_prime = types.sint(0) + rk = types.sint() + c = types.cint() + ci = [types.cint() for i in range(l)] + d = types.sint() + x, pow2m = B2U(m, l, kappa) + #assert(pow2m.value == 2**m.value) + #assert(sum(b.value for b in x) == m.value) + for i in range(l): + bit(r[i]) + t1 = two_power(i) * r[i] + t2 = t1*x[i] + r_prime += t2 + r_dprime += t1 - t2 + #assert(r_prime.value == (sum(2**i*x[i].value*r[i].value for i in range(l)) % comparison.program.P)) + comparison.PRandInt(rk, kappa) + r_dprime += two_power(l) * rk + #assert(r_dprime.value == (2**l * rk.value + sum(2**i*(1 - x[i].value)*r[i].value for i in range(l)) % comparison.program.P)) + asm_open(c, a + r_dprime + r_prime) + for i in range(1,l): + ci[i] = c % two_power(i) + #assert(ci[i].value == c.value % 2**i) + c_dprime = sum(ci[i]*(x[i-1] - x[i]) for i in range(1,l)) + #assert(c_dprime.value == (sum(ci[i].value*(x[i-1].value - x[i].value) for i in range(1,l)) % comparison.program.P)) + lts(d, c_dprime, r_prime, l, kappa) + if compute_modulo: + b = c_dprime - r_prime + pow2m * d + return b, pow2m + else: + pow2inv = Inv(pow2m) + #assert(pow2inv.value * pow2m.value % comparison.program.P == 1) + b = (a - c_dprime + r_prime) * pow2inv - d + return b + +def TruncRoundNearestAdjustOverflow(a, length, target_length, kappa): + t = comparison.TruncRoundNearest(a, length, length - target_length, kappa) + overflow = t.greater_equal(two_power(target_length), target_length + 1, kappa) + s = (1 - overflow) * t + overflow * t / 2 + return s, overflow + +def Int2FL(a, gamma, l, kappa): + lam = gamma - 1 + s = types.sint() + comparison.LTZ(s, a, gamma, kappa) + z = EQZ(a, gamma, kappa) + a = (1 - 2 * s) * a + a_bits = BitDec(a, lam, lam, kappa) + a_bits.reverse() + b = PreOR(a_bits, kappa) + t = a * (1 + sum(2**i * (1 - b_i) for i,b_i in enumerate(b))) + p = - (lam - sum(b)) + if gamma - 1 > l: + if types.sfloat.round_nearest: + v, overflow = TruncRoundNearestAdjustOverflow(t, gamma - 1, l, kappa) + p = p + overflow + else: + v = types.sint() + comparison.Trunc(v, t, gamma - 1, gamma - l - 1, kappa, False) + else: + v = 2**(l-gamma+1) * t + p = (p + gamma - 1 - l) * (1 -z) + return v, p, z, s + +def FLRound(x, mode): + """ Rounding with floating point output. + *mode*: 0 -> floor, 1 -> ceil, -1 > trunc """ + v1, p1, z1, s1, l, k = x.v, x.p, x.z, x.s, x.vlen, x.plen + a = types.sint() + comparison.LTZ(a, p1, k, x.kappa) + b = p1.less_than(-l + 1, k, x.kappa) + v2, inv_2pow_p1 = Trunc(v1, l, -a * (1 - b) * x.p, x.kappa, True) + c = EQZ(v2, l, x.kappa) + if mode == -1: + away_from_zero = 0 + mode = x.s + else: + away_from_zero = mode + s1 - 2 * mode * s1 + v = v1 - v2 + (1 - c) * inv_2pow_p1 * away_from_zero + d = v.equal(two_power(l), l + 1, x.kappa) + v = d * two_power(l-1) + (1 - d) * v + v = a * ((1 - b) * v + b * away_from_zero * two_power(l-1)) + (1 - a) * v1 + s = (1 - b * mode) * s1 + z = or_op(EQZ(v, l, x.kappa), z1) + v = v * (1 - z) + p = ((p1 + d * a) * (1 - b) + b * away_from_zero * (1 - l)) * (1 - z) + return v, p, z, s + +def TruncPr(a, k, m, kappa=None): + """ Probabilistic truncation [a/2^m + u] + where Pr[u = 1] = (a % 2^m) / 2^m + """ + if kappa is None: + kappa = 40 + + b = two_power(k-1) + a + r_prime, r_dprime = types.sint(), types.sint() + comparison.PRandM(r_dprime, r_prime, [types.sint() for i in range(m)], + k, m, kappa) + two_to_m = two_power(m) + r = two_to_m * r_dprime + r_prime + c = (b + r).reveal() + c_prime = c % two_to_m + a_prime = c_prime - r_prime + d = (a - a_prime) / two_to_m + return d + +def SDiv(a, b, l, kappa): + theta = int(ceil(log(l / 3.5) / log(2))) + alpha = two_power(2*l) + beta = 1 / types.cint(two_power(l)) + w = types.cint(int(2.9142 * two_power(l))) - 2 * b + x = alpha - b * w + y = a * w + y = TruncPr(y, 2 * l, l, kappa) + x2 = types.sint() + comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, False) + x1 = (x - x2) * beta + for i in range(theta-1): + y = y * (x1 + two_power(l)) + TruncPr(y * x2, 2 * l, l, kappa) + y = TruncPr(y, 2 * l + 1, l + 1, kappa) + x = x1 * x2 + TruncPr(x2**2, 2 * l + 1, l + 1, kappa) + x = x1 * x1 + TruncPr(x, 2 * l + 1, l - 1, kappa) + x2 = types.sint() + comparison.Mod2m(x2, x, 2 * l, l, kappa, False) + x1 = (x - x2) * beta + y = y * (x1 + two_power(l)) + TruncPr(y * x2, 2 * l, l, kappa) + y = TruncPr(y, 2 * l + 1, l - 1, kappa) + return y + +def SDiv_mono(a, b, l, kappa): + theta = int(ceil(log(l / 3.5) / log(2))) + alpha = two_power(2*l) + w = types.cint(int(2.9142 * two_power(l))) - 2 * b + x = alpha - b * w + y = a * w + y = TruncPr(y, 2 * l + 1, l + 1, kappa) + for i in range(theta-1): + y = y * (alpha + x) + # keep y with l bits + y = TruncPr(y, 3 * l, 2 * l, kappa) + x = x**2 + # keep x with 2l bits + x = TruncPr(x, 4 * l, 2 * l, kappa) + y = y * (alpha + x) + y = TruncPr(y, 3 * l, 2 * l, kappa) + return y + + +def FPDiv(a, b, k, f, kappa): + theta = int(ceil(log(k/3.5))) + alpha = types.cint(1 * two_power(2*f)) + + w = AppRcr(b, k, f, kappa) + x = alpha - b * w + y = a * w + y = TruncPr(y, 2*k, f, kappa) + + for i in range(theta): + y = y * (alpha + x) + x = x * x + y = TruncPr(y, 2*k, 2*f, kappa) + x = TruncPr(x, 2*k, 2*f, kappa) + + y = y * (alpha + x) + y = TruncPr(y, 2*k, 2*f, kappa) + return y + +def AppRcr(b, k, f, kappa): + """ + Approximate reciprocal of [b]: + Given [b], compute [1/b] + """ + alpha = types.cint(int(2.9142 * (2**k))) + c, v = Norm(b, k, f, kappa) + d = alpha - 2 * c + w = d * v + w = TruncPr(w, 2 * k, 2 * (k - f)) + return w + +def Norm(b, k, f, kappa): + """ + Computes secret integer values [c] and [v_prime] st. + 2^{k-1} <= c < 2^k and c = b*v_prime + """ + temp = types.sint() + comparison.LTZ(temp, b, k, kappa) + sign = 1 - 2 * temp # 1 - 2 * [b < 0] + + x = sign * b + #x = |b| + bits = x.bit_decompose(k) + y = PreOR(bits) + + z = [0] * k + for i in range(k - 1): + z[i] = y[i] - y[i + 1] + + z[k - 1] = y[k - 1] + # z[i] = 0 for all i except when bits[i + 1] = first one + + #now reverse bits of z[i] + v = types.sint() + for i in range(k): + v += two_power(k - i - 1) * z[i] + c = x * v + v_prime = sign * v + return c, v_prime + + + + + + + + diff --git a/Compiler/graph.py b/Compiler/graph.py new file mode 100644 index 000000000..97c152af6 --- /dev/null +++ b/Compiler/graph.py @@ -0,0 +1,220 @@ +# (C) 2016 University of Bristol. See License.txt + +import heapq +from Compiler.exceptions import * + +class GraphError(CompilerError): + pass + +class SparseDiGraph(object): + """ Directed graph suitable when each node only has a small number of edges. + + Edges are stored as a list instead of a dictionary to save memory, leading + to slower searching for dense graphs. + + Node attributes must be specified in advance, as these are stored in the + same list as edges. + """ + def __init__(self, max_nodes, default_attributes=None): + """ max_nodes: maximum no of nodes + default_attributes: dict of node attributes and default values """ + if default_attributes is None: + default_attributes = { 'merges': None, 'stop': -1, 'start': -1, 'is_source': True } + self.default_attributes = default_attributes + self.attribute_pos = dict(zip(default_attributes.keys(), range(len(default_attributes)))) + self.n = max_nodes + # each node contains list of default attributes, followed by outoing edges + self.nodes = [self.default_attributes.values() for i in range(self.n)] + self.pred = [[] for i in range(self.n)] + self.weights = {} + + def __len__(self): + return self.n + + def __getitem__(self, i): + """ Get list of the neighbours of node i """ + return self.nodes[i][len(self.default_attributes):] + + def __iter__(self): + pass #return iter(self.nodes) + + def __contains__(self, i): + return i >= 0 and i < self.n + + def add_node(self, i, **attr): + if i >= self.n: + raise CompilerError('Cannot add node %d to graph of size %d' % (i, self.n)) + node = self.nodes[i] + + for a,value in attr.items(): + if a in self.default_attributes: + node[self.attribute_pos[a]] = value + else: + raise CompilerError('Invalid attribute %s for graph node' % a) + + def set_attr(self, i, attr, value): + if attr in self.default_attributes: + self.nodes[i][self.attribute_pos[attr]] = value + else: + raise CompilerError('Invalid attribute %s for graph node' % attr) + + def get_attr(self, i, attr): + return self.nodes[i][self.attribute_pos[attr]] + + def remove_node(self, i): + """ Remove node i and all its edges """ + succ = self[i] + pred = self.pred[i] + for v in succ: + self.pred[v].remove(i) + #del self.weights[(i,v)] + for v in pred: + # find index to ensure attribute isn't removed instead + index = self[v].index(i) + len(self.default_attributes) + del self.nodes[v][index] + #del self.weights[(v,i)] + #self.nodes[v].remove(i) + self.pred[i] = [] + self.nodes[i] = self.default_attributes.values() + + def add_edge(self, i, j, weight=1): + if j not in self[i]: + self.nodes[i].append(j) + self.pred[j].append(i) + self.weights[(i,j)] = weight + + def add_edges_from(self, tuples): + for edge in tuples: + if len(edge) == 3: + # use weight + self.add_edge(edge[0], edge[1], edge[2]) + else: + self.add_edge(edge[0], edge[1]) + + def remove_edge(self, i, j): + jindex = self[i].index(j) + len(self.default_attributes) + del self.nodes[i][jindex] + self.pred[j].remove(i) + del self.weights[(i,j)] + + def remove_edges_from(self, pairs): + for i,j in pairs: + self.remove_edge(i, j) + + def degree(self, i): + return len(self.nodes[i]) - len(self.default_attributes) + + +def topological_sort(G, nbunch=None, pref=None): + seen={} + order_explored=[] # provide order and + explored={} # fast search without more general priorityDictionary + + if pref is None: + def get_children(node): + return G[node] + else: + def get_children(node): + if pref.has_key(node): + pref_set = set(pref[node]) + for i in G[node]: + if i not in pref_set: + yield i + for i in reversed(pref[node]): + yield i + else: + for i in G[node]: + yield i + + if nbunch is None: + nbunch = range(len(G)) + for v in nbunch: # process all vertices in G + if v in explored: + continue + fringe=[v] # nodes yet to look at + while fringe: + w=fringe[-1] # depth first search + if w in explored: # already looked down this branch + fringe.pop() + continue + seen[w]=1 # mark as seen + # Check successors for cycles and for new nodes + new_nodes=[] + for n in get_children(w): + if n not in explored: + if n in seen: #CYCLE !! + raise GraphError("Graph contains a cycle at %d (%s,%s)." % \ + (n, G[n], G.pred[n])) + new_nodes.append(n) + if new_nodes: # Add new_nodes to fringe + fringe.extend(new_nodes) + else: # No new nodes so w is fully explored + explored[w]=1 + order_explored.append(w) + fringe.pop() # done considering this node + + order_explored.reverse() # reverse order explored + return order_explored + +def dag_shortest_paths(G, source): + top_order = topological_sort(G) + dist = [None] * len(G) + dist[source] = 0 + for u in top_order: + if dist[u] is None: + continue + for v in G[u]: + if dist[v] is None or dist[v] > dist[u] + G.weights[(u,v)]: + dist[v] = dist[u] + G.weights[(u,v)] + return dist + +def reverse_dag_shortest_paths(G, source): + top_order = reversed(topological_sort(G)) + dist = [None] * len(G) + dist[source] = 0 + for u in top_order: + if u ==68273: + print 'dist[68273]', dist[u] + print 'pred[u]', G.pred[u] + if dist[u] is None: + continue + for v in G.pred[u]: + if dist[v] is None or dist[v] > dist[u] + G.weights[(v,u)]: + dist[v] = dist[u] + G.weights[(v,u)] + return dist + +def single_source_longest_paths(G, source, reverse=False): + # make weights negative, then do shortest paths + for edge in G.weights: + G.weights[edge] = -G.weights[edge] + if reverse: + dist = reverse_dag_shortest_paths(G, source) + else: + dist = dag_shortest_paths(G, source) + #dist = johnson(G, sources) + # reset weights + for edge in G.weights: + G.weights[edge] = -G.weights[edge] + for i,n in enumerate(dist): + if n is None: + dist[i] = 0 + else: + dist[i] = -dist[i] + #for k, v in dist.iteritems(): + # dist[k] = -v + return dist + + +def longest_paths(G, sources=None): + # make weights negative, then do shortest paths + for edge in G.weights: + G.weights[edge] = -G.weights[edge] + dist = {} + for source in sources: + print ('%s, ' % source), + dist[source] = dag_shortest_paths(G, source) + #dist = johnson(G, sources) + # reset weights + for edge in G.weights: + G.weights[edge] = -G.weights[edge] + return dist diff --git a/Compiler/instructions.py b/Compiler/instructions.py new file mode 100644 index 000000000..ab9d77358 --- /dev/null +++ b/Compiler/instructions.py @@ -0,0 +1,1310 @@ +# (C) 2016 University of Bristol. See License.txt + +""" This module is for classes of actual assembly instructions. + +All base classes, utility functions etc. should go in +instructions_base.py instead. This is for two reasons: +1) Easier generation of documentation +2) Ensures that 'from instructions import *' will only import assembly +instructions and nothing else. + +Note: every instruction should have a suitable docstring for auto-generation of +documentation +""" + +import itertools +import tools +from random import randint +from Compiler.config import * +from Compiler.exceptions import * +import Compiler.instructions_base as base + + +# avoid naming collision with input instruction +_python_input = input + +### +### Load and store instructions +### + +@base.gf2n +@base.vectorize +class ldi(base.Instruction): + r""" Assigns register $c_i$ the value $n$. """ + __slots__ = [] + code = base.opcodes['LDI'] + arg_format = ['cw','i'] + + def execute(self): + self.args[0].value = self.args[1] + +@base.gf2n +@base.vectorize +class ldsi(base.Instruction): + r""" Assigns register $s_i$ a share of the value $n$. """ + __slots__ = [] + code = base.opcodes['LDSI'] + arg_format = ['sw','i'] + + def execute(self): + self.args[0].value = self.args[1] + +@base.gf2n +@base.vectorize +class ldmc(base.DirectMemoryInstruction, base.ReadMemoryInstruction): + r""" Assigns register $c_i$ the value in memory \verb+C[n]+. """ + __slots__ = ["code"] + code = base.opcodes['LDMC'] + arg_format = ['cw','int'] + + def execute(self): + self.args[0].value = program.mem_c[self.args[1]] + +@base.gf2n +@base.vectorize +class ldms(base.DirectMemoryInstruction, base.ReadMemoryInstruction): + r""" Assigns register $s_i$ the value in memory \verb+S[n]+. """ + __slots__ = ["code"] + code = base.opcodes['LDMS'] + arg_format = ['sw','int'] + + def execute(self): + self.args[0].value = program.mem_s[self.args[1]] + +@base.gf2n +@base.vectorize +class stmc(base.DirectMemoryWriteInstruction): + r""" Sets \verb+C[n]+ to be the value $c_i$. """ + __slots__ = ["code"] + code = base.opcodes['STMC'] + arg_format = ['c','int'] + + def execute(self): + program.mem_c[self.args[1]] = self.args[0].value + +@base.gf2n +@base.vectorize +class stms(base.DirectMemoryWriteInstruction): + r""" Sets \verb+S[n]+ to be the value $s_i$. """ + __slots__ = ["code"] + code = base.opcodes['STMS'] + arg_format = ['s','int'] + + def execute(self): + program.mem_s[self.args[1]] = self.args[0].value + +@base.vectorize +class ldmint(base.DirectMemoryInstruction, base.ReadMemoryInstruction): + r""" Assigns register $ci_i$ the value in memory \verb+Ci[n]+. """ + __slots__ = ["code"] + code = base.opcodes['LDMINT'] + arg_format = ['ciw','int'] + + def execute(self): + self.args[0].value = program.mem_i[self.args[1]] + +@base.vectorize +class stmint(base.DirectMemoryWriteInstruction): + r""" Sets \verb+Ci[n]+ to be the value $ci_i$. """ + __slots__ = ["code"] + code = base.opcodes['STMINT'] + arg_format = ['ci','int'] + + def execute(self): + program.mem_i[self.args[1]] = self.args[0].value + +# must have seperate instructions because address is always modp +@base.vectorize +class ldmci(base.ReadMemoryInstruction): + r""" Assigns register $c_i$ the value in memory \verb+C[cj]+. """ + code = base.opcodes['LDMCI'] + arg_format = ['cw','ci'] + + def execute(self): + self.args[0].value = program.mem_c[self.args[1].value] + +@base.vectorize +class ldmsi(base.ReadMemoryInstruction): + r""" Assigns register $s_i$ the value in memory \verb+S[cj]+. """ + code = base.opcodes['LDMSI'] + arg_format = ['sw','ci'] + + def execute(self): + self.args[0].value = program.mem_s[self.args[1].value] + +@base.vectorize +class stmci(base.WriteMemoryInstruction): + r""" Sets \verb+C[cj]+ to be the value $c_i$. """ + code = base.opcodes['STMCI'] + arg_format = ['c','ci'] + + def execute(self): + program.mem_c[self.args[1].value] = self.args[0].value + +@base.vectorize +class stmsi(base.WriteMemoryInstruction): + r""" Sets \verb+S[cj]+ to be the value $s_i$. """ + code = base.opcodes['STMSI'] + arg_format = ['s','ci'] + + def execute(self): + program.mem_s[self.args[1].value] = self.args[0].value + +@base.vectorize +class ldminti(base.ReadMemoryInstruction): + r""" Assigns register $ci_i$ the value in memory \verb+Ci[cj]+. """ + code = base.opcodes['LDMINTI'] + arg_format = ['ciw','ci'] + + def execute(self): + self.args[0].value = program.mem_i[self.args[1].value] + +@base.vectorize +class stminti(base.WriteMemoryInstruction): + r""" Sets \verb+Ci[cj]+ to be the value $ci_i$. """ + code = base.opcodes['STMINTI'] + arg_format = ['ci','ci'] + + def execute(self): + program.mem_i[self.args[1].value] = self.args[0].value + +@base.vectorize +class gldmci(base.ReadMemoryInstruction): + r""" Assigns register $c_i$ the value in memory \verb+C[cj]+. """ + code = base.opcodes['LDMCI'] + 0x100 + arg_format = ['cgw','ci'] + + def execute(self): + self.args[0].value = program.mem_c[self.args[1].value] + +@base.vectorize +class gldmsi(base.ReadMemoryInstruction): + r""" Assigns register $s_i$ the value in memory \verb+S[cj]+. """ + code = base.opcodes['LDMSI'] + 0x100 + arg_format = ['sgw','ci'] + + def execute(self): + self.args[0].value = program.mem_s[self.args[1].value] + +@base.vectorize +class gstmci(base.WriteMemoryInstruction): + r""" Sets \verb+C[cj]+ to be the value $c_i$. """ + code = base.opcodes['STMCI'] + 0x100 + arg_format = ['cg','ci'] + + def execute(self): + program.mem_c[self.args[1].value] = self.args[0].value + +@base.vectorize +class gstmsi(base.WriteMemoryInstruction): + r""" Sets \verb+S[cj]+ to be the value $s_i$. """ + code = base.opcodes['STMSI'] + 0x100 + arg_format = ['sg','ci'] + + def execute(self): + program.mem_s[self.args[1].value] = self.args[0].value + +@base.gf2n +@base.vectorize +class protectmems(base.Instruction): + r""" Protects secret memory range $[ci_i,ci_j)$. """ + code = base.opcodes['PROTECTMEMS'] + arg_format = ['ci','ci'] + +@base.gf2n +@base.vectorize +class protectmemc(base.Instruction): + r""" Protects clear memory range $[ci_i,ci_j)$. """ + code = base.opcodes['PROTECTMEMC'] + arg_format = ['ci','ci'] + +@base.gf2n +@base.vectorize +class protectmemint(base.Instruction): + r""" Protects integer memory range $[ci_i,ci_j)$. """ + code = base.opcodes['PROTECTMEMINT'] + arg_format = ['ci','ci'] + +@base.gf2n +@base.vectorize +class movc(base.Instruction): + r""" Assigns register $c_i$ the value in the register $c_j$. """ + __slots__ = ["code"] + code = base.opcodes['MOVC'] + arg_format = ['cw','c'] + + def execute(self): + self.args[0].value = self.args[1].value + +@base.gf2n +@base.vectorize +class movs(base.Instruction): + r""" Assigns register $s_i$ the value in the register $s_j$. """ + __slots__ = ["code"] + code = base.opcodes['MOVS'] + arg_format = ['sw','s'] + + def execute(self): + self.args[0].value = self.args[1].value + +@base.vectorize +class movint(base.Instruction): + r""" Assigns register $ci_i$ the value in the register $ci_j$. """ + __slots__ = ["code"] + code = base.opcodes['MOVINT'] + arg_format = ['ciw','ci'] + +@base.vectorize +class pushint(base.Instruction): + r""" Pushes register $ci_i$ to the thread-local stack. """ + code = base.opcodes['PUSHINT'] + arg_format = ['ci'] + +@base.vectorize +class popint(base.Instruction): + r""" Pops from the thread-local stack to register $ci_i$. """ + code = base.opcodes['POPINT'] + arg_format = ['ciw'] + + +### +### Machine +### + +@base.vectorize +class ldtn(base.Instruction): + r""" Assigns register $c_i$ the number of the current thread. """ + code = base.opcodes['LDTN'] + arg_format = ['ciw'] + +@base.vectorize +class ldarg(base.Instruction): + r""" Assigns register $c_i$ the argument passed to the current thread. """ + code = base.opcodes['LDARG'] + arg_format = ['ciw'] + +@base.vectorize +class starg(base.Instruction): + r""" Assigns register $c_i$ to the argument. """ + code = base.opcodes['STARG'] + arg_format = ['ci'] + +@base.gf2n +class reqbl(base.Instruction): + r""" Require bit length $n". """ + code = base.opcodes['REQBL'] + arg_format = ['int'] + +class time(base.Instruction): + r""" Output epoch time. """ + code = base.opcodes['TIME'] + arg_format = [] + +class start(base.Instruction): + r""" Start timer. """ + code = base.opcodes['START'] + arg_format = ['i'] + +class stop(base.Instruction): + r""" Stop timer. """ + code = base.opcodes['STOP'] + arg_format = ['i'] + +class use(base.Instruction): + r""" Offline data usage. """ + code = base.opcodes['USE'] + arg_format = ['int','int','int'] + +class use_inp(base.Instruction): + r""" Input usage. """ + code = base.opcodes['USE_INP'] + arg_format = ['int','int','int'] + +class run_tape(base.Instruction): + r""" Start tape $n$ in thread $c_i$ with argument $c_j$. """ + code = base.opcodes['RUN_TAPE'] + arg_format = ['int','int','int'] + +class join_tape(base.Instruction): + r""" Join thread $c_i$. """ + code = base.opcodes['JOIN_TAPE'] + arg_format = ['int'] + +class crash(base.IOInstruction): + r""" Crash runtime. """ + code = base.opcodes['CRASH'] + arg_format = [] + +@base.gf2n +class use_prep(base.Instruction): + r""" Input usage. """ + code = base.opcodes['USE_PREP'] + arg_format = ['str','int'] + +### +### Basic arithmetic +### + +@base.gf2n +@base.vectorize +class addc(base.AddBase): + r""" Clear addition $c_i=c_j+c_k$. """ + __slots__ = [] + code = base.opcodes['ADDC'] + arg_format = ['cw','c','c'] + +@base.gf2n +@base.vectorize +class adds(base.AddBase): + r""" Secret addition $s_i=s_j+s_k$. """ + __slots__ = [] + code = base.opcodes['ADDS'] + arg_format = ['sw','s','s'] + +@base.gf2n +@base.vectorize +class addm(base.AddBase): + r""" Mixed addition $s_i=s_j+c_k$. """ + __slots__ = [] + code = base.opcodes['ADDM'] + arg_format = ['sw','s','c'] + +@base.gf2n +@base.vectorize +class subc(base.SubBase): + r""" Clear subtraction $c_i=c_j-c_k$. """ + __slots__ = [] + code = base.opcodes['SUBC'] + arg_format = ['cw','c','c'] + +@base.gf2n +@base.vectorize +class subs(base.SubBase): + r""" Secret subtraction $s_i=s_j-s_k$. """ + __slots__ = [] + code = base.opcodes['SUBS'] + arg_format = ['sw','s','s'] + +@base.gf2n +@base.vectorize +class subml(base.SubBase): + r""" Mixed subtraction $s_i=s_j-c_k$. """ + __slots__ = [] + code = base.opcodes['SUBML'] + arg_format = ['sw','s','c'] + +@base.gf2n +@base.vectorize +class submr(base.SubBase): + r""" Mixed subtraction $s_i=c_j-s_k$. """ + __slots__ = [] + code = base.opcodes['SUBMR'] + arg_format = ['sw','c','s'] + +@base.gf2n +@base.vectorize +class mulc(base.MulBase): + r""" Clear multiplication $c_i=c_j \cdot c_k$. """ + __slots__ = [] + code = base.opcodes['MULC'] + arg_format = ['cw','c','c'] + +@base.gf2n +@base.vectorize +class mulm(base.MulBase): + r""" Mixed multiplication $s_i=c_j \cdot s_k$. """ + __slots__ = [] + code = base.opcodes['MULM'] + arg_format = ['sw','s','c'] + +@base.gf2n +@base.vectorize +class divc(base.Instruction): + r""" Clear division $c_i=c_j/c_k$. """ + __slots__ = [] + code = base.opcodes['DIVC'] + arg_format = ['cw','c','c'] + + def execute(self): + self.args[0].value = self.args[1].value * pow(self.args[2].value, program.P-2, program.P) % program.P + +@base.gf2n +@base.vectorize +class modc(base.Instruction): + r""" Clear modular reduction $c_i=c_j/c_k$. """ + __slots__ = [] + code = base.opcodes['MODC'] + arg_format = ['cw','c','c'] + + def execute(self): + self.args[0].value = self.args[1].value % self.args[2].value + +@base.vectorize +class legendrec(base.Instruction): + r""" Clear Legendre symbol computation, $c_i = (c_j / p)$. """ + __slots__ = [] + code = base.opcodes['LEGENDREC'] + arg_format = ['cw','c'] + +### +### Bitwise operations +### + +@base.gf2n +@base.vectorize +class andc(base.Instruction): + r""" Clear logical AND $c_i = c_j \land c_k$ """ + __slots__ = [] + code = base.opcodes['ANDC'] + arg_format = ['cw','c','c'] + + def execute(self): + self.args[0].value = (self.args[1].value & self.args[2].value) % program.P + +@base.gf2n +@base.vectorize +class orc(base.Instruction): + r""" Clear logical OR $c_i = c_j \lor c_k$ """ + __slots__ = [] + code = base.opcodes['ORC'] + arg_format = ['cw','c','c'] + + def execute(self): + self.args[0].value = (self.args[1].value | self.args[2].value) % program.P + +@base.gf2n +@base.vectorize +class xorc(base.Instruction): + r""" Clear logical XOR $c_i = c_j \oplus c_k$ """ + __slots__ = [] + code = base.opcodes['XORC'] + arg_format = ['cw','c','c'] + + def execute(self): + self.args[0].value = (self.args[1].value ^ self.args[2].value) % program.P + +@base.vectorize +class notc(base.Instruction): + r""" Clear logical NOT $c_i = \lnot c_j$ """ + __slots__ = [] + code = base.opcodes['NOTC'] + arg_format = ['cw','c', 'int'] + + def execute(self): + self.args[0].value = (~self.args[1].value + 2 ** self.args[2]) % program.P + +@base.vectorize +class gnotc(base.Instruction): + r""" Clear logical NOT $cg_i = \lnot cg_j$ """ + __slots__ = [] + code = (1 << 8) + base.opcodes['NOTC'] + arg_format = ['cgw','cg'] + + def is_gf2n(self): + return True + + def execute(self): + self.args[0].value = ~self.args[1].value + +@base.vectorize +class gbitdec(base.Instruction): + r""" Store every $n$-th bit of $cg_i$ in $cg_j, \dots$. """ + __slots__ = [] + code = base.opcodes['GBITDEC'] + arg_format = tools.chain(['cg', 'int'], itertools.repeat('cgw')) + + def is_g2fn(self): + return True + + def has_var_args(self): + return True + +@base.vectorize +class gbitcom(base.Instruction): + r""" Store the bits $cg_j, \dots$ as every $n$-th bit of $cg_i$. """ + __slots__ = [] + code = base.opcodes['GBITCOM'] + arg_format = tools.chain(['cgw', 'int'], itertools.repeat('cg')) + + def is_g2fn(self): + return True + + def has_var_args(self): + return True + + +### +### Special GF(2) arithmetic instructions +### + +@base.vectorize +class gmulbitc(base.MulBase): + r""" Clear GF(2^n) by clear GF(2) multiplication """ + __slots__ = [] + code = base.opcodes['GMULBITC'] + arg_format = ['cgw','cg','cg'] + + def is_gf2n(self): + return True + +@base.vectorize +class gmulbitm(base.MulBase): + r""" Secret GF(2^n) by clear GF(2) multiplication """ + __slots__ = [] + code = base.opcodes['GMULBITM'] + arg_format = ['sgw','sg','cg'] + + def is_gf2n(self): + return True + +### +### Arithmetic with immediate values +### + +@base.gf2n +@base.vectorize +class addci(base.ClearImmediate): + """ Clear addition of immediate value $c_i=c_j+n$. """ + __slots__ = [] + code = base.opcodes['ADDCI'] + op = '__add__' + +@base.gf2n +@base.vectorize +class addsi(base.SharedImmediate): + """ Secret addition of immediate value $s_i=s_j+n$. """ + __slots__ = [] + code = base.opcodes['ADDSI'] + op = '__add__' + +@base.gf2n +@base.vectorize +class subci(base.ClearImmediate): + r""" Clear subtraction of immediate value $c_i=c_j-n$. """ + __slots__ = [] + code = base.opcodes['SUBCI'] + op = '__sub__' + +@base.gf2n +@base.vectorize +class subsi(base.SharedImmediate): + r""" Secret subtraction of immediate value $s_i=s_j-n$. """ + __slots__ = [] + code = base.opcodes['SUBSI'] + op = '__sub__' + +@base.gf2n +@base.vectorize +class subcfi(base.ClearImmediate): + r""" Clear subtraction from immediate value $c_i=n-c_j$. """ + __slots__ = [] + code = base.opcodes['SUBCFI'] + op = '__rsub__' + +@base.gf2n +@base.vectorize +class subsfi(base.SharedImmediate): + r""" Secret subtraction from immediate value $s_i=n-s_j$. """ + __slots__ = [] + code = base.opcodes['SUBSFI'] + op = '__rsub__' + +@base.gf2n +@base.vectorize +class mulci(base.ClearImmediate): + r""" Clear multiplication by immediate value $c_i=c_j \cdot n$. """ + __slots__ = [] + code = base.opcodes['MULCI'] + op = '__mul__' + +@base.gf2n +@base.vectorize +class mulsi(base.SharedImmediate): + r""" Secret multiplication by immediate value $s_i=s_j \cdot n$. """ + __slots__ = [] + code = base.opcodes['MULSI'] + op = '__mul__' + +@base.gf2n +@base.vectorize +class divci(base.ClearImmediate): + r""" Clear division by immediate value $c_i=c_j/n$. """ + __slots__ = [] + code = base.opcodes['DIVCI'] + def execute(self): + self.args[0].value = self.args[1].value * pow(self.args[2], program.P-2, program.P) % program.P + +@base.gf2n +@base.vectorize +class modci(base.ClearImmediate): + r""" Clear modular reduction by immediate value $c_i=c_j \mod{n}$. """ + __slots__ = [] + code = base.opcodes['MODCI'] + op = '__mod__' + +@base.gf2n +@base.vectorize +class andci(base.ClearImmediate): + r""" Clear logical AND with immediate value $c_i = c_j \land c_k$ """ + __slots__ = [] + code = base.opcodes['ANDCI'] + op = '__and__' + +@base.gf2n +@base.vectorize +class xorci(base.ClearImmediate): + r""" Clear logical XOR with immediate value $c_i = c_j \oplus c_k$ """ + __slots__ = [] + code = base.opcodes['XORCI'] + op = '__xor__' + +@base.gf2n +@base.vectorize +class orci(base.ClearImmediate): + r""" Clear logical OR with immediate value $c_i = c_j \vee c_k$ """ + __slots__ = [] + code = base.opcodes['ORCI'] + op = '__or__' + + +### +### Shift instructions +### + +@base.gf2n +@base.vectorize +class shlc(base.Instruction): + r""" Clear bitwise shift left $c_i = c_j << c_k$ """ + __slots__ = [] + code = base.opcodes['SHLC'] + arg_format = ['cw','c','c'] + + def execute(self): + self.args[0].value = (self.args[1].value << self.args[2].value) % program.P + +@base.gf2n +@base.vectorize +class shrc(base.Instruction): + r""" Clear bitwise shift right $c_i = c_j >> c_k$ """ + __slots__ = [] + code = base.opcodes['SHRC'] + arg_format = ['cw','c','c'] + + def execute(self): + self.args[0].value = (self.args[1].value >> self.args[2].value) % program.P + +@base.gf2n +@base.vectorize +class shlci(base.ClearShiftInstruction): + r""" Clear bitwise shift left by immediate value $c_i = c_j << n$ """ + __slots__ = [] + code = base.opcodes['SHLCI'] + op = '__lshift__' + +@base.gf2n +@base.vectorize +class shrci(base.ClearShiftInstruction): + r""" Clear bitwise shift right by immediate value $c_i = c_j >> n$ """ + __slots__ = [] + code = base.opcodes['SHRCI'] + op = '__rshift__' + + +### +### Data access instructions +### + +@base.gf2n +@base.vectorize +class triple(base.DataInstruction): + r""" Load secret variables $s_i$, $s_j$ and $s_k$ + with the next multiplication triple. """ + __slots__ = ['data_type'] + code = base.opcodes['TRIPLE'] + arg_format = ['sw','sw','sw'] + data_type = 'triple' + + def execute(self): + self.args[0].value = randint(0,program.P) + self.args[1].value = randint(0,program.P) + self.args[2].value = (self.args[0].value * self.args[1].value) % program.P + +@base.vectorize +class gbittriple(base.DataInstruction): + r""" Load secret variables $s_i$, $s_j$ and $s_k$ + with the next GF(2) multiplication triple. """ + __slots__ = ['data_type'] + code = base.opcodes['GBITTRIPLE'] + arg_format = ['sgw','sgw','sgw'] + data_type = 'bittriple' + field_type = 'gf2n' + + def is_gf2n(self): + return True + +@base.vectorize +class gbitgf2ntriple(base.DataInstruction): + r""" Load secret variables $s_i$, $s_j$ and $s_k$ + with the next GF(2) and GF(2^n) multiplication triple. """ + code = base.opcodes['GBITGF2NTRIPLE'] + arg_format = ['sgw','sgw','sgw'] + data_type = 'bitgf2ntriple' + field_type = 'gf2n' + + def is_gf2n(self): + return True + +@base.gf2n +@base.vectorize +class bit(base.DataInstruction): + r""" Load secret variable $s_i$ + with the next secret bit. """ + __slots__ = [] + code = base.opcodes['BIT'] + arg_format = ['sw'] + data_type = 'bit' + + def execute(self): + self.args[0].value = randint(0,1) + +@base.gf2n +@base.vectorize +class square(base.DataInstruction): + r""" Load secret variables $s_i$ and $s_j$ + with the next squaring tuple. """ + __slots__ = [] + code = base.opcodes['SQUARE'] + arg_format = ['sw','sw'] + data_type = 'square' + + def execute(self): + self.args[0].value = randint(0,program.P) + self.args[1].value = (self.args[0].value * self.args[0].value) % program.P + +@base.gf2n +@base.vectorize +class inverse(base.DataInstruction): + r""" Load secret variables $s_i$, $s_j$ and $s_k$ + with the next inverse triple. """ + __slots__ = [] + code = base.opcodes['INV'] + arg_format = ['sw','sw'] + data_type = 'inverse' + + def execute(self): + self.args[0].value = randint(0,program.P) + import gmpy + self.args[1].value = int(gmpy.invert(self.args[0].value, program.P)) + +@base.gf2n +@base.vectorize +class inputmask(base.Instruction): + r""" Load secret $s_i$ with the next input mask for player $p$ and + write the mask on player $p$'s private output. """ + __slots__ = [] + code = base.opcodes['INPUTMASK'] + arg_format = ['sw', 'p'] + field_type = 'modp' + + def add_usage(self, req_node): + req_node.increment((self.field_type, 'input', self.args[1]), \ + self.get_size()) + +@base.gf2n +@base.vectorize +class prep(base.Instruction): + r""" Custom preprocessed data """ + __slots__ = [] + code = base.opcodes['PREP'] + arg_format = tools.chain(['str'], itertools.repeat('sw')) + gf2n_arg_format = tools.chain(['str'], itertools.repeat('sgw')) + field_type = 'modp' + + def add_usage(self, req_node): + req_node.increment((self.field_type, self.args[0]), 1) + + def has_var_args(self): + return True + +### +### I/O +### + +@base.gf2n +@base.vectorize +class asm_input(base.IOInstruction): + r""" Receive input from player $p$ and put in register $s_i$. """ + __slots__ = [] + code = base.opcodes['INPUT'] + arg_format = ['sw', 'p'] + field_type = 'modp' + + def add_usage(self, req_node): + req_node.increment((self.field_type, 'input', self.args[1]), \ + self.get_size()) + def execute(self): + self.args[0].value = _python_input("Enter player %d's input:" % self.args[1]) % program.P + +@base.gf2n +class startinput(base.RawInputInstruction): + r""" Receive inputs from player $p$. """ + __slots__ = [] + code = base.opcodes['STARTINPUT'] + arg_format = ['p', 'int'] + field_type = 'modp' + + def add_usage(self, req_node): + req_node.increment((self.field_type, 'input', self.args[0]), \ + self.args[1]) + +class stopinput(base.RawInputInstruction): + r""" Receive inputs from player $p$ and put in registers. """ + __slots__ = [] + code = base.opcodes['STOPINPUT'] + arg_format = tools.chain(['p'], itertools.repeat('sw')) + + def has_var_args(self): + return True + +class gstopinput(base.RawInputInstruction): + r""" Receive inputs from player $p$ and put in registers. """ + __slots__ = [] + code = 0x100 + base.opcodes['STOPINPUT'] + arg_format = tools.chain(['p'], itertools.repeat('sgw')) + + def has_var_args(self): + return True + +@base.gf2n +@base.vectorize +class print_mem(base.IOInstruction): + r""" Print value in clear memory \verb|C[ci]| to stdout. """ + __slots__ = [] + code = base.opcodes['PRINTMEM'] + arg_format = ['c'] + + def execute(self): + pass + +@base.gf2n +@base.vectorize +class print_reg(base.IOInstruction): + r""" Print value of register \verb|ci| to stdout and optional 4-char comment. """ + __slots__ = [] + code = base.opcodes['PRINTREG'] + arg_format = ['c','i'] + + def __init__(self, reg, comment=''): + super(print_reg_class, self).__init__(reg, self.str_to_int(comment)) + + def execute(self): + pass + +@base.gf2n +@base.vectorize +class print_reg_plain(base.IOInstruction): + r""" Print only the value of register \verb|ci| to stdout. """ + __slots__ = [] + code = base.opcodes['PRINTREGPLAIN'] + arg_format = ['c'] + +class print_char(base.IOInstruction): + r""" Print a single character to stdout. """ + code = base.opcodes['PRINTCHR'] + arg_format = ['int'] + + def __init__(self, ch): + super(print_char, self).__init__(ord(ch)) + +class print_char4(base.IOInstruction): + r""" Print a 4 character string. """ + code = base.opcodes['PRINTSTR'] + arg_format = ['int'] + + def __init__(self, val): + super(print_char4, self).__init__(self.str_to_int(val)) + +@base.vectorize +class print_char_regint(base.IOInstruction): + r""" Print register $ci_i$ as a single character to stdout. """ + code = base.opcodes['PRINTCHRINT'] + arg_format = ['ci'] + +@base.vectorize +class print_char4_regint(base.IOInstruction): + r""" Print register $ci_i$ as a four character string to stdout. """ + code = base.opcodes['PRINTSTRINT'] + arg_format = ['ci'] + +@base.vectorize +class pubinput(base.PublicFileIOInstruction): + __slots__ = [] + code = base.opcodes['PUBINPUT'] + arg_format = ['ciw'] + +@base.vectorize +class readsocketc(base.IOInstruction): + """Read an int from socket and store in register""" + __slots__ = [] + code = base.opcodes['READSOCKETC'] + arg_format = ['ciw', 'int'] + +@base.vectorize +class readsockets(base.IOInstruction): + """Read a secret share + MAC from socket and store in register""" + __slots__ = [] + code = base.opcodes['READSOCKETS'] + arg_format = ['sw', 'int'] + +@base.vectorize +class writesocketc(base.IOInstruction): + """Write int from register into socket""" + __slots__ = [] + code = base.opcodes['WRITESOCKETC'] + arg_format = ['ci', 'int'] + +@base.vectorize +class writesockets(base.IOInstruction): + """Write secret share + MAC from register into socket""" + __slots__ = [] + code = base.opcodes['WRITESOCKETS'] + arg_format = ['s', 'int'] + +class opensocket(base.IOInstruction): + """Open a server socket connection at the given port number""" + __slots__ = [] + code = base.opcodes['OPENSOCKET'] + arg_format = ['int'] + +class closesocket(base.IOInstruction): + """Close a server socket connection""" + __slots__ = [] + code = base.opcodes['CLOSESOCKET'] + arg_format = [] + +@base.gf2n +@base.vectorize +class raw_output(base.PublicFileIOInstruction): + r""" Raw output of register \verb|ci| to file. """ + __slots__ = [] + code = base.opcodes['RAWOUTPUT'] + arg_format = ['c'] + +@base.gf2n +@base.vectorize +class startprivateoutput(base.Instruction): + r""" Initiate private output to $n$ of $s_j$ via $s_i$. """ + __slots__ = [] + code = base.opcodes['STARTPRIVATEOUTPUT'] + arg_format = ['sw','s','p'] + +@base.gf2n +@base.vectorize +class stopprivateoutput(base.Instruction): + r""" Previously iniated private output to $n$ via $c_i$. """ + __slots__ = [] + code = base.opcodes['STOPPRIVATEOUTPUT'] + arg_format = ['c','p'] + +@base.vectorize +class rand(base.Instruction): + __slots__ = [] + code = base.opcodes['RAND'] + arg_format = ['ciw','ci'] + +### +### Integer operations +### + +@base.vectorize +class ldint(base.Instruction): + __slots__ = [] + code = base.opcodes['LDINT'] + arg_format = ['ciw', 'i'] + +@base.vectorize +class addint(base.IntegerInstruction): + __slots__ = [] + code = base.opcodes['ADDINT'] + +@base.vectorize +class subint(base.IntegerInstruction): + __slots__ = [] + code = base.opcodes['SUBINT'] + +@base.vectorize +class mulint(base.IntegerInstruction): + __slots__ = [] + code = base.opcodes['MULINT'] + +@base.vectorize +class divint(base.IntegerInstruction): + __slots__ = [] + code = base.opcodes['DIVINT'] + +### +### Clear comparison instructions +### + +@base.vectorize +class eqzc(base.UnaryComparisonInstruction): + r""" Clear comparison $c_i = (c_j \stackrel{?}{==} 0)$. """ + __slots__ = [] + code = base.opcodes['EQZC'] + + def execute(self): + if self.args[1].value == 0: + self.args[0].value = 1 + else: + self.args[0].value = 0 + +@base.vectorize +class ltzc(base.UnaryComparisonInstruction): + r""" Clear comparison $c_i = (c_j \stackrel{?}{<} 0)$. """ + __slots__ = [] + code = base.opcodes['LTZC'] + +@base.vectorize +class ltc(base.IntegerInstruction): + r""" Clear comparison $c_i = (c_j \stackrel{?}{<} c_k)$. """ + __slots__ = [] + code = base.opcodes['LTC'] + +@base.vectorize +class gtc(base.IntegerInstruction): + r""" Clear comparison $c_i = (c_j \stackrel{?}{>} c_k)$. """ + __slots__ = [] + code = base.opcodes['GTC'] + +@base.vectorize +class eqc(base.IntegerInstruction): + r""" Clear comparison $c_i = (c_j \stackrel{?}{==} c_k)$. """ + __slots__ = [] + code = base.opcodes['EQC'] + + +### +### Jumps etc +### + +class jmp(base.JumpInstruction): + """ Unconditional relative jump of $n+1$ instructions. """ + __slots__ = [] + code = base.opcodes['JMP'] + arg_format = ['int'] + jump_arg = 0 + + def execute(self): + pass + +class jmpi(base.JumpInstruction): + """ Unconditional relative jump of $c_i+1$ instructions. """ + __slots__ = [] + code = base.opcodes['JMPI'] + arg_format = ['ci'] + jump_arg = 0 + +class jmpnz(base.JumpInstruction): + r""" Jump $n+1$ instructions if $c_i \neq 0$. + + e.g. + jmpnz(c, n) : advance n+1 instructions if c is non-zero + jmpnz(c, 0) : do nothing + jmpnz(c, -1): infinite loop if c is non-zero + """ + __slots__ = [] + code = base.opcodes['JMPNZ'] + arg_format = ['ci', 'int'] + jump_arg = 1 + + def execute(self): + pass + +class jmpeqz(base.JumpInstruction): + r""" Jump $n+1$ instructions if $c_i == 0$. """ + __slots__ = [] + code = base.opcodes['JMPEQZ'] + arg_format = ['ci', 'int'] + jump_arg = 1 + + def execute(self): + pass + +### +### Conversions +### + +@base.gf2n +@base.vectorize +class convint(base.Instruction): + """ Convert from integer register $ci_j$ to clear modp register $c_i$. """ + __slots__ = [] + code = base.opcodes['CONVINT'] + arg_format = ['cw', 'ci'] + +@base.vectorize +class convmodp(base.Instruction): + """ Convert from clear modp register $c_j$ to integer register $ci_i$. """ + __slots__ = [] + code = base.opcodes['CONVMODP'] + arg_format = ['ciw', 'c', 'int'] + def __init__(self, *args, **kwargs): + bitlength = kwargs.get('bitlength', program.bit_length) + super(convmodp_class, self).__init__(*(args + (bitlength,))) + +@base.vectorize +class gconvgf2n(base.Instruction): + """ Convert from clear modp register $c_j$ to integer register $ci_i$. """ + __slots__ = [] + code = base.opcodes['GCONVGF2N'] + arg_format = ['ciw', 'cg'] + +### +### Other instructions +### + +@base.gf2n +@base.vectorize +class startopen(base.Instruction): + """ Start opening secret register $s_i$. """ + __slots__ = [] + code = base.opcodes['STARTOPEN'] + arg_format = itertools.repeat('s') + + def execute(self): + for arg in self.args[::-1]: + program.curr_block.open_queue.append(arg.value) + + def has_var_args(self): + return True + +@base.gf2n +@base.vectorize +class stopopen(base.Instruction): + """ Store previous opened value in $c_i$. """ + __slots__ = [] + code = base.opcodes['STOPOPEN'] + arg_format = itertools.repeat('cw') + + def execute(self): + for arg in self.args: + arg.value = program.curr_block.open_queue.pop() + + def has_var_args(self): + return True + +### +### CISC-style instructions +### + +# rename 'open' to avoid conflict with built-in open function +@base.gf2n +@base.vectorize +class asm_open(base.CISC): + """ Open the value in $s_j$ and assign it to $c_i$. """ + __slots__ = [] + arg_format = ['cw','s'] + + def expand(self): + startopen(self.args[1]) + stopopen(self.args[0]) + + +@base.gf2n +@base.vectorize +class muls(base.CISC): + """ Secret multiplication $s_i = s_j \cdot s_k$. """ + __slots__ = [] + arg_format = ['sw','s','s'] + + def expand(self): + s = [program.curr_block.new_reg('s') for i in range(9)] + c = [program.curr_block.new_reg('c') for i in range(3)] + triple(s[0], s[1], s[2]) + subs(s[3], self.args[1], s[0]) + subs(s[4], self.args[2], s[1]) + startopen(s[3], s[4]) + stopopen(c[0], c[1]) + mulm(s[5], s[1], c[0]) + mulm(s[6], s[0], c[1]) + mulc(c[2], c[0], c[1]) + adds(s[7], s[2], s[5]) + adds(s[8], s[7], s[6]) + addm(self.args[0], s[8], c[2]) + +@base.gf2n +@base.vectorize +class sqrs(base.CISC): + """ Secret squaring $s_i = s_j \cdot s_j$. """ + __slots__ = [] + arg_format = ['sw', 's'] + + def expand(self): + s = [program.curr_block.new_reg('s') for i in range(6)] + c = [program.curr_block.new_reg('c') for i in range(2)] + square(s[0], s[1]) + subs(s[2], self.args[1], s[0]) + asm_open(c[0], s[2]) + mulc(c[1], c[0], c[0]) + mulm(s[3], self.args[1], c[0]) + adds(s[4], s[3], s[3]) + adds(s[5], s[1], s[4]) + subml(self.args[0], s[5], c[1]) + + +@base.gf2n +@base.vectorize +class lts(base.CISC): + """ Secret comparison $s_i = (s_j < s_k)$. """ + __slots__ = [] + arg_format = ['sw', 's', 's', 'int', 'int'] + + def expand(self): + a = program.curr_block.new_reg('s') + subs(a, self.args[1], self.args[2]) + comparison.LTZ(self.args[0], a, self.args[3], self.args[4]) + +@base.vectorize +class g2muls(base.CISC): + r""" Secret GF(2) multiplication """ + __slots__ = [] + arg_format = ['sgw','sg','sg'] + + def expand(self): + s = [program.curr_block.new_reg('sg') for i in range(9)] + c = [program.curr_block.new_reg('cg') for i in range(3)] + gbittriple(s[0], s[1], s[2]) + gsubs(s[3], self.args[1], s[0]) + gsubs(s[4], self.args[2], s[1]) + gstartopen(s[3], s[4]) + gstopopen(c[0], c[1]) + gmulbitm(s[5], s[1], c[0]) + gmulbitm(s[6], s[0], c[1]) + gmulbitc(c[2], c[0], c[1]) + gadds(s[7], s[2], s[5]) + gadds(s[8], s[7], s[6]) + gaddm(self.args[0], s[8], c[2]) + +#@base.vectorize +#class gmulbits(base.CISC): +# r""" Secret $GF(2^n) \times GF(2)$ multiplication """ +# __slots__ = [] +# arg_format = ['sgw','sg','sg'] +# +# def expand(self): +# s = [program.curr_block.new_reg('s') for i in range(9)] +# c = [program.curr_block.new_reg('c') for i in range(3)] +# g2ntriple(s[0], s[1], s[2]) +# subs(s[3], self.args[1], s[0]) +# subs(s[4], self.args[2], s[1]) +# startopen(s[3], s[4]) +# stopopen(c[0], c[1]) +# mulm(s[5], s[1], c[0]) +# mulm(s[6], s[0], c[1]) +# mulc(c[2], c[0], c[1]) +# adds(s[7], s[2], s[5]) +# adds(s[8], s[7], s[6]) +# addm(self.args[0], s[8], c[2]) + +# hack for circular dependency +from Compiler import comparison diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py new file mode 100644 index 000000000..420631a8b --- /dev/null +++ b/Compiler/instructions_base.py @@ -0,0 +1,742 @@ +# (C) 2016 University of Bristol. See License.txt + +import itertools +from random import randint +import time +import inspect +import functools +from Compiler.exceptions import * +from Compiler.config import * +from Compiler import util + + +### +### Opcode constants +### +### Whenever these are changed the corresponding enums in Processor/instruction.h +### MUST also be changed. (+ the documentation) +### +opcodes = dict( + # Load/store + LDI = 0x1, + LDSI = 0x2, + LDMC = 0x3, + LDMS = 0x4, + STMC = 0x5, + STMS = 0x6, + LDMCI = 0x7, + LDMSI = 0x8, + STMCI = 0x9, + STMSI = 0xA, + MOVC = 0xB, + MOVS = 0xC, + PROTECTMEMS = 0xD, + PROTECTMEMC = 0xE, + PROTECTMEMINT = 0xF, + LDMINT = 0xCA, + STMINT = 0xCB, + LDMINTI = 0xCC, + STMINTI = 0xCD, + PUSHINT = 0xCE, + POPINT = 0xCF, + MOVINT = 0xD0, + # Machine + LDTN = 0x10, + LDARG = 0x11, + REQBL = 0x12, + STARG = 0x13, + TIME = 0x14, + START = 0x15, + STOP = 0x16, + USE = 0x17, + USE_INP = 0x18, + RUN_TAPE = 0x19, + JOIN_TAPE = 0x1A, + CRASH = 0x1B, + USE_PREP = 0x1C, + # Addition + ADDC = 0x20, + ADDS = 0x21, + ADDM = 0x22, + ADDCI = 0x23, + ADDSI = 0x24, + SUBC = 0x25, + SUBS = 0x26, + SUBML = 0x27, + SUBMR = 0x28, + SUBCI = 0x29, + SUBSI = 0x2A, + SUBCFI = 0x2B, + SUBSFI = 0x2C, + # Multiplication/division + MULC = 0x30, + MULM = 0x31, + MULCI = 0x32, + MULSI = 0x33, + DIVC = 0x34, + DIVCI = 0x35, + MODC = 0x36, + MODCI = 0x37, + LEGENDREC = 0x38, + GMULBITC = 0x136, + GMULBITM = 0x137, + # Open + STARTOPEN = 0xA0, + STOPOPEN = 0xA1, + # Data access + TRIPLE = 0x50, + BIT = 0x51, + SQUARE = 0x52, + INV = 0x53, + GBITTRIPLE = 0x154, + GBITGF2NTRIPLE = 0x155, + INPUTMASK = 0x56, + PREP = 0x57, + # Input + INPUT = 0x60, + STARTINPUT = 0x61, + STOPINPUT = 0x62, + READSOCKETC = 0x63, + READSOCKETS = 0x64, + WRITESOCKETC = 0x65, + WRITESOCKETS = 0x66, + OPENSOCKET = 0x67, + CLOSESOCKET = 0x68, + # Bitwise logic + ANDC = 0x70, + XORC = 0x71, + ORC = 0x72, + ANDCI = 0x73, + XORCI = 0x74, + ORCI = 0x75, + NOTC = 0x76, + # Bitwise shifts + SHLC = 0x80, + SHRC = 0x81, + SHLCI = 0x82, + SHRCI = 0x83, + # Branching and comparison + JMP = 0x90, + JMPNZ = 0x91, + JMPEQZ = 0x92, + EQZC = 0x93, + LTZC = 0x94, + LTC = 0x95, + GTC = 0x96, + EQC = 0x97, + JMPI = 0x98, + # Integers + LDINT = 0x9A, + ADDINT = 0x9B, + SUBINT = 0x9C, + MULINT = 0x9D, + DIVINT = 0x9E, + # Conversion + CONVINT = 0xC0, + CONVMODP = 0xC1, + GCONVGF2N = 0x1C1, + # IO + PRINTMEM = 0xB0, + PRINTREG = 0XB1, + RAND = 0xB2, + PRINTREGPLAIN = 0xB3, + PRINTCHR = 0xB4, + PRINTSTR = 0xB5, + PUBINPUT = 0xB6, + RAWOUTPUT = 0xB7, + STARTPRIVATEOUTPUT = 0xB8, + STOPPRIVATEOUTPUT = 0xB9, + PRINTCHRINT = 0xBA, + PRINTSTRINT = 0xBB, + GBITDEC = 0x184, + GBITCOM = 0x185, +) + + +def int_to_bytes(x): + """ 32 bit int to big-endian 4 byte conversion. """ + return [(x >> 8*i) % 256 for i in (3,2,1,0)] + + +global_vector_size = 1 +global_vector_size_depth = 0 +global_instruction_type_stack = ['modp'] + +def set_global_vector_size(size): + global global_vector_size, global_vector_size_depth + if size == 1: + return + if global_vector_size == 1 or global_vector_size == size: + global_vector_size = size + global_vector_size_depth += 1 + else: + raise CompilerError('Cannot set global vector size when already set') + +def set_global_instruction_type(t): + if t == 'modp' or t == 'gf2n': + global_instruction_type_stack.append(t) + else: + raise CompilerError('Invalid type %s for setting global instruction type') + +def reset_global_vector_size(): + global global_vector_size, global_vector_size_depth + if global_vector_size_depth > 0: + global_vector_size_depth -= 1 + if global_vector_size_depth == 0: + global_vector_size = 1 + +def reset_global_instruction_type(): + global_instruction_type_stack.pop() + +def get_global_vector_size(): + return global_vector_size + +def get_global_instruction_type(): + return global_instruction_type_stack[-1] + + +def vectorize(instruction, global_dict=None): + """ Decorator to vectorize instructions. """ + + if global_dict is None: + global_dict = inspect.getmodule(instruction).__dict__ + + class Vectorized_Instruction(instruction): + __slots__ = ['size'] + def __init__(self, size, *args, **kwargs): + self.size = size + super(Vectorized_Instruction, self).__init__(*args, **kwargs) + for arg,f in zip(self.args, self.arg_format): + if issubclass(ArgFormats[f], RegisterArgFormat): + arg.set_size(size) + def get_code(self): + return (self.size << 9) + self.code + def get_pre_arg(self): + return "%d, " % self.size + def is_vec(self): + return self.size > 1 + def get_size(self): + return self.size + def expand(self): + set_global_vector_size(self.size) + super(Vectorized_Instruction, self).expand() + reset_global_vector_size() + + @functools.wraps(instruction) + def maybe_vectorized_instruction(*args, **kwargs): + if global_vector_size == 1: + return instruction(*args, **kwargs) + else: + return Vectorized_Instruction(global_vector_size, *args, **kwargs) + maybe_vectorized_instruction.vec_ins = Vectorized_Instruction + maybe_vectorized_instruction.std_ins = instruction + + vectorized_name = 'v' + instruction.__name__ + Vectorized_Instruction.__name__ = vectorized_name + global_dict[vectorized_name] = Vectorized_Instruction + global_dict[instruction.__name__ + '_class'] = instruction + return maybe_vectorized_instruction + + +def gf2n(instruction): + """ Decorator to create GF_2^n instruction corresponding to a given + modp instruction. + + Adds the new GF_2^n instruction to the globals dictionary. Also adds a + vectorized GF_2^n instruction if a modp version exists. """ + global_dict = inspect.getmodule(instruction).__dict__ + + if global_dict.has_key('v' + instruction.__name__): + vectorized = True + else: + vectorized = False + + if isinstance(instruction, type) and issubclass(instruction, Instruction): + instruction_cls = instruction + else: + try: + instruction_cls = global_dict[instruction.__name__ + '_class'] + except KeyError: + raise CompilerError('Cannot decorate instruction %s' % instruction) + + class GF2N_Instruction(instruction_cls): + __doc__ = instruction_cls.__doc__.replace('c_', 'c^g_').replace('s_', 's^g_') + __slots__ = [] + field_type = 'gf2n' + if isinstance(instruction_cls.code, int): + code = (1 << 8) + instruction_cls.code + + # set modp registers in arg_format to GF2N registers + if 'gf2n_arg_format' in instruction_cls.__dict__: + arg_format = instruction_cls.gf2n_arg_format + elif isinstance(instruction_cls.arg_format, itertools.repeat): + __f = instruction_cls.arg_format.next() + if __f != 'int' and __f != 'p': + arg_format = itertools.repeat(__f[0] + 'g' + __f[1:]) + else: + __format = [] + for __f in instruction_cls.arg_format: + if __f in ('int', 'p', 'ci', 'str'): + __format.append(__f) + else: + __format.append(__f[0] + 'g' + __f[1:]) + arg_format = __format + + def is_gf2n(self): + return True + + def expand(self): + set_global_instruction_type('gf2n') + super(GF2N_Instruction, self).expand() + reset_global_instruction_type() + + GF2N_Instruction.__name__ = 'g' + instruction_cls.__name__ + if vectorized: + vec_GF2N = vectorize(GF2N_Instruction, global_dict) + + @functools.wraps(instruction) + def maybe_gf2n_instruction(*args, **kwargs): + if get_global_instruction_type() == 'gf2n': + if vectorized: + return vec_GF2N(*args, **kwargs) + else: + return GF2N_Instruction(*args, **kwargs) + else: + return instruction(*args, **kwargs) + + # If instruction is vectorized, new GF2N instruction must also be + if vectorized: + global_dict[GF2N_Instruction.__name__] = vec_GF2N + else: + global_dict[GF2N_Instruction.__name__] = GF2N_Instruction + + global_dict[instruction.__name__ + '_class'] = instruction_cls + return maybe_gf2n_instruction + #return instruction + + +class RegType(object): + """ enum-like static class for Register types """ + ClearModp = 'c' + SecretModp = 's' + ClearGF2N = 'cg' + SecretGF2N = 'sg' + ClearInt = 'ci' + + Types = [ClearModp, SecretModp, ClearGF2N, SecretGF2N, ClearInt] + + @staticmethod + def create_dict(init_value_fn): + """ Create a dictionary with all the RegTypes as keys """ + return { + RegType.ClearModp : init_value_fn(), + RegType.SecretModp : init_value_fn(), + RegType.ClearGF2N : init_value_fn(), + RegType.SecretGF2N : init_value_fn(), + RegType.ClearInt : init_value_fn(), + } + +class ArgFormat(object): + @classmethod + def check(cls, arg): + return NotImplemented + + @classmethod + def encode(cls, arg): + return NotImplemented + +class RegisterArgFormat(ArgFormat): + @classmethod + def check(cls, arg): + if not isinstance(arg, program.curr_tape.Register): + raise ArgumentError(arg, 'Invalid register argument') + if arg.i > REG_MAX: + raise ArgumentError(arg, 'Register index too large') + if arg.program != program.curr_tape: + raise ArgumentError(arg, 'Register from other tape, trace: %s' % \ + util.format_trace(arg.caller)) + if arg.reg_type != cls.reg_type: + raise ArgumentError(arg, "Wrong register type '%s', expected '%s'" % \ + (arg.reg_type, cls.reg_type)) + + @classmethod + def encode(cls, arg): + return int_to_bytes(arg.i) + +class ClearModpAF(RegisterArgFormat): + reg_type = RegType.ClearModp + +class SecretModpAF(RegisterArgFormat): + reg_type = RegType.SecretModp + +class ClearGF2NAF(RegisterArgFormat): + reg_type = RegType.ClearGF2N + +class SecretGF2NAF(RegisterArgFormat): + reg_type = RegType.SecretGF2N + +class ClearIntAF(RegisterArgFormat): + reg_type = RegType.ClearInt + +class IntArgFormat(ArgFormat): + @classmethod + def check(cls, arg): + if not isinstance(arg, (int, long)): + raise ArgumentError(arg, 'Expected an integer-valued argument') + + @classmethod + def encode(cls, arg): + return int_to_bytes(arg) + +class ImmediateModpAF(IntArgFormat): + @classmethod + def check(cls, arg): + super(ImmediateModpAF, cls).check(arg) + if arg >= 2**31 or arg < -2**31: + raise ArgumentError(arg, 'Immediate value outside of 32-bit range') + +class ImmediateGF2NAF(IntArgFormat): + @classmethod + def check(cls, arg): + # bounds checking for GF(2^n)??? + super(ImmediateGF2NAF, cls).check(arg) + +class PlayerNoAF(IntArgFormat): + @classmethod + def check(cls, arg): + super(PlayerNoAF, cls).check(arg) + if arg > 256: + raise ArgumentError(arg, 'Player number > 256') + +class String(ArgFormat): + length = 12 + + @classmethod + def check(cls, arg): + if not isinstance(arg, str): + raise ArgumentError(arg, 'Argument is not string') + if len(arg) > cls.length: + raise ArgumentError(arg, 'String longer than ' + cls.length) + if '\0' in arg: + raise ArgumentError(arg, 'String contains zero-byte') + + @classmethod + def encode(cls, arg): + return arg + '\0' * (cls.length - len(arg)) + +ArgFormats = { + 'c': ClearModpAF, + 's': SecretModpAF, + 'cw': ClearModpAF, + 'sw': SecretModpAF, + 'cg': ClearGF2NAF, + 'sg': SecretGF2NAF, + 'cgw': ClearGF2NAF, + 'sgw': SecretGF2NAF, + 'ci': ClearIntAF, + 'ciw': ClearIntAF, + 'i': ImmediateModpAF, + 'ig': ImmediateGF2NAF, + 'int': IntArgFormat, + 'p': PlayerNoAF, + 'str': String, +} + +def format_str_is_reg(format_str): + return issubclass(ArgFormats[format_str], RegisterArgFormat) + +def format_str_is_writeable(format_str): + return format_str_is_reg(format_str) and format_str[-1] == 'w' + + +class Instruction(object): + """ + Base class for a RISC-type instruction. Has methods for checking arguments, + getting byte encoding, emulating the instruction, etc. + """ + __slots__ = ['args', 'arg_format', 'code', 'caller'] + count = 0 + + def __init__(self, *args, **kwargs): + """ Create an instruction and append it to the program list. """ + self.args = list(args) + self.check_args() + if not program.FIRST_PASS: + if kwargs.get('add_to_prog', True): + program.curr_block.instructions.append(self) + if program.DEBUG: + self.caller = [frame[1:] for frame in inspect.stack()[1:]] + else: + self.caller = None + if program.EMULATE: + self.execute() + + Instruction.count += 1 + if Instruction.count % 100000 == 0: + print "Compiled %d lines at" % self.__class__.count, time.asctime() + + def get_code(self): + return self.code + + def get_encoding(self): + enc = int_to_bytes(self.get_code()) + # add the number of registers to a start/stop open instruction + if self.has_var_args(): + enc += int_to_bytes(len(self.args)) + for arg,format in zip(self.args, self.arg_format): + enc += ArgFormats[format].encode(arg) + return enc + + def get_bytes(self): + return bytearray(self.get_encoding()) + + def execute(self): + """ Emulate execution of this instruction """ + raise NotImplementedError('execute method must be implemented') + + def check_args(self): + """ Check the args match up with that specified in arg_format """ + for n,(arg,f) in enumerate(itertools.izip_longest(self.args, self.arg_format)): + if arg is None: + if not isinstance(self.arg_format, (list, tuple)): + break # end of optional arguments + else: + raise CompilerError('Incorrect number of arguments for instruction %s' % (self)) + try: + ArgFormats[f].check(arg) + except ArgumentError as e: + raise CompilerError('Invalid argument "%s" to instruction: %s' + % (e.arg, self) + '\n' + e.msg) + + def get_used(self): + """ Return the set of registers that are read in this instruction. """ + return set(arg for arg,w in zip(self.args, self.arg_format) if \ + format_str_is_reg(w) and not format_str_is_writeable(w)) + + def get_def(self): + """ Return the set of registers that are written to in this instruction. """ + return set(arg for arg,w in zip(self.args, self.arg_format) if \ + format_str_is_writeable(w)) + + def get_pre_arg(self): + return "" + + def has_var_args(self): + return False + + def is_vec(self): + return False + + def is_gf2n(self): + return False + + def get_size(self): + return 1 + + def add_usage(self, req_node): + pass + + def __str__(self): + return self.__class__.__name__ + ' ' + self.get_pre_arg() + ', '.join(str(a) for a in self.args) + + def __repr__(self): + return self.__class__.__name__ + '(' + self.get_pre_arg() + ','.join(str(a) for a in self.args) + ')' + +### +### Basic arithmetic +### + +class AddBase(Instruction): + __slots__ = [] + + def execute(self): + self.args[0].value = (self.args[1].value + self.args[2].value) % program.P + +class SubBase(Instruction): + __slots__ = [] + + def execute(self): + self.args[0].value = (self.args[1].value - self.args[2].value) % program.P + +class MulBase(Instruction): + __slots__ = [] + + def execute(self): + self.args[0].value = (self.args[1].value * self.args[2].value) % program.P + +### +### Basic arithmetic with immediate values +### + +class ImmediateBase(Instruction): + __slots__ = ['op'] + + def execute(self): + exec('self.args[0].value = self.args[1].value.%s(self.args[2]) %% program.P' % self.op) + +class SharedImmediate(ImmediateBase): + __slots__ = [] + arg_format = ['sw', 's', 'i'] + +class ClearImmediate(ImmediateBase): + __slots__ = [] + arg_format = ['cw', 'c', 'i'] + + +### +### Memory access instructions +### + +class DirectMemoryInstruction(Instruction): + __slots__ = [] + def __init__(self, *args, **kwargs): + super(DirectMemoryInstruction, self).__init__(*args, **kwargs) + +class ReadMemoryInstruction(Instruction): + __slots__ = [] + +class WriteMemoryInstruction(Instruction): + __slots__ = [] + +class DirectMemoryWriteInstruction(DirectMemoryInstruction, \ + WriteMemoryInstruction): + __slots__ = [] + def __init__(self, *args, **kwargs): + if program.curr_tape.prevent_direct_memory_write: + raise CompilerError('Direct memory writing prevented') + super(DirectMemoryWriteInstruction, self).__init__(*args, **kwargs) + +### +### I/O instructions +### + +class DoNotEliminateInstruction(Instruction): + """ What do you think? """ + __slots__ = [] + +class IOInstruction(DoNotEliminateInstruction): + """ Instruction that uses stdin/stdout during runtime. These are linked + to prevent instruction reordering during optimization. """ + __slots__ = [] + + @classmethod + def str_to_int(cls, s): + """ Convert a 4 character string to an integer. """ + if len(s) > 4: + raise CompilerError('String longer than 4 characters') + n = 0 + for c in reversed(s.ljust(4)): + n <<= 8 + n += ord(c) + return n + +class AsymmetricCommunicationInstruction(DoNotEliminateInstruction): + """ Instructions involving sending from or to only one party. """ + __slots__ = [] + +class RawInputInstruction(AsymmetricCommunicationInstruction): + """ Raw input instructions. """ + __slots__ = [] + +class PublicFileIOInstruction(DoNotEliminateInstruction): + """ Instruction to reads/writes public information from/to files. """ + __slots__ = [] + +### +### Data access instructions +### + +class DataInstruction(Instruction): + __slots__ = [] + field_type = 'modp' + + def add_usage(self, req_node): + req_node.increment((self.field_type, self.data_type), self.get_size()) + +### +### Integer operations +### + +class IntegerInstruction(Instruction): + """ Base class for integer operations. """ + __slots__ = [] + arg_format = ['ciw', 'ci', 'ci'] + +### +### Clear comparison instructions +### + +class UnaryComparisonInstruction(Instruction): + """ Base class for unary comparisons. """ + __slots__ = [] + arg_format = ['ciw', 'ci'] + +### +### Clear shift instructions +### + +class ClearShiftInstruction(ClearImmediate): + __slots__ = [] + + def check_args(self): + super(ClearShiftInstruction, self).check_args() + if program.galois_length > 64: + bits = 127 + else: + # assume 64-bit machine + bits = 63 + if self.args[2] > bits: + raise CompilerError('Shifting by more than %d bits ' + 'not implemented' % bits) + +### +### Jumps etc +### + +class dummywrite(Instruction): + """ Dummy instruction to create source node in the dependency graph, + preventing read-before-write warnings. """ + __slots__ = [] + + def __init__(self, *args, **kwargs): + self.arg_format = [arg.reg_type + 'w' for arg in args] + super(dummywrite, self).__init__(*args, **kwargs) + + def execute(self): + pass + + def get_encoding(self): + return [] + +class JumpInstruction(Instruction): + __slots__ = ['jump_arg'] + + def set_relative_jump(self, value): + if value == -1: + raise CompilerException('Jump by -1 would cause infinite loop') + self.args[self.jump_arg] = value + + def get_relative_jump(self): + return self.args[self.jump_arg] + + +class CISC(Instruction): + """ + Base class for a CISC instruction. + + Children must implement expand(self) to process the instruction. + """ + __slots__ = [] + code = None + + def __init__(self, *args): + self.args = args + self.check_args() + #if EMULATE: + # self.expand() + if not program.FIRST_PASS: + self.expand() + + def expand(self): + """ Expand this into a sequence of RISC instructions. """ + raise NotImplementedError('expand method must be implemented') diff --git a/Compiler/library.py b/Compiler/library.py new file mode 100644 index 000000000..e5f1c4217 --- /dev/null +++ b/Compiler/library.py @@ -0,0 +1,1115 @@ +# (C) 2016 University of Bristol. See License.txt + +from Compiler.types import cint,sint,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types +from Compiler.instructions import * +from Compiler.util import tuplify,untuplify +from Compiler import instructions,instructions_base,comparison,program +import inspect,math +import random +import collections + +def get_program(): + return instructions.program +def get_tape(): + return get_program().curr_tape +def get_block(): + return get_program().curr_block + +def vectorize(function): + def vectorized_function(*args, **kwargs): + if len(args) > 0 and isinstance(args[0], program.Tape.Register): + instructions_base.set_global_vector_size(args[0].size) + res = function(*args, **kwargs) + instructions_base.reset_global_vector_size() + elif 'size' in kwargs: + instructions_base.set_global_vector_size(kwargs['size']) + del kwargs['size'] + res = function(*args, **kwargs) + instructions_base.reset_global_vector_size() + else: + res = function(*args, **kwargs) + return res + vectorized_function.__name__ = function.__name__ + return vectorized_function + +def set_instruction_type(function): + def instruction_typed_function(*args, **kwargs): + if len(args) > 0 and isinstance(args[0], program.Tape.Register): + if args[0].is_gf2n: + instructions_base.set_global_instruction_type('gf2n') + else: + instructions_base.set_global_instruction_type('modp') + res = function(*args, **kwargs) + instructions_base.reset_global_instruction_type() + else: + res = function(*args, **kwargs) + return res + instruction_typed_function.__name__ = function.__name__ + return instruction_typed_function + + +def print_str(s, *args): + """ Print a string, with optional args for adding variables/registers with %s """ + def print_plain_str(ss): + """ Print a plain string (no custom formatting options) """ + i = 1 + while 4*i < len(ss): + print_char4(ss[4*(i-1):4*i]) + i += 1 + i = 4*(i-1) + while i < len(ss): + print_char(ss[i]) + i += 1 + + if len(args) != s.count('%s'): + raise CompilerError('Incorrect number of arguments for string format:', s) + substrings = s.split('%s') + for i,ss in enumerate(substrings): + print_plain_str(ss) + if i < len(args): + if isinstance(args[i], MemValue): + val = args[i].register + else: + val = args[i] + + if isinstance(val, program.Tape.Register): + if val.reg_type == 'ci': + cint(val).print_reg_plain() + elif val.is_clear: + val.print_reg_plain() + else: + raise CompilerError('Cannot print secret value:', args[i]) + elif isinstance(val, list): + print_str('[' + ', '.join('%s' for i in range(len(val))) + ']', *val) + else: + try: + val.output() + except AttributeError: + print_plain_str(str(val)) + +def print_ln(s='', *args): + """ Print line, with optional args for adding variables/registers with %s """ + print_str(s, *args) + print_char('\n') + +def runtime_error(msg='', *args): + """ Print an error message and abort the runtime. """ + print_str('User exception: ') + print_ln(msg, *args) + crash() + +def public_input(): + res = regint() + pubinput(res) + return res + +# mostly obsolete functions +# use the equivalent from types.py + +def load_int(value, size=None): + return regint(value, size=size) + +def load_int_to_secret(value, size=None): + return sint(value, size=size) + +def load_int_to_secret_vector(vector): + res = sint(size=len(vector)) + for i,val in enumerate(vector): + ldsi(res[i], val) + return res + +@vectorize +def load_float_to_secret(value, sec=40): + def _bit_length(x): + return len(bin(x).lstrip('-0b')) + + num,den = value.as_integer_ratio() + exp = int(round(math.log(den, 2))) + + nbits = _bit_length(num) + if nbits > sfloat.vlen: + num >>= (nbits - sfloat.vlen) + exp -= (nbits - sfloat.vlen) + elif nbits < sfloat.vlen: + num <<= (sfloat.vlen - nbits) + exp += (sfloat.vlen - nbits) + + if _bit_length(exp) > sfloat.plen: + raise CompilerException('Cannot load floating point to secret: overflow') + if num < 0: + s = load_int_to_secret(1) + z = load_int_to_secret(0) + else: + s = load_int_to_secret(0) + if num == 0: + z = load_int_to_secret(1) + else: + z = load_int_to_secret(0) + v = load_int_to_secret(num) + p = load_int_to_secret(exp) + return sfloat(v, p, s, z) + +def load_clear_mem(address): + return cint.load_mem(address) + +def load_secret_mem(address): + return sint.load_mem(address) + +def load_mem(address, value_type): + if value_type in _types: + value_type = _types[value_type] + return value_type.load_mem(address) + +@vectorize +def store_in_mem(value, address): + if isinstance(value, int): + value = load_int(value) + if isinstance(value, _register): + value.store_in_mem(address) + else: + # legacy + if value.is_clear: + if isinstance(address, cint): + stmci(value, address) + else: + stmc(value, address) + else: + if isinstance(address, cint): + stmsi(value, address) + else: + stms(value, address) + +@set_instruction_type +@vectorize +def reveal(secret): + if isinstance(secret, _number): + return secret.reveal() + if secret.is_gf2n: + res = cgf2n() + else: + res = cint() + instructions.asm_open(res, secret) + return res + +@vectorize +def compare_secret(a, b, length, sec=40): + res = sint() + instructions.lts(res, a, b, length, sec) + +def get_input_from(player, size=None): + return sint.get_input_from(player, size=size) + +def get_random_triple(size=None): + return sint.get_random_triple(size=size) + +def get_random_bit(size=None): + return sint.get_random_bit(size=size) + +def get_random_square(size=None): + return sint.get_random_square(size=size) + +def get_random_inverse(size=None): + return sint.get_random_inverse(size=size) + +def get_random_int(bits, size=None): + return sint.get_random_int(bits, size=size) + +@vectorize +def get_thread_number(): + res = regint() + ldtn(res) + return res + +@vectorize +def get_arg(): + res = regint() + ldarg(res) + return res + +def make_array(l): + if isinstance(l, program.Tape.Register): + res = Array(1, type(l)) + res[0] = l + else: + l = list(l) + res = Array(len(l), type(l[0]) if l else cint) + res.assign(l) + return res + + +class FunctionTapeCall: + def __init__(self, thread, base, bases): + self.thread = thread + self.base = base + self.bases = bases + def start(self): + self.thread.start(self.base) + return self + def join(self): + self.thread.join() + instructions.program.free(self.base, 'ci') + for reg_type,addr in self.bases.iteritems(): + get_program().free(addr, reg_type.reg_type) + +class Function: + def __init__(self, function, name=None, compile_args=[]): + self.type_args = {} + self.function = function + self.name = name + if name is None: + self.name = self.function.__name__ + '-' + str(id(function)) + self.compile_args = compile_args + def __call__(self, *args): + args = tuple(arg.read() if isinstance(arg, MemValue) else arg for arg in args) + get_reg_type = lambda x: regint if isinstance(x, (int, long)) else type(x) + if len(args) not in self.type_args: + # first call + type_args = collections.defaultdict(list) + for i,arg in enumerate(args): + type_args[get_reg_type(arg)].append(i) + def wrapped_function(*compile_args): + base = get_arg() + bases = dict((t, regint.load_mem(base + i)) \ + for i,t in enumerate(type_args)) + runtime_args = [None] * len(args) + for t,i_args in type_args.iteritems(): + for i,i_arg in enumerate(i_args): + runtime_args[i_arg] = t.load_mem(bases[t] + i) + return self.function(*(list(compile_args) + runtime_args)) + self.on_first_call(wrapped_function) + self.type_args[len(args)] = type_args + type_args = self.type_args[len(args)] + base = instructions.program.malloc(len(type_args), 'ci') + bases = dict((t, get_program().malloc(len(type_args[t]), t)) \ + for t in type_args) + for i,reg_type in enumerate(type_args): + store_in_mem(bases[reg_type], base + i) + for j,i_arg in enumerate(type_args[reg_type]): + if get_reg_type(args[i_arg]) != reg_type: + raise CompilerError('type mismatch') + store_in_mem(args[i_arg], bases[reg_type] + j) + return self.on_call(base, bases) + +class FunctionTape(Function): + # not thread-safe + def on_first_call(self, wrapped_function): + self.thread = MPCThread(wrapped_function, self.name, + args=self.compile_args) + def on_call(self, base, bases): + return FunctionTapeCall(self.thread, base, bases) + +def function_tape(function): + return FunctionTape(function) + +def function_tape_with_compile_args(*args): + def wrapper(function): + return FunctionTape(function, compile_args=args) + return wrapper + +def memorize(x): + if isinstance(x, (tuple, list)): + return tuple(memorize(i) for i in x) + else: + return MemValue(x) + +def unmemorize(x): + if isinstance(x, (tuple, list)): + return tuple(unmemorize(i) for i in x) + else: + return x.read() + +class FunctionBlock(Function): + def on_first_call(self, wrapped_function): + old_block = get_tape().active_basicblock + parent_node = get_tape().req_node + get_tape().open_scope(lambda x: x[0], None, 'begin-' + self.name) + block = get_tape().active_basicblock + block.persistent_allocation = True + del parent_node.children[-1] + self.node = get_tape().req_node + print 'Compiling function', self.name + result = wrapped_function(*self.compile_args) + if result is not None: + self.result = memorize(result) + else: + self.result = None + print 'Done compiling function', self.name + p_return_address = get_tape().program.malloc(1, 'ci') + get_tape().function_basicblocks[block] = p_return_address + return_address = regint.load_mem(p_return_address) + get_tape().active_basicblock.set_exit(instructions.jmpi(return_address, add_to_prog=False)) + self.last_sub_block = get_tape().active_basicblock + get_tape().close_scope(old_block, parent_node, 'end-' + self.name) + old_block.set_exit(instructions.jmp(0, add_to_prog=False), get_tape().active_basicblock) + self.basic_block = block + + def on_call(self, base, bases): + if base is not None: + instructions.starg(regint(base)) + block = self.basic_block + if block not in get_tape().function_basicblocks: + raise CompilerError('unknown function') + old_block = get_tape().active_basicblock + old_block.set_exit(instructions.jmp(0, add_to_prog=False), block) + p_return_address = get_tape().function_basicblocks[block] + return_address = get_tape().new_reg('ci') + old_block.return_address_store = instructions.ldint(return_address, 0) + instructions.stmint(return_address, p_return_address) + get_tape().start_new_basicblock(name='call-' + self.name) + get_tape().active_basicblock.set_return(old_block, self.last_sub_block) + get_tape().req_node.children.append(self.node) + if self.result is not None: + return unmemorize(self.result) + +def function_block(function): + return FunctionBlock(function) + +def function_block_with_compile_args(*args): + def wrapper(function): + return FunctionBlock(function, compile_args=args) + return wrapper + +def method_block(function): + # If you use this, make sure to use MemValue for all member + # variables. + compiled_functions = {} + def wrapper(self, *args): + if self in compiled_functions: + return compiled_functions[self](*args) + else: + name = '%s-%s-%d' % (type(self).__name__, function.__name__, \ + id(self)) + block = FunctionBlock(function, name=name, compile_args=(self,)) + compiled_functions[self] = block + return block(*args) + return wrapper + +def cond_swap(x,y): + b = x < y + if isinstance(x, sfloat): + res = ([], []) + for i,j in enumerate(('v','p','z','s')): + xx = x.__getattribute__(j) + yy = y.__getattribute__(j) + bx = b * xx + by = b * yy + res[0].append(bx + yy - by) + res[1].append(xx - bx + by) + return sfloat(*res[0]), sfloat(*res[1]) + bx = b * x + by = b * y + return bx + y - by, x - bx + by + +def sort(a): + res = a + + for i in range(len(a)): + for j in reversed(range(i)): + res[j], res[j+1] = cond_swap(res[j], res[j+1]) + + return res + +def odd_even_merge(a): + if len(a) == 2: + a[0], a[1] = cond_swap(a[0], a[1]) + else: + even = a[::2] + odd = a[1::2] + odd_even_merge(even) + odd_even_merge(odd) + a[0] = even[0] + for i in range(1, len(a) / 2): + a[2*i-1], a[2*i] = cond_swap(odd[i-1], even[i]) + a[-1] = odd[-1] + +def odd_even_merge_sort(a): + if len(a) == 1: + return + elif len(a) % 2 == 0: + lower = a[:len(a)/2] + upper = a[len(a)/2:] + odd_even_merge_sort(lower) + odd_even_merge_sort(upper) + a[:] = lower + upper + odd_even_merge(a) + else: + raise CompilerError('Length of list must be power of two') + +def chunky_odd_even_merge_sort(a): + for i,j in enumerate(a): + j.store_in_mem(i * j.sizeof()) + l = 1 + while l < len(a): + l *= 2 + k = 1 + while k < l: + k *= 2 + def round(): + for i in range(len(a)): + a[i] = type(a[i]).load_mem(i * a[i].sizeof()) + for i in range(len(a) / l): + for j in range(l / k): + base = i * l + j + step = l / k + if k == 2: + a[base], a[base+step] = cond_swap(a[base], a[base+step]) + else: + b = a[base:base+k*step:step] + for m in range(base + step, base + (k - 1) * step, 2 * step): + a[m], a[m+step] = cond_swap(a[m], a[m+step]) + for i in range(len(a)): + a[i].store_in_mem(i * a[i].sizeof()) + chunk = MPCThread(round, 'sort-%d-%d-%03x' % (l,k,random.randrange(256**3))) + chunk.start() + chunk.join() + #round() + for i in range(len(a)): + a[i] = type(a[i]).load_mem(i * a[i].sizeof()) + +def chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7, use_chunk_wraps=False): + if n is None: + n = len(a) + a_base = instructions.program.malloc(n, 's') + for i,j in enumerate(a): + store_in_mem(j, a_base + i) + instructions.program.restart_main_thread() + else: + a_base = a + tmp_base = instructions.program.malloc(n, 's') + chunks = {} + threads = [] + + def run_threads(): + for thread in threads: + thread.start() + for thread in threads: + thread.join() + del threads[:] + + def run_chunk(size, base): + if size not in chunks: + def swap_list(list_base): + for i in range(size / 2): + base = list_base + 2 * i + x, y = cond_swap(load_secret_mem(base), + load_secret_mem(base + 1)) + store_in_mem(x, base) + store_in_mem(y, base + 1) + chunks[size] = FunctionTape(swap_list, 'sort-%d-%03x' % + (size, random.randrange(256**3))) + return chunks[size](base) + + def run_round(size): + # minimize number of chunk sizes + n_chunks = int(math.ceil(1.0 * size / max_chunk_size)) + lower_size = size / n_chunks / 2 * 2 + n_lower_size = n_chunks - (size - n_chunks * lower_size) / 2 + # print len(to_swap) == lower_size * n_lower_size + \ + # (lower_size + 2) * (n_chunks - n_lower_size), \ + # len(to_swap), n_chunks, lower_size, n_lower_size + base = 0 + round_threads = [] + for i in range(n_lower_size): + round_threads.append(run_chunk(lower_size, tmp_base + base)) + base += lower_size + for i in range(n_chunks - n_lower_size): + round_threads.append(run_chunk(lower_size + 2, tmp_base + base)) + base += lower_size + 2 + run_threads_in_rounds(round_threads) + + postproc_chunks = [] + wrap_chunks = {} + post_threads = [] + pre_threads = [] + + def load_and_store(x, y, to_right): + if to_right: + store_in_mem(load_secret_mem(x), y) + else: + store_in_mem(load_secret_mem(y), x) + + def run_setup(k, a_addr, step, tmp_addr): + if k == 2: + def mem_op(preproc, a_addr, step, tmp_addr): + load_and_store(a_addr, tmp_addr, preproc) + load_and_store(a_addr + step, tmp_addr + 1, preproc) + res = 2 + else: + def mem_op(preproc, a_addr, step, tmp_addr): + instructions.program.curr_tape.merge_opens = False +# for i,m in enumerate(range(a_addr + step, a_addr + (k - 1) * step, step)): + for i in range(k - 2): + m = a_addr + step + i * step + load_and_store(m, tmp_addr + i, preproc) + res = k - 2 + if not use_chunk_wraps or k <= 4: + mem_op(True, a_addr, step, tmp_addr) + postproc_chunks.append((mem_op, (a_addr, step, tmp_addr))) + else: + if k not in wrap_chunks: + pre_chunk = FunctionTape(mem_op, 'pre-%d-%03x' % (k,random.randrange(256**3)), + compile_args=[True]) + post_chunk = FunctionTape(mem_op, 'post-%d-%03x' % (k,random.randrange(256**3)), + compile_args=[False]) + wrap_chunks[k] = (pre_chunk, post_chunk) + pre_chunk, post_chunk = wrap_chunks[k] + pre_threads.append(pre_chunk(a_addr, step, tmp_addr)) + post_threads.append(post_chunk(a_addr, step, tmp_addr)) + return res + + def run_threads_in_rounds(all_threads): + for thread in all_threads: + if len(threads) == n_threads: + run_threads() + threads.append(thread) + run_threads() + del all_threads[:] + + def run_postproc(): + run_threads_in_rounds(post_threads) + for chunk,args in postproc_chunks: + chunk(False, *args) + postproc_chunks[:] = [] + + l = 1 + while l < n: + l *= 2 + k = 1 + while k < l: + k *= 2 + size = 0 + instructions.program.curr_tape.merge_opens = False + for i in range(n / l): + for j in range(l / k): + base = i * l + j + step = l / k + size += run_setup(k, a_base + base, step, tmp_base + size) + run_threads_in_rounds(pre_threads) + run_round(size) + run_postproc() + + if isinstance(a, list): + instructions.program.restart_main_thread() + for i in range(n): + a[i] = load_secret_mem(a_base + i) + instructions.program.free(a_base, 's') + instructions.program.free(tmp_base, 's') + +def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7): + if n is None: + n = len(a) + a_base = instructions.program.malloc(n, 's') + for i,j in enumerate(a): + store_in_mem(j, a_base + i) + instructions.program.restart_main_thread() + else: + a_base = a + tmp_base = instructions.program.malloc(n, 's') + tmp_i = instructions.program.malloc(1, 'ci') + chunks = {} + threads = [] + + def run_threads(): + for thread in threads: + thread.start() + for thread in threads: + thread.join() + del threads[:] + + def run_threads_in_rounds(all_threads): + for thread in all_threads: + if len(threads) == n_threads: + run_threads() + threads.append(thread) + run_threads() + del all_threads[:] + + def run_chunk(size, base): + if size not in chunks: + def swap_list(list_base): + for i in range(size / 2): + base = list_base + 2 * i + x, y = cond_swap(load_secret_mem(base), + load_secret_mem(base + 1)) + store_in_mem(x, base) + store_in_mem(y, base + 1) + chunks[size] = FunctionTape(swap_list, 'sort-%d-%03x' % + (size, random.randrange(256**3))) + return chunks[size](base) + + def run_round(size): + # minimize number of chunk sizes + n_chunks = int(math.ceil(1.0 * size / max_chunk_size)) + lower_size = size / n_chunks / 2 * 2 + n_lower_size = n_chunks - (size - n_chunks * lower_size) / 2 + # print len(to_swap) == lower_size * n_lower_size + \ + # (lower_size + 2) * (n_chunks - n_lower_size), \ + # len(to_swap), n_chunks, lower_size, n_lower_size + base = 0 + round_threads = [] + for i in range(n_lower_size): + round_threads.append(run_chunk(lower_size, tmp_base + base)) + base += lower_size + for i in range(n_chunks - n_lower_size): + round_threads.append(run_chunk(lower_size + 2, tmp_base + base)) + base += lower_size + 2 + run_threads_in_rounds(round_threads) + + l = 1 + while l < n: + l *= 2 + k = 1 + while k < l: + k *= 2 + def load_and_store(x, y): + if to_tmp: + store_in_mem(load_secret_mem(x), y) + else: + store_in_mem(load_secret_mem(y), x) + def outer(i): + def inner(j): + base = j + step = l / k + if k == 2: + tmp_addr = regint.load_mem(tmp_i) + load_and_store(base, tmp_addr) + load_and_store(base + step, tmp_addr + 1) + store_in_mem(tmp_addr + 2, tmp_i) + else: + def inner2(m): + tmp_addr = regint.load_mem(tmp_i) + load_and_store(m, tmp_addr) + store_in_mem(tmp_addr + 1, tmp_i) + range_loop(inner2, base + step, base + (k - 1) * step, step) + range_loop(inner, a_base + i * l, a_base + i * l + l / k) + instructions.program.curr_tape.merge_opens = False + to_tmp = True + store_in_mem(tmp_base, tmp_i) + range_loop(outer, n / l) + if k == 2: + run_round(n) + else: + run_round(n / k * (k - 2)) + instructions.program.curr_tape.merge_opens = False + to_tmp = False + store_in_mem(tmp_base, tmp_i) + range_loop(outer, n / l) + + if isinstance(a, list): + instructions.program.restart_main_thread() + for i in range(n): + a[i] = load_secret_mem(a_base + i) + instructions.program.free(a_base, 's') + instructions.program.free(tmp_base, 's') + instructions.program.free(tmp_i, 'ci') + + +def loopy_odd_even_merge_sort(a, sorted_length=1, n_parallel=32): + l = sorted_length + while l < len(a): + l *= 2 + k = 1 + while k < l: + k *= 2 + n_outer = len(a) / l + n_inner = l / k + n_innermost = 1 if k == 2 else k / 2 - 1 + @for_range_parallel(n_parallel / n_innermost / n_inner, n_outer) + def loop(i): + @for_range_parallel(n_parallel / n_innermost, n_inner) + def inner(j): + base = i*l + j + step = l/k + if k == 2: + a[base], a[base+step] = cond_swap(a[base], a[base+step]) + else: + @for_range_parallel(n_parallel, n_innermost) + def f(i): + m1 = step + i * 2 * step + m2 = m1 + base + a[m2], a[m2+step] = cond_swap(a[m2], a[m2+step]) + +def mergesort(A): + B = Array(len(A), sint) + + def merge(i_left, i_right, i_end): + i0 = MemValue(i_left) + i1 = MemValue(i_right) + @for_range(i_left, i_end) + def loop(j): + if_then(and_(lambda: i0 < i_right, + or_(lambda: i1 >= i_end, + lambda: regint(reveal(A[i0] <= A[i1]))))) + B[j] = A[i0] + i0.iadd(1) + else_then() + B[j] = A[i1] + i1.iadd(1) + end_if() + + width = MemValue(1) + @do_while + def width_loop(): + @for_range(0, len(A), 2 * width) + def merge_loop(i): + merge(i, i + width, i + 2 * width) + A.assign(B) + width.imul(2) + return width < len(A) + +def range_loop(loop_body, start, stop=None, step=None): + if stop is None: + stop = start + start = 0 + if step is None: + step = 1 + def loop_fn(i): + loop_body(i) + return i + step + if isinstance(step, int): + if step > 0: + condition = lambda x: x < stop + elif step < 0: + condition = lambda x: x > stop + else: + raise CompilerError('step must not be zero') + else: + b = step > 0 + condition = lambda x: b * (x < stop) + (1 - b) * (x > stop) + while_loop(loop_fn, condition, start) + if isinstance(start, int) and isinstance(stop, int) \ + and isinstance(step, int): + # known loop count + if condition(start): + get_tape().req_node.children[-1].aggregator = \ + lambda x: ((stop - start) / step) * x[0] + +def for_range(start, stop=None, step=None): + def decorator(loop_body): + range_loop(loop_body, start, stop, step) + return loop_body + return decorator + +def for_range_parallel(n_parallel, n_loops): + return map_reduce_single(n_parallel, n_loops, \ + lambda *x: [], lambda *x: []) + +def map_reduce_single(n_parallel, n_loops, initializer, reducer, mem_state=None): + if not isinstance(n_parallel, int): + raise CompilerException('Number of parallel executions' \ + 'must be constant') + n_parallel = n_parallel or 1 + if mem_state is None: + # default to list of MemValues to allow varying types + mem_state = [MemValue(x) for x in initializer()] + use_array = False + else: + # use Arrays for multithread version + use_array = True + def decorator(loop_body): + if isinstance(n_loops, int): + loop_rounds = n_loops / n_parallel \ + if n_parallel < n_loops else 0 + else: + loop_rounds = n_loops / n_parallel + def write_state_to_memory(r): + if use_array: + mem_state.assign(r) + else: + # cannot do mem_state = [...] due to scope issue + for j,x in enumerate(r): + mem_state[j].write(x) + # will be optimized out if n_loops <= n_parallel + @for_range(loop_rounds) + def f(i): + state = tuplify(initializer()) + for k in range(n_parallel): + j = i * n_parallel + k + state = reducer(tuplify(loop_body(j)), state) + r = reducer(mem_state, state) + write_state_to_memory(r) + if isinstance(n_loops, int): + state = mem_state + for j in range(loop_rounds * n_parallel, n_loops): + state = reducer(tuplify(loop_body(j)), state) + else: + @for_range(loop_rounds * n_parallel, n_loops) + def f(j): + r = reducer(tuplify(loop_body(j)), mem_state) + write_state_to_memory(r) + state = mem_state + for i,x in enumerate(state): + if use_array: + mem_state[i] = x + else: + mem_state[i].write(x) + def returner(): + return untuplify(tuple(state)) + return returner + return decorator + +def for_range_multithread(n_threads, n_parallel, n_loops, thread_mem_req={}): + return map_reduce(n_threads, n_parallel, n_loops, \ + lambda *x: [], lambda *x: [], thread_mem_req) + +def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \ + thread_mem_req={}): + n_threads = n_threads or 1 + if n_threads == 1 or n_loops == 1: + dec = map_reduce_single(n_parallel, n_loops, initializer, reducer) + if thread_mem_req: + thread_mem = Array(thread_mem_req[regint], regint) + return lambda loop_body: dec(lambda i: loop_body(i, thread_mem)) + else: + return dec + def decorator(loop_body): + thread_rounds = n_loops / n_threads + remainder = n_loops % n_threads + for t in thread_mem_req: + if t != regint: + raise CompilerError('Not implemented for other than regint') + args = Matrix(n_threads, 2 + thread_mem_req.get(regint, 0), 'ci') + state = tuple(initializer()) + def f(inc): + if thread_mem_req: + thread_mem = Array(thread_mem_req[regint], regint, \ + args[get_arg()].address + 2) + mem_state = Array(len(state), type(state[0]) \ + if state else cint, args[get_arg()][1]) + base = args[get_arg()][0] + @map_reduce_single(n_parallel, thread_rounds + inc, \ + initializer, reducer, mem_state) + def f(i): + if thread_mem_req: + return loop_body(base + i, thread_mem) + else: + return loop_body(base + i) + prog = get_program() + threads = [] + if thread_rounds: + tape = prog.new_tape(f, (0,), 'multithread') + for i in range(n_threads - remainder): + mem_state = make_array(initializer()) + args[remainder + i][0] = i * thread_rounds + if len(mem_state): + args[remainder + i][1] = mem_state.address + threads.append(prog.run_tape(tape, remainder + i)) + if remainder: + tape1 = prog.new_tape(f, (1,), 'multithread1') + for i in range(remainder): + mem_state = make_array(initializer()) + args[i][0] = (n_threads - remainder + i) * thread_rounds + i + if len(mem_state): + args[i][1] = mem_state.address + threads.append(prog.run_tape(tape1, i)) + for thread in threads: + prog.join_tape(thread) + if state: + if thread_rounds: + for i in range(n_threads - remainder): + state = reducer(Array(len(state), type(state[0]), \ + args[remainder + i][1]), state) + if remainder: + for i in range(remainder): + state = reducer(Array(len(state), type(state[0]).reg_type, \ + args[i][1]), state) + def returner(): + return untuplify(state) + return returner + return decorator + +def map_sum(n_threads, n_parallel, n_loops, n_items, value_types): + value_types = tuplify(value_types) + if len(value_types) == 1: + value_types *= n_items + elif len(value_types) != n_items: + raise CompilerError('Incorrect number of value_types.') + initializer = lambda: [t(0) for t in value_types] + def summer(x,y): + return tuple(a + b for a,b in zip(x,y)) + return map_reduce(n_threads, n_parallel, n_loops, initializer, summer) + +def foreach_enumerate(a): + for x in a: + get_program().public_input(' '.join(str(y) for y in tuplify(x))) + def decorator(loop_body): + @for_range(len(a)) + def f(i): + loop_body(i, *(public_input() for j in range(len(tuplify(a[0]))))) + return f + return decorator + +def while_loop(loop_body, condition, arg): + if not callable(condition): + raise CompilerError('Condition must be callable') + # store arg in stack + pre_condition = condition(arg) + if not isinstance(pre_condition, (bool,int)) or pre_condition: + pushint(arg if isinstance(arg,regint) else regint(arg)) + def loop_fn(): + result = loop_body(regint.pop()) + pushint(result) + return condition(result) + if_statement(pre_condition, lambda: do_while(loop_fn)) + regint.pop() + +def while_do(condition, *args): + def decorator(loop_body): + while_loop(loop_body, condition, *args) + return loop_body + return decorator + +def do_loop(condition, loop_fn): + # store initial condition to stack + pushint(condition if isinstance(condition,regint) else regint(condition)) + def wrapped_loop(): + # save condition to stack + new_cond = regint.pop() + # run the loop + condition = loop_fn(new_cond) + pushint(condition) + return condition + do_while(wrapped_loop) + regint.pop() + +def do_while(loop_fn): + scope = instructions.program.curr_block + parent_node = get_tape().req_node + # possibly unknown loop count + get_tape().open_scope(lambda x: x[0].set_all(float('Inf')), \ + name='begin-loop') + loop_block = instructions.program.curr_block + condition = loop_fn() + if callable(condition): + condition = condition() + branch = instructions.jmpnz(regint.conv(condition), 0, add_to_prog=False) + instructions.program.curr_block.set_exit(branch, loop_block) + get_tape().close_scope(scope, parent_node, 'end-loop') + return loop_fn + +def if_then(condition): + class State: pass + state = State() + if callable(condition): + condition = condition() + state.condition = regint.conv(condition) + state.start_block = instructions.program.curr_block + state.req_child = get_tape().open_scope(lambda x: x[0].max(x[1]), \ + name='if-block') + state.has_else = False + instructions.program.curr_tape.if_states.append(state) + +def else_then(): + try: + state = instructions.program.curr_tape.if_states[-1] + except IndexError: + raise CompilerError('No open if block') + if state.has_else: + raise CompilerError('else block already defined') + # run the else block + state.if_exit_block = instructions.program.curr_block + state.req_child.add_node(get_tape(), 'else-block') + instructions.program.curr_tape.start_new_basicblock(state.start_block, \ + name='else-block') + state.else_block = instructions.program.curr_block + state.has_else = True + +def end_if(): + try: + state = instructions.program.curr_tape.if_states.pop() + except IndexError: + raise CompilerError('No open if/else block') + branch = instructions.jmpeqz(regint.conv(state.condition), 0, \ + add_to_prog=False) + # start next block + get_tape().close_scope(state.start_block, state.req_child.parent, 'end-if') + if state.has_else: + # jump to else block if condition == 0 + state.start_block.set_exit(branch, state.else_block) + # set if block to skip else + jump = instructions.jmp(0, add_to_prog=False) + state.if_exit_block.set_exit(jump, instructions.program.curr_block) + else: + # set start block's conditional jump to next block + state.start_block.set_exit(branch, instructions.program.curr_block) + # nothing to compute without else + state.req_child.aggregator = lambda x: x[0] + +def if_statement(condition, if_fn, else_fn=None): + if condition is True or condition is False: + # condition known at compile time + if condition: + if_fn() + elif else_fn is not None: + else_fn() + else: + state = if_then(condition) + if_fn() + if else_fn is not None: + else_then() + else_fn() + end_if() + +def if_(condition): + def decorator(body): + if_then(condition) + body() + end_if() + return decorator + +def if_e(condition): + def decorator(body): + if_then(condition) + body() + return decorator + +def else_(body): + else_then() + body() + end_if() + +def and_(*terms): + # not thread-safe + p_res = instructions.program.malloc(1, 'ci') + for term in terms: + if_then(term()) + store_in_mem(1, p_res) + for term in terms: + else_then() + store_in_mem(0, p_res) + end_if() + def load_result(): + res = regint.load_mem(p_res) + instructions.program.free(p_res, 'ci') + return res + return load_result + +def or_(*terms): + # not thread-safe + p_res = instructions.program.malloc(1, 'ci') + res = regint() + for term in terms: + if_then(term()) + store_in_mem(1, p_res) + else_then() + store_in_mem(0, p_res) + for term in terms: + end_if() + def load_result(): + res = regint.load_mem(p_res) + instructions.program.free(p_res, 'ci') + return res + return load_result + +def not_(term): + return lambda: 1 - term() + +def start_timer(timer_id=0): + get_tape().start_new_basicblock(name='pre-start-timer') + start(timer_id) + get_tape().start_new_basicblock(name='post-start-timer') + +def stop_timer(timer_id=0): + get_tape().start_new_basicblock(name='pre-stop-timer') + stop(timer_id) + get_tape().start_new_basicblock(name='post-stop-timer') diff --git a/Compiler/program.py b/Compiler/program.py new file mode 100644 index 000000000..aaf3fa7f7 --- /dev/null +++ b/Compiler/program.py @@ -0,0 +1,902 @@ +# (C) 2016 University of Bristol. See License.txt + +from Compiler.config import * +from Compiler.exceptions import * +from Compiler.instructions_base import RegType +import Compiler.instructions +import Compiler.instructions_base +import compilerLib +import allocator as al +import random +import time +import sys, os, errno +import inspect +from collections import defaultdict +import itertools +import math + + +data_types = dict( + triple = 0, + square = 1, + bit = 2, + inverse = 3, + bittriple = 4, + bitgf2ntriple = 5 +) + +field_types = dict( + modp = 0, + gf2n = 1, +) + + +class Program(object): + """ A program consists of a list of tapes and a scheduled order + of execution for these tapes. + + These are created by executing a file containing appropriate instructions + and threads. """ + def __init__(self, name, options, param=-1, assemblymode=False): + self.options = options + self.init_names(name, assemblymode) + self.P = P_VALUES[param] + self.param = param + self.bit_length = BIT_LENGTHS[param] + print 'Default bit length:', self.bit_length + self.security = STAT_SEC[param] + print 'Default security parameter:', self.security + self.galois_length = int(options.galois) + print 'Galois length:', self.galois_length + self.schedule = [('start', [])] + self.main_ctr = 0 + self.tapes = [] + self._curr_tape = None + self.EMULATE = True # defaults + self.FIRST_PASS = False + self.DEBUG = False + self.main_thread_running = False + self.allocated_mem = RegType.create_dict(lambda: USER_MEM) + self.free_mem_blocks = defaultdict(set) + self.allocated_mem_blocks = {} + self.req_num = None + self.tape_stack = [] + self.n_threads = 1 + self.free_threads = set() + self.public_input_file = open(self.programs_dir + '/Public-Input/%s' % name, 'w') + Program.prog = self + + self.reset_values() + + def max_par_tapes(self): + """ Upper bound on number of tapes that will be run in parallel. + (Excludes empty tapes) """ + if self.n_threads > 1: + if len(self.schedule) > 1: + raise CompilerError('Static and dynamic parallelism not compatible') + return self.n_threads + res = 1 + running = defaultdict(lambda: 0) + for action,tapes in self.schedule: + tapes = [t[0] for t in tapes if not t[0].is_empty()] + if action == 'start': + for tape in tapes: + running[tape] += 1 + elif action == 'stop': + for tape in tapes: + running[tape] -= 1 + else: + raise CompilerError('Invalid schedule action') + res = max(res, sum(running.itervalues())) + return res + + def init_names(self, name, assemblymode): + # ignore path to file - source must be in Programs/Source + if 'Programs' in os.listdir(os.getcwd()): + # compile prog in ./Programs/Source directory + self.programs_dir = os.getcwd() + '/Programs' + else: + # assume source is in main SPDZ directory + self.programs_dir = sys.path[0] + '/Programs' + print 'Compiling program in', self.programs_dir + + # create extra directories if needed + for dirname in ['Public-Input', 'Bytecode', 'Schedules']: + if not os.path.exists(self.programs_dir + '/' + dirname): + os.mkdir(self.programs_dir + '/' + dirname) + + name = name.split('/')[-1] + if name.endswith('.mpc'): + self.name = name[:-4] + else: + self.name = name + + if assemblymode: + self.infile = self.programs_dir + '/Source/' + self.name + '.asm' + else: + self.infile = self.programs_dir + '/Source/' + self.name + '.mpc' + + def new_tape(self, function, args=[], name=None): + if name is None: + name = function.__name__ + name = "%s-%s" % (self.name, name) + # make sure there is a current tape + self.curr_tape + tape_index = len(self.tapes) + name += "-%d" % tape_index + self.tape_stack.append(self.curr_tape) + self.curr_tape = Tape(name, self) + self.curr_tape.prevent_direct_memory_write = True + self.tapes.append(self.curr_tape) + function(*args) + self.finalize_tape(self.curr_tape) + if self.tape_stack: + self.curr_tape = self.tape_stack.pop() + return tape_index + + def run_tape(self, tape_index, arg): + if self.curr_tape is not self.tapes[0]: + raise CompilerError('Compiler does not support ' \ + 'recursive spawning of threads') + if self.free_threads: + thread_number = self.free_threads.pop() + else: + thread_number = self.n_threads + self.n_threads += 1 + self.curr_tape.start_new_basicblock(name='pre-run_tape') + Compiler.instructions.run_tape(thread_number, arg, tape_index) + self.curr_tape.start_new_basicblock(name='post-run_tape') + self.curr_tape.req_node.children.append(self.tapes[tape_index].req_tree) + return thread_number + + def join_tape(self, thread_number): + self.curr_tape.start_new_basicblock(name='pre-join_tape') + Compiler.instructions.join_tape(thread_number) + self.curr_tape.start_new_basicblock(name='post-join_tape') + self.free_threads.add(thread_number) + + def start_thread(self, thread, arg): + if self.main_thread_running: + # wait for main thread to finish + self.schedule_wait(self.curr_tape) + self.main_thread_running = False + + # compile thread if not been used already + if thread.tape not in self.tapes: + self.curr_tape = thread.tape + self.tapes.append(thread.tape) + thread.target(*thread.args) + + # add thread to schedule + self.schedule_start(thread.tape, arg) + self.curr_tape = None + + def stop_thread(self, thread): + tape = thread.tape + self.schedule_wait(tape) + + def update_req(self, tape): + if self.req_num is None: + self.req_num = tape.req_num + else: + self.req_num += tape.req_num + + def read_memory(self, filename): + """ Read the clear and shared memory from a file """ + f = open(filename) + n = int(f.next()) + self.mem_c = [0]*n + self.mem_s = [0]*n + mem = self.mem_c + done_c = False + for line in f: + line = line.split(' ') + a = int(line[0]) + b = int(line[1]) + if a != -1: + mem[a] = b + elif done_c: + break + else: + mem = self.mem_s + done_c = True + + def get_memory(self, mem_type, i): + if mem_type == 'c': + return self.mem_c[i] + elif mem_type == 's': + return self.mem_s[i] + raise CompilerError('Invalid memory type') + + def reset_values(self): + """ Reset register and memory values. """ + for tape in self.tapes: + tape.reset_registers() + self.mem_c = range(USER_MEM + TMP_MEM) + self.mem_s = range(USER_MEM + TMP_MEM) + random.seed(0) + + def write_bytes(self, outfile=None): + """ Write all non-empty threads and schedule to files. """ + # runtime doesn't support 'new-style' parallelism yet + old_style = True + + nonempty_tapes = [t for t in self.tapes if not t.is_empty()] + + sch_filename = self.programs_dir + '/Schedules/%s.sch' % self.name + sch_file = open(sch_filename, 'w') + print 'Writing to', sch_filename + sch_file.write(str(self.max_par_tapes()) + '\n') + sch_file.write(str(len(nonempty_tapes)) + '\n') + sch_file.write(' '.join(tape.name for tape in nonempty_tapes) + '\n') + + # assign tapes indices (needed for scheduler) + for i,tape in enumerate(nonempty_tapes): + tape.index = i + + for sch in self.schedule: + # schedule may still contain empty tapes: ignore these + tapes = filter(lambda x: not x[0].is_empty(), sch[1]) + # no empty line + if not tapes: + continue + line = ' '.join(str(t[0].index) + + (':' + str(t[1]) if t[1] is not None else '') for t in tapes) + if old_style: + if sch[0] == 'start': + sch_file.write('%d %s\n' % (len(tapes), line)) + else: + sch_file.write('%s %d %s\n' % (tapes[0], len(tapes), line)) + + sch_file.write('0\n') + sch_file.write(' '.join(sys.argv) + '\n') + for tape in self.tapes: + tape.write_bytes() + + def schedule_start(self, tape, arg=None): + """ Schedule the start of a thread. """ + if self.schedule[-1][0] == 'start': + self.schedule[-1][1].append((tape, arg)) + else: + self.schedule.append(('start', [(tape, arg)])) + + def schedule_wait(self, tape): + """ Schedule the end of a thread. """ + if self.schedule[-1][0] == 'stop': + self.schedule[-1][1].append((tape, None)) + else: + self.schedule.append(('stop', [(tape, None)])) + self.finalize_tape(tape) + self.update_req(tape) + + def finalize_tape(self, tape): + if not tape.purged: + tape.optimize(self.options) + tape.write_bytes() + if self.options.asmoutfile: + tape.write_str(self.options.asmoutfile + '-' + tape.name) + tape.purge() + + def emulate(self): + """ Emulate execution of entire program. """ + self.reset_values() + for sch in self.schedule: + if sch[0] == 'start': + for tape in sch[1]: + self._curr_tape = tape + for block in tape.basicblocks: + for line in block.instructions: + line.execute() + + def restart_main_thread(self): + if self.main_thread_running: + # wait for main thread to finish + self.schedule_wait(self._curr_tape) + self.main_thread_running = False + name = '%s-%d' % (self.name, self.main_ctr) + self._curr_tape = Tape(name, self) + self.tapes.append(self._curr_tape) + self.main_ctr += 1 + # add to schedule + self.schedule_start(self._curr_tape) + self.main_thread_running = True + + @property + def curr_tape(self): + """ The tape that is currently running.""" + if self._curr_tape is None: + # Create a new main thread if necessary + self.restart_main_thread() + return self._curr_tape + + @curr_tape.setter + def curr_tape(self, value): + self._curr_tape = value + + @property + def curr_block(self): + """ The basic block that is currently being created. """ + return self.curr_tape.active_basicblock + + def malloc(self, size, mem_type): + """ Allocate memory from the top """ + if size == 0: + return + if isinstance(mem_type, type): + mem_type = mem_type.reg_type + key = size, mem_type + if self.free_mem_blocks[key]: + addr = self.free_mem_blocks[key].pop() + else: + addr = self.allocated_mem[mem_type] + self.allocated_mem[mem_type] += size + if len(str(addr)) != len(str(addr + size)): + print "Memory of type '%s' now of size %d" % (mem_type, addr + size) + self.allocated_mem_blocks[addr,mem_type] = size + return addr + + def free(self, addr, mem_type): + """ Free memory """ + if self.curr_block.persistent_allocation: + raise CompilerError('Cannot free memory within function block') + size = self.allocated_mem_blocks.pop((addr,mem_type)) + self.free_mem_blocks[size,mem_type].add(addr) + + def finalize_memory(self): + import library + self.curr_tape.start_new_basicblock(None, 'memory-usage') + for mem_type,size in self.allocated_mem.items(): + if size: + #print "Memory of type '%s' of size %d" % (mem_type, size) + library.load_mem(size - 1, mem_type) + + def public_input(self, x): + self.public_input_file.write('%s\n' % str(x)) + + def set_bit_length(self, bit_length): + self.bit_length = bit_length + print 'Changed bit length for comparisons etc. to', bit_length + + def set_security(self, security): + self.security = security + print 'Changed statistical security for comparison etc. to', security + +class Tape: + """ A tape contains a list of basic blocks, onto which instructions are added. """ + def __init__(self, name, program, param=-1): + """ Set prime p and the initial instructions and registers. """ + self.program = program + self.init_names(name) + self.P = P_VALUES[param] + self.init_registers() + self.req_tree = self.ReqNode(name) + self.req_node = self.req_tree + self.basicblocks = [] + self.purged = False + self.active_basicblock = None + self.start_new_basicblock() + self._is_empty = False + self.merge_opens = True + self.if_states = [] + self.req_bit_length = defaultdict(lambda: 0) + self.function_basicblocks = {} + self.functions = [] + self.prevent_direct_memory_write = False + + class BasicBlock(object): + def __init__(self, parent, name, scope, exit_condition=None): + self.parent = parent + self.P = parent.P + self.instructions = [] + self.name = name + self.index = len(parent.basicblocks) + self.open_queue = [] + self.exit_condition = exit_condition + self.exit_block = None + self.previous_block = None + self.scope = scope + self.children = [] + if scope is not None: + scope.children.append(self) + self.persistent_allocation = scope.persistent_allocation + else: + self.persistent_allocation = False + + def new_reg(self, reg_type, size=None): + return self.parent.new_reg(reg_type, size=size) + + def set_return(self, previous_block, sub_block): + self.previous_block = previous_block + self.sub_block = sub_block + + def adjust_return(self): + offset = self.sub_block.get_offset(self) + self.previous_block.return_address_store.args[1] = offset + + def set_exit(self, condition, exit_true=None): + """ Sets the block which we start from next, depending on the condition. + + (Default is to go to next block in the list) + """ + self.exit_condition = condition + self.exit_block = exit_true + for reg in condition.get_used(): + reg.can_eliminate = False + + def add_jump(self): + """ Add the jump for this block's exit condition to list of + instructions (must be done after merging) """ + self.instructions.append(self.exit_condition) + + def get_offset(self, next_block): + return next_block.offset - (self.offset + len(self.instructions)) + + def adjust_jump(self): + """ Set the correct relative jump offset """ + offset = self.get_offset(self.exit_block) + self.exit_condition.set_relative_jump(offset) + #print 'Basic block %d jumps to %d (%d)' % (next_block_index, jump_index, offset) + def __str__(self): + return self.name + + def is_empty(self): + """ Returns True if the list of basic blocks is empty. + + Note: False is returned even when tape only contains basic + blocks with no instructions. However, these are removed when + optimize is called. """ + if not self.purged: + self._is_empty = (len(self.basicblocks) == 0) + return self._is_empty + + def start_new_basicblock(self, scope=False, name=''): + # use False because None means no scope + if scope is False: + scope = self.active_basicblock + suffix = '%s-%d' % (name, len(self.basicblocks)) + sub = self.BasicBlock(self, self.name + '-' + suffix, scope) + self.basicblocks.append(sub) + self.active_basicblock = sub + self.req_node.add_block(sub) + print 'Compiling basic block', sub.name + + def init_registers(self): + self.reset_registers() + self.reg_counter = RegType.create_dict(lambda: 0) + + def init_names(self, name): + # ignore path to file - source must be in Programs/Source + name = name.split('/')[-1] + if name.endswith('.asm'): + self.name = name[:-4] + else: + self.name = name + self.infile = self.program.programs_dir + '/Source/' + self.name + '.asm' + self.outfile = self.program.programs_dir + '/Bytecode/' + self.name + '.bc' + + def purge(self): + self._is_empty = (len(self.basicblocks) == 0) + del self.reg_values + del self.basicblocks + del self.active_basicblock + self.purged = True + + def unpurged(function): + def wrapper(self, *args, **kwargs): + if self.purged: + print '%s called on purged block %s, ignoring' % \ + (function.__name__, self.name) + return + return function(self, *args, **kwargs) + return wrapper + + @unpurged + def optimize(self, options): + if len(self.basicblocks) == 0: + print 'Tape %s is empty' % self.name + return + + if self.if_states: + raise CompilerError('Unclosed if/else blocks') + + print 'Processing tape', self.name, 'with %d blocks' % len(self.basicblocks) + + for block in self.basicblocks: + al.determine_scope(block) + + # merge open instructions + # need to do this if there are several blocks + if (options.merge_opens and self.merge_opens) or options.dead_code_elimination: + for i,block in enumerate(self.basicblocks): + if len(block.instructions) > 0: + print 'Processing basic block %s, %d/%d, %d instructions' % \ + (block.name, i, len(self.basicblocks), \ + len(block.instructions)) + # the next call is necessary for allocation later even without merging + merger = al.Merger(block, options) + if options.dead_code_elimination: + if len(block.instructions) > 10000: + print 'Eliminate dead code...' + merger.eliminate_dead_code() + if options.merge_opens and self.merge_opens: + if len(block.instructions) == 0: + block.used_from_scope = set() + block.defined_registers = set() + continue + if len(block.instructions) > 10000: + print 'Merging open instructions...' + numrounds = merger.longest_paths_merge() + if numrounds > 0: + print 'Program requires %d rounds of communication' % numrounds + numinv = sum(len(i.args) for i in block.instructions if isinstance(i, Compiler.instructions.startopen_class)) + if numinv > 0: + print 'Program requires %d invocations' % numinv + if options.dead_code_elimination: + block.instructions = filter(lambda x: x is not None, block.instructions) + if not (options.merge_opens and self.merge_opens): + print 'Not merging open instructions in tape %s' % self.name + + # add jumps + offset = 0 + for block in self.basicblocks: + if block.exit_condition is not None: + block.add_jump() + block.offset = offset + offset += len(block.instructions) + for block in self.basicblocks: + if block.exit_block is not None: + block.adjust_jump() + if block.previous_block is not None: + block.adjust_return() + + # now remove any empty blocks (must be done after setting jumps) + self.basicblocks = filter(lambda x: len(x.instructions) != 0, self.basicblocks) + + # allocate registers + reg_counts = self.count_regs() + if filter(lambda n: n > REG_MAX, reg_counts) and not options.noreallocate: + print 'Tape register usage:' + print 'modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp]) + print 'GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N]) + print 'Re-allocating...' + allocator = al.StraightlineAllocator(REG_MAX) + def alloc_loop(block): + for reg in block.used_from_scope: + allocator.alloc_reg(reg, block.persistent_allocation) + for child in block.children: + if child.instructions: + alloc_loop(child) + for i,block in enumerate(reversed(self.basicblocks)): + if len(block.instructions) > 10000: + print 'Allocating %s, %d/%d' % \ + (block.name, i, len(self.basicblocks)) + if block.exit_condition is not None: + jump = block.exit_condition.get_relative_jump() + if isinstance(jump, (int,long)) and jump < 0 and \ + block.exit_block.scope is not None: + alloc_loop(block.exit_block.scope) + allocator.process(block.instructions, block.persistent_allocation) + + # offline data requirements + print 'Compile offline data requirements...' + self.req_num = self.req_tree.aggregate() + print 'Tape requires', self.req_num + for req,num in self.req_num.items(): + if num == float('inf'): + num = -1 + if req[1] in data_types: + self.basicblocks[-1].instructions.append( + Compiler.instructions.use(field_types[req[0]], \ + data_types[req[1]], num, \ + add_to_prog=False)) + elif req[1] == 'input': + self.basicblocks[-1].instructions.append( + Compiler.instructions.use_inp(field_types[req[0]], \ + req[2], num, \ + add_to_prog=False)) + elif req[0] == 'modp': + self.basicblocks[-1].instructions.append( + Compiler.instructions.use_prep(req[1], num, \ + add_to_prog=False)) + elif req[0] == 'gf2n': + self.basicblocks[-1].instructions.append( + Compiler.instructions.guse_prep(req[1], num, \ + add_to_prog=False)) + + if not self.is_empty(): + # bit length requirement + self.basicblocks[-1].instructions.append( + Compiler.instructions.reqbl(self.req_bit_length['p'], add_to_prog=False)) + self.basicblocks[-1].instructions.append( + Compiler.instructions.greqbl(self.req_bit_length['2'], add_to_prog=False)) + print 'Tape requires prime bit length', self.req_bit_length['p'] + print 'Tape requires galois bit length', self.req_bit_length['2'] + + @unpurged + def _get_instructions(self): + return itertools.chain.\ + from_iterable(b.instructions for b in self.basicblocks) + + @unpurged + def get_encoding(self): + """ Get the encoding of the program, in human-readable format. """ + return [i.get_encoding() for i in self._get_instructions() if i is not None] + + @unpurged + def get_bytes(self): + """ Get the byte encoding of the program as an actual string of bytes. """ + return "".join(str(i.get_bytes()) for i in self._get_instructions() if i is not None) + + @unpurged + def write_encoding(self, filename): + """ Write the readable encoding to a file. """ + print 'Writing to', filename + f = open(filename, 'w') + for line in self.get_encoding(): + f.write(str(line) + '\n') + f.close() + + @unpurged + def write_str(self, filename): + """ Write the sequence of instructions to a file. """ + print 'Writing to', filename + f = open(filename, 'w') + n = 0 + for block in self.basicblocks: + if block.instructions: + f.write('# %s\n' % block.name) + for line in block.instructions: + f.write('%s # %d\n' % (line, n)) + n += 1 + f.close() + + @unpurged + def write_bytes(self, filename=None): + """ Write the program's byte encoding to a file. """ + if filename is None: + filename = self.outfile + if not filename.endswith('.bc'): + filename += '.bc' + if not 'Bytecode' in filename: + filename = self.program.programs_dir + '/Bytecode/' + filename + print 'Writing to', filename + f = open(filename, 'w') + f.write(self.get_bytes()) + f.close() + + def new_reg(self, reg_type, size=None): + return self.Register(reg_type, self, size=size) + + def count_regs(self, reg_type=None): + if reg_type is None: + return self.reg_counter + else: + return self.reg_counter[reg_type] + + def reset_registers(self): + """ Reset register values to zero. """ + self.reg_values = RegType.create_dict(lambda: [0] * INIT_REG_MAX) + + def get_value(self, reg_type, i): + return self.reg_values[reg_type][i] + + def __str__(self): + return self.name + + class ReqNum(defaultdict): + def __init__(self, init={}): + super(Tape.ReqNum, self).__init__(lambda: 0, init) + def __add__(self, other): + res = Tape.ReqNum() + for i,count in self.items(): + res[i] += count + for i,count in other.items(): + res[i] += count + return res + def __mul__(self, other): + res = Tape.ReqNum() + for i in self: + res[i] = other * self[i] + return res + __rmul__ = __mul__ + def set_all(self, value): + res = Tape.ReqNum() + for i in self: + res[i] = value + return res + def max(self, other): + res = Tape.ReqNum() + for i in self: + res[i] = max(self[i], other[i]) + for i in other: + res[i] = max(self[i], other[i]) + return res + def cost(self): + return sum(num * COST[req[0]][req[1]] for req,num in self.items() \ + if req[1] != 'input') + def __str__(self): + return ", ".join('%s inputs in %s from player %d' \ + % (num, req[0], req[2]) \ + if req[1] == 'input' \ + else '%s %ss in %s' % (num, req[1], req[0]) \ + for req,num in self.items()) + def __repr__(self): + return repr(dict(self)) + + class ReqNode(object): + __slots__ = ['num', 'children', 'name', 'blocks'] + def __init__(self, name): + self.children = [] + self.name = name + self.blocks = [] + def aggregate(self, *args): + self.num = Tape.ReqNum() + for block in self.blocks: + for inst in block.instructions: + inst.add_usage(self) + res = reduce(lambda x,y: x + y.aggregate(self.name), + self.children, self.num) + return res + def increment(self, data_type, num=1): + self.num[data_type] += num + def add_block(self, block): + self.blocks.append(block) + + class ReqChild(object): + __slots__ = ['aggregator', 'nodes', 'parent'] + def __init__(self, aggregator, parent): + self.aggregator = aggregator + self.nodes = [] + self.parent = parent + def aggregate(self, name): + res = self.aggregator([node.aggregate() for node in self.nodes]) + return res + def add_node(self, tape, name): + new_node = Tape.ReqNode(name) + self.nodes.append(new_node) + tape.req_node = new_node + + def open_scope(self, aggregator, scope=False, name=''): + child = self.ReqChild(aggregator, self.req_node) + self.req_node.children.append(child) + child.add_node(self, '%s-%d' % (name, len(self.basicblocks))) + self.start_new_basicblock(name=name) + return child + + def close_scope(self, outer_scope, parent_req_node, name): + self.req_node = parent_req_node + self.start_new_basicblock(outer_scope, name) + + def require_bit_length(self, bit_length, t='p'): + if t == 'p': + self.req_bit_length[t] = max(bit_length + 1, \ + self.req_bit_length[t]) + if self.program.param != -1 and bit_length >= self.program.param: + raise CompilerError('Inadequate bit length %d for prime, ' \ + 'program requires %d bits' % \ + (self.program.param, self.req_bit_length['p'])) + else: + self.req_bit_length[t] = max(bit_length, self.req_bit_length) + + class Register(object): + """ + Class for creating new registers. The register's index is automatically assigned + based on the block's reg_counter dictionary. + + The 'value' property is for emulation. + """ + __slots__ = ["reg_type", "program", "i", "value", "_is_active", \ + "size", "vector", "vectorbase", "caller", \ + "can_eliminate"] + + def __init__(self, reg_type, program, value=None, size=None, i=None): + """ Creates a new register. + reg_type must be one of those defined in RegType. """ + if Compiler.instructions_base.get_global_instruction_type() == 'gf2n': + if reg_type == RegType.ClearModp: + reg_type = RegType.ClearGF2N + elif reg_type == RegType.SecretModp: + reg_type = RegType.SecretGF2N + self.reg_type = reg_type + self.program = program + if size is None: + size = Compiler.instructions_base.get_global_vector_size() + self.size = size + if i: + self.i = i + else: + self.i = program.reg_counter[reg_type] + program.reg_counter[reg_type] += size + self.vector = [] + self.vectorbase = self + if value is not None: + self.value = value + self._is_active = False + self.can_eliminate = True + if Program.prog.DEBUG: + self.caller = [frame[1:] for frame in inspect.stack()[1:]] + else: + self.caller = None + if self.i % 1000000 == 0 and self.i > 0: + print "Initialized %d registers at" % self.i, time.asctime() + + def set_size(self, size): + if self.size == size: + return + elif self.size == 1 and self.vectorbase is self: + if '%s%d' % (self.reg_type, self.i) in compilerLib.VARS: + # create vector register in assembly mode + self.size = size + self.vector = [self] + for i in range(1,size): + reg = compilerLib.VARS['%s%d' % (self.reg_type, self.i + i)] + reg.set_vectorbase(self) + self.vector.append(reg) + else: + raise CompilerError('Cannot find %s in VARS' % str(self)) + else: + raise CompilerError('Cannot reset size of vector register') + + def set_vectorbase(self, vectorbase): + if self.vectorbase != self: + raise CompilerError('Cannot assign one register' \ + 'to several vectors') + self.vectorbase = vectorbase + + def create_vector_elements(self): + if self.vector: + return + elif self.size == 1: + self.vector = [self] + return + self.vector = [self] + for i in range(1,self.size): + reg = Tape.Register(self.reg_type, self.program, size=1, i=self.i+i) + reg.set_vectorbase(self) + self.vector.append(reg) + + def __getitem__(self, index): + if not self.vector: + self.create_vector_elements() + return self.vector[index] + + def __len__(self): + return self.size + + def activate(self): + """ Activating a register signals that it will at some point be used + in the program. + + Inactive registers are reserved for temporaries for CISC instructions. """ + if not self._is_active: + self._is_active = True + + @property + def value(self): + return self.program.reg_values[self.reg_type][self.i] + + @value.setter + def value(self, val): + while (len(self.program.reg_values[self.reg_type]) <= self.i): + self.program.reg_values[self.reg_type] += [0] * INIT_REG_MAX + self.program.reg_values[self.reg_type][self.i] = val + + @property + def is_active(self): + return self._is_active + + @property + def is_gf2n(self): + return self.reg_type == RegType.ClearGF2N or \ + self.reg_type == RegType.SecretGF2N + + @property + def is_clear(self): + return self.reg_type == RegType.ClearModp or \ + self.reg_type == RegType.ClearGF2N or \ + self.reg_type == RegType.ClearInt + + def __str__(self): + return self.reg_type + str(self.i) + + __repr__ = __str__ diff --git a/Compiler/tools.py b/Compiler/tools.py new file mode 100644 index 000000000..d30891be7 --- /dev/null +++ b/Compiler/tools.py @@ -0,0 +1,9 @@ +# (C) 2016 University of Bristol. See License.txt + +import itertools + +class chain(object): + def __init__(self, *args): + self.args = args + def __iter__(self): + return itertools.chain(*self.args) diff --git a/Compiler/types.py b/Compiler/types.py new file mode 100644 index 000000000..810735d38 --- /dev/null +++ b/Compiler/types.py @@ -0,0 +1,2013 @@ +# (C) 2016 University of Bristol. See License.txt + +from Compiler.program import Tape +from Compiler.exceptions import * +from Compiler.instructions import * +from Compiler.instructions_base import * +from floatingpoint import two_power +import comparison, floatingpoint +import math +import util +import operator + + +class MPCThread(object): + def __init__(self, target, name, args = [], runtime_arg = None): + """ Create a thread from a callable object. """ + if not callable(target): + raise CompilerError('Target %s for thread %s is not callable' % (target,name)) + self.name = name + self.tape = Tape(program.name + '-' + name, program) + self.target = target + self.args = args + self.runtime_arg = runtime_arg + self.running = 0 + + def start(self, runtime_arg = None): + self.running += 1 + program.start_thread(self, runtime_arg or self.runtime_arg) + + def join(self): + if not self.running: + raise CompilerError('Thread %s is not running' % self.name) + self.running -= 1 + program.stop_thread(self) + + +def vectorize(operation): + def vectorized_operation(self, *args, **kwargs): + if len(args): + if (isinstance(args[0], Tape.Register) or isinstance(args[0], sfloat)) \ + and args[0].size != self.size: + raise CompilerError('Different vector sizes of operands') + set_global_vector_size(self.size) + res = operation(self, *args, **kwargs) + reset_global_vector_size() + return res + return vectorized_operation + +def vectorized_classmethod(function): + def vectorized_function(cls, *args, **kwargs): + size = None + if 'size' in kwargs: + size = kwargs.pop('size') + if size: + set_global_vector_size(size) + res = function(cls, *args, **kwargs) + reset_global_vector_size() + else: + res = function(cls, *args, **kwargs) + return res + return classmethod(vectorized_function) + +def vectorize_init(function): + def vectorized_init(*args, **kwargs): + size = None + if len(args) > 1 and (isinstance(args[1], Tape.Register) or \ + isinstance(args[1], sfloat)): + size = args[1].size + if 'size' in kwargs and kwargs['size'] is not None \ + and kwargs['size'] != size: + raise CompilerError('Mismatch in vector size') + if 'size' in kwargs and kwargs['size']: + size = kwargs['size'] + if size is not None: + set_global_vector_size(size) + res = function(*args, **kwargs) + reset_global_vector_size() + else: + res = function(*args, **kwargs) + return res + return vectorized_init + +def set_instruction_type(operation): + def instruction_typed_operation(self, *args, **kwargs): + set_global_instruction_type(self.instruction_type) + res = operation(self, *args, **kwargs) + reset_global_instruction_type() + return res + return instruction_typed_operation + +def read_mem_value(operation): + def read_mem_operation(self, other, *args, **kwargs): + if isinstance(other, MemValue): + other = other.read() + return operation(self, other, *args, **kwargs) + return read_mem_operation + + +class _number(object): + def square(self): + return self * self + + def __add__(self, other): + if other is 0 or other is 0L: + return self + else: + return self.add(other) + + def __mul__(self, other): + if other is 0 or other is 0L: + return 0 + elif other is 1 or other is 1L: + return self + else: + return self.mul(other) + + __radd__ = __add__ + __rmul__ = __mul__ + + @vectorize + def __pow__(self, exp): + if isinstance(exp, int) and exp >= 0: + if exp == 0: + return self.__class__(1) + exp = bin(exp)[3:] + res = self + for i in exp: + res = res.square() + if i == '1': + res *= self + return res + else: + return NotImplemented + +class _int(object): + def if_else(self, a, b): + return self * (a - b) + b + + def cond_swap(self, a, b): + prod = self * (a - b) + return a - prod, b + prod + +class _gf2n(object): + def if_else(self, a, b): + return b ^ self * self.hard_conv(a ^ b) + + def cond_swap(self, a, b, t=None): + prod = self * self.hard_conv(a ^ b) + res = a ^ prod, b ^ prod + if t is None: + return res + else: + return tuple(t.conv(r) for r in res) + + +class _register(Tape.Register, _number): + @vectorized_classmethod + def conv(cls, val): + if isinstance(val, MemValue): + val = val.read() + if isinstance(val, cls): + return val + elif not isinstance(val, _register): + try: + return type(val)(cls.conv(v) for v in val) + except TypeError: + pass + return cls(val) + + @vectorized_classmethod + @read_mem_value + def hard_conv(cls, val): + if type(val) == cls: + return val + elif not isinstance(val, _register): + try: + return val.hard_conv_me(cls) + except AttributeError: + try: + return type(val)(cls.hard_conv(v) for v in val) + except TypeError: + pass + return cls(val) + + @vectorized_classmethod + @set_instruction_type + def _load_mem(cls, address, direct_inst, indirect_inst): + res = cls() + if isinstance(address, _register): + indirect_inst(res, regint.conv(address)) + else: + direct_inst(res, address) + return res + + @set_instruction_type + @vectorize + def _store_in_mem(self, address, direct_inst, indirect_inst): + if isinstance(address, _register): + indirect_inst(self, regint.conv(address)) + else: + direct_inst(self, address) + + @classmethod + def prep_res(cls, other): + return cls() + + def __init__(self, reg_type, val, size): + super(_register, self).__init__(reg_type, program.curr_tape, size=size) + if isinstance(val, (int, long)): + self.load_int(val) + elif val is not None: + self.load_other(val) + + def sizeof(self): + return self.size + + +class _clear(_register): + __slots__ = [] + + @vectorized_classmethod + @set_instruction_type + def protect_memory(cls, start, end): + program.curr_tape.start_new_basicblock(name='protect-memory') + protectmemc(regint(start), regint(end)) + + @set_instruction_type + @vectorize + def load_other(self, val): + if isinstance(val, type(self)): + movc(self, val) + else: + self.convert_from(val) + + @vectorize + @read_mem_value + def convert_from(self, val): + if not isinstance(val, regint): + val = regint(val) + convint(self, val) + + @set_instruction_type + @vectorize + def print_reg(self, comment=''): + print_reg(self, comment) + + @set_instruction_type + @vectorize + def print_reg_plain(self): + print_reg_plain(self) + + @set_instruction_type + @vectorize + def raw_output(self): + raw_output(self) + + @set_instruction_type + @read_mem_value + @vectorize + def clear_op(self, other, c_inst, ci_inst, reverse=False): + cls = self.__class__ + res = self.prep_res(other) + if isinstance(other, cls): + c_inst(res, self, other) + elif isinstance(other, (int, long)): + if self.in_immediate_range(other): + ci_inst(res, self, other) + else: + if reverse: + c_inst(res, cls(other), self) + else: + c_inst(res, self, cls(other)) + else: + return NotImplemented + return res + + @set_instruction_type + @read_mem_value + @vectorize + def coerce_op(self, other, inst, reverse=False): + cls = self.__class__ + res = cls() + if isinstance(other, (int, long)): + other = cls(other) + elif not isinstance(other, cls): + return NotImplemented + if reverse: + inst(res, other, self) + else: + inst(res, self, other) + return res + + def add(self, other): + return self.clear_op(other, addc, addci) + + def mul(self, other): + return self.clear_op(other, mulc, mulci) + + def __sub__(self, other): + return self.clear_op(other, subc, subci) + + def __rsub__(self, other): + return self.clear_op(other, subc, subcfi, True) + + def __div__(self, other): + return self.clear_op(other, divc, divci) + + def __rdiv__(self, other): + return self.coerce_op(other, divc, True) + + def __eq__(self, other): + if isinstance(other, (_clear,int,long)): + return regint(self) == other + else: + return NotImplemented + + def __ne__(self, other): + return 1 - (self == other) + + def __and__(self, other): + return self.clear_op(other, andc, andci) + + def __xor__(self, other): + return self.clear_op(other, xorc, xorci) + + def __or__(self, other): + return self.clear_op(other, orc, orci) + + __rand__ = __and__ + __rxor__ = __xor__ + __ror__ = __or__ + + +class cint(_clear, _int): + " Clear mod p integer type. """ + __slots__ = [] + instruction_type = 'modp' + reg_type = 'c' + + @vectorized_classmethod + def load_mem(cls, address): + return cls._load_mem(address, ldmc, ldmci) + + def store_in_mem(self, address): + self._store_in_mem(address, stmc, stmci) + + @staticmethod + def in_immediate_range(value): + return value < 2**31 and value >= -2**31 + + def __init__(self, val=None, size=None): + super(cint, self).__init__('c', val=val, size=size) + + @vectorize + def load_int(self, val): + if val: + # +1 for sign + program.curr_tape.require_bit_length(1 + int(math.ceil(math.log(abs(val))))) + if self.in_immediate_range(val): + ldi(self, val) + else: + max = 2**31 - 1 + sign = abs(val) / val + val = abs(val) + chunks = [] + while val: + mod = val % max + val = (val - mod) / max + chunks.append(mod) + sum = cint(sign * chunks.pop()) + for i,chunk in enumerate(reversed(chunks)): + sum *= max + if i == len(chunks) - 1: + addci(self, sum, sign * chunk) + elif chunk: + sum += sign * chunk + + def __mod__(self, other): + return self.clear_op(other, modc, modci) + + def __rmod__(self, other): + return self.coerce_op(other, modc, True) + + def __lt__(self, other): + if isinstance(other, (type(self),int,long)): + return regint(self) < other + else: + return NotImplemented + + def __gt__(self, other): + if isinstance(other, (type(self),int,long)): + return regint(self) > other + else: + return NotImplemented + + def __le__(self, other): + return 1 - (self > other) + + def __ge__(self, other): + return 1 - (self < other) + + def __lshift__(self, other): + return self.clear_op(other, shlc, shlci) + + def __rshift__(self, other): + return self.clear_op(other, shrc, shrci) + + def __neg__(self): + return 0 - self + + @vectorize + def __invert__(self): + res = cint() + notc(res, self, program.bit_length) + return res + + def __rpow__(self, base): + if base == 2: + return 1 << self + else: + return NotImplemented + + @vectorize + def __rlshift__(self, other): + return cint(other) << self + + @vectorize + def __rrshift__(self, other): + return cint(other) >> self + + @read_mem_value + def mod2m(self, other, bit_length=None, signed=None): + return self % 2**other + + @read_mem_value + def right_shift(self, other, bit_length=None): + return self >> other + + @read_mem_value + def greater_than(self, other, bit_length=None): + return self > other + + def pow2(self, bit_length=None): + return 2**self + + def bit_decompose(self, bit_length=None): + if bit_length == 0: + return [] + bit_length = bit_length or program.bit_length + return floatingpoint.bits(self, bit_length) + + def legendre(self): + res = cint() + legendrec(res, self) + return res + + +class cgf2n(_clear, _gf2n): + __slots__ = [] + instruction_type = 'gf2n' + reg_type = 'cg' + + @classmethod + def bit_compose(cls, bits, step=None): + size = bits[0].size + res = cls(size=size) + vgbitcom(size, res, step or 1, *bits) + return res + + @vectorized_classmethod + def load_mem(cls, address): + return cls._load_mem(address, gldmc, gldmci) + + def store_in_mem(self, address): + self._store_in_mem(address, gstmc, gstmci) + + @staticmethod + def in_immediate_range(value): + return value < 2**32 and value >= 0 + + def __init__(self, val=None, size=None): + super(cgf2n, self).__init__('cg', val=val, size=size) + + @vectorize + def load_int(self, val): + if val < 0: + raise CompilerError('Negative GF2n immediate') + if self.in_immediate_range(val): + gldi(self, val) + else: + chunks = [] + while val: + mod = val % 2**32 + val >>= 32 + chunks.append(mod) + sum = cgf2n(chunks.pop()) + for i,chunk in enumerate(reversed(chunks)): + sum <<= 32 + if i == len(chunks) - 1: + gaddci(self, sum, chunk) + elif chunk: + sum += chunk + + def __mul__(self, other): + return super(cgf2n, self).__mul__(other) + + def __neg__(self): + return self + + @vectorize + def __invert__(self): + res = cgf2n() + gnotc(res, self) + return res + + @vectorize + def __lshift__(self, other): + if isinstance(other, int): + res = cgf2n() + gshlci(res, self, other) + return res + else: + return NotImplemented + + @vectorize + def __rshift__(self, other): + if isinstance(other, int): + res = cgf2n() + gshrci(res, self, other) + return res + else: + return NotImplemented + + @vectorize + def bit_decompose(self, bit_length=None, step=None): + bit_length = bit_length or program.galois_length + step = step or 1 + res = [type(self)() for _ in range(bit_length / step)] + gbitdec(self, step, *res) + return res + +class regint(_register, _int): + __slots__ = [] + reg_type = 'ci' + instruction_type = 'modp' + + @classmethod + def protect_memory(cls, start, end): + program.curr_tape.start_new_basicblock(name='protect-memory') + protectmemint(regint(start), regint(end)) + + @vectorized_classmethod + def load_mem(cls, address): + return cls._load_mem(address, ldmint, ldminti) + + def store_in_mem(self, address): + self._store_in_mem(address, stmint, stminti) + + @vectorized_classmethod + def pop(cls): + res = cls() + popint(res) + return res + + @vectorized_classmethod + def get_random(cls, bit_length): + if isinstance(bit_length, int): + bit_length = regint(bit_length) + res = cls() + rand(res, bit_length) + return res + + @vectorized_classmethod + def read_from_socket(cls): + res = cls() + readsocketc(res,0) + return res + + @vectorize + def write_to_socket(self): + writesocketc(self,0) + + @vectorize_init + def __init__(self, val=None, size=None): + super(regint, self).__init__(self.reg_type, val=val, size=size) + + def load_int(self, val): + if cint.in_immediate_range(val): + ldint(self, val) + else: + lower = val % 2**32 + upper = val >> 32 + if lower >= 2**31: + lower -= 2**32 + upper += 1 + addint(self, regint(upper) * regint(2**16)**2, regint(lower)) + + @read_mem_value + def load_other(self, val): + if isinstance(val, cint): + convmodp(self, val) + elif isinstance(val, cgf2n): + gconvgf2n(self, val) + elif isinstance(val, regint): + addint(self, val, regint(0)) + else: + raise CompilerError("Cannot convert '%s' to integer" % type(val)) + + @vectorize + @read_mem_value + def int_op(self, other, inst, reverse=False): + if isinstance(other, _secret): + return NotImplemented + elif not isinstance(other, type(self)): + other = type(self)(other) + res = regint() + if reverse: + inst(res, other, self) + else: + inst(res, self, other) + return res + + def add(self, other): + return self.int_op(other, addint) + + def __sub__(self, other): + return self.int_op(other, subint) + + def __rsub__(self, other): + return self.int_op(other, subint, True) + + def mul(self, other): + return self.int_op(other, mulint) + + def __neg__(self): + return 0 - self + + def __div__(self, other): + return self.int_op(other, divint) + + def __rdiv__(self, other): + return self.int_op(other, divint, True) + + def __mod__(self, other): + return cint(self) % other + + def __rmod__(self, other): + return other % cint(self) + + def __rpow__(self, other): + return other**cint(self) + + def __eq__(self, other): + return self.int_op(other, eqc) + + def __ne__(self, other): + return 1 - (self == other) + + def __lt__(self, other): + return self.int_op(other, ltc) + + def __gt__(self, other): + return self.int_op(other, gtc) + + def __le__(self, other): + return 1 - (self > other) + + def __ge__(self, other): + return 1 - (self < other) + + def __lshift__(self, other): + return regint(cint(self) << other) + + def __rshift__(self, other): + return regint(cint(self) >> other) + + def __rlshift__(self, other): + return regint(other << cint(self)) + + def __rrshift__(self, other): + return regint(other >> cint(self)) + + def __and__(self, other): + return regint(other & cint(self)) + + def __or__(self, other): + return regint(other | cint(self)) + + def __xor__(self, other): + return regint(other ^ cint(self)) + + __rand__ = __and__ + __ror__ = __or__ + __rxor__ = __xor__ + + def mod2m(self, *args, **kwargs): + return cint(self).mod2m(*args, **kwargs) + + +class _secret(_register): + __slots__ = [] + + @vectorized_classmethod + @set_instruction_type + def protect_memory(cls, start, end): + program.curr_tape.start_new_basicblock(name='protect-memory') + protectmems(regint(start), regint(end)) + + @vectorized_classmethod + @set_instruction_type + def get_input_from(cls, player): + res = cls() + asm_input(res, player) + return res + + @vectorized_classmethod + @set_instruction_type + def get_random_triple(cls): + res = (cls(), cls(), cls()) + triple(*res) + return res + + @vectorized_classmethod + @set_instruction_type + def get_random_bit(cls): + res = cls() + bit(res) + return res + + @vectorized_classmethod + @set_instruction_type + def get_random_square(cls): + res = (cls(), cls()) + square(*res) + return res + + @vectorized_classmethod + @set_instruction_type + def get_random_inverse(cls): + res = (cls(), cls()) + inverse(*res) + return res + + @vectorized_classmethod + @set_instruction_type + def get_random_input_mask_for(cls, player): + res = cls() + inputmask(res, player) + return res + + def __init__(self, reg_type, val=None, size=None): + if isinstance(val, self.clear_type): + size = val.size + super(_secret, self).__init__(reg_type, val=val, size=size) + + @set_instruction_type + @vectorize + def load_int(self, val): + if self.clear_type.in_immediate_range(val): + ldsi(self, val) + else: + self.load_clear(self.clear_type(val)) + + @vectorize + def load_clear(self, val): + addm(self, self.__class__(0), val) + + @set_instruction_type + @read_mem_value + @vectorize + def load_other(self, val): + if isinstance(val, self.clear_type): + self.load_clear(val) + elif isinstance(val, type(self)): + movs(self, val) + else: + self.load_clear(self.clear_type(val)) + + @set_instruction_type + @read_mem_value + @vectorize + def secret_op(self, other, s_inst, m_inst, si_inst, reverse=False): + cls = self.__class__ + res = self.prep_res(other) + if isinstance(other, regint): + other = res.clear_type(other) + if isinstance(other, cls): + s_inst(res, self, other) + elif isinstance(other, res.clear_type): + if reverse: + m_inst(res, other, self) + else: + m_inst(res, self, other) + elif isinstance(other, (int, long)): + if self.clear_type.in_immediate_range(other): + si_inst(res, self, other) + else: + if reverse: + m_inst(res, res.clear_type(other), self) + else: + m_inst(res, self, res.clear_type(other)) + else: + return NotImplemented + return res + + def add(self, other): + return self.secret_op(other, adds, addm, addsi) + + def mul(self, other): + return self.secret_op(other, muls, mulm, mulsi) + + def __sub__(self, other): + return self.secret_op(other, subs, subml, subsi) + + def __rsub__(self, other): + return self.secret_op(other, subs, submr, subsfi, True) + + @vectorize + def __div__(self, other): + return self * (self.clear_type(1) / other) + + @vectorize + def __rdiv__(self, other): + a,b = self.get_random_inverse() + return other * a / (a * self).reveal() + + @set_instruction_type + @vectorize + def square(self): + res = self.__class__() + sqrs(res, self) + return res + + @set_instruction_type + @vectorize + def reveal(self): + res = self.clear_type() + asm_open(res, self) + return res + + @set_instruction_type + def reveal_to(self, player): + masked = self.__class__() + startprivateoutput(masked, self, player) + stopprivateoutput(masked.reveal(), player) + + +class sint(_secret, _int): + " Shared mod p integer type. """ + __slots__ = [] + instruction_type = 'modp' + clear_type = cint + reg_type = 's' + + @vectorized_classmethod + def get_random_int(cls, bits): + res = sint() + comparison.PRandInt(res, bits) + return res + + @classmethod + def get_raw_input_from(cls, player): + res = cls() + startinput(player, 1) + stopinput(player, res) + return res + + @vectorized_classmethod + def read_from_socket(cls): + res = cls() + readsockets(res,0) + return res + + @vectorize + def write_to_socket(self): + writesockets(self,0) + + @vectorized_classmethod + def load_mem(cls, address): + return cls._load_mem(address, ldms, ldmsi) + + def store_in_mem(self, address): + self._store_in_mem(address, stms, stmsi) + + def __init__(self, val=None, size=None): + super(sint, self).__init__('s', val=val, size=size) + + @vectorize + def __neg__(self): + return 0 - self + + @read_mem_value + @vectorize + def __lt__(self, other, bit_length=None, security=None): + res = sint() + comparison.LTZ(res, self - other, bit_length or program.bit_length + 1, + security or program.security) + return res + + @read_mem_value + @vectorize + def __gt__(self, other, bit_length=None, security=None): + res = sint() + comparison.LTZ(res, other - self, bit_length or program.bit_length + 1, + security or program.security) + return res + + def __le__(self, other, bit_length=None, security=None): + return 1 - self.greater_than(other, bit_length, security) + + def __ge__(self, other, bit_length=None, security=None): + return 1 - self.less_than(other, bit_length, security) + + @read_mem_value + @vectorize + def __eq__(self, other, bit_length=None, security=None): + return floatingpoint.EQZ(self - other, bit_length or program.bit_length, + security or program.security) + + def __ne__(self, other, bit_length=None, security=None): + return 1 - self.equal(other, bit_length, security) + + less_than = __lt__ + greater_than = __gt__ + less_equal = __le__ + greater_equal = __ge__ + equal = __eq__ + not_equal = __ne__ + + @vectorize + def __mod__(self, modulus): + if isinstance(modulus, (int, long)): + l = math.log(modulus, 2) + if 2**int(round(l)) == modulus: + return self.mod2m(int(l)) + raise NotImplementedError('Modulo only implemented for powers of two.') + + @read_mem_value + def mod2m(self, m, bit_length=None, security=None, signed=True): + bit_length = bit_length or program.bit_length + security = security or program.security + if isinstance(m, int): + if m == 0: + return 0 + if m >= bit_length: + return self + res = sint() + if m == 1: + comparison.Mod2(res, self, bit_length, security, signed) + else: + comparison.Mod2m(res, self, bit_length, m, security, signed) + else: + res, pow2 = floatingpoint.Trunc(self, bit_length, m, security, True) + return res + + @vectorize + def __rpow__(self, base): + if base == 2: + return self.pow2() + else: + return NotImplemented + + def pow2(self, bit_length=None, security=None): + return floatingpoint.Pow2(self, bit_length or program.bit_length, \ + security or program.security) + + def __lshift__(self, other): + return self * 2**other + + @vectorize + @read_mem_value + def __rshift__(self, other, bit_length=None, security=None): + bit_length = bit_length or program.bit_length + security = security or program.security + if isinstance(other, int): + if other == 0: + return self + res = sint() + comparison.Trunc(res, self, bit_length, other, security, True) + return res + elif isinstance(other, sint): + return floatingpoint.Trunc(self, bit_length, other, security) + else: + return floatingpoint.Trunc(self, bit_length, sint(other), security) + + right_shift = __rshift__ + + def __rlshift__(self, other): + return other * 2**self + + @vectorize + def __rrshift__(self, other): + return floatingpoint.Trunc(other, program.bit_length, self, program.security) + + def bit_decompose(self, bit_length=None, security=None): + if bit_length == 0: + return [] + bit_length = bit_length or program.bit_length + security = security or program.security + return floatingpoint.BitDec(self, bit_length, bit_length, security) + +class sgf2n(_secret, _gf2n): + __slots__ = [] + instruction_type = 'gf2n' + clear_type = cgf2n + reg_type = 'sg' + + @classmethod + def get_raw_input_from(cls, player): + res = cls() + gstartinput(player, 1) + gstopinput(player, res) + return res + + def add(self, other): + if isinstance(other, sgf2nint): + return NotImplemented + else: + return super(sgf2n, self).add(other) + + def mul(self, other): + if isinstance(other, (sgf2nint)): + return NotImplemented + else: + return super(sgf2n, self).mul(other) + + @vectorized_classmethod + def load_mem(cls, address): + return cls._load_mem(address, gldms, gldmsi) + + def store_in_mem(self, address): + self._store_in_mem(address, gstms, gstmsi) + + def __init__(self, val=None, size=None): + super(sgf2n, self).__init__('sg', val=val, size=size) + + def __neg__(self): + return self + + @vectorize + def __invert__(self): + return self ^ cgf2n(2**program.galois_length - 1) + + def __xor__(self, other): + if other is 0 or other is 0L: + return self + else: + return super(sgf2n, self).add(other) + + __rxor__ = __xor__ + + @vectorize + def __and__(self, other): + if isinstance(other, (int, long)): + other_bits = [(other >> i) & 1 \ + for i in range(program.galois_length)] + else: + other_bits = other.bit_decompose() + self_bits = self.bit_decompose() + return sum((x * y) << i \ + for i,(x,y) in enumerate(zip(self_bits, other_bits))) + + __rand__ = __and__ + + @vectorize + def __lshift__(self, other): + return self * cgf2n(1 << other) + + @vectorize + def right_shift(self, other, bit_length=None): + bits = self.bit_decompose(bit_length) + return sum(b << i for i,b in enumerate(bits[other:])) + + def equal(self, other, bit_length=None, expand=1): + bits = [1 - bit for bit in (self - other).bit_decompose(bit_length)][::expand] + while len(bits) > 1: + bits.insert(0, bits.pop() * bits.pop()) + return bits[0] + + def not_equal(self, other, bit_length=None): + return 1 - self.equal(other, bit_length) + + __eq__ = equal + __ne__ = not_equal + + @vectorize + def bit_decompose(self, bit_length=None, step=1): + if bit_length == 0: + return [] + bit_length = bit_length or program.galois_length + random_bits = [self.get_random_bit() \ + for i in range(0, bit_length, step)] + one = cgf2n(1) + masked = sum([b * (one << (i * step)) for i,b in enumerate(random_bits)], self).reveal() + masked_bits = masked.bit_decompose(bit_length) + return [m + r for m,r in zip(masked_bits, random_bits)] + + @vectorize + def bit_decompose_embedding(self): + + random_bits = [self.get_random_bit() \ + for i in range(8)] + one = cgf2n(1) + wanted_positions = [0, 5, 10, 15, 20, 25, 30, 35] + masked = sum([b * (one << wanted_positions[i]) for i,b in enumerate(random_bits)], self).reveal() + return [self.clear_type((masked >> wanted_positions[i]) & one) + r for i,r in enumerate(random_bits)] + +sint.basic_type = sint +sgf2n.basic_type = sgf2n + + +class sgf2nint(sgf2n): + bits = None + + @classmethod + def compose(cls, bits): + bits = list(bits) + if len(bits) > cls.n_bits: + raise CompilerError('Too many bits') + res = cls() + res.bits = bits + [0] * (cls.n_bits - len(bits)) + gmovs(res, sum(b << i for i,b in enumerate(bits))) + return res + + @staticmethod + def bit_adder(a, b): + a, b = list(a), list(b) + a += [0] * (len(b) - len(a)) + b += [0] * (len(a) - len(b)) + lower = [] + for (ai,bi) in zip(a,b): + if ai is 0 or bi is 0: + lower.append(ai + bi) + a.pop(0) + b.pop(0) + else: + break + d = [(ai + bi, ai * bi) for (ai,bi) in zip(a,b)] + carry = lambda y,x,*args: \ + (x[0] * y[0], 1 - (1 - x[1]) * (1 - x[0] * y[1])) + if d: + carries = (0,) + zip(*floatingpoint.PreOpL(carry, d))[1] + else: + carries = [] + return lower + [ai + bi + carry for (ai,bi,carry) in zip(a,b,carries)] + + @staticmethod + def full_adder(a, b, carry): + s = a + b + return s + carry, util.or_op(a * b, s * carry) + + @staticmethod + def half_adder(a, b): + return a + b, a * b + + @staticmethod + def bit_comparator(a, b): + op = lambda y,x,*args: (util.if_else(x[1], x[0], y[0]), \ + util.if_else(x[1], 1, y[1])) + return floatingpoint.KOpL(op, [(bi, ai + bi) for (ai,bi) in zip(a,b)]) + + @staticmethod + def get_highest_different_bits(a, b, index): + diff = [ai + bi for (ai,bi) in reversed(zip(a,b))] + preor = floatingpoint.PreOR(diff, raw=True) + highest_diff = [x - y for (x,y) in reversed(zip(preor, [0] + preor))] + raw = sum(map(operator.mul, highest_diff, (a,b)[index])) + return raw.bit_decompose()[0] + + def load_int(self, other): + if -2**(self.n_bits-1) <= other < 2**(self.n_bits-1): + sgf2n.load_int(self, other + 2**self.n_bits if other < 0 else other) + else: + raise CompilerError('Invalid signed %d-bit integer: %d' % \ + (self.n_bits, other)) + + def load_other(self, other): + if isinstance(other, sgf2nint): + gmovs(self, self.compose(other.bit_decompose(self.n_bits))) + elif isinstance(other, sgf2n): + gmovs(self, other) + else: + gaddm(self, sgf2n(0), cgf2n(other)) + + def add(self, other): + if type(other) == sgf2n: + raise CompilerError('Unclear addition') + a = self.bit_decompose() + b = util.bit_decompose(other, self.n_bits) + return self.compose(self.bit_adder(a, b)) + + def mul(self, other): + if type(other) == sgf2n: + raise CompilerError('Unclear multiplication') + self_bits = self.bit_decompose() + if isinstance(other, (int, long)): + other_bits = util.bit_decompose(other, self.n_bits) + bit_matrix = [[x * y for y in self_bits] for x in other_bits] + else: + other = sgf2n(other) + products = [x * other for x in self_bits] + bit_matrix = [util.bit_decompose(x, self.n_bits) for x in products] + columns = [filter(None, (bit_matrix[j][i-j] \ + for j in range(min(len(bit_matrix), i + 1)))) \ + for i in range(len(bit_matrix[0]))] + # Wallace tree + while max(len(c) for c in columns) > 2: + new_columns = [[] for i in range(len(columns) + 1)] + for i,col in enumerate(columns): + while len(col) > 2: + s, carry = self.full_adder(*(col.pop() for i in range(3))) + new_columns[i].append(s) + new_columns[i+1].append(carry) + if len(col) == 2: + s, carry = self.half_adder(*(col.pop() for i in range(2))) + new_columns[i].append(s) + new_columns[i+1].append(carry) + else: + new_columns[i].extend(col) + columns = new_columns[:-1] + for col in columns: + col.extend([0] * (2 - len(col))) + return self.compose(self.bit_adder(*zip(*columns))) + + def __sub__(self, other): + if type(other) == sgf2n: + raise CompilerError('Unclear subtraction') + a = self.bit_decompose() + b = util.bit_decompose(other, self.n_bits) + d = [(1 + ai + bi, (1 - ai) * bi) for (ai,bi) in zip(a,b)] + borrow = lambda y,x,*args: \ + (x[0] * y[0], 1 - (1 - x[1]) * (1 - x[0] * y[1])) + borrows = (0,) + zip(*floatingpoint.PreOpL(borrow, d))[1] + return self.compose(ai + bi + borrow \ + for (ai,bi,borrow) in zip(a,b,borrows)) + + def __rsub__(self, other): + raise NotImplementedError() + + def __div__(self, other): + raise NotImplementedError() + + def __rdiv__(self, other): + raise NotImplementedError() + + def __lshift__(self, other): + return self.compose(([0] * other + self.bit_decompose())[:self.n_bits]) + + def __rshift__(self, other): + return self.compose(self.bit_decompose()[other:]) + + def bit_decompose(self, n_bits=None, *args): + if self.bits is None: + self.bits = sgf2n(self).bit_decompose(self.n_bits) + if n_bits is None: + return self.bits[:] + else: + return self.bits[:n_bits] + [self.fill_bit()] * (n_bits - self.n_bits) + + def fill_bit(self): + return self.bits[-1] + + @staticmethod + def prep_comparison(a, b): + a[-1], b[-1] = b[-1], a[-1] + + def comparison(self, other, const_rounds=False, index=None): + a = self.bit_decompose() + b = util.bit_decompose(other, self.n_bits) + self.prep_comparison(a, b) + if const_rounds: + return self.get_highest_different_bits(a, b, index) + else: + return self.bit_comparator(a, b) + + def __lt__(self, other): + if program.options.comparison == 'log': + x, not_equal = self.comparison(other) + return util.if_else(not_equal, x, 0) + else: + return self.comparison(other, True, 1) + + def __le__(self, other): + if program.options.comparison == 'log': + x, not_equal = self.comparison(other) + return util.if_else(not_equal, x, 1) + else: + return 1 - self.comparison(other, True, 0) + + def __ge__(self, other): + return 1 - (self < other) + + def __gt__(self, other): + return 1 - (self <= other) + + def __neg__(self): + return 1 + self.compose(1 ^ b for b in self.bit_decompose()) + +class sgf2nuint(sgf2nint): + def load_int(self, other): + if 0 <= other < 2**self.n_bits: + sgf2n.load_int(self, other) + else: + raise CompilerError('Invalid unsigned %d-bit integer: %d' % \ + (self.n_bits, other)) + + def fill_bit(self): + return 0 + + @staticmethod + def prep_comparison(a, b): + pass + +class sgf2nuint32(sgf2nuint): + n_bits = 32 + +class sgf2nint32(sgf2nint): + n_bits = 32 + +def get_sgf2nint(n): + class sgf2nint_spec(sgf2nint): + n_bits = n + #sgf2nint_spec.__name__ = 'sgf2unint' + str(n) + return sgf2nint_spec + +def get_sgf2nuint(n): + class sgf2nuint_spec(sgf2nint): + n_bits = n + #sgf2nuint_spec.__name__ = 'sgf2nuint' + str(n) + return sgf2nuint_spec + +class sgf2nfloat(sgf2n): + @classmethod + def set_precision(cls, vlen, plen): + cls.vlen = vlen + cls.plen = plen + class v_type(sgf2nuint): + n_bits = 2 * vlen + 1 + class p_type(sgf2nint): + n_bits = plen + class pdiff_type(sgf2nuint): + n_bits = plen + cls.v_type = v_type + cls.p_type = p_type + cls.pdiff_type = pdiff_type + + def __init__(self, val, p=None, z=None, s=None): + super(sgf2nfloat, self).__init__() + if p is None and type(val) == sgf2n: + bits = val.bit_decompose(self.vlen + self.plen + 1) + self.v = self.v_type.compose(bits[:self.vlen]) + self.p = self.p_type.compose(bits[self.vlen:-1]) + self.s = bits[-1] + self.z = util.tree_reduce(operator.mul, (1 - b for b in self.v.bits)) + else: + if p is None: + v, p, z, s = sfloat.convert_float(val, self.vlen, self.plen) + # correct sfloat + p += self.vlen - 1 + v_bits = util.bit_decompose(v, self.vlen) + p_bits = util.bit_decompose(p, self.plen) + self.v = self.v_type.compose(v_bits) + self.p = self.p_type.compose(p_bits) + self.z = z + self.s = s + else: + self.v, self.p, self.z, self.s = val, p, z, s + v_bits = val.bit_decompose()[:self.vlen] + p_bits = p.bit_decompose()[:self.plen] + gmovs(self, util.bit_compose(v_bits + p_bits + [self.s])) + + def add(self, other): + a = self.p < other.p + b = self.p == other.p + c = self.v < other.v + other_dominates = (b.if_else(c, a)) + pmax, pmin = a.cond_swap(self.p, other.p, self.p_type) + vmax, vmin = other_dominates.cond_swap(self.v, other.v, self.v_type) + s3 = self.s ^ other.s + pdiff = self.pdiff_type(pmax - pmin) + d = self.vlen < pdiff + pow_delta = util.pow2(d.if_else(0, pdiff).bit_decompose(util.log2(self.vlen))) + v3 = vmax + v4 = self.v_type(sgf2n(vmax) * pow_delta) + self.v_type(s3.if_else(-vmin, vmin)) + v = self.v_type(sgf2n(d.if_else(v3, v4) << self.vlen) / pow_delta) + v >>= self.vlen - 1 + h = floatingpoint.PreOR(v.bits[self.vlen+1::-1]) + tmp = sum(util.if_else(b, 0, 1 << i) for i,b in enumerate(h)) + pow_p0 = 1 + self.v_type(tmp) + v = (v * pow_p0) >> 2 + p = pmax - sum(self.p_type.compose([1 - b]) for b in h) + 1 + v = self.z.if_else(other.v, other.z.if_else(self.v, v)) + z = v == 0 + p = z.if_else(0, self.z.if_else(other.p, other.z.if_else(self.p, p))) + s = other_dominates.if_else(other.s, self.s) + s = self.z.if_else(other.s, other.z.if_else(self.s, s)) + return sgf2nfloat(v, p, z, s) + + def mul(self, other): + v = (self.v * other.v) >> (self.vlen - 1) + b = v.bits[self.vlen] + v = b.if_else(v >> 1, v) + p = self.p + other.p + self.p_type.compose([b]) + s = self.s + other.s + z = util.or_op(self.z, other.z) + return sgf2nfloat(v, p, z, s) + +sgf2nfloat.set_precision(24, 8) + +class sfloat(_number): + """ Shared floating point data type, representing (1 - 2s)*(1 - z)*v*2^p. + + v: significand + p: exponent + z: zero flag + s: sign bit + """ + __slots__ = ['v', 'p', 'z', 's', 'size'] + # single precision + vlen = 24 + plen = 8 + kappa = 40 + round_nearest = False + error = 0 + + @vectorized_classmethod + def load_mem(cls, address): + res = [] + for i in range(4): + res.append(sint.load_mem(address + i * get_global_vector_size())) + return sfloat(*res) + + @classmethod + def set_error(cls, error): + cls.error += error - cls.error * error + + @staticmethod + def convert_float(v, vlen, plen): + if v < 0: + s = 1 + else: + s = 0 + if v == 0: + v = 0 + p = 0 + z = 1 + else: + p = int(math.floor(math.log(abs(v), 2))) - vlen + 1 + v = int(round(abs(v) * 2 ** (-p))) + if v == 2 ** vlen: + p += 1 + v /= 2 + z = 0 + if abs(p) >= 2 ** plen: + raise CompilerError('Cannot convert %s to float ' \ + 'with %d exponent bits' % (v, plen)) + return v, p, z, s + + @vectorize_init + def __init__(self, v, p=None, z=None, s=None, size=None): + self.size = get_global_vector_size() + if p is None: + if isinstance(v, sfloat): + p = v.p + z = v.z + s = v.s + v = v.v + elif isinstance(v, sint): + v, p, z, s = floatingpoint.Int2FL(v, program.bit_length, + self.vlen, self.kappa) + else: + v, p, z, s = self.convert_float(v, self.vlen, self.plen) + if isinstance(v, int): + if not ((v >= 2**(self.vlen-1) and v < 2**(self.vlen)) or v == 0): + raise CompilerError('Floating point number malformed: significand') + self.v = library.load_int_to_secret(v) + else: + self.v = v + if isinstance(p, int): + if not (p >= -2**self.plen and p < (2**self.plen - 1)): + raise CompilerError('Floating point number malformed: exponent') + self.p = library.load_int_to_secret(p) + else: + self.p = p + if isinstance(z, int): + if not (z == 0 or z == 1): + raise CompilerError('Floating point number malformed: zero bit') + self.z = sint() + ldsi(self.z, z) + else: + self.z = z + if isinstance(s, int): + if not (s == 0 or s == 1): + raise CompilerError('Floating point number malformed: sign') + self.s = sint() + ldsi(self.s, s) + else: + self.s = s + + def store_in_mem(self, address): + for i,x in enumerate((self.v, self.p, self.z, self.s)): + x.store_in_mem(address + i * get_global_vector_size()) + + def sizeof(self): + return self.size * 4 + + @vectorize + def add(self, other): + if isinstance(other, sfloat): + a,c,d,e = [sint() for i in range(4)] + t = sint() + t2 = sint() + v1 = self.v + v2 = other.v + p1 = self.p + p2 = other.p + s1 = self.s + s2 = other.s + z1 = self.z + z2 = other.z + a = p1.less_than(p2, self.plen, self.kappa) + b = floatingpoint.EQZ(p1 - p2, self.plen, self.kappa) + c = v1.less_than(v2, self.vlen, self.kappa) + ap1 = a*p1 + ap2 = a*p2 + aneg = 1 - a + bneg = 1 - b + cneg = 1 - c + av1 = a*v1 + av2 = a*v2 + cv1 = c*v1 + cv2 = c*v2 + pmax = ap2 + p1 - ap1 + pmin = p2 - ap2 + ap1 + vmax = bneg*(av2 + v1 - av1) + b*(cv2 + v1 - cv1) + vmin = bneg*(av1 + v2 - av2) + b*(cv1 + v2 - cv2) + s3 = s1 + s2 - 2 * s1 * s2 + comparison.LTZ(d, self.vlen + pmin - pmax + sfloat.round_nearest, + self.plen, self.kappa) + pow_delta = floatingpoint.Pow2((1 - d) * (pmax - pmin), + self.vlen + 1 + sfloat.round_nearest, + self.kappa) + # deviate from paper for more precision + #v3 = 2 * (vmax - s3) + 1 + v3 = vmax + v4 = vmax * pow_delta + (1 - 2 * s3) * vmin + v = (d * v3 + (1 - d) * v4) * two_power(self.vlen + sfloat.round_nearest) \ + * floatingpoint.Inv(pow_delta) + comparison.Trunc(t, v, 2 * self.vlen + 1 + sfloat.round_nearest, + self.vlen - 1, self.kappa, False) + v = t + u = floatingpoint.BitDec(v, self.vlen + 2 + sfloat.round_nearest, + self.vlen + 2 + sfloat.round_nearest, self.kappa, + range(1 + sfloat.round_nearest, + self.vlen + 2 + sfloat.round_nearest)) + # using u[0] doesn't seem necessary + h = floatingpoint.PreOR(u[:sfloat.round_nearest:-1], self.kappa) + p0 = self.vlen + 1 - sum(h) + pow_p0 = 1 + sum([two_power(i) * (1 - h[i]) for i in range(len(h))]) + if self.round_nearest: + t2, overflow = \ + floatingpoint.TruncRoundNearestAdjustOverflow(pow_p0 * v, + self.vlen + 3, + self.vlen, + self.kappa) + p0 = p0 - overflow + else: + comparison.Trunc(t2, pow_p0 * v, self.vlen + 2, 2, self.kappa, False) + v = t2 + # deviate for more precision + #p = pmax - p0 + 1 - d + p = pmax - p0 + 1 + zz = self.z*other.z + zprod = 1 - self.z - other.z + zz + v = zprod*t2 + self.z*v2 + other.z*v1 + z = floatingpoint.EQZ(v, self.vlen, self.kappa) + p = (zprod*p + self.z*p2 + other.z*p1)*(1 - z) + s = (1 - b)*(a*other.s + aneg*self.s) + b*(c*other.s + cneg*self.s) + s = zprod*s + (other.z - zz)*self.s + (self.z - zz)*other.s + return sfloat(v, p, z, s) + else: + return NotImplemented + + @vectorize + def mul(self, other): + if isinstance(other, sfloat): + v1 = sint() + v2 = sint() + b = sint() + c2expl = cint() + comparison.ld2i(c2expl, self.vlen) + if sfloat.round_nearest: + v1 = comparison.TruncRoundNearest(self.v*other.v, 2*self.vlen, + self.vlen-1, self.kappa) + else: + comparison.Trunc(v1, self.v*other.v, 2*self.vlen, self.vlen-1, self.kappa, False) + t = v1 - c2expl + comparison.LTZ(b, t, self.vlen+1, self.kappa) + comparison.Trunc(v2, b*v1 + v1, self.vlen+1, 1, self.kappa, False) + z = self.z + other.z - self.z*other.z # = OR(z1, z2) + s = self.s + other.s - 2*self.s*other.s # = XOR(s1,s2) + p = (self.p + other.p - b + self.vlen)*(1 - z) + return sfloat(v2, p, z, s) + else: + return NotImplemented + + def __sub__(self, other): + return self + -other + + def __rsub__(self, other): + raise NotImplementedError() + + def __div__(self, other): + v = floatingpoint.SDiv(self.v, other.v + other.z * (2**self.vlen - 1), + self.vlen, self.kappa) + b = v.less_than(two_power(self.vlen-1), self.vlen + 1, self.kappa) + overflow = v.greater_equal(two_power(self.vlen), self.vlen + 1, self.kappa) + underflow = v.less_than(two_power(self.vlen-2), self.vlen + 1, self.kappa) + v = (v + b * v) * (1 - overflow) * (1 - underflow) + \ + overflow * (2**self.vlen - 1) + \ + underflow * (2**(self.vlen-1)) * (1 - self.z) + p = (1 - self.z) * (self.p - other.p - self.vlen - b + 1) + z = self.z + s = self.s + other.s - 2 * self.s * other.s + sfloat.set_error(other.z) + return sfloat(v, p, z, s) + + @vectorize + def __neg__(self): + return sfloat(self.v, self.p, self.z, (1 - self.s) * (1 - self.z)) + + @vectorize + def __lt__(self, other): + if isinstance(other, sfloat): + z1 = self.z + z2 = other.z + s1 = self.s + s2 = other.s + a = self.p.less_than(other.p, self.plen, self.kappa) + c = floatingpoint.EQZ(self.p - other.p, self.plen, self.kappa) + d = ((1 - 2*self.s)*self.v).less_than((1 - 2*other.s)*other.v, self.vlen + 1, self.kappa) + cd = c*d + ca = c*a + b1 = cd + a - ca + b2 = cd + 1 + ca - c - a + s12 = self.s*other.s + z12 = self.z*other.z + b = (z1 - z12)*(1 - s2) + (z2 - z12)*s1 + (1 + z12 - z1 - z2)*(s1 - s12 + (1 + s12 - s1 - s2)*b1 + s12*b2) + return b + else: + return NotImplemented + + def __ge__(self, other): + return 1 - (self < other) + + @vectorize + def __eq__(self, other): + # the sign can be both ways for zeroes + both_zero = self.z * other.z + return floatingpoint.EQZ(self.v - other.v, self.vlen, self.kappa) * \ + floatingpoint.EQZ(self.p - other.p, self.plen, self.kappa) * \ + (1 - self.s - other.s + 2 * self.s * other.s) * \ + (1 - both_zero) + both_zero + + def __ne__(self, other): + return 1 - (self == other) + + def value(self): + """ Gets actual floating point value, if emulation is enabled. """ + return (1 - 2*self.s.value)*(1 - self.z.value)*self.v.value/float(2**self.p.value) + + +_types = { + 'c': cint, + 's': sint, + 'sg': sgf2n, + 'cg': cgf2n, + 'ci': regint, +} + + +class Array(object): + def __init__(self, length, value_type, address=None): + if value_type in _types: + value_type = _types[value_type] + self.address = address + if address is None: + self.address = program.malloc(length, value_type.reg_type) + self.length = length + self.value_type = value_type + + def delete(self): + if program: + program.free(self.address, self.value_type.reg_type) + + def get_address(self, index): + if isinstance(index, int) and self.length is not None: + index += self.length * (index < 0) + if index >= self.length or index < 0: + raise IndexError('index %s, length %s' % \ + (str(index), str(self.length))) + return self.address + index + + def get_slice(self, index): + if index.stop is None and self.length is None: + raise CompilerError('Cannot slice array of unknown length') + return index.start or 0, index.stop or self.length, index.step or 1 + + def __getitem__(self, index): + if isinstance(index, slice): + start, stop, step = self.get_slice(index) + res_length = (stop - start - 1) / step + 1 + res = Array(res_length, self.value_type) + @library.for_range(res_length) + def f(i): + res[i] = self[start+i*step] + return res + return self.value_type.load_mem(self.get_address(index)) + + def __setitem__(self, index, value): + if isinstance(index, slice): + start, stop, step = self.get_slice(index) + source_index = MemValue(0) + @library.for_range(start, stop, step) + def f(i): + self[i] = value[source_index] + source_index.iadd(1) + return + self.value_type.conv(value).store_in_mem(self.get_address(index)) + + def __len__(self): + return self.length + + def __iter__(self): + for i in range(self.length): + yield self[i] + + def assign(self, other): + if isinstance(other, Array): + def loop(i): + self[i] = other[i] + library.range_loop(loop, len(self)) + elif isinstance(other, Tape.Register): + if len(other) == self.length: + self[0] = other + else: + raise CompilerError('Length mismatch between array and vector') + else: + for i,j in enumerate(other): + self[i] = j + return self + + def assign_all(self, value): + mem_value = MemValue(value) + n_loops = 8 if len(self) > 2**20 else 1 + @library.for_range_multithread(n_loops, 1024, len(self)) + def f(i): + self[i] = mem_value + return self + + +class Matrix(object): + def __init__(self, rows, columns, value_type): + self.rows = rows + self.columns = columns + if value_type in _types: + value_type = _types[value_type] + self.value_type = value_type + self.address = Array(rows * columns, value_type).address + + def __getitem__(self, index): + return Array(self.columns, self.value_type, \ + self.address + index * self.columns) + + def __len__(self): + return self.rows + + def assign_all(self, value): + @library.for_range(len(self)) + def f(i): + self[i].assign_all(value) + return self + + +class SubMultiArray(object): + def __init__(self, sizes, value_type, address, index): + self.sizes = sizes + self.value_type = value_type + self.address = address + index * reduce(operator.mul, self.sizes) + + def __getitem__(self, index): + if len(self.sizes) == 2: + return Array(self.sizes[1], self.value_type, \ + self.address + index * self.sizes[0]) + else: + return SubMultiArray(self.sizes[1:], self.value_type, \ + self.address, index) + +class MultiArray(object): + def __init__(self, sizes, value_type): + self.sizes = sizes + self.value_type = value_type + self.array = Array(reduce(operator.mul, sizes), \ + value_type) + if len(sizes) < 2: + raise CompilerError('Use Array') + + def __getitem__(self, index): + return SubMultiArray(self.sizes[1:], self.value_type, \ + self.array.address, index) + +class VectorArray(object): + def __init__(self, length, value_type, vector_size, address=None): + self.array = Array(length * vector_size, value_type, address) + self.vector_size = vector_size + self.value_type = value_type + + def __getitem__(self, index): + return self.value_type.load_mem(self.array.address + \ + index * self.vector_size, + size=self.vector_size) + + def __setitem__(self, index, value): + if value.size != self.vector_size: + raise CompilerError('vector size mismatch') + value.store_in_mem(self.array.address + index * self.vector_size) + +class _mem(_number): + __add__ = lambda self,other: self.read() + other + __sub__ = lambda self,other: self.read() - other + __mul__ = lambda self,other: self.read() * other + __div__ = lambda self,other: self.read() / other + __mod__ = lambda self,other: self.read() % other + __pow__ = lambda self,other: self.read() ** other + __neg__ = lambda self,other: -self.read() + __lt__ = lambda self,other: self.read() < other + __gt__ = lambda self,other: self.read() > other + __le__ = lambda self,other: self.read() <= other + __ge__ = lambda self,other: self.read() >= other + __eq__ = lambda self,other: self.read() == other + __ne__ = lambda self,other: self.read() != other + __and__ = lambda self,other: self.read() & other + __xor__ = lambda self,other: self.read() ^ other + __or__ = lambda self,other: self.read() | other + __lshift__ = lambda self,other: self.read() << other + __rshift__ = lambda self,other: self.read() >> other + + __radd__ = lambda self,other: other + self.read() + __rsub__ = lambda self,other: other - self.read() + __rmul__ = lambda self,other: other * self.read() + __rdiv__ = lambda self,other: other / self.read() + __rmod__ = lambda self,other: other % self.read() + __rand__ = lambda self,other: other & self.read() + __rxor__ = lambda self,other: other ^ self.read() + __ror__ = lambda self,other: other | self.read() + + __iadd__ = lambda self,other: self.write(self.read() + other) + __isub__ = lambda self,other: self.write(self.read() - other) + __imul__ = lambda self,other: self.write(self.read() * other) + __idiv__ = lambda self,other: self.write(self.read() / other) + __imod__ = lambda self,other: self.write(self.read() % other) + __ipow__ = lambda self,other: self.write(self.read() ** other) + __iand__ = lambda self,other: self.write(self.read() & other) + __ixor__ = lambda self,other: self.write(self.read() ^ other) + __ior__ = lambda self,other: self.write(self.read() | other) + __ilshift__ = lambda self,other: self.write(self.read() << other) + __irshift__ = lambda self,other: self.write(self.read() >> other) + + iadd = __iadd__ + isub = __isub__ + imul = __imul__ + idiv = __idiv__ + imod = __imod__ + ipow = __ipow__ + iand = __iand__ + ixor = __ixor__ + ior = __ior__ + ilshift = __ilshift__ + irshift = __irshift__ + + store_in_mem = lambda self,address: self.read().store_in_mem(address) + +class MemValue(_mem): + __slots__ = ['last_write_block', 'reg_type', 'register', 'address', 'deleted'] + + def __init__(self, value): + self.last_write_block = None + if isinstance(value, int): + self.value_type = regint + value = regint(value) + elif isinstance(value, MemValue): + self.value_type = value.value_type + else: + self.value_type = type(value) + self.reg_type = self.value_type.reg_type + self.address = program.malloc(1, self.reg_type) + self.deleted = False + self.write(value) + + def delete(self): + program.free(self.address, self.reg_type) + self.deleted = True + + def check(self): + if self.deleted: + raise CompilerError('MemValue deleted') + + def read(self): + self.check() + if program.curr_block != self.last_write_block: + self.register = library.load_mem(self.address, self.value_type) + self.last_write_block = program.curr_block + return self.register + + def write(self, value): + self.check() + if isinstance(value, MemValue): + self.register = value.read() + elif isinstance(value, (int,long)): + self.register = self.value_type(value) + else: + self.register = value + if not isinstance(self.register, self.value_type): + raise CompilerError('Mismatch in register type, cannot write \ + %s to %s' % (type(self.register), self.value_type)) + library.store_in_mem(self.register, self.address) + self.last_write_block = program.curr_block + return self + + def reveal(self): + if self.register.is_clear: + return self.read() + else: + return self.read().reveal() + + less_than = lambda self,other,bit_length=None,security=None: \ + self.read().less_than(other,bit_length,security) + greater_than = lambda self,other,bit_length=None,security=None: \ + self.read().greater_than(other,bit_length,security) + less_equal = lambda self,other,bit_length=None,security=None: \ + self.read().less_equal(other,bit_length,security) + greater_equal = lambda self,other,bit_length=None,security=None: \ + self.read().greater_equal(other,bit_length,security) + equal = lambda self,other,bit_length=None,security=None: \ + self.read().equal(other,bit_length,security) + not_equal = lambda self,other,bit_length=None,security=None: \ + self.read().not_equal(other,bit_length,security) + + pow2 = lambda self,*args,**kwargs: self.read().pow2(*args, **kwargs) + mod2m = lambda self,*args,**kwargs: self.read().mod2m(*args, **kwargs) + right_shift = lambda self,*args,**kwargs: self.read().right_shift(*args, **kwargs) + + bit_decompose = lambda self,*args,**kwargs: self.read().bit_decompose(*args, **kwargs) + + if_else = lambda self,*args,**kwargs: self.read().if_else(*args, **kwargs) + + +class MemFloat(_mem): + def __init__(self, *args): + value = sfloat(*args) + self.v = MemValue(value.v) + self.p = MemValue(value.p) + self.z = MemValue(value.z) + self.s = MemValue(value.s) + + def write(self, *args): + value = sfloat(*args) + self.v.write(value.v) + self.p.write(value.p) + self.z.write(value.z) + self.s.write(value.s) + + def read(self): + return sfloat(self.v, self.p, self.z, self.s) + +def getNamedTupleType(*names): + class NamedTuple(object): + class NamedTupleArray(object): + def __init__(self, size, t): + import types + self.arrays = [types.Array(size, t) for i in range(len(names))] + def __getitem__(self, index): + return NamedTuple(array[index] for array in self.arrays) + def __setitem__(self, index, item): + for array,value in zip(self.arrays, item): + array[index] = value + @classmethod + def get_array(cls, size, t): + return cls.NamedTupleArray(size, t) + def __init__(self, *args): + if len(args) == 1: + args = args[0] + for name, value in zip(names, args): + self.__dict__[name] = value + def __iter__(self): + for name in names: + yield self.__dict__[name] + def __add__(self, other): + return NamedTuple(i + j for i,j in zip(self, other)) + def __sub__(self, other): + return NamedTuple(i - j for i,j in zip(self, other)) + def __xor__(self, other): + return NamedTuple(i ^ j for i,j in zip(self, other)) + def __mul__(self, other): + return NamedTuple(other * i for i in self) + __rmul__ = __mul__ + __rxor__ = __xor__ + def reveal(self): + return self.__type__(x.reveal() for x in self) + return NamedTuple + + +import library diff --git a/Compiler/util.py b/Compiler/util.py new file mode 100644 index 000000000..a4f9e3fd6 --- /dev/null +++ b/Compiler/util.py @@ -0,0 +1,105 @@ +# (C) 2016 University of Bristol. See License.txt + +import math +import operator + +def format_trace(trace, prefix=' '): + if trace is None: + return '' + else: + return ''.join('\n%sFile "%s", line %s, in %s\n%s %s' % + (prefix,i[0],i[1],i[2],prefix,i[3][0].strip()) \ + for i in reversed(trace)) + +def tuplify(x): + if isinstance(x, (list, tuple)): + return tuple(x) + else: + return (x,) + +def untuplify(x): + if len(x) == 1: + return x[0] + else: + return x + +def greater_than(a, b, bits): + if isinstance(a, int) and isinstance(b, int): + return a > b + else: + return a.greater_than(b, bits) + +def pow2(a, bits): + if isinstance(a, int): + return 2**a + else: + return a.pow2(bits) + +def mod2m(a, b, bits, signed): + if isinstance(a, int): + return a % 2**b + else: + return a.mod2m(b, bits, signed=signed) + +def right_shift(a, b, bits): + if isinstance(a, int): + return a >> b + else: + return a.right_shift(b, bits) + +def bit_decompose(a, bits): + if isinstance(a, (int,long)): + return [int((a >> i) & 1) for i in range(bits)] + else: + return a.bit_decompose(bits) + +def bit_compose(bits): + return sum(b << i for i,b in enumerate(bits)) + +def series(a): + sum = 0 + for i in a: + yield sum + sum += i + yield sum + +def if_else(cond, a, b): + try: + if isinstance(cond, (bool, int)): + if cond: + return a + else: + return b + return cond.if_else(a, b) + except: + print cond, a, b + raise + +def cond_swap(cond, a, b): + if isinstance(cond, (bool, int)): + if cond: + return a, b + else: + return b, a + return cond.cond_swap(a, b) + +def log2(x): + #print 'Compute log2 of', x + return int(math.ceil(math.log(x, 2))) + +def tree_reduce(function, sequence): + n = len(sequence) + if n == 1: + return sequence[0] + else: + reduced = [function(sequence[2*i], sequence[2*i+1]) for i in range(n/2)] + return tree_reduce(function, reduced + sequence[n/2*2:]) + +def or_op(a, b): + return a + b - a * b + +OR = or_op + +def pow2(bits): + powers = [b.if_else(2**2**i, 1) for i,b in enumerate(bits)] + return tree_reduce(operator.mul, powers) diff --git a/Exceptions/Exceptions.h b/Exceptions/Exceptions.h new file mode 100644 index 000000000..3c20a4baf --- /dev/null +++ b/Exceptions/Exceptions.h @@ -0,0 +1,162 @@ +// (C) 2016 University of Bristol. See License.txt + +#ifndef _Exceptions +#define _Exceptions + +#include +#include +#include +using namespace std; + +class not_implemented: public exception + { virtual const char* what() const throw() + { return "Case not implemented"; } + }; +class division_by_zero: public exception + { virtual const char* what() const throw() + { return "Division by zero"; } + }; +class invalid_plaintext: public exception + { virtual const char* what() const throw() + { return "Inconsistent plaintext space"; } + }; +class rep_mismatch: public exception + { virtual const char* what() const throw() + { return "Representation mismatch"; } + }; +class pr_mismatch: public exception + { virtual const char* what() const throw() + { return "Prime mismatch"; } + }; +class params_mismatch: public exception + { virtual const char* what() const throw() + { return "FHE params mismatch"; } + }; +class field_mismatch: public exception + { virtual const char* what() const throw() + { return "Plaintext Field mismatch"; } + }; +class level_mismatch: public exception + { virtual const char* what() const throw() + { return "Level mismatch"; } + }; +class invalid_length: public exception + { virtual const char* what() const throw() + { return "Invalid length"; } + }; +class invalid_commitment: public exception + { virtual const char* what() const throw() + { return "Invalid Commitment"; } + }; +class IO_Error: public exception + { string msg, ans; + public: + IO_Error(string m) : msg(m) + { ans="IO-Error : "; + ans+=msg; + } + ~IO_Error()throw() { } + virtual const char* what() const throw() + { + return ans.c_str(); + } + }; +class broadcast_invalid: public exception + { virtual const char* what() const throw() + { return "Inconsistent broadcast at some point"; } + }; +class bad_keygen: public exception + { string msg; + public: + bad_keygen(string m) : msg(m) {} + ~bad_keygen()throw() { } + virtual const char* what() const throw() + { string ans="KeyGen has gone wrong: "+msg; + return ans.c_str(); + } + }; +class bad_enccommit: public exception + { virtual const char* what() const throw() + { return "Error in EncCommit"; } + }; +class invalid_params: public exception + { virtual const char* what() const throw() + { return "Invalid Params"; } + }; +class bad_value: public exception + { virtual const char* what() const throw() + { return "Some value is wrong somewhere"; } + }; +class Offline_Check_Error: public exception + { string msg; + public: + Offline_Check_Error(string m) : msg(m) {} + ~Offline_Check_Error()throw() { } + virtual const char* what() const throw() + { string ans="Offline-Check-Error : "; + ans+=msg; + return ans.c_str(); + } + }; +class mac_fail: public exception + { virtual const char* what() const throw() + { return "MacCheck Failure"; } + }; +class invalid_program: public exception + { virtual const char* what() const throw() + { return "Invalid Program"; } + }; +class file_error: public exception + { string filename, ans; + public: + file_error(string m="") : filename(m) + { + ans="File Error : "; + ans+=filename; + } + ~file_error()throw() { } + virtual const char* what() const throw() + { + return ans.c_str(); + } + }; +class end_of_file: public exception + { virtual const char* what() const throw() + { return "End of file reached"; } + }; +class Processor_Error: public exception + { string msg; + public: + Processor_Error(string m) + { + msg = "Processor-Error : " + m; + } + ~Processor_Error()throw() { } + virtual const char* what() const throw() + { + return msg.c_str(); + } + }; +class max_mod_sz_too_small : public exception + { int len; + public: + max_mod_sz_too_small(int len) : len(len) {} + ~max_mod_sz_too_small() throw() {} + virtual const char* what() const throw() + { stringstream out; + out << "MAX_MOD_SZ too small for desired bit length of p, " + << "must be at least ceil(len(p)/len(word))+1, " + << "in this case: " << len; + return out.str().c_str(); + } + }; +class crash_requested: public exception + { virtual const char* what() const throw() + { return "Crash requested by program"; } + }; +class memory_exception : public exception {}; +class how_would_that_work : public exception {}; + + + +#endif diff --git a/Fake-Offline.cpp b/Fake-Offline.cpp new file mode 100644 index 000000000..750598d18 --- /dev/null +++ b/Fake-Offline.cpp @@ -0,0 +1,554 @@ +// (C) 2016 University of Bristol. See License.txt + + +#include "Math/gf2n.h" +#include "Math/gfp.h" +#include "Math/Share.h" +#include "Math/Setup.h" +#include "Auth/fake-stuff.h" +#include "Exceptions/Exceptions.h" + +#include "Math/Setup.h" +#include "Processor/Data_Files.h" +#include "Tools/mkpath.h" +#include "Tools/ezOptionParser.h" + +#include +#include +using namespace std; + + +string prep_data_prefix; + +/* N = Number players + * ntrip = Number triples needed + * str = "2" or "p" + */ +template +void make_mult_triples(const T& key,int N,int ntrip,const string& str,bool zero) +{ + PRNG G; + G.ReSeed(); + + ofstream* outf=new ofstream[N]; + T a,b,c; + vector > Sa(N),Sb(N),Sc(N); + /* Generate Triples */ + for (int i=0; i > Sa(N),Sb(N),Sc(N); + /* Generate Triples */ + for (int i=0; i +void make_square_tuples(const T& key,int N,int ntrip,const string& str,bool zero) +{ + PRNG G; + G.ReSeed(); + + ofstream* outf=new ofstream[N]; + T a,c; + vector > Sa(N),Sc(N); + /* Generate Squares */ + for (int i=0; i +void make_bits(const T& key,int N,int ntrip,const string& str,bool zero) +{ + PRNG G; + G.ReSeed(); + + ofstream* outf=new ofstream[N]; + T a; + vector > Sa(N); + /* Generate Bits */ + for (int i=0; i +void make_inputs(const T& key,int N,int ntrip,const string& str,bool zero) +{ + PRNG G; + G.ReSeed(); + + ofstream* outf=new ofstream[N]; + T a; + vector > Sa(N); + /* Generate Inputs */ + for (int player=0; player +void make_inverse(const T& key,int N,int ntrip,bool zero) +{ + PRNG G; + G.ReSeed(); + + ofstream* outf=new ofstream[N]; + T a,b; + vector > Sa(N),Sb(N); + /* Generate Triples */ + for (int i=0; i +void make_PreMulC(const T& key, int N, int ntrip, bool zero) +{ + stringstream ss; + ss << prep_data_prefix << "PreMulC-" << T::type_char(); + Files files(N, key, ss.str()); + PRNG G; + G.ReSeed(); + T a, b, c; + c = 1; + for (int i=0; i badOptions; + string usage; + unsigned int i; + if(!opt.gotRequired(badOptions)) + { + for (i=0; i < badOptions.size(); ++i) + cerr << "ERROR: Missing required option " << badOptions[i] << "."; + opt.getUsage(usage); + cout << usage; + return 1; + } + + if(!opt.gotExpected(badOptions)) + { + for(i=0; i < badOptions.size(); ++i) + cerr << "ERROR: Got unexpected number of arguments for option " << badOptions[i] << "."; + opt.getUsage(usage); + cout << usage; + return 1; + } + + int nplayers; + if (opt.firstArgs.size() == 2) + { + nplayers = atoi(opt.firstArgs[1]->c_str()); + } + else if (opt.lastArgs.size() == 1) + { + nplayers = atoi(opt.lastArgs[0]->c_str()); + } + else + { + cerr << "ERROR: invalid number of arguments\n"; + opt.getUsage(usage); + cout << usage; + return 1; + } + + int default_num = 0; + int ntrip2=0, ntripp=0, nbits2=0,nbitsp=0,nsqr2=0,nsqrp=0,ninp2=0,ninpp=0,ninv=0, nbittrip=0, nbitgf2ntrip=0; + vector list_options; + int lg2, lgp; + + opt.get("--lgp")->getInt(lgp); + opt.get("--lg2")->getInt(lg2); + + opt.get("--default")->getInt(default_num); + ntrip2 = ntripp = nbits2 = nbitsp = nsqr2 = nsqrp = ninp2 = ninpp = ninv = + nbittrip = nbitgf2ntrip = default_num; + + if (opt.isSet("--ntriples")) + { + opt.get("--ntriples")->getInts(list_options); + ntrip2 = list_options[0]; + ntripp = list_options[1]; + } + if (opt.isSet("--nbits")) + { + opt.get("--nbits")->getInts(list_options); + nbits2 = list_options[0]; + nbitsp = list_options[1]; + } + if (opt.isSet("--ninputs")) + { + opt.get("--ninputs")->getInts(list_options); + ninp2 = list_options[0]; + ninpp = list_options[1]; + } + if (opt.isSet("--nsquares")) + { + opt.get("--nsquares")->getInts(list_options); + nsqr2 = list_options[0]; + nsqrp = list_options[1]; + } + if (opt.isSet("--ninverses")) + opt.get("--ninverses")->getInt(ninv); + if (opt.isSet("--nbittriples")) + opt.get("--nbittriples")->getInt(nbittrip); + if (opt.isSet("--nbitgf2ntriples")) + opt.get("--nbitgf2ntriples")->getInt(nbitgf2ntrip); + + bool zero = opt.isSet("--zero"); + if (zero) + cout << "Set all values to zero" << endl; + + PRNG G; + G.ReSeed(); + prep_data_prefix = get_prep_dir(nplayers, lgp, lg2); + // Set up the fields + ofstream outf; + bigint p; + generate_online_setup(outf, prep_data_prefix, p, lgp, lg2); + + generate_keys(prep_data_prefix, nplayers); + + /* Find number players and MAC keys etc*/ + gfp keyp,pp; keyp.assign_zero(); + gf2n key2,p2; key2.assign_zero(); + int tmpN = 0; + ifstream inpf; + + // create Player-Data if not there + if (mkdir_p("Player-Data") == -1) + { + cerr << "mkdir_p(Player-Data) failed\n"; + throw file_error(); + } + + for (i = 0; i < (unsigned int)nplayers; i++) + { + stringstream filename; + filename << prep_data_prefix << "Player-MAC-Keys-P" << i; + inpf.open(filename.str().c_str()); + if (inpf.fail()) + { + inpf.close(); + cout << "No MAC key share for player " << i << ", generating a fresh one\n"; + pp.randomize(G); + p2.randomize(G); + ofstream outf(filename.str().c_str()); + if (outf.fail()) + throw file_error(filename.str().c_str()); + outf << nplayers << " " << pp << " " << p2; + outf.close(); + cout << "Written new MAC key share to " << filename.str() << endl; + } + else + { + inpf >> tmpN; // not needed here + pp.input(inpf,true); + p2.input(inpf,true); + inpf.close(); + } + cout << " Key " << i << "\t p: " << pp << "\n\t 2: " << p2 << endl; + keyp.add(pp); + key2.add(p2); + } + cout << "--------------\n"; + cout << "Final Keys :\t p: " << keyp << "\n\t\t 2: " << key2 << endl; + + make_mult_triples(key2,nplayers,ntrip2,"2",zero); + make_mult_triples(keyp,nplayers,ntripp,"p",zero); + make_bits(key2,nplayers,nbits2,"2",zero); + make_bits(keyp,nplayers,nbitsp,"p",zero); + make_square_tuples(key2,nplayers,nsqr2,"2",zero); + make_square_tuples(keyp,nplayers,nsqrp,"p",zero); + make_inputs(key2,nplayers,ninp2,"2",zero); + make_inputs(keyp,nplayers,ninpp,"p",zero); + make_inverse(key2,nplayers,ninv,zero); + make_inverse(keyp,nplayers,ninv,zero); + make_bit_triples(key2,nplayers,nbittrip,DATA_BITTRIPLE,zero); + make_bit_triples(key2,nplayers,nbitgf2ntrip,DATA_BITGF2NTRIPLE,zero); + make_PreMulC(key2,nplayers,ninv,zero); + make_PreMulC(keyp,nplayers,ninv,zero); +} diff --git a/HOSTS.example b/HOSTS.example new file mode 100644 index 000000000..376c273bd --- /dev/null +++ b/HOSTS.example @@ -0,0 +1,5 @@ +192.168.0.1 +192.168.0.2 +192.168.0.3 +192.168.0.4 +192.168.0.5 diff --git a/License.txt b/License.txt new file mode 100644 index 000000000..e30817d3a --- /dev/null +++ b/License.txt @@ -0,0 +1,19 @@ +University of Bristol : Open Access Software Licence + +Copyright (c) 2016, The University of Bristol, a chartered corporation having Royal Charter number RC000648 and a charity (number X1121) and its place of administration being at Senate House, Tyndall Avenue, Bristol, BS8 1TH, United Kingdom. + +All rights reserved + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +Any use of the software for scientific publications or commercial purposes should be reported to the University of Bristol (OSI-notifications@bristol.ac.uk and quote reference 1914). This is for impact and usage monitoring purposes only. + +Enquiries about further applications and development opportunities are welcome. Please contact nigel@cs.bris.ac.uk + diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..eff45dac5 --- /dev/null +++ b/Makefile @@ -0,0 +1,75 @@ +# (C) 2016 University of Bristol. See License.txt + + +include CONFIG + +MATH = $(patsubst %.cpp,%.o,$(wildcard Math/*.cpp)) + +TOOLS = $(patsubst %.cpp,%.o,$(wildcard Tools/*.cpp)) + +NETWORK = $(patsubst %.cpp,%.o,$(wildcard Networking/*.cpp)) + +AUTH = $(patsubst %.cpp,%.o,$(wildcard Auth/*.cpp)) + +PROCESSOR = $(patsubst %.cpp,%.o,$(wildcard Processor/*.cpp)) + +# OT stuff needs GF2N_LONG, so only compile if this is enabled +ifeq ($(USE_GF2N_LONG),1) +OT = $(patsubst %.cpp,%.o,$(filter-out OT/OText_main.cpp,$(wildcard OT/*.cpp))) +OT_EXE = ot.x ot-offline.x +endif + +COMMON = $(MATH) $(TOOLS) $(NETWORK) $(AUTH) +COMPLETE = $(COMMON) $(PROCESSOR) $(FHEOFFLINE) $(TINYOTOFFLINE) + +LIB = libSPDZ.a +LIBSIMPLEOT = SimpleOT/libsimpleot.a + + +all: gen_input online offline + +online: Fake-Offline.x Server.x Player-Online.x Check-Offline.x + +offline: $(OT_EXE) Check-Offline.x + +gen_input: gen_input_f2n.x gen_input_fp.x + +Fake-Offline.x: Fake-Offline.cpp $(COMMON) $(PROCESSOR) + $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) + +Check-Offline.x: Check-Offline.cpp $(COMMON) $(PROCESSOR) + $(CXX) $(CFLAGS) Check-Offline.cpp -o Check-Offline.x $(COMMON) $(PROCESSOR) $(LDLIBS) + +Server.x: Server.cpp $(COMMON) + $(CXX) $(CFLAGS) Server.cpp -o Server.x $(COMMON) $(LDLIBS) + +Player-Online.x: Player-Online.cpp $(COMMON) $(PROCESSOR) + $(CXX) $(CFLAGS) Player-Online.cpp -o Player-Online.x $(COMMON) $(PROCESSOR) $(LDLIBS) + +ifeq ($(USE_GF2N_LONG),1) +ot.x: $(OT) $(COMMON) OT/OText_main.cpp + $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) $(LIBSIMPLEOT) + +ot-check.x: $(OT) $(COMMON) + $(CXX) $(CFLAGS) -o ot-check.x OT/BitVector.o OT/OutputCheck.cpp $(COMMON) $(LDLIBS) + +ot-bitmatrix.x: $(OT) $(COMMON) OT/BitMatrixTest.cpp + $(CXX) $(CFLAGS) -o ot-bitmatrix.x OT/BitMatrixTest.cpp OT/BitMatrix.o OT/BitVector.o $(COMMON) $(LDLIBS) + +ot-offline.x: $(OT) $(COMMON) ot-offline.cpp + $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) $(LIBSIMPLEOT) +endif + +check-passive.x: $(COMMON) check-passive.cpp + $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) + +gen_input_f2n.x: Scripts/gen_input_f2n.cpp $(COMMON) + $(CXX) $(CFLAGS) Scripts/gen_input_f2n.cpp -o gen_input_f2n.x $(COMMON) $(LDLIBS) + +gen_input_fp.x: Scripts/gen_input_fp.cpp $(COMMON) + $(CXX) $(CFLAGS) Scripts/gen_input_fp.cpp -o gen_input_fp.x $(COMMON) $(LDLIBS) + + +clean: + -rm */*.o *.o *.x core.* *.a gmon.out + diff --git a/Math/Integer.cpp b/Math/Integer.cpp new file mode 100644 index 000000000..b6dc06e7c --- /dev/null +++ b/Math/Integer.cpp @@ -0,0 +1,24 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * Integer.cpp + * + */ + +#include "Integer.h" + +void Integer::output(ostream& s,bool human) const +{ + if (human) + s << a; + else + s.write((char*)&a, sizeof(a)); +} + +void Integer::input(istream& s,bool human) +{ + if (human) + s >> a; + else + s.read((char*)&a, sizeof(a)); +} diff --git a/Math/Integer.h b/Math/Integer.h new file mode 100644 index 000000000..b1730b340 --- /dev/null +++ b/Math/Integer.h @@ -0,0 +1,34 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * Integer.h + * + */ + +#ifndef INTEGER_H_ +#define INTEGER_H_ + +#include +using namespace std; + +// Wrapper class for integer, used for Memory + +class Integer +{ + long a; + + public: + + Integer() { a = 0; } + Integer(long a) : a(a) {} + + long get() const { return a; } + + void assign_zero() { a = 0; } + + void output(ostream& s,bool human) const; + void input(istream& s,bool human); + +}; + +#endif /* INTEGER_H_ */ diff --git a/Math/Setup.cpp b/Math/Setup.cpp new file mode 100644 index 000000000..ac780c40d --- /dev/null +++ b/Math/Setup.cpp @@ -0,0 +1,148 @@ +// (C) 2016 University of Bristol. See License.txt + + +#include "Math/Setup.h" +#include "Math/gfp.h" +#include "Math/gf2n.h" + +#include "Tools/mkpath.h" + +#include + + +/* + * Just setup the primes, doesn't need NTL. + * Sets idx and m to be used by SHE setup if necessary + */ +void SPDZ_Data_Setup_Primes(bigint& p,int lgp,int& idx,int& m) +{ + cout << "Setting up parameters" << endl; + + switch (lgp) + { case -1: + m=16; + idx=1; // Any old figures will do, but need to be for lgp at last + lgp=32; // Switch to bigger prime to get parameters + break; + case 32: + m=8192; + idx=0; + break; + case 64: + m=16384; + idx=1; + break; + case 128: + m=32768; + idx=2; + break; + case 256: + m=32768; + idx=3; + break; + case 512: + m=65536; + idx=4; + break; + default: + throw invalid_params(); + break; + } + cout << "m = " << m << endl; + + // Here we choose a prime which is the order of a BN curve + // - Reason is that there are some applications where this + // would be a good idea. So I have hard coded it in here + // - This is pointless/impossible for lgp=32, 64 so for + // these do something naive + // - Have not tested 256 and 512 + bigint u; + int ex; + if (lgp!=32 && lgp!=64) + { u=1; u=u<<(lgp-1); u=sqrt(sqrt(u/36))/m; + u=u*m; + bigint q; + // cout << ex << " " << u << " " << numBits(u) << endl; + p=(((36*u+36)*u+18)*u+6)*u+1; // The group order of a BN curve + q=(((36*u+36)*u+24)*u+6)*u+1; // The base field size of a BN curve + while (!probPrime(p) || !probPrime(q) || numBits(p)> p; + inpf >> lg2; + + inpf.close(); + + gfp::init_field(p); + gf2n::init_field(lg2); +} + +void read_setup(int nparties, int lg2p, int gf2ndegree) +{ + string dir = get_prep_dir(nparties, lg2p, gf2ndegree); + read_setup(dir); +} diff --git a/Math/Setup.h b/Math/Setup.h new file mode 100644 index 000000000..a3813f1d9 --- /dev/null +++ b/Math/Setup.h @@ -0,0 +1,35 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * Setup.h + * + */ + +#ifndef MATH_SETUP_H_ +#define MATH_SETUP_H_ + +#include "Math/bigint.h" + +#include +using namespace std; + +/* + * Routines to create and read setup files for the finite fields + */ + +// Create setup file for gfp and gf2n +void generate_online_setup(ofstream& outf, string dirname, bigint& p, int lgp, int lg2); + +// Setup primes only +// Chooses a p of at least lgp bits +void SPDZ_Data_Setup_Primes(bigint& p,int lgp,int& idx,int& m); + +// get main directory for prep. data +string get_prep_dir(int nparties, int lg2p, int gf2ndegree); + +// Read online setup file for gfp and gf2n +void read_setup(const string& dir_prefix); +void read_setup(int nparties, int lg2p, int gf2ndegree); + + +#endif /* MATH_SETUP_H_ */ diff --git a/Math/Share.cpp b/Math/Share.cpp new file mode 100644 index 000000000..065c69a8d --- /dev/null +++ b/Math/Share.cpp @@ -0,0 +1,126 @@ +// (C) 2016 University of Bristol. See License.txt + + +#include "Share.h" +//#include "Tools/random.h" +#include "Math/gfp.h" +#include "Math/gf2n.h" +#include "Math/operators.h" + + +template +Share::Share(const T& aa, int my_num, const T& alphai) +{ + if (my_num == 0) + a = aa; + else + a.assign_zero(); + mac = aa * alphai; +} + + +template +void Share::mul_by_bit(const Share& S,const T& aa) +{ + a.mul(S.a,aa); + mac.mul(S.mac,aa); +} + +template<> +void Share::mul_by_bit(const Share& S, const gf2n& aa) +{ + a.mul_by_bit(S.a,aa); + mac.mul_by_bit(S.mac,aa); +} + +template +void Share::add(const Share& S,const T& aa,bool playerone,const T& alphai) +{ + if (playerone) + { a.add(S.a,aa); } + else + { a=S.a; } + + T tmp; + tmp.mul(alphai,aa); + mac.add(S.mac,tmp); +} + + + +template +void Share::sub(const Share& S,const T& aa,bool playerone,const T& alphai) +{ + if (playerone) + { a.sub(S.a,aa); } + else + { a=S.a; } + + T tmp; + tmp.mul(alphai,aa); + mac.sub(S.mac,tmp); +} + + + +template +void Share::sub(const T& aa,const Share& S,bool playerone,const T& alphai) +{ + if (playerone) + { a.sub(aa,S.a); } + else + { a=S.a; + a.negate(); + } + + T tmp; + tmp.mul(alphai,aa); + mac.sub(tmp,S.mac); +} + + + +template +void Share::sub(const Share& S1,const Share& S2) +{ + a.sub(S1.a,S2.a); + mac.sub(S1.mac,S2.mac); +} + + + +template +T combine(const vector< Share >& S) +{ + T ans=S[0].a; + for (unsigned int i=1; i +bool check_macs(const vector< Share >& S,const T& key) +{ + T val=combine(S); + + // Now check the MAC is valid + val.mul(val,key); + for (unsigned i=0; i; +template class Share; +template gf2n combine(const vector< Share >& S); +template gfp combine(const vector< Share >& S); +template bool check_macs(const vector< Share >& S,const gf2n& key); +template bool check_macs(const vector< Share >& S,const gfp& key); + +#ifdef USE_GF2N_LONG +template class Share; +template gf2n_short combine(const vector< Share >& S); +template bool check_macs(const vector< Share >& S,const gf2n_short& key); +#endif diff --git a/Math/Share.h b/Math/Share.h new file mode 100644 index 000000000..95382c11d --- /dev/null +++ b/Math/Share.h @@ -0,0 +1,117 @@ +// (C) 2016 University of Bristol. See License.txt + + +#ifndef _Share +#define _Share + +/* Class for holding a share of either a T or gfp element */ + +#include +#include +using namespace std; + +#include "Math/gfp.h" +#include "Math/gf2n.h" + +// Forward declaration as apparently this is needed for friends in templates +template class Share; +template T combine(const vector< Share >& S); +template bool check_macs(const vector< Share >& S,const T& key); + + +template +class Share +{ + T a; // The share + T mac; // Shares of the mac + + public: + + typedef T value_type; + + static int size() + { return 2 * T::size(); } + + static string type_string() + { return T::type_string(); } + + void assign(const Share& S) + { a=S.a; mac=S.mac; } + void assign(const char* buffer) + { a.assign(buffer); mac.assign(buffer + T::size()); } + void assign_zero() + { a.assign_zero(); + mac.assign_zero(); + } + + Share() { assign_zero(); } + Share(const Share& S) { assign(S); } + Share(const T& aa, int my_num, const T& alphai); + ~Share() { ; } + Share& operator=(const Share& S) + { if (this!=&S) { assign(S); } + return *this; + } + + const T& get_share() const { return a; } + const T& get_mac() const { return mac; } + void set_share(const T& aa) { a=aa; } + void set_mac(const T& aa) { mac=aa; } + + /* Arithmetic Routines */ + void mul(const Share& S,const T& aa); + void mul_by_bit(const Share& S,const T& aa); + void add(const Share& S,const T& aa,bool playerone,const T& alphai); + void negate() { a.negate(); mac.negate(); } + void sub(const Share& S,const T& aa,bool playerone,const T& alphai); + void sub(const T& aa,const Share& S,bool playerone,const T& alphai); + void add(const Share& S1,const Share& S2); + void sub(const Share& S1,const Share& S2); + void add(const Share& S1) { add(*this,S1); } + + // Input and output from a stream + // - Can do in human or machine only format (later should be faster) + void output(ostream& s,bool human) const + { a.output(s,human); if (human) { s << " "; } + mac.output(s,human); + } + void input(istream& s,bool human) + { a.input(s,human); + mac.input(s,human); + } + + /* Takes a vector of shares, one from each player and + * determines the shared value + * - i.e. Partially open the shares + */ + friend T combine(const vector< Share >& S); + + /* Given a set of shares, one from each player and + * the global key, determines if the sharing is valid + * - Mainly for test purposes + */ + friend bool check_macs(const vector< Share >& S,const T& key); +}; + +// specialized mul by bit for gf2n +template <> +void Share::mul_by_bit(const Share& S,const gf2n& aa); + +template +Share operator*(const T& y, const Share& x) { Share res; res.mul(x, y); return res; } + +template +inline void Share::add(const Share& S1,const Share& S2) +{ + a.add(S1.a,S2.a); + mac.add(S1.mac,S2.mac); +} + +template +inline void Share::mul(const Share& S,const T& aa) +{ + a.mul(S.a,aa); + mac.mul(S.mac,aa); +} + +#endif diff --git a/Math/Zp_Data.cpp b/Math/Zp_Data.cpp new file mode 100644 index 000000000..d5ef56cca --- /dev/null +++ b/Math/Zp_Data.cpp @@ -0,0 +1,138 @@ +// (C) 2016 University of Bristol. See License.txt + + +#include "Zp_Data.h" + + +void Zp_Data::init(const bigint& p,bool mont) +{ pr=p; + mask=(1<<((mpz_sizeinbase(pr.get_mpz_t(),2)-1)%(8*sizeof(mp_limb_t))))-1; + + montgomery=mont; + t=mpz_size(pr.get_mpz_t()); + if (t>=MAX_MOD_SZ) + throw max_mod_sz_too_small(t+1); + if (montgomery) + { mpn_zero(R,MAX_MOD_SZ); + mpn_zero(R2,MAX_MOD_SZ); + mpn_zero(R3,MAX_MOD_SZ); + bigint r=2,pp=pr; + mpz_pow_ui(r.get_mpz_t(),r.get_mpz_t(),t*8*sizeof(mp_limb_t)); + mpz_invert(pp.get_mpz_t(),pr.get_mpz_t(),r.get_mpz_t()); + pp=r-pp; // pi=-1/p mod R + pi=(pp.get_mpz_t()->_mp_d)[0]; + + r=r%pr; + mpn_copyi(R,r.get_mpz_t()->_mp_d,mpz_size(r.get_mpz_t())); + + bigint r2=(r*r)%pr; + mpn_copyi(R2,r2.get_mpz_t()->_mp_d,mpz_size(r2.get_mpz_t())); + + bigint r3=(r2*r)%pr; + mpn_copyi(R3,r3.get_mpz_t()->_mp_d,mpz_size(r3.get_mpz_t())); + + if (sizeof(unsigned long)!=sizeof(mp_limb_t)) + { cout << "The underlying types of MPIR mean we cannot use our Montgomery code" << endl; + throw not_implemented(); + } + } + mpn_zero(prA,MAX_MOD_SZ); + mpn_copyi(prA,pr.get_mpz_t()->_mp_d,t); +} + + +void Zp_Data::assign(const Zp_Data& Zp) +{ pr=Zp.pr; + mask=Zp.mask; + + montgomery=Zp.montgomery; + t=Zp.t; + mpn_copyi(R,Zp.R,t+1); + mpn_copyi(R2,Zp.R2,t+1); + mpn_copyi(R3,Zp.R3,t+1); + pi=Zp.pi; + + mpn_copyi(prA,Zp.prA,t+1); +} + + +void Zp_Data::Sub(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const +{ + mp_limb_t borrow = mpn_sub_n(ans,x,y,t); + if (borrow!=0) + mpn_add_n(ans,ans,prA,t); +} + +__m128i Zp_Data::get_random128(PRNG& G) +{ + while (true) + { + __m128i res = G.get_doubleword(); + if (mpn_cmp((mp_limb_t*)&res, prA, t) < 0) + return res; + } +} + + +#include + +void Zp_Data::Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const +{ + if (x[t]!=0 || y[t]!=0) { cout << "Mont_Mult Bug" << endl; abort(); } + mp_limb_t ans[2*MAX_MOD_SZ],u; + // First loop + u=x[0]*y[0]*pi; + ans[t] = mpn_mul_1(ans,y,t,x[0]); + ans[t+1] = mpn_addmul_1(ans,prA,t+1,u); + for (int i=1; i=pr) { ans=z-pr; } + // else { z=ans; } + if (mpn_cmp(ans+t,prA,t+1)>=0) + { mpn_sub_n(z,ans+t,prA,t); } + else + { mpn_copyi(z,ans+t,t); } +} + + + +ostream& operator<<(ostream& s,const Zp_Data& ZpD) +{ + s << ZpD.pr << " " << ZpD.montgomery << endl; + if (ZpD.montgomery) + { s << ZpD.t << " " << ZpD.pi << endl; + for (int i=0; i>(istream& s,Zp_Data& ZpD) +{ + s >> ZpD.pr >> ZpD.montgomery; + if (ZpD.montgomery) + { s >> ZpD.t >> ZpD.pi; + if (ZpD.t>=MAX_MOD_SZ) + throw max_mod_sz_too_small(ZpD.t+1); + mpn_zero(ZpD.R,MAX_MOD_SZ); + mpn_zero(ZpD.R2,MAX_MOD_SZ); + mpn_zero(ZpD.R3,MAX_MOD_SZ); + mpn_zero(ZpD.prA,MAX_MOD_SZ); + for (int i=0; i> ZpD.R[i]; } + for (int i=0; i> ZpD.R2[i]; } + for (int i=0; i> ZpD.R3[i]; } + for (int i=0; i> ZpD.prA[i]; } + } + return s; +} diff --git a/Math/Zp_Data.h b/Math/Zp_Data.h new file mode 100644 index 000000000..0ddb14e94 --- /dev/null +++ b/Math/Zp_Data.h @@ -0,0 +1,129 @@ +// (C) 2016 University of Bristol. See License.txt + +#ifndef _Zp_Data +#define _Zp_Data + +/* Class to define helper information for a Zp element + * + * Basically the data needed for Montgomery operations + * + * Almost all data is public as this is basically a container class + * + */ + +#include "Math/bigint.h" +#include "Tools/random.h" + +#include +#include +using namespace std; + +#ifndef MAX_MOD_SZ + #ifdef LargeM + #define MAX_MOD_SZ 20 + #else + #define MAX_MOD_SZ 3 + #endif +#endif + +class modp; + +class Zp_Data +{ + bool montgomery; // True if we are using Montgomery arithmetic + mp_limb_t R[MAX_MOD_SZ],R2[MAX_MOD_SZ],R3[MAX_MOD_SZ],pi; + mp_limb_t prA[MAX_MOD_SZ]; + int t; // More Montgomery data + + void Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const; + + public: + + bigint pr; + mp_limb_t mask; + + void assign(const Zp_Data& Zp); + void init(const bigint& p,bool mont=true); + int get_t() const { return t; } + const mp_limb_t* get_prA() const { return prA; } + + // This one does nothing, needed so as to make vectors of Zp_Data + Zp_Data() : montgomery(0), pi(0), mask(0) { t=1; } + + // The main init funciton + Zp_Data(const bigint& p,bool mont=true) + { init(p,mont); } + + Zp_Data(const Zp_Data& Zp) { assign(Zp); } + Zp_Data& operator=(const Zp_Data& Zp) + { if (this!=&Zp) { assign(Zp); } + return *this; + } + ~Zp_Data() { ; } + + template + void Add(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const; + void Add(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const; + void Sub(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const; + + __m128i get_random128(PRNG& G); + + friend void to_bigint(bigint& ans,const modp& x,const Zp_Data& ZpD,bool reduce); + + friend void to_modp(modp& ans,int x,const Zp_Data& ZpD); + friend void to_modp(modp& ans,const bigint& x,const Zp_Data& ZpD); + + friend void Add(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD); + friend void Sub(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD); + friend void Mul(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD); + friend void Sqr(modp& ans,const modp& x,const Zp_Data& ZpD); + friend void Negate(modp& ans,const modp& x,const Zp_Data& ZpD); + friend void Inv(modp& ans,const modp& x,const Zp_Data& ZpD); + + friend void Power(modp& ans,const modp& x,int exp,const Zp_Data& ZpD); + friend void Power(modp& ans,const modp& x,const bigint& exp,const Zp_Data& ZpD); + + friend void assignOne(modp& x,const Zp_Data& ZpD); + friend void assignZero(modp& x,const Zp_Data& ZpD); + friend bool isZero(const modp& x,const Zp_Data& ZpD); + friend bool isOne(const modp& x,const Zp_Data& ZpD); + friend bool areEqual(const modp& x,const modp& y,const Zp_Data& ZpD); + + friend class modp; + + friend ostream& operator<<(ostream& s,const Zp_Data& ZpD); + friend istream& operator>>(istream& s,Zp_Data& ZpD); +}; + +template<> +inline void Zp_Data::Add<2>(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const +{ + __uint128_t a, b, p; + memcpy(&a, x, sizeof(__uint128_t)); + memcpy(&b, y, sizeof(__uint128_t)); + memcpy(&p, prA, sizeof(__uint128_t)); + __uint128_t c = a + b; + asm goto ("jc %l[sub]" :::: sub); + if (c >= p) + sub: + c -= p; + memcpy(ans, &c, sizeof(__uint128_t)); +} + +template<> +inline void Zp_Data::Add<0>(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const +{ + mp_limb_t carry = mpn_add_n(ans,x,y,t); + if (carry!=0 || mpn_cmp(ans,prA,t)>=0) + { mpn_sub_n(ans,ans,prA,t); } +} + +inline void Zp_Data::Add(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const +{ + if (t == 2) + return Add<2>(ans, x, y); + else + return Add<0>(ans, x, y); +} + +#endif diff --git a/Math/bigint.cpp b/Math/bigint.cpp new file mode 100644 index 000000000..ee2ce958a --- /dev/null +++ b/Math/bigint.cpp @@ -0,0 +1,95 @@ +// (C) 2016 University of Bristol. See License.txt + + +#include "bigint.h" +#include "Exceptions/Exceptions.h" + + +bigint sqrRootMod(const bigint& a,const bigint& p) +{ + bigint ans; + if (a==0) { ans=0; return ans; } + if (mpz_tstbit(p.get_mpz_t(),1)==1) + { // First do case with p=3 mod 4 + bigint exp=(p+1)/4; + mpz_powm(ans.get_mpz_t(),a.get_mpz_t(),exp.get_mpz_t(),p.get_mpz_t()); + } + else + { // Shanks algorithm + gmp_randclass Gen(gmp_randinit_default); + Gen.seed(0); + bigint x,y,n,q,t,b,temp; + // Find n such that (n/p)=-1 + int leg=1; + while (leg!=-1) + { n=Gen.get_z_range(p); + leg=mpz_legendre(n.get_mpz_t(),p.get_mpz_t()); + } + // Split p-1 = 2^e q + q=p-1; + int e=0; + while (mpz_even_p(q.get_mpz_t())) + { e++; q=q/2; } + // y=n^q mod p, x=a^((q-1)/2) mod p, r=e + int r=e; + mpz_powm(y.get_mpz_t(),n.get_mpz_t(),q.get_mpz_t(),p.get_mpz_t()); + temp=(q-1)/2; + mpz_powm(x.get_mpz_t(),a.get_mpz_t(),temp.get_mpz_t(),p.get_mpz_t()); + // b=a*x^2 mod p, x=a*x mod p + b=(a*x*x)%p; + x=(a*x)%p; + // While b!=1 do + while (b!=1) + { // Find smallest m such that b^(2^m)=1 mod p + int m=1; + temp=(b*b)%p; + while (temp!=1) + { temp=(temp*temp)%p; m++; } + // t=y^(2^(r-m-1)) mod p, y=t^2, r=m + t=y; + for (int i=0; i=0) + { mpz_powm(ans.get_mpz_t(),x.get_mpz_t(),e.get_mpz_t(),p.get_mpz_t()); } + else + { bigint xi,ei=-e; + invMod(xi,x,p); + mpz_powm(ans.get_mpz_t(),xi.get_mpz_t(),ei.get_mpz_t(),p.get_mpz_t()); + } + + return ans; +} + + +int powerMod(int x,int e,int p) +{ + if (e==1) { return x; } + if (e==0) { return 1; } + if (e<0) + { throw not_implemented(); } + int t=x,ans=1; + while (e!=0) + { if ((e&1)==1) { ans=(ans*t)%p; } + e>>=1; + t=(t*t)%p; + } + return ans; +} + + diff --git a/Math/bigint.h b/Math/bigint.h new file mode 100644 index 000000000..99c6f4073 --- /dev/null +++ b/Math/bigint.h @@ -0,0 +1,116 @@ +// (C) 2016 University of Bristol. See License.txt + +#ifndef _bigint +#define _bigint + +#include +using namespace std; + +#include +#include + +typedef mpz_class bigint; + +#include "Exceptions/Exceptions.h" +#include "Tools/int.h" + +/********************************** + * Utility Functions * + **********************************/ + +inline int gcd(const int x,const int y) +{ + bigint xx=x; + return mpz_gcd_ui(NULL,xx.get_mpz_t(),y); +} + + +inline bigint gcd(const bigint& x,const bigint& y) +{ + bigint g; + mpz_gcd(g.get_mpz_t(),x.get_mpz_t(),y.get_mpz_t()); + return g; +} + + +inline void invMod(bigint& ans,const bigint& x,const bigint& p) +{ + mpz_invert(ans.get_mpz_t(),x.get_mpz_t(),p.get_mpz_t()); +} + +inline int numBits(const bigint& m) +{ + return mpz_sizeinbase(m.get_mpz_t(),2); +} + + + +inline int numBits(int m) +{ + bigint te=m; + return mpz_sizeinbase(te.get_mpz_t(),2); +} + + + +inline int numBytes(const bigint& m) +{ + return mpz_sizeinbase(m.get_mpz_t(),256); +} + + + + + +inline int probPrime(const bigint& x) +{ + gmp_randstate_t rand_state; + gmp_randinit_default(rand_state); + int ans=mpz_probable_prime_p(x.get_mpz_t(),rand_state,40,0); + gmp_randclear(rand_state); + return ans; +} + + +inline void bigintFromBytes(bigint& x,octet* bytes,int len) +{ + mpz_import(x.get_mpz_t(),len,1,sizeof(octet),0,0,bytes); +} + + +inline void bytesFromBigint(octet* bytes,const bigint& x,unsigned int len) +{ + size_t ll; + mpz_export(bytes,&ll,1,sizeof(octet),0,0,x.get_mpz_t()); + if (ll>len) + { throw invalid_length(); } + for (unsigned int i=ll; i=0 +int powerMod(int x,int e,int p); + +inline int Hwt(int N) +{ + int result=0; + while(N) + { result++; + N&=(N-1); + } + return result; +} + +#endif + diff --git a/Math/field_types.h b/Math/field_types.h new file mode 100644 index 000000000..f1f9fb6db --- /dev/null +++ b/Math/field_types.h @@ -0,0 +1,15 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * types.h + * + */ + +#ifndef MATH_FIELD_TYPES_H_ +#define MATH_FIELD_TYPES_H_ + + +enum DataFieldType { DATA_MODP, DATA_GF2N, N_DATA_FIELD_TYPE }; + + +#endif /* MATH_FIELD_TYPES_H_ */ diff --git a/Math/gf2n.cpp b/Math/gf2n.cpp new file mode 100644 index 000000000..c0f7b5ea0 --- /dev/null +++ b/Math/gf2n.cpp @@ -0,0 +1,345 @@ +// (C) 2016 University of Bristol. See License.txt + + +#include "Math/gf2n.h" + +#include "Exceptions/Exceptions.h" + +#include +#include +#include +#include + +int gf2n_short::n; +int gf2n_short::t1; +int gf2n_short::t2; +int gf2n_short::t3; +int gf2n_short::l0; +int gf2n_short::l1; +int gf2n_short::l2; +int gf2n_short::l3; +int gf2n_short::nterms; +word gf2n_short::mask; +bool gf2n_short::useC; +bool gf2n_short::rewind = false; + +word gf2n_short_table[256][256]; + +#define num_2_fields 4 + +/* Require + * 2*(n-1)-64+t1<64 + */ +int fields_2[num_2_fields][4] = { + {4,1,0,0},{8,4,3,1},{28,1,0,0},{40,20,15,10} + }; + + +void gf2n_short::init_tables() +{ + if (sizeof(word)!=8) + { cout << "Word size is wrong" << endl; + throw not_implemented(); + } + int i,j; + for (i=0; i<256; i++) + { for (j=0; j<256; j++) + { word ii=i,jj=j; + gf2n_short_table[i][j]=0; + while (ii!=0) + { if ((ii&1)==1) { gf2n_short_table[i][j]^=jj; } + jj<<=1; + ii>>=1; + } + } + } +} + + +void gf2n_short::init_field(int nn) +{ + gf2n_short::init_tables(); + int i,j=-1; + for (i=0; i=64) { throw invalid_params(); } + + mask=(1ULL<>8, b2=y>>8; + + c0=gf2n_short_table[a1][b1]; + c1=gf2n_short_table[a2][b2]; + word te=gf2n_short_table[a1][b2]^gf2n_short_table[a2][b1]; + c0^=(te&0xFF)<<8; + c1^=te>>8; +} + +/* Takes 16 bit x and y and returns the 32 bit product */ +inline word mul16(word x,word y) +{ + word a1=x&(0xFF), b1=y&(0xFF); + word a2=x>>8, b2=y>>8; + + word ans=gf2n_short_table[a2][b2]<<8; + ans^=gf2n_short_table[a1][b2]^gf2n_short_table[a2][b1]; + ans<<=8; + ans^=gf2n_short_table[a1][b1]; + + return ans; +} + + + +/* Takes 16 bit x the 32 bit square */ +inline word sqr16(word x) +{ + word a1=x&(0xFF),a2=x>>8; + + word ans=gf2n_short_table[a2][a2]<<16; + ans^=gf2n_short_table[a1][a1]; + + return ans; +} + + + +void gf2n_short::reduce_trinomial(word xh,word xl) +{ + // Deal with xh first + a=xl; + a^=(xh<>n; + while (hi!=0) + { a&=mask; + + a^=hi; + a^=(hi<>n; + } +} + +void gf2n_short::reduce_pentanomial(word xh,word xl) +{ + // Deal with xh first + a=xl; + a^=(xh<>n; + while (hi!=0) + { a&=mask; + + a^=hi; + a^=(hi<>n; + } +} + + +void mul32(word x,word y,word& ans) +{ + word a1=x&(0xFFFF),b1=y&(0xFFFF); + word a2=x>>16, b2=y>>16; + + word c0,c1; + + ans=mul16(a1,b1); + word upp=mul16(a2,b2); + + mul16(a1,b2,c0,c1); + ans^=c0<<16; upp^=c1; + + mul16(a2,b1,c0,c1); + ans^=c0<<16; upp^=c1; + + ans^=(upp<<32); +} + + + +void gf2n_short::mul(const gf2n_short& x,const gf2n_short& y) +{ + word hi,lo; + + if (gf2n_short::useC) + { /* Uses Karatsuba */ + word c,d,e,t; + word xl=x.a&0xFFFFFFFF,yl=y.a&0xFFFFFFFF; + word xh=x.a>>32,yh=y.a>>32; + mul32(xl,yl,c); + mul32(xh,yh,d); + mul32((xl^xh),(yl^yh),e); + t=c^e^d; + lo=c^(t<<32); + hi=d^(t>>32); + } + else + { /* Use Intel Instructions */ + __m128i xx,yy,zz; + uint64_t c[] __attribute__((aligned (16))) = { 0,0 }; + xx=_mm_set1_epi64x(x.a); + yy=_mm_set1_epi64x(y.a); + zz=_mm_clmulepi64_si128(xx,yy,0); + _mm_store_si128((__m128i*)c,zz); + lo=c[0]; + hi=c[1]; + } + + reduce(hi,lo); +} + + + + +inline void sqr32(word x,word& ans) +{ + word a1=x&(0xFFFF),a2=x>>16; + ans=sqr16(a1)^(sqr16(a2)<<32); +} + +void gf2n_short::square() +{ + word xh,xl; + sqr32(a&0xFFFFFFFF,xl); + sqr32(a>>32,xh); + reduce(xh,xl); +} + + +void gf2n_short::square(const gf2n_short& bb) +{ + word xh,xl; + sqr32(bb.a&0xFFFFFFFF,xl); + sqr32(bb.a>>32,xh); + reduce(xh,xl); +} + + + + + + +void gf2n_short::invert() +{ + if (is_one()) { return; } + if (is_zero()) { throw division_by_zero(); } + + word u,v=a,B=0,D=1,mod=1; + + mod^=(1ULL<>=1; + if ((B&1)!=0) { B^=mod; } + B>>=1; + } + while ((v&1)==0 && v!=0) + { v>>=1; + if ((D&1)!=0) { D^=mod; } + D>>=1; + } + + if (u>=v) { u=u^v; B=B^D; } + else { v=v^u; D=D^B; } + } + + a=D; +} + + +void gf2n_short::power(long i) +{ + long n=i; + if (n<0) { invert(); n=-n; } + + gf2n_short T=*this; + assign_one(); + while (n!=0) + { if ((n&1)!=0) { mul(*this,T); } + n>>=1; + T.square(); + } +} + + +void gf2n_short::randomize(PRNG& G) +{ + a=G.get_uint(); + a=(a<<32)^G.get_uint(); + a&=mask; +} + + +void gf2n_short::output(ostream& s,bool human) const +{ + if (human) + { s << hex << a << dec << " "; } + else + { s.write((char*) &a,sizeof(word)); } +} + +void gf2n_short::input(istream& s,bool human) +{ + if (s.peek() == EOF) + { if (s.tellg() == 0) + { cout << "IO problem. Empty file?" << endl; + throw file_error(); + } + //throw end_of_file(); + s.clear(); // unset EOF flag + s.seekg(0); + if (!rewind) + cout << "REWINDING - ONLY FOR BENCHMARKING" << endl; + rewind = true; + } + + if (human) + { s >> hex >> a >> dec; } + else + { s.read((char*) &a,sizeof(word)); } + + a &= mask; +} diff --git a/Math/gf2n.h b/Math/gf2n.h new file mode 100644 index 000000000..f76ed20cf --- /dev/null +++ b/Math/gf2n.h @@ -0,0 +1,191 @@ +// (C) 2016 University of Bristol. See License.txt + +#ifndef _gf2n +#define _gf2n + +#include +#include + +#include +using namespace std; + +#include "Tools/random.h" + +#include "Math/gf2nlong.h" +#include "Math/field_types.h" + +/* This interface compatible with the gfp interface + * which then allows us to template the Share + * data type. + */ + + +/* + Arithmetic in Gf_{2^n} with n<64 +*/ + +class gf2n_short +{ + word a; + + static int n,t1,t2,t3,nterms; + static int l0,l1,l2,l3; + static word mask; + static bool useC; + static bool rewind; + + /* Assign x[0..2*nwords] to a and reduce it... */ + void reduce_trinomial(word xh,word xl); + void reduce_pentanomial(word xh,word xl); + + void reduce(word xh,word xl) + { if (nterms==3) + { reduce_pentanomial(xh,xl); } + else + { reduce_trinomial(xh,xl); } + } + + static void init_tables(); + + public: + + typedef gf2n_short value_type; + typedef word internal_type; + + static void init_field(int nn); + static int degree() { return n; } + static int get_nterms() { return nterms; } + static int get_t(int i) + { if (i==0) { return t1; } + else if (i==1) { return t2; } + else if (i==2) { return t3; } + return -1; + } + + static DataFieldType field_type() { return DATA_GF2N; } + static char type_char() { return '2'; } + static string type_string() { return "gf2n"; } + + static int size() { return sizeof(a); } + static int t() { return 0; } + + word get() const { return a; } + word get_word() const { return a; } + + void assign(const gf2n_short& g) { a=g.a; } + + void assign_zero() { a=0; } + void assign_one() { a=1; } + void assign_x() { a=2; } + void assign(word aa) { a=aa&mask; } + void assign(int aa) { a=static_cast(aa)&mask; } + void assign(const char* buffer) { a = *(word*)buffer; } + + int get_bit(int i) const + { return (a>>i)&1; } + void set_bit(int i,unsigned int b) + { if (b==1) + { a |= (1UL< + void add(octet* x) + { a^=*(word*)(x); } + void add(octet* x) + { add<0>(x); } + void sub(const gf2n_short& x,const gf2n_short& y) + { a=x.a^y.a; } + void sub(const gf2n_short& x) + { a^=x.a; } + // = x * y + void mul(const gf2n_short& x,const gf2n_short& y); + void mul(const gf2n_short& x) { mul(*this,x); } + // x * y when one of x,y is a bit + void mul_by_bit(const gf2n_short& x, const gf2n_short& y) { a = x.a * y.a; } + + gf2n_short operator+(const gf2n_short& x) { gf2n_short res; res.add(*this, x); return res; } + gf2n_short operator*(const gf2n_short& x) { gf2n_short res; res.mul(*this, x); return res; } + gf2n_short& operator+=(const gf2n_short& x) { add(x); return *this; } + gf2n_short& operator*=(const gf2n_short& x) { mul(x); return *this; } + gf2n_short operator-(const gf2n_short& x) { gf2n_short res; res.add(*this, x); return res; } + gf2n_short& operator-=(const gf2n_short& x) { sub(x); return *this; } + + void square(); + void square(const gf2n_short& aa); + void invert(); + void invert(const gf2n_short& aa) + { *this=aa; invert(); } + void negate() { return; } + void power(long i); + + /* Bitwise Ops */ + void AND(const gf2n_short& x,const gf2n_short& y) { a=x.a&y.a; } + void XOR(const gf2n_short& x,const gf2n_short& y) { a=x.a^y.a; } + void OR(const gf2n_short& x,const gf2n_short& y) { a=x.a|y.a; } + void NOT(const gf2n_short& x) { a=(~x.a)&mask; } + void SHL(const gf2n_short& x,int n) { a=(x.a<>n; } + + gf2n_short operator&(const gf2n_short& x) { gf2n_short res; res.AND(*this, x); return res; } + gf2n_short operator^(const gf2n_short& x) { gf2n_short res; res.XOR(*this, x); return res; } + gf2n_short operator|(const gf2n_short& x) { gf2n_short res; res.OR(*this, x); return res; } + gf2n_short operator!() { gf2n_short res; res.NOT(*this); return res; } + gf2n_short operator<<(int i) { gf2n_short res; res.SHL(*this, i); return res; } + gf2n_short operator>>(int i) { gf2n_short res; res.SHR(*this, i); return res; } + + /* Crap RNG */ + void randomize(PRNG& G); + // compatibility with gfp + void almost_randomize(PRNG& G) { randomize(G); } + + void output(ostream& s,bool human) const; + void input(istream& s,bool human); + + friend ostream& operator<<(ostream& s,const gf2n_short& x) + { s << hex << "0x" << x.a << dec; + return s; + } + friend istream& operator>>(istream& s,gf2n_short& x) + { s >> hex >> x.a >> dec; + return s; + } + + + // Pack and unpack in native format + // i.e. Dont care about conversion to human readable form + void pack(octetStream& o) const + { o.append((octet*) &a,sizeof(word)); } + void unpack(octetStream& o) + { o.consume((octet*) &a,sizeof(word)); } +}; + +#ifdef USE_GF2N_LONG +typedef gf2n_long gf2n; +#else +typedef gf2n_short gf2n; +#endif + +#endif diff --git a/Math/gf2nlong.cpp b/Math/gf2nlong.cpp new file mode 100644 index 000000000..52e4ee0e4 --- /dev/null +++ b/Math/gf2nlong.cpp @@ -0,0 +1,277 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * gf2n_longlong.cpp + * + */ + +#include "gf2nlong.h" + +#include "Exceptions/Exceptions.h" + +#include +#include +#include +#include + + +bool is_ge(__m128i a, __m128i b) +{ + word aa[2], bb[2]; + _mm_storeu_si128((__m128i*)aa, a); + _mm_storeu_si128((__m128i*)bb, b); +// cout << hex << "is_ge " << aa[1] << " " << bb[1] << " " << (aa[1] > bb[1]) << " "; +// cout << aa[0] << " " << bb[0] << " " << (aa[0] >= bb[0]) << endl; + return aa[1] == bb[1] ? aa[0] >= bb[0] : aa[1] > bb[1]; +} + + +ostream& operator<<(ostream& s, const int128& a) +{ + word* tmp = (word*)&a.a; + s << hex; + s.width(16); + s.fill('0'); + s << tmp[1]; + s.width(16); + s << tmp[0] << dec; + return s; +} + + +int gf2n_long::n; +int gf2n_long::t1; +int gf2n_long::t2; +int gf2n_long::t3; +int gf2n_long::l0; +int gf2n_long::l1; +int gf2n_long::l2; +int gf2n_long::l3; +int gf2n_long::nterms; +int128 gf2n_long::mask; +int128 gf2n_long::lowermask; +int128 gf2n_long::uppermask; +bool gf2n_long::rewind = false; + +#define num_2_fields 1 + +/* Require + * 2*(n-1)-64+t1<64 + */ +int long_fields_2[num_2_fields][4] = { + {128,7,2,1}, + }; + + +void gf2n_long::init_field(int nn) +{ + if (nn!=128) { + cout << "Compiled for GF(2^128) only. Change parameters or compile " + "without USE_GF2N_LONG" << endl; + throw not_implemented(); + } + + int i,j=-1; + for (i=0; i=128) { throw not_implemented(); } + // if (nterms==3 && n!=128) { throw not_implemented(); } + + mask=_mm_set_epi64x(-1,-1); + lowermask=_mm_set_epi64x((1LL<<(64-7))-1,-1); + uppermask=_mm_set_epi64x(((word)-1)<<(64-7),0); +} + + + +void gf2n_long::reduce_trinomial(int128 xh,int128 xl) +{ + // Deal with xh first + a=xl; + a^=(xh<>n; + while (hi==0) + { a&=mask; + + a^=hi; + a^=(hi<>n; + } +} + +void gf2n_long::reduce_pentanomial(int128 xh, int128 xl) +{ + // Deal with xh first + a=xl; + int128 upper, lower; + upper=xh&uppermask; + lower=xh&lowermask; + // Upper part + int128 tmp = 0; + tmp^=(upper>>(n-t1-l0)); + tmp^=(upper>>(n-t1-l1)); + tmp^=(upper>>(n-t1-l2)); + tmp^=(upper>>(n-t1-l3)); + lower^=(tmp>>(l1)); + a^=(tmp<<(n-l1)); + // Lower part + a^=(lower<>n; + while (hi!=0) + { a&=mask; + + a^=hi; + a^=(hi<>n; + } +*/ +} + + +gf2n_long& gf2n_long::mul(const gf2n_long& x,const gf2n_long& y) +{ + __m128i res[2]; + memset(res,0,sizeof(res)); + + mul128(x.a.a,y.a.a,res,res+1); + + reduce(res[1],res[0]); + + return *this; +} + + +class int129 +{ + int128 lower; + bool msb; + +public: + int129() : lower(_mm_setzero_si128()), msb(false) { } + int129(int128 lower, bool msb) : lower(lower), msb(msb) { } + int129(int128 a) : lower(a), msb(false) { } + int129(word a) + { *this = a; } + int128 get_lower() { return lower; } + int129& operator=(const __m128i& other) + { lower = other; msb = false; return *this; } + int129& operator=(const word& other) + { lower = _mm_set_epi64x(0, other); msb = false; return *this; } + bool operator==(const int129& other) + { return (lower == other.lower) && (msb == other.msb); } + bool operator!=(const int129& other) + { return !(*this == other); } + bool operator>=(const int129& other) + { //cout << ">= " << msb << other.msb << (msb > other.msb) << is_ge(lower.a, other.lower.a) << endl; + return msb == other.msb ? is_ge(lower.a, other.lower.a) : msb > other.msb; } + int129 operator<<(int other) + { return int129(lower << other, _mm_cvtsi128_si32(((lower >> (128-other)) & 1).a)); } + int129& operator>>=(int other) + { lower >>= other; lower |= (int128(msb) << (128-other)); msb = !other; return *this; } + int129 operator^(const int129& other) + { return int129(lower ^ other.lower, msb ^ other.msb); } + int129& operator^=(const int129& other) + { lower ^= other.lower; msb ^= other.msb; return *this; } + int129 operator&(const word& other) + { return int129(lower & other, false); } + friend ostream& operator<<(ostream& s, const int129& a) + { s << a.msb << a.lower; return s; } +}; + +void gf2n_long::invert() +{ + if (is_one()) { return; } + if (is_zero()) { throw division_by_zero(); } + + int129 u,v=a,B=0,D=1,mod=1; + + mod^=(int129(1)<>=1; + if ((B&1)!=0) { B^=mod; } + B>>=1; + } + while ((v&1)==0 && v!=0) + { v>>=1; + if ((D&1)!=0) { D^=mod; } + D>>=1; + } + + if (u>=v) { u=u^v; B=B^D; } + else { v=v^u; D=D^B; } + } + + a=D.get_lower(); +} + + +void gf2n_long::randomize(PRNG& G) +{ + a=G.get_doubleword(); + a&=mask; +} + + +void gf2n_long::output(ostream& s,bool human) const +{ + if (human) + { s << *this; } + else + { s.write((char*) &a,sizeof(__m128i)); } +} + +void gf2n_long::input(istream& s,bool human) +{ + if (s.peek() == EOF) + { if (s.tellg() == 0) + { cout << "IO problem. Empty file?" << endl; + throw file_error(); + } + //throw end_of_file(); + s.clear(); // unset EOF flag + s.seekg(0); + if (!rewind) + cout << "REWINDING - ONLY FOR BENCHMARKING" << endl; + rewind = true; + } + + if (human) + { s >> *this; } + else + { s.read((char*) &a,sizeof(__m128i)); } +} diff --git a/Math/gf2nlong.h b/Math/gf2nlong.h new file mode 100644 index 000000000..79bd3f1d8 --- /dev/null +++ b/Math/gf2nlong.h @@ -0,0 +1,274 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * gf2nlong.h + * + */ + +#ifndef MATH_GF2NLONG_H_ +#define MATH_GF2NLONG_H_ + +#include +#include + +#include +using namespace std; + +#include + +#include "Tools/random.h" +#include "Math/field_types.h" + + +class int128 +{ +public: + __m128i a; + + int128() : a(_mm_setzero_si128()) { } + int128(const int128& a) : a(a.a) { } + int128(const __m128i& a) : a(a) { } + int128(const word& a) : a(_mm_cvtsi64_si128(a)) { } + int128(const word& upper, const word& lower) : a(_mm_set_epi64x(upper, lower)) { } + + word get_lower() { return (word)_mm_cvtsi128_si64(a); } + + bool operator==(const int128& other) const { return _mm_test_all_zeros(a ^ other.a, a ^ other.a); } + bool operator!=(const int128& other) const { return !(*this == other); } + + int128 operator<<(const int& other) const; + int128 operator>>(const int& other) const; + + int128 operator^(const int128& other) const { return a ^ other.a; } + int128 operator|(const int128& other) const { return a | other.a; } + int128 operator&(const int128& other) const { return a & other.a; } + + int128 operator~() const { return ~a; } + + int128& operator<<=(const int& other) { return *this = *this << other; } + int128& operator>>=(const int& other) { return *this = *this >> other; } + + int128& operator^=(const int128& other) { a ^= other.a; return *this; } + int128& operator|=(const int128& other) { a |= other.a; return *this; } + int128& operator&=(const int128& other) { a &= other.a; return *this; } + + friend ostream& operator<<(ostream& s, const int128& a); +}; + + +/* This interface compatible with the gfp interface + * which then allows us to template the Share + * data type. + */ + + +/* + Arithmetic in Gf_{2^n} with n<=128 +*/ + +class gf2n_long +{ + int128 a; + + static int n,t1,t2,t3,nterms; + static int l0,l1,l2,l3; + static int128 mask,lowermask,uppermask; + static bool rewind; + + /* Assign x[0..2*nwords] to a and reduce it... */ + void reduce_trinomial(int128 xh,int128 xl); + void reduce_pentanomial(int128 xh,int128 xl); + + public: + + typedef gf2n_long value_type; + typedef int128 internal_type; + + void reduce(int128 xh,int128 xl) + { + if (nterms==3) + { reduce_pentanomial(xh,xl); } + else + { reduce_trinomial(xh,xl); } + } + + static void init_field(int nn); + static int degree() { return n; } + static int get_nterms() { return nterms; } + static int get_t(int i) + { if (i==0) { return t1; } + else if (i==1) { return t2; } + else if (i==2) { return t3; } + return -1; + } + + static DataFieldType field_type() { return DATA_GF2N; } + static char type_char() { return '2'; } + static string type_string() { return "gf2n_long"; } + + static int size() { return sizeof(a); } + static int t() { return 0; } + + int128 get() const { return a; } + __m128i to_m128i() const { return a.a; } + word get_word() const { return _mm_cvtsi128_si64x(a.a); } + + void assign(const gf2n_long& g) { a=g.a; } + + void assign_zero() { a=_mm_setzero_si128(); } + void assign_one() { a=int128(0,1); } + void assign_x() { a=int128(0,2); } + void assign(int128 aa) { a=aa&mask; } + void assign(int aa) { a=int128(static_cast(aa))&mask; } + void assign(const char* buffer) { a = _mm_loadu_si128((__m128i*)buffer); } + + int get_bit(int i) const + { return ((a>>i)&1).get_lower(); } + void set_bit(int i,unsigned int b) + { if (b==1) + { a |= (1UL< + void add(octet* x) + { a^=int128(_mm_loadu_si128((__m128i*)x)); } + void add(octet* x) + { add<0>(x); } + void sub(const gf2n_long& x,const gf2n_long& y) + { a=x.a^y.a; } + void sub(const gf2n_long& x) + { a^=x.a; } + // = x * y + gf2n_long& mul(const gf2n_long& x,const gf2n_long& y); + void mul(const gf2n_long& x) { mul(*this,x); } + // x * y when one of x,y is a bit + void mul_by_bit(const gf2n_long& x, const gf2n_long& y) { a = x.a.a * y.a.a; } + + gf2n_long operator+(const gf2n_long& x) { gf2n_long res; res.add(*this, x); return res; } + gf2n_long operator*(const gf2n_long& x) { gf2n_long res; res.mul(*this, x); return res; } + gf2n_long& operator+=(const gf2n_long& x) { add(x); return *this; } + gf2n_long& operator*=(const gf2n_long& x) { mul(x); return *this; } + gf2n_long operator-(const gf2n_long& x) { gf2n_long res; res.add(*this, x); return res; } + gf2n_long& operator-=(const gf2n_long& x) { sub(x); return *this; } + + void square(); + void square(const gf2n_long& aa); + void invert(); + void invert(const gf2n_long& aa) + { *this=aa; invert(); } + void negate() { return; } + void power(long i); + + /* Bitwise Ops */ + void AND(const gf2n_long& x,const gf2n_long& y) { a=x.a&y.a; } + void XOR(const gf2n_long& x,const gf2n_long& y) { a=x.a^y.a; } + void OR(const gf2n_long& x,const gf2n_long& y) { a=x.a|y.a; } + void NOT(const gf2n_long& x) { a=(~x.a)&mask; } + void SHL(const gf2n_long& x,int n) { a=(x.a<>n; } + + gf2n_long operator&(const gf2n_long& x) { gf2n_long res; res.AND(*this, x); return res; } + gf2n_long operator^(const gf2n_long& x) { gf2n_long res; res.XOR(*this, x); return res; } + gf2n_long operator|(const gf2n_long& x) { gf2n_long res; res.OR(*this, x); return res; } + gf2n_long operator!() { gf2n_long res; res.NOT(*this); return res; } + gf2n_long operator<<(int i) { gf2n_long res; res.SHL(*this, i); return res; } + gf2n_long operator>>(int i) { gf2n_long res; res.SHR(*this, i); return res; } + + /* Crap RNG */ + void randomize(PRNG& G); + // compatibility with gfp + void almost_randomize(PRNG& G) { randomize(G); } + + void output(ostream& s,bool human) const; + void input(istream& s,bool human); + + friend ostream& operator<<(ostream& s,const gf2n_long& x) + { s << hex << x.a << dec; + return s; + } + friend istream& operator>>(istream& s,gf2n_long& x) + { bigint tmp; + s >> hex >> tmp >> dec; + x.a = 0; + mpn_copyi((word*)&x.a.a, tmp.get_mpz_t()->_mp_d, tmp.get_mpz_t()->_mp_size); + return s; + } + + + // Pack and unpack in native format + // i.e. Dont care about conversion to human readable form + void pack(octetStream& o) const + { o.append((octet*) &a,sizeof(__m128i)); } + void unpack(octetStream& o) + { o.consume((octet*) &a,sizeof(__m128i)); } +}; + + +inline int128 int128::operator<<(const int& other) const +{ + int128 res(_mm_slli_epi64(a, other)); + __m128i mask; + if (other < 64) + mask = _mm_srli_epi64(a, 64 - other); + else + mask = _mm_slli_epi64(a, other - 64); + res.a ^= _mm_slli_si128(mask, 8); + return res; +} + +inline int128 int128::operator>>(const int& other) const +{ + int128 res(_mm_srli_epi64(a, other)); + __m128i mask; + if (other < 64) + mask = _mm_slli_epi64(a, 64 - other); + else + mask = _mm_srli_epi64(a, other - 64); + res.a ^= _mm_srli_si128(mask, 8); + return res; +} + +inline void mul128(__m128i a, __m128i b, __m128i *res1, __m128i *res2) +{ + __m128i tmp3, tmp4, tmp5, tmp6; + + tmp3 = _mm_clmulepi64_si128(a, b, 0x00); + tmp4 = _mm_clmulepi64_si128(a, b, 0x10); + tmp5 = _mm_clmulepi64_si128(a, b, 0x01); + tmp6 = _mm_clmulepi64_si128(a, b, 0x11); + + tmp4 = _mm_xor_si128(tmp4, tmp5); + tmp5 = _mm_slli_si128(tmp4, 8); + tmp4 = _mm_srli_si128(tmp4, 8); + tmp3 = _mm_xor_si128(tmp3, tmp5); + tmp6 = _mm_xor_si128(tmp6, tmp4); + // initial mul now in tmp3, tmp6 + *res1 = tmp3; + *res2 = tmp6; +} + +#endif /* MATH_GF2NLONG_H_ */ diff --git a/Math/gfp.cpp b/Math/gfp.cpp new file mode 100644 index 000000000..da20448db --- /dev/null +++ b/Math/gfp.cpp @@ -0,0 +1,125 @@ +// (C) 2016 University of Bristol. See License.txt + + +#include "Math/gfp.h" + +#include "Exceptions/Exceptions.h" + +Zp_Data gfp::ZpD; + +void gfp::almost_randomize(PRNG& G) +{ + G.get_octets((octet*)a.x,t()*sizeof(mp_limb_t)); + a.x[t()-1]&=ZpD.mask; +} + +void gfp::AND(const gfp& x,const gfp& y) +{ + bigint bi1,bi2; + to_bigint(bi1,x); + to_bigint(bi2,y); + mpz_and(bi1.get_mpz_t(), bi1.get_mpz_t(), bi2.get_mpz_t()); + to_gfp(*this, bi1); +} + +void gfp::OR(const gfp& x,const gfp& y) +{ + bigint bi1,bi2; + to_bigint(bi1,x); + to_bigint(bi2,y); + mpz_ior(bi1.get_mpz_t(), bi1.get_mpz_t(), bi2.get_mpz_t()); + to_gfp(*this, bi1); +} + +void gfp::XOR(const gfp& x,const gfp& y) +{ + bigint bi1,bi2; + to_bigint(bi1,x); + to_bigint(bi2,y); + mpz_xor(bi1.get_mpz_t(), bi1.get_mpz_t(), bi2.get_mpz_t()); + to_gfp(*this, bi1); +} + +void gfp::AND(const gfp& x,const bigint& y) +{ + bigint bi; + to_bigint(bi,x); + mpz_and(bi.get_mpz_t(), bi.get_mpz_t(), y.get_mpz_t()); + to_gfp(*this, bi); +} + +void gfp::OR(const gfp& x,const bigint& y) +{ + bigint bi; + to_bigint(bi,x); + mpz_ior(bi.get_mpz_t(), bi.get_mpz_t(), y.get_mpz_t()); + to_gfp(*this, bi); +} + +void gfp::XOR(const gfp& x,const bigint& y) +{ + bigint bi; + to_bigint(bi,x); + mpz_xor(bi.get_mpz_t(), bi.get_mpz_t(), y.get_mpz_t()); + to_gfp(*this, bi); +} + + + + +void gfp::SHL(const gfp& x,int n) +{ + if (!x.is_zero()) + { + bigint bi; + to_bigint(bi,x,false); + mpn_lshift(bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_size,n); + to_gfp(*this, bi); + } + else + { + assign_zero(); + } +} + + +void gfp::SHR(const gfp& x,int n) +{ + if (!x.is_zero()) + { + bigint bi; + to_bigint(bi,x); + mpn_rshift(bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_size,n); + to_gfp(*this, bi); + } + else + { + assign_zero(); + } +} + + +void gfp::SHL(const gfp& x,const bigint& n) +{ + SHL(x,mpz_get_si(n.get_mpz_t())); +} + + +void gfp::SHR(const gfp& x,const bigint& n) +{ + SHR(x,mpz_get_si(n.get_mpz_t())); +} + + +gfp gfp::sqrRoot() +{ + // Temp move to bigint so as to call sqrRootMod + bigint ti; + to_bigint(ti, *this); + ti = sqrRootMod(ti, ZpD.pr); + if (!isOdd(ti)) + ti = ZpD.pr - ti; + gfp temp; + to_gfp(temp, ti); + return temp; +} diff --git a/Math/gfp.h b/Math/gfp.h new file mode 100644 index 000000000..f35ebb8f6 --- /dev/null +++ b/Math/gfp.h @@ -0,0 +1,205 @@ +// (C) 2016 University of Bristol. See License.txt + +#ifndef _gfp +#define _gfp + +#include +using namespace std; + +#include "Math/gf2n.h" +#include "Math/modp.h" +#include "Math/Zp_Data.h" +#include "Math/field_types.h" +#include "Tools/random.h" + +/* This is a wrapper class for the modp data type + * It is used to be interface compatible with the gfp + * type, which then allows us to template the Share + * data type. + * + * So gfp is used ONLY for the stuff in the finite fields + * we are going to be doing MPC over, not the modp stuff + * for the FHE scheme + */ + + +class gfp +{ + modp a; + static Zp_Data ZpD; + + public: + + typedef gfp value_type; + + static void init_field(const bigint& p,bool mont=true) + { ZpD.init(p,mont); } + static bigint pr() + { return ZpD.pr; } + static int t() + { return ZpD.get_t(); } + static Zp_Data& get_ZpD() + { return ZpD; } + + static DataFieldType field_type() { return DATA_MODP; } + static char type_char() { return 'p'; } + static string type_string() { return "gfp"; } + + static int size() { return t() * sizeof(mp_limb_t); } + + void assign(const gfp& g) { a=g.a; } + void assign_zero() { assignZero(a,ZpD); } + void assign_one() { assignOne(a,ZpD); } + void assign(word aa) { bigint b=aa; to_gfp(*this,b); } + void assign(long aa) { bigint b=aa; to_gfp(*this,b); } + void assign(int aa) { bigint b=aa; to_gfp(*this,b); } + void assign(const char* buffer) { a.assign(buffer, ZpD.get_t()); } + + modp get() const { return a; } + + // Assumes prD behind x is equal to ZpD + void assign(modp& x) { a=x; } + + gfp() { assignZero(a,ZpD); } + gfp(const gfp& g) { a=g.a; } + gfp(const modp& g) { a=g; } + gfp(const __m128i& x) { *this=x; } + gfp(const int128& x) { *this=x.a; } + gfp(const bigint& x) { to_modp(a, x, ZpD); } + gfp(int x) { assign(x); } + ~gfp() { ; } + + gfp& operator=(const gfp& g) + { if (&g!=this) { a=g.a; } + return *this; + } + + gfp& operator=(const __m128i other) + { + memcpy(a.x, &other, sizeof(other)); + a.x[2] = 0; + return *this; + } + + void to_m128i(__m128i& ans) + { + memcpy(&ans, a.x, sizeof(ans)); + } + + __m128i to_m128i() + { + return _mm_loadu_si128((__m128i*)a.x); + } + + + bool is_zero() const { return isZero(a,ZpD); } + bool is_one() const { return isOne(a,ZpD); } + bool equal(const gfp& y) const { return areEqual(a,y.a,ZpD); } + bool operator==(const gfp& y) const { return equal(y); } + bool operator!=(const gfp& y) const { return !equal(y); } + + // x+y + template + void add(const gfp& x,const gfp& y) + { Add(a,x.a,y.a,ZpD); } + template + void add(const gfp& x) + { Add(a,a,x.a,ZpD); } + template + void add(void* x) + { ZpD.Add(a.x,a.x,(mp_limb_t*)x); } + void add(const gfp& x,const gfp& y) + { Add(a,x.a,y.a,ZpD); } + void add(const gfp& x) + { Add(a,a,x.a,ZpD); } + void add(void* x) + { ZpD.Add(a.x,a.x,(mp_limb_t*)x); } + void sub(const gfp& x,const gfp& y) + { Sub(a,x.a,y.a,ZpD); } + void sub(const gfp& x) + { Sub(a,a,x.a,ZpD); } + // = x * y + void mul(const gfp& x,const gfp& y) + { Mul(a,x.a,y.a,ZpD); } + void mul(const gfp& x) + { Mul(a,a,x.a,ZpD); } + + gfp operator+(const gfp& x) { gfp res; res.add(*this, x); return res; } + gfp operator-(const gfp& x) { gfp res; res.sub(*this, x); return res; } + gfp operator*(const gfp& x) { gfp res; res.mul(*this, x); return res; } + gfp& operator+=(const gfp& x) { add(x); return *this; } + gfp& operator-=(const gfp& x) { sub(x); return *this; } + gfp& operator*=(const gfp& x) { mul(x); return *this; } + + void square(const gfp& aa) + { Sqr(a,aa.a,ZpD); } + void square() + { Sqr(a,a,ZpD); } + void invert() + { Inv(a,a,ZpD); } + void invert(const gfp& aa) + { Inv(a,aa.a,ZpD); } + void negate() + { Negate(a,a,ZpD); } + void power(long i) + { Power(a,a,i,ZpD); } + + // deterministic square root + gfp sqrRoot(); + + void randomize(PRNG& G) + { a.randomize(G,ZpD); } + // faster randomization, see implementation for explanation + void almost_randomize(PRNG& G); + + void output(ostream& s,bool human) const + { a.output(s,ZpD,human); } + void input(istream& s,bool human) + { a.input(s,ZpD,human); } + + friend ostream& operator<<(ostream& s,const gfp& x) + { x.output(s,true); + return s; + } + friend istream& operator>>(istream& s,gfp& x) + { x.input(s,true); + return s; + } + + /* Bitwise Ops + * - Converts gfp args to bigints and then converts answer back to gfp + */ + void AND(const gfp& x,const gfp& y); + void XOR(const gfp& x,const gfp& y); + void OR(const gfp& x,const gfp& y); + void AND(const gfp& x,const bigint& y); + void XOR(const gfp& x,const bigint& y); + void OR(const gfp& x,const bigint& y); + void SHL(const gfp& x,int n); + void SHR(const gfp& x,int n); + void SHL(const gfp& x,const bigint& n); + void SHR(const gfp& x,const bigint& n); + + gfp operator&(const gfp& x) { gfp res; res.AND(*this, x); return res; } + gfp operator^(const gfp& x) { gfp res; res.XOR(*this, x); return res; } + gfp operator|(const gfp& x) { gfp res; res.OR(*this, x); return res; } + gfp operator<<(int i) { gfp res; res.SHL(*this, i); return res; } + gfp operator>>(int i) { gfp res; res.SHR(*this, i); return res; } + + // Pack and unpack in native format + // i.e. Dont care about conversion to human readable form + void pack(octetStream& o) const + { a.pack(o,ZpD); } + void unpack(octetStream& o) + { a.unpack(o,ZpD); } + + + // Convert representation to and from a bigint number + friend void to_bigint(bigint& ans,const gfp& x,bool reduce=true) + { to_bigint(ans,x.a,x.ZpD,reduce); } + friend void to_gfp(gfp& ans,const bigint& x) + { to_modp(ans.a,x,ans.ZpD); } +}; + + +#endif diff --git a/Math/modp.cpp b/Math/modp.cpp new file mode 100644 index 000000000..033edda32 --- /dev/null +++ b/Math/modp.cpp @@ -0,0 +1,263 @@ +// (C) 2016 University of Bristol. See License.txt + +#include "Zp_Data.h" +#include "modp.h" + +#include "Exceptions/Exceptions.h" + +bool modp::rewind = false; + +/*********************************************************************** + * The following functions remain the same in Real and Montgomery rep * + ***********************************************************************/ + +void modp::randomize(PRNG& G, const Zp_Data& ZpD) +{ + bigint x=G.randomBnd(ZpD.pr); + to_modp(*this,x,ZpD); +} + +void modp::pack(octetStream& o,const Zp_Data& ZpD) const +{ + o.append((octet*) x,ZpD.t*sizeof(mp_limb_t)); +} + + +void modp::unpack(octetStream& o,const Zp_Data& ZpD) +{ + o.consume((octet*) x,ZpD.t*sizeof(mp_limb_t)); +} + + +void Sub(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD) +{ + ZpD.Sub(ans.x, x.x, y.x); +} + + +void Negate(modp& ans,const modp& x,const Zp_Data& ZpD) +{ + if (isZero(x,ZpD)) { ans=x; return; } + mpn_sub_n(ans.x,ZpD.prA,x.x,ZpD.t); +} + + +bool areEqual(const modp& x,const modp& y,const Zp_Data& ZpD) +{ if (mpn_cmp(x.x,y.x,ZpD.t)!=0) + { return false; } + return true; +} + +bool isZero(const modp& ans,const Zp_Data& ZpD) +{ + for (int i=0; i_mp_d,x.x,one); + } + else + { mpn_copyi(a->_mp_d,x.x,ZpD.t+1); } + a->_mp_size=ZpD.t; + if (reduce) + while (a->_mp_size>=1 && (a->_mp_d)[a->_mp_size-1]==0) + { a->_mp_size--; } + ans=bigint(a); + + mpz_clear(a); +} + + +void to_modp(modp& ans,int x,const Zp_Data& ZpD) +{ + mpn_zero(ans.x,ZpD.t+1); + if (x>=0) + { ans.x[0]=x; + if (ZpD.t==1) { ans.x[0]=ans.x[0]%ZpD.prA[0]; } + } + else + { if (ZpD.t==1) + { ans.x[0]=(ZpD.prA[0]+x)%ZpD.prA[0]; } + else + { bigint xx=ZpD.pr+x; + to_modp(ans,xx,ZpD); + return; + } + } + if (ZpD.montgomery) + { ZpD.Mont_Mult(ans.x,ans.x,ZpD.R2); } +} + + +void to_modp(modp& ans,const bigint& x,const Zp_Data& ZpD) +{ + bigint xx=x%ZpD.pr; + if (xx<0) { xx+=ZpD.pr; } + //mpz_mod(xx.get_mpz_t(),x.get_mpz_t(),ZpD.pr.get_mpz_t()); + mpn_zero(ans.x,ZpD.t+1); + mpn_copyi(ans.x,xx.get_mpz_t()->_mp_d,xx.get_mpz_t()->_mp_size); + if (ZpD.montgomery) + { ZpD.Mont_Mult(ans.x,ans.x,ZpD.R2); } +} + + + +void Mul(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD) +{ + if (ZpD.montgomery) + { ZpD.Mont_Mult(ans.x,x.x,y.x); } + else + { //ans.x=(x.x*y.x)%ZpD.pr; + mp_limb_t aa[2*MAX_MOD_SZ],q[2*MAX_MOD_SZ]; + mpn_mul_n(aa,x.x,y.x,ZpD.t); + mpn_tdiv_qr(q,ans.x,0,aa,2*ZpD.t,ZpD.prA,ZpD.t); + } +} + + +void Sqr(modp& ans,const modp& x,const Zp_Data& ZpD) +{ + if (ZpD.montgomery) + { ZpD.Mont_Mult(ans.x,x.x,x.x); } + else + { //ans.x=(x.x*x.x)%ZpD.pr; + mp_limb_t aa[2*MAX_MOD_SZ],q[2*MAX_MOD_SZ]; + mpn_sqr(aa,x.x,ZpD.t); + mpn_tdiv_qr(q,ans.x,0,aa,2*ZpD.t,ZpD.prA,ZpD.t); + } +} + + +void Inv(modp& ans,const modp& x,const Zp_Data& ZpD) +{ + mp_limb_t g[MAX_MOD_SZ],xx[MAX_MOD_SZ+1],yy[MAX_MOD_SZ+1]; + mp_size_t sz; + mpn_copyi(xx,x.x,ZpD.t); + mpn_copyi(yy,ZpD.prA,ZpD.t); + mpn_gcdext(g,ans.x,&sz,xx,ZpD.t,yy,ZpD.t); + if (sz<0) + { mpn_sub(ans.x,ZpD.prA,ZpD.t,ans.x,-sz); + sz=-sz; + } + else + { for (int i=sz; i>=1; + Sqr(t,t,ZpD); + } +} + + +// XXXX This is a crap version. Hopefully this is not time critical +void Power(modp& ans,const modp& x,const bigint& exp,const Zp_Data& ZpD) +{ + if (exp==1) { ans=x; return; } + if (exp==0) { assignOne(ans,ZpD); return; } + if (exp<0) { throw not_implemented(); } + modp t=x; + assignOne(ans,ZpD); + bigint e=exp; + while (e!=0) + { if ((e&1)==1) { Mul(ans,ans,t,ZpD); } + e>>=1; + Sqr(t,t,ZpD); + } +} + + +void modp::output(ostream& s,const Zp_Data& ZpD,bool human) const +{ + if (human) + { bigint te; + to_bigint(te,*this,ZpD); + if (te < ZpD.pr / 2) + s << te; + else + s << (te - ZpD.pr); + } + else + { s.write((char*) x,ZpD.t*sizeof(mp_limb_t)); } +} + +void modp::input(istream& s,const Zp_Data& ZpD,bool human) +{ + if (s.peek() == EOF) + { if (s.tellg() == 0) + { cout << "IO problem. Empty file?" << endl; + throw file_error(); + } + //throw end_of_file(); + s.clear(); // unset EOF flag + s.seekg(0); + if (!rewind) + cout << "REWINDING - ONLY FOR BENCHMARKING" << endl; + rewind = true; + } + + if (human) + { bigint te; + s >> te; + to_modp(*this,te,ZpD); + } + else + { s.read((char*) x,ZpD.t*sizeof(mp_limb_t)); } +} + + diff --git a/Math/modp.h b/Math/modp.h new file mode 100644 index 000000000..b384c3a1d --- /dev/null +++ b/Math/modp.h @@ -0,0 +1,116 @@ +// (C) 2016 University of Bristol. See License.txt + +#ifndef _Modp +#define _Modp + +/* + * Currently we only support an MPIR based implementation. + * + * What ever is type-def'd to bigint is assumed to have + * operator overloading for all standard operators, has + * comparison operations and istream/ostream operators >>/<<. + * + * All "integer" operations will be done using operator notation + * all "modp" operations should be done using the function calls + * below (interchange with Montgomery arithmetic). + * + */ + +#include "Tools/octetStream.h" +#include "Tools/random.h" + +#include "Math/bigint.h" +#include "Math/Zp_Data.h" + +class modp +{ + static bool rewind; + + mp_limb_t x[MAX_MOD_SZ]; + + public: + + // NEXT FUNCTION IS FOR DEBUG PURPOSES ONLY + mp_limb_t get_limb(int i) { return x[i]; } + + // use mem* functions instead of mpn_*, so the compiler can optimize + modp() + { memset(x, 0, sizeof(x)); } + modp(const modp& y) + { memcpy(x, y.x, sizeof(x)); } + modp& operator=(const modp& y) + { if (this!=&y) { memcpy(x, y.x, sizeof(x)); } + return *this; + } + + void assign(const char* buffer, int t) { memcpy(x, buffer, t * sizeof(mp_limb_t)); } + + void randomize(PRNG& G, const Zp_Data& ZpD); + + // Pack and unpack in native format + // i.e. Dont care about conversion to human readable form + // i.e. When we do montgomery we dont care about decoding + void pack(octetStream& o,const Zp_Data& ZpD) const; + void unpack(octetStream& o,const Zp_Data& ZpD); + + + /********************************** + * Modp Operations * + **********************************/ + + // Convert representation to and from a modp number + friend void to_bigint(bigint& ans,const modp& x,const Zp_Data& ZpD,bool reduce=true); + + friend void to_modp(modp& ans,int x,const Zp_Data& ZpD); + friend void to_modp(modp& ans,const bigint& x,const Zp_Data& ZpD); + + template + friend void Add(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD); + friend void Add(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD) + { ZpD.Add(ans.x, x.x, y.x); } + friend void Sub(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD); + friend void Mul(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD); + friend void Sqr(modp& ans,const modp& x,const Zp_Data& ZpD); + friend void Negate(modp& ans,const modp& x,const Zp_Data& ZpD); + friend void Inv(modp& ans,const modp& x,const Zp_Data& ZpD); + + friend void Power(modp& ans,const modp& x,int exp,const Zp_Data& ZpD); + friend void Power(modp& ans,const modp& x,const bigint& exp,const Zp_Data& ZpD); + + friend void assignOne(modp& x,const Zp_Data& ZpD); + friend void assignZero(modp& x,const Zp_Data& ZpD); + friend bool isZero(const modp& x,const Zp_Data& ZpD); + friend bool isOne(const modp& x,const Zp_Data& ZpD); + friend bool areEqual(const modp& x,const modp& y,const Zp_Data& ZpD); + + // Input and output from a stream + // - Can do in human or machine only format (later should be faster) + // - If human output appends a space to help with reading + // and also convert back/forth from Montgomery if needed + void output(ostream& s,const Zp_Data& ZpD,bool human) const; + void input(istream& s,const Zp_Data& ZpD,bool human); + + friend class gfp; + +}; + + +inline void assignZero(modp& x,const Zp_Data& ZpD) +{ + if (sizeof(x.x) <= 3 * 16) + // use memset to allow the compiler to optimize + // if x.x is at most 3*128 bits + memset(x.x, 0, sizeof(x.x)); + else + mpn_zero(x.x, ZpD.t + 1); +} + +template +inline void Add(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD) +{ + ZpD.Add(ans.x, x.x, y.x); +} + + +#endif + diff --git a/Math/operators.h b/Math/operators.h new file mode 100644 index 000000000..d4d34b4a9 --- /dev/null +++ b/Math/operators.h @@ -0,0 +1,37 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * operations.h + * + */ + +#ifndef MATH_OPERATORS_H_ +#define MATH_OPERATORS_H_ + +template +T operator*(const bool& x, const T& y) { return x ? y : T(); } +template +T operator*(const T& y, const bool& x) { return x ? y : T(); } +template +T& operator*=(const T& y, const bool& x) { y = x ? y : T(); return y; } + +template +T operator+(const T& x, const U& y) { T res; res.add(x, y); return res; } +template +T operator*(const T& x, const U& y) { T res; res.mul(x, y); return res; } +template +T operator-(const T& x, const U& y) { T res; res.sub(x, y); return res; } + +template +T& operator+=(T& x, const U& y) { x.add(y); return x; } +template +T& operator*=(T& x, const U& y) { x.mul(y); return x; } +template +T& operator-=(T& x, const U& y) { x.sub(y); return x; } + +template +T operator/(const T& x, const U& y) { U inv = y; inv.invert(); return x * inv; } +template +T& operator/=(const T& x, const U& y) { U inv = y; inv.invert(); return x *= inv; } + +#endif /* MATH_OPERATORS_H_ */ diff --git a/Networking/Player.cpp b/Networking/Player.cpp new file mode 100644 index 000000000..2325cec5d --- /dev/null +++ b/Networking/Player.cpp @@ -0,0 +1,411 @@ +// (C) 2016 University of Bristol. See License.txt + + +#include "Player.h" +#include "Exceptions/Exceptions.h" + +#include + +// Use printf rather than cout so valgrind can detect thread issues + +void Names::init(int player,int pnb,const char* servername) +{ + player_no=player; + portnum_base=pnb; + setup_names(servername); + setup_server(); +} + + +void Names::init(int player,int pnb,vector Nms) +{ + player_no=player; + portnum_base=pnb; + nplayers=Nms.size(); + names.resize(nplayers); + for (int i=0; i Nms) +{ + player_no=player; + portnum_base=pnb; + nplayers=Nms.size(); + names=Nms; + setup_server(); +} + +// initialize hostnames from file +void Names::init(int player, int _nplayers, int pnb, const string& filename) +{ + ifstream hostsfile(filename.c_str()); + if (hostsfile.fail()) + { + stringstream ss; + ss << "Error opening " << filename << ". See HOSTS.example for an example."; + throw file_error(ss.str().c_str()); + } + player_no = player; + nplayers = _nplayers; + portnum_base = pnb; + string line; + while (getline(hostsfile, line)) + { + if (line.length() > 0 && line.at(0) != '#') + names.push_back(line); + } + if ((int)names.size() < nplayers) + throw invalid_params(); + names.resize(nplayers); + for (unsigned int i = 0; i < names.size(); i++) + cerr << "name: " << names[i] << endl; + setup_server(); +} + +void Names::setup_names(const char *servername) +{ + int socket_num; + int pn = portnum_base - 1; + set_up_client_socket(socket_num, servername, pn); + send(socket_num, (octet*)&player_no, sizeof(player_no)); + cerr << "Sent " << player_no << " to " << servername << ":" << pn << endl; + + int inst=-1; // wait until instruction to start. + while (inst != GO) { receive(socket_num, inst); } + + // Send my name + octet my_name[512]; + memset(my_name,0,512*sizeof(octet)); + gethostname((char*)my_name,512); + fprintf(stderr, "My Name = %s\n",my_name); + send(socket_num,my_name,512); + cerr << "My number = " << player_no << endl; + + // Now get the set of names + int i; + receive(socket_num,nplayers); + cerr << nplayers << " players\n"; + names.resize(nplayers); + for (i=0; i& names,int portnum_base,int id_base,ServerSocket& server) +{ + sockets.resize(nplayers); + // Set up the client side + for (int i=player_no; i& o,bool donthash) const +{ for (int i=0; iplayer_no) + { o[player_no].Send(sockets[i]); } + else if (iplayer_no) + { o[i].reset_write_head(); + o[i].Receive(sockets[i]); + } + } + if (!donthash) + { for (int i=0; i h(nplayers); + blk_SHA1_Final(hashVal,&ctx); + h[player_no].append(hashVal,HASH_SIZE); + + Broadcast_Receive(h,true); + for (int i=0; i& players, vector& result) const +{ + fd_set rfds; + FD_ZERO(&rfds); + int highest = 0; + vector::iterator it; + for (it = players.begin(); it != players.end(); it++) + { + if (*it >= 0) + { + FD_SET(sockets[*it], &rfds); + highest = max(highest, sockets[*it]); + } + } + + int res = select(highest + 1, &rfds, 0, 0, 0); + + if (res < 0) + error("select()"); + + result.clear(); + result.reserve(res); + for (it = players.begin(); it != players.end(); it++) + { + if (res == 0) + break; + + if (*it >= 0 && FD_ISSET(sockets[*it], &rfds)) + { + res--; + result.push_back(*it); + } + } +} + + +ThreadPlayer::ThreadPlayer(const Names& Nms, int id_base) : Player(Nms, id_base) +{ + for (int i = 0; i < Nms.num_players(); i++) + { + receivers.push_back(new Receiver(sockets[i])); + receivers[i]->start(); + + senders.push_back(new Sender(socket_to_send(i))); + senders[i]->start(); + } +} + +ThreadPlayer::~ThreadPlayer() +{ + for (unsigned int i = 0; i < receivers.size(); i++) + { + receivers[i]->stop(); + if (receivers[i]->timer.elapsed() > 0) + cerr << "Waiting for receiving from " << i << ": " << receivers[i]->timer.elapsed() << endl; + delete receivers[i]; + } + + for (unsigned int i = 0; i < senders.size(); i++) + { + senders[i]->stop(); + if (senders[i]->timer.elapsed() > 0) + cerr << "Waiting for sending to " << i << ": " << senders[i]->timer.elapsed() << endl; + delete senders[i]; + } +} + +void ThreadPlayer::request_receive(int i, octetStream& o) const +{ + receivers[i]->request(o); +} + +void ThreadPlayer::wait_receive(int i, octetStream& o, bool donthash) const +{ + receivers[i]->wait(o); + if (!donthash) + { blk_SHA1_Update(&ctx,o.get_data(),o.get_length()); } +} + +void ThreadPlayer::receive_player(int i, octetStream& o, bool donthash) const +{ + request_receive(i, o); + wait_receive(i, o, donthash); +} + +void ThreadPlayer::send_all(const octetStream& o,bool donthash) const +{ + for (int i=0; irequest(o); + } + + if (!donthash) + { blk_SHA1_Update(&ctx,o.get_data(),o.get_length()); } + + for (int i = 0; i < nplayers; i++) + if (i != player_no) + senders[i]->wait(o); +} + + +TwoPartyPlayer::TwoPartyPlayer(const Names& Nms, int other_player, int id) : PlayerBase(Nms), other_player(other_player) +{ + is_server = Nms.my_num() > other_player; + setup_sockets(Nms.names[other_player].c_str(), *Nms.server, Nms.portnum_base + other_player, id); +} + +TwoPartyPlayer::~TwoPartyPlayer() +{ + close_client_socket(socket); +} + +void TwoPartyPlayer::setup_sockets(const char* hostname, ServerSocket& server, int pn, int id) +{ + if (is_server) + { + fprintf(stderr, "Setting up server with id %d\n",id); + socket = server.get_connection_socket(id); + } + else + { + fprintf(stderr, "Setting up client to %s:%d with id %d\n", hostname, pn, id); + set_up_client_socket(socket, hostname, pn); + ::send(socket, (unsigned char*)&id, sizeof(id)); + } +} + +int TwoPartyPlayer::other_player_num() const +{ + return other_player; +} + +void TwoPartyPlayer::send(octetStream& o) const +{ + o.Send(socket); +} + +void TwoPartyPlayer::receive(octetStream& o) const +{ + o.reset_write_head(); + o.Receive(socket); +} + +void TwoPartyPlayer::send_receive_player(vector& o) const +{ + { + if (is_server) + { + o[0].Send(socket); + o[1].reset_write_head(); + o[1].Receive(socket); + } + else + { + o[1].reset_write_head(); + o[1].Receive(socket); + o[0].Send(socket); + } + } +} diff --git a/Networking/Player.h b/Networking/Player.h new file mode 100644 index 000000000..500061197 --- /dev/null +++ b/Networking/Player.h @@ -0,0 +1,185 @@ +// (C) 2016 University of Bristol. See License.txt + +#ifndef _Player +#define _Player + +/* Class to create a player, for KeyGen, Offline and Online phases. + * + * Basically handles connection to the server to obtain the names + * of the other players. Plus sending and receiving of data + * + */ + +#include +#include +#include +#include +using namespace std; + +#include "Tools/octetStream.h" +#include "Networking/sockets.h" +#include "Networking/ServerSocket.h" +#include "Tools/sha1.h" +#include "Networking/Receiver.h" +#include "Networking/Sender.h" + +/* Class to get the names off the server */ +class Names +{ + vector names; + int nplayers; + int portnum_base; + int player_no; + + void setup_names(const char *servername); + + void setup_server(); + + public: + + mutable ServerSocket* server; + + // Usual setup names + void init(int player,int pnb,const char* servername); + Names(int player,int pnb,const char* servername) + { init(player,pnb,servername); } + // Set up names when we KNOW who we are going to be using before hand + void init(int player,int pnb,vector Nms); + Names(int player,int pnb,vector Nms) + { init(player,pnb,Nms); } + void init(int player,int pnb,vector Nms); + Names(int player,int pnb,vector Nms) + { init(player,pnb,Nms); } + // Set up names from file -- reads the first nplayers names in the file + void init(int player, int nplayers, int pnb, const string& hostsfile); + Names(int player, int nplayers, int pnb, const string& hostsfile) + { init(player, nplayers, pnb, hostsfile); } + + + Names() : nplayers(-1), portnum_base(-1), player_no(-1), server(0) { ; } + Names(const Names& other); + ~Names(); + + int num_players() const { return nplayers; } + int my_num() const { return player_no; } + const string get_name(int i) const { return names[i]; } + int get_portnum_base() const { return portnum_base; } + + friend class PlayerBase; + friend class Player; + friend class TwoPartyPlayer; +}; + + +class PlayerBase +{ +protected: + int player_no; + +public: + PlayerBase(const Names& Nms) : player_no(Nms.my_num()) {} + int my_num() const { return player_no; } +}; + + +class Player : public PlayerBase +{ +protected: + vector sockets; + int send_to_self_socket; + + void setup_sockets(const vector& names,int portnum_base,int id_base,ServerSocket& server); + + int nplayers; + + mutable blk_SHA_CTX ctx; + + map socket_players; + + int socket_to_send(int player) const { return player == player_no ? send_to_self_socket : sockets[player]; } + +public: + // The offset is used for the multi-threaded call, to ensure different + // portnum bases in each thread + Player(const Names& Nms,int id_base=0); + + virtual ~Player(); + + int num_players() const { return nplayers; } + int socket(int i) const { return sockets[i]; } + + // Send/Receive data to/from player i + // 8-bit ints only (mainly for testing) + void send_int(int i,int a) const { send(sockets[i],a); } + void receive_int(int i,int& a) const { receive(sockets[i],a); } + + // Send an octetStream to all other players + // -- And corresponding receive + virtual void send_all(const octetStream& o,bool donthash=false) const; + void send_to(int player,const octetStream& o,bool donthash=false) const; + virtual void receive_player(int i,octetStream& o,bool donthash=false) const; + + // Receive one from player i + + /* Broadcast and Receive data to/from all players + * - Assumes o[player_no] contains the thing broadcast by me + */ + void Broadcast_Receive(vector& o,bool donthash=false) const; + + /* Run Protocol To Verify Broadcast Is Correct + * - Resets the blk_SHA_CTX at the same time + */ + void Check_Broadcast() const; + + // wait for available inputs + void wait_for_available(vector& players, vector& result) const; + + // dummy functions for compatibility + virtual void request_receive(int i, octetStream& o) const { sockets[i]; o.get_length(); } + virtual void wait_receive(int i, octetStream& o, bool donthash=false) const { receive_player(i, o, donthash); } +}; + + +class ThreadPlayer : public Player +{ +public: + mutable vector receivers; + mutable vector senders; + + ThreadPlayer(const Names& Nms,int id_base=0); + virtual ~ThreadPlayer(); + + void request_receive(int i, octetStream& o) const; + void wait_receive(int i, octetStream& o, bool donthash=false) const; + void receive_player(int i,octetStream& o,bool donthash=false) const; + + void send_all(const octetStream& o,bool donthash=false) const; +}; + + +class TwoPartyPlayer : public PlayerBase +{ +private: + // setup sockets for comm. with only one other player + void setup_sockets(const char* hostname, ServerSocket& server, int pn, int id); + + int socket; + bool is_server; + int other_player; + +public: + TwoPartyPlayer(const Names& Nms, int other_player, int pn_offset=0); + ~TwoPartyPlayer(); + + void send(octetStream& o) const; + void receive(octetStream& o) const; + + int other_player_num() const; + + /* Send and receive to/from the other player + * - o[0] contains my data, received data put in o[1] + */ + void send_receive_player(vector& o) const; +}; + +#endif diff --git a/Networking/Receiver.cpp b/Networking/Receiver.cpp new file mode 100644 index 000000000..c7527f3c8 --- /dev/null +++ b/Networking/Receiver.cpp @@ -0,0 +1,58 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * Receiver.cpp + * + */ + +#include "Receiver.h" + +#include +using namespace std; + +void* run_receiver_thread(void* receiver) +{ + ((Receiver*)receiver)->run(); + return 0; +} + +Receiver::Receiver(int socket) : socket(socket), thread(0) +{ +} + +void Receiver::start() +{ + pthread_create(&thread, 0, run_receiver_thread, this); +} + +void Receiver::stop() +{ + in.stop(); + pthread_join(thread, 0); +} + +void Receiver::run() +{ + octetStream* os = 0; + while (in.pop(os)) + { + os->reset_write_head(); + timer.start(); + os->Receive(socket); + timer.stop(); + out.push(os); + } +} + +void Receiver::request(octetStream& os) +{ + in.push(&os); +} + +void Receiver::wait(octetStream& os) +{ + octetStream* queued = 0; + out.pop(queued); + if (queued != &os) + throw not_implemented(); +} diff --git a/Networking/Receiver.h b/Networking/Receiver.h new file mode 100644 index 000000000..f7d62d075 --- /dev/null +++ b/Networking/Receiver.h @@ -0,0 +1,40 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * Receiver.h + * + */ + +#ifndef NETWORKING_RECEIVER_H_ +#define NETWORKING_RECEIVER_H_ + +#include + +#include "Tools/octetStream.h" +#include "Tools/WaitQueue.h" +#include "Tools/time-func.h" + +class Receiver +{ + int socket; + WaitQueue in; + WaitQueue out; + pthread_t thread; + + // prevent copying + Receiver(const Receiver& other); + +public: + Timer timer; + + Receiver(int socket); + + void start(); + void stop(); + void run(); + + void request(octetStream& os); + void wait(octetStream& os); +}; + +#endif /* NETWORKING_RECEIVER_H_ */ diff --git a/Networking/Sender.cpp b/Networking/Sender.cpp new file mode 100644 index 000000000..89dc0eedd --- /dev/null +++ b/Networking/Sender.cpp @@ -0,0 +1,54 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * Sender.cpp + * + */ + +#include "Sender.h" + +void* run_sender_thread(void* sender) +{ + ((Sender*)sender)->run(); + return 0; +} + +Sender::Sender(int socket) : socket(socket), thread(0) +{ +} + +void Sender::start() +{ + pthread_create(&thread, 0, run_sender_thread, this); +} + +void Sender::stop() +{ + in.stop(); + pthread_join(thread, 0); +} + +void Sender::run() +{ + const octetStream* os = 0; + while (in.pop(os)) + { +// timer.start(); + os->Send(socket); +// timer.stop(); + out.push(os); + } +} + +void Sender::request(const octetStream& os) +{ + in.push(&os); +} + +void Sender::wait(const octetStream& os) +{ + const octetStream* queued = 0; + out.pop(queued); + if (queued != &os) + throw not_implemented(); +} diff --git a/Networking/Sender.h b/Networking/Sender.h new file mode 100644 index 000000000..ff95c8bbb --- /dev/null +++ b/Networking/Sender.h @@ -0,0 +1,40 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * Sender.h + * + */ + +#ifndef NETWORKING_SENDER_H_ +#define NETWORKING_SENDER_H_ + +#include + +#include "Tools/octetStream.h" +#include "Tools/WaitQueue.h" +#include "Tools/time-func.h" + +class Sender +{ + int socket; + WaitQueue in; + WaitQueue out; + pthread_t thread; + + // prevent copying + Sender(const Sender& other); + +public: + Timer timer; + + Sender(int socket); + + void start(); + void stop(); + void run(); + + void request(const octetStream& os); + void wait(const octetStream& os); +}; + +#endif /* NETWORKING_SENDER_H_ */ diff --git a/Networking/ServerSocket.cpp b/Networking/ServerSocket.cpp new file mode 100644 index 000000000..653068a1b --- /dev/null +++ b/Networking/ServerSocket.cpp @@ -0,0 +1,115 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * ServerSocket.cpp + * + */ + +#include +#include +#include "Exceptions/Exceptions.h" + +#include +#include + +#include +#include +using namespace std; + +void* accept_thread(void* server_socket) +{ + ((ServerSocket*)server_socket)->accept_clients(); + return 0; +} + +ServerSocket::ServerSocket(int Portnum) : portnum(Portnum) +{ + struct sockaddr_in serv; /* socket info about our server */ + + memset(&serv, 0, sizeof(serv)); /* zero the struct before filling the fields */ + serv.sin_family = AF_INET; /* set the type of connection to TCP/IP */ + serv.sin_addr.s_addr = INADDR_ANY; /* set our address to any interface */ + serv.sin_port = htons(Portnum); /* set the server port number */ + + main_socket = socket(AF_INET, SOCK_STREAM, 0); + if (main_socket<0) { error("set_up_socket:socket"); } + + int one=1; + int fl=setsockopt(main_socket,SOL_SOCKET,SO_REUSEADDR,(char*)&one,sizeof(int)); + if (fl<0) { error("set_up_socket:setsockopt"); } + + /* disable Nagle's algorithm */ + fl= setsockopt(main_socket, IPPROTO_TCP, TCP_NODELAY, (char*)&one,sizeof(int)); + if (fl<0) { error("set_up_socket:setsockopt"); } + + octet my_name[512]; + memset(my_name,0,512*sizeof(octet)); + gethostname((char*)my_name,512); + + /* bind serv information to mysocket + * - Just assume it will eventually wake up + */ + fl=1; + while (fl!=0) + { fl=bind(main_socket, (struct sockaddr *)&serv, sizeof(struct sockaddr)); + if (fl != 0) + { cerr << "Binding to socket on " << my_name << ":" << Portnum << " failed, trying again in a second ..." << endl; + sleep(1); + } + else + { cerr << "Bound on port " << Portnum << endl; } + } + if (fl<0) { error("set_up_socket:bind"); } + + /* start listening, allowing a queue of up to 1000 pending connection */ + fl=listen(main_socket, 1000); + if (fl<0) { error("set_up_socket:listen"); } + + pthread_create(&thread, 0, accept_thread, this); +} + +ServerSocket::~ServerSocket() +{ + pthread_cancel(thread); + pthread_join(thread, 0); + if (close(main_socket)) { error("close(main_socket"); }; +} + +void ServerSocket::accept_clients() +{ + while (true) + { + struct sockaddr dest; + memset(&dest, 0, sizeof(dest)); /* zero the struct before filling the fields */ + int socksize = sizeof(dest); + int consocket = accept(main_socket, (struct sockaddr *)&dest, (socklen_t*) &socksize); + if (consocket<0) { error("set_up_socket:accept"); } + + int client_id; + receive(consocket, (unsigned char*)&client_id, sizeof(client_id)); + + data_signal.lock(); + clients[client_id] = consocket; + data_signal.broadcast(); + data_signal.unlock(); + } +} + +int ServerSocket::get_connection_socket(int id) +{ + data_signal.lock(); + if (used.find(id) != used.end()) + { + stringstream ss; + ss << "Connection id " << hex << id << " already used"; + throw IO_Error(ss.str()); + } + + while (clients.find(id) == clients.end()) + data_signal.wait(); + + int client = clients[id]; + used.insert(id); + data_signal.unlock(); + return client; +} diff --git a/Networking/ServerSocket.h b/Networking/ServerSocket.h new file mode 100644 index 000000000..08a9cbc46 --- /dev/null +++ b/Networking/ServerSocket.h @@ -0,0 +1,44 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * ServerSocket.h + * + */ + +#ifndef NETWORKING_SERVERSOCKET_H_ +#define NETWORKING_SERVERSOCKET_H_ + +#include +#include +using namespace std; + +#include + +#include "Tools/WaitQueue.h" +#include "Tools/Signal.h" + +class ServerSocket +{ + int main_socket, portnum; + map clients; + set used; + Signal data_signal; + pthread_t thread; + + // disable copying + ServerSocket(const ServerSocket& other); + +public: + ServerSocket(int Portnum); + ~ServerSocket(); + + void accept_clients(); + + // This depends on clients sending their id as int. + // Has to be thread-safe. + int get_connection_socket(int number); + + void close_socket(); +}; + +#endif /* NETWORKING_SERVERSOCKET_H_ */ diff --git a/Networking/data.h b/Networking/data.h new file mode 100644 index 000000000..d131a6a9a --- /dev/null +++ b/Networking/data.h @@ -0,0 +1,45 @@ +// (C) 2016 University of Bristol. See License.txt + +#ifndef _Data +#define _Data + +#include + +#include "Exceptions/Exceptions.h" + + +typedef unsigned char octet; + +// Assumes word is a 64 bit value +#ifdef WIN32 + typedef unsigned __int64 word; +#else + typedef unsigned long word; +#endif + +#define BROADCAST 0 +#define ROUTE 1 +#define TERMINATE 2 +#define GO 3 + +void encode_length(octet *buff,int len); +int decode_length(octet *buff); + + +inline void encode_length(octet *buff,int len) +{ + if (len<0) { throw invalid_length(); } + buff[0]=len&255; + buff[1]=(len>>8)&255; + buff[2]=(len>>16)&255; + buff[3]=(len>>24)&255; +} + +inline int decode_length(octet *buff) +{ + int len=buff[0]+256*buff[1]+65536*buff[2]+16777216*buff[3]; + if (len<0) { throw invalid_length(); } + return len; +} + +#endif diff --git a/Networking/sockets.cpp b/Networking/sockets.cpp new file mode 100644 index 000000000..f225a13df --- /dev/null +++ b/Networking/sockets.cpp @@ -0,0 +1,225 @@ +// (C) 2016 University of Bristol. See License.txt + + +#include "sockets.h" +#include "Exceptions/Exceptions.h" + +#include +using namespace std; + +void error(const char *str) +{ + char err[1000]; + gethostname(err,1000); + strcat(err," : "); + strcat(err,str); + perror(err); + throw bad_value(); +} + +void error(const char *str1,const char *str2) +{ + char err[1000]; + gethostname(err,1000); + strcat(err," : "); + strcat(err,str1); + strcat(err,str2); + perror(err); + throw bad_value(); +} + + + +void set_up_server_socket(sockaddr_in& dest,int& consocket,int& main_socket,int Portnum) +{ + + struct sockaddr_in serv; /* socket info about our server */ + int socksize = sizeof(struct sockaddr_in); + + memset(&dest, 0, sizeof(dest)); /* zero the struct before filling the fields */ + memset(&serv, 0, sizeof(serv)); /* zero the struct before filling the fields */ + serv.sin_family = AF_INET; /* set the type of connection to TCP/IP */ + serv.sin_addr.s_addr = INADDR_ANY; /* set our address to any interface */ + serv.sin_port = htons(Portnum); /* set the server port number */ + + main_socket = socket(AF_INET, SOCK_STREAM, 0); + if (main_socket<0) { error("set_up_socket:socket"); } + + int one=1; + int fl=setsockopt(main_socket,SOL_SOCKET,SO_REUSEADDR,(char*)&one,sizeof(int)); + if (fl<0) { error("set_up_socket:setsockopt"); } + + /* disable Nagle's algorithm */ + fl= setsockopt(main_socket, IPPROTO_TCP, TCP_NODELAY, (char*)&one,sizeof(int)); + if (fl<0) { error("set_up_socket:setsockopt"); } + + octet my_name[512]; + memset(my_name,0,512*sizeof(octet)); + gethostname((char*)my_name,512); + + /* bind serv information to mysocket + * - Just assume it will eventually wake up + */ + fl=1; + while (fl!=0) + { fl=bind(main_socket, (struct sockaddr *)&serv, sizeof(struct sockaddr)); + if (fl != 0) + { cerr << "Binding to socket on " << my_name << ":" << Portnum << " failed, trying again in a second ..." << endl; + sleep(1); + } + else + { cerr << "Bound on port " << Portnum << endl; } + } + if (fl<0) { error("set_up_socket:bind"); } + + /* start listening, allowing a queue of up to 1 pending connection */ + fl=listen(main_socket, 1); + if (fl<0) { error("set_up_socket:listen"); } + + consocket = accept(main_socket, (struct sockaddr *)&dest, (socklen_t*) &socksize); + + if (consocket<0) { error("set_up_socket:accept"); } + +} + + +void close_server_socket(int consocket,int main_socket) +{ + if (close(consocket)) { error("close(socket)"); } + if (close(main_socket)) { error("close(main_socket"); }; +} + + + +void set_up_client_socket(int& mysocket,const char* hostname,int Portnum) +{ + mysocket = socket(AF_INET, SOCK_STREAM, 0); + if (mysocket<0) { error("set_up_socket:socket"); } + + /* disable Nagle's algorithm */ + int one=1; + int fl= setsockopt(mysocket, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int)); + if (fl<0) { error("set_up_socket:setsockopt"); } + + fl=setsockopt(mysocket, SOL_SOCKET, SO_REUSEADDR, (char*)&one, sizeof(int)); + if (fl<0) { error("set_up_socket:setsockopt"); } + + struct sockaddr_in dest; + dest.sin_family = AF_INET; + dest.sin_port = htons(Portnum); // set destination port number + + /* + struct hostent *server; + server=gethostbyname(hostname); + if (server== NULL) + { error("set_up_socket:gethostbyname"); } + bcopy((char *)server->h_addr, + (char *)&dest.sin_addr.s_addr, + server->h_length); // set destination IP number + */ + struct addrinfo hints, *ai=NULL,*rp; + memset (&hints, 0, sizeof(hints)); + hints.ai_family = AF_INET; + hints.ai_flags = AI_CANONNAME; + + octet my_name[512]; + memset(my_name,0,512*sizeof(octet)); + gethostname((char*)my_name,512); + + int erp; + for (int i = 0; i < 60; i++) + { erp=getaddrinfo (hostname, NULL, &hints, &ai); + if (erp == 0) + { break; } + else + { cerr << "getaddrinfo on " << my_name << " has returned '" << gai_strerror(erp) << + "' for " << hostname << ", trying again in a second ..." << endl; + if (ai) + freeaddrinfo(ai); + sleep(1); + } + } + if (erp!=0) + { error("set_up_socket:getaddrinfo"); } + + for (rp=ai; rp!=NULL; rp=rp->ai_next) + { const struct in_addr *addr4 = &((const struct sockaddr_in*)ai->ai_addr)->sin_addr; + + if (ai->ai_family == AF_INET) + { memcpy((char *)&dest.sin_addr.s_addr,addr4,sizeof(in_addr)); + continue; + } + } + freeaddrinfo(ai); + + + do + { fl=1; + while (fl==1 || errno==EINPROGRESS) + { fl=connect(mysocket, (struct sockaddr *)&dest, sizeof(struct sockaddr)); } + } + while (fl==-1 && errno==ECONNREFUSED); + if (fl<0) { error("set_up_socket:connect:",hostname); } +} + + + +void close_client_socket(int socket) +{ + if (close(socket)) + { + char tmp[1000]; + sprintf(tmp, "close(%d)", socket); + error(tmp); + } +} + + + +unsigned long long sent_amount = 0, sent_counter = 0; + + +void send(int socket,int a) +{ + unsigned char msg[1]; + msg[0]=a&255; + if (send(socket,msg,1,0)!=1) + { error("Send error - 2 "); } +} + + +void receive(int socket,int& a) +{ + unsigned char msg[1]; + int i=0; + while (i==0) + { i=recv(socket,msg,1,0); + if (i<0) { error("Receiving error - 2"); } + } + a=msg[0]; +} + + + +void send_ack(int socket) +{ + char msg[]="OK"; + if (send(socket,msg,2,0)!=2) + { error("Send Ack"); } +} + + +int get_ack(int socket) +{ + char msg[]="OK"; + char msg_r[2]; + int i=0,j; + while (2-i>0) + { j=recv(socket,msg_r+i,2-i,0); + i=i+j; + } + + if (msg_r[0]!=msg[0] || msg_r[1]!=msg[1]) { return 1; } + return 0; +} + diff --git a/Networking/sockets.h b/Networking/sockets.h new file mode 100644 index 000000000..a7c38fb08 --- /dev/null +++ b/Networking/sockets.h @@ -0,0 +1,67 @@ +// (C) 2016 University of Bristol. See License.txt + +#ifndef _sockets +#define _sockets + +#include "Networking/data.h" + +#include /* Errors */ +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +#include +#include +#include /* Wait for Process Termination */ + + +void error(const char *str1,const char *str2); +void error(const char *str); + +void set_up_server_socket(sockaddr_in& dest,int& consocket,int& main_socket,int Portnum); +void close_server_socket(int consocket,int main_socket); + +void set_up_client_socket(int& mysocket,const char* hostname,int Portnum); +void close_client_socket(int socket); + +void send(int socket,octet *msg,int len); +void receive(int socket,octet *msg,int len); + +/* Send and receive 8 bit integers */ +void send(int socket,int a); +void receive(int socket,int& a); + +void send_ack(int socket); +int get_ack(int socket); + + +extern unsigned long long sent_amount, sent_counter; + +inline void send(int socket,octet *msg,int len) +{ + if (send(socket,msg,len,0)!=len) + { error("Send error - 1 "); } + + sent_amount += len; + sent_counter++; +} + +inline void receive(int socket,octet *msg,int len) +{ + int i=0,j; + while (len-i>0) + { j=recv(socket,msg+i,len-i,0); + if (j<0) { error("Receiving error - 1"); } + i=i+j; + } +} + +#endif diff --git a/OT/BaseOT.cpp b/OT/BaseOT.cpp new file mode 100644 index 000000000..2788ed726 --- /dev/null +++ b/OT/BaseOT.cpp @@ -0,0 +1,294 @@ +// (C) 2016 University of Bristol. See License.txt + +#include "OT/BaseOT.h" +#include "Tools/random.h" + +#include +#include +#include +#include + +extern "C" { +#include "SimpleOT/ot_sender.h" +#include "SimpleOT/ot_receiver.h" +} + +using namespace std; + +const char* role_to_str(OT_ROLE role) +{ + if (role == RECEIVER) + return "RECEIVER"; + if (role == SENDER) + return "SENDER"; + return "BOTH"; +} + +OT_ROLE INV_ROLE(OT_ROLE role) +{ + if (role == RECEIVER) + return SENDER; + if (role == SENDER) + return RECEIVER; + else + return BOTH; +} + +void send_if_ot_sender(const TwoPartyPlayer* P, vector& os, OT_ROLE role) +{ + if (role == SENDER) + { + P->send(os[0]); + } + else if (role == RECEIVER) + { + P->receive(os[1]); + } + else + { + // both sender + receiver + P->send_receive_player(os); + } +} + +void send_if_ot_receiver(const TwoPartyPlayer* P, vector& os, OT_ROLE role) +{ + if (role == RECEIVER) + { + P->send(os[0]); + } + else if (role == SENDER) + { + P->receive(os[1]); + } + else + { + // both + P->send_receive_player(os); + } +} + + +void BaseOT::exec_base(bool new_receiver_inputs) +{ + int i, j, k, len; + PRNG G; + G.ReSeed(); + vector os(2); + SIMPLEOT_SENDER sender; + SIMPLEOT_RECEIVER receiver; + + unsigned char S_pack[ PACKBYTES ]; + unsigned char Rs_pack[ 2 ][ 4 * PACKBYTES ]; + unsigned char sender_keys[ 2 ][ 4 ][ HASHBYTES ]; + unsigned char receiver_keys[ 4 ][ HASHBYTES ]; + unsigned char cs[ 4 ]; + + if (ot_role & SENDER) + { + sender_genS(&sender, S_pack); + os[0].store_bytes(S_pack, sizeof(S_pack)); + } + send_if_ot_sender(P, os, ot_role); + + if (ot_role & RECEIVER) + { + os[1].get_bytes((octet*) receiver.S_pack, len); + if (len != HASHBYTES) + { + cerr << "Received invalid length in base OT\n"; + exit(1); + } + receiver_procS(&receiver); + receiver_maketable(&receiver); + } + + for (i = 0; i < nOT; i += 4) + { + if (ot_role & RECEIVER) + { + for (j = 0; j < 4; j++) + { + if (new_receiver_inputs) + receiver_inputs[i + j] = G.get_uchar()&1; + cs[j] = receiver_inputs[i + j]; + } + receiver_rsgen(&receiver, Rs_pack[0], cs); + os[0].reset_write_head(); + os[0].store_bytes(Rs_pack[0], sizeof(Rs_pack[0])); + receiver_keygen(&receiver, receiver_keys); + } + send_if_ot_receiver(P, os, ot_role); + + if (ot_role & SENDER) + { + os[1].get_bytes((octet*) Rs_pack[1], len); + if (len != sizeof(Rs_pack[1])) + { + cerr << "Received invalid length in base OT\n"; + exit(1); + } + sender_keygen(&sender, Rs_pack[1], sender_keys); + + // Copy 128 bits of keys to sender_inputs + for (j = 0; j < 4; j++) + { + for (k = 0; k < AES_BLK_SIZE; k++) + { + sender_inputs[i + j][0].set_byte(k, sender_keys[0][j][k]); + sender_inputs[i + j][1].set_byte(k, sender_keys[1][j][k]); + } + } + } + + if (ot_role & RECEIVER) + { + // Copy keys to receiver_outputs + for (j = 0; j < 4; j++) + { + for (k = 0; k < AES_BLK_SIZE; k++) + { + receiver_outputs[i + j].set_byte(k, receiver_keys[j][k]); + } + } + } + #ifdef BASE_OT_DEBUG + for (j = 0; j < 4; j++) + { + if (ot_role & SENDER) + { + printf("%4d-th sender keys:", i+j); + for (k = 0; k < HASHBYTES; k++) printf("%.2X", sender_keys[0][j][k]); + printf(" "); + for (k = 0; k < HASHBYTES; k++) printf("%.2X", sender_keys[1][j][k]); + printf("\n"); + } + if (ot_role & RECEIVER) + { + printf("%4d-th receiver key:", i+j); + for (k = 0; k < HASHBYTES; k++) printf("%.2X", receiver_keys[j][k]); + printf("\n"); + } + } + + printf("\n"); + #endif + } + set_seeds(); +} + +void BaseOT::set_seeds() +{ + for (int i = 0; i < nOT; i++) + { + // Set PRG seeds + if (ot_role & SENDER) + { + G_sender[i][0].SetSeed(sender_inputs[i][0].get_ptr()); + G_sender[i][1].SetSeed(sender_inputs[i][1].get_ptr()); + } + if (ot_role & RECEIVER) + { + G_receiver[i].SetSeed(receiver_outputs[i].get_ptr()); + } + } + extend_length(); +} + +void BaseOT::extend_length() +{ + for (int i = 0; i < nOT; i++) + { + if (ot_role & SENDER) + { + sender_inputs[i][0].randomize(G_sender[i][0]); + sender_inputs[i][1].randomize(G_sender[i][1]); + } + if (ot_role & RECEIVER) + { + receiver_outputs[i].randomize(G_receiver[i]); + } + } +} + + +void BaseOT::check() +{ + vector os(2); + BitVector tmp_vector(8 * AES_BLK_SIZE); + + + for (int i = 0; i < nOT; i++) + { + if (ot_role == SENDER) + { + // send both inputs over + sender_inputs[i][0].pack(os[0]); + sender_inputs[i][1].pack(os[0]); + P->send(os[0]); + } + else if (ot_role == RECEIVER) + { + P->receive(os[1]); + } + else + { + // both sender + receiver + sender_inputs[i][0].pack(os[0]); + sender_inputs[i][1].pack(os[0]); + P->send_receive_player(os); + } + if (ot_role & RECEIVER) + { + tmp_vector.unpack(os[1]); + + if (receiver_inputs[i] == 1) + { + tmp_vector.unpack(os[1]); + } + if (!tmp_vector.equals(receiver_outputs[i])) + { + cerr << "Incorrect OT\n"; + exit(1); + } + } + os[0].reset_write_head(); + os[1].reset_write_head(); + } +} + + +void FakeOT::exec_base(bool new_receiver_inputs) +{ + PRNG G; + G.ReSeed(); + vector os(2); + vector bv(2); + + if ((ot_role & RECEIVER) && new_receiver_inputs) + { + for (int i = 0; i < nOT; i++) + // Generate my receiver inputs + receiver_inputs[i] = G.get_uchar()&1; + } + + if (ot_role & SENDER) + for (int i = 0; i < nOT; i++) + for (int j = 0; j < 2; j++) + { + sender_inputs[i][j].randomize(G); + sender_inputs[i][j].pack(os[0]); + } + + send_if_ot_sender(P, os, ot_role); + + if (ot_role & RECEIVER) + for (int i = 0; i < nOT; i++) + { + for (int j = 0; j < 2; j++) + bv[j].unpack(os[1]); + receiver_outputs[i] = bv[receiver_inputs[i]]; + } + + set_seeds(); +} diff --git a/OT/BaseOT.h b/OT/BaseOT.h new file mode 100644 index 000000000..188801c35 --- /dev/null +++ b/OT/BaseOT.h @@ -0,0 +1,93 @@ +// (C) 2016 University of Bristol. See License.txt + +#ifndef _BASE_OT +#define _BASE_OT + +/* The OT thread uses the Miracl library, which is not thread safe. + * Thus all Miracl based code is contained in this one thread so as + * to avoid locking issues etc. + * + * Thus this thread serves all base OTs to all other threads + */ + +#include "Networking/Player.h" +#include "Tools/random.h" +#include "OT/BitVector.h" + +// currently always assumes BOTH, i.e. do 2 sets of OT symmetrically, +// use bitwise & to check for role +enum OT_ROLE +{ + RECEIVER = 0x01, + SENDER = 0x10, + BOTH = 0x11 +}; + +OT_ROLE INV_ROLE(OT_ROLE role); + +const char* role_to_str(OT_ROLE role); +void send_if_ot_sender(const TwoPartyPlayer* P, vector& os, OT_ROLE role); +void send_if_ot_receiver(const TwoPartyPlayer* P, vector& os, OT_ROLE role); + +class BaseOT +{ +public: + vector receiver_inputs; + vector< vector > sender_inputs; + vector receiver_outputs; + TwoPartyPlayer* P; + + BaseOT(int nOT, int ot_length, TwoPartyPlayer* player, OT_ROLE role=BOTH) + : P(player), nOT(nOT), ot_length(ot_length), ot_role(role) + { + receiver_inputs.resize(nOT); + sender_inputs.resize(nOT, vector(2)); + receiver_outputs.resize(nOT); + G_sender.resize(nOT, vector(2)); + G_receiver.resize(nOT); + + for (int i = 0; i < nOT; i++) + { + sender_inputs[i][0] = BitVector(8 * AES_BLK_SIZE); + sender_inputs[i][1] = BitVector(8 * AES_BLK_SIZE); + receiver_outputs[i] = BitVector(8 * AES_BLK_SIZE); + } + } + virtual ~BaseOT() {} + + int length() { return ot_length; } + + void set_receiver_inputs(const vector& new_inputs) + { + if ((int)new_inputs.size() != nOT) + throw invalid_length(); + receiver_inputs = new_inputs; + } + + // do the OTs -- generate fresh random choice bits by default + virtual void exec_base(bool new_receiver_inputs=true); + // use PRG to get the next ot_length bits + void extend_length(); + void check(); +protected: + int nOT, ot_length; + OT_ROLE ot_role; + + vector< vector > G_sender; + vector G_receiver; + + bool is_sender() { return (bool) (ot_role & SENDER); } + bool is_receiver() { return (bool) (ot_role & RECEIVER); } + + void set_seeds(); +}; + +class FakeOT : public BaseOT +{ +public: + FakeOT(int nOT, int ot_length, TwoPartyPlayer* player, OT_ROLE role=BOTH) : + BaseOT(nOT, ot_length, player, role) {} + void exec_base(bool new_receiver_inputs=true); +}; + +#endif diff --git a/OT/BitMatrix.cpp b/OT/BitMatrix.cpp new file mode 100644 index 000000000..012bcafce --- /dev/null +++ b/OT/BitMatrix.cpp @@ -0,0 +1,646 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * BitMatrix.cpp + * + */ + +#include +#include +#include + +#include "BitMatrix.h" +#include "Math/gf2n.h" +#include "Math/gfp.h" + +union matrix16x8 +{ + __m128i whole; + octet rows[16]; + + bool get_bit(int x, int y) + { return (rows[x] >> y) & 1; } + + void input(square128& input, int x, int y); + void transpose(square128& output, int x, int y); +}; + +class square16 +{ +public: + // 16x16 in two halves, 128 bits each + matrix16x8 halves[2]; + + bool get_bit(int x, int y) + { return halves[y/8].get_bit(x, y % 8); } + + void input(square128& output, int x, int y); + void transpose(square128& output, int x, int y); + + void check_transpose(square16& dual); + void print(); +}; + +__attribute__((optimize("unroll-loops"))) +inline void matrix16x8::input(square128& input, int x, int y) +{ + for (int l = 0; l < 16; l++) + rows[l] = input.bytes[16*x+l][y]; +} + +__attribute__((optimize("unroll-loops"))) +inline void square16::input(square128& input, int x, int y) +{ + for (int i = 0; i < 2; i++) + halves[i].input(input, x, 2 * y + i); +} + +__attribute__((optimize("unroll-loops"))) +inline void matrix16x8::transpose(square128& output, int x, int y) +{ + for (int j = 0; j < 8; j++) + { + int row = _mm_movemask_epi8(whole); + whole = _mm_slli_epi64(whole, 1); + + // _mm_movemask_epi8 uses most significant bit, hence +7-j + output.doublebytes[8*x+7-j][y] = row; + } +} + +__attribute__((optimize("unroll-loops"))) +inline void square16::transpose(square128& output, int x, int y) +{ + for (int i = 0; i < 2; i++) + halves[i].transpose(output, 2 * x + i, y); +} + +#ifdef __AVX2__ +union matrix32x8 +{ + __m256i whole; + octet rows[32]; + + void input(square128& input, int x, int y); + void transpose(square128& output, int x, int y); +}; + +class square32 +{ +public: + matrix32x8 quarters[4]; + + void input(square128& input, int x, int y); + void transpose(square128& output, int x, int y); +}; + +__attribute__((optimize("unroll-loops"))) +inline void matrix32x8::input(square128& input, int x, int y) +{ + for (int l = 0; l < 32; l++) + rows[l] = input.bytes[32*x+l][y]; +} + +__attribute__((optimize("unroll-loops"))) +inline void square32::input(square128& input, int x, int y) +{ + for (int i = 0; i < 4; i++) + quarters[i].input(input, x, 4 * y + i); +} + +__attribute__((optimize("unroll-loops"))) +inline void matrix32x8::transpose(square128& output, int x, int y) +{ + for (int j = 0; j < 8; j++) + { + int row = _mm256_movemask_epi8(whole); + whole = _mm256_slli_epi64(whole, 1); + + // _mm_movemask_epi8 uses most significant bit, hence +7-j + output.words[8*x+7-j][y] = row; + } +} + +__attribute__((optimize("unroll-loops"))) +inline void square32::transpose(square128& output, int x, int y) +{ + for (int i = 0; i < 4; i++) + quarters[i].transpose(output, 4 * x + i, y); +} +#endif + +#ifdef __AVX2__ +#warning Using AVX2 for transpose +typedef square32 subsquare; +#define N_SUBSQUARES 4 +#else +typedef square16 subsquare; +#define N_SUBSQUARES 8 +#endif + +__attribute__((optimize("unroll-loops"))) +void square128::transpose() +{ + for (int j = 0; j < N_SUBSQUARES; j++) + for (int k = 0; k < j; k++) + { + subsquare a, b; + a.input(*this, k, j); + b.input(*this, j, k); + a.transpose(*this, j, k); + b.transpose(*this, k, j); + } + + for (int j = 0; j < N_SUBSQUARES; j++) + { + subsquare a; + a.input(*this, j, j); + a.transpose(*this, j, j); + } +} + +void square128::randomize(PRNG& G) +{ + G.get_octets((octet*)&rows, sizeof(rows)); +} + +template <> +void square128::randomize(int row, PRNG& G) +{ + rows[row] = G.get_doubleword(); +} + +template <> +void square128::randomize(int row, PRNG& G) +{ + rows[row] = gfp::get_ZpD().get_random128(G); +} + + +void gfp_iadd(__m128i& a, __m128i& b) +{ + gfp::get_ZpD().Add((mp_limb_t*)&a, (mp_limb_t*)&a, (mp_limb_t*)&b); +} + +void gfp_isub(__m128i& a, __m128i& b) +{ + gfp::get_ZpD().Sub((mp_limb_t*)&a, (mp_limb_t*)&a, (mp_limb_t*)&b); +} + +void gfp_irsub(__m128i& a, __m128i& b) +{ + gfp::get_ZpD().Sub((mp_limb_t*)&a, (mp_limb_t*)&b, (mp_limb_t*)&a); +} + +template<> +void square128::conditional_add(BitVector& conditions, square128& other, int offset) +{ + for (int i = 0; i < 128; i++) + if (conditions.get_bit(128 * offset + i)) + rows[i] ^= other.rows[i]; +} + +template<> +void square128::conditional_add(BitVector& conditions, square128& other, int offset) +{ + for (int i = 0; i < 128; i++) + if (conditions.get_bit(128 * offset + i)) + gfp_iadd(rows[i], other.rows[i]); +} + +template +void square128::hash_row_wise(MMO& mmo, square128& input) +{ + mmo.hashBlockWise((octet*)rows, (octet*)input.rows); +} + +template <> +void square128::to(gf2n_long& result) +{ + int128 high, low; + for (int i = 0; i < 128; i++) + { + low ^= int128(rows[i]) << i; + high ^= int128(rows[i]) >> (128 - i); + } + result.reduce(high, low); +} + +template <> +void square128::to(gfp& result) +{ + mp_limb_t product[4], sum[4], tmp[2][4]; + memset(tmp, 0, sizeof(tmp)); + memset(sum, 0, sizeof(sum)); + for (int i = 0; i < 128; i++) + { + memcpy(&(tmp[i/64][i/64]), &(rows[i]), sizeof(rows[i])); + mpn_lshift(product, tmp[i/64], 4, i % 64); + mpn_add_n(sum, product, sum, 4); + } + mp_limb_t q[4], ans[4]; + mpn_tdiv_qr(q, ans, 0, sum, 4, gfp::get_ZpD().get_prA(), 2); + result = *(__m128i*)ans; +} + +void square128::check_transpose(square128& dual, int i, int k) +{ + for (int j = 0; j < 16; j++) + for (int l = 0; l < 16; l++) + if (get_bit(16 * i + j, 16 * k + l) != dual.get_bit(16 * k + l, 16 * i + j)) + { + cout << "Error in 16x16 square (" << i << "," << k << ")" << endl; + print(i, k); + dual.print(i, k); + exit(1); + } +} + +void square16::print() +{ + for (int i = 0; i < 2; i++) + { + for (int j = 0; j < 8; j++) + { + for (int k = 0; k < 2; k++) + { + for (int l = 0; l < 8; l++) + cout << halves[k].get_bit(8 * i + j, l); + cout << " "; + } + cout << endl; + } + cout << endl; + } +} + +void square128::print(int i, int k) +{ + square16 a; + a.input(*this, i, k); + a.print(); +} + +void square128::print() +{ + for (int i = 0; i < 128; i++) + { + for (int j = 0; j < 128; j++) + cout << get_bit(i, j); + cout << endl; + } +} + +void square128::set_zero() +{ + for (int i = 0; i < 128; i++) + rows[i] = _mm_setzero_si128(); +} + +square128& square128::operator^=(square128& other) +{ + for (int i = 0; i < 128; i++) + rows[i] ^= other.rows[i]; + return *this; +} + +template<> +square128& square128::add(square128& other) +{ + return *this ^= other; +} + +template<> +square128& square128::add(square128& other) +{ + for (int i = 0; i < 128; i++) + gfp_iadd(rows[i], other.rows[i]); + return *this; +} + +template<> +square128& square128::sub(square128& other) +{ + return *this ^= other; +} + +template<> +square128& square128::sub(square128& other) +{ + for (int i = 0; i < 128; i++) + gfp_isub(rows[i], other.rows[i]); + return *this; +} + +template<> +square128& square128::rsub(square128& other) +{ + return *this ^= other; +} + +template<> +square128& square128::rsub(square128& other) +{ + for (int i = 0; i < 128; i++) + gfp_irsub(rows[i], other.rows[i]); + return *this; +} + +square128& square128::operator^=(__m128i* other) +{ + __m128i value = _mm_loadu_si128(other); + for (int i = 0; i < 128; i++) + rows[i] ^= value; + return *this; +} + +template <> +square128& square128::sub(__m128i* other) +{ + return *this ^= other; +} + +template <> +square128& square128::sub(__m128i* other) +{ + __m128i value = _mm_loadu_si128(other); + for (int i = 0; i < 128; i++) + gfp_isub(rows[i], value); + return *this; +} + +square128& square128::operator^=(BitVector& other) +{ + return *this ^= (__m128i*)other.get_ptr(); +} + +bool square128::operator==(square128& other) +{ + for (int i = 0; i < 128; i++) + { + __m128i tmp = rows[i] ^ other.rows[i]; + if (not _mm_test_all_zeros(tmp, tmp)) + return false; + } + return true; +} + +void square128::pack(octetStream& o) const +{ + o.append((octet*)this->bytes, sizeof(bytes)); +} + +void square128::unpack(octetStream &o) +{ + o.consume((octet*)this->bytes, sizeof(bytes)); +} + + +BitMatrix::BitMatrix(int length) +{ + resize(length); +} + +void BitMatrix::resize(int length) +{ + if (length % 128 != 0) + throw invalid_length(); + squares.resize(length / 128); +} + +int BitMatrix::size() +{ + return squares.size() * 128; +} + +template +BitMatrix& BitMatrix::add(BitMatrix& other) +{ + if (squares.size() != other.squares.size()) + throw invalid_length(); + for (size_t i = 0; i < squares.size(); i++) + squares[i].add(other.squares[i]); + return *this; +} + +template +BitMatrix& BitMatrix::sub(BitMatrix& other) +{ + if (squares.size() != other.squares.size()) + throw invalid_length(); + for (size_t i = 0; i < squares.size(); i++) + squares[i].sub(other.squares[i]); + return *this; +} + +template +BitMatrix& BitMatrix::rsub(BitMatrixSlice& other) +{ + if (squares.size() < other.end) + throw invalid_length(); + for (size_t i = other.start; i < other.end; i++) + squares[i].rsub(other.bm.squares[i]); + return *this; +} + +template +BitMatrix& BitMatrix::sub(BitVector& other) +{ + if (squares.size() * 128 != other.size()) + throw invalid_length(); + for (size_t i = 0; i < squares.size(); i++) + squares[i].sub((__m128i*)other.get_ptr() + i); + return *this; +} + +bool BitMatrix::operator==(BitMatrix& other) +{ + if (squares.size() != other.squares.size()) + throw invalid_length(); + for (size_t i = 0; i < squares.size(); i++) + if (not(squares[i] == other.squares[i])) + return false; + return true; +} + +bool BitMatrix::operator!=(BitMatrix& other) +{ + return not (*this == other); +} + +void BitMatrix::randomize(PRNG& G) +{ + for (size_t i = 0; i < squares.size(); i++) + squares[i].randomize(G); +} + +void BitMatrix::randomize(int row, PRNG& G) +{ + for (size_t i = 0; i < squares.size(); i++) + squares[i].randomize(row, G); +} + +void BitMatrix::transpose() +{ + for (size_t i = 0; i < squares.size(); i++) + squares[i].transpose(); +} + +void BitMatrix::check_transpose(BitMatrix& dual) +{ + for (size_t i = 0; i < squares.size(); i++) + { + for (int j = 0; j < 128; j++) + for (int k = 0; k < 128; k++) + if (squares[i].get_bit(j, k) != dual.squares[i].get_bit(k, j)) + { + cout << "First error in square " << i << " row " << j + << " column " << k << endl; + squares[i].print(i / 8, j / 8); + dual.squares[i].print(i / 8, j / 8); + return; + } + } + cout << "No errors in transpose" << endl; +} + +void BitMatrix::print_side_by_side(BitMatrix& other) +{ + for (int i = 0; i < 32; i++) + { + for (int j = 0; j < 64; j++) + cout << squares[0].get_bit(i,j); + cout << " "; + for (int j = 0; j < 64; j++) + cout << other.squares[0].get_bit(i,j); + cout << endl; + } +} + +void BitMatrix::print_conditional(BitVector& conditions) +{ + for (int i = 0; i < 32; i++) + { + if (conditions.get_bit(i)) + for (int j = 0; j < 65; j++) + cout << " "; + for (int j = 0; j < 64; j++) + cout << squares[0].get_bit(i,j); + if (!conditions.get_bit(i)) + for (int j = 0; j < 65; j++) + cout << " "; + cout << endl; + } +} + +void BitMatrix::pack(octetStream& os) const +{ + for (size_t i = 0; i < squares.size(); i++) + squares[i].pack(os); +} + +void BitMatrix::unpack(octetStream& os) +{ + for (size_t i = 0; i < squares.size(); i++) + squares[i].unpack(os); +} + +void BitMatrix::to(vector& output) +{ + output.resize(128); + for (int i = 0; i < 128; i++) + { + output[i].resize(128 * squares.size()); + for (size_t j = 0; j < squares.size(); j++) + output[i].set_int128(j, squares[j].rows[i]); + } +} + +BitMatrixSlice::BitMatrixSlice(BitMatrix& bm, size_t start, size_t size) : + bm(bm), start(start), size(size) +{ + end = start + size; + if (end > bm.squares.size()) + { + stringstream ss; + ss << "Matrix slice (" << start << "," << end << ") larger than matrix (" << bm.squares.size() << ")"; + throw invalid_argument(ss.str()); + } +} + +template +BitMatrixSlice& BitMatrixSlice::rsub(BitMatrixSlice& other) +{ + bm.rsub(other); + return *this; +} + +template +BitMatrixSlice& BitMatrixSlice::add(BitVector& other, int repeat) +{ + if (end * 128 > other.size() * repeat) + throw invalid_length(); + for (size_t i = start; i < end; i++) + bm.squares[i].sub((__m128i*)other.get_ptr() + i / repeat); + return *this; +} + +template +void BitMatrixSlice::randomize(int row, PRNG& G) +{ + for (size_t i = start; i < end; i++) + bm.squares[i].randomize(row, G); +} + +template +void BitMatrixSlice::conditional_add(BitVector& conditions, BitMatrix& other, bool useOffset) +{ + for (size_t i = start; i < end; i++) + bm.squares[i].conditional_add(conditions, other.squares[i], useOffset * i); +} + +void BitMatrixSlice::transpose() +{ + for (size_t i = start; i < end; i++) + bm.squares[i].transpose(); +} + +template +void BitMatrixSlice::print() +{ + cout << "hex / value" << endl; + for (int i = 0; i < 16; i++) + { + cout << int128(bm.squares[0].rows[i]) << " " << T(bm.squares[0].rows[i]) << endl; + } + cout << endl; +} + +void BitMatrixSlice::pack(octetStream& os) const +{ + for (size_t i = start; i < end; i++) + bm.squares[i].pack(os); +} + +void BitMatrixSlice::unpack(octetStream& os) +{ + for (size_t i = start; i < end; i++) + bm.squares[i].unpack(os); +} + +template void BitMatrixSlice::conditional_add(BitVector& conditions, BitMatrix& other, bool useOffset); +template void BitMatrixSlice::conditional_add(BitVector& conditions, BitMatrix& other, bool useOffset); +template BitMatrixSlice& BitMatrixSlice::rsub(BitMatrixSlice& other); +template BitMatrixSlice& BitMatrixSlice::rsub(BitMatrixSlice& other); +template BitMatrixSlice& BitMatrixSlice::add(BitVector& other, int repeat); +template BitMatrixSlice& BitMatrixSlice::add(BitVector& other, int repeat); +template BitMatrix& BitMatrix::add(BitMatrix& other); +template BitMatrix& BitMatrix::add(BitMatrix& other); +template BitMatrix& BitMatrix::sub(BitMatrix& other); +template BitMatrix& BitMatrix::sub(BitMatrix& other); +template void BitMatrixSlice::print(); +template void BitMatrixSlice::print(); +template void BitMatrixSlice::randomize(int row, PRNG& G); +template void BitMatrixSlice::randomize(int row, PRNG& G); +template void square128::hash_row_wise(MMO& mmo, square128& input); +template void square128::hash_row_wise(MMO& mmo, square128& input); diff --git a/OT/BitMatrix.h b/OT/BitMatrix.h new file mode 100644 index 000000000..3dae60702 --- /dev/null +++ b/OT/BitMatrix.h @@ -0,0 +1,136 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * BitMatrix.h + * + */ + +#ifndef OT_BITMATRIX_H_ +#define OT_BITMATRIX_H_ + +#include +#include + +#include "BitVector.h" +#include "Tools/random.h" +#include "Tools/MMO.h" +#include "Math/gf2nlong.h" + +using namespace std; + +union square128 { + __m128i rows[128]; + octet bytes[128][16]; + int16_t doublebytes[128][8]; + int32_t words[128][4]; + + bool get_bit(int x, int y) + { return (bytes[x][y/8] >> (y % 8)) & 1; } + + void set_zero(); + + square128& operator^=(square128& other); + square128& operator^=(__m128i* other); + square128& operator^=(BitVector& other); + bool operator==(square128& other); + + template + square128& add(square128& other); + template + square128& sub(square128& other); + template + square128& rsub(square128& other); + template + square128& sub(__m128i* other); + + void randomize(PRNG& G); + template + void randomize(int row, PRNG& G); + template + void conditional_add(BitVector& conditions, square128& other, int offset); + void transpose(); + template + void hash_row_wise(MMO& mmo, square128& input); + template + void to(T& result); + + void check_transpose(square128& dual, int i, int k); + void print(int i, int k); + void print(); + + // Pack and unpack in native format + // i.e. Dont care about conversion to human readable form + void pack(octetStream& o) const; + void unpack(octetStream& o); +}; + +class BitMatrixSlice; + +class BitMatrix +{ +public: + vector squares; + + BitMatrix() {} + BitMatrix(int length); + void resize(int length); + int size(); + + template + BitMatrix& add(BitMatrix& other); + template + BitMatrix& sub(BitMatrix& other); + template + BitMatrix& rsub(BitMatrixSlice& other); + template + BitMatrix& sub(BitVector& other); + bool operator==(BitMatrix& other); + bool operator!=(BitMatrix& other); + + void randomize(PRNG& G); + void randomize(int row, PRNG& G); + void transpose(); + + void check_transpose(BitMatrix& dual); + void print_side_by_side(BitMatrix& other); + void print_conditional(BitVector& conditions); + + // Pack and unpack in native format + // i.e. Dont care about conversion to human readable form + void pack(octetStream& o) const; + void unpack(octetStream& o); + + void to(vector& output); +}; + +class BitMatrixSlice +{ + friend class BitMatrix; + + BitMatrix& bm; + size_t start, size, end; + +public: + BitMatrixSlice(BitMatrix& bm, size_t start, size_t size); + + template + BitMatrixSlice& rsub(BitMatrixSlice& other); + template + BitMatrixSlice& add(BitVector& other, int repeat = 1); + + template + void randomize(int row, PRNG& G); + template + void conditional_add(BitVector& conditions, BitMatrix& other, bool useOffset = false); + void transpose(); + + template + void print(); + + // Pack and unpack in native format + // i.e. Dont care about conversion to human readable form + void pack(octetStream& o) const; + void unpack(octetStream& o); +}; + +#endif /* OT_BITMATRIX_H_ */ diff --git a/OT/BitVector.cpp b/OT/BitVector.cpp new file mode 100644 index 000000000..285dfd6c1 --- /dev/null +++ b/OT/BitVector.cpp @@ -0,0 +1,107 @@ +// (C) 2016 University of Bristol. See License.txt + + +#include "OT/BitVector.h" +#include "Tools/random.h" +#include "Tools/octetStream.h" +#include "Math/gf2n.h" +#include "Math/gfp.h" + +#include + +void BitVector::randomize(PRNG& G) +{ + G.get_octets(bytes, nbytes); +} + +template<> +void BitVector::randomize_blocks(PRNG& G) +{ + randomize(G); +} + +template<> +void BitVector::randomize_blocks(PRNG& G) +{ + gfp tmp; + for (size_t i = 0; i < (nbits / 128); i++) + { + tmp.randomize(G); + for (int j = 0; j < 2; j++) + ((mp_limb_t*)bytes)[2*i+j] = tmp.get().get_limb(j); + } +} + +void BitVector::randomize_at(int a, int nb, PRNG& G) +{ + if (nb < 1) + throw invalid_length(); + G.get_octets(bytes + a, nb); +} + +/* + */ + +void BitVector::output(ostream& s,bool human) const +{ + if (human) + { + s << nbits << " " << hex; + for (unsigned int i = 0; i < nbytes; i++) + { + s << int(bytes[i]) << " "; + } + s << dec << endl; + } + else + { + int len = nbits; + s.write((char*) &len, sizeof(int)); + s.write((char*) bytes, nbytes); + } +} + + +void BitVector::input(istream& s,bool human) +{ + if (s.peek() == EOF) + { + if (s.tellg() == 0) + { + cout << "IO problem. Empty file?" << endl; + throw file_error(); + } + throw end_of_file(); + } + int len; + if (human) + { + s >> len >> hex; + resize(len); + for (size_t i = 0; i < nbytes; i++) + { + s >> bytes[i]; + } + s >> dec; + } + else + { + s.read((char*) &len, sizeof(int)); + resize(len); + s.read((char*) bytes, nbytes); + } +} + + +void BitVector::pack(octetStream& o) const +{ + o.append((octet*)bytes, nbytes); +} + + +void BitVector::unpack(octetStream& o) +{ + o.consume((octet*)bytes, nbytes); +} + + diff --git a/OT/BitVector.h b/OT/BitVector.h new file mode 100644 index 000000000..54eac5a4f --- /dev/null +++ b/OT/BitVector.h @@ -0,0 +1,212 @@ +// (C) 2016 University of Bristol. See License.txt + +#ifndef _BITVECTOR +#define _BITVECTOR + +/* Vector of bits */ + +#include +#include +using namespace std; +#include +#include + +#include "Exceptions/Exceptions.h" +#include "Networking/data.h" +// just for util functions +#include "Math/bigint.h" +#include "Math/gf2nlong.h" + +class PRNG; +class octetStream; + + +class BitVector +{ + octet* bytes; + + size_t nbytes; + size_t nbits; + size_t length; + + public: + + void assign(const BitVector& K) + { + if (nbits != K.nbits) + { + resize(K.nbits); + } + memcpy(bytes, K.bytes, nbytes); + } + void assign_bytes(char* new_bytes, int len) + { + resize(len*8); + memcpy(bytes, new_bytes, len); + } + void assign_zero() + { + memset(bytes, 0, nbytes); + } + // only grows, never destroys + void resize(size_t new_nbits) + { + if (nbits != new_nbits) + { + int new_nbytes = DIV_CEIL(new_nbits,8); + + if (nbits < new_nbits) + { + octet* tmp = new octet[new_nbytes]; + memcpy(tmp, bytes, nbytes); + delete[] bytes; + bytes = tmp; + } + + nbits = new_nbits; + nbytes = new_nbytes; + /* + // use realloc to preserve original contents + if (new_nbits < nbits) + { + memcpy(tmp, bytes, new_nbytes); + } + else + { + memset(tmp, 0, new_nbytes); + memcpy(tmp, bytes, nbytes); + }*/ + + // realloc may fail on size 0 + /*if (new_nbits == 0) + { + free(bytes); + bytes = (octet*) malloc(0);//new octet[0]; + //free(bytes); + return; + } + bytes = (octet*)realloc(bytes, nbytes); + if (bytes == NULL) + { + cerr << "realloc failed\n"; + exit(1); + }*/ + /*delete[] bytes; + nbits = new_nbits; + nbytes = DIV_CEIL(nbits, 8); + bytes = new octet[nbytes];*/ + } + } + unsigned int size() const { return nbits; } + unsigned int size_bytes() const { return nbytes; } + octet* get_ptr() { return bytes; } + + BitVector(size_t n=128) + { + nbits = n; + nbytes = DIV_CEIL(nbits, 8); + bytes = new octet[nbytes]; + length = n; + assign_zero(); + } + BitVector(const BitVector& K) + { + bytes = new octet[K.nbytes]; + nbytes = K.nbytes; + nbits = K.nbits; + assign(K); + } + ~BitVector() { + //cout << "Destroy, size = " << nbytes << endl; + delete[] bytes; + } + BitVector& operator=(const BitVector& K) + { + if (this!=&K) { assign(K); } + return *this; + } + + octet get_byte(int i) const { return bytes[i]; } + + void set_byte(int i, octet b) { bytes[i] = b; } + + // get the i-th 64-bit word + word get_word(int i) const { return *(word*)(bytes + i*8); } + + void set_word(int i, word w) + { + int offset = i * sizeof(word); + memcpy(bytes + offset, (octet*)&w, sizeof(word)); + } + + int128 get_int128(int i) const { return _mm_lddqu_si128((__m128i*)bytes + i); } + void set_int128(int i, int128 a) { *((__m128i*)bytes + i) = a.a; } + + int get_bit(int i) const + { + return (bytes[i/8] >> (i % 8)) & 1; + } + void set_bit(int i,unsigned int a) + { + int j = i/8, k = i&7; + if (a==1) + { bytes[j] |= (octet)(1UL< + void randomize_blocks(PRNG& G); + // randomize bytes a, ..., a+nb-1 + void randomize_at(int a, int nb, PRNG& G); + + void output(ostream& s,bool human) const; + void input(istream& s,bool human); + + // Pack and unpack in native format + // i.e. Dont care about conversion to human readable form + void pack(octetStream& o) const; + void unpack(octetStream& o); + + string str() + { + stringstream ss; + ss << hex; + for(size_t i(0);i < nbytes;++i) + ss << (int)bytes[i] << " "; + return ss.str(); + } +}; + +#endif diff --git a/OT/NPartyTripleGenerator.cpp b/OT/NPartyTripleGenerator.cpp new file mode 100644 index 000000000..14a14f7ba --- /dev/null +++ b/OT/NPartyTripleGenerator.cpp @@ -0,0 +1,555 @@ +// (C) 2016 University of Bristol. See License.txt + +#include "NPartyTripleGenerator.h" + +#include "OT/OTExtensionWithMatrix.h" +#include "OT/OTMultiplier.h" +#include "Math/gfp.h" +#include "Math/Share.h" +#include "Math/operators.h" +#include "Auth/Subroutines.h" +#include "Auth/MAC_Check.h" + +#include +#include +#include + +template +class Triple +{ +public: + T a[N]; + T b; + T c[N]; + + int repeat(int l) + { + switch (l) + { + case 0: + case 2: + return N; + case 1: + return 1; + default: + throw bad_value(); + } + } + + T& byIndex(int l, int j) + { + switch (l) + { + case 0: + return a[j]; + case 1: + return b; + case 2: + return c[j]; + default: + throw bad_value(); + } + } + + template + void amplify(const Triple& uncheckedTriple, PRNG& G) + { + b = uncheckedTriple.b; + for (int i = 0; i < N; i++) + for (int j = 0; j < M; j++) + { + typename T::value_type r; + r.randomize(G); + a[i] += r * uncheckedTriple.a[j]; + c[i] += r * uncheckedTriple.c[j]; + } + } + + void output(ostream& outputStream, int n = N, bool human = false) + { + for (int i = 0; i < n; i++) + { + a[i].output(outputStream, human); + b.output(outputStream, human); + c[i].output(outputStream, human); + } + } +}; + +template +class PlainTriple : public Triple +{ +public: + // this assumes that valueBits[1] is still set to the bits of b + void to(vector& valueBits, int i) + { + for (int j = 0; j < N; j++) + { + valueBits[0].set_int128(i * N + j, this->a[j].to_m128i()); + valueBits[2].set_int128(i * N + j, this->c[j].to_m128i()); + } + } +}; + +template +class ShareTriple : public Triple, N> +{ +public: + void from(PlainTriple& triple, vector*>& ot_multipliers, + int iTriple, const NPartyTripleGenerator& generator) + { + for (int l = 0; l < 3; l++) + { + int repeat = this->repeat(l); + for (int j = 0; j < repeat; j++) + { + T value = triple.byIndex(l,j); + T mac = value * generator.machine.get_mac_key(); + for (int i = 0; i < generator.nparties-1; i++) + mac += ot_multipliers[i]->macs[l][iTriple * repeat + j]; + Share& share = this->byIndex(l,j); + share.set_share(value); + share.set_mac(mac); + } + } + } + + T computeCheckMAC(const T& maskedA) + { + return this->c[0].get_mac() - maskedA * this->b.get_mac(); + } +}; + +/* + * Copies the relevant base OTs from setup + * N.B. setup must not be stored as it will be used by other threads + */ +NPartyTripleGenerator::NPartyTripleGenerator(OTTripleSetup& setup, + const Names& names, int thread_num, int _nTriples, int nloops, + TripleMachine& machine) : + globalPlayer(names, - thread_num * machine.nplayers * machine.nplayers), + thread_num(thread_num), + my_num(setup.get_my_num()), + nloops(nloops), + nparties(setup.get_nparties()), + machine(machine) +{ + nTriplesPerLoop = DIV_CEIL(_nTriples, nloops); + nTriples = nTriplesPerLoop * nloops; + field_size = 128; + nAmplify = machine.amplify ? N_AMPLIFY : 1; + nPreampTriplesPerLoop = nTriplesPerLoop * nAmplify; + + int n = nparties; + //baseReceiverInput = machines[0]->baseReceiverInput; + //baseSenderInputs.resize(n-1); + //baseReceiverOutputs.resize(n-1); + nbase = setup.get_nbase(); + baseReceiverInput.resize(nbase); + baseReceiverOutputs.resize(n - 1); + baseSenderInputs.resize(n - 1); + players.resize(n-1); + + gf2n_long::init_field(128); + + for (int i = 0; i < n-1; i++) + { + // i for indexing, other_player is actual number + int other_player, id; + if (i >= my_num) + other_player = i + 1; + else + other_player = i; + + // copy base OT inputs + outputs + for (int j = 0; j < 128; j++) + { + baseReceiverInput.set_bit(j, (unsigned int)setup.get_base_receiver_input(j)); + } + baseReceiverOutputs[i] = setup.baseOTs[i]->receiver_outputs; + baseSenderInputs[i] = setup.baseOTs[i]->sender_inputs; + + // new TwoPartyPlayer with unique id for each thread + pair of players + if (my_num < other_player) + id = (thread_num+1)*n*n + my_num*n + other_player; + else + id = (thread_num+1)*n*n + other_player*n + my_num; + players[i] = new TwoPartyPlayer(names, other_player, id); + cout << "Set up with player " << other_player << " in thread " << thread_num << " with id " << id << endl; + } + + pthread_mutex_init(&mutex, 0); + pthread_cond_init(&ready, 0); +} + +NPartyTripleGenerator::~NPartyTripleGenerator() +{ + for (size_t i = 0; i < players.size(); i++) + delete players[i]; + //delete nplayer; + pthread_mutex_destroy(&mutex); + pthread_cond_destroy(&ready); +} + +template +void* run_ot_thread(void* ptr) +{ + ((OTMultiplier*)ptr)->multiply(); + return NULL; +} + +template +void NPartyTripleGenerator::generate() +{ + vector< OTMultiplier* > ot_multipliers(nparties-1); + + timers["Generator thread"].start(); + + for (int i = 0; i < nparties-1; i++) + { + ot_multipliers[i] = new OTMultiplier(*this, i); + pthread_mutex_lock(&ot_multipliers[i]->mutex); + pthread_create(&(ot_multipliers[i]->thread), 0, run_ot_thread, ot_multipliers[i]); + } + + // add up the shares from each thread and write to file + stringstream ss; + ss << machine.prep_data_dir; + if (machine.generateBits) + ss << "Bits-"; + else + ss << "Triples-"; + ss << T::type_char() << "-P" << my_num; + if (thread_num != 0) + ss << "-" << thread_num; + ofstream outputFile(ss.str().c_str()); + + if (machine.generateBits) + generateBits(ot_multipliers, outputFile); + else + generateTriples(ot_multipliers, outputFile); + + timers["Generator thread"].stop(); + if (machine.output) + cout << "Written " << nTriples << " outputs to " << ss.str() << endl; + else + cout << "Generated " << nTriples << " outputs" << endl; + + // wait for threads to finish + for (int i = 0; i < nparties-1; i++) + { + pthread_mutex_unlock(&ot_multipliers[i]->mutex); + pthread_join(ot_multipliers[i]->thread, NULL); + cout << "OT thread " << i << " finished\n" << flush; + } + cout << "OT threads finished\n"; + + for (size_t i = 0; i < ot_multipliers.size(); i++) + delete ot_multipliers[i]; +} + +template<> +void NPartyTripleGenerator::generateBits(vector< OTMultiplier* >& ot_multipliers, + ofstream& outputFile) +{ + PRNG share_prg; + share_prg.ReSeed(); + + int nBitsToCheck = nTriplesPerLoop + field_size; + valueBits.resize(1); + valueBits[0].resize(ceil(1.0 * nBitsToCheck / field_size) * field_size); + MAC_Check MC(machine.get_mac_key()); + vector< Share > bits(nBitsToCheck); + vector< Share > to_open(1); + vector opened(1); + + start_progress(ot_multipliers); + + for (int k = 0; k < nloops; k++) + { + print_progress(k); + + valueBits[0].randomize_blocks(share_prg); + + for (int i = 0; i < nparties-1; i++) + pthread_cond_signal(&ot_multipliers[i]->ready); + timers["Authentication OTs"].start(); + for (int i = 0; i < nparties-1; i++) + pthread_cond_wait(&ot_multipliers[i]->ready, &ot_multipliers[i]->mutex); + timers["Authentication OTs"].stop(); + + octet seed[SEED_SIZE]; + Create_Random_Seed(seed, globalPlayer, SEED_SIZE); + PRNG G; + G.SetSeed(seed); + + Share check_sum; + gf2n r; + for (int j = 0; j < nBitsToCheck; j++) + { + gf2n mac_sum = bool(valueBits[0].get_bit(j)) * machine.get_mac_key(); + for (int i = 0; i < nparties-1; i++) + mac_sum += ot_multipliers[i]->macs[0][j]; + bits[j].set_share(valueBits[0].get_bit(j)); + bits[j].set_mac(mac_sum); + r.randomize(G); + check_sum += r * bits[j]; + } + + to_open[0] = check_sum; + MC.POpen_Begin(opened, to_open, globalPlayer); + MC.POpen_End(opened, to_open, globalPlayer); + MC.Check(globalPlayer); + + if (machine.output) + for (int j = 0; j < nTriplesPerLoop; j++) + bits[j].output(outputFile, false); + + for (int i = 0; i < nparties-1; i++) + pthread_cond_signal(&ot_multipliers[i]->ready); + } +} + +template<> +void NPartyTripleGenerator::generateBits(vector< OTMultiplier* >& ot_multipliers, + ofstream& outputFile) +{ + generateTriples(ot_multipliers, outputFile); +} + +template +void NPartyTripleGenerator::generateTriples(vector< OTMultiplier* >& ot_multipliers, + ofstream& outputFile) +{ + PRNG share_prg; + share_prg.ReSeed(); + + valueBits.resize(3); + for (int i = 0; i < 2; i++) + valueBits[2*i].resize(field_size * nPreampTriplesPerLoop); + valueBits[1].resize(field_size * nTriplesPerLoop); + vector< PlainTriple > preampTriples; + vector< PlainTriple > amplifiedTriples; + vector< ShareTriple > uncheckedTriples; + MAC_Check MC(machine.get_mac_key()); + + if (machine.amplify) + preampTriples.resize(nTriplesPerLoop); + if (machine.generateMACs) + { + amplifiedTriples.resize(nTriplesPerLoop); + uncheckedTriples.resize(nTriplesPerLoop); + } + + start_progress(ot_multipliers); + + for (int k = 0; k < nloops; k++) + { + print_progress(k); + + for (int j = 0; j < 2; j++) + valueBits[j].randomize_blocks(share_prg); + + timers["OTs"].start(); + for (int i = 0; i < nparties-1; i++) + pthread_cond_wait(&ot_multipliers[i]->ready, &ot_multipliers[i]->mutex); + timers["OTs"].stop(); + + for (int j = 0; j < nPreampTriplesPerLoop; j++) + { + T a(valueBits[0].get_int128(j)); + T b(valueBits[1].get_int128(j / nAmplify)); + T c = a * b; + timers["Triple computation"].start(); + for (int i = 0; i < nparties-1; i++) + { + c += ot_multipliers[i]->c_output[j]; + } + timers["Triple computation"].stop(); + if (machine.amplify) + { + preampTriples[j/nAmplify].a[j%nAmplify] = a; + preampTriples[j/nAmplify].b = b; + preampTriples[j/nAmplify].c[j%nAmplify] = c; + } + else + { + timers["Writing"].start(); + a.output(outputFile, false); + b.output(outputFile, false); + c.output(outputFile, false); + timers["Writing"].stop(); + } + } + + if (machine.amplify) + { + octet seed[SEED_SIZE]; + Create_Random_Seed(seed, globalPlayer, SEED_SIZE); + PRNG G; + G.SetSeed(seed); + for (int iTriple = 0; iTriple < nTriplesPerLoop; iTriple++) + { + PlainTriple triple; + triple.amplify(preampTriples[iTriple], G); + + if (machine.generateMACs) + amplifiedTriples[iTriple] = triple; + else + { + timers["Writing"].start(); + triple.output(outputFile); + timers["Writing"].stop(); + } + } + + if (machine.generateMACs) + { + for (int iTriple = 0; iTriple < nTriplesPerLoop; iTriple++) + amplifiedTriples[iTriple].to(valueBits, iTriple); + + for (int i = 0; i < nparties-1; i++) + pthread_cond_signal(&ot_multipliers[i]->ready); + timers["Authentication OTs"].start(); + for (int i = 0; i < nparties-1; i++) + pthread_cond_wait(&ot_multipliers[i]->ready, &ot_multipliers[i]->mutex); + timers["Authentication OTs"].stop(); + + for (int iTriple = 0; iTriple < nTriplesPerLoop; iTriple++) + { + uncheckedTriples[iTriple].from(amplifiedTriples[iTriple], ot_multipliers, iTriple, *this); + + if (!machine.check) + { + timers["Writing"].start(); + amplifiedTriples[iTriple].output(outputFile); + timers["Writing"].stop(); + } + } + + if (machine.check) + { + vector< Share > maskedAs(nTriplesPerLoop); + vector< ShareTriple > maskedTriples(nTriplesPerLoop); + for (int j = 0; j < nTriplesPerLoop; j++) + { + maskedTriples[j].amplify(uncheckedTriples[j], G); + maskedAs[j] = maskedTriples[j].a[0]; + } + + vector openedAs(nTriplesPerLoop); + MC.POpen_Begin(openedAs, maskedAs, globalPlayer); + MC.POpen_End(openedAs, maskedAs, globalPlayer); + + for (int j = 0; j < nTriplesPerLoop; j++) + MC.AddToCheck(maskedTriples[j].computeCheckMAC(openedAs[j]), int128(0), globalPlayer); + + MC.Check(globalPlayer); + + if (machine.generateBits) + generateBitsFromTriples(uncheckedTriples, MC, outputFile); + else + if (machine.output) + for (int j = 0; j < nTriplesPerLoop; j++) + uncheckedTriples[j].output(outputFile, 1); + } + } + } + + for (int i = 0; i < nparties-1; i++) + pthread_cond_signal(&ot_multipliers[i]->ready); + } +} + +template<> +void NPartyTripleGenerator::generateBitsFromTriples( + vector< ShareTriple >& triples, MAC_Check& MC, ofstream& outputFile) +{ + vector< Share > a_plus_b(nTriplesPerLoop), a_squared(nTriplesPerLoop); + for (int i = 0; i < nTriplesPerLoop; i++) + a_plus_b[i] = triples[i].a[0] + triples[i].b; + vector opened(nTriplesPerLoop); + MC.POpen_Begin(opened, a_plus_b, globalPlayer); + MC.POpen_End(opened, a_plus_b, globalPlayer); + for (int i = 0; i < nTriplesPerLoop; i++) + a_squared[i] = triples[i].a[0] * opened[i] - triples[i].c[0]; + MC.POpen_Begin(opened, a_squared, globalPlayer); + MC.POpen_End(opened, a_squared, globalPlayer); + Share one(gfp(1), globalPlayer.my_num(), MC.get_alphai()); + for (int i = 0; i < nTriplesPerLoop; i++) + { + gfp root = opened[i].sqrRoot(); + if (root.is_zero()) + continue; + Share bit = (triples[i].a[0] / root + one) / gfp(2); + if (machine.output) + bit.output(outputFile, false); + } +} + +template<> +void NPartyTripleGenerator::generateBitsFromTriples( + vector< ShareTriple >& triples, MAC_Check& MC, ofstream& outputFile) +{ + throw how_would_that_work(); + // warning gymnastics + triples[0]; + MC.number(); + outputFile << ""; +} + +template +void NPartyTripleGenerator::start_progress(vector< OTMultiplier* >& ot_multipliers) +{ + for (int i = 0; i < nparties-1; i++) + pthread_cond_wait(&ot_multipliers[i]->ready, &ot_multipliers[i]->mutex); + lock(); + signal(); + wait(); + gettimeofday(&last_lap, 0); + for (int i = 0; i < nparties-1; i++) + pthread_cond_signal(&ot_multipliers[i]->ready); +} + +void NPartyTripleGenerator::print_progress(int k) +{ + if (thread_num == 0 && my_num == 0) + { + struct timeval stop; + gettimeofday(&stop, 0); + if (timeval_diff_in_seconds(&last_lap, &stop) > 1) + { + double diff = timeval_diff_in_seconds(&machine.start, &stop); + double throughput = k * nTriplesPerLoop * machine.nthreads / diff; + double remaining = diff * (nloops - k) / k; + cout << k << '/' << nloops << ", throughput: " << throughput + << ", time left: " << remaining << ", elapsed: " << diff + << ", estimated total: " << (diff + remaining) << endl; + last_lap = stop; + } + } +} + +void NPartyTripleGenerator::lock() +{ + pthread_mutex_lock(&mutex); +} + +void NPartyTripleGenerator::unlock() +{ + pthread_mutex_unlock(&mutex); +} + +void NPartyTripleGenerator::signal() +{ + pthread_cond_signal(&ready); +} + +void NPartyTripleGenerator::wait() +{ + pthread_cond_wait(&ready, &mutex); +} + + +template void NPartyTripleGenerator::generate(); +template void NPartyTripleGenerator::generate(); diff --git a/OT/NPartyTripleGenerator.h b/OT/NPartyTripleGenerator.h new file mode 100644 index 000000000..91a905729 --- /dev/null +++ b/OT/NPartyTripleGenerator.h @@ -0,0 +1,84 @@ +// (C) 2016 University of Bristol. See License.txt + +#ifndef OT_NPARTYTRIPLEGENERATOR_H_ +#define OT_NPARTYTRIPLEGENERATOR_H_ + +#include "Networking/Player.h" +#include "OT/BaseOT.h" +#include "Tools/random.h" +#include "Tools/time-func.h" +#include "Math/gfp.h" +#include "Auth/MAC_Check.h" + +#include "OT/OTTripleSetup.h" +#include "OT/TripleMachine.h" +#include "OT/OTMultiplier.h" + +#include +#include + +#define N_AMPLIFY 3 + +template +class ShareTriple; + +class NPartyTripleGenerator +{ + //OTTripleSetup* setup; + Player globalPlayer; + + int thread_num; + int my_num; + int nbase; + + struct timeval last_lap; + + pthread_mutex_t mutex; + pthread_cond_t ready; + + template + void generateTriples(vector< OTMultiplier* >& ot_multipliers, ofstream& outputFile); + template + void generateBits(vector< OTMultiplier* >& ot_multipliers, ofstream& outputFile); + template + void generateBitsFromTriples(vector >& triples, + MAC_Check& MC, ofstream& outputFile); + + template + void start_progress(vector< OTMultiplier* >& ot_multipliers); + void print_progress(int k); + +public: + // TwoPartyPlayer's for OTs, n-party Player for sacrificing + vector players; + //vector machines; + BitVector baseReceiverInput; // same for every set of OTs + vector< vector< vector > > baseSenderInputs; + vector< vector > baseReceiverOutputs; + vector valueBits; + + int nTriples; + int nTriplesPerLoop; + int nloops; + int field_size; + int nAmplify; + int nPreampTriplesPerLoop; + int repeat[3]; + int nparties; + + TripleMachine& machine; + + map timers; + + NPartyTripleGenerator(OTTripleSetup& setup, const Names& names, int thread_num, int nTriples, int nloops, TripleMachine& machine); + ~NPartyTripleGenerator(); + template + void generate(); + + void lock(); + void unlock(); + void signal(); + void wait(); +}; + +#endif diff --git a/OT/OTExtension.cpp b/OT/OTExtension.cpp new file mode 100644 index 000000000..8efd07af7 --- /dev/null +++ b/OT/OTExtension.cpp @@ -0,0 +1,791 @@ +// (C) 2016 University of Bristol. See License.txt + +#include "OTExtension.h" + +#include "OT/Tools.h" +#include "Math/gf2n.h" +#include "Tools/aes.h" +#include "Tools/MMO.h" +#include +#include + + +word TRANSPOSE_MASKS128[7][2] = { + { 0x0000000000000000, 0xFFFFFFFFFFFFFFFF }, + { 0x00000000FFFFFFFF, 0x00000000FFFFFFFF }, + { 0x0000FFFF0000FFFF, 0x0000FFFF0000FFFF }, + { 0x00FF00FF00FF00FF, 0x00FF00FF00FF00FF }, + { 0x0F0F0F0F0F0F0F0F, 0x0F0F0F0F0F0F0F0F }, + { 0x3333333333333333, 0x3333333333333333 }, + { 0x5555555555555555, 0x5555555555555555 } +}; + +string word_to_str(word a) +{ + stringstream ss; + ss << hex; + for(int i = 0;i < 8; i++) + ss << ((a >> (i*8)) & 255) << " "; + return ss.str(); +} + +// Transpose 16x16 matrix starting at bv[x][y] in-place using SSE2 +void sse_transpose16(vector& bv, int x, int y) +{ + __m128i input[2]; + // 16x16 in two halves, 128 bits each + for (int i = 0; i < 2; i++) + for (int j = 0; j < 16; j++) + ((octet*)&input[i])[j] = bv[x+j].get_byte(y / 8 + i); + + for (int i = 0; i < 2; i++) + for (int j = 0; j < 8; j++) + { + int output = _mm_movemask_epi8(input[i]); + input[i] = _mm_slli_epi64(input[i], 1); + + for (int k = 0; k < 2; k++) + // _mm_movemask_epi8 uses most significant bit, hence +7-j + bv[x+8*i+7-j].set_byte(y / 8 + k, ((octet*)&output)[k]); + } +} + +/* + * Transpose 128x128 bit-matrix using Eklundh's algorithm + * + * Input is in input[i] [ bits to ], i = 0, ..., 127 + * Output is in output[i + offset] (entire 128-bit vector), i = 0, ..., 127. + * + * Transposes 128-bit vectors in little-endian format. + */ +//void eklundh_transpose64(vector& output, const vector& input, int offset) +void eklundh_transpose128(vector& output, const vector& input, + int offset) +{ + int width = 64; + int logn = 7, nswaps = 1; + +#ifdef TRANSPOSE_DEBUG + stringstream input_ss[128]; + stringstream output_ss[128]; +#endif + // first copy input to output + for (int i = 0; i < 128; i++) + { + //output[i + offset*64] = input[i].get(offset); + output[i + offset].set_word(0, input[i].get_word(offset/64)); + output[i + offset].set_word(1, input[i].get_word(offset/64 + 1)); + +#ifdef TRANSPOSE_DEBUG + for (int j = 0; j < 128; j++) + { + input_ss[j] << input[i].get_bit(offset + j); + } +#endif + } + + // now transpose output in-place + for (int i = 0; i < logn; i++) + { + word mask1 = TRANSPOSE_MASKS128[i][1], mask2 = TRANSPOSE_MASKS128[i][0]; + word inv_mask1 = ~mask1, inv_mask2 = ~mask2; + + if (width == 8) + { + for (int j = 0; j < 8; j++) + for (int k = 0; k < 8; k++) + sse_transpose16(output, offset + 16 * j, 16 * k); + break; + } + else + // for width >= 64, shift is undefined so treat as a special case + // (and avoid branching in inner loop) + if (width < 64) + { + for (int j = 0; j < nswaps; j++) + { + for (int k = 0; k < width; k++) + { + int i1 = k + 2*width*j; + int i2 = k + width + 2*width*j; + + // t1 is lower 64 bits, t2 is upper 64 bits + // (remember we're transposing in little-endian format) + word t1 = output[i1 + offset].get_word(0); + word t2 = output[i1 + offset].get_word(1); + + word tt1 = output[i2 + offset].get_word(0); + word tt2 = output[i2 + offset].get_word(1); + + // swap operations due to little endian-ness + output[i1 + offset].set_word(0, (t1 & mask1) ^ + ((tt1 & mask1) << width)); + output[i1 + offset].set_word(1, (t2 & mask2) ^ + ((tt2 & mask2) << width) ^ + ((tt1 & mask1) >> (64 - width))); + + output[i2 + offset].set_word(0, (tt1 & inv_mask1) ^ + ((t1 & inv_mask1) >> width) ^ + ((t2 & inv_mask2)) << (64 - width)); + output[i2 + offset].set_word(1, (tt2 & inv_mask2) ^ + ((t2 & inv_mask2) >> width)); + } + } + } + else + { + for (int j = 0; j < nswaps; j++) + { + for (int k = 0; k < width; k++) + { + int i1 = k + 2*width*j; + int i2 = k + width + 2*width*j; + + // t1 is lower 64 bits, t2 is upper 64 bits + // (remember we're transposing in little-endian format) + word t1 = output[i1 + offset].get_word(0); + word t2 = output[i1 + offset].get_word(1); + + word tt1 = output[i2 + offset].get_word(0); + word tt2 = output[i2 + offset].get_word(1); + + output[i1 + offset].set_word(0, (t1 & mask1)); + output[i1 + offset].set_word(1, (t2 & mask2) ^ + ((tt1 & mask1) >> (64 - width))); + + output[i2 + offset].set_word(0, (tt1 & inv_mask1) ^ + ((t2 & inv_mask2)) << (64 - width)); + output[i2 + offset].set_word(1, (tt2 & inv_mask2)); + } + } + } + nswaps *= 2; + width /= 2; + } +#ifdef TRANSPOSE_DEBUG + for (int i = 0; i < 128; i++) + { + for (int j = 0; j < 128; j++) + { + output_ss[j] << output[offset + j].get_bit(i); + } + } + for (int i = 0; i < 128; i++) + { + if (output_ss[i].str().compare(input_ss[i].str()) != 0) + { + cerr << "String " << i << " failed. offset = " << offset << endl; + cerr << input_ss[i].str() << endl; + cerr << output_ss[i].str() << endl; + exit(1); + } + } + cout << "\ttranspose with offset " << offset << " ok\n"; +#endif +} + + +// get bit, starting from MSB as bit 0 +int get_bit(word x, int b) +{ + return (x >> (63 - b)) & 1; +} +int get_bit128(word x1, word x2, int b) +{ + if (b < 64) + { + return (x1 >> (b - 64)) & 1; + } + else + { + return (x2 >> b) & 1; + } +} + +void naive_transpose128(vector& output, const vector& input, + int offset) +{ + for (int i = 0; i < 128; i++) + { + // NB: words are read from input in big-endian format + word w1 = input[i].get_word(offset/64); + word w2 = input[i].get_word(offset/64 + 1); + + for (int j = 0; j < 128; j++) + { + //output[j + offset].set_bit(i, input[i].get_bit(j + offset)); + + if (j < 64) + output[j + offset].set_bit(i, (w1 >> j) & 1); + else + output[j + offset].set_bit(i, w2 >> (j-64) & 1); + } + } +} + +void transpose64( + vector::iterator& output_it, + vector::iterator& input_it) +{ + for (int i = 0; i < 64; i++) + { + for (int j = 0; j < 64; j++) + { + (output_it + j)->set_bit(i, (input_it + i)->get_bit(j)); + } + } +} + +// Naive 64x64 bit matrix transpose +void naive_transpose64(vector& output, const vector& input, + int xoffset, int yoffset) +{ + int word_size = 64; + + for (int i = 0; i < word_size; i++) + { + word w = input[i + yoffset].get_word(xoffset); + for (int j = 0; j < word_size; j++) + { + //cout << j + xoffset*word_size << ", " << yoffset/word_size << endl; + //int wbit = (((w >> j) & 1) << i); cout << "wbit " << wbit << endl; + + // set i-th bit of output to j-th bit of w + // scale yoffset by 64 since we're selecting words from the BitVector + word tmp = output[j + xoffset*word_size].get_word(yoffset/word_size); + output[j + xoffset*word_size].set_word(yoffset/word_size, tmp ^ ((w >> j) & 1) << i); + // set i-th bit of output to j-th bit of w + //output[j + offset*word_size] ^= ((w >> j) & 1) << i; + } + } +} + +void OTExtension::transfer(int nOTs, + const BitVector& receiverInput) +{ +#ifdef OTEXT_TIMER + timeval totalstartv, totalendv; + gettimeofday(&totalstartv, NULL); +#endif + cout << "\tDoing " << nOTs << " extended OTs as " << role_to_str(ot_role) << endl; + // add k + s to account for discarding k OTs + nOTs += 2 * 128; + if (nOTs % nbaseOTs != 0) + throw invalid_length(); //"nOTs must be a multiple of nbaseOTs\n"); + if (nOTs == 0) + return; + + vector t0(nbaseOTs, BitVector(nOTs)), tmp(nbaseOTs, BitVector(nOTs)), t1(nbaseOTs, BitVector(nOTs)); + BitVector u(nOTs); + senderOutput.resize(2, vector(nOTs, BitVector(nbaseOTs))); + // resize to account for extra k OTs that are discarded + PRNG G; + G.ReSeed(); + BitVector newReceiverInput(nOTs); + for (unsigned int i = 0; i < receiverInput.size_bytes(); i++) + { + newReceiverInput.set_byte(i, receiverInput.get_byte(i)); + } + + //BitVector newReceiverInput(receiverInput); + newReceiverInput.resize(nOTs); + + receiverOutput.resize(nOTs, BitVector(nbaseOTs)); + + for (int loop = 0; loop < nloops; loop++) + { + vector os(2), tmp_os(2); + + // randomize last 128 + 128 bits that will be discarded + for (int i = 0; i < 4; i++) + newReceiverInput.set_word(nOTs/64 - i, G.get_word()); + + // expand with PRG and create correlation + if (ot_role & RECEIVER) + { + for (int i = 0; i < nbaseOTs; i++) + { + t0[i].randomize(G_sender[i][0]); + t1[i].randomize(G_sender[i][1]); + tmp[i].assign(t1[i]); + + tmp[i].add(t0[i]); + tmp[i].add(newReceiverInput); + tmp[i].pack(os[0]); + /*cout << "t0: " << t0[i].str() << endl; + cout << "t1: " << t1[i].str() << endl; + cout << "Sending tmp: " << tmp[i].str() << endl;*/ + } + } +#ifdef OTEXT_TIMER + timeval commst1, commst2; + gettimeofday(&commst1, NULL); +#endif + // send t0 + t1 + x + send_if_ot_receiver(player, os, ot_role); + + // sender adjusts using base receiver bits + if (ot_role & SENDER) + { + for (int i = 0; i < nbaseOTs; i++) + { + // randomize base receiver output + tmp[i].randomize(G_receiver[i]); + + // u = t0 + t1 + x + u.unpack(os[1]); + + if (baseReceiverInput.get_bit(i) == 1) + { + // now tmp is q[i] = t0[i] + Delta[i] * x + tmp[i].add(u); + } + } + } +#ifdef OTEXT_TIMER + gettimeofday(&commst2, NULL); + double commstime = timeval_diff(&commst1, &commst2); + cout << "\t\tCommunication took time " << commstime/1000000 << endl << flush; + times["Communication"] += timeval_diff(&commst1, &commst2); +#endif + + // transpose t0[i] onto receiverOutput and tmp (q[i]) onto senderOutput[i][0] + + // stupid transpose + /*for (int j = 0; j < nOTs; j++) + { + for (int i = 0; i < nbaseOTs; i++) + { + senderOutput[0][j].set_bit(i, t0[i].get_bit(j)); + receiverOutput[j].set_bit(i, tmp[i].get_bit(j)); + } + }*/ + cout << "Starting matrix transpose\n" << flush << endl; +#ifdef OTEXT_TIMER + timeval transt1, transt2; + gettimeofday(&transt1, NULL); +#endif + // transpose in 128-bit chunks with Eklundh's algorithm + for (int i = 0; i < nOTs / 128; i++) + { + if (ot_role & RECEIVER) + { + eklundh_transpose128(receiverOutput, t0, i*128); + //naive_transpose128(receiverOutput, t0, i*128); + } + if (ot_role & SENDER) + { + eklundh_transpose128(senderOutput[0], tmp, i*128); + //naive_transpose128(senderOutput[0], tmp, i*128); + } + } + +#ifdef OTEXT_TIMER + gettimeofday(&transt2, NULL); + double transtime = timeval_diff(&transt1, &transt2); + cout << "\t\tMatrix transpose took time " << transtime/1000000 << endl << flush; + times["Matrix transpose"] += timeval_diff(&transt1, &transt2); +#endif + +#ifdef OTEXT_DEBUG + // verify correctness of the OTs + // i.e. senderOutput[0][i] + x_i * Delta = receiverOutput[i] + // (where Delta = baseReceiverOutput) + BitVector tmp_vector1(nbaseOTs), tmp_vector2(nOTs);//nbaseOTs); + cout << "\tVerifying OT extensions (debugging)\n"; + for (int i = 0; i < nOTs; i++) + { + os[0].reset_write_head(); + os[1].reset_write_head(); + + if (ot_role & RECEIVER) + { + // send t0 and x over + receiverOutput[i].pack(os[0]); + //t0[i].pack(os[0]); + newReceiverInput.pack(os[0]); + } + send_if_ot_receiver(player, os, ot_role); + + if (ot_role & SENDER) + { + tmp_vector1.unpack(os[1]); + tmp_vector2.unpack(os[1]); + + // if x_i = 1, add Delta + if (tmp_vector2.get_bit(i) == 1) + { + tmp_vector1.add(baseReceiverInput); + } + if (!tmp_vector1.equals(senderOutput[0][i])) + { + cerr << "Incorrect OT at " << i << "\n"; + exit(1); + } + } + } + cout << "Correlated OTs all OK\n"; +#endif + +#ifdef OTEXT_TIMER + double elapsed; +#endif + // correlation check + if (!passive_only) + { +#ifdef OTEXT_TIMER + timeval startv, endv; + gettimeofday(&startv, NULL); +#endif + check_correlation(nOTs, newReceiverInput); +#ifdef OTEXT_TIMER + gettimeofday(&endv, NULL); + elapsed = timeval_diff(&startv, &endv); + cout << "\t\tTotal correlation check time: " << elapsed/1000000 << endl << flush; + times["Total correlation check"] += timeval_diff(&startv, &endv); +#endif + } + + hash_outputs(nOTs, receiverOutput); +#ifdef OTEXT_TIMER + gettimeofday(&totalendv, NULL); + elapsed = timeval_diff(&totalstartv, &totalendv); + cout << "\t\tTotal thread time: " << elapsed/1000000 << endl << flush; +#endif + +#ifdef OTEXT_DEBUG + // verify correctness of the random OTs + // i.e. senderOutput[0][i] + x_i * Delta = receiverOutput[i] + // (where Delta = baseReceiverOutput) + cout << "Verifying random OTs (debugging)\n"; + for (int i = 0; i < nOTs; i++) + { + os[0].reset_write_head(); + os[1].reset_write_head(); + + if (ot_role & RECEIVER) + { + // send receiver's input/output over + receiverOutput[i].pack(os[0]); + newReceiverInput.pack(os[0]); + } + send_if_ot_receiver(player, os, ot_role); + //player->send_receive_player(os); + if (ot_role & SENDER) + { + tmp_vector1.unpack(os[1]); + tmp_vector2.unpack(os[1]); + + // if x_i = 1, comp with sender output[1] + if ((tmp_vector2.get_bit(i) == 1)) + { + if (!tmp_vector1.equals(senderOutput[1][i])) + { + cerr << "Incorrect OT\n"; + exit(1); + } + } + // else should be sender output[0] + else if (!tmp_vector1.equals(senderOutput[0][i])) + { + cerr << "Incorrect OT\n"; + exit(1); + } + } + } + cout << "Random OTs all OK\n"; +#endif + } + +#ifdef OTEXT_TIMER + gettimeofday(&totalendv, NULL); + times["Total thread"] += timeval_diff(&totalstartv, &totalendv); +#endif + + receiverOutput.resize(nOTs - 2 * 128); + senderOutput[0].resize(nOTs - 2 * 128); + senderOutput[1].resize(nOTs - 2 * 128); +} + +/* + * Hash outputs to make into random OT + */ +void OTExtension::hash_outputs(int nOTs, vector& receiverOutput) +{ + cout << "Hashing... " << flush; + octetStream os, h_os(HASH_SIZE); + BitVector tmp(nbaseOTs); + MMO mmo; +#ifdef OTEXT_TIMER + timeval startv, endv; + gettimeofday(&startv, NULL); +#endif + + for (int i = 0; i < nOTs; i++) + { + if (ot_role & SENDER) + { + tmp.add(senderOutput[0][i], baseReceiverInput); + + if (senderOutput[0][i].size() == 128) + { + mmo.hashOneBlock(senderOutput[0][i].get_ptr(), senderOutput[0][i].get_ptr()); + mmo.hashOneBlock(senderOutput[1][i].get_ptr(), tmp.get_ptr()); + } + else + { + os.reset_write_head(); + h_os.reset_write_head(); + + senderOutput[0][i].pack(os); + os.hash(h_os); + senderOutput[0][i].unpack(h_os); + + os.reset_write_head(); + h_os.reset_write_head(); + + tmp.pack(os); + os.hash(h_os); + senderOutput[1][i].unpack(h_os); + } + } + if (ot_role & RECEIVER) + { + if (receiverOutput[i].size() == 128) + mmo.hashOneBlock(receiverOutput[i].get_ptr(), receiverOutput[i].get_ptr()); + else + { + os.reset_write_head(); + h_os.reset_write_head(); + + receiverOutput[i].pack(os); + os.hash(h_os); + receiverOutput[i].unpack(h_os); + } + } + } + cout << "done.\n"; +#ifdef OTEXT_TIMER + gettimeofday(&endv, NULL); + double elapsed = timeval_diff(&startv, &endv); + cout << "\t\tOT ext hashing took time " << elapsed/1000000 << endl << flush; + times["Hashing"] += timeval_diff(&startv, &endv); +#endif +} + + +// test if a == b +int eq_m128i(__m128i a, __m128i b) +{ + __m128i vcmp = _mm_cmpeq_epi8(a, b); + uint16_t vmask = _mm_movemask_epi8(vcmp); + return (vmask == 0xffff); +} + +void random_m128i(PRNG& G, __m128i *r) +{ + BitVector rv(128); + rv.randomize(G); + *r = _mm_load_si128((__m128i*)&(rv.get_ptr()[0])); +} + +void test_mul() +{ + cout << "Testing GF(2^128) multiplication\n"; + __m128i t1, t2, t3, t4, t5, t6, t7, t8; + PRNG G; + G.ReSeed(); + BitVector r(128); + for (int i = 0; i < 1000; i++) + { + random_m128i(G, &t1); + random_m128i(G, &t2); + // test commutativity + gfmul128(t1, t2, &t3); + gfmul128(t2, t1, &t4); + if (!eq_m128i(t3, t4)) + { + cerr << "Incorrect multiplication:\n"; + cerr << "t1 * t2 = " << __m128i_toString(t3) << endl; + cerr << "t2 * t1 = " << __m128i_toString(t4) << endl; + } + // test distributivity: t1*t3 + t2*t3 = (t1 + t2) * t3 + random_m128i(G, &t1); + random_m128i(G, &t2); + random_m128i(G, &t3); + gfmul128(t1, t3, &t4); + gfmul128(t2, t3, &t5); + t6 = _mm_xor_si128(t4, t5); + + t7 = _mm_xor_si128(t1, t2); + gfmul128(t7, t3, &t8); + if (!eq_m128i(t6, t8)) + { + cerr << "Incorrect multiplication:\n"; + cerr << "t1 * t3 + t2 * t3 = " << __m128i_toString(t6) << endl; + cerr << "(t1 + t2) * t3 = " << __m128i_toString(t8) << endl; + } + } + t1 = _mm_set_epi32(0, 0, 0, 03); + t2 = _mm_set_epi32(0, 0, 0, 11); + //gfmul128(t1, t2, &t3); + mul128(t1, t2, &t3, &t4); + cout << "t1 = " << __m128i_toString(t1) << endl; + cout << "t2 = " << __m128i_toString(t2) << endl; + cout << "t3 = " << __m128i_toString(t3) << endl; + cout << "t4 = " << __m128i_toString(t4) << endl; + + uint64_t cc[] __attribute__((aligned (16))) = { 0,0 }; + _mm_store_si128((__m128i*)cc, t1); + word t1w = cc[0]; + _mm_store_si128((__m128i*)cc, t2); + word t2w = cc[0]; + + cout << "t1w = " << t1w << endl; + cout << "t1 = " << word_to_bytes(t1w) << endl; + cout << "t2 = " << word_to_bytes(t2w) << endl; + cout << "t1 * t2 = " << word_to_bytes(t1w*t2w) << endl; +} + + + +void OTExtension::check_correlation(int nOTs, + const BitVector& receiverInput) +{ + //cout << "\tStarting correlation check\n" << flush; +#ifdef OTEXT_TIMER + timeval startv, endv; + gettimeofday(&startv, NULL); +#endif + if (nbaseOTs != 128) + { + cerr << "Correlation check not implemented for length != 128\n"; + throw not_implemented(); + } + PRNG G; + octet* seed = new octet[SEED_SIZE]; + random_seed_commit(seed, *player, SEED_SIZE); +#ifdef OTEXT_TIMER + gettimeofday(&endv, NULL); + double elapsed = timeval_diff(&startv, &endv); + cout << "\t\tCommitment for seed took time " << elapsed/1000000 << endl << flush; + times["Commitment for seed"] += timeval_diff(&startv, &endv); + gettimeofday(&startv, NULL); +#endif + G.SetSeed(seed); + delete[] seed; + vector os(2); + + if (!Check_CPU_support_AES()) + { + cerr << "Not implemented GF(2^128) multiplication in C\n"; + throw not_implemented(); + } + + __m128i Delta, x128i; + Delta = _mm_load_si128((__m128i*)&(baseReceiverInput.get_ptr()[0])); + + BitVector chi(nbaseOTs); + BitVector x(nbaseOTs); + __m128i t = _mm_setzero_si128(); + __m128i q = _mm_setzero_si128(); + __m128i t2 = _mm_setzero_si128(); + __m128i q2 = _mm_setzero_si128(); + __m128i chii, ti, qi, ti2, qi2; + x128i = _mm_setzero_si128(); + + for (int i = 0; i < nOTs; i++) + { +// chi.randomize(G); +// chii = _mm_load_si128((__m128i*)&(chi.get_ptr()[0])); + chii = G.get_doubleword(); + + if (ot_role & RECEIVER) + { + if (receiverInput.get_bit(i) == 1) + { + x128i = _mm_xor_si128(x128i, chii); + } + ti = _mm_loadu_si128((__m128i*)get_receiver_output(i)); + // multiply over polynomial ring to avoid reduction + mul128(ti, chii, &ti, &ti2); + t = _mm_xor_si128(t, ti); + t2 = _mm_xor_si128(t2, ti2); + } + if (ot_role & SENDER) + { + qi = _mm_loadu_si128((__m128i*)(get_sender_output(0, i))); + mul128(qi, chii, &qi, &qi2); + q = _mm_xor_si128(q, qi); + q2 = _mm_xor_si128(q2, qi2); + } + } +#ifdef OTEXT_DEBUG + if (ot_role & RECEIVER) + { + cout << "\tSending x,t\n"; + cout << "\tsend x = " << __m128i_toString(x128i) << endl; + cout << "\tsend t = " << __m128i_toString(t) << endl; + cout << "\tsend t2 = " << __m128i_toString(t2) << endl; + } +#endif + check_iteration(Delta, q, q2, t, t2, x128i); +#ifdef OTEXT_TIMER + gettimeofday(&endv, NULL); + elapsed = timeval_diff(&startv, &endv); + cout << "\t\tChecking correlation took time " << elapsed/1000000 << endl << flush; + times["Checking correlation"] += timeval_diff(&startv, &endv); +#endif +} + +void OTExtension::check_iteration(__m128i delta, __m128i q, __m128i q2, + __m128i t, __m128i t2, __m128i x) +{ + vector os(2); + // send x, t; + __m128i received_t, received_t2, received_x, tmp1, tmp2; + if (ot_role & RECEIVER) + { + os[0].append((octet*)&x, sizeof(x)); + os[0].append((octet*)&t, sizeof(t)); + os[0].append((octet*)&t2, sizeof(t2)); + } + send_if_ot_receiver(player, os, ot_role); + + if (ot_role & SENDER) + { + os[1].consume((octet*)&received_x, sizeof(received_x)); + os[1].consume((octet*)&received_t, sizeof(received_t)); + os[1].consume((octet*)&received_t2, sizeof(received_t2)); + + // check t = x * Delta + q + //gfmul128(received_x, delta, &tmp1); + mul128(received_x, delta, &tmp1, &tmp2); + tmp1 = _mm_xor_si128(tmp1, q); + tmp2 = _mm_xor_si128(tmp2, q2); + + if (eq_m128i(tmp1, received_t) && eq_m128i(tmp2, received_t2)) + { + //cout << "\tCheck passed\n"; + } + else + { + cerr << "Correlation check failed\n"; + cout << "rec t = " << __m128i_toString(received_t) << endl; + cout << "tmp1 = " << __m128i_toString(tmp1) << endl; + cout << "q = " << __m128i_toString(q) << endl; + exit(1); + } + } +} + + +octet* OTExtension::get_receiver_output(int i) +{ + return receiverOutput[i].get_ptr(); +} + +octet* OTExtension::get_sender_output(int choice, int i) +{ + return senderOutput[choice][i].get_ptr(); +} diff --git a/OT/OTExtension.h b/OT/OTExtension.h new file mode 100644 index 000000000..91724a6a0 --- /dev/null +++ b/OT/OTExtension.h @@ -0,0 +1,122 @@ +// (C) 2016 University of Bristol. See License.txt + +#ifndef _OTEXTENSION +#define _OTEXTENSION + +#include "OT/BaseOT.h" + +#include "Exceptions/Exceptions.h" + +#include "Networking/Player.h" + +#include "Tools/time-func.h" + +#include +#include +#include +#include +#include +#include + +using namespace std; + +//#define OTEXT_TIMER +//#define OTEXT_DEBUG + + +class OTExtension +{ +public: + BitVector baseReceiverInput; + vector< vector > senderOutput; + vector receiverOutput; + map times; + + OTExtension(int nbaseOTs, int baseLength, + int nloops, int nsubloops, + TwoPartyPlayer* player, + BitVector& baseReceiverInput, + vector< vector >& baseSenderInput, + vector& baseReceiverOutput, + OT_ROLE role=BOTH, + bool passive=false) + : baseReceiverInput(baseReceiverInput), passive_only(passive), nbaseOTs(nbaseOTs), + baseLength(baseLength), nloops(nloops), nsubloops(nsubloops), ot_role(role), player(player) + { + G_sender.resize(nbaseOTs, vector(2)); + G_receiver.resize(nbaseOTs); + + // set up PRGs for expanding the seed OTs + for (int i = 0; i < nbaseOTs; i++) + { + assert(baseSenderInput[i][0].size_bytes() >= AES_BLK_SIZE); + assert(baseSenderInput[i][1].size_bytes() >= AES_BLK_SIZE); + assert(baseReceiverOutput[i].size_bytes() >= AES_BLK_SIZE); + + if (ot_role & RECEIVER) + { + G_sender[i][0].SetSeed(baseSenderInput[i][0].get_ptr()); + G_sender[i][1].SetSeed(baseSenderInput[i][1].get_ptr()); + } + if (ot_role & SENDER) + { + G_receiver[i].SetSeed(baseReceiverOutput[i].get_ptr()); + } + +#ifdef OTEXT_DEBUG + // sanity check for base OTs + vector os(2); + BitVector t0(128); + + if (ot_role & RECEIVER) + { + // send both inputs to test + baseSenderInput[i][0].pack(os[0]); + baseSenderInput[i][1].pack(os[0]); + } + send_if_ot_receiver(player, os, ot_role); + + if (ot_role & SENDER) + { + // sender checks results + t0.unpack(os[1]); + if (baseReceiverInput.get_bit(i) == 1) + t0.unpack(os[1]); + if (!t0.equals(baseReceiverOutput[i])) + { + cerr << "Incorrect base OT\n"; + exit(1); + } + } + + + os[0].reset_write_head(); + os[1].reset_write_head(); +#endif + } + } + + virtual ~OTExtension() {} + + virtual void transfer(int nOTs, const BitVector& receiverInput); + virtual octet* get_receiver_output(int i); + virtual octet* get_sender_output(int choice, int i); + +protected: + bool passive_only; + int nbaseOTs, baseLength, nloops, nsubloops; + OT_ROLE ot_role; + TwoPartyPlayer* player; + vector< vector > G_sender; + vector G_receiver; + + void check_correlation(int nOTs, + const BitVector& receiverInput); + + void check_iteration(__m128i delta, __m128i q, __m128i q2, + __m128i t, __m128i t2, __m128i x); + + void hash_outputs(int nOTs, vector& receiverOutput); +}; + +#endif diff --git a/OT/OTExtensionWithMatrix.cpp b/OT/OTExtensionWithMatrix.cpp new file mode 100644 index 000000000..7894ae949 --- /dev/null +++ b/OT/OTExtensionWithMatrix.cpp @@ -0,0 +1,466 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * OTExtensionWithMatrix.cpp + * + */ + +#include "OTExtensionWithMatrix.h" +#include "Math/gfp.h" + +void OTExtensionWithMatrix::seed(vector& baseSenderInput, + BitMatrix& baseReceiverOutput) +{ + nbaseOTs = baseReceiverInput.size(); + //cout << "nbaseOTs " << nbaseOTs << endl; + G_sender.resize(nbaseOTs, vector(2)); + G_receiver.resize(nbaseOTs); + + // set up PRGs for expanding the seed OTs + for (int i = 0; i < nbaseOTs; i++) + { + if (ot_role & RECEIVER) + { + G_sender[i][0].SetSeed((octet*)&baseSenderInput[0].squares[i/128].rows[i%128]); + G_sender[i][1].SetSeed((octet*)&baseSenderInput[1].squares[i/128].rows[i%128]); + } + if (ot_role & SENDER) + { + G_receiver[i].SetSeed((octet*)&baseReceiverOutput.squares[i/128].rows[i%128]); + } + } +} + +void OTExtensionWithMatrix::transfer(int nOTs, + const BitVector& receiverInput) +{ +#ifdef OTEXT_TIMER + timeval totalstartv, totalendv; + gettimeofday(&totalstartv, NULL); +#endif + cout << "\tDoing " << nOTs << " extended OTs as " << role_to_str(ot_role) << endl; + + // resize to account for extra k OTs that are discarded + BitVector newReceiverInput(nOTs); + for (unsigned int i = 0; i < receiverInput.size_bytes(); i++) + { + newReceiverInput.set_byte(i, receiverInput.get_byte(i)); + } + + for (int loop = 0; loop < nloops; loop++) + { + extend(nOTs, newReceiverInput); +#ifdef OTEXT_TIMER + gettimeofday(&totalendv, NULL); + double elapsed = timeval_diff(&totalstartv, &totalendv); + cout << "\t\tTotal thread time: " << elapsed/1000000 << endl << flush; +#endif + } + +#ifdef OTEXT_TIMER + gettimeofday(&totalendv, NULL); + times["Total thread"] += timeval_diff(&totalstartv, &totalendv); +#endif +} + +void OTExtensionWithMatrix::resize(int nOTs) +{ + t1.resize(nOTs); + u.resize(nOTs); + senderOutputMatrices.resize(2); + for (int i = 0; i < 2; i++) + senderOutputMatrices[i].resize(nOTs); + receiverOutputMatrix.resize(nOTs); +} + +// the template is used to denote the field of the hash output +template +void OTExtensionWithMatrix::extend(int nOTs_requested, BitVector& newReceiverInput) +{ +// if (nOTs % nbaseOTs != 0) +// throw invalid_length(); //"nOTs must be a multiple of nbaseOTs\n"); + if (nOTs_requested == 0) + return; + // add k + s to account for discarding k OTs + int nOTs = nOTs_requested + 2 * 128; + + int slice = nOTs / nsubloops / 128; + nOTs = slice * nsubloops * 128; + resize(nOTs); + newReceiverInput.resize(nOTs); + + // randomize last 128 + 128 bits that will be discarded + for (int i = 0; i < 4; i++) + newReceiverInput.set_word(nOTs/64 - i - 1, G.get_word()); + + // subloop for first part to interleave communication with computation + for (int start = 0; start < nOTs / 128; start += slice) + { + expand(start, slice); + correlate(start, slice, newReceiverInput, true); + transpose(start, slice); + } + +#ifdef OTEXT_TIMER + double elapsed; +#endif + // correlation check + if (!passive_only) + { +#ifdef OTEXT_TIMER + timeval startv, endv; + gettimeofday(&startv, NULL); +#endif + check_correlation(nOTs, newReceiverInput); +#ifdef OTEXT_TIMER + gettimeofday(&endv, NULL); + elapsed = timeval_diff(&startv, &endv); + cout << "\t\tTotal correlation check time: " << elapsed/1000000 << endl << flush; + times["Total correlation check"] += timeval_diff(&startv, &endv); +#endif + } + + hash_outputs(nOTs); + + receiverOutputMatrix.resize(nOTs_requested); + senderOutputMatrices[0].resize(nOTs_requested); + senderOutputMatrices[1].resize(nOTs_requested); + newReceiverInput.resize(nOTs_requested); +} + +template +void OTExtensionWithMatrix::expand(int start, int slice) +{ + BitMatrixSlice receiverOutputSlice(receiverOutputMatrix, start, slice); + BitMatrixSlice senderOutputSlices[2] = { + BitMatrixSlice(senderOutputMatrices[0], start, slice), + BitMatrixSlice(senderOutputMatrices[1], start, slice) + }; + BitMatrixSlice t1Slice(t1, start, slice); + + // expand with PRG + if (ot_role & RECEIVER) + { + for (int i = 0; i < nbaseOTs; i++) + { + receiverOutputSlice.randomize(i, G_sender[i][0]); + t1Slice.randomize(i, G_sender[i][1]); + } + } + + if (ot_role & SENDER) + { + for (int i = 0; i < nbaseOTs; i++) + // randomize base receiver output + senderOutputSlices[0].randomize(i, G_receiver[i]); + } +} + +template +void OTExtensionWithMatrix::expand_transposed() +{ + for (int i = 0; i < nbaseOTs; i++) + { + if (ot_role & RECEIVER) + { + receiverOutputMatrix.squares[i/128].randomize(i % 128, G_sender[i][0]); + t1.squares[i/128].randomize(i % 128, G_sender[i][1]); + } + if (ot_role & SENDER) + { + senderOutputMatrices[0].squares[i/128].randomize(i % 128, G_receiver[i]); + } + } +} + +void OTExtensionWithMatrix::setup_for_correlation(vector& baseSenderOutputs, BitMatrix& baseReceiverOutput) +{ + receiverOutputMatrix = baseSenderOutputs[0]; + t1 = baseSenderOutputs[1]; + u.resize(t1.size()); + senderOutputMatrices.resize(2); + senderOutputMatrices[0] = baseReceiverOutput; +} + +template +void OTExtensionWithMatrix::correlate(int start, int slice, + BitVector& newReceiverInput, bool useConstantBase, int repeat) +{ + vector os(2); + + BitMatrixSlice receiverOutputSlice(receiverOutputMatrix, start, slice); + BitMatrixSlice senderOutputSlices[2] = { + BitMatrixSlice(senderOutputMatrices[0], start, slice), + BitMatrixSlice(senderOutputMatrices[1], start, slice) + }; + BitMatrixSlice t1Slice(t1, start, slice); + BitMatrixSlice uSlice(u, start, slice); + + // create correlation + if (ot_role & RECEIVER) + { + t1Slice.rsub(receiverOutputSlice); + t1Slice.add(newReceiverInput, repeat); + t1Slice.pack(os[0]); + +// t1 = receiverOutputMatrix; +// t1 ^= newReceiverInput; +// receiverOutputMatrix.print_side_by_side(t1); + } +#ifdef OTEXT_TIMER + timeval commst1, commst2; + gettimeofday(&commst1, NULL); +#endif + // send t0 + t1 + x + send_if_ot_receiver(player, os, ot_role); + + // sender adjusts using base receiver bits + if (ot_role & SENDER) + { + // u = t0 + t1 + x + uSlice.unpack(os[1]); + senderOutputSlices[0].conditional_add(baseReceiverInput, u, !useConstantBase); + } +#ifdef OTEXT_TIMER + gettimeofday(&commst2, NULL); + double commstime = timeval_diff(&commst1, &commst2); + cout << "\t\tCommunication took time " << commstime/1000000 << endl << flush; + times["Communication"] += timeval_diff(&commst1, &commst2); +#endif +} + +void OTExtensionWithMatrix::transpose(int start, int slice) +{ + BitMatrixSlice receiverOutputSlice(receiverOutputMatrix, start, slice); + BitMatrixSlice senderOutputSlices[2] = { + BitMatrixSlice(senderOutputMatrices[0], start, slice), + BitMatrixSlice(senderOutputMatrices[1], start, slice) + }; + + // transpose t0[i] onto receiverOutput and tmp (q[i]) onto senderOutput[i][0] + + //cout << "Starting matrix transpose\n" << flush << endl; +#ifdef OTEXT_TIMER + timeval transt1, transt2; + gettimeofday(&transt1, NULL); +#endif + // transpose in 128-bit chunks + if (ot_role & RECEIVER) + receiverOutputSlice.transpose(); + if (ot_role & SENDER) + senderOutputSlices[0].transpose(); + +#ifdef OTEXT_TIMER + gettimeofday(&transt2, NULL); + double transtime = timeval_diff(&transt1, &transt2); + cout << "\t\tMatrix transpose took time " << transtime/1000000 << endl << flush; + times["Matrix transpose"] += timeval_diff(&transt1, &transt2); +#endif +} + +/* + * Hash outputs to make into random OT + */ +template +void OTExtensionWithMatrix::hash_outputs(int nOTs) +{ + //cout << "Hashing... " << flush; + octetStream os, h_os(HASH_SIZE); + square128 tmp; + MMO mmo; +#ifdef OTEXT_TIMER + timeval startv, endv; + gettimeofday(&startv, NULL); +#endif + + for (int i = 0; i < nOTs / 128; i++) + { + if (ot_role & SENDER) + { + tmp = senderOutputMatrices[0].squares[i]; + tmp ^= baseReceiverInput; + senderOutputMatrices[0].squares[i].hash_row_wise(mmo, senderOutputMatrices[0].squares[i]); + senderOutputMatrices[1].squares[i].hash_row_wise(mmo, tmp); + } + if (ot_role & RECEIVER) + { + receiverOutputMatrix.squares[i].hash_row_wise(mmo, receiverOutputMatrix.squares[i]); + } + } + //cout << "done.\n"; +#ifdef OTEXT_TIMER + gettimeofday(&endv, NULL); + double elapsed = timeval_diff(&startv, &endv); + cout << "\t\tOT ext hashing took time " << elapsed/1000000 << endl << flush; + times["Hashing"] += timeval_diff(&startv, &endv); +#endif +} + +template +void OTExtensionWithMatrix::reduce_squares(unsigned int nTriples, vector& output) +{ + if (receiverOutputMatrix.squares.size() < nTriples) + throw invalid_length(); + output.resize(nTriples); + for (unsigned int j = 0; j < nTriples; j++) + { + T c1, c2; + receiverOutputMatrix.squares[j].to(c1); + senderOutputMatrices[0].squares[j].to(c2); + output[j] = c1 - c2; + } +} + +octet* OTExtensionWithMatrix::get_receiver_output(int i) +{ + return (octet*)&receiverOutputMatrix.squares[i/128].rows[i%128]; +} + +octet* OTExtensionWithMatrix::get_sender_output(int choice, int i) +{ + return (octet*)&senderOutputMatrices[choice].squares[i/128].rows[i%128]; +} + +void OTExtensionWithMatrix::print(BitVector& newReceiverInput, int i) +{ + if (player->my_num() == 0) + { + print_receiver(newReceiverInput, receiverOutputMatrix, i); + print_sender(senderOutputMatrices[0].squares[i], senderOutputMatrices[1].squares[i]); + } + else + { + print_sender(senderOutputMatrices[0].squares[i], senderOutputMatrices[1].squares[i]); + print_receiver(newReceiverInput, receiverOutputMatrix, i); + } +} + +template +void OTExtensionWithMatrix::print_receiver(BitVector& newReceiverInput, BitMatrix& matrix, int k, int offset) +{ + if (ot_role & RECEIVER) + { + for (int i = 0; i < 16; i++) + { + if (newReceiverInput.get_bit((offset + k) * 128 + i)) + { + for (int j = 0; j < 33; j++) + cout << " "; + cout << T(matrix.squares[k].rows[i]); + } + else + cout << int128(matrix.squares[k].rows[i]); + cout << endl; + } + cout << endl; + } +} + +void OTExtensionWithMatrix::print_sender(square128& square0, square128& square1) +{ + if (ot_role & SENDER) + { + for (int i = 0; i < 16; i++) + { + cout << int128(square0.rows[i]) << " "; + cout << int128(square1.rows[i]) << " "; + cout << endl; + } + cout << endl; + } +} + +template +void OTExtensionWithMatrix::print_post_correlate(BitVector& newReceiverInput, int j, int offset, int sender) +{ + cout << "post correlate, sender" << sender << endl; + if (player->my_num() == sender) + { + T delta = newReceiverInput.get_int128(offset + j); + for (int i = 0; i < 16; i++) + { + cout << (int128(receiverOutputMatrix.squares[j].rows[i])); + cout << " "; + cout << (T(receiverOutputMatrix.squares[j].rows[i]) - delta); + cout << endl; + } + cout << endl; + } + else + { + print_receiver(baseReceiverInput, senderOutputMatrices[0], j); + } +} + +void OTExtensionWithMatrix::print_pre_correlate(int i) +{ + cout << "pre correlate" << endl; + if (player->my_num() == 0) + print_sender(receiverOutputMatrix.squares[i], t1.squares[i]); + else + print_receiver(baseReceiverInput, senderOutputMatrices[0], i); +} + +void OTExtensionWithMatrix::print_post_transpose(BitVector& newReceiverInput, int i, int sender) +{ + cout << "post transpose, sender " << sender << endl; + if (player->my_num() == sender) + { + print_receiver(newReceiverInput, receiverOutputMatrix); + } + else + { + square128 tmp = senderOutputMatrices[0].squares[i]; + tmp ^= baseReceiverInput; + print_sender(senderOutputMatrices[0].squares[i], tmp); + } +} + +void OTExtensionWithMatrix::print_pre_expand() +{ + cout << "pre expand" << endl; + if (player->my_num() == 0) + { + for (int i = 0; i < 16; i++) + { + for (int j = 0; j < 2; j++) + cout << int128(_mm_loadu_si128((__m128i*)G_sender[i][j].get_seed())) << " "; + cout << endl; + } + cout << endl; + } + else + { + for (int i = 0; i < 16; i++) + { + if (baseReceiverInput.get_bit(i)) + { + for (int j = 0; j < 33; j++) + cout << " "; + } + cout << int128(_mm_loadu_si128((__m128i*)G_receiver[i].get_seed())) << endl; + } + cout << endl; + } +} + +template void OTExtensionWithMatrix::correlate(int start, int slice, + BitVector& newReceiverInput, bool useConstantBase, int repeat); +template void OTExtensionWithMatrix::correlate(int start, int slice, + BitVector& newReceiverInput, bool useConstantBase, int repeat); +template void OTExtensionWithMatrix::print_post_correlate( + BitVector& newReceiverInput, int j, int offset, int sender); +template void OTExtensionWithMatrix::print_post_correlate( + BitVector& newReceiverInput, int j, int offset, int sender); +template void OTExtensionWithMatrix::extend(int nOTs_requested, + BitVector& newReceiverInput); +template void OTExtensionWithMatrix::extend(int nOTs_requested, + BitVector& newReceiverInput); +template void OTExtensionWithMatrix::expand(int start, int slice); +template void OTExtensionWithMatrix::expand(int start, int slice); +template void OTExtensionWithMatrix::expand_transposed(); +template void OTExtensionWithMatrix::expand_transposed(); +template void OTExtensionWithMatrix::reduce_squares(unsigned int nTriples, + vector& output); +template void OTExtensionWithMatrix::reduce_squares(unsigned int nTriples, + vector& output); diff --git a/OT/OTExtensionWithMatrix.h b/OT/OTExtensionWithMatrix.h new file mode 100644 index 000000000..d9cf59bde --- /dev/null +++ b/OT/OTExtensionWithMatrix.h @@ -0,0 +1,71 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * OTExtensionWithMatrix.h + * + */ + +#ifndef OT_OTEXTENSIONWITHMATRIX_H_ +#define OT_OTEXTENSIONWITHMATRIX_H_ + +#include "OTExtension.h" +#include "BitMatrix.h" +#include "Math/gf2n.h" + +class OTExtensionWithMatrix : public OTExtension +{ +public: + vector senderOutputMatrices; + BitMatrix receiverOutputMatrix; + BitMatrix t1, u; + PRNG G; + + OTExtensionWithMatrix(int nbaseOTs, int baseLength, + int nloops, int nsubloops, + TwoPartyPlayer* player, + BitVector& baseReceiverInput, + vector< vector >& baseSenderInput, + vector& baseReceiverOutput, + OT_ROLE role=BOTH, + bool passive=false) + : OTExtension(nbaseOTs, baseLength, nloops, nsubloops, player, baseReceiverInput, + baseSenderInput, baseReceiverOutput, role, passive) { + G.ReSeed(); + } + + void seed(vector& baseSenderInput, + BitMatrix& baseReceiverOutput); + void transfer(int nOTs, const BitVector& receiverInput); + void resize(int nOTs); + template + void extend(int nOTs, BitVector& newReceiverInput); + template + void expand(int start, int slice); + template + void expand_transposed(); + template + void correlate(int start, int slice, BitVector& newReceiverInput, bool useConstantBase, int repeat = 1); + void transpose(int start, int slice); + void setup_for_correlation(vector& baseSenderOutputs, BitMatrix& baseReceiverOutput); + template + void reduce_squares(unsigned int nTriples, vector& output); + + void print(BitVector& newReceiverInput, int i = 0); + template + void print_receiver(BitVector& newReceiverInput, BitMatrix& matrix, int i = 0, int offset = 0); + void print_sender(square128& square0, square128& square); + template + void print_post_correlate(BitVector& newReceiverInput, int i = 0, int offset = 0, int sender = 0); + void print_pre_correlate(int i = 0); + void print_post_transpose(BitVector& newReceiverInput, int i = 0, int sender = 0); + void print_pre_expand(); + + octet* get_receiver_output(int i); + octet* get_sender_output(int choice, int i); + +protected: + template + void hash_outputs(int nOTs); +}; + +#endif /* OT_OTEXTENSIONWITHMATRIX_H_ */ diff --git a/OT/OTMachine.cpp b/OT/OTMachine.cpp new file mode 100644 index 000000000..db5e2b6b0 --- /dev/null +++ b/OT/OTMachine.cpp @@ -0,0 +1,401 @@ +// (C) 2016 University of Bristol. See License.txt + +#include "Networking/Player.h" +#include "OT/OTExtension.h" +#include "OT/OTExtensionWithMatrix.h" +#include "Exceptions/Exceptions.h" +#include "Tools/time-func.h" + +#include +#include +#include +#include + +#include + +#include "OutputCheck.h" +#include "OTMachine.h" + +//#define BASE_OT_DEBUG + +class OT_thread_info +{ + public: + + int thread_num; + bool stop; + int other_player_num; + OTExtension* ot_ext; + int nOTs, nbase; + BitVector receiverInput; +}; + +void* run_otext_thread(void* ptr) +{ + OT_thread_info *tinfo = (OT_thread_info*) ptr; + + //int num = tinfo->thread_num; + //int other_player_num = tinfo->other_player_num; + printf("\tI am in thread %d\n", tinfo->thread_num); + tinfo->ot_ext->transfer(tinfo->nOTs, tinfo->receiverInput); + return NULL; +} + +OTMachine::OTMachine(int argc, const char** argv) +{ + opt.add( + "", // Default. + 1, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "This player's number, 0/1 (required).", // Help description. + "-p", // Flag token. + "--player" // Flag token. + ); + + opt.add( + "5000", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Base port number (default: 5000).", // Help description. + "-pn", // Flag token. + "--portnum" // Flag token. + ); + + opt.add( + "localhost", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Host name(s) that player 0 is running on (default: localhost). Split with commas.", // Help description. + "-h", // Flag token. + "--hostname" // Flag token. + ); + + opt.add( + "1024", + 0, + 1, + 0, + "Number of extended OTs to run (default: 1024).", + "-n", + "--nOTs" + ); + + opt.add( + "128", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Number of base OTs to run (default: 128).", // Help description. + "-b", // Flag token. + "--nbase" // Flag token. + ); + + opt.add( + "s", + 0, + 1, + 0, + "Mode for OT. a (asymmetric) or s (symmetric, i.e. play both sender/receiver) (default: s).", + "-m", + "--mode" + ); + opt.add( + "1", + 0, + 1, + 0, + "Number of threads (default: 1).", + "-x", + "--nthreads" + ); + + opt.add( + "1", + 0, + 1, + 0, + "Number of loops (default: 1).", + "-l", + "--nloops" + ); + + opt.add( + "1", + 0, + 1, + 0, + "Number of subloops (default: 1).", + "-s", + "--nsubloops" + ); + + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Run in passive security mode.", // Help description. + "-pas", // Flag token. + "--passive" // Flag token. + ); + + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Write results to files.", // Help description. + "-o", // Flag token. + "--output" // Flag token. + ); + + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Real base OT.", // Help description. + "-r", // Flag token. + "--real" // Flag token. + ); + + opt.parse(argc, argv); + + string hostname, ot_mode, usage; + passive = false; + opt.get("-p")->getInt(my_num); + opt.get("-pn")->getInt(portnum_base); + opt.get("-h")->getString(hostname); + opt.get("-n")->getLong(nOTs); + opt.get("-m")->getString(ot_mode); + opt.get("--nthreads")->getInt(nthreads); + opt.get("--nloops")->getInt(nloops); + opt.get("--nsubloops")->getInt(nsubloops); + opt.get("--nbase")->getInt(nbase); + if (opt.isSet("-pas")) + passive = true; + + if (!opt.isSet("-p")) + { + opt.getUsage(usage); + cout << usage; + exit(0); + } + + cout << "Player 0 host name = " << hostname << endl; + cout << "Creating " << nOTs << " extended OTs in " << nthreads << " threads\n"; + cout << "Running in mode " << ot_mode << endl; + + if (passive) + cout << "Running with PASSIVE security only\n"; + + if (nbase < 128) + cout << "WARNING: only using " << nbase << " seed OTs, using these for OT extensions is insecure.\n"; + + if (ot_mode.compare("s") == 0) + ot_role = BOTH; + else if (ot_mode.compare("a") == 0) + { + if (my_num == 0) + ot_role = SENDER; + else + ot_role = RECEIVER; + } + else + { + cerr << "Invalid OT mode argument: " << ot_mode << endl; + exit(1); + } + + // Several names for multiplexing + unsigned int pos = 0; + while (pos < hostname.length()) + { + string::size_type new_pos = hostname.find(',', pos); + if (new_pos == string::npos) + new_pos = hostname.length(); + int len = new_pos - pos; + string name = hostname.substr(pos, len); + pos = new_pos + 1; + + vector names(2); + names[my_num] = "localhost"; + names[1-my_num] = name; + N.resize(N.size() + 1); + N[N.size()-1].init(my_num, portnum_base, names); + } + + P = new TwoPartyPlayer(N[0], 1 - my_num, 500); + + timeval baseOTstart, baseOTend; + gettimeofday(&baseOTstart, NULL); + // swap role for base OTs + if (opt.isSet("-r")) + bot_ = new BaseOT(nbase, 128, P, INV_ROLE(ot_role)); + else + bot_ = new FakeOT(nbase, 128, P, INV_ROLE(ot_role)); + cout << "real mode " << opt.isSet("-r") << endl; + BaseOT& bot = *bot_; + bot.exec_base(); + gettimeofday(&baseOTend, NULL); + double basetime = timeval_diff(&baseOTstart, &baseOTend); + cout << "\t\tBaseTime (" << role_to_str(ot_role) << "): " << basetime/1000000 << endl << flush; + + // Receiver send something to force synchronization + // (since Sender finishes baseOTs before Receiver) + int a = 3; + vector os(2); + os[0].store(a); + P->send_receive_player(os); + os[1].get(a); + cout << a << endl; + +#ifdef BASE_OT_DEBUG + // check base OTs + bot.check(); + // check after extending with PRG a few times + for (int i = 0; i < 8; i++) + { + bot.extend_length(); + bot.check(); + } + cout << "Verifying base OTs (debugging)\n"; +#endif + + // convert baseOT selection bits to BitVector + // (not already BitVector due to legacy PVW code) + baseReceiverInput.resize(nbase); + for (int i = 0; i < nbase; i++) + { + baseReceiverInput.set_bit(i, bot.receiver_inputs[i]); + } +} + +OTMachine::~OTMachine() +{ + delete bot_; + delete P; +} + + +void OTMachine::run() +{ + // divide nOTs between threads and loops + nOTs = DIV_CEIL(nOTs, nthreads * nloops); + // round up to multiple of base OTs and subloops + // discount for discarded OTs + nOTs = DIV_CEIL(nOTs + 2 * 128, nbase * nsubloops) * nbase * nsubloops - 2 * 128; + cout << "Running " << nOTs << " OT extensions per thread and loop\n" << flush; + + // PRG for generating inputs etc + PRNG G; + G.ReSeed(); + BitVector receiverInput(nOTs); + receiverInput.randomize(G); + BaseOT& bot = *bot_; + + cout << "Initialize OT Extension\n"; + vector tinfos(nthreads); + vector threads(nthreads); + timeval OTextstart, OTextend; + gettimeofday(&OTextstart, NULL); + + // copy base inputs/outputs for each thread + vector base_receiver_input_copy(nthreads); + vector > > base_sender_inputs_copy(nthreads, vector >(nbase, vector(2))); + vector< vector > base_receiver_outputs_copy(nthreads, vector(nbase)); + vector players(nthreads); + + for (int i = 0; i < nthreads; i++) + { + tinfos[i].receiverInput.assign(receiverInput); + + base_receiver_input_copy[i].assign(baseReceiverInput); + for (int j = 0; j < nbase; j++) + { + base_sender_inputs_copy[i][j][0].assign(bot.sender_inputs[j][0]); + base_sender_inputs_copy[i][j][1].assign(bot.sender_inputs[j][1]); + base_receiver_outputs_copy[i][j].assign(bot.receiver_outputs[j]); + } + // now setup resources for each thread + // round robin with the names + players[i] = new TwoPartyPlayer(N[i%N.size()], 1 - my_num, (i+1) * 1000); + tinfos[i].thread_num = i+1; + tinfos[i].other_player_num = 1 - my_num; + tinfos[i].nOTs = nOTs; + tinfos[i].ot_ext = new OTExtensionWithMatrix(nbase, bot.length(), + nloops, nsubloops, + players[i], + base_receiver_input_copy[i], + base_sender_inputs_copy[i], + base_receiver_outputs_copy[i], + ot_role, + passive); + + // create the thread + pthread_create(&threads[i], NULL, run_otext_thread, &tinfos[i]); + + // extend base OTs with PRG for the next thread + bot.extend_length(); + } + // wait for threads to finish + for (int i = 0; i < nthreads; i++) + { + pthread_join(threads[i],NULL); + cout << "thread " << i+1 << " finished\n" << flush; + } + + map& times = tinfos[0].ot_ext->times; + for (map::iterator it = times.begin(); it != times.end(); it++) + { + long long sum = 0; + for (int i = 0; i < nthreads; i++) + sum += tinfos[i].ot_ext->times[it->first]; + + cout << it->first << " on average took time " + << double(sum) / nthreads / 1e6 << endl; + } + + gettimeofday(&OTextend, NULL); + double totaltime = timeval_diff(&OTextstart, &OTextend); + cout << "Time for OTExt threads (" << role_to_str(ot_role) << "): " << totaltime/1000000 << endl << flush; + + if (opt.isSet("-o")) + { + BitVector receiver_output, sender_output; + char filename[1024]; + sprintf(filename, RECEIVER_INPUT, my_num); + ofstream outf(filename); + receiverInput.output(outf, false); + outf.close(); + sprintf(filename, RECEIVER_OUTPUT, my_num); + outf.open(filename); + for (unsigned int i = 0; i < nOTs; i++) + { + receiver_output.assign_bytes((char*) tinfos[0].ot_ext->get_receiver_output(i), sizeof(__m128i)); + receiver_output.output(outf, false); + } + outf.close(); + + for (int i = 0; i < 2; i++) + { + sprintf(filename, SENDER_OUTPUT, my_num, i); + outf.open(filename); + for (int j = 0; j < nOTs; j++) + { + sender_output.assign_bytes((char*) tinfos[0].ot_ext->get_sender_output(i, j), sizeof(__m128i)); + sender_output.output(outf, false); + } + outf.close(); + } + } + + for (int i = 0; i < nthreads; i++) + { + delete players[i]; + delete tinfos[i].ot_ext; + } +} diff --git a/OT/OTMachine.h b/OT/OTMachine.h new file mode 100644 index 000000000..9706e68f5 --- /dev/null +++ b/OT/OTMachine.h @@ -0,0 +1,33 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * OTMachine.h + * + */ + +#ifndef OT_OTMACHINE_H_ +#define OT_OTMACHINE_H_ + +#include "OT/OTExtension.h" +#include "Tools/ezOptionParser.h" + +class OTMachine +{ + ez::ezOptionParser opt; + OT_ROLE ot_role; + +public: + int my_num, portnum_base, nthreads, nloops, nsubloops, nbase; + long nOTs; + bool passive; + TwoPartyPlayer* P; + BitVector baseReceiverInput; + BaseOT* bot_; + vector N; + + OTMachine(int argc, const char** argv); + ~OTMachine(); + void run(); +}; + +#endif /* OT_OTMACHINE_H_ */ diff --git a/OT/OTMultiplier.cpp b/OT/OTMultiplier.cpp new file mode 100644 index 000000000..658ae6504 --- /dev/null +++ b/OT/OTMultiplier.cpp @@ -0,0 +1,164 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * OTMultiplier.cpp + * + */ + +#include "OT/OTMultiplier.h" +#include "OT/NPartyTripleGenerator.h" + +#include + +template +OTMultiplier::OTMultiplier(NPartyTripleGenerator& generator, + int thread_num) : + generator(generator), thread_num(thread_num), + rot_ext(128, 128, 0, 1, + generator.players[thread_num], generator.baseReceiverInput, + generator.baseSenderInputs[thread_num], + generator.baseReceiverOutputs[thread_num], BOTH, !generator.machine.check) +{ + c_output.resize(generator.nTriplesPerLoop); + pthread_mutex_init(&mutex, 0); + pthread_cond_init(&ready, 0); + thread = 0; +} + +template +OTMultiplier::~OTMultiplier() +{ + pthread_mutex_destroy(&mutex); + pthread_cond_destroy(&ready); +} + +template +void OTMultiplier::multiply() +{ + BitVector keyBits(generator.field_size); + keyBits.set_int128(0, generator.machine.get_mac_key().to_m128i()); + rot_ext.extend(generator.field_size, keyBits); + vector< vector > senderOutput(128); + vector receiverOutput; + for (int j = 0; j < 128; j++) + { + senderOutput[j].resize(2); + for (int i = 0; i < 2; i++) + { + senderOutput[j][i].resize(128); + senderOutput[j][i].set_int128(0, rot_ext.senderOutputMatrices[i].squares[0].rows[j]); + } + } + rot_ext.receiverOutputMatrix.to(receiverOutput); + OTExtensionWithMatrix auth_ot_ext(128, 128, 0, 1, + generator.players[thread_num], keyBits, senderOutput, + receiverOutput, BOTH, true); + + if (generator.machine.generateBits) + multiplyForBits(auth_ot_ext); + else + multiplyForTriples(auth_ot_ext); +} + +template +void OTMultiplier::multiplyForTriples(OTExtensionWithMatrix& auth_ot_ext) +{ + auth_ot_ext.resize(generator.nPreampTriplesPerLoop * generator.field_size); + + // dummy input for OT correlator + vector _; + vector< vector > __; + BitVector ___; + + OTExtensionWithMatrix otCorrelator(0, 0, 0, 0, generator.players[thread_num], + ___, __, _, BOTH, true); + otCorrelator.resize(128 * generator.nPreampTriplesPerLoop); + + rot_ext.resize(generator.field_size * generator.nPreampTriplesPerLoop + 2 * 128); + + pthread_mutex_lock(&mutex); + pthread_cond_signal(&ready); + pthread_cond_wait(&ready, &mutex); + + for (int i = 0; i < generator.nloops; i++) + { + BitVector aBits = generator.valueBits[0]; + //timers["Extension"].start(); + rot_ext.extend(generator.field_size * generator.nPreampTriplesPerLoop, aBits); + //timers["Extension"].stop(); + + //timers["Correlation"].start(); + otCorrelator.baseReceiverInput = aBits; + otCorrelator.setup_for_correlation(rot_ext.senderOutputMatrices, rot_ext.receiverOutputMatrix); + otCorrelator.correlate(0, generator.nPreampTriplesPerLoop, generator.valueBits[1], false, generator.nAmplify); + //timers["Correlation"].stop(); + + //timers["Triple computation"].start(); + + otCorrelator.reduce_squares(generator.nPreampTriplesPerLoop, c_output); + + pthread_cond_signal(&ready); + pthread_cond_wait(&ready, &mutex); + + if (generator.machine.generateMACs) + { + macs.resize(3); + for (int j = 0; j < 3; j++) + { + int nValues = generator.nTriplesPerLoop; + if (generator.machine.check && (j % 2 == 0)) + nValues *= 2; + auth_ot_ext.expand(0, nValues); + auth_ot_ext.correlate(0, nValues, generator.valueBits[j], true); + auth_ot_ext.reduce_squares(nValues, macs[j]); + } + + pthread_cond_signal(&ready); + pthread_cond_wait(&ready, &mutex); + } + } + + pthread_mutex_unlock(&mutex); +} + +template<> +void OTMultiplier::multiplyForBits(OTExtensionWithMatrix& auth_ot_ext) +{ + multiplyForTriples(auth_ot_ext); +} + +template<> +void OTMultiplier::multiplyForBits(OTExtensionWithMatrix& auth_ot_ext) +{ + int nBits = generator.nTriplesPerLoop + generator.field_size; + int nBlocks = ceil(1.0 * nBits / generator.field_size); + auth_ot_ext.resize(nBlocks * generator.field_size); + macs.resize(1); + macs[0].resize(nBits); + + pthread_mutex_lock(&mutex); + pthread_cond_signal(&ready); + pthread_cond_wait(&ready, &mutex); + + for (int i = 0; i < generator.nloops; i++) + { + auth_ot_ext.expand(0, nBlocks); + auth_ot_ext.correlate(0, nBlocks, generator.valueBits[0], true); + auth_ot_ext.transpose(0, nBlocks); + + for (int j = 0; j < nBits; j++) + { + int128 r = auth_ot_ext.receiverOutputMatrix.squares[j/128].rows[j%128]; + int128 s = auth_ot_ext.senderOutputMatrices[0].squares[j/128].rows[j%128]; + macs[0][j] = r ^ s; + } + + pthread_cond_signal(&ready); + pthread_cond_wait(&ready, &mutex); + } + + pthread_mutex_unlock(&mutex); +} + +template class OTMultiplier; +template class OTMultiplier; diff --git a/OT/OTMultiplier.h b/OT/OTMultiplier.h new file mode 100644 index 000000000..15d9530e4 --- /dev/null +++ b/OT/OTMultiplier.h @@ -0,0 +1,41 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * OTMultiplier.h + * + */ + +#ifndef OT_OTMULTIPLIER_H_ +#define OT_OTMULTIPLIER_H_ + +#include +using namespace std; + +#include "OT/OTExtensionWithMatrix.h" +#include "Tools/random.h" + +class NPartyTripleGenerator; + +template +class OTMultiplier +{ + void multiplyForTriples(OTExtensionWithMatrix& auth_ot_ext); + void multiplyForBits(OTExtensionWithMatrix& auth_ot_ext); +public: + NPartyTripleGenerator& generator; + int thread_num; + OTExtensionWithMatrix rot_ext; + //OTExtensionWithMatrix* auth_ot_ext; + vector c_output; + vector< vector > macs; + + pthread_t thread; + pthread_mutex_t mutex; + pthread_cond_t ready; + + OTMultiplier(NPartyTripleGenerator& generator, int thread_num); + ~OTMultiplier(); + void multiply(); +}; + +#endif /* OT_OTMULTIPLIER_H_ */ diff --git a/OT/OTTripleSetup.cpp b/OT/OTTripleSetup.cpp new file mode 100644 index 000000000..049eda821 --- /dev/null +++ b/OT/OTTripleSetup.cpp @@ -0,0 +1,44 @@ +// (C) 2016 University of Bristol. See License.txt + +#include "OTTripleSetup.h" + +void OTTripleSetup::setup() +{ + timeval baseOTstart, baseOTend; + gettimeofday(&baseOTstart, NULL); + + G.ReSeed(); + for (int i = 0; i < nbase; i++) + { + base_receiver_inputs[i] = G.get_uchar() & 1; + } + //baseReceiverInput.randomize(G); + + for (int i = 0; i < nparties - 1; i++) + { + baseOTs[i]->set_receiver_inputs(base_receiver_inputs); + baseOTs[i]->exec_base(false); + } + gettimeofday(&baseOTend, NULL); + double basetime = timeval_diff(&baseOTstart, &baseOTend); + cout << "\t\tBaseTime: " << basetime/1000000 << endl << flush; + + // Receiver send something to force synchronization + // (since Sender finishes baseOTs before Receiver) +} + +void OTTripleSetup::close_connections() +{ + for (size_t i = 0; i < players.size(); i++) + { + delete players[i]; + } +} + +OTTripleSetup::~OTTripleSetup() +{ + for (size_t i = 0; i < baseOTs.size(); i++) + { + delete baseOTs[i]; + } +} diff --git a/OT/OTTripleSetup.h b/OT/OTTripleSetup.h new file mode 100644 index 000000000..354cd21e9 --- /dev/null +++ b/OT/OTTripleSetup.h @@ -0,0 +1,91 @@ +// (C) 2016 University of Bristol. See License.txt + +#ifndef OT_TRIPLESETUP_H_ +#define OT_TRIPLESETUP_H_ + +#include "Networking/Player.h" +#include "OT/BaseOT.h" +#include "OT/OTMachine.h" +#include "Tools/random.h" +#include "Tools/time-func.h" +#include "Math/gfp.h" + +/* + * Class for creating and storing base OTs between every pair of parties. + */ +class OTTripleSetup +{ + vector base_receiver_inputs; + vector< vector< vector > > baseSenderInputs; + vector< vector > baseReceiverOutputs; + + PRNG G; + int nparties; + int my_num; + int nbase; + bool real_OTs; + +public: + map timers; + vector baseOTs; + vector players; + + int get_nparties() { return nparties; } + int get_nbase() { return nbase; } + int get_my_num() { return my_num; } + int get_base_receiver_input(int i) { return base_receiver_inputs[i]; } + + OTTripleSetup(Names& N, bool real_OTs) + : nparties(N.num_players()), my_num(N.my_num()), nbase(128), real_OTs(real_OTs) + { + base_receiver_inputs.resize(nbase); + players.resize(nparties - 1); + baseOTs.resize(nparties - 1); + baseSenderInputs.resize(nparties - 1); + baseReceiverOutputs.resize(nparties - 1); + + if (real_OTs) + cout << "Doing real base OTs\n"; + else + cout << "Doing fake base OTs\n"; + + for (int i = 0; i < nparties - 1; i++) + { + int other_player, id; + // i for indexing, other_player is actual number + if (i >= my_num) + other_player = i + 1; + else + other_player = i; + // unique id per pair of parties (to assign port no.) + if (my_num < other_player) + id = my_num*nparties + other_player; + else + id = other_player*nparties + my_num; + + players[i] = new TwoPartyPlayer(N, other_player, id); + + // sets up a pair of base OTs, playing both roles + if (real_OTs) + { + baseOTs[i] = new BaseOT(nbase, 128, players[i]); + } + else + { + baseOTs[i] = new FakeOT(nbase, 128, players[i]); + } + } + } + ~OTTripleSetup(); + + // run the Base OTs + void setup(); + // close down the sockets + void close_connections(); + + //template + //T get_mac_key(); +}; + + +#endif diff --git a/OT/OText_main.cpp b/OT/OText_main.cpp new file mode 100644 index 000000000..fc2edaaf7 --- /dev/null +++ b/OT/OText_main.cpp @@ -0,0 +1,13 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * OText_main.cpp + * + */ + +#include "OTMachine.h" + +int main(int argc, const char** argv) +{ + OTMachine(argc, argv).run(); +} diff --git a/OT/OutputCheck.h b/OT/OutputCheck.h new file mode 100644 index 000000000..a598b8930 --- /dev/null +++ b/OT/OutputCheck.h @@ -0,0 +1,15 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * check.h + * + */ + +#ifndef OT_OUTPUTCHECK_H_ +#define OT_OUTPUTCHECK_H_ + +#define RECEIVER_INPUT "Player-Data/OT-receiver%d-input" +#define RECEIVER_OUTPUT "Player-Data/OT-receiver%d-output" +#define SENDER_OUTPUT "Player-Data/OT-sender%d-output%d" + +#endif /* OT_OUTPUTCHECK_H_ */ diff --git a/OT/Tools.cpp b/OT/Tools.cpp new file mode 100644 index 000000000..32e9208fe --- /dev/null +++ b/OT/Tools.cpp @@ -0,0 +1,107 @@ +// (C) 2016 University of Bristol. See License.txt + +#include "Tools.h" +#include "Math/gf2nlong.h" + +void random_seed_commit(octet* seed, const TwoPartyPlayer& player, int len) +{ + PRNG G; + G.ReSeed(); + vector seed_strm(2); + vector Comm_seed(2); + vector Open_seed(2); + G.get_octetStream(seed_strm[0], len); + + Commit(Comm_seed[0], Open_seed[0], seed_strm[0], player.my_num()); + player.send_receive_player(Comm_seed); + player.send_receive_player(Open_seed); + + memset(seed, 0, len*sizeof(octet)); + + if (!Open(seed_strm[1], Comm_seed[1], Open_seed[1], player.other_player_num())) + { + throw invalid_commitment(); + } + for (int i = 0; i < len; i++) + { + seed[i] = seed_strm[0].get_data()[i] ^ seed_strm[1].get_data()[i]; + } +} + +void shiftl128(word x1, word x2, word& res1, word& res2, size_t k) +{ + if (k > 128) + throw invalid_length(); + if (k >= 64) // shifting a 64-bit integer by more than 63 bits is "undefined" + { + x1 = x2; + x2 = 0; + shiftl128(x1, x2, res1, res2, k - 64); + } + else + { + res1 = (x1 << k) | (x2 >> (64-k)); + res2 = (x2 << k); + } +} + +// reduce modulo x^128 + x^7 + x^2 + x + 1 +// NB this is incorrect as it bit-reflects the result as required for +// GCM mode +void gfred128(__m128i tmp3, __m128i tmp6, __m128i *res) +{ + __m128i tmp2, tmp4, tmp5, tmp7, tmp8, tmp9; + tmp7 = _mm_srli_epi32(tmp3, 31); + tmp8 = _mm_srli_epi32(tmp6, 31); + + tmp3 = _mm_slli_epi32(tmp3, 1); + tmp6 = _mm_slli_epi32(tmp6, 1); + + tmp9 = _mm_srli_si128(tmp7, 12); + tmp8 = _mm_slli_si128(tmp8, 4); + tmp7 = _mm_slli_si128(tmp7, 4); + tmp3 = _mm_or_si128(tmp3, tmp7); + tmp6 = _mm_or_si128(tmp6, tmp8); + tmp6 = _mm_or_si128(tmp6, tmp9); + + tmp7 = _mm_slli_epi32(tmp3, 31); + tmp8 = _mm_slli_epi32(tmp3, 30); + tmp9 = _mm_slli_epi32(tmp3, 25); + + tmp7 = _mm_xor_si128(tmp7, tmp8); + tmp7 = _mm_xor_si128(tmp7, tmp9); + tmp8 = _mm_srli_si128(tmp7, 4); + tmp7 = _mm_slli_si128(tmp7, 12); + tmp3 = _mm_xor_si128(tmp3, tmp7); + + tmp2 = _mm_srli_epi32(tmp3, 1); + tmp4 = _mm_srli_epi32(tmp3, 2); + tmp5 = _mm_srli_epi32(tmp3, 7); + tmp2 = _mm_xor_si128(tmp2, tmp4); + tmp2 = _mm_xor_si128(tmp2, tmp5); + tmp2 = _mm_xor_si128(tmp2, tmp8); + tmp3 = _mm_xor_si128(tmp3, tmp2); + + tmp6 = _mm_xor_si128(tmp6, tmp3); + *res = tmp6; +} + +// Based on Intel's code for GF(2^128) mul, with reduction +void gfmul128 (__m128i a, __m128i b, __m128i *res) +{ + __m128i tmp3, tmp6; + mul128(a, b, &tmp3, &tmp6); + // Now do the reduction + gfred128(tmp3, tmp6, res); +} + +string word_to_bytes(const word w) +{ + stringstream ss; + octet* bytes = (octet*) &w; + ss << hex; + for (unsigned int i = 0; i < sizeof(word); i++) + ss << (int)bytes[i] << " "; + return ss.str(); +} + diff --git a/OT/Tools.h b/OT/Tools.h new file mode 100644 index 000000000..1038d6430 --- /dev/null +++ b/OT/Tools.h @@ -0,0 +1,51 @@ +// (C) 2016 University of Bristol. See License.txt + +#ifndef _OTTOOLS +#define _OTTOOLS + +#include "Networking/Player.h" +#include "Tools/Commit.h" +#include "Tools/random.h" + +#define SEED_SIZE_BYTES SEED_SIZE + +/* + * Generate a secure, random seed between 2 parties via commitment + */ +void random_seed_commit(octet* seed, const TwoPartyPlayer& player, int len); + +/* + * GF(2^128) multiplication using Intel instructions + * (should this go in gf2n class???) + */ +void gfmul128(__m128i a, __m128i b, __m128i *res); +void gfred128(__m128i a1, __m128i a2, __m128i *res); + +//#if defined(__SSE2__) +/* + * Convert __m128i to string of type T + */ +template +string __m128i_toString(const __m128i var) { + stringstream sstr; + sstr << hex; + const T* values = (const T*) &var; + if (sizeof(T) == 1) { + for (unsigned int i = 0; i < sizeof(__m128i); i++) { + sstr << (int) values[i] << " "; + } + } else { + for (unsigned int i = 0; i < sizeof(__m128i) / sizeof(T); i++) { + sstr << values[i] << " "; + } + } + return sstr.str(); +} +//#endif + +string word_to_bytes(const word w); + +void shiftl128(word x1, word x2, word& res1, word& res2, size_t k); + + +#endif diff --git a/OT/TripleMachine.cpp b/OT/TripleMachine.cpp new file mode 100644 index 000000000..b1f72c9d1 --- /dev/null +++ b/OT/TripleMachine.cpp @@ -0,0 +1,270 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * TripleMachine.cpp + * + */ + +#include +#include "OT/NPartyTripleGenerator.h" +#include "OT/OTMachine.h" +#include "OT/OTTripleSetup.h" +#include "Math/gf2n.h" +#include "Math/Setup.h" +#include "Tools/ezOptionParser.h" +#include "Math/Setup.h" + +#include +#include +using namespace std; + +template +void* run_ngenerator_thread(void* ptr) +{ + ((NPartyTripleGenerator*)ptr)->generate(); + return 0; +} + +TripleMachine::TripleMachine(int argc, const char** argv) +{ + ez::ezOptionParser opt; + opt.add( + "2", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Number of parties (default: 2).", // Help description. + "-N", // Flag token. + "--nparties" // Flag token. + ); + opt.add( + "", // Default. + 1, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "This player's number, 0/1 (required).", // Help description. + "-p", // Flag token. + "--player" // Flag token. + ); + opt.add( + "1", + 0, + 1, + 0, + "Number of threads (default: 1).", + "-x", + "--nthreads" + ); + opt.add( + "1000", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Number of triples (default: 1000).", // Help description. + "-n", // Flag token. + "--ntriples" // Flag token. + ); + opt.add( + "1", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Number of loops (default: 1).", // Help description. + "-l", // Flag token. + "--nloops" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Generate MACs (implies -a).", // Help description. + "-m", // Flag token. + "--macs" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Amplify triples.", // Help description. + "-a", // Flag token. + "--amplify" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Check triples (implies -m).", // Help description. + "-c", // Flag token. + "--check" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "GF(p) triples", // Help description. + "-P", // Flag token. + "--prime-field" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Channel bonding", // Help description. + "-b", // Flag token. + "--bonding" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Generate bits", // Help description. + "-B", // Flag token. + "--bits" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Write results to files.", // Help description. + "-o", // Flag token. + "--output" // Flag token. + ); + + opt.parse(argc, argv); + opt.get("-p")->getInt(my_num); + opt.get("-N")->getInt(nplayers); + opt.get("-x")->getInt(nthreads); + opt.get("-n")->getInt(ntriples); + opt.get("-l")->getInt(nloops); + generateBits = opt.get("-B")->isSet; + check = opt.get("-c")->isSet || generateBits; + generateMACs = opt.get("-m")->isSet || check; + amplify = opt.get("-a")->isSet || generateMACs; + primeField = opt.get("-P")->isSet; + bonding = opt.get("-b")->isSet; + output = opt.get("-o")->isSet; + + if (!opt.isSet("-p")) + { + string usage; + opt.getUsage(usage); + cout << usage; + exit(0); + } + + nTriplesPerThread = DIV_CEIL(ntriples, nthreads); + + prep_data_dir = get_prep_dir(nplayers, 128, 128); + ofstream outf; + bigint p; + generate_online_setup(outf, prep_data_dir, p, 128, 128); + // doesn't work with Montgomery multiplication + gfp::init_field(p, false); + gf2n::init_field(128); + + PRNG G; + G.ReSeed(); + mac_key2.randomize(G); + mac_keyp.randomize(G); +} + +void TripleMachine::run() +{ + cout << "my_num: " << my_num << endl; + Names N[2]; + N[0].init(my_num, nplayers, 10000, "HOSTS"); + int nConnections = 1; + if (bonding) + { + N[1].init(my_num, nplayers, 11000, "HOSTS2"); + nConnections = 2; + } + // do the base OTs + OTTripleSetup setup(N[0], false); + setup.setup(); + setup.close_connections(); + + vector generators(nthreads); + vector threads(nthreads); + + for (int i = 0; i < nthreads; i++) + { + generators[i] = new NPartyTripleGenerator(setup, N[i%nConnections], i, nTriplesPerThread, nloops, *this); + } + ntriples = generators[0]->nTriples * nthreads; + cout <<"Setup generators\n"; + for (int i = 0; i < nthreads; i++) + { + // lock before starting thread to avoid race condition + generators[i]->lock(); + if (primeField) + pthread_create(&threads[i], 0, run_ngenerator_thread, generators[i]); + else + pthread_create(&threads[i], 0, run_ngenerator_thread, generators[i]); + } + + // wait for initialization, then start clock and computation + for (int i = 0; i < nthreads; i++) + generators[i]->wait(); + cout << "Starting computation" << endl; + gettimeofday(&start, 0); + for (int i = 0; i < nthreads; i++) + { + generators[i]->signal(); + generators[i]->unlock(); + } + + // wait for threads to finish + for (int i = 0; i < nthreads; i++) + { + pthread_join(threads[i],NULL); + cout << "thread " << i+1 << " finished\n" << flush; + } + + map& timers = generators[0]->timers; + for (map::iterator it = timers.begin(); it != timers.end(); it++) + { + double sum = 0; + for (size_t i = 0; i < generators.size(); i++) + sum += generators[i]->timers[it->first].elapsed(); + cout << it->first << " on average took time " + << sum / generators.size() << endl; + } + + gettimeofday(&stop, 0); + double time = timeval_diff_in_seconds(&start, &stop); + cout << "Time: " << time << endl; + cout << "Throughput: " << ntriples / time << endl; + + for (size_t i = 0; i < generators.size(); i++) + delete generators[i]; + + output_mac_keys(); +} + +void TripleMachine::output_mac_keys() +{ + stringstream ss; + ss << prep_data_dir << "Player-MAC-Keys-P" << my_num; + cout << "Writing MAC key to " << ss.str() << endl; + ofstream outputFile(ss.str().c_str()); + outputFile << nplayers << endl; + outputFile << mac_keyp << " " << mac_key2 << endl; +} + +template<> gf2n TripleMachine::get_mac_key() +{ + return mac_key2; +} + +template<> gfp TripleMachine::get_mac_key() +{ + return mac_keyp; +} diff --git a/OT/TripleMachine.h b/OT/TripleMachine.h new file mode 100644 index 000000000..d377bca22 --- /dev/null +++ b/OT/TripleMachine.h @@ -0,0 +1,40 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * TripleMachine.h + * + */ + +#ifndef OT_TRIPLEMACHINE_H_ +#define OT_TRIPLEMACHINE_H_ + +#include "Math/gf2n.h" +#include "Math/gfp.h" + +class TripleMachine +{ + gf2n mac_key2; + gfp mac_keyp; + +public: + int my_num, nplayers, nthreads, ntriples, nloops; + int nTriplesPerThread; + string prep_data_dir; + bool generateMACs; + bool amplify; + bool check; + bool primeField; + bool bonding; + bool generateBits; + bool output; + struct timeval start, stop; + + TripleMachine(int argc, const char** argv); + void run(); + + template + T get_mac_key(); + void output_mac_keys(); +}; + +#endif /* OT_TRIPLEMACHINE_H_ */ diff --git a/Player-Online.cpp b/Player-Online.cpp new file mode 100644 index 000000000..246e845c9 --- /dev/null +++ b/Player-Online.cpp @@ -0,0 +1,179 @@ +// (C) 2016 University of Bristol. See License.txt + +#include "Processor/Machine.h" +#include "Tools/ezOptionParser.h" + +#include +#include +#include +using namespace std; + +int main(int argc, const char** argv) +{ + ez::ezOptionParser opt; + + opt.syntax = "./Player-Online.x [OPTIONS] \n"; + opt.example = "./Player-Online.x -lgp 64 -lg2 128 -m new 0 sample-prog\n./Player-Online.x -pn 13000 -h localhost 1 sample-prog\n"; + + opt.add( + "128", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Bit length of GF(p) field (default: 128)", // Help description. + "-lgp", // Flag token. + "--lgp" // Flag token. + ); + opt.add( + "40", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Bit length of GF(2^n) field (default: 40)", // Help description. + "-lg2", // Flag token. + "--lg2" // Flag token. + ); + opt.add( + "5000", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Port number base to attempt to start connections from (default: 5000)", // Help description. + "-pn", // Flag token. + "--portnumbase" // Flag token. + ); + opt.add( + "localhost", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Host where Server.x is running (default: localhost)", // Help description. + "-h", // Flag token. + "--hostname" // Flag token. + ); + opt.add( + "empty", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Where to obtain memory, new|old|empty (default: empty)\n\t" + "new: copy from Player-Memory-P file\n\t" + "old: reuse previous memory in Memory-P\n\t" + "empty: create new empty memory", // Help description. + "-m", // Flag token. + "--memory" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Direct communication instead of star-shaped", // Help description. + "-d", // Flag token. + "--direct" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Star-shaped communication handled by background threads", // Help description. + "-P", // Flag token. + "--parallel" // Flag token. + ); + opt.add( + "0", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Sum at most n shares at once when using indirect communication", // Help description. + "-s", // Flag token. + "--opening-sum" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Use player-specific threads for communication", // Help description. + "-t", // Flag token. + "--threads" // Flag token. + ); + opt.add( + "0", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Maximum number of parties to send to at once", // Help description. + "-b", // Flag token. + "--max-broadcast" // Flag token. + ); + + opt.parse(argc, argv); + + vector allArgs(opt.firstArgs); + allArgs.insert(allArgs.end(), opt.lastArgs.begin(), opt.lastArgs.end()); + string progname; + int playerno; + string usage; + vector badOptions; + unsigned int i; + + if (allArgs.size() != 3) + { + cerr << "ERROR: incorrect number of arguments to Player-Online.x\n"; + cerr << "Arguments given were:\n"; + for (unsigned int j = 1; j < allArgs.size(); j++) + cout << "'" << *allArgs[j] << "'" << endl; + opt.getUsage(usage); + cout << usage; + return 1; + } + else + { + playerno = atoi(allArgs[1]->c_str()); + progname = *allArgs[2]; + + } + + if(!opt.gotRequired(badOptions)) + { + for (i=0; i < badOptions.size(); ++i) + cerr << "ERROR: Missing required option " << badOptions[i] << "."; + opt.getUsage(usage); + cout << usage; + return 1; + } + + if(!opt.gotExpected(badOptions)) + { + for(i=0; i < badOptions.size(); ++i) + cerr << "ERROR: Got unexpected number of arguments for option " << badOptions[i] << "."; + opt.getUsage(usage); + cout << usage; + return 1; + } + + string memtype, hostname; + int lg2, lgp, pnbase, opening_sum, max_broadcast; + + opt.get("--portnumbase")->getInt(pnbase); + opt.get("--lgp")->getInt(lgp); + opt.get("--lg2")->getInt(lg2); + opt.get("--memory")->getString(memtype); + opt.get("--hostname")->getString(hostname); + opt.get("--opening-sum")->getInt(opening_sum); + opt.get("--max-broadcast")->getInt(max_broadcast); + + + Machine(playerno, pnbase, hostname, progname, memtype, lgp, lg2, + opt.get("--direct")->isSet, opening_sum, opt.get("--parallel")->isSet, + opt.get("--threads")->isSet, max_broadcast).run(); + + cerr << "Command line:"; + for (int i = 0; i < argc; i++) + cerr << " " << argv[i]; + cerr << endl; +} + + diff --git a/Processor/Buffer.cpp b/Processor/Buffer.cpp new file mode 100644 index 000000000..a5ca44526 --- /dev/null +++ b/Processor/Buffer.cpp @@ -0,0 +1,145 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * Buffer.cpp + * + */ + +#include "Buffer.h" +#include "Processor/InputTuple.h" + +bool BufferBase::rewind = false; + + +void BufferBase::setup(ifstream* f, int length, const char* type) +{ + file = f; + tuple_length = length; + data_type = type; +} + +void BufferBase::seekg(int pos) +{ + file->seekg(pos * tuple_length); + if (file->eof() || file->fail()) + { + file->clear(); + file->seekg(0); + if (!rewind) + cerr << "REWINDING - ONLY FOR BENCHMARKING" << endl; + rewind = true; + } + next = BUFFER_SIZE; +} + +template +Buffer::~Buffer() +{ + if (timer.elapsed() && data_type) + cerr << T::type_string() << " " << data_type << " reading: " + << timer.elapsed() << endl; +} + +template +void Buffer::fill_buffer() +{ + if (T::size() == sizeof(T)) + { + // read directly + read((char*)buffer); + } + else + { + char read_buffer[sizeof(buffer)]; + read(read_buffer); + //memset(buffer, 0, sizeof(buffer)); + for (int i = 0; i < BUFFER_SIZE; i++) + buffer[i].assign(&read_buffer[i*T::size()]); + } +} + +template +void Buffer::read(char* read_buffer) +{ + int size_in_bytes = T::size() * BUFFER_SIZE; + int n_read = 0; + timer.start(); + do + { + file->read(read_buffer + n_read, size_in_bytes - n_read); + n_read += file->gcount(); + if (file->eof()) + { + file->clear(); // unset EOF flag + file->seekg(0); + if (!rewind) + cerr << "REWINDING - ONLY FOR BENCHMARKING" << endl; + rewind = true; + eof = true; + } + if (file->fail()) + { + stringstream ss; + ss << "IO problem when buffering " << T::type_string(); + if (data_type) + ss << " " << data_type; + throw file_error(ss.str()); + } + } + while (n_read < size_in_bytes); + timer.stop(); +} + +template +void Buffer::input(U& a) +{ + if (next == BUFFER_SIZE) + { + fill_buffer(); + next = 0; + } + + a = buffer[next]; + next++; +} + +template < template class U, template class V > +BufferBase& BufferHelper::get_buffer(DataFieldType field_type) +{ + if (field_type == DATA_MODP) + return bufferp; + else if (field_type == DATA_GF2N) + return buffer2; + else + throw not_implemented(); +} + +template < template class U, template class V > +void BufferHelper::setup(DataFieldType field_type, string filename, int tuple_length, const char* data_type) +{ + files[field_type] = new ifstream(filename.c_str(), ios::in | ios::binary); + if (files[field_type]->fail()) + throw file_error(filename); + get_buffer(field_type).setup(files[field_type], tuple_length, data_type); +} + +template class U, template class V> +void BufferHelper::close() +{ + for (int i = 0; i < N_DATA_FIELD_TYPE; i++) + if (files[i]) + { + files[i]->close(); + delete files[i]; + } +} + +template class Buffer< Share, Share >; +template class Buffer< Share, Share >; +template class Buffer< InputTuple, RefInputTuple >; +template class Buffer< InputTuple, RefInputTuple >; +template class Buffer< gfp, gfp >; +template class Buffer< gf2n, gf2n >; + +template class BufferHelper; +template class BufferHelper; diff --git a/Processor/Buffer.h b/Processor/Buffer.h new file mode 100644 index 000000000..a9936d1e1 --- /dev/null +++ b/Processor/Buffer.h @@ -0,0 +1,74 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * Buffer.h + * + */ + +#ifndef PROCESSOR_BUFFER_H_ +#define PROCESSOR_BUFFER_H_ + +#include +using namespace std; + +#include "Math/Share.h" +#include "Math/field_types.h" +#include "Tools/time-func.h" + +#ifndef BUFFER_SIZE +#define BUFFER_SIZE 101 +#endif + + +class BufferBase +{ +protected: + static bool rewind; + + ifstream* file; + int next; + const char* data_type; + Timer timer; + int tuple_length; + +public: + bool eof; + + BufferBase() : file(0), next(BUFFER_SIZE), data_type(0), tuple_length(-1), eof(false) {}; + void setup(ifstream* f, int length, const char* type = 0); + void seekg(int pos); + bool is_up() { return file != 0; } +}; + + +template +class Buffer : public BufferBase +{ + T buffer[BUFFER_SIZE]; + + void read(char* read_buffer); + +public: + ~Buffer(); + void input(U& a); + void fill_buffer(); +}; + + +template < template class U, template class V > +class BufferHelper +{ +public: + Buffer< U, V > bufferp; + Buffer< U, V > buffer2; + ifstream* files[N_DATA_FIELD_TYPE]; + + BufferHelper() { memset(files, 0, sizeof(files)); } + void input(V& a) { bufferp.input(a); } + void input(V& a) { buffer2.input(a); } + BufferBase& get_buffer(DataFieldType field_type); + void setup(DataFieldType field_type, string filename, int tuple_length, const char* data_type = 0); + void close(); +}; + +#endif /* PROCESSOR_BUFFER_H_ */ diff --git a/Processor/Data_Files.cpp b/Processor/Data_Files.cpp new file mode 100644 index 000000000..96d64d0ab --- /dev/null +++ b/Processor/Data_Files.cpp @@ -0,0 +1,218 @@ +// (C) 2016 University of Bristol. See License.txt + + +#include "Processor/Data_Files.h" +#include "Processor/Processor.h" + +#include + +const char* Data_Files::dtype_names[N_DTYPE] = { "Triples", "Squares", "Bits", "Inverses", "BitTriples", "BitGF2NTriples" }; +const char* Data_Files::field_names[] = { "p", "2" }; +const char* Data_Files::long_field_names[] = { "gfp", "gf2n" }; +const bool Data_Files::implemented[N_DATA_FIELD_TYPE][N_DTYPE] = { + { true, true, true, true, false, false }, + { true, true, true, true, true, true }, +}; +const int Data_Files::tuple_size[N_DTYPE] = { 3, 2, 1, 2, 3, 3 }; + +Lock Data_Files::tuple_lengths_lock; +map Data_Files::tuple_lengths; + + +void DataPositions::set_num_players(int num_players) +{ + files.resize(N_DATA_FIELD_TYPE, vector(N_DTYPE)); + inputs.resize(num_players, vector(N_DATA_FIELD_TYPE)); +} + +void DataPositions::increase(const DataPositions& delta) +{ + if (inputs.size() != delta.inputs.size()) + throw invalid_length(); + for (unsigned int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++) + { + for (unsigned int dtype = 0; dtype < N_DTYPE; dtype++) + files[field_type][dtype] += delta.files[field_type][dtype]; + for (unsigned int j = 0; j < inputs.size(); j++) + inputs[j][field_type] += delta.inputs[j][field_type]; + + map::const_iterator it; + const map& delta_ext = delta.extended[field_type]; + for (it = delta_ext.begin(); it != delta_ext.end(); it++) + extended[field_type][it->first] += it->second; + } +} + +void DataPositions::print_cost() const +{ + ifstream file("cost"); + double total_cost = 0; + for (int i = 0; i < N_DATA_FIELD_TYPE; i++) + { + cerr << " Type " << Data_Files::field_names[i] << endl; + for (int j = 0; j < N_DTYPE; j++) + { + double cost_per_item = 0; + file >> cost_per_item; + if (cost_per_item < 0) + break; + int items_used = files[i][j]; + double cost = items_used * cost_per_item; + total_cost += cost; + cerr.fill(' '); + cerr << " " << setw(10) << cost << " = " << setw(10) << items_used + << " " << setw(14) << Data_Files::dtype_names[j] << " à " << setw(11) + << cost_per_item << endl; + } + for (map::const_iterator it = extended[i].begin(); + it != extended[i].end(); it++) + { + cerr.fill(' '); + cerr << setw(27) << it->second << " " << setw(14) << it->first.get_string() << endl; + } + } + + cerr << "Total cost: " << total_cost << endl; +} + + +int Data_Files::share_length(int field_type) +{ + switch (field_type) + { + case DATA_MODP: + return 2 * gfp::t() * sizeof(mp_limb_t); + case DATA_GF2N: + return 2 * sizeof(word); + default: + throw invalid_params(); + } +} + +int Data_Files::tuple_length(int field_type, int dtype) +{ + return tuple_size[dtype] * share_length(field_type); +} + +Data_Files::Data_Files(int myn, int n, const string& prep_data_dir) : + usage(n), prep_data_dir(prep_data_dir) +{ + cerr << "Setting up Data_Files in: " << prep_data_dir << endl; + num_players=n; + my_num=myn; + char filename[1024]; + input_buffers = new BufferHelper[num_players]; + + for (int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++) + { + for (int dtype = 0; dtype < N_DTYPE; dtype++) + { + if (implemented[field_type][dtype]) + { + sprintf(filename,(prep_data_dir + "%s-%s-P%d").c_str(),dtype_names[dtype], + field_names[field_type],my_num); + buffers[dtype].setup(DataFieldType(field_type), filename, + tuple_length(field_type, dtype), dtype_names[dtype]); + } + } + + for (int i=0; i >::iterator it = + extended.begin(); it != extended.end(); it++) + it->second.close(); +} + +void Data_Files::seekg(DataPositions& pos) +{ + for (int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++) + { + for (int dtype = 0; dtype < N_DTYPE; dtype++) + if (implemented[field_type][dtype]) + buffers[dtype].get_buffer(DataFieldType(field_type)).seekg(pos.files[field_type][dtype]); + for (int j = 0; j < num_players; j++) + if (j == my_num) + my_input_buffers.get_buffer(DataFieldType(field_type)).seekg(pos.inputs[j][field_type]); + else + input_buffers[j].get_buffer(DataFieldType(field_type)).seekg(pos.inputs[j][field_type]); + for (map::const_iterator it = pos.extended[field_type].begin(); + it != pos.extended[field_type].end(); it++) + { + setup_extended(DataFieldType(field_type), it->first); + extended[it->first].get_buffer(DataFieldType(field_type)).seekg(it->second); + } + } + + usage = pos; +} + +void Data_Files::skip(const DataPositions& pos) +{ + DataPositions new_pos = usage; + new_pos.increase(pos); + seekg(new_pos); +} + +void Data_Files::setup_extended(DataFieldType field_type, const DataTag& tag, int tuple_size) +{ + BufferBase& buffer = extended[tag].get_buffer(field_type); + tuple_lengths_lock.lock(); + int tuple_length = tuple_lengths[tag]; + int my_tuple_length = tuple_size * share_length(field_type); + if (tuple_length > 0) + { + if (tuple_size > 0 && my_tuple_length != tuple_length) + { + stringstream ss; + ss << "Inconsistent size of " << field_names[field_type] << " " + << tag.get_string() << ": " << my_tuple_length << " vs " + << tuple_length; + throw Processor_Error(ss.str()); + } + } + else + tuple_lengths[tag] = my_tuple_length; + tuple_lengths_lock.unlock(); + + if (!buffer.is_up()) + { + stringstream ss; + ss << prep_data_dir << tag.get_string() << "-" << field_names[field_type] << "-P" << my_num; + extended[tag].setup(field_type, ss.str(), tuple_length); + } +} + +template +void Data_Files::get(Processor& proc, DataTag tag, const vector& regs, int vector_size) +{ + usage.extended[T::field_type()][tag] += vector_size; + setup_extended(T::field_type(), tag, regs.size()); + for (int j = 0; j < vector_size; j++) + for (unsigned int i = 0; i < regs.size(); i++) + extended[tag].input(proc.get_S_ref(regs[i] + j)); +} + +template void Data_Files::get(Processor& proc, DataTag tag, const vector& regs, int vector_size); +template void Data_Files::get(Processor& proc, DataTag tag, const vector& regs, int vector_size); diff --git a/Processor/Data_Files.h b/Processor/Data_Files.h new file mode 100644 index 000000000..2b5581121 --- /dev/null +++ b/Processor/Data_Files.h @@ -0,0 +1,158 @@ +// (C) 2016 University of Bristol. See License.txt + +#ifndef _Data_Files +#define _Data_Files + +/* This class holds the Online data files all in one place + * so the streams are easy to pass around and access + */ + +#include "Math/gfp.h" +#include "Math/gf2n.h" +#include "Math/Share.h" +#include "Math/field_types.h" +#include "Processor/Buffer.h" +#include "Processor/InputTuple.h" +#include "Tools/Lock.h" + +#include +#include +using namespace std; + +enum Dtype { DATA_TRIPLE, DATA_SQUARE, DATA_BIT, DATA_INVERSE, DATA_BITTRIPLE, DATA_BITGF2NTRIPLE, N_DTYPE }; + +class DataTag +{ + int t[4]; + +public: + // assume that tag is three integers + DataTag(const int* tag) + { + strncpy((char*)t, (char*)tag, 3 * sizeof(int)); + t[3] = 0; + } + string get_string() const + { + return string((char*)t); + } + bool operator<(const DataTag& other) const + { + for (int i = 0; i < 3; i++) + if (t[i] != other.t[i]) + return t[i] < other.t[i]; + return false; + } +}; + +struct DataPositions +{ + vector< vector > files; + vector< vector > inputs; + map extended[N_DATA_FIELD_TYPE]; + + DataPositions(int num_players = 0) { set_num_players(num_players); } + void set_num_players(int num_players); + void increase(const DataPositions& delta); + void print_cost() const; +}; + +class Processor; + +class Data_Files +{ + static map tuple_lengths; + static Lock tuple_lengths_lock; + + BufferHelper buffers[N_DTYPE]; + BufferHelper* input_buffers; + BufferHelper my_input_buffers; + map > extended; + + int my_num,num_players; + + DataPositions usage; + + const string prep_data_dir; + + public: + + static const char* dtype_names[N_DTYPE]; + static const char* field_names[N_DATA_FIELD_TYPE]; + static const char* long_field_names[N_DATA_FIELD_TYPE]; + static const bool implemented[N_DATA_FIELD_TYPE][N_DTYPE]; + static const int tuple_size[N_DTYPE]; + + static int share_length(int field_type); + static int tuple_length(int field_type, int dtype); + + Data_Files(int my_num,int n,const string& prep_data_dir); + ~Data_Files(); + + DataPositions tellg(); + void seekg(DataPositions& pos); + void skip(const DataPositions& pos); + template + bool eof(Dtype dtype); + template + bool input_eof(int player); + + void setup_extended(DataFieldType field_type, const DataTag& tag, int tuple_size = 0); + template + void get(Processor& proc, DataTag tag, const vector& regs, int vector_size); + + DataPositions get_usage() + { + return usage; + } + + template + void get_three(DataFieldType field_type, Dtype dtype, Share& a, Share& b, Share& c) + { + usage.files[field_type][dtype]++; + buffers[dtype].input(a); + buffers[dtype].input(b); + buffers[dtype].input(c); + } + + template + void get_two(DataFieldType field_type, Dtype dtype, Share& a, Share& b) + { + usage.files[field_type][dtype]++; + buffers[dtype].input(a); + buffers[dtype].input(b); + } + + template + void get_one(DataFieldType field_type, Dtype dtype, Share& a) + { + usage.files[field_type][dtype]++; + buffers[dtype].input(a); + } + + template + void get_input(Share& a,T& x,int i) + { + usage.inputs[i][T::field_type()]++; + RefInputTuple tuple(a, x); + if (i==my_num) + my_input_buffers.input(tuple); + else + input_buffers[i].input(a); + } +}; + +template inline +bool Data_Files::eof(Dtype dtype) + { return buffers[dtype].get_buffer(T::field_type()).eof; } + +template inline +bool Data_Files::input_eof(int player) +{ + if (player == my_num) + return my_input_buffers.get_buffer(T::field_type()).eof; + else + return input_buffers[player].get_buffer(T::field_type()).eof; +} + +#endif diff --git a/Processor/Input.cpp b/Processor/Input.cpp new file mode 100644 index 000000000..c76f2c9be --- /dev/null +++ b/Processor/Input.cpp @@ -0,0 +1,89 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * Input.cpp + * + */ + +#include "Input.h" +#include "Processor.h" + +template +Input::Input(Processor& proc, MAC_Check& mc) : + proc(proc), MC(mc), values_input(0) +{ + buffer.setup(&proc.private_input, -1, "private input"); +} + +template +Input::~Input() +{ + if (timer.elapsed() > 0) + cerr << T::type_string() << " inputs: " << timer.elapsed() << endl; +} + +template +void Input::adjust_mac(Share& share, T& value) +{ + T tmp; + tmp.mul(MC.get_alphai(), value); + tmp.add(share.get_mac(),tmp); + share.set_mac(tmp); +} + +template +void Input::start(int player, int n_inputs) +{ + if (player == proc.P.my_num()) + { + octetStream o; + shares.resize(n_inputs); + + for (int i = 0; i < n_inputs; i++) + { + T rr, t; + Share& share = shares[i]; + proc.DataF.get_input(share, rr, player); + T xi; + buffer.input(t); + t.sub(t, rr); + t.pack(o); + xi.add(t, share.get_share()); + share.set_share(xi); + adjust_mac(share, t); + } + + proc.P.send_all(o, true); + values_input += n_inputs; + } +} + +template +void Input::stop(int player, vector targets) +{ + T tmp; + + if (player == proc.P.my_num()) + { + for (unsigned int i = 0; i < targets.size(); i++) + proc.get_S_ref(targets[i]) = shares[i]; + } + else + { + T t; + octetStream o; + timer.start(); + proc.P.receive_player(player, o, true); + timer.stop(); + for (unsigned int i = 0; i < targets.size(); i++) + { + Share& share = proc.get_S_ref(targets[i]); + proc.DataF.get_input(share, t, player); + t.unpack(o); + adjust_mac(share, t); + } + } +} + +template class Input; +template class Input; diff --git a/Processor/Input.h b/Processor/Input.h new file mode 100644 index 000000000..bb0a767b4 --- /dev/null +++ b/Processor/Input.h @@ -0,0 +1,43 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * Input.h + * + */ + +#ifndef PROCESSOR_INPUT_H_ +#define PROCESSOR_INPUT_H_ + +#include +using namespace std; + +#include "Math/Share.h" +#include "Auth/MAC_Check.h" +#include "Processor/Buffer.h" +#include "Tools/time-func.h" + +class Processor; + +template +class Input +{ + Processor& proc; + MAC_Check& MC; + vector< Share > shares; + Buffer buffer; + Timer timer; + + void adjust_mac(Share& share, T& value); + +public: + int values_input; + + Input(Processor& proc, MAC_Check& mc); + ~Input(); + + void start(int player, int n_inputs); + void stop(int player, vector targets); + +}; + +#endif /* PROCESSOR_INPUT_H_ */ diff --git a/Processor/InputTuple.h b/Processor/InputTuple.h new file mode 100644 index 000000000..67c38a126 --- /dev/null +++ b/Processor/InputTuple.h @@ -0,0 +1,42 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * InputTuple.h + * + */ + +#ifndef PROCESSOR_INPUTTUPLE_H_ +#define PROCESSOR_INPUTTUPLE_H_ + + +template +struct InputTuple +{ + Share share; + T value; + + static int size() + { return Share::size() + T::size(); } + + static string type_string() + { return T::type_string(); } + + void assign(const char* buffer) + { + share.assign(buffer); + value.assign(buffer + Share::size()); + } +}; + + +template +struct RefInputTuple +{ + Share& share; + T& value; + RefInputTuple(Share& share, T& value) : share(share), value(value) {} + void operator=(InputTuple& other) { share = other.share; value = other.value; } +}; + + +#endif /* PROCESSOR_INPUTTUPLE_H_ */ diff --git a/Processor/Instruction.cpp b/Processor/Instruction.cpp new file mode 100644 index 000000000..01081b288 --- /dev/null +++ b/Processor/Instruction.cpp @@ -0,0 +1,1558 @@ +// (C) 2016 University of Bristol. See License.txt + + +#include "Processor/Instruction.h" +#include "Processor/Machine.h" +#include "Processor/Processor.h" +#include "Exceptions/Exceptions.h" +#include "Tools/time-func.h" + +#include +#include +#include +#include + + +// Read a byte +int get_val(istream& s) +{ + char cc; + s.get(cc); + int a=cc; + if (a<0) { a+=256; } + return a; +} + +// Read a 4-byte integer +int get_int(istream& s) +{ + int n = 0; + for (int i=0; i<4; i++) + { n<<=8; + int t=get_val(s); + n+=t; + } + return n; +} + + +// Convert modp to signed bigint of a given bit length +void to_signed_bigint(bigint& bi, const gfp& x, int len) +{ + to_bigint(bi, x); + int neg; + // get sign and abs(x) + bigint p_half=(gfp::pr()-1)/2; + if (mpz_cmp(bi.get_mpz_t(), p_half.get_mpz_t()) < 0) + neg = 0; + else + { + bi = gfp::pr() - bi; + neg = 1; + } + // reduce to range -2^(len-1), ..., 2^(len-1) + bigint one = 1; + bi &= (one << len) - 1; + if (neg) + bi = -bi; +} + + +void get_vector(int m, vector& start, istream& s) +{ + start.resize(m); + for (int i = 0; i < m; i++) + start[i] = get_int(s); +} + + +void Instruction::parse(istream& s) +{ + n=0; start.resize(0); + r[0]=0; r[1]=0; r[2]=0; + + int pos=s.tellg(); + opcode=get_int(s); + size=opcode>>9; + opcode&=0x1FF; + + if (size==0) + size=1; + + switch (opcode) + { + // instructions with 3 register operands + case ADDC: + case ADDS: + case ADDM: + case SUBC: + case SUBS: + case SUBML: + case SUBMR: + case MULC: + case MULM: + case DIVC: + case MODC: + case TRIPLE: + case ANDC: + case XORC: + case ORC: + case SHLC: + case SHRC: + case GADDC: + case GADDS: + case GADDM: + case GSUBC: + case GSUBS: + case GSUBML: + case GSUBMR: + case GMULC: + case GMULM: + case GDIVC: + case GTRIPLE: + case GBITTRIPLE: + case GBITGF2NTRIPLE: + case GANDC: + case GXORC: + case GORC: + case GMULBITC: + case GMULBITM: + case LTC: + case GTC: + case EQC: + case ADDINT: + case SUBINT: + case MULINT: + case DIVINT: + r[0]=get_int(s); + r[1]=get_int(s); + r[2]=get_int(s); + break; + // instructions with 2 register operands + case LDMCI: + case LDMSI: + case STMCI: + case STMSI: + case MOVC: + case MOVS: + case MOVINT: + case LDMINTI: + case STMINTI: + case LEGENDREC: + case SQUARE: + case INV: + case GINV: + case CONVINT: + case GLDMCI: + case GLDMSI: + case GSTMCI: + case GSTMSI: + case GMOVC: + case GMOVS: + case GSQUARE: + case GNOTC: + case GCONVINT: + case GCONVGF2N: + case LTZC: + case EQZC: + case RAND: + case PROTECTMEMS: + case PROTECTMEMC: + case GPROTECTMEMS: + case GPROTECTMEMC: + case PROTECTMEMINT: + r[0]=get_int(s); + r[1]=get_int(s); + break; + // instructions with 1 register operand + case BIT: + case PRINTMEM: + case PRINTREGPLAIN: + case LDTN: + case LDARG: + case STARG: + case JMPI: + case GBIT: + case GPRINTMEM: + case GPRINTREGPLAIN: + case JOIN_TAPE: + case PUSHINT: + case POPINT: + case PUBINPUT: + case RAWOUTPUT: + case GRAWOUTPUT: + case PRINTCHRINT: + case PRINTSTRINT: + r[0]=get_int(s); + break; + // instructions with 3 registers + 1 integer operand + r[0]=get_int(s); + r[1]=get_int(s); + r[2]=get_int(s); + n = get_int(s); + break; + // instructions with 2 registers + 1 integer operand + case ADDCI: + case ADDSI: + case SUBCI: + case SUBSI: + case SUBCFI: + case SUBSFI: + case MULCI: + case MULSI: + case DIVCI: + case MODCI: + case ANDCI: + case XORCI: + case ORCI: + case SHLCI: + case SHRCI: + case NOTC: + case CONVMODP: + case GADDCI: + case GADDSI: + case GSUBCI: + case GSUBSI: + case GSUBCFI: + case GSUBSFI: + case GMULCI: + case GMULSI: + case GDIVCI: + case GANDCI: + case GXORCI: + case GORCI: + case GSHLCI: + case GSHRCI: + case USE: + case USE_INP: + case RUN_TAPE: + case STARTPRIVATEOUTPUT: + case GSTARTPRIVATEOUTPUT: + r[0]=get_int(s); + r[1]=get_int(s); + n = get_int(s); + break; + // instructions with 1 register + 1 integer operand + case LDI: + case LDSI: + case LDMC: + case LDMS: + case STMC: + case STMS: + case LDMINT: + case STMINT: + case INPUT: + case JMPNZ: + case JMPEQZ: + case GLDI: + case GLDSI: + case GLDMC: + case GLDMS: + case GSTMC: + case GSTMS: + case GINPUT: + case PRINTREG: + case GPRINTREG: + case LDINT: + case STARTINPUT: + case GSTARTINPUT: + case STOPPRIVATEOUTPUT: + case GSTOPPRIVATEOUTPUT: + case INPUTMASK: + case GINPUTMASK: + case READSOCKETC: + case READSOCKETS: + case WRITESOCKETC: + case WRITESOCKETS: + r[0]=get_int(s); + n = get_int(s); + break; + // instructions with 1 integer operand + case PRINTSTR: + case PRINTCHR: + case JMP: + case START: + case STOP: + case OPENSOCKET: + n = get_int(s); + break; + // instructions with no operand + case TIME: + case CRASH: + case CLOSESOCKET: + break; + // open instructions + case STARTOPEN: + case STOPOPEN: + case GSTARTOPEN: + case GSTOPOPEN: + int m; + m = get_int(s); + get_vector(m, start, s); + break; + // raw input + case STOPINPUT: + case GSTOPINPUT: + // subtract player number argument + m = get_int(s) - 1; + n = get_int(s); + get_vector(m, start, s); + break; + case GBITDEC: + case GBITCOM: + m = get_int(s) - 2; + r[0] = get_int(s); + n = get_int(s); + get_vector(m, start, s); + break; + case PREP: + case GPREP: + // subtract extra argument + m = get_int(s) - 1; + s.read((char*)r, sizeof(r)); + start.resize(m); + for (int i = 0; i < m; i++) + { start[i] = get_int(s); } + break; + case USE_PREP: + case GUSE_PREP: + s.read((char*)r, sizeof(r)); + n = get_int(s); + break; + case REQBL: + n = get_int(s); + if (n > 0 && gfp::pr() < bigint(1) << (n-1)) + { + cout << "Tape requires prime of bit length " << n << endl; + throw invalid_params(); + } + break; + case GREQBL: + n = get_int(s); + if (n > 0 && gf2n::degree() < int(n)) + { + stringstream ss; + ss << "Tape requires prime of bit length " << n << endl; + throw Processor_Error(ss.str()); + } + break; + default: + ostringstream os; + os << "Invalid instruction " << hex << showbase << opcode << " at " << pos; + throw Processor_Error(os.str()); + } +} + + +bool Instruction::get_offline_data_usage(DataPositions& usage) +{ + switch (opcode) + { + case USE: + if (r[0] >= N_DATA_FIELD_TYPE) + throw invalid_program(); + if (r[1] >= N_DTYPE) + throw invalid_program(); + usage.files[r[0]][r[1]] = n; + return int(n) >= 0; + case USE_INP: + if (r[0] >= N_DATA_FIELD_TYPE) + throw invalid_program(); + if ((unsigned)r[1] >= usage.inputs.size()) + throw Processor_Error("Player number too high"); + usage.inputs[r[1]][r[0]] = n; + return int(n) >= 0; + case USE_PREP: + usage.extended[gfp::field_type()][r] = n; + return int(n) >= 0; + case GUSE_PREP: + usage.extended[gf2n::field_type()][r] = n; + return int(n) >= 0; + default: + return true; + } +} + +RegType Instruction::get_reg_type() const +{ + switch (opcode) { + case LDMINT: + case STMINT: + case LDMINTI: + case STMINTI: + case PUSHINT: + case POPINT: + case MOVINT: + return INT; + case PREP: + case USE_PREP: + case GUSE_PREP: + // those use r[] for a string + return NONE; + default: + if (is_gf2n_instruction()) + return GF2N; + else if (opcode >> 4 == 0x9) + return INT; + else + return MODP; + } +} + +int Instruction::get_max_reg(RegType reg_type) const +{ + if (get_reg_type() != reg_type) { return 0; } + + if (start.size()) + return *max_element(start.begin(), start.end()) + size; + else + return *max_element(r, r + 3) + size; +} + +int Instruction::get_mem(RegType reg_type, SecrecyType sec_type) const +{ + if (get_reg_type() == reg_type and is_direct_memory_access(sec_type)) + return n + size; + else + return 0; +} + +bool Instruction::is_direct_memory_access(SecrecyType sec_type) const +{ + if (sec_type == SECRET) + { + switch (opcode) + { + case LDMS: + case STMS: + case GLDMS: + case GSTMS: + return true; + default: + return false; + } + } + else + { + switch (opcode) + { + case LDMC: + case STMC: + case GLDMC: + case GSTMC: + case LDMINT: + case STMINT: + return true; + default: + return false; + } + } +} + + + +ostream& operator<<(ostream& s,const Instruction& instr) +{ + s << instr.opcode << " : "; + for (int i=0; i<3; i++) + { s << instr.r[i] << " "; } + s << " : " << instr.n; + if (instr.start.size()!=0) + { s << " : " << instr.start.size() << " : "; + for (unsigned int i=0; ir[0], this->r[1], this->r[2]}; + int n = this->n; + for (int i = 0; i < size; i++) + { switch (opcode) + { case LDI: + Proc.temp.ansp.assign(n); + Proc.write_Cp(r[0],Proc.temp.ansp); + break; + case GLDI: + Proc.temp.ans2.assign(n); + Proc.write_C2(r[0],Proc.temp.ans2); + break; + case LDSI: + { Proc.temp.ansp.assign(n); + if (Proc.P.my_num()==0) + Proc.get_Sp_ref(r[0]).set_share(Proc.temp.ansp); + else + Proc.get_Sp_ref(r[0]).assign_zero(); + gfp& tmp=Proc.temp.tmpp; + tmp.mul(Proc.MCp.get_alphai(),Proc.temp.ansp); + Proc.get_Sp_ref(r[0]).set_mac(tmp); + } + break; + case GLDSI: + { Proc.temp.ans2.assign(n); + if (Proc.P.my_num()==0) + Proc.get_S2_ref(r[0]).set_share(Proc.temp.ans2); + else + Proc.get_S2_ref(r[0]).assign_zero(); + gf2n& tmp=Proc.temp.tmp2; + tmp.mul(Proc.MC2.get_alphai(),Proc.temp.ans2); + Proc.get_S2_ref(r[0]).set_mac(tmp); + } + break; + case LDMC: + Proc.write_Cp(r[0],Proc.machine.Mp.read_C(n)); + n++; + break; + case GLDMC: + Proc.write_C2(r[0],Proc.machine.M2.read_C(n)); + n++; + break; + case LDMS: + Proc.write_Sp(r[0],Proc.machine.Mp.read_S(n)); + n++; + break; + case GLDMS: + Proc.write_S2(r[0],Proc.machine.M2.read_S(n)); + n++; + break; + case LDMINT: + Proc.write_Ci(r[0],Proc.machine.Mi.read_C(n).get()); + n++; + break; + case LDMCI: + Proc.write_Cp(r[0], Proc.machine.Mp.read_C(Proc.read_Ci(r[1]))); + break; + case GLDMCI: + Proc.write_C2(r[0], Proc.machine.M2.read_C(Proc.read_Ci(r[1]))); + break; + case LDMSI: + Proc.write_Sp(r[0], Proc.machine.Mp.read_S(Proc.read_Ci(r[1]))); + break; + case GLDMSI: + Proc.write_S2(r[0], Proc.machine.M2.read_S(Proc.read_Ci(r[1]))); + break; + case LDMINTI: + Proc.write_Ci(r[0], Proc.machine.Mi.read_C(Proc.read_Ci(r[1])).get()); + break; + case STMC: + Proc.machine.Mp.write_C(n,Proc.read_Cp(r[0]),Proc.PC); + n++; + break; + case GSTMC: + Proc.machine.M2.write_C(n,Proc.read_C2(r[0]),Proc.PC); + n++; + break; + case STMS: + Proc.machine.Mp.write_S(n,Proc.read_Sp(r[0]),Proc.PC); + n++; + break; + case GSTMS: + Proc.machine.M2.write_S(n,Proc.read_S2(r[0]),Proc.PC); + n++; + break; + case STMINT: + Proc.machine.Mi.write_C(n,Integer(Proc.read_Ci(r[0])),Proc.PC); + n++; + break; + case STMCI: + Proc.machine.Mp.write_C(Proc.read_Ci(r[1]), Proc.read_Cp(r[0]),Proc.PC); + break; + case GSTMCI: + Proc.machine.M2.write_C(Proc.read_Ci(r[1]), Proc.read_C2(r[0]),Proc.PC); + break; + case STMSI: + Proc.machine.Mp.write_S(Proc.read_Ci(r[1]), Proc.read_Sp(r[0]),Proc.PC); + break; + case GSTMSI: + Proc.machine.M2.write_S(Proc.read_Ci(r[1]), Proc.read_S2(r[0]),Proc.PC); + break; + case STMINTI: + Proc.machine.Mi.write_C(Proc.read_Ci(r[1]), Integer(Proc.read_Ci(r[0])),Proc.PC); + break; + case MOVC: + Proc.write_Cp(r[0],Proc.read_Cp(r[1])); + break; + case GMOVC: + Proc.write_C2(r[0],Proc.read_C2(r[1])); + break; + case MOVS: + Proc.write_Sp(r[0],Proc.read_Sp(r[1])); + break; + case GMOVS: + Proc.write_S2(r[0],Proc.read_S2(r[1])); + break; + case MOVINT: + Proc.write_Ci(r[0],Proc.read_Ci(r[1])); + break; + case PROTECTMEMS: + Proc.machine.Mp.protect_s(Proc.read_Ci(r[0]), Proc.read_Ci(r[1])); + break; + case PROTECTMEMC: + Proc.machine.Mp.protect_c(Proc.read_Ci(r[0]), Proc.read_Ci(r[1])); + break; + case GPROTECTMEMS: + Proc.machine.M2.protect_s(Proc.read_Ci(r[0]), Proc.read_Ci(r[1])); + break; + case GPROTECTMEMC: + Proc.machine.M2.protect_c(Proc.read_Ci(r[0]), Proc.read_Ci(r[1])); + break; + case PROTECTMEMINT: + Proc.machine.Mi.protect_c(Proc.read_Ci(r[0]), Proc.read_Ci(r[1])); + break; + case PUSHINT: + Proc.pushi(Proc.read_Ci(r[0])); + break; + case POPINT: + Proc.popi(Proc.get_Ci_ref(r[0])); + break; + case LDTN: + Proc.write_Ci(r[0],Proc.get_thread_num()); + break; + case LDARG: + Proc.write_Ci(r[0],Proc.get_arg()); + break; + case STARG: + Proc.set_arg(Proc.read_Ci(r[0])); + break; + case ADDC: + #ifdef DEBUG + Proc.temp.ansp.add(Proc.read_Cp(r[1]),Proc.read_Cp(r[2])); + Proc.write_Cp(r[0],Proc.temp.ansp); + #else + Proc.get_Cp_ref(r[0]).add(Proc.read_Cp(r[1]),Proc.read_Cp(r[2])); + #endif + break; + case GADDC: + #ifdef DEBUG + ans2.add(Proc.read_C2(r[1]),Proc.read_C2(r[2])); + Proc.write_C2(r[0],Proc.temp.ans2); + #else + Proc.get_C2_ref(r[0]).add(Proc.read_C2(r[1]),Proc.read_C2(r[2])); + #endif + break; + case ADDS: + #ifdef DEBUG + Sansp.add(Proc.read_Sp(r[1]),Proc.read_Sp(r[2])); + Proc.write_Sp(r[0],Sansp); + #else + Proc.get_Sp_ref(r[0]).add(Proc.read_Sp(r[1]),Proc.read_Sp(r[2])); + #endif + break; + case GADDS: + #ifdef DEBUG + Sans2.add(Proc.read_S2(r[1]),Proc.read_S2(r[2])); + Proc.write_S2(r[0],Sans2); + #else + Proc.get_S2_ref(r[0]).add(Proc.read_S2(r[1]),Proc.read_S2(r[2])); + #endif + break; + case ADDM: + #ifdef DEBUG + Sansp.add(Proc.read_Sp(r[1]),Proc.read_Cp(r[2]),Proc.P.my_num()==0,Proc.MCp.get_alphai()); + Proc.write_Sp(r[0],Sansp); + #else + Proc.get_Sp_ref(r[0]).add(Proc.read_Sp(r[1]),Proc.read_Cp(r[2]),Proc.P.my_num()==0,Proc.MCp.get_alphai()); + #endif + break; + case GADDM: + #ifdef DEBUG + Sans2.add(Proc.read_S2(r[1]),Proc.read_C2(r[2]),Proc.P.my_num()==0,Proc.MC2.get_alphai()); + Proc.write_S2(r[0],Sans2); + #else + Proc.get_S2_ref(r[0]).add(Proc.read_S2(r[1]),Proc.read_C2(r[2]),Proc.P.my_num()==0,Proc.MC2.get_alphai()); + #endif + break; + case SUBC: + #ifdef DEBUG + ansp.sub(Proc.read_Cp(r[1]),Proc.read_Cp(r[2])); + Proc.write_Cp(r[0],Proc.temp.ansp); + #else + Proc.get_Cp_ref(r[0]).sub(Proc.read_Cp(r[1]),Proc.read_Cp(r[2])); + #endif + break; + case GSUBC: + #ifdef DEBUG + Proc.temp.ans2.sub(Proc.read_C2(r[1]),Proc.read_C2(r[2])); + Proc.write_C2(r[0],Proc.temp.ans2); + #else + Proc.get_C2_ref(r[0]).sub(Proc.read_C2(r[1]),Proc.read_C2(r[2])); + #endif + break; + case SUBS: + #ifdef DEBUG + Sansp.sub(Proc.read_Sp(r[1]),Proc.read_Sp(r[2])); + Proc.write_Sp(r[0],Sansp); + #else + Proc.get_Sp_ref(r[0]).sub(Proc.read_Sp(r[1]),Proc.read_Sp(r[2])); + #endif + break; + case GSUBS: + #ifdef DEBUG + Sans2.sub(Proc.read_S2(r[1]),Proc.read_S2(r[2])); + Proc.write_S2(r[0],Sans2); + #else + Proc.get_S2_ref(r[0]).sub(Proc.read_S2(r[1]),Proc.read_S2(r[2])); + #endif + break; + case SUBML: + #ifdef DEBUG + Sansp.sub(Proc.read_Sp(r[1]),Proc.read_Cp(r[2]),Proc.P.my_num()==0,Proc.MCp.get_alphai()); + Proc.write_Sp(r[0],Sansp); + #else + Proc.get_Sp_ref(r[0]).sub(Proc.read_Sp(r[1]),Proc.read_Cp(r[2]),Proc.P.my_num()==0,Proc.MCp.get_alphai()); + #endif + break; + case GSUBML: + #ifdef DEBUG + Sans2.sub(Proc.read_S2(r[1]),Proc.read_C2(r[2]),Proc.P.my_num()==0,Proc.MC2.get_alphai()); + Proc.write_S2(r[0],Sans2); + #else + Proc.get_S2_ref(r[0]).sub(Proc.read_S2(r[1]),Proc.read_C2(r[2]),Proc.P.my_num()==0,Proc.MC2.get_alphai()); + #endif + break; + case SUBMR: + #ifdef DEBUG + Sansp.sub(Proc.read_Cp(r[1]),Proc.read_Sp(r[2]),Proc.P.my_num()==0,Proc.MCp.get_alphai()); + Proc.write_Sp(r[0],Sansp); + #else + Proc.get_Sp_ref(r[0]).sub(Proc.read_Cp(r[1]),Proc.read_Sp(r[2]),Proc.P.my_num()==0,Proc.MCp.get_alphai()); + #endif + break; + case GSUBMR: + #ifdef DEBUG + Sans2.sub(Proc.read_C2(r[1]),Proc.read_S2(r[2]),Proc.P.my_num()==0,Proc.MC2.get_alphai()); + Proc.write_S2(r[0],Sans2); + #else + Proc.get_S2_ref(r[0]).sub(Proc.read_C2(r[1]),Proc.read_S2(r[2]),Proc.P.my_num()==0,Proc.MC2.get_alphai()); + #endif + break; + case MULC: + #ifdef DEBUG + ansp.mul(Proc.read_Cp(r[1]),Proc.read_Cp(r[2])); + Proc.write_Cp(r[0],Proc.temp.ansp); + #else + Proc.get_Cp_ref(r[0]).mul(Proc.read_Cp(r[1]),Proc.read_Cp(r[2])); + #endif + break; + case GMULC: + #ifdef DEBUG + Proc.temp.ans2.mul(Proc.read_C2(r[1]),Proc.read_C2(r[2])); + Proc.write_C2(r[0],Proc.temp.ans2); + #else + Proc.get_C2_ref(r[0]).mul(Proc.read_C2(r[1]),Proc.read_C2(r[2])); + #endif + break; + case MULM: + #ifdef DEBUG + Sansp.mul(Proc.read_Sp(r[1]),Proc.read_Cp(r[2])); + Proc.write_Sp(r[0],Sansp); + #else + Proc.get_Sp_ref(r[0]).mul(Proc.read_Sp(r[1]),Proc.read_Cp(r[2])); + #endif + break; + case GMULM: + #ifdef DEBUG + Sans2.mul(Proc.read_S2(r[1]),Proc.read_C2(r[2])); + Proc.write_S2(r[0],Sans2); + #else + Proc.get_S2_ref(r[0]).mul(Proc.read_S2(r[1]),Proc.read_C2(r[2])); + #endif + break; + case DIVC: + if (Proc.read_Cp(r[2]).is_zero()) + throw Processor_Error("Division by zero from register"); + Proc.temp.ansp.invert(Proc.read_Cp(r[2])); + Proc.temp.ansp.mul(Proc.read_Cp(r[1])); + Proc.write_Cp(r[0],Proc.temp.ansp); + break; + case GDIVC: + if (Proc.read_C2(r[2]).is_zero()) + throw Processor_Error("Division by zero from register"); + Proc.temp.ans2.invert(Proc.read_C2(r[2])); + Proc.temp.ans2.mul(Proc.read_C2(r[1])); + Proc.write_C2(r[0],Proc.temp.ans2); + break; + case MODC: + to_bigint(Proc.temp.aa, Proc.read_Cp(r[1])); + to_bigint(Proc.temp.aa2, Proc.read_Cp(r[2])); + mpz_fdiv_r(Proc.temp.aa.get_mpz_t(), Proc.temp.aa.get_mpz_t(), Proc.temp.aa2.get_mpz_t()); + to_gfp(Proc.temp.ansp, Proc.temp.aa); + Proc.write_Cp(r[0],Proc.temp.ansp); + break; + case LEGENDREC: + to_bigint(Proc.temp.aa, Proc.read_Cp(r[1])); + Proc.temp.aa = mpz_legendre(Proc.temp.aa.get_mpz_t(), gfp::pr().get_mpz_t()); + //Proc.temp.aa = legendre; + to_gfp(Proc.temp.ansp, Proc.temp.aa); + Proc.write_Cp(r[0], Proc.temp.ansp); + break; + case DIVCI: + if (n == 0) + throw Processor_Error("Division by immediate zero"); + to_gfp(Proc.temp.ansp,n%gfp::pr()); + Proc.temp.ansp.invert(); + Proc.temp.ansp.mul(Proc.read_Cp(r[1])); + Proc.write_Cp(r[0],Proc.temp.ansp); + break; + case GDIVCI: + if (n == 0) + throw Processor_Error("Division by immediate zero"); + Proc.temp.ans2.assign(n); + Proc.temp.ans2.invert(); + Proc.temp.ans2.mul(Proc.read_C2(r[1])); + Proc.write_C2(r[0],Proc.temp.ans2); + break; + case MODCI: + to_bigint(Proc.temp.aa, Proc.read_Cp(r[1])); + to_gfp(Proc.temp.ansp, mpz_fdiv_ui(Proc.temp.aa.get_mpz_t(), n)); + Proc.write_Cp(r[0],Proc.temp.ansp); + break; + case GMULBITC: + #ifdef DEBUG + Proc.temp.ans2.mul_by_bit(Proc.read_C2(r[1]),Proc.read_C2(r[2])); + Proc.write_C2(r[0],Proc.temp.ans2); + #else + Proc.get_C2_ref(r[0]).mul_by_bit(Proc.read_C2(r[1]),Proc.read_C2(r[2])); + #endif + break; + case GMULBITM: + #ifdef DEBUG + Sans2.mul_by_bit(Proc.read_S2(r[1]),Proc.read_C2(r[2])); + Proc.write_S2(r[0],Sans2); + #else + Proc.get_S2_ref(r[0]).mul_by_bit(Proc.read_S2(r[1]),Proc.read_C2(r[2])); + #endif + break; + case ADDCI: + Proc.temp.ansp.assign(n); + #ifdef DEBUG + Proc.temp.ansp.add(Proc.temp.ansp,Proc.read_Cp(r[1])); + Proc.write_Cp(r[0],Proc.temp.ansp); + #else + Proc.get_Cp_ref(r[0]).add(Proc.temp.ansp,Proc.read_Cp(r[1])); + #endif + break; + case GADDCI: + Proc.temp.ans2.assign(n); + #ifdef DEBUG + Proc.temp.ans2.add(Proc.temp.ans2,Proc.read_C2(r[1])); + Proc.write_C2(r[0],Proc.temp.ans2); + #else + Proc.get_C2_ref(r[0]).add(Proc.temp.ans2,Proc.read_C2(r[1])); + #endif + break; + case ADDSI: + Proc.temp.ansp.assign(n); + #ifdef DEBUG + Sansp.add(Proc.read_Sp(r[1]),Proc.temp.ansp,Proc.P.my_num()==0,Proc.MCp.get_alphai()); + Proc.write_Sp(r[0],Sansp); + #else + Proc.get_Sp_ref(r[0]).add(Proc.read_Sp(r[1]),Proc.temp.ansp,Proc.P.my_num()==0,Proc.MCp.get_alphai()); + #endif + break; + case GADDSI: + Proc.temp.ans2.assign(n); + #ifdef DEBUG + Sans2.add(Proc.read_S2(r[1]),Proc.temp.ans2,Proc.P.my_num()==0,Proc.MC2.get_alphai()); + Proc.write_S2(r[0],Sans2); + #else + Proc.get_S2_ref(r[0]).add(Proc.read_S2(r[1]),Proc.temp.ans2,Proc.P.my_num()==0,Proc.MC2.get_alphai()); + #endif + break; + case SUBCI: + Proc.temp.ansp.assign(n); + #ifdef DEBUG + Proc.temp.ansp.sub(Proc.read_Cp(r[1]),Proc.temp.ansp); + Proc.write_Cp(r[0],Proc.temp.ansp); + #else + Proc.get_Cp_ref(r[0]).sub(Proc.read_Cp(r[1]),Proc.temp.ansp); + #endif + break; + case GSUBCI: + Proc.temp.ans2.assign(n); + #ifdef DEBUG + Proc.temp.ans2.sub(Proc.read_C2(r[1]),Proc.temp.ans2); + Proc.write_C2(r[0],Proc.temp.ans2); + #else + Proc.get_C2_ref(r[0]).sub(Proc.read_C2(r[1]),Proc.temp.ans2); + #endif + break; + case SUBSI: + Proc.temp.ansp.assign(n); + #ifdef DEBUG + Sansp.sub(Proc.read_Sp(r[1]),Proc.temp.ansp,Proc.P.my_num()==0,Proc.MCp.get_alphai()); + Proc.write_Sp(r[0],Sansp); + #else + Proc.get_Sp_ref(r[0]).sub(Proc.read_Sp(r[1]),Proc.temp.ansp,Proc.P.my_num()==0,Proc.MCp.get_alphai()); + #endif + break; + case GSUBSI: + Proc.temp.ans2.assign(n); + #ifdef DEBUG + Sans2.sub(Proc.read_S2(r[1]),Proc.temp.ans2,Proc.P.my_num()==0,Proc.MC2.get_alphai()); + Proc.write_S2(r[0],Sans2); + #else + Proc.get_S2_ref(r[0]).sub(Proc.read_S2(r[1]),Proc.temp.ans2,Proc.P.my_num()==0,Proc.MC2.get_alphai()); + #endif + break; + case SUBCFI: + Proc.temp.ansp.assign(n); + #ifdef DEBUG + Proc.temp.ansp.sub(Proc.temp.ansp,Proc.read_Cp(r[1])); + Proc.write_Cp(r[0],Proc.temp.ansp); + #else + Proc.get_Cp_ref(r[0]).sub(Proc.temp.ansp,Proc.read_Cp(r[1])); + #endif + break; + case GSUBCFI: + Proc.temp.ans2.assign(n); + #ifdef DEBUG + Proc.temp.ans2.sub(Proc.temp.ans2,Proc.read_C2(r[1])); + Proc.write_C2(r[0],Proc.temp.ans2); + #else + Proc.get_C2_ref(r[0]).sub(Proc.temp.ans2,Proc.read_C2(r[1])); + #endif + break; + case SUBSFI: + Proc.temp.ansp.assign(n); + #ifdef DEBUG + Sansp.sub(Proc.temp.ansp,Proc.read_Sp(r[1]),Proc.P.my_num()==0,Proc.MCp.get_alphai()); + Proc.write_Sp(r[0],Sansp); + #else + Proc.get_Sp_ref(r[0]).sub(Proc.temp.ansp,Proc.read_Sp(r[1]),Proc.P.my_num()==0,Proc.MCp.get_alphai()); + #endif + break; + case GSUBSFI: + Proc.temp.ans2.assign(n); + #ifdef DEBUG + Sans2.sub(Proc.temp.ans2,Proc.read_S2(r[1]),Proc.P.my_num()==0,Proc.MC2.get_alphai()); + Proc.write_S2(r[0],Sans2); + #else + Proc.get_S2_ref(r[0]).sub(Proc.temp.ans2,Proc.read_S2(r[1]),Proc.P.my_num()==0,Proc.MC2.get_alphai()); + #endif + break; + case MULCI: + Proc.temp.ansp.assign(n); + #ifdef DEBUG + Proc.temp.ansp.mul(Proc.temp.ansp,Proc.read_Cp(r[1])); + Proc.write_Cp(r[0],Proc.temp.ansp); + #else + Proc.get_Cp_ref(r[0]).mul(Proc.temp.ansp,Proc.read_Cp(r[1])); + #endif + break; + case GMULCI: + Proc.temp.ans2.assign(n); + #ifdef DEBUG + Proc.temp.ans2.mul(Proc.temp.ans2,Proc.read_C2(r[1])); + Proc.write_C2(r[0],Proc.temp.ans2); + #else + Proc.get_C2_ref(r[0]).mul(Proc.temp.ans2,Proc.read_C2(r[1])); + #endif + break; + case MULSI: + Proc.temp.ansp.assign(n); + #ifdef DEBUG + Sansp.mul(Proc.read_Sp(r[1]),Proc.temp.ansp); + Proc.write_Sp(r[0],Sansp); + #else + Proc.get_Sp_ref(r[0]).mul(Proc.read_Sp(r[1]),Proc.temp.ansp); + #endif + break; + case GMULSI: + Proc.temp.ans2.assign(n); + #ifdef DEBUG + Sans2.mul(Proc.read_S2(r[1]),Proc.temp.ans2); + Proc.write_S2(r[0],Sans2); + #else + Proc.get_S2_ref(r[0]).mul(Proc.read_S2(r[1]),Proc.temp.ans2); + #endif + break; + case TRIPLE: + Proc.DataF.get_three(DATA_MODP, DATA_TRIPLE, Proc.get_Sp_ref(r[0]),Proc.get_Sp_ref(r[1]),Proc.get_Sp_ref(r[2])); + break; + case GTRIPLE: + Proc.DataF.get_three(DATA_GF2N, DATA_TRIPLE, Proc.get_S2_ref(r[0]),Proc.get_S2_ref(r[1]),Proc.get_S2_ref(r[2])); + break; + case GBITTRIPLE: + Proc.DataF.get_three(DATA_GF2N, DATA_BITTRIPLE, Proc.get_S2_ref(r[0]),Proc.get_S2_ref(r[1]),Proc.get_S2_ref(r[2])); + break; + case GBITGF2NTRIPLE: + Proc.DataF.get_three(DATA_GF2N, DATA_BITGF2NTRIPLE, Proc.get_S2_ref(r[0]),Proc.get_S2_ref(r[1]),Proc.get_S2_ref(r[2])); + break; + case SQUARE: + Proc.DataF.get_two(DATA_MODP, DATA_SQUARE, Proc.get_Sp_ref(r[0]),Proc.get_Sp_ref(r[1])); + break; + case GSQUARE: + Proc.DataF.get_two(DATA_GF2N, DATA_SQUARE, Proc.get_S2_ref(r[0]),Proc.get_S2_ref(r[1])); + break; + case BIT: + Proc.DataF.get_one(DATA_MODP, DATA_BIT, Proc.get_Sp_ref(r[0])); + break; + case GBIT: + Proc.DataF.get_one(DATA_GF2N, DATA_BIT, Proc.get_S2_ref(r[0])); + break; + case INV: + Proc.DataF.get_two(DATA_MODP, DATA_INVERSE, Proc.get_Sp_ref(r[0]),Proc.get_Sp_ref(r[1])); + break; + case GINV: + Proc.DataF.get_two(DATA_GF2N, DATA_INVERSE, Proc.get_S2_ref(r[0]),Proc.get_S2_ref(r[1])); + break; + case INPUTMASK: + Proc.DataF.get_input(Proc.get_Sp_ref(r[0]), Proc.temp.ansp, n); + if (n == Proc.P.my_num()) + Proc.temp.ansp.output(Proc.private_output, false); + break; + case GINPUTMASK: + Proc.DataF.get_input(Proc.get_S2_ref(r[0]), Proc.temp.ans2, n); + if (n == Proc.P.my_num()) + Proc.temp.ans2.output(Proc.private_output, false); + break; + case INPUT: + { gfp& rr=Proc.temp.rrp; gfp& t=Proc.temp.tp; gfp& tmp=Proc.temp.tmpp; + Proc.DataF.get_input(Proc.get_Sp_ref(r[0]),rr,n); + octetStream o; + if (n==Proc.P.my_num()) + { gfp& xi=Proc.temp.xip; + #ifdef DEBUG + printf("Enter your input : \n"); + #endif + word x; + cin >> x; + t.assign(x); + t.sub(t,rr); + t.pack(o); + Proc.P.send_all(o); + xi.add(t,Proc.get_Sp_ref(r[0]).get_share()); + Proc.get_Sp_ref(r[0]).set_share(xi); + } + else + { Proc.P.receive_player(n,o); + t.unpack(o); + } + tmp.mul(Proc.MCp.get_alphai(),t); + tmp.add(Proc.get_Sp_ref(r[0]).get_mac(),tmp); + Proc.get_Sp_ref(r[0]).set_mac(tmp); + } + break; + case GINPUT: + { gf2n& rr=Proc.temp.rr2; gf2n& t=Proc.temp.t2; gf2n& tmp=Proc.temp.tmp2; + Proc.DataF.get_input(Proc.get_S2_ref(r[0]),rr,n); + octetStream o; + if (n==Proc.P.my_num()) + { gf2n& xi=Proc.temp.xi2; + #ifdef DEBUG + printf("Enter your input : \n"); + #endif + word x; + cin >> x; + t.assign(x); + t.sub(t,rr); + t.pack(o); + Proc.P.send_all(o); + xi.add(t,Proc.get_S2_ref(r[0]).get_share()); + Proc.get_S2_ref(r[0]).set_share(xi); + } + else + { Proc.P.receive_player(n,o); + t.unpack(o); + } + tmp.mul(Proc.MC2.get_alphai(),t); + tmp.add(Proc.get_S2_ref(r[0]).get_mac(),tmp); + Proc.get_S2_ref(r[0]).set_mac(tmp); + } + break; + case STARTINPUT: + Proc.inputp.start(r[0],n); + break; + case GSTARTINPUT: + Proc.input2.start(r[0],n); + break; + case STOPINPUT: + Proc.inputp.stop(n,start); + break; + case GSTOPINPUT: + Proc.input2.stop(n,start); + break; + case ANDC: + #ifdef DEBUG + ansp.AND(Proc.read_Cp(r[1]),Proc.read_Cp(r[2])); + Proc.write_Cp(r[0],Proc.temp.ansp); + #else + Proc.get_Cp_ref(r[0]).AND(Proc.read_Cp(r[1]),Proc.read_Cp(r[2])); + #endif + break; + case GANDC: + #ifdef DEBUG + Proc.temp.ans2.AND(Proc.read_C2(r[1]),Proc.read_C2(r[2])); + Proc.write_C2(r[0],Proc.temp.ans2); + #else + Proc.get_C2_ref(r[0]).AND(Proc.read_C2(r[1]),Proc.read_C2(r[2])); + #endif + break; + case XORC: + #ifdef DEBUG + Proc.temp.ansp.XOR(Proc.read_Cp(r[1]),Proc.read_Cp(r[2])); + Proc.write_Cp(r[0],Proc.temp.ansp); + #else + Proc.get_Cp_ref(r[0]).XOR(Proc.read_Cp(r[1]),Proc.read_Cp(r[2])); + #endif + break; + case GXORC: + #ifdef DEBUG + Proc.temp.ans2.XOR(Proc.read_C2(r[1]),Proc.read_C2(r[2])); + Proc.write_C2(r[0],Proc.temp.ans2); + #else + Proc.get_C2_ref(r[0]).XOR(Proc.read_C2(r[1]),Proc.read_C2(r[2])); + #endif + break; + case ORC: + #ifdef DEBUG + Proc.temp.ansp.OR(Proc.read_Cp(r[1]),Proc.read_Cp(r[2])); + Proc.write_Cp(r[0],Proc.temp.ansp); + #else + Proc.get_Cp_ref(r[0]).OR(Proc.read_Cp(r[1]),Proc.read_Cp(r[2])); + #endif + break; + case GORC: + #ifdef DEBUG + Proc.temp.ans2.OR(Proc.read_C2(r[1]),Proc.read_C2(r[2])); + Proc.write_C2(r[0],Proc.temp.ans2); + #else + Proc.get_C2_ref(r[0]).OR(Proc.read_C2(r[1]),Proc.read_C2(r[2])); + #endif + break; + case ANDCI: + Proc.temp.aa=n; + #ifdef DEBUG + Proc.temp.ansp.AND(Proc.read_Cp(r[1]),Proc.temp.aa); + Proc.write_Cp(r[0],ansp); + #else + Proc.get_Cp_ref(r[0]).AND(Proc.read_Cp(r[1]),Proc.temp.aa); + #endif + break; + case GANDCI: + Proc.temp.ans2.assign(n); + #ifdef DEBUG + Proc.temp.ans2.AND(Proc.temp.ans2,Proc.read_C2(r[1])); + Proc.write_C2(r[0],Proc.temp.ans2); + #else + Proc.get_C2_ref(r[0]).AND(Proc.temp.ans2,Proc.read_C2(r[1])); + #endif + break; + case XORCI: + Proc.temp.aa=n; + #ifdef DEBUG + ansp.XOR(Proc.read_Cp(r[1]),Proc.temp.aa); + Proc.write_Cp(r[0],Proc.temp.ansp); + #else + Proc.get_Cp_ref(r[0]).XOR(Proc.read_Cp(r[1]),Proc.temp.aa); + #endif + break; + case GXORCI: + Proc.temp.ans2.assign(n); + #ifdef DEBUG + Proc.temp.ans2.XOR(Proc.temp.ans2,Proc.read_C2(r[1])); + Proc.write_C2(r[0],Proc.temp.ans2); + #else + Proc.get_C2_ref(r[0]).XOR(Proc.temp.ans2,Proc.read_C2(r[1])); + #endif + break; + case ORCI: + Proc.temp.aa=n; + #ifdef DEBUG + Proc.temp.ansp.OR(Proc.read_Cp(r[1]),Proc.temp.aa); + Proc.write_Cp(r[0],Proc.temp.ansp); + #else + Proc.get_Cp_ref(r[0]).OR(Proc.read_Cp(r[1]),Proc.temp.aa); + #endif + break; + case GORCI: + Proc.temp.ans2.assign(n); + #ifdef DEBUG + Proc.temp.ans2.OR(Proc.temp.ans2,Proc.read_C2(r[1])); + Proc.write_C2(r[0],Proc.temp.ans2); + #else + Proc.get_C2_ref(r[0]).OR(Proc.temp.ans2,Proc.read_C2(r[1])); + #endif + break; + // Note: Fp version has different semantics for NOTC than GNOTC + case NOTC: + to_bigint(Proc.temp.aa, Proc.read_Cp(r[1])); + mpz_com(Proc.temp.aa.get_mpz_t(), Proc.temp.aa.get_mpz_t()); + Proc.temp.aa2 = 1; + Proc.temp.aa2 <<= n; + Proc.temp.aa += Proc.temp.aa2; + to_gfp(Proc.temp.ansp, Proc.temp.aa); + Proc.write_Cp(r[0],Proc.temp.ansp); + break; + case GNOTC: + #ifdef DEBUG + Proc.temp.ans2.NOT(Proc.read_C2(r[1])); + Proc.write_C2(r[0],Proc.temp.ans2); + #else + Proc.get_C2_ref(r[0]).NOT(Proc.read_C2(r[1])); + #endif + break; + case SHLC: + to_bigint(Proc.temp.aa,Proc.read_Cp(r[2])); + if (Proc.temp.aa > 63) + throw not_implemented(); + #ifdef DEBUG + Proc.temp.ansp.SHL(Proc.read_Cp(r[1]),Proc.temp.aa); + Proc.write_Cp(r[0],Proc.temp.ansp); + #else + Proc.get_Cp_ref(r[0]).SHL(Proc.read_Cp(r[1]),Proc.temp.aa); + #endif + break; + case SHRC: + to_bigint(Proc.temp.aa,Proc.read_Cp(r[2])); + if (Proc.temp.aa > 63) + throw not_implemented(); + #ifdef DEBUG + Proc.temp.ansp.SHR(Proc.read_Cp(r[1]),Proc.temp.aa); + Proc.write_Cp(r[0],Proc.temp.ansp); + #else + Proc.get_Cp_ref(r[0]).SHR(Proc.read_Cp(r[1]),Proc.temp.aa); + #endif + break; + case SHLCI: + #ifdef DEBUG + Proc.temp.ansp.SHL(Proc.read_Cp(r[1]),Proc.temp.aa); + Proc.write_Cp(r[0],Proc.temp.ansp); + #else + Proc.get_Cp_ref(r[0]).SHL(Proc.read_Cp(r[1]),n); + #endif + break; + case GSHLCI: + #ifdef DEBUG + Proc.temp.ans2.SHL(Proc.read_C2(r[1]),n); + Proc.write_C2(r[0],Proc.temp.ans2); + #else + Proc.get_C2_ref(r[0]).SHL(Proc.read_C2(r[1]),n); + #endif + break; + case SHRCI: + #ifdef DEBUG + Proc.temp.ansp.SHR(Proc.read_Cp(r[1]),Proc.temp.aa); + Proc.write_Cp(r[0],Proc.temp.ansp); + #else + Proc.get_Cp_ref(r[0]).SHR(Proc.read_Cp(r[1]),n); + #endif + break; + case GSHRCI: + #ifdef DEBUG + Proc.temp.ans2.SHR(Proc.read_C2(r[1]),n); + Proc.write_C2(r[0],Proc.temp.ans2); + #else + Proc.get_C2_ref(r[0]).SHR(Proc.read_C2(r[1]),n); + #endif + break; + case GBITDEC: + for (int j = 0; j < size; j++) + { + gf2n::internal_type a = Proc.read_C2(r[0] + j).get(); + for (unsigned int i = 0; i < start.size(); i++) + { + Proc.get_C2_ref(start[i] + j) = a & 1; + a >>= n; + } + } + return; + case GBITCOM: + for (int j = 0; j < size; j++) + { + gf2n::internal_type a = 0; + for (unsigned int i = 0; i < start.size(); i++) + { + a ^= Proc.read_C2(start[i] + j).get() << (i * n); + } + Proc.get_C2_ref(r[0] + j) = a; + } + return; + case STARTOPEN: + Proc.POpen_Start(start,Proc.P,Proc.MCp,size); + return; + case GSTARTOPEN: + Proc.POpen_Start(start,Proc.P,Proc.MC2,size); + return; + case STOPOPEN: + Proc.POpen_Stop(start,Proc.P,Proc.MCp,size); + return; + case GSTOPOPEN: + Proc.POpen_Stop(start,Proc.P,Proc.MC2,size); + return; + case JMP: + Proc.PC += (signed int) n; + break; + case JMPI: + Proc.PC += (signed int) Proc.read_Ci(r[0]); + break; + case JMPNZ: + if (Proc.read_Ci(r[0]) != 0) + { Proc.PC += (signed int) n; } + break; + case JMPEQZ: + if (Proc.read_Ci(r[0]) == 0) + { Proc.PC += (signed int) n; } + break; + case EQZC: + if (Proc.read_Ci(r[1]) == 0) + Proc.write_Ci(r[0], 1); + else + Proc.write_Ci(r[0], 0); + break; + case LTZC: + if (Proc.read_Ci(r[1]) < 0) + Proc.write_Ci(r[0], 1); + else + Proc.write_Ci(r[0], 0); + break; + case LTC: + if (Proc.read_Ci(r[1]) < Proc.read_Ci(r[2])) + Proc.write_Ci(r[0], 1); + else + Proc.write_Ci(r[0], 0); + break; + case GTC: + if (Proc.read_Ci(r[1]) > Proc.read_Ci(r[2])) + Proc.write_Ci(r[0], 1); + else + Proc.write_Ci(r[0], 0); + break; + case EQC: + if (Proc.read_Ci(r[1]) == Proc.read_Ci(r[2])) + Proc.write_Ci(r[0], 1); + else + Proc.write_Ci(r[0], 0); + break; + case LDINT: + Proc.write_Ci(r[0], n); + break; + case ADDINT: + Proc.get_Ci_ref(r[0]) = Proc.read_Ci(r[1]) + Proc.read_Ci(r[2]); + break; + case SUBINT: + Proc.get_Ci_ref(r[0]) = Proc.read_Ci(r[1]) - Proc.read_Ci(r[2]); + break; + case MULINT: + Proc.get_Ci_ref(r[0]) = Proc.read_Ci(r[1]) * Proc.read_Ci(r[2]); + break; + case DIVINT: + Proc.get_Ci_ref(r[0]) = Proc.read_Ci(r[1]) / Proc.read_Ci(r[2]); + break; + case CONVINT: + Proc.get_Cp_ref(r[0]).assign(Proc.read_Ci(r[1])); + break; + case GCONVINT: + Proc.get_C2_ref(r[0]).assign((word)Proc.read_Ci(r[1])); + break; + case CONVMODP: + to_signed_bigint(Proc.temp.aa,Proc.read_Cp(r[1]),n); + Proc.write_Ci(r[0], Proc.temp.aa.get_si()); + break; + case GCONVGF2N: + Proc.write_Ci(r[0], Proc.read_C2(r[1]).get_word()); + break; + case PRINTMEM: + if (Proc.P.my_num() == 0) + { cout << "Mem[" << r[0] << "] = " << Proc.machine.Mp.read_C(r[0]) << endl; } + break; + case GPRINTMEM: + if (Proc.P.my_num() == 0) + { cout << "Mem[" << r[0] << "] = " << Proc.machine.M2.read_C(r[0]) << endl; } + break; + case PRINTREG: + if (Proc.P.my_num() == 0) + { + cout << "Reg[" << r[0] << "] = " << Proc.read_Cp(r[0]) + << " # " << string((char*)&n,sizeof(n)) << endl; + } + break; + case GPRINTREG: + if (Proc.P.my_num() == 0) + { + cout << "Reg[" << r[0] << "] = " << Proc.read_C2(r[0]) + << " # " << string((char*)&n,sizeof(n)) << endl; + } + break; + case PRINTREGPLAIN: + if (Proc.P.my_num() == 0) + { + cout << Proc.read_Cp(r[0]) << flush; + } + break; + case GPRINTREGPLAIN: + if (Proc.P.my_num() == 0) + { + cout << Proc.read_C2(r[0]) << flush; + } + break; + case PRINTSTR: + if (Proc.P.my_num() == 0) + { + cout << string((char*)&n,sizeof(n)) << flush; + } + break; + case PRINTCHR: + if (Proc.P.my_num() == 0) + { + cout << string((char*)&n,1) << flush; + } + break; + case PRINTCHRINT: + if (Proc.P.my_num() == 0) + { + cout << string((char*)&(Proc.read_Ci(r[0])),1) << flush; + } + break; + case PRINTSTRINT: + if (Proc.P.my_num() == 0) + { + cout << string((char*)&(Proc.read_Ci(r[0])),sizeof(int)) << flush; + } + break; + case RAND: + Proc.write_Ci(r[0], Proc.prng.get_uint() % (1 << Proc.read_Ci(r[1]))); + break; + case REQBL: + case GREQBL: + case USE: + case USE_INP: + case USE_PREP: + case GUSE_PREP: + break; + case TIME: + cout << "Elapsed time: " << Proc.machine.timer[0].elapsed() << endl; + break; + case START: + cout << "Starting timer " << n << " at " << Proc.machine.timer[n].elapsed() + << " after " << Proc.machine.timer[n].idle() << endl; + Proc.machine.timer[n].start(); + break; + case STOP: + Proc.machine.timer[n].stop(); + cout << "Stopped timer " << n << " at " << Proc.machine.timer[n].elapsed() << endl; + break; + case RUN_TAPE: + Proc.DataF.skip(Proc.machine.run_tape(r[0], n, r[1], -1)); + break; + case JOIN_TAPE: + Proc.machine.join_tape(r[0]); + break; + case CRASH: + throw crash_requested(); + break; + // *** + // TODO: read/write shared GF(2^n) data instructions + // *** + case OPENSOCKET: + Proc.open_socket(n); + break; + case CLOSESOCKET: + Proc.close_socket(); + break; + case READSOCKETC: // n is *unused atm*, r[0] is register to write to + int dest; + Proc.read_socket(dest); + Proc.write_Ci(r[0], (long)dest); + break; + case READSOCKETS: + // read share then MAC share + Proc.read_socket(Proc.temp.ansp); + Proc.get_Sp_ref(r[0]).set_share(Proc.temp.ansp); + Proc.read_socket(Proc.temp.ansp); + Proc.get_Sp_ref(r[0]).set_mac(Proc.temp.ansp); + break; + case GREADSOCKETS: + //Proc.get_S2_ref(r[0]).get_share().pack(socket_octetstream); + //Proc.get_S2_ref(r[0]).get_mac().pack(socket_octetstream); + break; + case WRITESOCKETC: // n is *unused atm*, r[0] is register to write to; + Proc.write_socket((int&)Proc.get_Ci_ref(r[0])); + break; + case WRITESOCKETS: + Proc.write_socket(Proc.get_Sp_ref(r[0]).get_share()); + Proc.write_socket(Proc.get_Sp_ref(r[0]).get_mac()); + break; + /*case GWRITESOCKETS: + Proc.get_S2_ref(r[0]).get_share().pack(socket_octetstream); + Proc.get_S2_ref(r[0]).get_mac().pack(socket_octetstream); + break;*/ + case PUBINPUT: + Proc.public_input >> Proc.get_Ci_ref(r[0]); + break; + case RAWOUTPUT: + Proc.read_Cp(r[0]).output(Proc.public_output, false); + break; + case GRAWOUTPUT: + Proc.read_C2(r[0]).output(Proc.public_output, false); + break; + case STARTPRIVATEOUTPUT: + Proc.privateOutputp.start(n,r[0],r[1]); + break; + case GSTARTPRIVATEOUTPUT: + Proc.privateOutput2.start(n,r[0],r[1]); + break; + case STOPPRIVATEOUTPUT: + Proc.privateOutputp.stop(n,r[0]); + break; + case GSTOPPRIVATEOUTPUT: + Proc.privateOutput2.stop(n,r[0]); + break; + case PREP: + Proc.DataF.get(Proc, r, start, size); + return; + case GPREP: + Proc.DataF.get(Proc, r, start, size); + return; + default: + printf("Case of opcode=%d not implemented yet\n",opcode); + throw not_implemented(); + break; + } + if (size > 1) + { + r[0]++; r[1]++; r[2]++; + } + } +} diff --git a/Processor/Instruction.h b/Processor/Instruction.h new file mode 100644 index 000000000..8855111fa --- /dev/null +++ b/Processor/Instruction.h @@ -0,0 +1,311 @@ +// (C) 2016 University of Bristol. See License.txt + +#ifndef _Instruction +#define _Instruction + +/* Class to read and decode an instruction + */ + +#include +#include +#include +using namespace std; + +#include "Processor/Memory.h" +#include "Processor/Data_Files.h" +#include "Networking/Player.h" +#include "Math/Integer.h" +#include "Auth/MAC_Check.h" + +class Machine; +class Processor; + +/* + * Opcode constants + * + * Whenever these are changed the corresponding dict in Compiler/instructions.py + * MUST also be changed. (+ the documentation) + */ +enum +{ + // Load/store + LDI = 0x1, + LDSI = 0x2, + LDMC = 0x3, + LDMS = 0x4, + STMC = 0x5, + STMS = 0x6, + LDMCI = 0x7, + LDMSI = 0x8, + STMCI = 0x9, + STMSI = 0xA, + MOVC = 0xB, + MOVS = 0xC, + PROTECTMEMS = 0xD, + PROTECTMEMC = 0xE, + PROTECTMEMINT = 0xF, + LDMINT = 0xCA, + STMINT = 0xCB, + LDMINTI = 0xCC, + STMINTI = 0xCD, + PUSHINT = 0xCE, + POPINT = 0xCF, + MOVINT = 0xD0, + // Machine + LDTN = 0x10, + LDARG = 0x11, + REQBL = 0x12, + STARG = 0x13, + TIME = 0x14, + START = 0x15, + STOP = 0x16, + USE = 0x17, + USE_INP = 0x18, + RUN_TAPE = 0x19, + JOIN_TAPE = 0x1A, + CRASH = 0x1B, + USE_PREP = 0x1C, + // Addition + ADDC = 0x20, + ADDS = 0x21, + ADDM = 0x22, + ADDCI = 0x23, + ADDSI = 0x24, + SUBC = 0x25, + SUBS = 0x26, + SUBML = 0x27, + SUBMR = 0x28, + SUBCI = 0x29, + SUBSI = 0x2A, + SUBCFI = 0x2B, + SUBSFI = 0x2C, + // Multiplication/division/other arithmetic + MULC = 0x30, + MULM = 0x31, + MULCI = 0x32, + MULSI = 0x33, + DIVC = 0x34, + DIVCI = 0x35, + MODC = 0x36, + MODCI = 0x37, + LEGENDREC = 0x38, + // Open + STARTOPEN = 0xA0, + STOPOPEN = 0xA1, + // Data access + TRIPLE = 0x50, + BIT = 0x51, + SQUARE = 0x52, + INV = 0x53, + INPUTMASK = 0x56, + PREP = 0x57, + // Input + INPUT = 0x60, + STARTINPUT = 0x61, + STOPINPUT = 0x62, + READSOCKETC = 0x63, + READSOCKETS = 0x64, + WRITESOCKETC = 0x65, + WRITESOCKETS = 0x66, + OPENSOCKET = 0x67, + CLOSESOCKET = 0x68, + // Bitwise logic + ANDC = 0x70, + XORC = 0x71, + ORC = 0x72, + ANDCI = 0x73, + XORCI = 0x74, + ORCI = 0x75, + NOTC = 0x76, + // Bitwise shifts + SHLC = 0x80, + SHRC = 0x81, + SHLCI = 0x82, + SHRCI = 0x83, + // Branching and comparison + JMP = 0x90, + JMPNZ = 0x91, + JMPEQZ = 0x92, + EQZC = 0x93, + LTZC = 0x94, + LTC = 0x95, + GTC = 0x96, + EQC = 0x97, + JMPI = 0x98, + // Integers + LDINT = 0x9A, + ADDINT = 0x9B, + SUBINT = 0x9C, + MULINT = 0x9D, + DIVINT = 0x9E, + // Conversion + CONVINT = 0xC0, + CONVMODP = 0xC1, + + // IO + PRINTMEM = 0xB0, + PRINTREG = 0XB1, + RAND = 0xB2, + PRINTREGPLAIN = 0xB3, + PRINTCHR = 0xB4, + PRINTSTR = 0xB5, + PUBINPUT = 0xB6, + RAWOUTPUT = 0xB7, + STARTPRIVATEOUTPUT = 0xB8, + STOPPRIVATEOUTPUT = 0xB9, + PRINTCHRINT = 0xBA, + PRINTSTRINT = 0xBB, + + // GF(2^n) versions + + // Load/store + GLDI = 0x101, + GLDSI = 0x102, + GLDMC = 0x103, + GLDMS = 0x104, + GSTMC = 0x105, + GSTMS = 0x106, + GLDMCI = 0x107, + GLDMSI = 0x108, + GSTMCI = 0x109, + GSTMSI = 0x10A, + GMOVC = 0x10B, + GMOVS = 0x10C, + GPROTECTMEMS = 0x10D, + GPROTECTMEMC = 0x10E, + // Machine + GREQBL = 0x112, + GUSE_PREP = 0x11C, + // Addition + GADDC = 0x120, + GADDS = 0x121, + GADDM = 0x122, + GADDCI = 0x123, + GADDSI = 0x124, + GSUBC = 0x125, + GSUBS = 0x126, + GSUBML = 0x127, + GSUBMR = 0x128, + GSUBCI = 0x129, + GSUBSI = 0x12A, + GSUBCFI = 0x12B, + GSUBSFI = 0x12C, + // Multiplication/division + GMULC = 0x130, + GMULM = 0x131, + GMULCI = 0x132, + GMULSI = 0x133, + GDIVC = 0x134, + GDIVCI = 0x135, + GMULBITC = 0x136, + GMULBITM = 0x137, + // Open + GSTARTOPEN = 0x1A0, + GSTOPOPEN = 0x1A1, + // Data access + GTRIPLE = 0x150, + GBIT = 0x151, + GSQUARE = 0x152, + GINV = 0x153, + GBITTRIPLE = 0x154, + GBITGF2NTRIPLE = 0x155, + GINPUTMASK = 0x156, + GPREP = 0x157, + // Input + GINPUT = 0x160, + GSTARTINPUT = 0x161, + GSTOPINPUT = 0x162, + GREADSOCKETS = 0x164, + GWRITESOCKETS = 0x166, + // Bitwise logic + GANDC = 0x170, + GXORC = 0x171, + GORC = 0x172, + GANDCI = 0x173, + GXORCI = 0x174, + GORCI = 0x175, + GNOTC = 0x176, + // Bitwise shifts + GSHLCI = 0x182, + GSHRCI = 0x183, + GBITDEC = 0x184, + GBITCOM = 0x185, + // Conversion + GCONVINT = 0x1C0, + GCONVGF2N = 0x1C1, + // IO + GPRINTMEM = 0x1B0, + GPRINTREG = 0X1B1, + GPRINTREGPLAIN = 0x1B3, + GRAWOUTPUT = 0x1B7, + GSTARTPRIVATEOUTPUT = 0x1B8, + GSTOPPRIVATEOUTPUT = 0x1B9, +}; + + +// Register types +enum RegType { + MODP, + GF2N, + INT, + MAX_REG_TYPE, + NONE +}; + +enum SecrecyType { + SECRET, + CLEAR, + MAX_SECRECY_TYPE +}; + + +struct TempVars { + gf2n ans2; Share Sans2; + gfp ansp; Share Sansp; + bigint aa,aa2; + // INPUT and LDSI + gfp rrp,tp,tmpp; + gfp xip; + // GINPUT and GLDSI + gf2n rr2,t2,tmp2; + gf2n xi2; +}; + + +class Instruction +{ + int opcode; // The code + int size; // Vector size + int r[3]; // Three possible registers + unsigned int n; // Possible immediate value + vector start; // Values for a start/stop open + + public: + + // Reads a single instruction from the istream + void parse(istream& s); + + // Return whether usage is known + bool get_offline_data_usage(DataPositions& usage); + + bool is_gf2n_instruction() const { return ((opcode&0x100)!=0); } + RegType get_reg_type() const; + + bool is_direct_memory_access(SecrecyType sec_type) const; + + // Returns the maximal register used + int get_max_reg(RegType reg_type) const; + + // Returns the memory size used if applicable and known + int get_mem(RegType reg_type, SecrecyType sec_type) const; + + friend ostream& operator<<(ostream& s,const Instruction& instr); + + // Execute this instruction, updateing the processor and memory + // and streams pointing to the triples etc + void execute(Processor& Proc) const; +}; + + +#endif + diff --git a/Processor/Machine.cpp b/Processor/Machine.cpp new file mode 100644 index 000000000..061a4f416 --- /dev/null +++ b/Processor/Machine.cpp @@ -0,0 +1,362 @@ +// (C) 2016 University of Bristol. See License.txt + +#include "Machine.h" + +#include "Exceptions/Exceptions.h" + +#include + +#include "Math/Setup.h" + +#include +#include +#include +#include +#include +using namespace std; + +Machine::Machine(int my_number, int PortnumBase, string hostname, + string progname_str, string memtype, int lgp, int lg2, bool direct, + int opening_sum, bool parallel, bool receive_threads, int max_broadcast) + : my_number(my_number), nthreads(0), tn(0), numt(0), usage_unknown(false), + progname(progname_str), direct(direct), opening_sum(opening_sum), parallel(parallel), + receive_threads(receive_threads), max_broadcast(max_broadcast) +{ + N.init(my_number,PortnumBase,hostname.c_str()); + + if (opening_sum < 2) + this->opening_sum = N.num_players(); + if (max_broadcast < 2) + this->max_broadcast = N.num_players(); + + // Set up the fields + prep_dir_prefix = get_prep_dir(N.num_players(), lgp, lg2); + read_setup(prep_dir_prefix); + + char filename[1024]; + int nn; + + sprintf(filename, (prep_dir_prefix + "Player-MAC-Keys-P%d").c_str(), my_number); + inpf.open(filename); + if (inpf.fail()) + { + cerr << "Could not open MAC key file. Perhaps it needs to be generated?\n"; + throw file_error(filename); + } + inpf >> nn; + if (nn!=N.num_players()) + { cerr << "KeyGen was last run with " << nn << " players." << endl; + cerr << " - You are running Online with " << N.num_players() << " players." << endl; + exit(1); + } + + alphapi.input(inpf,true); + alpha2i.input(inpf,true); + cerr << "MAC Key p = " << alphapi << endl; + cerr << "MAC Key 2 = " << alpha2i << endl; + inpf.close(); + + + // Initialize the global memory + if (memtype.compare("new")==0) + {sprintf(filename, "Player-Data/Player-Memory-P%d", my_number); + ifstream memfile(filename); + if (memfile.fail()) { throw file_error(filename); } + Load_Memory(M2,memfile); + Load_Memory(Mp,memfile); + Load_Memory(Mi,memfile); + memfile.close(); + } + else if (memtype.compare("old")==0) + { + sprintf(filename, "Player-Data/Memory-P%d", my_number); + inpf.open(filename,ios::in | ios::binary); + if (inpf.fail()) { throw file_error(); } + inpf >> M2 >> Mp >> Mi; + inpf.close(); + } + else if (!(memtype.compare("empty")==0)) + { cerr << "Invalid memory argument" << endl; + exit(1); + } + + sprintf(filename, "Programs/Schedules/%s.sch",progname.c_str()); + cerr << "Opening file " << filename << endl; + inpf.open(filename); + if (inpf.fail()) { throw file_error("Missing '" + string(filename) + "'. Did you compile '" + progname + "'?"); } + + int nprogs; + inpf >> nthreads; + inpf >> nprogs; + + // Keep record of used offline data + pos.set_num_players(N.num_players()); + + cerr << "Number of threads I will run in parallel = " << nthreads << endl; + cerr << "Number of program sequences I need to load = " << nprogs << endl; + + // Load in the programs + progs.resize(nprogs,N.num_players()); + char threadname[1024]; + for (int i=0; i> threadname; + sprintf(filename,"Programs/Bytecode/%s.bc",threadname); + cerr << "Loading program " << i << " from " << filename << endl; + ifstream pinp(filename); + if (pinp.fail()) { throw file_error(filename); } + progs[i].parse(pinp); + pinp.close(); + if (progs[i].direct_mem2_s() > M2.size_s()) + { + cerr << threadname << " needs more secret mod2 memory, resizing to " + << progs[i].direct_mem2_s() << endl; + M2.resize_s(progs[i].direct_mem2_s()); + } + if (progs[i].direct_memp_s() > Mp.size_s()) + { + cerr << threadname << " needs more secret modp memory, resizing to " + << progs[i].direct_memp_s() << endl; + Mp.resize_s(progs[i].direct_memp_s()); + } + if (progs[i].direct_mem2_c() > M2.size_c()) + { + cerr << threadname << " needs more clear mod2 memory, resizing to " + << progs[i].direct_mem2_c() << endl; + M2.resize_c(progs[i].direct_mem2_c()); + } + if (progs[i].direct_memp_c() > Mp.size_c()) + { + cerr << threadname << " needs more clear modp memory, resizing to " + << progs[i].direct_memp_c() << endl; + Mp.resize_c(progs[i].direct_memp_c()); + } + if (progs[i].direct_memi_c() > Mi.size_c()) + { + cerr << threadname << " needs more clear integer memory, resizing to " + << progs[i].direct_memi_c() << endl; + Mi.resize_c(progs[i].direct_memi_c()); + } + } + + progs[0].print_offline_cost(); + + /* Set up the threads */ + tinfo.resize(nthreads); + threads.resize(nthreads); + t_mutex.resize(nthreads); + client_ready.resize(nthreads); + server_ready.resize(nthreads); + join_timer.resize(nthreads); + + for (int i=0; i1) + { cerr << "Line " << line_number << " has " << + numt << " threads but tape " << tape_number << + " has unknown offline data usage" << endl; + throw invalid_program(); + } + else if (line_number == -1) + { + cerr << "Internally called tape " << tape_number << + " has unknown offline data usage" << endl; + throw invalid_program(); + } + usage_unknown = true; + return DataPositions(N.num_players()); + } + else + { + // Bits, Triples, Squares, and Inverses skipping + return progs[tape_number].get_offline_data_used(); + } +} + +void Machine::join_tape(int i) +{ + join_timer[i].start(); + pthread_mutex_lock(&t_mutex[i]); + //printf("Waiting for client to terminate\n"); + if ((tinfo[i].finished)==false) + { pthread_cond_wait(&client_ready[i],&t_mutex[i]); } + pthread_mutex_unlock(&t_mutex[i]); + join_timer[i].stop(); +} + +void Machine::run() +{ + Timer proc_timer(CLOCK_PROCESS_CPUTIME_ID); + proc_timer.start(); + timer[0].start(); + + bool flag=true; + usage_unknown=false; + int exec=0; + while (flag) + { inpf >> numt; + if (numt==0) + { flag=false; } + else + { for (int i=0; i> tn; + + // Cope with passing an integer parameter to a tape + int arg; + if (inpf.get() == ':') + inpf >> arg; + else + arg = 0; + + //cerr << "Run scheduled tape " << tn << " in thread " << i << endl; + pos.increase(run_tape(i, tn, arg, exec)); + } + // Make sure all terminate before we continue + for (int i=0; i::iterator it = timer.begin(); it != timer.end(); it++) + cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds " << endl; + + if (opening_sum < N.num_players() && !direct) + cerr << "Summed at most " << opening_sum << " shares at once with indirect communication" << endl; + else + cerr << "Summed all shares at once" << endl; + + if (max_broadcast < N.num_players() && !direct) + cerr << "Send to at most " << max_broadcast << " parties at once" << endl; + else + cerr << "Full broadcast" << endl; + + // Reduce memory size to speed up + int max_size = 1 << 20; + if (M2.size_s() > max_size) + M2.resize_s(max_size); + if (Mp.size_s() > max_size) + Mp.resize_s(max_size); + + // Write out the memory to use next time + char filename[1024]; + sprintf(filename,"Player-Data/Memory-P%d",my_number); + ofstream outf(filename,ios::out | ios::binary); + outf << M2 << Mp << Mi; + outf.close(); + + extern unsigned long long sent_amount, sent_counter; + cerr << "Data sent = " << sent_amount << " bytes in " + << sent_counter << " calls,"; + cerr << sent_amount / sent_counter / N.num_players() + << " bytes per call" << endl; + + for (int dtype = 0; dtype < N_DTYPE; dtype++) + { + cerr << "Num " << Data_Files::dtype_names[dtype] << "\t="; + for (int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++) + cerr << " " << pos.files[field_type][dtype]; + cerr << endl; + } + for (int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++) + { + cerr << "Num " << Data_Files::long_field_names[field_type] << " Inputs\t="; + for (int i = 0; i < N.num_players(); i++) + cerr << " " << pos.inputs[i][field_type]; + cerr << endl; + } + + cerr << "Total cost of program:" << endl; + pos.print_cost(); + + cerr << "End of prog" << endl; +} + + diff --git a/Processor/Machine.h b/Processor/Machine.h new file mode 100644 index 000000000..703d000b1 --- /dev/null +++ b/Processor/Machine.h @@ -0,0 +1,83 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * Machine.h + * + */ + +#ifndef MACHINE_H_ +#define MACHINE_H_ + +#include "Processor/Memory.h" +#include "Processor/Program.h" + +#include "Processor/Online-Thread.h" +#include "Processor/Data_Files.h" +#include "Math/gfp.h" + +#include "Tools/time-func.h" + +#include +#include +using namespace std; + +class Machine +{ + /* The mutex's lock the C-threads and then only release + * then we an MPC thread is ready to run on the C-thread. + * Control is passed back to the main loop when the + * MPC thread releases the mutex + */ + + vector tinfo; + vector threads; + + int my_number; + Names N; + gfp alphapi; + gf2n alpha2i; + + int nthreads; + + ifstream inpf; + + // Keep record of used offline data + DataPositions pos; + + int tn,numt; + bool usage_unknown; + + public: + + vector t_mutex; + vector client_ready; + vector server_ready; + vector progs; + + Memory M2; + Memory Mp; + Memory Mi; + + std::map timer; + vector join_timer; + Timer finish_timer; + + string prep_dir_prefix; + string progname; + + bool direct; + int opening_sum; + bool parallel; + bool receive_threads; + int max_broadcast; + + Machine(int my_number, int PortnumBase, string hostname, string progname, + string memtype, int lgp, int lg2, bool direct, int opening_sum, bool parallel, + bool receive_threads, int max_broadcast); + + DataPositions run_tape(int thread_number, int tape_number, int arg, int line_number); + void join_tape(int thread_number); + void run(); +}; + +#endif /* MACHINE_H_ */ diff --git a/Processor/Memory.cpp b/Processor/Memory.cpp new file mode 100644 index 000000000..5ce3c564c --- /dev/null +++ b/Processor/Memory.cpp @@ -0,0 +1,147 @@ +// (C) 2016 University of Bristol. See License.txt + +#include "Processor/Memory.h" +#include "Math/gf2n.h" +#include "Math/gfp.h" +#include "Math/Integer.h" + +#include + +#ifdef MEMPROTECT +template +void Memory::protect_s(unsigned int start, unsigned int end) +{ + protected_s.insert(pair(start, end)); +} + +template +void Memory::protect_c(unsigned int start, unsigned int end) +{ + protected_c.insert(pair(start, end)); +} + +template +bool Memory::is_protected_s(unsigned int index) +{ + for (set< pair >::iterator it = protected_s.begin(); + it != protected_s.end(); it++) + if (it->first <= index and it->second > index) + return true; + return false; +} + +template +bool Memory::is_protected_c(unsigned int index) +{ + for (set< pair >::iterator it = protected_c.begin(); + it != protected_c.end(); it++) + if (it->first <= index and it->second > index) + return true; + return false; +} +#endif + + +template +ostream& operator<<(ostream& s,const Memory& M) +{ + s << M.MS.size() << endl; + s << M.MC.size() << endl; + +#ifdef DEBUG + for (unsigned int i=0; i +istream& operator>>(istream& s,Memory& M) +{ + int len; + + s >> len; + M.resize_s(len); + s >> len; + M.resize_c(len); + s.seekg(1, istream::cur); + + for (unsigned int i=0; i +void Load_Memory(Memory& M,ifstream& inpf) +{ + int a; + T val; + Share S; + + inpf >> a; + M.resize_s(a); + inpf >> a; + M.resize_c(a); + + cerr << "Reading Clear Memory" << endl; + + // Read clear memory + inpf >> a; + val.input(inpf,true); + while (a!=-1) + { M.write_C(a,val); + inpf >> a; + val.input(inpf,true); + } + cerr << "Reading Shared Memory" << endl; + + // Read shared memory + inpf >> a; + S.input(inpf,true); + while (a!=-1) + { M.write_S(a,S); + inpf >> a; + S.input(inpf,true); + } +} + +template class Memory; +template class Memory; +template class Memory; + +template istream& operator>>(istream& s,Memory& M); +template istream& operator>>(istream& s,Memory& M); +template istream& operator>>(istream& s,Memory& M); + +template ostream& operator<<(ostream& s,const Memory& M); +template ostream& operator<<(ostream& s,const Memory& M); +template ostream& operator<<(ostream& s,const Memory& M); + +template void Load_Memory(Memory& M,ifstream& inpf); +template void Load_Memory(Memory& M,ifstream& inpf); +template void Load_Memory(Memory& M,ifstream& inpf); + +#ifdef USE_GF2N_LONG +template class Memory; +template istream& operator>>(istream& s,Memory& M); +template ostream& operator<<(ostream& s,const Memory& M); +template void Load_Memory(Memory& M,ifstream& inpf); +#endif diff --git a/Processor/Memory.h b/Processor/Memory.h new file mode 100644 index 000000000..21c1b1f87 --- /dev/null +++ b/Processor/Memory.h @@ -0,0 +1,97 @@ +// (C) 2016 University of Bristol. See License.txt + +#ifndef _Memory +#define _Memory + +/* Class to hold global memory of our system */ + +#include +#include +using namespace std; + +// Forward declaration as apparently this is needed for friends in templates +template class Memory; +template ostream& operator<<(ostream& s,const Memory& M); +template istream& operator>>(istream& s,Memory& M); + +#include "Math/Share.h" +template +class Memory +{ + vector > MS; + vector MC; +#ifdef MEMPROTECT + set< pair > protected_s; + set< pair > protected_c; +#endif + + public: + + void resize_s(int sz) + { MS.resize(sz); } + void resize_c(int sz) + { MC.resize(sz); } + + int size_s() + { return MS.size(); } + int size_c() + { return MC.size(); } + + const T& read_C(int i) const + { return MC[i]; } + const Share & read_S(int i) const + { return MS[i]; } + + void write_C(unsigned int i,const T& x,int PC=-1) + { MC[i]=x; + (void)PC; +#ifdef MEMPROTECT + if (is_protected_c(i)) + cerr << "Protected clear memory access of " << i << " by " << PC - 1 << endl; +#endif + } + void write_S(unsigned int i,const Share & x,int PC=-1) + { MS[i]=x; + (void)PC; +#ifdef MEMPROTECT + if (is_protected_s(i)) + cerr << "Protected secret memory access of " << i << " by " << PC - 1 << endl; +#endif + } + + +#ifdef MEMPROTECT + void protect_s(unsigned int start, unsigned int end); + void protect_c(unsigned int start, unsigned int end); + bool is_protected_s(unsigned int index); + bool is_protected_c(unsigned int index); +#else + void protect_s(unsigned int start, unsigned int end) + { (void)start, (void)end; cerr << "Memory protection not activated" << endl; } + void protect_c(unsigned int start, unsigned int end) + { (void)start, (void)end; cerr << "Memory protection not activated" << endl; } +#endif + + friend ostream& operator<< <>(ostream& s,const Memory& M); + friend istream& operator>> <>(istream& s,Memory& M); + +}; + + +/* This function loads a un-shared global memory from disk and + * produces the memory + * + * The global unshared memory is of the form + * sz <- Size + * n val <- Clear values + * n val <- Clear values + * -1 -1 <- End of clear values + * n val <- Shared values + * n val <- Shared values + * -1 -1 + */ +template +void Load_Memory(Memory& M,ifstream& inpf); + +#endif + diff --git a/Processor/Online-Thread.cpp b/Processor/Online-Thread.cpp new file mode 100644 index 000000000..f3cfa51c0 --- /dev/null +++ b/Processor/Online-Thread.cpp @@ -0,0 +1,173 @@ +// (C) 2016 University of Bristol. See License.txt + + +#include "Processor/Program.h" +#include "Processor/Online-Thread.h" +#include "Tools/time-func.h" +#include "Processor/Data_Files.h" +#include "Processor/Machine.h" +#include "Processor/Processor.h" + +#include +#include +#include +using namespace std; + + +void* Main_Func(void* ptr) +{ + thread_info *tinfo=(thread_info *) ptr; + Machine& machine=*(tinfo->machine); + vector& t_mutex = machine.t_mutex; + vector& client_ready = machine.client_ready; + vector& server_ready = machine.server_ready; + vector& progs = machine.progs; + + int num=tinfo->thread_num; + fprintf(stderr, "\tI am in thread %d\n",num); + Player* player; + if (!machine.receive_threads or machine.direct or machine.parallel) + { + cerr << "Using single-threaded receiving" << endl; + player = new Player(*(tinfo->Nms), num << 16); + } + else + { + cerr << "Using player-specific threads for receiving" << endl; + player = new ThreadPlayer(*(tinfo->Nms), num << 16); + } + Player& P = *player; + fprintf(stderr, "\tSet up player in thread %d\n",num); + + Data_Files DataF(P.my_num(),P.num_players(),machine.prep_dir_prefix); + + MAC_Check* MC2; + MAC_Check* MCp; + + // Use MAC_Check instead for more than 10000 openings at once + if (machine.direct) + { + cerr << "Using direct communication. If computation stalls, use -m when compiling." << endl; + MC2 = new Direct_MAC_Check(*(tinfo->alpha2i), *(tinfo->Nms), num); + MCp = new Direct_MAC_Check(*(tinfo->alphapi), *(tinfo->Nms), num); + } + else if (machine.parallel) + { + cerr << "Using indirect communication with background threads." << endl; + MC2 = new Parallel_MAC_Check(*(tinfo->alpha2i), *(tinfo->Nms), num, machine.opening_sum); + MCp = new Parallel_MAC_Check(*(tinfo->alphapi), *(tinfo->Nms), num, machine.opening_sum); + } + else + { + cerr << "Using indirect communication." << endl; + MC2 = new MAC_Check(*(tinfo->alpha2i), machine.opening_sum); + MCp = new MAC_Check(*(tinfo->alphapi), machine.opening_sum); + } + + Processor Proc(tinfo->thread_num,DataF,P,*MC2,*MCp,machine); + Share a,b,c; + + bool flag=true; + int program=-3; + // int exec=0; + + // Allocate memory for first program before starting the clock + Proc.reset(progs[0].num_regs2(),progs[0].num_regsp(),progs[0].num_regi(),tinfo->arg); + + // synchronize + cerr << "Locking for sync of thread " << num << endl; + pthread_mutex_lock(&t_mutex[num]); + tinfo->ready=true; + pthread_cond_signal(&client_ready[num]); + pthread_mutex_unlock(&t_mutex[num]); + + Timer thread_timer(CLOCK_THREAD_CPUTIME_ID), wait_timer; + thread_timer.start(); + + while (flag) + { // Wait until I have a program to run + wait_timer.start(); + pthread_mutex_lock(&t_mutex[num]); + if ((tinfo->prognum)==-2) + { pthread_cond_wait(&server_ready[num],&t_mutex[num]); } + program=(tinfo->prognum); + (tinfo->prognum)=-2; + pthread_mutex_unlock(&t_mutex[num]); + wait_timer.stop(); + //printf("\tRunning program %d\n",program); + + if (program==-1) + { flag=false; + fprintf(stderr, "\tThread %d terminating\n",num); + } + else + { // RUN PROGRAM + //printf("\tClient %d about to run %d in execution %d\n",num,program,exec); + Proc.reset(progs[program].num_regs2(),progs[program].num_regsp(),progs[program].num_regi(),tinfo->arg); + + // Bits, Triples, Squares, and Inverses skipping + DataF.seekg(tinfo->pos); + + //printf("\tExecuting program"); + // Execute the program + progs[program].execute(Proc); + + if (progs[program].usage_unknown()) + { // communicate file positions to main thread + tinfo->pos = DataF.get_usage(); + } + + //double elapsed = timeval_diff(&startv, &endv); + //printf("Thread time = %f seconds\n",elapsed/1000000); + //printf("\texec = %d\n",exec); exec++; + //printf("\tMC2.number = %d\n",MC2.number()); + //printf("\tMCp.number = %d\n",MCp.number()); + + // MACCheck + MC2->Check(P); + MCp->Check(P); + //printf("\tMAC checked\n"); + P.Check_Broadcast(); + //printf("\tBroadcast checked\n"); + + // printf("\tSignalling I have finished\n"); + wait_timer.start(); + pthread_mutex_lock(&t_mutex[num]); + (tinfo->finished)=true; + pthread_cond_signal(&client_ready[num]); + pthread_mutex_unlock(&t_mutex[num]); + wait_timer.stop(); + } + } + + // MACCheck + MC2->Check(P); + MCp->Check(P); + + //cout << num << " : Checking broadcast" << endl; + P.Check_Broadcast(); + //cout << num << " : Broadcast checked "<< endl; + + wait_timer.start(); + pthread_mutex_lock(&t_mutex[num]); + if (!tinfo->ready) + pthread_cond_wait(&server_ready[num], &t_mutex[num]); + pthread_mutex_unlock(&t_mutex[num]); + wait_timer.stop(); + + cerr << num << " : MAC Checking" << endl; + cerr << "\tMC2.number=" << MC2->number() << endl; + cerr << "\tMCp.number=" << MCp->number() << endl; + + cerr << "Thread " << num << " timer: " << thread_timer.elapsed() << endl; + cerr << "Thread " << num << " wait timer: " << wait_timer.elapsed() << endl; + + delete MC2; + delete MCp; + delete player; + + return NULL; +} + + + diff --git a/Processor/Online-Thread.h b/Processor/Online-Thread.h new file mode 100644 index 000000000..485a073e1 --- /dev/null +++ b/Processor/Online-Thread.h @@ -0,0 +1,41 @@ +// (C) 2016 University of Bristol. See License.txt + +#ifndef _Online_Thread +#define _Online_Thread + +#include "Networking/Player.h" +#include "Math/gf2n.h" +#include "Math/gfp.h" +#include "Math/Integer.h" +#include "Processor/Data_Files.h" + +#include +using namespace std; + +class Machine; + +class thread_info +{ + public: + + int thread_num; + int covert; + Names* Nms; + gf2n *alpha2i; + gfp *alphapi; + int prognum; + bool finished; + bool ready; + + // rownums for triples, bits, squares, and inverses etc + DataPositions pos; + // Integer arg (optional) + int arg; + + Machine* machine; +}; + +void* Main_Func(void *ptr); + +#endif + diff --git a/Processor/PrivateOutput.cpp b/Processor/PrivateOutput.cpp new file mode 100644 index 000000000..909871d4e --- /dev/null +++ b/Processor/PrivateOutput.cpp @@ -0,0 +1,35 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * PrivateOutput.cpp + * + */ + +#include "PrivateOutput.h" +#include "Processor.h" + +template +void PrivateOutput::start(int player, int target, int source) +{ + T mask; + proc.DataF.get_input(proc.get_S_ref(target), mask, player); + proc.get_S_ref(target).add(proc.get_S_ref(source)); + + if (player == proc.P.my_num()) + masks.push_back(mask); +} + +template +void PrivateOutput::stop(int player, int source) +{ + if (player == proc.P.my_num()) + { + T value; + value.sub(proc.get_C_ref(source), masks.front()); + value.output(proc.private_output, false); + masks.pop_front(); + } +} + +template class PrivateOutput; +template class PrivateOutput; diff --git a/Processor/PrivateOutput.h b/Processor/PrivateOutput.h new file mode 100644 index 000000000..52c7522ff --- /dev/null +++ b/Processor/PrivateOutput.h @@ -0,0 +1,31 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * PrivateOutput.h + * + */ + +#ifndef PROCESSOR_PRIVATEOUTPUT_H_ +#define PROCESSOR_PRIVATEOUTPUT_H_ + +#include +using namespace std; + +#include "Math/Share.h" + +class Processor; + +template +class PrivateOutput +{ + Processor& proc; + deque masks; + +public: + PrivateOutput(Processor& proc) : proc(proc) { }; + + void start(int player, int target, int source); + void stop(int player, int source); +}; + +#endif /* PROCESSOR_PRIVATEOUTPUT_H_ */ diff --git a/Processor/Processor.cpp b/Processor/Processor.cpp new file mode 100644 index 000000000..a62c78e26 --- /dev/null +++ b/Processor/Processor.cpp @@ -0,0 +1,222 @@ +// (C) 2016 University of Bristol. See License.txt + + +#include "Processor/Processor.h" +#include "Auth/MAC_Check.h" + +#include "Auth/fake-stuff.h" + + +Processor::Processor(int thread_num,Data_Files& DataF,Player& P, + MAC_Check& MC2,MAC_Check& MCp,Machine& machine, + int num_regs2,int num_regsp,int num_regi) +: thread_num(thread_num),socket_is_open(false),DataF(DataF),P(P),MC2(MC2),MCp(MCp),machine(machine), + input2(*this,MC2),inputp(*this,MCp),privateOutput2(*this),privateOutputp(*this),sent(0),rounds(0) +{ + reset(num_regs2,num_regsp,num_regi,0); + + public_input.open(get_filename("Programs/Public-Input/",false).c_str()); + private_input.open(get_filename("Player-Data/Private-Input-",true).c_str()); + public_output.open(get_filename("Player-Data/Public-Output-",true).c_str(), ios_base::out); + private_output.open(get_filename("Player-Data/Private-Output-",true).c_str(), ios_base::out); +} + + +Processor::~Processor() +{ + cerr << "Sent " << sent << " elements in " << rounds << " rounds" << endl; +} + + +string Processor::get_filename(const char* prefix, bool use_number) +{ + stringstream filename; + filename << prefix; + if (!use_number) + filename << machine.progname; + if (use_number) + filename << P.my_num(); + if (thread_num > 0) + filename << "-" << thread_num; + cerr << "Opening file " << filename.str() << endl; + return filename.str(); +} + + +void Processor::reset(int num_regs2,int num_regsp,int num_regi,int arg) +{ + reg_max2 = num_regs2; + reg_maxp = num_regsp; + reg_maxi = num_regi; + C2.resize(reg_max2); Cp.resize(reg_maxp); + S2.resize(reg_max2); Sp.resize(reg_maxp); + Ci.resize(reg_maxi); + this->arg = arg; + close_socket(); + + #ifdef DEBUG + rw2.resize(2*reg_max2); + for (int i=0; i<2*reg_max2; i++) { rw2[i]=0; } + rwp.resize(2*reg_maxp); + for (int i=0; i<2*reg_maxp; i++) { rwp[i]=0; } + rwi.resize(2*reg_maxi); + for (int i=0; i<2*reg_maxi; i++) { rwi[i]=0; } + #endif +} + +#include "Networking/sockets.h" + +// Set up a server socket for some client +void Processor::open_socket(int portnum_base) +{ + if (!socket_is_open) + { + socket_is_open = true; + sockaddr_in dest; + set_up_server_socket(dest, final_socket_fd, socket_fd, portnum_base + P.my_num()); + } +} + +void Processor::close_socket() +{ + if (socket_is_open) + { + socket_is_open = false; + close_server_socket(final_socket_fd, socket_fd); + } +} + +// Receive 32-bit int +void Processor::read_socket(int& x) +{ + octet bytes[4]; + receive(final_socket_fd, bytes, 4); + x = BYTES_TO_INT(bytes); +} + +// Send 32-bit int +void Processor::write_socket(int x) +{ + octet bytes[4]; + INT_TO_BYTES(bytes, x); + send(final_socket_fd, bytes, 4); +} + +// Receive field element +template +void Processor::read_socket(T& x) +{ + socket_stream.reset_write_head(); + socket_stream.Receive(final_socket_fd); + x.unpack(socket_stream); +} + +// Send field element +template +void Processor::write_socket(const T& x) +{ + socket_stream.reset_write_head(); + x.pack(socket_stream); + socket_stream.Send(final_socket_fd); +} + +template +void Processor::POpen_Start(const vector& reg,const Player& P,MAC_Check& MC,int size) +{ + int sz=reg.size(); + vector< Share >& Sh_PO = get_Sh_PO(); + vector& PO = get_PO(); + Sh_PO.clear(); + Sh_PO.reserve(sz*size); + if (size>1) + { + for (typename vector::const_iterator reg_it=reg.begin(); + reg_it!=reg.end(); reg_it++) + { + typename vector >::iterator begin=get_S().begin()+*reg_it; + Sh_PO.insert(Sh_PO.end(),begin,begin+size); + } + } + else + { + for (int i=0; i(reg[i])); } + } + PO.resize(sz*size); + MC.POpen_Begin(PO,Sh_PO,P); +} + + +template +void Processor::POpen_Stop(const vector& reg,const Player& P,MAC_Check& MC,int size) +{ + vector< Share >& Sh_PO = get_Sh_PO(); + vector& PO = get_PO(); + vector& C = get_C(); + int sz=reg.size(); + PO.resize(sz*size); + MC.POpen_End(PO,Sh_PO,P); + if (size>1) + { + typename vector::iterator PO_it=PO.begin(); + for (typename vector::const_iterator reg_it=reg.begin(); + reg_it!=reg.end(); reg_it++) + { + for (typename vector::iterator C_it=C.begin()+*reg_it; + C_it!=C.begin()+*reg_it+size; C_it++) + { + *C_it=*PO_it; + PO_it++; + } + } + } + else + { + for (unsigned int i=0; i(reg[i]) = PO[i]; } + } + + sent += reg.size() * size; + rounds++; +} + + + + + + + +ostream& operator<<(ostream& s,const Processor& P) +{ + s << "Processor State" << endl; + s << "Char 2 Registers" << endl; + s << "Val\tClearReg\tSharedReg" << endl; + for (int i=0; i& reg,const Player& P,MAC_Check& MC,int size); +template void Processor::POpen_Start(const vector& reg,const Player& P,MAC_Check& MC,int size); +template void Processor::POpen_Stop(const vector& reg,const Player& P,MAC_Check& MC,int size); +template void Processor::POpen_Stop(const vector& reg,const Player& P,MAC_Check& MC,int size); +template void Processor::read_socket(gfp& x); +template void Processor::read_socket(gf2n& x); +template void Processor::write_socket(const gfp& x); +template void Processor::write_socket(const gf2n& x); diff --git a/Processor/Processor.h b/Processor/Processor.h new file mode 100644 index 000000000..a4a0d37e1 --- /dev/null +++ b/Processor/Processor.h @@ -0,0 +1,267 @@ +// (C) 2016 University of Bristol. See License.txt + + +#ifndef _Processor +#define _Processor + +/* This is a representation of a processing element + * Consisting of 256 clear and 256 shared registers + */ + +#include "Math/Share.h" +#include "Math/gf2n.h" +#include "Math/gfp.h" +#include "Math/Integer.h" +#include "Exceptions/Exceptions.h" +#include "Networking/Player.h" +#include "Auth/MAC_Check.h" +#include "Data_Files.h" +#include "Input.h" +#include "PrivateOutput.h" +#include "Machine.h" + +#include + +class Processor +{ + vector C2; + vector Cp; + vector > S2; + vector > Sp; + vector Ci; + + // Stack + stack stacki; + + // This is the vector of partially opened values and shares we need to store + // as the Open commands are split in two + vector PO2; + vector POp; + vector > Sh_PO2; + vector > Sh_POp; + + int reg_max2,reg_maxp,reg_maxi; + int thread_num; + + // Optional argument to tape + int arg; + + // For reading/reading data from a socket (i.e. external party to SPDZ) + octetStream socket_stream; + int socket_fd, final_socket_fd; + bool socket_is_open; + + #ifdef DEBUG + vector rw2; + vector rwp; + vector rwi; + #endif + + template + vector< Share >& get_S(); + template + vector& get_C(); + + template + vector< Share >& get_Sh_PO(); + template + vector& get_PO(); + + public: + Data_Files& DataF; + Player& P; + MAC_Check& MC2; + MAC_Check& MCp; + Machine& machine; + + Input input2; + Input inputp; + + PrivateOutput privateOutput2; + PrivateOutput privateOutputp; + + ifstream public_input; + ifstream private_input; + ofstream public_output; + ofstream private_output; + + unsigned int PC; + TempVars temp; + PRNG prng; + + int sent, rounds; + + static const int reg_bytes = 4; + + void reset(int num_regs2,int num_regsp,int num_regi,int arg); // Reset the state of the processor + string get_filename(const char* basename, bool use_number); + + Processor(int thread_num,Data_Files& DataF,Player& P, + MAC_Check& MC2,MAC_Check& MCp,Machine& machine, + int num_regs2 = 256,int num_regsp = 256,int num_regi = 256); + ~Processor(); + + int get_thread_num() + { + return thread_num; + } + + int get_arg() const + { + return arg; + } + + void set_arg(int new_arg) + { + arg=new_arg; + } + + void pushi(long x) { stacki.push(x); } + void popi(long& x) { x = stacki.top(); stacki.pop(); } + + #ifdef DEBUG + const gf2n& read_C2(int i) const + { if (rw2[i]==0) + { throw Processor_Error("Invalid read on clear register"); } + return C2.at(i); + } + const Share & read_S2(int i) const + { if (rw2[i+reg_max2]==0) + { throw Processor_Error("Invalid read on shared register"); } + return S2.at(i); + } + gf2n& get_C2_ref(int i) + { rw2[i]=1; + return C2.at(i); + } + Share & get_S2_ref(int i) + { rw2[i+reg_max2]=1; + return S2.at(i); + } + void write_C2(int i,const gf2n& x) + { rw2[i]=1; + C2.at(i)=x; + } + void write_S2(int i,const Share & x) + { rw2[i+reg_max2]=1; + S2.at(i)=x; + } + + const gfp& read_Cp(int i) const + { if (rwp[i]==0) + { throw Processor_Error("Invalid read on clear register"); } + return Cp.at(i); + } + const Share & read_Sp(int i) const + { if (rwp[i+reg_maxp]==0) + { throw Processor_Error("Invalid read on shared register"); } + return Sp.at(i); + } + gfp& get_Cp_ref(int i) + { rwp[i]=1; + return Cp.at(i); + } + Share & get_Sp_ref(int i) + { rwp[i+reg_maxp]=1; + return Sp.at(i); + } + void write_Cp(int i,const gfp& x) + { rwp[i]=1; + Cp.at(i)=x; + } + void write_Sp(int i,const Share & x) + { rwp[i+reg_maxp]=1; + Sp.at(i)=x; + } + + const long& read_Ci(int i) const + { if (rwi[i]==0) + { throw Processor_Error("Invalid read on integer register"); } + return Ci.at(i); + } + long& get_Ci_ref(int i) + { rwi[i]=1; + return Ci.at(i); + } + void write_Ci(int i,const long& x) + { rwi[i]=1; + Ci.at(i)=x; + } + #else + const gf2n& read_C2(int i) const + { return C2[i]; } + const Share & read_S2(int i) const + { return S2[i]; } + gf2n& get_C2_ref(int i) + { return C2[i]; } + Share & get_S2_ref(int i) + { return S2[i]; } + void write_C2(int i,const gf2n& x) + { C2[i]=x; } + void write_S2(int i,const Share & x) + { S2[i]=x; } + + const gfp& read_Cp(int i) const + { return Cp[i]; } + const Share & read_Sp(int i) const + { return Sp[i]; } + gfp& get_Cp_ref(int i) + { return Cp[i]; } + Share & get_Sp_ref(int i) + { return Sp[i]; } + void write_Cp(int i,const gfp& x) + { Cp[i]=x; } + void write_Sp(int i,const Share & x) + { Sp[i]=x; } + + const long& read_Ci(int i) const + { return Ci[i]; } + long& get_Ci_ref(int i) + { return Ci[i]; } + void write_Ci(int i,const long& x) + { Ci[i]=x; } + #endif + + // Template-based access + template Share& get_S_ref(int i); + template T& get_C_ref(int i); + + // Access to sockets for reading clear/shared data + void open_socket(int portnum_base); + void close_socket(); + void read_socket(int& x); + void write_socket(int x); + template + void read_socket(T& x); + template + void write_socket(const T& x); + + // Access to PO (via calls to POpen start/stop) + template + void POpen_Start(const vector& reg,const Player& P,MAC_Check& MC,int size); + + template + void POpen_Stop(const vector& reg,const Player& P,MAC_Check& MC,int size); + + // Print the processor state + friend ostream& operator<<(ostream& s,const Processor& P); +}; + +template<> inline Share& Processor::get_S_ref(int i) { return get_S2_ref(i); } +template<> inline gf2n& Processor::get_C_ref(int i) { return get_C2_ref(i); } +template<> inline Share& Processor::get_S_ref(int i) { return get_Sp_ref(i); } +template<> inline gfp& Processor::get_C_ref(int i) { return get_Cp_ref(i); } + +template<> inline vector< Share >& Processor::get_S() { return S2; } +template<> inline vector< Share >& Processor::get_S() { return Sp; } + +template<> inline vector& Processor::get_C() { return C2; } +template<> inline vector& Processor::get_C() { return Cp; } + +template<> inline vector< Share >& Processor::get_Sh_PO() { return Sh_PO2; } +template<> inline vector& Processor::get_PO() { return PO2; } +template<> inline vector< Share >& Processor::get_Sh_PO() { return Sh_POp; } +template<> inline vector& Processor::get_PO() { return POp; } + +#endif + diff --git a/Processor/Program.cpp b/Processor/Program.cpp new file mode 100644 index 000000000..9b132305b --- /dev/null +++ b/Processor/Program.cpp @@ -0,0 +1,80 @@ +// (C) 2016 University of Bristol. See License.txt + + +#include "Processor/Program.h" +#include "Processor/Data_Files.h" +#include "Processor/Processor.h" + +void Program::compute_constants() +{ + max_reg2 = 0; + max_regp = 0; + max_regi = 0; + for (int reg_type = 0; reg_type < MAX_REG_TYPE; reg_type++) + for (int sec_type = 0; sec_type < MAX_SECRECY_TYPE; sec_type++) + max_mem[reg_type][sec_type] = 0; + for (unsigned int i=0; i p; + // Here we note the number of bits, squares and triples and input + // data needed + // - This is computed for a whole program sequence to enable + // the run time to be able to determine which ones to pass to it + DataPositions offline_data_used; + + // Maximal register used + int max_reg2,max_regp,max_regi; + + // Memory size used directly + int max_mem[MAX_REG_TYPE][MAX_SECRECY_TYPE]; + + // True if program contains variable-sized loop + bool unknown_usage; + + void compute_constants(); + + public: + + Program(int nplayers) : offline_data_used(nplayers), + max_reg2(0), max_regp(0), max_regi(0), unknown_usage(false) + { p.resize(0); } + + // Read in a program + void parse(istream& s); + + DataPositions get_offline_data_used() const { return offline_data_used; } + void print_offline_cost() const; + + bool usage_unknown() const { return unknown_usage; } + + int num_regs2() const { return max_reg2; } + int num_regsp() const { return max_regp; } + int num_regi() const { return max_regi; } + + int direct_mem(RegType reg_type, SecrecyType sec_type) + { return max_mem[reg_type][sec_type]; } + + int direct_mem2_s() const { return max_mem[GF2N][SECRET]; } + int direct_memp_s() const { return max_mem[MODP][SECRET]; } + int direct_mem2_c() const { return max_mem[GF2N][CLEAR]; } + int direct_memp_c() const { return max_mem[MODP][CLEAR]; } + int direct_memi_c() const { return max_mem[INT][CLEAR]; } + + friend ostream& operator<<(ostream& s,const Program& P); + + // Execute this program, updateing the processor and memory + // and streams pointing to the triples etc + void execute(Processor& Proc) const; + +}; + +#endif + diff --git a/Programs/Source/aes.mpc b/Programs/Source/aes.mpc new file mode 100644 index 000000000..58a87da7b --- /dev/null +++ b/Programs/Source/aes.mpc @@ -0,0 +1,395 @@ +# (C) 2016 University of Bristol. See License.txt + +from copy import copy + +rcon_raw = [ + 0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a, + 0x2f, 0x5e, 0xbc, 0x63, 0xc6, 0x97, 0x35, 0x6a, 0xd4, 0xb3, 0x7d, 0xfa, 0xef, 0xc5, 0x91, 0x39, + 0x72, 0xe4, 0xd3, 0xbd, 0x61, 0xc2, 0x9f, 0x25, 0x4a, 0x94, 0x33, 0x66, 0xcc, 0x83, 0x1d, 0x3a, + 0x74, 0xe8, 0xcb, 0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, + 0xab, 0x4d, 0x9a, 0x2f, 0x5e, 0xbc, 0x63, 0xc6, 0x97, 0x35, 0x6a, 0xd4, 0xb3, 0x7d, 0xfa, 0xef, + 0xc5, 0x91, 0x39, 0x72, 0xe4, 0xd3, 0xbd, 0x61, 0xc2, 0x9f, 0x25, 0x4a, 0x94, 0x33, 0x66, 0xcc, + 0x83, 0x1d, 0x3a, 0x74, 0xe8, 0xcb, 0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, + 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a, 0x2f, 0x5e, 0xbc, 0x63, 0xc6, 0x97, 0x35, 0x6a, 0xd4, 0xb3, + 0x7d, 0xfa, 0xef, 0xc5, 0x91, 0x39, 0x72, 0xe4, 0xd3, 0xbd, 0x61, 0xc2, 0x9f, 0x25, 0x4a, 0x94, + 0x33, 0x66, 0xcc, 0x83, 0x1d, 0x3a, 0x74, 0xe8, 0xcb, 0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, + 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a, 0x2f, 0x5e, 0xbc, 0x63, 0xc6, 0x97, 0x35, + 0x6a, 0xd4, 0xb3, 0x7d, 0xfa, 0xef, 0xc5, 0x91, 0x39, 0x72, 0xe4, 0xd3, 0xbd, 0x61, 0xc2, 0x9f, + 0x25, 0x4a, 0x94, 0x33, 0x66, 0xcc, 0x83, 0x1d, 0x3a, 0x74, 0xe8, 0xcb, 0x8d, 0x01, 0x02, 0x04, + 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a, 0x2f, 0x5e, 0xbc, 0x63, + 0xc6, 0x97, 0x35, 0x6a, 0xd4, 0xb3, 0x7d, 0xfa, 0xef, 0xc5, 0x91, 0x39, 0x72, 0xe4, 0xd3, 0xbd, + 0x61, 0xc2, 0x9f, 0x25, 0x4a, 0x94, 0x33, 0x66, 0xcc, 0x83, 0x1d, 0x3a, 0x74, 0xe8, 0xcb +] + +nparallel = 1 +noutput = 1 +nthreads = 1 + +rcon = VectorArray(len(rcon_raw), cgf2n, nparallel) +for idx in range(len(rcon_raw)): + rcon[idx] = cgf2n(rcon_raw[idx],size=nparallel) + +powers2 = VectorArray(8, cgf2n, nparallel) +for idx in range(8): + powers2[idx] = cgf2n(2,size=nparallel) ** (5 * idx) + +@vectorize +def ApplyEmbedding(x): + in_bytes = x.bit_decompose(8) + + out_bytes = [cgf2n(0) for _ in range(8)] + + out_bytes[0] = sum(in_bytes[0:8]) + out_bytes[1] = sum(in_bytes[idx] for idx in range(1, 8, 2)) + out_bytes[2] = in_bytes[2] + in_bytes[3] + in_bytes[6] + in_bytes[7] + out_bytes[3] = in_bytes[3] + in_bytes[7] + out_bytes[4] = in_bytes[4] + in_bytes[5] + in_bytes[6] + in_bytes[7] + out_bytes[5] = in_bytes[5] + in_bytes[7] + out_bytes[6] = in_bytes[6] + in_bytes[7] + out_bytes[7] = in_bytes[7] + + return sum(powers2[idx] * out_bytes[idx] for idx in range(8)) + + +def embed_helper(in_bytes): + out_bytes = [None] * 8 + out_bytes[0] = sum(in_bytes[0:8]) + out_bytes[1] = sum(in_bytes[idx] for idx in range(1, 8, 2)) + out_bytes[2] = in_bytes[2] + in_bytes[3] + in_bytes[6] + in_bytes[7] + out_bytes[3] = in_bytes[3] + in_bytes[7] + out_bytes[4] = in_bytes[4] + in_bytes[5] + in_bytes[6] + in_bytes[7] + out_bytes[5] = in_bytes[5] + in_bytes[7] + out_bytes[6] = in_bytes[6] + in_bytes[7] + out_bytes[7] = in_bytes[7] + return out_bytes + +@vectorize +def ApplyBDEmbedding(x): + entire_sequence_bits = copy(x) + + while len(entire_sequence_bits) < 8: + entire_sequence_bits.append(0) + + in_bytes = entire_sequence_bits + out_bytes = embed_helper(in_bytes) + + return sum(powers2[idx] * out_bytes[idx] for idx in range(8)) + + +def PreprocInverseEmbedding(x): + in_bytes = x.bit_decompose_embedding() + + out_bytes = [cgf2n(0) for _ in range(8)] + + out_bytes[7] = in_bytes[7] + out_bytes[6] = in_bytes[6] + out_bytes[7] + out_bytes[5] = in_bytes[5] + out_bytes[7] + out_bytes[4] = in_bytes[4] + out_bytes[5] + out_bytes[6] + out_bytes[7] + out_bytes[3] = in_bytes[3] + out_bytes[7] + out_bytes[2] = in_bytes[2] + out_bytes[3] + out_bytes[6] + out_bytes[7] + out_bytes[1] = in_bytes[1] + out_bytes[3] + out_bytes[5] + out_bytes[7] + out_bytes[0] = in_bytes[0] + sum(out_bytes[1:8]) + + return out_bytes + +@vectorize +def InverseEmbedding(x): + out_bytes = PreprocInverseEmbedding(x) + ret = cgf2n(0) + for idx in range(7, -1, -1): + ret = ret + (cgf2n(2) ** idx) * out_bytes[idx] + return ret + +def InverseBDEmbedding(x): + return PreprocInverseEmbedding(x) + +def expandAESKey(cipherKey, Nr = 10, Nb = 4, Nk = 4): + #cipherkey should be in hex + cipherKeySize = len(cipherKey) + + round_key = [sgf2n(0,size=nparallel)] * 176 + temp = [cgf2n(0,size=nparallel)] * 4 + + for i in range(Nk): + for j in range(4): + round_key[4 * i + j] = cipherKey[4 * i + j] + + for i in range(Nk, Nb * (Nr + 1)): + for j in range(4): + temp[j] = round_key[(i-1) * 4 + j] + if i % Nk == 0: + #rotate the 4 bytes word to the left + k = temp[0] + temp[0] = temp[1] + temp[1] = temp[2] + temp[2] = temp[3] + temp[3] = k + + #now substitute word + temp[0] = box.apply_sbox(temp[0]) + temp[1] = box.apply_sbox(temp[1]) + temp[2] = box.apply_sbox(temp[2]) + temp[3] = box.apply_sbox(temp[3]) + + temp[0] = temp[0] + ApplyEmbedding(rcon[int(i/Nk)]) + + for j in range(4): + round_key[4 * i + j] = round_key[4 * (i - Nk) + j] + temp[j] + return round_key + + #Nr = 10 -> The number of rounds in AES Cipher. + #Nb = 4 -> The number of columns of the AES state + #Nk = 4 -> The number of words of a AES key + +def SecretArrayEmbedd(byte_array): + return [ApplyEmbedding(_) for _ in byte_array] + +@vectorize +def subBytes(state): + for i in range(len(state)): + state[i] = box.apply_sbox(state[i]) + +def addRoundKey(roundKey): + @vectorize + def inner(state): + for i in range(len(state)): + state[i] = state[i] + roundKey[i] + return inner + +# mixColumn takes a column and does stuff + +Kv = VectorArray(4, cgf2n, nparallel) +Kv[1] = ApplyEmbedding(cgf2n(1,size=nparallel)) +Kv[2] = ApplyEmbedding(cgf2n(2,size=nparallel)) +Kv[3] = ApplyEmbedding(cgf2n(3,size=nparallel)) +Kv[4] = ApplyEmbedding(cgf2n(4,size=nparallel)) + + +@vectorize +def mixColumn(column): + temp = copy(column) + v1 = Kv[1] + v2 = Kv[2] + v3 = Kv[3] + v4 = Kv[4] + # no multiplication + doubles = [Kv[2] * t for t in temp] + column[0] = doubles[0] + (temp[1] + doubles[1]) + temp[2] + temp[3] + column[1] = temp[0] + doubles[1] + (temp[2] + doubles[2]) + temp[3] + column[2] = temp[0] + temp[1] + doubles[2] + (temp[3] + doubles[3]) + column[3] = (temp[0] + doubles[0]) + temp[1] + temp[2] + doubles[3] + +@vectorize +def mixColumns(state): + for i in range(4): + column = [] + for j in range(4): + column.append(state[i*4+j]) + mixColumn(column) + for j in range(4): + state[i*4+j] = column[j] + +def rotate(word, n): + return word[n:]+word[0:n] + +def shiftRows(state): + for i in range(4): + state[i::4] = rotate(state[i::4],i) + +@vectorize +def state_collapse(state): + return [InverseEmbedding(_) for _ in state] + + +# such constants. very wow. +_embedded_powers = [ + [0x1,0x2,0x4,0x8,0x10,0x20,0x40,0x80,0x100,0x200,0x400,0x800,0x1000,0x2000,0x4000,0x8000,0x10000,0x20000,0x40000,0x80000,0x100000,0x200000,0x400000,0x800000,0x1000000,0x2000000,0x4000000,0x8000000,0x10000000,0x20000000,0x40000000,0x80000000,0x100000000,0x200000000,0x400000000,0x800000000,0x1000000000,0x2000000000,0x4000000000,0x8000000000], + [0x1,0x4,0x10,0x40,0x100,0x400,0x1000,0x4000,0x10000,0x40000,0x100000,0x400000,0x1000000,0x4000000,0x10000000,0x40000000,0x100000000,0x400000000,0x1000000000,0x4000000000,0x108401,0x421004,0x1084010,0x4210040,0x10840100,0x42100400,0x108401000,0x421004000,0x1084010000,0x4210040000,0x840008401,0x2100021004,0x8400084010,0x1000000842,0x4000002108,0x100021,0x400084,0x1000210,0x4000840,0x10002100], + [0x1,0x10,0x100,0x1000,0x10000,0x100000,0x1000000,0x10000000,0x100000000,0x1000000000,0x108401,0x1084010,0x10840100,0x108401000,0x1084010000,0x840008401,0x8400084010,0x4000002108,0x400084,0x4000840,0x40008400,0x400084000,0x4000840000,0x8021004,0x80210040,0x802100400,0x8021004000,0x210802008,0x2108020080,0x1080010002,0x800008421,0x8000084210,0x108,0x1080,0x10800,0x108000,0x1080000,0x10800000,0x108000000,0x1080000000], + [0x1,0x100,0x10000,0x1000000,0x100000000,0x108401,0x10840100,0x1084010000,0x8400084010,0x400084,0x40008400,0x4000840000,0x80210040,0x8021004000,0x2108020080,0x800008421,0x108,0x10800,0x1080000,0x108000000,0x800108401,0x10002108,0x1000210800,0x20004010,0x2000401000,0x42008020,0x4200802000,0x84200842,0x8420084200,0x2000421084,0x40000420,0x4000042000,0x10040,0x1004000,0x100400000,0x40108401,0x4010840100,0x1080200040,0x8021080010,0x2100421080], + [0x1,0x10000,0x100000000,0x10840100,0x8400084010,0x40008400,0x80210040,0x2108020080,0x108,0x1080000,0x800108401,0x1000210800,0x2000401000,0x4200802000,0x8420084200,0x40000420,0x10040,0x100400000,0x4010840100,0x8021080010,0x40108421,0x1080000040,0x100421080,0x4200040100,0x1084200,0x842108401,0x1004210042,0x2008400004,0x4210000008,0x401080210,0x840108001,0x1000000840,0x100001000,0x840100,0x8401000000,0x800000001,0x84210800,0x2100001084,0x210802100,0x8001004210], + [0x1,0x100000000,0x8400084010,0x80210040,0x108,0x800108401,0x2000401000,0x8420084200,0x10040,0x4010840100,0x40108421,0x100421080,0x1084200,0x1004210042,0x4210000008,0x840108001,0x100001000,0x8401000000,0x84210800,0x210802100,0x800000401,0x2100420080,0x8000004000,0x4010002,0x4000800100,0x842000420,0x8421084,0x421080210,0x80010042,0x10802108,0x800000020,0x1084,0x8401084010,0x1004200040,0x4000840108,0x100020,0x2108401000,0x8400080210,0x84210802,0x10802100], + [0x1,0x8400084010,0x108,0x2000401000,0x10040,0x40108421,0x1084200,0x4210000008,0x100001000,0x84210800,0x800000401,0x8000004000,0x4000800100,0x8421084,0x80010042,0x800000020,0x8401084010,0x4000840108,0x2108401000,0x84210802,0x20,0x8000004210,0x2100,0x8401004,0x200800,0x802108420,0x21084000,0x4200842108,0x2000020000,0x1084210000,0x100421,0x1004010,0x10840008,0x108421080,0x1000200840,0x108001,0x8020004210,0x10040108,0x2108401004,0x1084210040], + [0x1,0x108,0x10040,0x1084200,0x100001000,0x800000401,0x4000800100,0x80010042,0x8401084010,0x2108401000,0x20,0x2100,0x200800,0x21084000,0x2000020000,0x100421,0x10840008,0x1000200840,0x8020004210,0x2108401004,0x400,0x42000,0x4010000,0x421080000,0x21004,0x2008420,0x210800100,0x4200002,0x401000210,0x2108401084,0x8000,0x840000,0x80200000,0x8421000000,0x420080,0x40108400,0x4210002000,0x84000040,0x8020004200,0x2108400084] +] + +enum_squarings = VectorArray(8 * 40, cgf2n, nparallel) +for i,_list in enumerate(_embedded_powers): + for j,x in enumerate(_list): + enum_squarings[40 * i + j] = cgf2n(x, size=nparallel) + +@vectorize +def fancy_squaring(bd_val, exponent): + #This is even more fancy; it performs directly on bit dec values + #returns x ** (2 ** exp) from a bit decomposed value + return sum(enum_squarings[exponent * 40 + idx] * bd_val[idx] + for idx in range(len(bd_val))) + +def inverseMod(val): + #embedded now! + #returns x ** 254 using offline squaring + #returns an embedded result + + raw_bit_dec = val.bit_decompose_embedding() + bd_val = [cgf2n(0,size=nparallel)] * 40 + + for idx in range(40): + if idx % 5 == 0: + bd_val[idx] = raw_bit_dec[idx / 5] + + bd_squared = bd_val + squared_index = 2 + + mapper = [0] * 129 + for idx in range(1, 8): + bd_squared = fancy_squaring(bd_val, idx) + mapper[squared_index] = bd_squared + squared_index *= 2 + + enum_powers = [ + 2, 4, 8, 16, 32, 64, 128 + ] + + inverted_product = \ + ((mapper[2] * mapper[4]) * (mapper[8] * mapper[16])) * ((mapper[32] * mapper[64]) * mapper[128]) + return inverted_product + +K01 = VectorArray(8, cgf2n, nparallel) +for idx in range(8): + K01[idx] = ApplyBDEmbedding([0,1]) ** idx + +class SpdzBox(object): + def init_matrices(self): + self.matrix_inv = [ + [0,0,1,0,0,1,0,1], + [1,0,0,1,0,0,1,0], + [0,1,0,0,1,0,0,1], + [1,0,1,0,0,1,0,0], + [0,1,0,1,0,0,1,0], + [0,0,1,0,1,0,0,1], + [1,0,0,1,0,1,0,0], + [0,1,0,0,1,0,1,0] + ] + to_add = [1,0,1,0,0,0,0,0] + self.addition_inv = [cgf2n(_,size=nparallel) for _ in to_add] + self.forward_matrix = [ + [1,0,0,0,1,1,1,1], + [1,1,0,0,0,1,1,1], + [1,1,1,0,0,0,1,1], + [1,1,1,1,0,0,0,1], + [1,1,1,1,1,0,0,0], + [0,1,1,1,1,1,0,0], + [0,0,1,1,1,1,1,0], + [0,0,0,1,1,1,1,1] + ] + forward_add = [1,1,0,0,0,1,1,0] + self.forward_add = VectorArray(len(forward_add), cgf2n, nparallel) + for i,x in enumerate(forward_add): + self.forward_add[i] = cgf2n(x, size=nparallel) + + def __init__(self): + constants = [ + 0x63, 0x8F, 0xB5, 0x01, 0xF4, 0x25, 0xF9, 0x09, 0x05 + ] + self.powers = [ + 0, 127, 191, 223, 239, 247, 251, 253, 254 + ] + self.constants = [ApplyEmbedding(cgf2n(_,size=nparallel)) for _ in constants] + self.init_matrices() + + def forward_bit_sbox(self, emb_byte): + emb_byte_inverse = inverseMod(emb_byte) + unembedded_x = InverseBDEmbedding(emb_byte_inverse) + + linear_transform = list() + for row in self.forward_matrix: + result = cgf2n(0, size=nparallel) + for idx in range(len(row)): + result = result + unembedded_x[idx] * row[idx] + linear_transform.append(result) + + #do the sum(linear_transfor + additive_layer) + summation_bd = [0 for _ in range(8)] + for idx in range(8): + summation_bd[idx] = linear_transform[idx] + self.forward_add[idx] + + #Now raise this to power of 254 + result = cgf2n(0,size=nparallel) + for idx in range(8): + result += ApplyBDEmbedding([summation_bd[idx]]) * K01[idx]; + return result + + def apply_sbox(self, what): + #applying with the multiplicative chain + return self.forward_bit_sbox(what) + +box = SpdzBox() + +def aesRound(roundKey): + @vectorize + def inner(state): + subBytes(state) + shiftRows(state) + mixColumns(state) + addRoundKey(roundKey)(state) + return inner + +# returns a 16-byte round key based on an expanded key and round number +def createRoundKey(expandedKey, n): + return expandedKey[(n*16):(n*16+16)] + +# wrapper function for 10 rounds of AES since we're using a 128-bit key +def aesMain(expandedKey, numRounds=10): + @vectorize + def inner(state): + roundKey = createRoundKey(expandedKey, 0) + addRoundKey(roundKey)(state) + for i in range(1, numRounds): + + roundKey = createRoundKey(expandedKey, i) + aesRound(roundKey)(state) + + roundKey = createRoundKey(expandedKey, numRounds) + + subBytes(state) + shiftRows(state) + addRoundKey(roundKey)(state) + return inner + +def encrypt_without_key_schedule(expandedKey): + @vectorize + def encrypt(plaintext): + plaintext = SecretArrayEmbedd(plaintext) + aesMain(expandedKey)(plaintext) + return state_collapse(plaintext) + return encrypt; + +""" +Test Vectors: + +plaintext: +6bc1bee22e409f96e93d7e117393172a + +key: +2b7e151628aed2a6abf7158809cf4f3c + +resulting cipher +3ad77bb40d7a3660a89ecaf32466ef97 + +""" + +def single_encryption(): + key = [sgf2n.get_raw_input_from(0) for _ in range(16)] + message = [sgf2n.get_raw_input_from(1) for _ in range(16)] + + key = [ApplyEmbedding(_) for _ in key] + expanded_key = expandAESKey(key) + + AES = encrypt_without_key_schedule(expanded_key) + + ciphertext = AES(message) + + for block in ciphertext: + print_ln('%s', block.reveal()) + +single_encryption() diff --git a/Programs/Source/tutorial.mpc b/Programs/Source/tutorial.mpc new file mode 100644 index 000000000..05f394d97 --- /dev/null +++ b/Programs/Source/tutorial.mpc @@ -0,0 +1,58 @@ +# (C) 2016 University of Bristol. See License.txt + +def test(actual, expected): + if isinstance(actual, (sint, sgf2n)): + actual = actual.reveal() + print_ln('expected %s, got %s', expected, actual) + +# cint: clear integers modulo p +# sint: secret integers modulo p + +a = sint(1) +b = cint(2) + +test(a + b, 3) +test(a + a, 2) +test(a * b, 2) +test(a * a, 1) +test(a - b, -1) +test(a < b, 1) +test(a <= b, 1) +test(a >= b, 0) +test(a > b, 0) +test(a == b, 0) +test(a != b, 1) + +clear_a = a.reveal() + +# sgfn2/cgf2n: secret/clear elements of GF(2^n) + +a = sgf2n(1) +b = cgf2n(2) + +test(a + b, 3) +test(a + a, 0) +test(a * b, 2) +test(a * a, 1) +test(a == b, 0) +test(a != b, 1) + +# arrays and loops + +a = Array(100, sint) + +@for_range(100) +def f(i): + a[i] = sint(i)**2 + +test(a[99], 99**2) + +# conditional + +if_then(cint(0)) +a[0] = 123 +else_then() +a[0] = 789 +end_if() + +test(a[0], 789) diff --git a/Programs/Source/vickrey.mpc b/Programs/Source/vickrey.mpc new file mode 100644 index 000000000..6494ee4c1 --- /dev/null +++ b/Programs/Source/vickrey.mpc @@ -0,0 +1,77 @@ +# (C) 2016 University of Bristol. See License.txt + +import util +from Compiler import types + +import math +import re +r = re.search('(\D*)(\d*)', program.name) + +if r.group(2): + n_inputs = int(r.group(2)) +else: + n_inputs = 100 + +n_parties = 2 +n_threads = int(math.ceil(2 ** (int(math.log(n_inputs, 2) - 7)))) +n_loops = 1 +n_bits = 64 +#value_type = types.get_sgf2nuint(n_bits) +value_type = sint + +program.set_bit_length(n_bits) +program.set_security(40) + +print_ln('n_inputs = %s, n_parties = %s, n_threads = %s, n_loops = %s, ' + 'value_type = %s', + n_inputs, n_parties, n_threads, n_loops, value_type.__name__) + +@for_range(n_loops) +def f(_): + Bid = types.getNamedTupleType('party', 'price') + bids = Bid.get_array(n_inputs, value_type) + + for i in range(n_inputs): + # i * 10 because inputs are all zero by default + bids[i] = Bid(i, value_type.get_raw_input_from(i % n_parties) + i * 10) + #bids = [Bid(i, value_type(i * 10)) for i in range(n_parties)] + + def bid_sort(a, b): + comp = a.price < b.price + res = util.cond_swap(comp, a, b) + for i in res: + i.price = value_type.hard_conv(i.price) + return res + + def first_and_second(left, right): + top = left[0].price < right[0].price + cross = [left[i].price < right[1-i].price for i in range(2)] + first = top.if_else(right[0], left[0]) + tmp = [cross[i].if_else(right[1-i], left[i]) for i in (0,1)] + second = top.if_else(*tmp) + for i in (first, second): + i.price = value_type.hard_conv(i.price) + return first, second + + results = Bid.get_array(2 * n_threads, value_type) + + def thread(): + i = get_arg() + n_per_thread = n_inputs / n_threads + if n_per_thread % 2 != 0: + raise Exception('Number of inputs must be divisible by 2') + start = i * n_per_thread + tuples = [bid_sort(bids[start+2*j], bids[start+2*j+1]) \ + for j in range(n_per_thread / 2)] + first, second = util.tree_reduce(first_and_second, tuples) + results[2*i], results[2*i+1] = first, second + + tape = program.new_tape(thread) + threads = [program.run_tape(tape, i) for i in range(n_threads)] + for i in threads: + program.join_tape(i) + + tuples = [(results[2*i], results[2*i+1]) for i in range(n_threads)] + first, second = util.tree_reduce(first_and_second, tuples) + + print_ln('Winner: %s, price: %s', first.party.reveal(), second.price.reveal()) diff --git a/README.md b/README.md new file mode 100644 index 000000000..ebfd019d2 --- /dev/null +++ b/README.md @@ -0,0 +1,114 @@ +(C) 2016 University of Bristol. See License.txt + +Software for the SPDZ and MASCOT secure multi-party computation protocols. +See `Programs/Source/` for some example MPC programs, and `tutorial.md` for +a basic tutorial. More examples and documentation will be available in the +coming weeks. + +See also https://www.cs.bris.ac.uk/Research/CryptographySecurity/SPDZ + +#### Requirements: + - GCC + - MPIR library, compiled with C++ support (use flag --enable-cxx when running configure) + - CPU supporting AES-NI and PCLMUL + - Python 2.x, ideally with `gmpy` package (for testing) + +#### To compile SPDZ: + +1) Optionally, edit CONFIG and CONFIG.mine so that the following variables point to the right locations: + - PREP_DIR: this should be a local, unversioned directory to store preprocessing data (defaults to Player-Data in the working directory) + +2) Run make (use the flag -j for faster compilation with multiple threads) + + +#### To setup for the online phase + +Run: + +`Scripts/setup-online.sh` + +This sets up parameters for the online phase for 2 parties with a 128-bit prime field and 40-bit binary field, and creates fake offline data (multiplication triples etc.) for these parameters. + +Parameters can be customised by running + +`Scripts/setup-online.sh ` + + +#### To compile a program + +To compile the program in `./Programs/Source/tutorial.mpc`, run: + +`./compile.py tutorial` + +This creates the bytecode and schedule files in Programs/Bytecode/ and Programs/Schedules/ + +#### To run a program + +To run the above program (on one machine), first run: + +`./Server.x 2 5000 &` + +(or replace `5000` with your desired port number) + +Then run both parties' online phase: + +`./Player-Online.x -pn 5000 0 tutorial` + +`./Player-Online.x -pn 5000 1 tutorial` (in a separate terminal) + +Or, you can use a script to do the above automatically: + +`Scripts/run-online.sh tutorial` + +To run a program on two different machines, firstly the preprocessing data must be +copied across to the second machine (or shared using sshfs), and secondly, Player-Online.x +needs to be passed the machine where Server.x is running. +e.g. if this machine is name `diffie` on the local network: + +`./Player-Online.x -pn 5000 -h diffie 0 tutorial` + +`./Player-Online.x -pn 5000 -h diffie 1 tutorial` + +#### Compiling and running programs from external directories + +Programs can also be edited, compiled and run from any directory with the above basic structure. So for a source file in `./Programs/Source/`, all SPDZ scripts must be run from `./`. The `setup-online.sh` script must also be run from `./` to create the relevant data. For example: + +``` +spdz$ cd ../ +$ mkdir myprogs +$ cd myprogs +$ mkdir -p Programs/Source +$ vi Programs/Source/test.mpc +$ ../spdz/compile.py test.mpc +$ ls Programs/ +Bytecode Public-Input Schedules Source +$ ../spdz/Scripts/setup-online.sh +$ ls +Player-Data Programs +$ ../spdz/Scripts/run-online.sh test +``` + +#### Offline phase (MASCOT) + +In order to compile the MASCOT code, the following must be set in CONFIG or CONFIG.mine: + +`USE_GF2N_LONG = 1` + +It also requires SimpleOT: +``` +git submodule update --init SimpleOT +cd SimpleOT +make +``` + +If SPDZ has been built before, any compiled code needs to be removed: + +`make clean` + +HOSTS must contain the hostnames or IPs of the players, see HOSTS.example for an example. + +Then, MASCOT can be run as follows: + +`host1:$ ./ot-offline.x -p 0 -c` + +`host2:$ ./ot-offline.x -p 1 -c` diff --git a/Scripts/gen_input_f2n.cpp b/Scripts/gen_input_f2n.cpp new file mode 100644 index 000000000..4544d514c --- /dev/null +++ b/Scripts/gen_input_f2n.cpp @@ -0,0 +1,25 @@ +// (C) 2016 University of Bristol. See License.txt + +#include +#include +#include "Math/gf2n.h" + +using namespace std; + +int main() { + ifstream cin("gf2n_vals.in"); + ofstream cout("gf2n_vals.out"); + + gf2n::init_field(40); + + int n; cin >> n; + for (int i = 0; i < n; ++i) { + gf2n_short x; cin >> x; + cerr << "value is: " << x << "\n"; + x.output(cout,false); + } + + cin.close(); + cout.close(); + return 0; +} \ No newline at end of file diff --git a/Scripts/gen_input_fp.cpp b/Scripts/gen_input_fp.cpp new file mode 100644 index 000000000..cc47f8a4e --- /dev/null +++ b/Scripts/gen_input_fp.cpp @@ -0,0 +1,27 @@ +// (C) 2016 University of Bristol. See License.txt + +#include +#include +#include "Math/gfp.h" + +using namespace std; + +int main() { + ifstream cin("gfp_vals.in"); + ofstream cout("gfp_vals.out"); + + gfp::init_field(bigint("172035116406933162231178957667602464769")); + + int n; cin >> n; + for (int i = 0; i < n; ++i) { + bigint a; + cin >> a; + gfp b; + to_gfp(b, a); + b.output(cout, false); + } + + cin.close(); + cout.close(); + return 0; +} \ No newline at end of file diff --git a/Scripts/run-common.sh b/Scripts/run-common.sh new file mode 100644 index 000000000..13a0d34f5 --- /dev/null +++ b/Scripts/run-common.sh @@ -0,0 +1,37 @@ +# (C) 2016 University of Bristol. See License.txt + + +run_player() { + port=$((RANDOM%10000+10000)) + >&2 echo Port $port + bin=$1 + shift + if test $bin = Player-Online.x; then + params="$* -pn $port -h localhost" + else + params="$port localhost $*" + fi + if test $bin = Player-KeyGen.x -a ! -e Player-Data/Params-Data; then + ./Setup.x $players $size 40 + fi + >&2 echo Parameters $params + $SPDZROOT/Server.x $players $port & + rem=$(($players - 2)) + for i in $(seq 0 $rem); do + echo "trying with player $i" + $prefix $SPDZROOT/$bin $i $params 2>&1 | tee $SPDZROOT/logs/$i & + done + last_player=$(($players - 1)) + $prefix $SPDZROOT/$bin $last_player $params > $SPDZROOT/logs/$last_player 2>&1 || return 1 +} + +killall Player-Online.x Server.x +sleep 0.5 + +#mkdir /dev/shm/Player-Data + +players=${PLAYERS:-2} + +#. Scripts/setup.sh + +mkdir logs diff --git a/Scripts/run-online.sh b/Scripts/run-online.sh new file mode 100755 index 000000000..739dca3a1 --- /dev/null +++ b/Scripts/run-online.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +# (C) 2016 University of Bristol. See License.txt + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +bits=${2:-128} +g=${3:-40} +mem=${4:-empty} + +. $HERE/run-common.sh + +Scripts/setup-online.sh 2 ${bits} ${g} + +run_player Player-Online.x ${1:-test_all} -lgp ${bits} -lg2 ${g} -m ${mem} || exit 1 diff --git a/Scripts/setup-online.sh b/Scripts/setup-online.sh new file mode 100755 index 000000000..320058343 --- /dev/null +++ b/Scripts/setup-online.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +# (C) 2016 University of Bristol. See License.txt + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +# number of players +players=${1:-2} +# prime field bit length +bits=${2:-128} +# binary field bit length +g=${3:-40} +# default number of triples etc. to create +default=${4:-10000} + +die () { + echo >&2 "$@" + echo >&2 "Usage: +setup-online.sh [nplayers] [prime_bitlength] [gf2n_bitlength] [num_prep] +Defaults: +nplayers=2, prime_bitlength=128, gf2n_bitlength=40, num_prep=10000" + exit 1 +} + +[ "$#" -le 4 ] || die "More than 4 arguments provided" + +for arg in "$@" +do + echo "$arg" | grep -E -q '^[0-9]+$' || die "Integer argument required, $arg provided" +done + +$SPDZROOT/Fake-Offline.x ${players} -lgp ${bits} -lg2 ${g} --default ${default} + +for i in 0 1; do + dd if=/dev/zero of=Player-Data/Private-Input-$i bs=10000 count=1 +done diff --git a/Server.cpp b/Server.cpp new file mode 100644 index 000000000..103887cf9 --- /dev/null +++ b/Server.cpp @@ -0,0 +1,96 @@ +// (C) 2016 University of Bristol. See License.txt + + +#include "Networking/sockets.h" +#include "Networking/ServerSocket.h" + +#include +#include +#include +using namespace std; + +vector socket_num; + + +vector names; + +int nmachines; + + + +void get_name(int num) +{ + // Now all machines are set up, send GO to start them. + send(socket_num[num], GO); + cerr << "Player " << num << " started." << endl; + + // Receive Name + names[num]=new octet[512]; + receive(socket_num[num],names[num],512); + cerr << "Player " << num << " is on machine " << names[num] << endl; +} + + +void send_names(int num) +{ + /* Now send the machine names back to each client + * and the number of machines + */ + send(socket_num[num],nmachines); + for (int i=0; i + +Lock::Lock() +{ + pthread_mutex_init(&mutex, 0); +} + +Lock::~Lock() +{ + pthread_mutex_destroy(&mutex); +} + +void Lock::lock() +{ + pthread_mutex_lock(&mutex); +} + +void Lock::unlock() +{ + pthread_mutex_unlock(&mutex); +} diff --git a/Tools/Lock.h b/Tools/Lock.h new file mode 100644 index 000000000..06c3b4b88 --- /dev/null +++ b/Tools/Lock.h @@ -0,0 +1,24 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * Lock.h + * + */ + +#ifndef TOOLS_LOCK_H_ +#define TOOLS_LOCK_H_ + +#include + +class Lock +{ + pthread_mutex_t mutex; +public: + Lock(); + virtual ~Lock(); + + void lock(); + void unlock(); +}; + +#endif /* TOOLS_LOCK_H_ */ diff --git a/Tools/MMO.cpp b/Tools/MMO.cpp new file mode 100644 index 000000000..a940305a1 --- /dev/null +++ b/Tools/MMO.cpp @@ -0,0 +1,109 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * MMO.cpp + * + * + */ + +#include "MMO.h" +#include "Math/gf2n.h" +#include "Math/gfp.h" +#include "Math/bigint.h" +#include + + +void MMO::zeroIV() +{ + octet key[AES_BLK_SIZE]; + memset(key,0,AES_BLK_SIZE*sizeof(octet)); + setIV(key); +} + + +void MMO::setIV(octet key[AES_BLK_SIZE]) +{ + aes_schedule(IV,key); +} + + +template <> +void MMO::hashOneBlock(octet* output, octet* input) +{ + __m128i in = _mm_loadu_si128((__m128i*)input); + __m128i ct = aes_encrypt(in, IV); +// __m128i out = ct ^ in; + _mm_storeu_si128((__m128i*)output, ct); +} + + +template <> +void MMO::hashOneBlock(octet* output, octet* input) +{ + __m128i in = _mm_loadu_si128((__m128i*)input); + __m128i ct = aes_encrypt(in, IV); + while (mpn_cmp((mp_limb_t*)&ct, gfp::get_ZpD().get_prA(), gfp::t()) >= 0) + ct = aes_encrypt(ct, IV); + _mm_storeu_si128((__m128i*)output, ct); +} + +template <> +void MMO::hashBlockWise(octet* output, octet* input) +{ + for (int i = 0; i < 16; i++) + ecb_aes_128_encrypt<8>(&((__m128i*)output)[i*8], &((__m128i*)input)[i*8], IV); +} + +template <> +void MMO::hashBlockWise(octet* output, octet* input) +{ + for (int i = 0; i < 16; i++) + { + __m128i* in = &((__m128i*)input)[i*8]; + __m128i* out = &((__m128i*)output)[i*8]; + ecb_aes_128_encrypt<8>(out, in, IV); + int left = 8; + int indices[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + while (left) + { + int now_left = 0; + for (int j = 0; j < left; j++) + if (mpn_cmp((mp_limb_t*)&out[indices[j]], gfp::get_ZpD().get_prA(), gfp::t()) >= 0) + { + indices[now_left] = indices[j]; + now_left++; + } + left = now_left; + + // and now my favorite hack + switch (left) { + case 8: + ecb_aes_128_encrypt<8>(out, out, IV, indices); + break; + case 7: + ecb_aes_128_encrypt<7>(out, out, IV, indices); + break; + case 6: + ecb_aes_128_encrypt<6>(out, out, IV, indices); + break; + case 5: + ecb_aes_128_encrypt<5>(out, out, IV, indices); + break; + case 4: + ecb_aes_128_encrypt<4>(out, out, IV, indices); + break; + case 3: + ecb_aes_128_encrypt<3>(out, out, IV, indices); + break; + case 2: + ecb_aes_128_encrypt<2>(out, out, IV, indices); + break; + case 1: + ecb_aes_128_encrypt<1>(out, out, IV, indices); + break; + default: + break; + } + } + } +} diff --git a/Tools/MMO.h b/Tools/MMO.h new file mode 100644 index 000000000..3f6fe3e8f --- /dev/null +++ b/Tools/MMO.h @@ -0,0 +1,30 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * MMO.h + * + */ + +#ifndef TOOLS_MMO_H_ +#define TOOLS_MMO_H_ + +#include "Tools/aes.h" + +// Matyas-Meyer-Oseas hashing +class MMO +{ + octet IV[176] __attribute__((aligned (16))); + +public: + MMO() { zeroIV(); } + void zeroIV(); + void setIV(octet key[AES_BLK_SIZE]); + template + void hashOneBlock(octet* output, octet* input); + template + void hashBlockWise(octet* output, octet* input); + template + void outputOneBlock(octet* output); +}; + +#endif /* TOOLS_MMO_H_ */ diff --git a/Tools/Signal.cpp b/Tools/Signal.cpp new file mode 100644 index 000000000..420fdd157 --- /dev/null +++ b/Tools/Signal.cpp @@ -0,0 +1,40 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * Signal.cpp + * + */ + +#include "Signal.h" + +Signal::Signal() +{ + pthread_mutex_init(&mutex, 0); + pthread_cond_init(&cond, 0); +} + +Signal::~Signal() +{ + pthread_mutex_destroy(&mutex); + pthread_cond_destroy(&cond); +} + +void Signal::lock() +{ + pthread_mutex_lock(&mutex); +} + +void Signal::unlock() +{ + pthread_mutex_unlock(&mutex); +} + +void Signal::wait() +{ + pthread_cond_wait(&cond, &mutex); +} + +void Signal::broadcast() +{ + pthread_cond_broadcast(&cond); +} diff --git a/Tools/Signal.h b/Tools/Signal.h new file mode 100644 index 000000000..27fbacc5b --- /dev/null +++ b/Tools/Signal.h @@ -0,0 +1,27 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * Signal.h + * + */ + +#ifndef TOOLS_SIGNAL_H_ +#define TOOLS_SIGNAL_H_ + +#include + +class Signal +{ + pthread_mutex_t mutex; + pthread_cond_t cond; + +public: + Signal(); + virtual ~Signal(); + void lock(); + void unlock(); + void wait(); + void broadcast(); +}; + +#endif /* TOOLS_SIGNAL_H_ */ diff --git a/Tools/WaitQueue.h b/Tools/WaitQueue.h new file mode 100644 index 000000000..20722fd2e --- /dev/null +++ b/Tools/WaitQueue.h @@ -0,0 +1,91 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * WaitQueue.h + * + */ + +#ifndef TOOLS_WAITQUEUE_H_ +#define TOOLS_WAITQUEUE_H_ + +#include +#include +using namespace std; + +template +class WaitQueue +{ + pthread_mutex_t mutex; + pthread_cond_t cond; + + deque queue; + bool running; + + // prevent copying + WaitQueue(const WaitQueue& other); + +public: + WaitQueue() : running(true) + { + pthread_mutex_init(&mutex, 0); + pthread_cond_init(&cond, 0); + } + + ~WaitQueue() + { + pthread_mutex_destroy(&mutex); + pthread_cond_destroy(&cond); + } + + void lock() + { + pthread_mutex_lock(&mutex); + } + + void unlock() + { + pthread_mutex_unlock(&mutex); + } + + void wait() + { + pthread_cond_wait(&cond, &mutex); + } + + void signal() + { + pthread_cond_signal(&cond); + } + + void push(const T& value) + { + lock(); + queue.push_back(value); + signal(); + unlock(); + } + + bool pop(T& value) + { + lock(); + if (running and queue.size() == 0) + wait(); + if (running) + { + value = queue.front(); + queue.pop_front(); + } + unlock(); + return running; + } + + void stop() + { + lock(); + running = false; + signal(); + unlock(); + } +}; + +#endif /* TOOLS_WAITQUEUE_H_ */ diff --git a/Tools/aes-ni.cpp b/Tools/aes-ni.cpp new file mode 100644 index 000000000..e1125efdc --- /dev/null +++ b/Tools/aes-ni.cpp @@ -0,0 +1,231 @@ +// (C) 2016 University of Bristol. See License.txt + + +#include "aes.h" + + +/********************** + * M-Code Version * + **********************/ + +#define cpuid(func,ax,bx,cx,dx)\ + __asm__ __volatile__ ("cpuid":\ + "=a" (ax), "=b" (bx), "=c" (cx), "=d" (dx) : "a" (func)); + + +int Check_CPU_support_AES() +{ unsigned int a,b,c,d; + cpuid(1, a,b,c,d); + return (c & 0x2000000); +} + +inline __m128i AES_128_ASSIST (__m128i temp1, __m128i temp2) +{ __m128i temp3; temp2 = _mm_shuffle_epi32 (temp2 ,0xff); + temp3 = _mm_slli_si128 (temp1, 0x4); + temp1 = _mm_xor_si128 (temp1, temp3); + temp3 = _mm_slli_si128 (temp3, 0x4); + temp1 = _mm_xor_si128 (temp1, temp3); + temp3 = _mm_slli_si128 (temp3, 0x4); + temp1 = _mm_xor_si128 (temp1, temp3); + temp1 = _mm_xor_si128 (temp1, temp2); + return temp1; +} + + +void aes_128_schedule( octet* key, const octet* userkey ) +{ __m128i temp1, temp2; + __m128i *Key_Schedule = (__m128i*)key; + temp1 = _mm_loadu_si128((__m128i*)userkey); + Key_Schedule[0] = temp1; + temp2 = _mm_aeskeygenassist_si128 (temp1 ,0x1); + temp1 = AES_128_ASSIST(temp1, temp2); + Key_Schedule[1] = temp1; + temp2 = _mm_aeskeygenassist_si128 (temp1,0x2); + temp1 = AES_128_ASSIST(temp1, temp2); + Key_Schedule[2] = temp1; + temp2 = _mm_aeskeygenassist_si128 (temp1,0x4); + temp1 = AES_128_ASSIST(temp1, temp2); + Key_Schedule[3] = temp1; + temp2 = _mm_aeskeygenassist_si128 (temp1,0x8); + temp1 = AES_128_ASSIST(temp1, temp2); + Key_Schedule[4] = temp1; + temp2 = _mm_aeskeygenassist_si128 (temp1,0x10); + temp1 = AES_128_ASSIST(temp1, temp2); + Key_Schedule[5] = temp1; + temp2 = _mm_aeskeygenassist_si128 (temp1,0x20); + temp1 = AES_128_ASSIST(temp1, temp2); + Key_Schedule[6] = temp1; + temp2 = _mm_aeskeygenassist_si128 (temp1,0x40); + temp1 = AES_128_ASSIST(temp1, temp2); + Key_Schedule[7] = temp1; + temp2 = _mm_aeskeygenassist_si128 (temp1,0x80); + temp1 = AES_128_ASSIST(temp1, temp2); + Key_Schedule[8] = temp1; + temp2 = _mm_aeskeygenassist_si128 (temp1,0x1b); + temp1 = AES_128_ASSIST(temp1, temp2); + Key_Schedule[9] = temp1; + temp2 = _mm_aeskeygenassist_si128 (temp1,0x36); + temp1 = AES_128_ASSIST(temp1, temp2); + Key_Schedule[10] = temp1; +} + +inline void KEY_192_ASSIST(__m128i* temp1, __m128i * temp2, __m128i * temp3) +{ __m128i temp4; + *temp2 = _mm_shuffle_epi32 (*temp2, 0x55); + temp4 = _mm_slli_si128 (*temp1, 0x4); + *temp1 = _mm_xor_si128 (*temp1, temp4); + temp4 = _mm_slli_si128 (temp4, 0x4); + *temp1 = _mm_xor_si128 (*temp1, temp4); + temp4 = _mm_slli_si128 (temp4, 0x4); + *temp1 = _mm_xor_si128 (*temp1, temp4); + *temp1 = _mm_xor_si128 (*temp1, *temp2); + *temp2 = _mm_shuffle_epi32(*temp1, 0xff); + temp4 = _mm_slli_si128 (*temp3, 0x4); + *temp3 = _mm_xor_si128 (*temp3, temp4); + *temp3 = _mm_xor_si128 (*temp3, *temp2); +} + + +void aes_192_schedule( octet* key, const octet* userkey ) +{ __m128i temp1, temp2, temp3; + __m128i *Key_Schedule = (__m128i*)key; + temp1 = _mm_loadu_si128((__m128i*)userkey); + temp3 = _mm_loadu_si128((__m128i*)(userkey+16)); + Key_Schedule[0]=temp1; + Key_Schedule[1]=temp3; + temp2=_mm_aeskeygenassist_si128 (temp3,0x1); + KEY_192_ASSIST(&temp1, &temp2, &temp3); + Key_Schedule[1] = (__m128i)_mm_shuffle_pd((__m128d)Key_Schedule[1],(__m128d)temp1,0); + Key_Schedule[2] = (__m128i)_mm_shuffle_pd((__m128d)temp1,(__m128d)temp3,1); + temp2=_mm_aeskeygenassist_si128 (temp3,0x2); + KEY_192_ASSIST(&temp1, &temp2, &temp3); + Key_Schedule[3]=temp1; + Key_Schedule[4]=temp3; + temp2=_mm_aeskeygenassist_si128 (temp3,0x4); + KEY_192_ASSIST(&temp1, &temp2, &temp3); + Key_Schedule[4] = (__m128i)_mm_shuffle_pd((__m128d)Key_Schedule[4],(__m128d)temp1,0); + Key_Schedule[5] = (__m128i)_mm_shuffle_pd((__m128d)temp1,(__m128d)temp3,1); + temp2=_mm_aeskeygenassist_si128 (temp3,0x8); + KEY_192_ASSIST(&temp1, &temp2, &temp3); + Key_Schedule[6]=temp1; + Key_Schedule[7]=temp3; + temp2=_mm_aeskeygenassist_si128 (temp3,0x10); + KEY_192_ASSIST(&temp1, &temp2, &temp3); + Key_Schedule[7] = (__m128i)_mm_shuffle_pd((__m128d)Key_Schedule[7],(__m128d)temp1,0); + Key_Schedule[8] = (__m128i)_mm_shuffle_pd((__m128d)temp1,(__m128d)temp3,1); + temp2=_mm_aeskeygenassist_si128 (temp3,0x20); + KEY_192_ASSIST(&temp1, &temp2, &temp3); + Key_Schedule[9]=temp1; + Key_Schedule[10]=temp3; + temp2=_mm_aeskeygenassist_si128 (temp3,0x40); + KEY_192_ASSIST(&temp1, &temp2, &temp3); + Key_Schedule[10] = (__m128i)_mm_shuffle_pd((__m128d)Key_Schedule[10],(__m128d)temp1,0); + Key_Schedule[11] = (__m128i)_mm_shuffle_pd((__m128d)temp1,(__m128d)temp3,1); + temp2=_mm_aeskeygenassist_si128 (temp3,0x80); + KEY_192_ASSIST(&temp1, &temp2, &temp3); + Key_Schedule[12]=temp1; +} + +inline void KEY_256_ASSIST_1(__m128i* temp1, __m128i * temp2) +{ __m128i temp4; + *temp2 = _mm_shuffle_epi32(*temp2, 0xff); + temp4 = _mm_slli_si128 (*temp1, 0x4); + *temp1 = _mm_xor_si128 (*temp1, temp4); + temp4 = _mm_slli_si128 (temp4, 0x4); + *temp1 = _mm_xor_si128 (*temp1, temp4); + temp4 = _mm_slli_si128 (temp4, 0x4); + *temp1 = _mm_xor_si128 (*temp1, temp4); + *temp1 = _mm_xor_si128 (*temp1, *temp2); +} + + +inline void KEY_256_ASSIST_2(__m128i* temp1, __m128i * temp3) +{ __m128i temp2,temp4; + temp4 = _mm_aeskeygenassist_si128 (*temp1, 0x0); + temp2 = _mm_shuffle_epi32(temp4, 0xaa); + temp4 = _mm_slli_si128 (*temp3, 0x4); + *temp3 = _mm_xor_si128 (*temp3, temp4); + temp4 = _mm_slli_si128 (temp4, 0x4); + *temp3 = _mm_xor_si128 (*temp3, temp4); + temp4 = _mm_slli_si128 (temp4, 0x4); + *temp3 = _mm_xor_si128 (*temp3, temp4); + *temp3 = _mm_xor_si128 (*temp3, temp2); +} + +void aes_256_schedule( octet* key, const octet* userkey ) +{ __m128i temp1, temp2, temp3; + __m128i *Key_Schedule = (__m128i*)key; + temp1 = _mm_loadu_si128((__m128i*)userkey); + temp3 = _mm_loadu_si128((__m128i*)(userkey+16)); + Key_Schedule[0] = temp1; + Key_Schedule[1] = temp3; + temp2 = _mm_aeskeygenassist_si128 (temp3,0x01); + KEY_256_ASSIST_1(&temp1, &temp2); + Key_Schedule[2]=temp1; + KEY_256_ASSIST_2(&temp1, &temp3); + Key_Schedule[3]=temp3; + temp2 = _mm_aeskeygenassist_si128 (temp3,0x02); + KEY_256_ASSIST_1(&temp1, &temp2); + Key_Schedule[4]=temp1; + KEY_256_ASSIST_2(&temp1, &temp3); + Key_Schedule[5]=temp3; + temp2 = _mm_aeskeygenassist_si128 (temp3,0x04); + KEY_256_ASSIST_1(&temp1, &temp2); + Key_Schedule[6]=temp1; + KEY_256_ASSIST_2(&temp1, &temp3); + Key_Schedule[7]=temp3; + temp2 = _mm_aeskeygenassist_si128 (temp3,0x08); + KEY_256_ASSIST_1(&temp1, &temp2); + Key_Schedule[8]=temp1; + KEY_256_ASSIST_2(&temp1, &temp3); + Key_Schedule[9]=temp3; + temp2 = _mm_aeskeygenassist_si128 (temp3,0x10); + KEY_256_ASSIST_1(&temp1, &temp2); + Key_Schedule[10]=temp1; + KEY_256_ASSIST_2(&temp1, &temp3); + Key_Schedule[11]=temp3; + temp2 = _mm_aeskeygenassist_si128 (temp3,0x20); + KEY_256_ASSIST_1(&temp1, &temp2); + Key_Schedule[12]=temp1; + KEY_256_ASSIST_2(&temp1, &temp3); + Key_Schedule[13]=temp3; + temp2 = _mm_aeskeygenassist_si128 (temp3,0x40); + KEY_256_ASSIST_1(&temp1, &temp2); + Key_Schedule[14]=temp1; +} + + + +void aes_128_encrypt(octet* out, const octet* in, const octet* key) +{ __m128i tmp; + tmp = _mm_loadu_si128 (&((__m128i*)in)[0]); + tmp = aes_128_encrypt(tmp,key); + _mm_storeu_si128 (&((__m128i*)out)[0],tmp); +} + +void aes_192_encrypt(octet* out, const octet* in, const octet* key) +{ __m128i tmp; + tmp = _mm_loadu_si128 (&((__m128i*)in)[0]); + tmp = _mm_xor_si128 (tmp,((__m128i*)key)[0]); + int j; + for(j=1; j <12; j++) + { tmp = _mm_aesenc_si128 (tmp,((__m128i*)key)[j]); } + tmp = _mm_aesenclast_si128 (tmp,((__m128i*)key)[j]); + _mm_storeu_si128 (&((__m128i*)out)[0],tmp); +} + +void aes_256_encrypt(octet* out, const octet* in, const octet* key) +{ __m128i tmp; + tmp = _mm_loadu_si128 (&((__m128i*)in)[0]); + tmp = _mm_xor_si128 (tmp,((__m128i*)key)[0]); + int j; + for(j=1; j <14; j++) + { tmp = _mm_aesenc_si128 (tmp,((__m128i*)key)[j]); } + tmp = _mm_aesenclast_si128 (tmp,((__m128i*)key)[j]); + _mm_storeu_si128 (&((__m128i*)out)[0],tmp); +} + + + + + diff --git a/Tools/aes.cpp b/Tools/aes.cpp new file mode 100644 index 000000000..23a983343 --- /dev/null +++ b/Tools/aes.cpp @@ -0,0 +1,491 @@ +// (C) 2016 University of Bristol. See License.txt + + +#include "aes.h" + + +/********************** + * C Version * + **********************/ + + +#define U8_TO_U32_LE(r,x,i) \ +{ \ + r = ( uint )( x[ i + 0 ] ) << 0; \ + r |= ( uint )( x[ i + 1 ] ) << 8; \ + r |= ( uint )( x[ i + 2 ] ) << 16; \ + r |= ( uint )( x[ i + 3 ] ) << 24; \ +} +#define U32_TO_U8_LE(r,x,i) \ +{ \ + r[ i + 0 ] = ( x >> 0 ) & 0xFF; \ + r[ i + 1 ] = ( x >> 8 ) & 0xFF; \ + r[ i + 2 ] = ( x >> 16 ) & 0xFF; \ + r[ i + 3 ] = ( x >> 24 ) & 0xFF; \ +} + +#define ROUND1(a,b,c,d) \ +{ \ + t0 = t0 ^ RK[ 0 ]; \ + t1 = t1 ^ RK[ 1 ]; \ + t2 = t2 ^ RK[ 2 ]; \ + t3 = t3 ^ RK[ 3 ]; \ +} + +#define ROUND2(a,b,c,d) \ +{ \ + t4 = ( T0[ ( t0 >> 0 ) & 0xFF ] ) ^ \ + ( T1[ ( t1 >> 8 ) & 0xFF ] ) ^ \ + ( T2[ ( t2 >> 16 ) & 0xFF ] ) ^ \ + ( T3[ ( t3 >> 24 ) & 0xFF ] ) ^ RK[ a ]; \ + t5 = ( T0[ ( t1 >> 0 ) & 0xFF ] ) ^ \ + ( T1[ ( t2 >> 8 ) & 0xFF ] ) ^ \ + ( T2[ ( t3 >> 16 ) & 0xFF ] ) ^ \ + ( T3[ ( t0 >> 24 ) & 0xFF ] ) ^ RK[ b ]; \ + t6 = ( T0[ ( t2 >> 0 ) & 0xFF ] ) ^ \ + ( T1[ ( t3 >> 8 ) & 0xFF ] ) ^ \ + ( T2[ ( t0 >> 16 ) & 0xFF ] ) ^ \ + ( T3[ ( t1 >> 24 ) & 0xFF ] ) ^ RK[ c ]; \ + t7 = ( T0[ ( t3 >> 0 ) & 0xFF ] ) ^ \ + ( T1[ ( t0 >> 8 ) & 0xFF ] ) ^ \ + ( T2[ ( t1 >> 16 ) & 0xFF ] ) ^ \ + ( T3[ ( t2 >> 24 ) & 0xFF ] ) ^ RK[ d ]; \ + \ + t0 = t4; \ + t1 = t5; \ + t2 = t6; \ + t3 = t7; \ +} + +#define ROUND3(a,b,c,d) \ +{ \ + t4 = ( T4[ ( t0 >> 0 ) & 0xFF ] & 0x000000FF ) ^ \ + ( T4[ ( t1 >> 8 ) & 0xFF ] & 0x0000FF00 ) ^ \ + ( T4[ ( t2 >> 16 ) & 0xFF ] & 0x00FF0000 ) ^ \ + ( T4[ ( t3 >> 24 ) & 0xFF ] & 0xFF000000 ) ^ RK[ a ]; \ + t5 = ( T4[ ( t1 >> 0 ) & 0xFF ] & 0x000000FF ) ^ \ + ( T4[ ( t2 >> 8 ) & 0xFF ] & 0x0000FF00 ) ^ \ + ( T4[ ( t3 >> 16 ) & 0xFF ] & 0x00FF0000 ) ^ \ + ( T4[ ( t0 >> 24 ) & 0xFF ] & 0xFF000000 ) ^ RK[ b ]; \ + t6 = ( T4[ ( t2 >> 0 ) & 0xFF ] & 0x000000FF ) ^ \ + ( T4[ ( t3 >> 8 ) & 0xFF ] & 0x0000FF00 ) ^ \ + ( T4[ ( t0 >> 16 ) & 0xFF ] & 0x00FF0000 ) ^ \ + ( T4[ ( t1 >> 24 ) & 0xFF ] & 0xFF000000 ) ^ RK[ c ]; \ + t7 = ( T4[ ( t3 >> 0 ) & 0xFF ] & 0x000000FF ) ^ \ + ( T4[ ( t0 >> 8 ) & 0xFF ] & 0x0000FF00 ) ^ \ + ( T4[ ( t1 >> 16 ) & 0xFF ] & 0x00FF0000 ) ^ \ + ( T4[ ( t2 >> 24 ) & 0xFF ] & 0xFF000000 ) ^ RK[ d ]; \ +} + +uint T0[] ={ 0xA56363C6, 0x847C7CF8, 0x997777EE, 0x8D7B7BF6, + 0x0DF2F2FF, 0xBD6B6BD6, 0xB16F6FDE, 0x54C5C591, + 0x50303060, 0x03010102, 0xA96767CE, 0x7D2B2B56, + 0x19FEFEE7, 0x62D7D7B5, 0xE6ABAB4D, 0x9A7676EC, + 0x45CACA8F, 0x9D82821F, 0x40C9C989, 0x877D7DFA, + 0x15FAFAEF, 0xEB5959B2, 0xC947478E, 0x0BF0F0FB, + 0xECADAD41, 0x67D4D4B3, 0xFDA2A25F, 0xEAAFAF45, + 0xBF9C9C23, 0xF7A4A453, 0x967272E4, 0x5BC0C09B, + 0xC2B7B775, 0x1CFDFDE1, 0xAE93933D, 0x6A26264C, + 0x5A36366C, 0x413F3F7E, 0x02F7F7F5, 0x4FCCCC83, + 0x5C343468, 0xF4A5A551, 0x34E5E5D1, 0x08F1F1F9, + 0x937171E2, 0x73D8D8AB, 0x53313162, 0x3F15152A, + 0x0C040408, 0x52C7C795, 0x65232346, 0x5EC3C39D, + 0x28181830, 0xA1969637, 0x0F05050A, 0xB59A9A2F, + 0x0907070E, 0x36121224, 0x9B80801B, 0x3DE2E2DF, + 0x26EBEBCD, 0x6927274E, 0xCDB2B27F, 0x9F7575EA, + 0x1B090912, 0x9E83831D, 0x742C2C58, 0x2E1A1A34, + 0x2D1B1B36, 0xB26E6EDC, 0xEE5A5AB4, 0xFBA0A05B, + 0xF65252A4, 0x4D3B3B76, 0x61D6D6B7, 0xCEB3B37D, + 0x7B292952, 0x3EE3E3DD, 0x712F2F5E, 0x97848413, + 0xF55353A6, 0x68D1D1B9, 0x00000000, 0x2CEDEDC1, + 0x60202040, 0x1FFCFCE3, 0xC8B1B179, 0xED5B5BB6, + 0xBE6A6AD4, 0x46CBCB8D, 0xD9BEBE67, 0x4B393972, + 0xDE4A4A94, 0xD44C4C98, 0xE85858B0, 0x4ACFCF85, + 0x6BD0D0BB, 0x2AEFEFC5, 0xE5AAAA4F, 0x16FBFBED, + 0xC5434386, 0xD74D4D9A, 0x55333366, 0x94858511, + 0xCF45458A, 0x10F9F9E9, 0x06020204, 0x817F7FFE, + 0xF05050A0, 0x443C3C78, 0xBA9F9F25, 0xE3A8A84B, + 0xF35151A2, 0xFEA3A35D, 0xC0404080, 0x8A8F8F05, + 0xAD92923F, 0xBC9D9D21, 0x48383870, 0x04F5F5F1, + 0xDFBCBC63, 0xC1B6B677, 0x75DADAAF, 0x63212142, + 0x30101020, 0x1AFFFFE5, 0x0EF3F3FD, 0x6DD2D2BF, + 0x4CCDCD81, 0x140C0C18, 0x35131326, 0x2FECECC3, + 0xE15F5FBE, 0xA2979735, 0xCC444488, 0x3917172E, + 0x57C4C493, 0xF2A7A755, 0x827E7EFC, 0x473D3D7A, + 0xAC6464C8, 0xE75D5DBA, 0x2B191932, 0x957373E6, + 0xA06060C0, 0x98818119, 0xD14F4F9E, 0x7FDCDCA3, + 0x66222244, 0x7E2A2A54, 0xAB90903B, 0x8388880B, + 0xCA46468C, 0x29EEEEC7, 0xD3B8B86B, 0x3C141428, + 0x79DEDEA7, 0xE25E5EBC, 0x1D0B0B16, 0x76DBDBAD, + 0x3BE0E0DB, 0x56323264, 0x4E3A3A74, 0x1E0A0A14, + 0xDB494992, 0x0A06060C, 0x6C242448, 0xE45C5CB8, + 0x5DC2C29F, 0x6ED3D3BD, 0xEFACAC43, 0xA66262C4, + 0xA8919139, 0xA4959531, 0x37E4E4D3, 0x8B7979F2, + 0x32E7E7D5, 0x43C8C88B, 0x5937376E, 0xB76D6DDA, + 0x8C8D8D01, 0x64D5D5B1, 0xD24E4E9C, 0xE0A9A949, + 0xB46C6CD8, 0xFA5656AC, 0x07F4F4F3, 0x25EAEACF, + 0xAF6565CA, 0x8E7A7AF4, 0xE9AEAE47, 0x18080810, + 0xD5BABA6F, 0x887878F0, 0x6F25254A, 0x722E2E5C, + 0x241C1C38, 0xF1A6A657, 0xC7B4B473, 0x51C6C697, + 0x23E8E8CB, 0x7CDDDDA1, 0x9C7474E8, 0x211F1F3E, + 0xDD4B4B96, 0xDCBDBD61, 0x868B8B0D, 0x858A8A0F, + 0x907070E0, 0x423E3E7C, 0xC4B5B571, 0xAA6666CC, + 0xD8484890, 0x05030306, 0x01F6F6F7, 0x120E0E1C, + 0xA36161C2, 0x5F35356A, 0xF95757AE, 0xD0B9B969, + 0x91868617, 0x58C1C199, 0x271D1D3A, 0xB99E9E27, + 0x38E1E1D9, 0x13F8F8EB, 0xB398982B, 0x33111122, + 0xBB6969D2, 0x70D9D9A9, 0x898E8E07, 0xA7949433, + 0xB69B9B2D, 0x221E1E3C, 0x92878715, 0x20E9E9C9, + 0x49CECE87, 0xFF5555AA, 0x78282850, 0x7ADFDFA5, + 0x8F8C8C03, 0xF8A1A159, 0x80898909, 0x170D0D1A, + 0xDABFBF65, 0x31E6E6D7, 0xC6424284, 0xB86868D0, + 0xC3414182, 0xB0999929, 0x772D2D5A, 0x110F0F1E, + 0xCBB0B07B, 0xFC5454A8, 0xD6BBBB6D, 0x3A16162C }; + +uint T1[] ={ 0x6363C6A5, 0x7C7CF884, 0x7777EE99, 0x7B7BF68D, + 0xF2F2FF0D, 0x6B6BD6BD, 0x6F6FDEB1, 0xC5C59154, + 0x30306050, 0x01010203, 0x6767CEA9, 0x2B2B567D, + 0xFEFEE719, 0xD7D7B562, 0xABAB4DE6, 0x7676EC9A, + 0xCACA8F45, 0x82821F9D, 0xC9C98940, 0x7D7DFA87, + 0xFAFAEF15, 0x5959B2EB, 0x47478EC9, 0xF0F0FB0B, + 0xADAD41EC, 0xD4D4B367, 0xA2A25FFD, 0xAFAF45EA, + 0x9C9C23BF, 0xA4A453F7, 0x7272E496, 0xC0C09B5B, + 0xB7B775C2, 0xFDFDE11C, 0x93933DAE, 0x26264C6A, + 0x36366C5A, 0x3F3F7E41, 0xF7F7F502, 0xCCCC834F, + 0x3434685C, 0xA5A551F4, 0xE5E5D134, 0xF1F1F908, + 0x7171E293, 0xD8D8AB73, 0x31316253, 0x15152A3F, + 0x0404080C, 0xC7C79552, 0x23234665, 0xC3C39D5E, + 0x18183028, 0x969637A1, 0x05050A0F, 0x9A9A2FB5, + 0x07070E09, 0x12122436, 0x80801B9B, 0xE2E2DF3D, + 0xEBEBCD26, 0x27274E69, 0xB2B27FCD, 0x7575EA9F, + 0x0909121B, 0x83831D9E, 0x2C2C5874, 0x1A1A342E, + 0x1B1B362D, 0x6E6EDCB2, 0x5A5AB4EE, 0xA0A05BFB, + 0x5252A4F6, 0x3B3B764D, 0xD6D6B761, 0xB3B37DCE, + 0x2929527B, 0xE3E3DD3E, 0x2F2F5E71, 0x84841397, + 0x5353A6F5, 0xD1D1B968, 0x00000000, 0xEDEDC12C, + 0x20204060, 0xFCFCE31F, 0xB1B179C8, 0x5B5BB6ED, + 0x6A6AD4BE, 0xCBCB8D46, 0xBEBE67D9, 0x3939724B, + 0x4A4A94DE, 0x4C4C98D4, 0x5858B0E8, 0xCFCF854A, + 0xD0D0BB6B, 0xEFEFC52A, 0xAAAA4FE5, 0xFBFBED16, + 0x434386C5, 0x4D4D9AD7, 0x33336655, 0x85851194, + 0x45458ACF, 0xF9F9E910, 0x02020406, 0x7F7FFE81, + 0x5050A0F0, 0x3C3C7844, 0x9F9F25BA, 0xA8A84BE3, + 0x5151A2F3, 0xA3A35DFE, 0x404080C0, 0x8F8F058A, + 0x92923FAD, 0x9D9D21BC, 0x38387048, 0xF5F5F104, + 0xBCBC63DF, 0xB6B677C1, 0xDADAAF75, 0x21214263, + 0x10102030, 0xFFFFE51A, 0xF3F3FD0E, 0xD2D2BF6D, + 0xCDCD814C, 0x0C0C1814, 0x13132635, 0xECECC32F, + 0x5F5FBEE1, 0x979735A2, 0x444488CC, 0x17172E39, + 0xC4C49357, 0xA7A755F2, 0x7E7EFC82, 0x3D3D7A47, + 0x6464C8AC, 0x5D5DBAE7, 0x1919322B, 0x7373E695, + 0x6060C0A0, 0x81811998, 0x4F4F9ED1, 0xDCDCA37F, + 0x22224466, 0x2A2A547E, 0x90903BAB, 0x88880B83, + 0x46468CCA, 0xEEEEC729, 0xB8B86BD3, 0x1414283C, + 0xDEDEA779, 0x5E5EBCE2, 0x0B0B161D, 0xDBDBAD76, + 0xE0E0DB3B, 0x32326456, 0x3A3A744E, 0x0A0A141E, + 0x494992DB, 0x06060C0A, 0x2424486C, 0x5C5CB8E4, + 0xC2C29F5D, 0xD3D3BD6E, 0xACAC43EF, 0x6262C4A6, + 0x919139A8, 0x959531A4, 0xE4E4D337, 0x7979F28B, + 0xE7E7D532, 0xC8C88B43, 0x37376E59, 0x6D6DDAB7, + 0x8D8D018C, 0xD5D5B164, 0x4E4E9CD2, 0xA9A949E0, + 0x6C6CD8B4, 0x5656ACFA, 0xF4F4F307, 0xEAEACF25, + 0x6565CAAF, 0x7A7AF48E, 0xAEAE47E9, 0x08081018, + 0xBABA6FD5, 0x7878F088, 0x25254A6F, 0x2E2E5C72, + 0x1C1C3824, 0xA6A657F1, 0xB4B473C7, 0xC6C69751, + 0xE8E8CB23, 0xDDDDA17C, 0x7474E89C, 0x1F1F3E21, + 0x4B4B96DD, 0xBDBD61DC, 0x8B8B0D86, 0x8A8A0F85, + 0x7070E090, 0x3E3E7C42, 0xB5B571C4, 0x6666CCAA, + 0x484890D8, 0x03030605, 0xF6F6F701, 0x0E0E1C12, + 0x6161C2A3, 0x35356A5F, 0x5757AEF9, 0xB9B969D0, + 0x86861791, 0xC1C19958, 0x1D1D3A27, 0x9E9E27B9, + 0xE1E1D938, 0xF8F8EB13, 0x98982BB3, 0x11112233, + 0x6969D2BB, 0xD9D9A970, 0x8E8E0789, 0x949433A7, + 0x9B9B2DB6, 0x1E1E3C22, 0x87871592, 0xE9E9C920, + 0xCECE8749, 0x5555AAFF, 0x28285078, 0xDFDFA57A, + 0x8C8C038F, 0xA1A159F8, 0x89890980, 0x0D0D1A17, + 0xBFBF65DA, 0xE6E6D731, 0x424284C6, 0x6868D0B8, + 0x414182C3, 0x999929B0, 0x2D2D5A77, 0x0F0F1E11, + 0xB0B07BCB, 0x5454A8FC, 0xBBBB6DD6, 0x16162C3A }; + +uint T2[] ={ 0x63C6A563, 0x7CF8847C, 0x77EE9977, 0x7BF68D7B, + 0xF2FF0DF2, 0x6BD6BD6B, 0x6FDEB16F, 0xC59154C5, + 0x30605030, 0x01020301, 0x67CEA967, 0x2B567D2B, + 0xFEE719FE, 0xD7B562D7, 0xAB4DE6AB, 0x76EC9A76, + 0xCA8F45CA, 0x821F9D82, 0xC98940C9, 0x7DFA877D, + 0xFAEF15FA, 0x59B2EB59, 0x478EC947, 0xF0FB0BF0, + 0xAD41ECAD, 0xD4B367D4, 0xA25FFDA2, 0xAF45EAAF, + 0x9C23BF9C, 0xA453F7A4, 0x72E49672, 0xC09B5BC0, + 0xB775C2B7, 0xFDE11CFD, 0x933DAE93, 0x264C6A26, + 0x366C5A36, 0x3F7E413F, 0xF7F502F7, 0xCC834FCC, + 0x34685C34, 0xA551F4A5, 0xE5D134E5, 0xF1F908F1, + 0x71E29371, 0xD8AB73D8, 0x31625331, 0x152A3F15, + 0x04080C04, 0xC79552C7, 0x23466523, 0xC39D5EC3, + 0x18302818, 0x9637A196, 0x050A0F05, 0x9A2FB59A, + 0x070E0907, 0x12243612, 0x801B9B80, 0xE2DF3DE2, + 0xEBCD26EB, 0x274E6927, 0xB27FCDB2, 0x75EA9F75, + 0x09121B09, 0x831D9E83, 0x2C58742C, 0x1A342E1A, + 0x1B362D1B, 0x6EDCB26E, 0x5AB4EE5A, 0xA05BFBA0, + 0x52A4F652, 0x3B764D3B, 0xD6B761D6, 0xB37DCEB3, + 0x29527B29, 0xE3DD3EE3, 0x2F5E712F, 0x84139784, + 0x53A6F553, 0xD1B968D1, 0x00000000, 0xEDC12CED, + 0x20406020, 0xFCE31FFC, 0xB179C8B1, 0x5BB6ED5B, + 0x6AD4BE6A, 0xCB8D46CB, 0xBE67D9BE, 0x39724B39, + 0x4A94DE4A, 0x4C98D44C, 0x58B0E858, 0xCF854ACF, + 0xD0BB6BD0, 0xEFC52AEF, 0xAA4FE5AA, 0xFBED16FB, + 0x4386C543, 0x4D9AD74D, 0x33665533, 0x85119485, + 0x458ACF45, 0xF9E910F9, 0x02040602, 0x7FFE817F, + 0x50A0F050, 0x3C78443C, 0x9F25BA9F, 0xA84BE3A8, + 0x51A2F351, 0xA35DFEA3, 0x4080C040, 0x8F058A8F, + 0x923FAD92, 0x9D21BC9D, 0x38704838, 0xF5F104F5, + 0xBC63DFBC, 0xB677C1B6, 0xDAAF75DA, 0x21426321, + 0x10203010, 0xFFE51AFF, 0xF3FD0EF3, 0xD2BF6DD2, + 0xCD814CCD, 0x0C18140C, 0x13263513, 0xECC32FEC, + 0x5FBEE15F, 0x9735A297, 0x4488CC44, 0x172E3917, + 0xC49357C4, 0xA755F2A7, 0x7EFC827E, 0x3D7A473D, + 0x64C8AC64, 0x5DBAE75D, 0x19322B19, 0x73E69573, + 0x60C0A060, 0x81199881, 0x4F9ED14F, 0xDCA37FDC, + 0x22446622, 0x2A547E2A, 0x903BAB90, 0x880B8388, + 0x468CCA46, 0xEEC729EE, 0xB86BD3B8, 0x14283C14, + 0xDEA779DE, 0x5EBCE25E, 0x0B161D0B, 0xDBAD76DB, + 0xE0DB3BE0, 0x32645632, 0x3A744E3A, 0x0A141E0A, + 0x4992DB49, 0x060C0A06, 0x24486C24, 0x5CB8E45C, + 0xC29F5DC2, 0xD3BD6ED3, 0xAC43EFAC, 0x62C4A662, + 0x9139A891, 0x9531A495, 0xE4D337E4, 0x79F28B79, + 0xE7D532E7, 0xC88B43C8, 0x376E5937, 0x6DDAB76D, + 0x8D018C8D, 0xD5B164D5, 0x4E9CD24E, 0xA949E0A9, + 0x6CD8B46C, 0x56ACFA56, 0xF4F307F4, 0xEACF25EA, + 0x65CAAF65, 0x7AF48E7A, 0xAE47E9AE, 0x08101808, + 0xBA6FD5BA, 0x78F08878, 0x254A6F25, 0x2E5C722E, + 0x1C38241C, 0xA657F1A6, 0xB473C7B4, 0xC69751C6, + 0xE8CB23E8, 0xDDA17CDD, 0x74E89C74, 0x1F3E211F, + 0x4B96DD4B, 0xBD61DCBD, 0x8B0D868B, 0x8A0F858A, + 0x70E09070, 0x3E7C423E, 0xB571C4B5, 0x66CCAA66, + 0x4890D848, 0x03060503, 0xF6F701F6, 0x0E1C120E, + 0x61C2A361, 0x356A5F35, 0x57AEF957, 0xB969D0B9, + 0x86179186, 0xC19958C1, 0x1D3A271D, 0x9E27B99E, + 0xE1D938E1, 0xF8EB13F8, 0x982BB398, 0x11223311, + 0x69D2BB69, 0xD9A970D9, 0x8E07898E, 0x9433A794, + 0x9B2DB69B, 0x1E3C221E, 0x87159287, 0xE9C920E9, + 0xCE8749CE, 0x55AAFF55, 0x28507828, 0xDFA57ADF, + 0x8C038F8C, 0xA159F8A1, 0x89098089, 0x0D1A170D, + 0xBF65DABF, 0xE6D731E6, 0x4284C642, 0x68D0B868, + 0x4182C341, 0x9929B099, 0x2D5A772D, 0x0F1E110F, + 0xB07BCBB0, 0x54A8FC54, 0xBB6DD6BB, 0x162C3A16 }; + +uint T3[] ={ 0xC6A56363, 0xF8847C7C, 0xEE997777, 0xF68D7B7B, + 0xFF0DF2F2, 0xD6BD6B6B, 0xDEB16F6F, 0x9154C5C5, + 0x60503030, 0x02030101, 0xCEA96767, 0x567D2B2B, + 0xE719FEFE, 0xB562D7D7, 0x4DE6ABAB, 0xEC9A7676, + 0x8F45CACA, 0x1F9D8282, 0x8940C9C9, 0xFA877D7D, + 0xEF15FAFA, 0xB2EB5959, 0x8EC94747, 0xFB0BF0F0, + 0x41ECADAD, 0xB367D4D4, 0x5FFDA2A2, 0x45EAAFAF, + 0x23BF9C9C, 0x53F7A4A4, 0xE4967272, 0x9B5BC0C0, + 0x75C2B7B7, 0xE11CFDFD, 0x3DAE9393, 0x4C6A2626, + 0x6C5A3636, 0x7E413F3F, 0xF502F7F7, 0x834FCCCC, + 0x685C3434, 0x51F4A5A5, 0xD134E5E5, 0xF908F1F1, + 0xE2937171, 0xAB73D8D8, 0x62533131, 0x2A3F1515, + 0x080C0404, 0x9552C7C7, 0x46652323, 0x9D5EC3C3, + 0x30281818, 0x37A19696, 0x0A0F0505, 0x2FB59A9A, + 0x0E090707, 0x24361212, 0x1B9B8080, 0xDF3DE2E2, + 0xCD26EBEB, 0x4E692727, 0x7FCDB2B2, 0xEA9F7575, + 0x121B0909, 0x1D9E8383, 0x58742C2C, 0x342E1A1A, + 0x362D1B1B, 0xDCB26E6E, 0xB4EE5A5A, 0x5BFBA0A0, + 0xA4F65252, 0x764D3B3B, 0xB761D6D6, 0x7DCEB3B3, + 0x527B2929, 0xDD3EE3E3, 0x5E712F2F, 0x13978484, + 0xA6F55353, 0xB968D1D1, 0x00000000, 0xC12CEDED, + 0x40602020, 0xE31FFCFC, 0x79C8B1B1, 0xB6ED5B5B, + 0xD4BE6A6A, 0x8D46CBCB, 0x67D9BEBE, 0x724B3939, + 0x94DE4A4A, 0x98D44C4C, 0xB0E85858, 0x854ACFCF, + 0xBB6BD0D0, 0xC52AEFEF, 0x4FE5AAAA, 0xED16FBFB, + 0x86C54343, 0x9AD74D4D, 0x66553333, 0x11948585, + 0x8ACF4545, 0xE910F9F9, 0x04060202, 0xFE817F7F, + 0xA0F05050, 0x78443C3C, 0x25BA9F9F, 0x4BE3A8A8, + 0xA2F35151, 0x5DFEA3A3, 0x80C04040, 0x058A8F8F, + 0x3FAD9292, 0x21BC9D9D, 0x70483838, 0xF104F5F5, + 0x63DFBCBC, 0x77C1B6B6, 0xAF75DADA, 0x42632121, + 0x20301010, 0xE51AFFFF, 0xFD0EF3F3, 0xBF6DD2D2, + 0x814CCDCD, 0x18140C0C, 0x26351313, 0xC32FECEC, + 0xBEE15F5F, 0x35A29797, 0x88CC4444, 0x2E391717, + 0x9357C4C4, 0x55F2A7A7, 0xFC827E7E, 0x7A473D3D, + 0xC8AC6464, 0xBAE75D5D, 0x322B1919, 0xE6957373, + 0xC0A06060, 0x19988181, 0x9ED14F4F, 0xA37FDCDC, + 0x44662222, 0x547E2A2A, 0x3BAB9090, 0x0B838888, + 0x8CCA4646, 0xC729EEEE, 0x6BD3B8B8, 0x283C1414, + 0xA779DEDE, 0xBCE25E5E, 0x161D0B0B, 0xAD76DBDB, + 0xDB3BE0E0, 0x64563232, 0x744E3A3A, 0x141E0A0A, + 0x92DB4949, 0x0C0A0606, 0x486C2424, 0xB8E45C5C, + 0x9F5DC2C2, 0xBD6ED3D3, 0x43EFACAC, 0xC4A66262, + 0x39A89191, 0x31A49595, 0xD337E4E4, 0xF28B7979, + 0xD532E7E7, 0x8B43C8C8, 0x6E593737, 0xDAB76D6D, + 0x018C8D8D, 0xB164D5D5, 0x9CD24E4E, 0x49E0A9A9, + 0xD8B46C6C, 0xACFA5656, 0xF307F4F4, 0xCF25EAEA, + 0xCAAF6565, 0xF48E7A7A, 0x47E9AEAE, 0x10180808, + 0x6FD5BABA, 0xF0887878, 0x4A6F2525, 0x5C722E2E, + 0x38241C1C, 0x57F1A6A6, 0x73C7B4B4, 0x9751C6C6, + 0xCB23E8E8, 0xA17CDDDD, 0xE89C7474, 0x3E211F1F, + 0x96DD4B4B, 0x61DCBDBD, 0x0D868B8B, 0x0F858A8A, + 0xE0907070, 0x7C423E3E, 0x71C4B5B5, 0xCCAA6666, + 0x90D84848, 0x06050303, 0xF701F6F6, 0x1C120E0E, + 0xC2A36161, 0x6A5F3535, 0xAEF95757, 0x69D0B9B9, + 0x17918686, 0x9958C1C1, 0x3A271D1D, 0x27B99E9E, + 0xD938E1E1, 0xEB13F8F8, 0x2BB39898, 0x22331111, + 0xD2BB6969, 0xA970D9D9, 0x07898E8E, 0x33A79494, + 0x2DB69B9B, 0x3C221E1E, 0x15928787, 0xC920E9E9, + 0x8749CECE, 0xAAFF5555, 0x50782828, 0xA57ADFDF, + 0x038F8C8C, 0x59F8A1A1, 0x09808989, 0x1A170D0D, + 0x65DABFBF, 0xD731E6E6, 0x84C64242, 0xD0B86868, + 0x82C34141, 0x29B09999, 0x5A772D2D, 0x1E110F0F, + 0x7BCBB0B0, 0xA8FC5454, 0x6DD6BBBB, 0x2C3A1616 }; + +uint T4[] ={ 0x63636363, 0x7C7C7C7C, 0x77777777, 0x7B7B7B7B, + 0xF2F2F2F2, 0x6B6B6B6B, 0x6F6F6F6F, 0xC5C5C5C5, + 0x30303030, 0x01010101, 0x67676767, 0x2B2B2B2B, + 0xFEFEFEFE, 0xD7D7D7D7, 0xABABABAB, 0x76767676, + 0xCACACACA, 0x82828282, 0xC9C9C9C9, 0x7D7D7D7D, + 0xFAFAFAFA, 0x59595959, 0x47474747, 0xF0F0F0F0, + 0xADADADAD, 0xD4D4D4D4, 0xA2A2A2A2, 0xAFAFAFAF, + 0x9C9C9C9C, 0xA4A4A4A4, 0x72727272, 0xC0C0C0C0, + 0xB7B7B7B7, 0xFDFDFDFD, 0x93939393, 0x26262626, + 0x36363636, 0x3F3F3F3F, 0xF7F7F7F7, 0xCCCCCCCC, + 0x34343434, 0xA5A5A5A5, 0xE5E5E5E5, 0xF1F1F1F1, + 0x71717171, 0xD8D8D8D8, 0x31313131, 0x15151515, + 0x04040404, 0xC7C7C7C7, 0x23232323, 0xC3C3C3C3, + 0x18181818, 0x96969696, 0x05050505, 0x9A9A9A9A, + 0x07070707, 0x12121212, 0x80808080, 0xE2E2E2E2, + 0xEBEBEBEB, 0x27272727, 0xB2B2B2B2, 0x75757575, + 0x09090909, 0x83838383, 0x2C2C2C2C, 0x1A1A1A1A, + 0x1B1B1B1B, 0x6E6E6E6E, 0x5A5A5A5A, 0xA0A0A0A0, + 0x52525252, 0x3B3B3B3B, 0xD6D6D6D6, 0xB3B3B3B3, + 0x29292929, 0xE3E3E3E3, 0x2F2F2F2F, 0x84848484, + 0x53535353, 0xD1D1D1D1, 0x00000000, 0xEDEDEDED, + 0x20202020, 0xFCFCFCFC, 0xB1B1B1B1, 0x5B5B5B5B, + 0x6A6A6A6A, 0xCBCBCBCB, 0xBEBEBEBE, 0x39393939, + 0x4A4A4A4A, 0x4C4C4C4C, 0x58585858, 0xCFCFCFCF, + 0xD0D0D0D0, 0xEFEFEFEF, 0xAAAAAAAA, 0xFBFBFBFB, + 0x43434343, 0x4D4D4D4D, 0x33333333, 0x85858585, + 0x45454545, 0xF9F9F9F9, 0x02020202, 0x7F7F7F7F, + 0x50505050, 0x3C3C3C3C, 0x9F9F9F9F, 0xA8A8A8A8, + 0x51515151, 0xA3A3A3A3, 0x40404040, 0x8F8F8F8F, + 0x92929292, 0x9D9D9D9D, 0x38383838, 0xF5F5F5F5, + 0xBCBCBCBC, 0xB6B6B6B6, 0xDADADADA, 0x21212121, + 0x10101010, 0xFFFFFFFF, 0xF3F3F3F3, 0xD2D2D2D2, + 0xCDCDCDCD, 0x0C0C0C0C, 0x13131313, 0xECECECEC, + 0x5F5F5F5F, 0x97979797, 0x44444444, 0x17171717, + 0xC4C4C4C4, 0xA7A7A7A7, 0x7E7E7E7E, 0x3D3D3D3D, + 0x64646464, 0x5D5D5D5D, 0x19191919, 0x73737373, + 0x60606060, 0x81818181, 0x4F4F4F4F, 0xDCDCDCDC, + 0x22222222, 0x2A2A2A2A, 0x90909090, 0x88888888, + 0x46464646, 0xEEEEEEEE, 0xB8B8B8B8, 0x14141414, + 0xDEDEDEDE, 0x5E5E5E5E, 0x0B0B0B0B, 0xDBDBDBDB, + 0xE0E0E0E0, 0x32323232, 0x3A3A3A3A, 0x0A0A0A0A, + 0x49494949, 0x06060606, 0x24242424, 0x5C5C5C5C, + 0xC2C2C2C2, 0xD3D3D3D3, 0xACACACAC, 0x62626262, + 0x91919191, 0x95959595, 0xE4E4E4E4, 0x79797979, + 0xE7E7E7E7, 0xC8C8C8C8, 0x37373737, 0x6D6D6D6D, + 0x8D8D8D8D, 0xD5D5D5D5, 0x4E4E4E4E, 0xA9A9A9A9, + 0x6C6C6C6C, 0x56565656, 0xF4F4F4F4, 0xEAEAEAEA, + 0x65656565, 0x7A7A7A7A, 0xAEAEAEAE, 0x08080808, + 0xBABABABA, 0x78787878, 0x25252525, 0x2E2E2E2E, + 0x1C1C1C1C, 0xA6A6A6A6, 0xB4B4B4B4, 0xC6C6C6C6, + 0xE8E8E8E8, 0xDDDDDDDD, 0x74747474, 0x1F1F1F1F, + 0x4B4B4B4B, 0xBDBDBDBD, 0x8B8B8B8B, 0x8A8A8A8A, + 0x70707070, 0x3E3E3E3E, 0xB5B5B5B5, 0x66666666, + 0x48484848, 0x03030303, 0xF6F6F6F6, 0x0E0E0E0E, + 0x61616161, 0x35353535, 0x57575757, 0xB9B9B9B9, + 0x86868686, 0xC1C1C1C1, 0x1D1D1D1D, 0x9E9E9E9E, + 0xE1E1E1E1, 0xF8F8F8F8, 0x98989898, 0x11111111, + 0x69696969, 0xD9D9D9D9, 0x8E8E8E8E, 0x94949494, + 0x9B9B9B9B, 0x1E1E1E1E, 0x87878787, 0xE9E9E9E9, + 0xCECECECE, 0x55555555, 0x28282828, 0xDFDFDFDF, + 0x8C8C8C8C, 0xA1A1A1A1, 0x89898989, 0x0D0D0D0D, + 0xBFBFBFBF, 0xE6E6E6E6, 0x42424242, 0x68686868, + 0x41414141, 0x99999999, 0x2D2D2D2D, 0x0F0F0F0F, + 0xB0B0B0B0, 0x54545454, 0xBBBBBBBB, 0x16161616 }; + +uint RC[] ={ 0x00000001, 0x00000002, 0x00000004, 0x00000008, + 0x00000010, 0x00000020, 0x00000040, 0x00000080, + 0x0000001B, 0x00000036 }; + + + +void aes_schedule( int nb, int nr, octet* k, uint* RK ) +{ + for( int i = 0; i < ( nb ); i++ ) + { U8_TO_U32_LE( RK[ i ], k, 4*i ); } + + for( int i = nb, j = 0; i < ( 4 * ( nr + 1 ) ); i++ ) + { uint t = RK[ i - 1 ]; + uint p = RK[ i - nb ]; + + if ( ( ( i % nb ) == 0 ) ) { + t = RC[ j++ ] ^ ( T4[ ( t >> 8 ) & 0xFF ] & 0x000000FF ) ^ + ( T4[ ( t >> 16 ) & 0xFF ] & 0x0000FF00 ) ^ + ( T4[ ( t >> 24 ) & 0xFF ] & 0x00FF0000 ) ^ + ( T4[ ( t >> 0 ) & 0xFF ] & 0xFF000000 ) ; + } + else if( ( ( i % nb ) == 4 ) && ( nb == 8 ) ) { + t = ( T4[ ( t >> 0 ) & 0xFF ] & 0x000000FF ) ^ + ( T4[ ( t >> 8 ) & 0xFF ] & 0x0000FF00 ) ^ + ( T4[ ( t >> 16 ) & 0xFF ] & 0x00FF0000 ) ^ + ( T4[ ( t >> 24 ) & 0xFF ] & 0xFF000000 ) ; + } + RK[ i ] = t ^ p; + } +} + + +void aes_128_encrypt( octet* C, octet* M, uint* RK ) +{ + uint t0, t1, t2, t3, t4, t5, t6, t7; + + U8_TO_U32_LE( t0, M, 0 ); U8_TO_U32_LE( t1, M, 4 ); + U8_TO_U32_LE( t2, M, 8 ); U8_TO_U32_LE( t3, M, 12 ); + + ROUND1( 0, 1, 2, 3 ); + ROUND2( 4, 5, 6, 7 ); ROUND2( 8, 9, 10, 11 ); ROUND2( 12, 13, 14, 15 ); + ROUND2( 16, 17, 18, 19 ); ROUND2( 20, 21, 22, 23 ); ROUND2( 24, 25, 26, 27 ); + ROUND2( 28, 29, 30, 31 ); ROUND2( 32, 33, 34, 35 ); ROUND2( 36, 37, 38, 39 ); + ROUND3( 40, 41, 42, 43 ); + + U32_TO_U8_LE( C, t4, 0 ); U32_TO_U8_LE( C, t5, 4 ); + U32_TO_U8_LE( C, t6, 8 ); U32_TO_U8_LE( C, t7, 12 ); +} + + +void aes_192_encrypt( octet* C, octet* M, uint* RK ) +{ + uint t0, t1, t2, t3, t4, t5, t6, t7; + + U8_TO_U32_LE( t0, M, 0 ); U8_TO_U32_LE( t1, M, 4 ); + U8_TO_U32_LE( t2, M, 8 ); U8_TO_U32_LE( t3, M, 12 ); + + ROUND1( 0, 1, 2, 3 ); + ROUND2( 4, 5, 6, 7 ); ROUND2( 8, 9, 10, 11 ); ROUND2( 12, 13, 14, 15 ); + ROUND2( 16, 17, 18, 19 ); ROUND2( 20, 21, 22, 23 ); ROUND2( 24, 25, 26, 27 ); + ROUND2( 28, 29, 30, 31 ); ROUND2( 32, 33, 34, 35 ); ROUND2( 36, 37, 38, 39 ); + ROUND2( 40, 41, 42, 43 ); ROUND2( 44, 45, 46, 47 ); + ROUND3( 48, 49, 50, 51 ); + + U32_TO_U8_LE( C, t4, 0 ); U32_TO_U8_LE( C, t5, 4 ); + U32_TO_U8_LE( C, t6, 8 ); U32_TO_U8_LE( C, t7, 12 ); +} + + +void aes_256_encrypt( octet* C, octet* M, uint* RK ) +{ + uint t0, t1, t2, t3, t4, t5, t6, t7; + + U8_TO_U32_LE( t0, M, 0 ); U8_TO_U32_LE( t1, M, 4 ); + U8_TO_U32_LE( t2, M, 8 ); U8_TO_U32_LE( t3, M, 12 ); + + ROUND1( 0, 1, 2, 3 ); + ROUND2( 4, 5, 6, 7 ); ROUND2( 8, 9, 10, 11 ); ROUND2( 12, 13, 14, 15 ); + ROUND2( 16, 17, 18, 19 ); ROUND2( 20, 21, 22, 23 ); ROUND2( 24, 25, 26, 27 ); + ROUND2( 28, 29, 30, 31 ); ROUND2( 32, 33, 34, 35 ); ROUND2( 36, 37, 38, 39 ); + ROUND2( 40, 41, 42, 43 ); ROUND2( 44, 45, 46, 47 ); ROUND2( 48, 49, 50, 51 ); + ROUND2( 52, 53, 54, 55 ); + ROUND3( 56, 57, 58, 59 ); + + U32_TO_U8_LE( C, t4, 0 ); U32_TO_U8_LE( C, t5, 4 ); + U32_TO_U8_LE( C, t6, 8 ); U32_TO_U8_LE( C, t7, 12 ); +} diff --git a/Tools/aes.h b/Tools/aes.h new file mode 100644 index 000000000..2924abb3c --- /dev/null +++ b/Tools/aes.h @@ -0,0 +1,98 @@ +// (C) 2016 University of Bristol. See License.txt + +#ifndef __AES_H +#define __AES_H + +#include + +#include "Networking/data.h" + +typedef unsigned int uint; + +#define AES_BLK_SIZE 16 + +/************* C Version *************/ +// Key Schedule +void aes_schedule( int nb, int nr, octet* k, uint* RK ); + +inline void aes_schedule( uint* RK, octet* K ) +{ aes_schedule(4,10,K,RK); } +inline void aes_128_schedule( uint* RK, octet* K ) +{ aes_schedule(4,10,K,RK); } +inline void aes_192_schedule( uint* RK, octet* K ) +{ aes_schedule(6,12,K,RK); } +inline void aes_256_schedule( uint* RK, octet* K ) +{ aes_schedule(8,14,K,RK); } + +// Encryption Function +void aes_128_encrypt( octet* C, octet* M, uint* RK ); +void aes_192_encrypt( octet* C, octet* M, uint* RK ); +void aes_256_encrypt( octet* C, octet* M, uint* RK ); + +inline void aes_encrypt( octet* C, octet* M, uint* RK ) +{ aes_128_encrypt(C,M,RK ); } + + +/*********** M-Code Version ***********/ +// Check can support this +int Check_CPU_support_AES(); +// Key Schedule +void aes_128_schedule( octet* key, const octet* userkey ); +void aes_192_schedule( octet* key, const octet* userkey ); +void aes_256_schedule( octet* key, const octet* userkey ); + +inline void aes_schedule( octet* key, const octet* userkey ) +{ aes_128_schedule(key,userkey); } + + +// Encryption Function +void aes_128_encrypt( octet* C, const octet* M,const octet* RK ); +void aes_192_encrypt( octet* C, const octet* M,const octet* RK ); +void aes_256_encrypt( octet* C, const octet* M,const octet* RK ); + +__attribute__((optimize("unroll-loops"))) +inline __m128i aes_128_encrypt(__m128i in, const octet* key) +{ __m128i& tmp = in; + tmp = _mm_xor_si128 (tmp,((__m128i*)key)[0]); + int j; + for(j=1; j <10; j++) + { tmp = _mm_aesenc_si128 (tmp,((__m128i*)key)[j]); } + tmp = _mm_aesenclast_si128 (tmp,((__m128i*)key)[j]); + return tmp; +} + +template +__attribute__((optimize("unroll-loops"))) +inline void ecb_aes_128_encrypt(__m128i* out, __m128i* in, const octet* key) +{ + __m128i tmp[N]; + for (int i = 0; i < N; i++) + tmp[i] = _mm_xor_si128 (in[i],((__m128i*)key)[0]); + int j; + for(j=1; j <10; j++) + for (int i = 0; i < N; i++) + tmp[i] = _mm_aesenc_si128 (tmp[i],((__m128i*)key)[j]); + for (int i = 0; i < N; i++) + out[i] = _mm_aesenclast_si128 (tmp[i],((__m128i*)key)[j]); +} + +template +inline void ecb_aes_128_encrypt(__m128i* out, const __m128i* in, const octet* key, const int* indices) +{ + __m128i tmp[N]; + for (int i = 0; i < N; i++) + tmp[i] = in[indices[i]]; + ecb_aes_128_encrypt(tmp, tmp, key); + for (int i = 0; i < N; i++) + out[indices[i]] = tmp[i]; +} + +inline void aes_encrypt( octet* C, const octet* M,const octet* RK ) +{ aes_128_encrypt(C,M,RK); } + +inline __m128i aes_encrypt( __m128i M,const octet* RK ) +{ return aes_128_encrypt(M,RK); } + + +#endif + diff --git a/Tools/ezOptionParser-MIT-LICENSE b/Tools/ezOptionParser-MIT-LICENSE new file mode 100644 index 000000000..54aee517b --- /dev/null +++ b/Tools/ezOptionParser-MIT-LICENSE @@ -0,0 +1,7 @@ +Copyright (C) 2011,2012 Remik Ziemlinski + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/Tools/ezOptionParser.h b/Tools/ezOptionParser.h new file mode 100644 index 000000000..d2b09c056 --- /dev/null +++ b/Tools/ezOptionParser.h @@ -0,0 +1,2160 @@ +// (C) 2016 University of Bristol. See License.txt + +/* +This file is part of ezOptionParser. See MIT-LICENSE. + +Copyright (C) 2011,2012,2014 Remik Ziemlinski + +CHANGELOG + +v0.0.0 20110505 rsz Created. +v0.1.0 20111006 rsz Added validator. +v0.1.1 20111012 rsz Fixed validation of ulonglong. +v0.1.2 20111126 rsz Allow flag names start with alphanumeric (previously, flag had to start with alpha). +v0.1.3 20120108 rsz Created work-around for unique id generation with IDGenerator that avoids retarded c++ translation unit linker errors with single-header static variables. Forced inline on all methods to please retard compiler and avoid multiple def errors. +v0.1.4 20120629 Enforced MIT license on all files. +v0.2.0 20121120 Added parseIndex to OptionGroup. +v0.2.1 20130506 Allow disabling doublespace of OPTIONS usage descriptions. +v0.2.2 20140504 Jose Santiago added compiler warning fixes. + Bruce Shankle added a crash fix in description printing. +*/ +#ifndef EZ_OPTION_PARSER_H +#define EZ_OPTION_PARSER_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ez { +#define DEBUGLINE() printf("%s:%d\n", __FILE__, __LINE__); + +/* ################################################################### */ +template +static T fromString(const std::string* s) { + std::istringstream stream (s->c_str()); + T t; + stream >> t; + return t; +}; +template +static T fromString(const char* s) { + std::istringstream stream (s); + T t; + stream >> t; + return t; +}; +/* ################################################################### */ +static inline bool isdigit(const std::string & s, int i=0) { + int n = s.length(); + for(; i < n; ++i) + switch(s[i]) { + case '0': case '1': case '2': + case '3': case '4': case '5': + case '6': case '7': case '8': case '9': break; + default: return false; + } + + return true; +}; +/* ################################################################### */ +static bool isdigit(const std::string * s, int i=0) { + int n = s->length(); + for(; i < n; ++i) + switch(s->at(i)) { + case '0': case '1': case '2': + case '3': case '4': case '5': + case '6': case '7': case '8': case '9': break; + default: return false; + } + + return true; +}; +/* ################################################################### */ +/* +Compare strings for opts, so short opt flags come before long format flags. +For example, -d < --dimension < --dmn, and also lower come before upper. The default STL std::string compare doesn't do that. +*/ +static bool CmpOptStringPtr(std::string * s1, std::string * s2) { + int c1,c2; + const char *s=s1->c_str(); + for(c1=0; c1 < (long int)s1->size(); ++c1) + if (isalnum(s[c1])) // locale sensitive. + break; + + s=s2->c_str(); + for(c2=0; c2 < (long int)s2->size(); ++c2) + if (isalnum(s[c2])) + break; + + // Test which has more symbols before its name. + if (c1 > c2) + return false; + else if (c1 < c2) + return true; + + // Both have same number of symbols, so compare first letter. + char char1 = s1->at(c1); + char char2 = s2->at(c2); + char lo1 = tolower(char1); + char lo2 = tolower(char2); + + if (lo1 != lo2) + return lo1 < lo2; + + // Their case doesn't match, so find which is lower. + char up1 = isupper(char1); + char up2 = isupper(char2); + + if (up1 && !up2) + return false; + else if (!up1 && up2) + return true; + + return (s1->compare(*s2)<0); +}; +/* ################################################################### */ +/* +Makes a vector of strings from one string, +splitting at (and excluding) delimiter "token". +*/ +static void SplitDelim( const std::string& s, const char token, std::vector * result) { + std::string::const_iterator i = s.begin(); + std::string::const_iterator j = s.begin(); + const std::string::const_iterator e = s.end(); + + while(i!=e) { + while(i!=e && *i++!=token); + std::string *newstr = new std::string(j, i); + if (newstr->at(newstr->size()-1) == token) newstr->erase(newstr->size()-1); + result->push_back(newstr); + j = i; + } +}; +/* ################################################################### */ +// Variant that uses deep copies and references instead of pointers (less efficient). +static void SplitDelim( const std::string& s, const char token, std::vector & result) { + std::string::const_iterator i = s.begin(); + std::string::const_iterator j = s.begin(); + const std::string::const_iterator e = s.end(); + + while(i!=e) { + while(i!=e && *i++!=token); + std::string newstr(j, i); + if (newstr.at(newstr.size()-1) == token) newstr.erase(newstr.size()-1); + result.push_back(newstr); + j = i; + } +}; +/* ################################################################### */ +// Variant that uses list instead of vector for efficient insertion, etc. +static void SplitDelim( const std::string& s, const char token, std::list & result) { + std::string::const_iterator i = s.begin(); + std::string::const_iterator j = s.begin(); + const std::string::const_iterator e = s.end(); + + while(i!=e) { + while(i!=e && *i++!=token); + std::string *newstr = new std::string(j, i); + if (newstr->at(newstr->size()-1) == token) newstr->erase(newstr->size()-1); + result.push_back(newstr); + j = i; + } +}; +/* ################################################################### */ +static void ToU1(std::string ** strings, unsigned char * out, int n) { + for(int i=0; i < n; ++i) { + out[i] = (unsigned char)atoi(strings[i]->c_str()); + } +}; +/* ################################################################### */ +static void ToS1(std::string ** strings, char * out, int n) { + for(int i=0; i < n; ++i) { + out[i] = (char)atoi(strings[i]->c_str()); + } +}; +/* ################################################################### */ +static void ToU2(std::string ** strings, unsigned short * out, int n) { + for(int i=0; i < n; ++i) { + out[i] = (unsigned short)atoi(strings[i]->c_str()); + } +}; +/* ################################################################### */ +static void ToS2(std::string ** strings, short * out, int n) { + for(int i=0; i < n; ++i) { + out[i] = (short)atoi(strings[i]->c_str()); + } +}; +/* ################################################################### */ +static void ToS4(std::string ** strings, int * out, int n) { + for(int i=0; i < n; ++i) { + out[i] = atoi(strings[i]->c_str()); + } +}; +/* ################################################################### */ +static void ToU4(std::string ** strings, unsigned int * out, int n) { + for(int i=0; i < n; ++i) { + out[i] = (unsigned int)strtoul(strings[i]->c_str(), NULL, 0); + } +}; +/* ################################################################### */ +static void ToS8(std::string ** strings, long long * out, int n) { + for(int i=0; i < n; ++i) { + std::stringstream ss(strings[i]->c_str()); + ss >> out[i]; + } +}; +/* ################################################################### */ +static void ToU8(std::string ** strings, unsigned long long * out, int n) { + for(int i=0; i < n; ++i) { + std::stringstream ss(strings[i]->c_str()); + ss >> out[i]; + } +}; +/* ################################################################### */ +static void ToF(std::string ** strings, float * out, int n) { + for(int i=0; i < n; ++i) { + out[i] = (float)atof(strings[i]->c_str()); + } +}; +/* ################################################################### */ +static void ToD(std::string ** strings, double * out, int n) { + for(int i=0; i < n; ++i) { + out[i] = (double)atof(strings[i]->c_str()); + } +}; +/* ################################################################### */ +static void StringsToInts(std::vector & strings, std::vector & out) { + for(int i=0; i < (long int)strings.size(); ++i) { + out.push_back(atoi(strings[i].c_str())); + } +}; +/* ################################################################### */ +static void StringsToInts(std::vector * strings, std::vector * out) { + for(int i=0; i < (long int)strings->size(); ++i) { + out->push_back(atoi(strings->at(i)->c_str())); + } +}; +/* ################################################################### */ +static void StringsToLongs(std::vector & strings, std::vector & out) { + for(int i=0; i < (long int)strings.size(); ++i) { + out.push_back(atol(strings[i].c_str())); + } +}; +/* ################################################################### */ +static void StringsToLongs(std::vector * strings, std::vector * out) { + for(int i=0; i < (long int)strings->size(); ++i) { + out->push_back(atol(strings->at(i)->c_str())); + } +}; +/* ################################################################### */ +static void StringsToULongs(std::vector & strings, std::vector & out) { + for(int i=0; i < (long int)strings.size(); ++i) { + out.push_back(strtoul(strings[i].c_str(),0,0)); + } +}; +/* ################################################################### */ +static void StringsToULongs(std::vector * strings, std::vector * out) { + for(int i=0; i < (long int)strings->size(); ++i) { + out->push_back(strtoul(strings->at(i)->c_str(),0,0)); + } +}; +/* ################################################################### */ +static void StringsToFloats(std::vector & strings, std::vector & out) { + for(int i=0; i < (long int)strings.size(); ++i) { + out.push_back(atof(strings[i].c_str())); + } +}; +/* ################################################################### */ +static void StringsToFloats(std::vector * strings, std::vector * out) { + for(int i=0; i < (long int)strings->size(); ++i) { + out->push_back(atof(strings->at(i)->c_str())); + } +}; +/* ################################################################### */ +static void StringsToDoubles(std::vector & strings, std::vector & out) { + for(int i=0; i < (long int)strings.size(); ++i) { + out.push_back(atof(strings[i].c_str())); + } +}; +/* ################################################################### */ +static void StringsToDoubles(std::vector * strings, std::vector * out) { + for(int i=0; i < (long int)strings->size(); ++i) { + out->push_back(atof(strings->at(i)->c_str())); + } +}; +/* ################################################################### */ +static void StringsToStrings(std::vector * strings, std::vector * out) { + for(int i=0; i < (long int)strings->size(); ++i) { + out->push_back( *strings->at(i) ); + } +}; +/* ################################################################### */ +static void ToLowerASCII(std::string & s) { + int n = s.size(); + int i=0; + char c; + for(; i < n; ++i) { + c = s[i]; + if(c<='Z' && c>='A') + s[i] = c+32; + } +} +/* ################################################################### */ +static char** CommandLineToArgvA(char* CmdLine, int* _argc) { + char** argv; + char* _argv; + unsigned long len; + unsigned long argc; + char a; + unsigned long i, j; + + bool in_QM; + bool in_TEXT; + bool in_SPACE; + + len = strlen(CmdLine); + i = ((len+2)/2)*sizeof(void*) + sizeof(void*); + + argv = (char**)malloc(i + (len+2)*sizeof(char)); + + _argv = (char*)(((unsigned char*)argv)+i); + + argc = 0; + argv[argc] = _argv; + in_QM = false; + in_TEXT = false; + in_SPACE = true; + i = 0; + j = 0; + + while( (a = CmdLine[i]) ) { + if(in_QM) { + if( (a == '\"') || + (a == '\'')) // rsz. Added single quote. + { + in_QM = false; + } else { + _argv[j] = a; + j++; + } + } else { + switch(a) { + case '\"': + case '\'': // rsz. Added single quote. + in_QM = true; + in_TEXT = true; + if(in_SPACE) { + argv[argc] = _argv+j; + argc++; + } + in_SPACE = false; + break; + case ' ': + case '\t': + case '\n': + case '\r': + if(in_TEXT) { + _argv[j] = '\0'; + j++; + } + in_TEXT = false; + in_SPACE = true; + break; + default: + in_TEXT = true; + if(in_SPACE) { + argv[argc] = _argv+j; + argc++; + } + _argv[j] = a; + j++; + in_SPACE = false; + break; + } + } + i++; + } + _argv[j] = '\0'; + argv[argc] = NULL; + + (*_argc) = argc; + return argv; +}; +/* ################################################################### */ +// Create unique ids with static and still allow single header that avoids multiple definitions linker error. +class ezOptionParserIDGenerator { +public: + static ezOptionParserIDGenerator& instance () { static ezOptionParserIDGenerator Generator; return Generator; } + short next () { return ++_id; } +private: + ezOptionParserIDGenerator() : _id(-1) {} + short _id; +}; +/* ################################################################### */ +/* Validate a value by checking: +- if as string, see if converted value is within datatype's limits, +- and see if falls within a desired range, +- or see if within set of given list of values. + +If comparing with a range, the values list must contain one or two values. One value is required when comparing with <, <=, >, >=. Use two values when requiring a test such as list[0] */ + GE, /* value >= list[0] */ + GTLT, /* list[0] < value < list[1] */ + GELT, /* list[0] <= value < list[1] */ + GELE, /* list[0] <= value <= list[1] */ + GTLE, /* list[0] < value <= list[1] */ + IN /* if value is in list */ + }; + + enum TYPE { NOTYPE=0, S1, U1, S2, U2, S4, U4, S8, U8, F, D, T }; + enum TYPE2 { NOTYPE2=0, INT8, UINT8, INT16, UINT16, INT32, UINT32, INT64, UINT64, FLOAT, DOUBLE, TEXT }; + + union { + unsigned char *u1; + char *s1; + unsigned short *u2; + short *s2; + unsigned int *u4; + int *s4; + unsigned long long *u8; + long long *s8; + float *f; + double *d; + std::string** t; + }; + + char op; + bool quiet; + short id; + char type; + int size; + bool insensitive; +}; +/* ------------------------------------------------------------------- */ +ezOptionValidator::~ezOptionValidator() { + reset(); +}; +/* ------------------------------------------------------------------- */ +void ezOptionValidator::reset() { + #define CLEAR(TYPE,P) case TYPE: if (P) delete [] P; P = 0; break; + switch(type) { + CLEAR(S1,s1); + CLEAR(U1,u1); + CLEAR(S2,s2); + CLEAR(U2,u2); + CLEAR(S4,s4); + CLEAR(U4,u4); + CLEAR(S8,s8); + CLEAR(U8,u8); + CLEAR(F,f); + CLEAR(D,d); + case T: + for(int i=0; i < size; ++i) + delete t[i]; + + delete [] t; + t = 0; + break; + default: break; + } + + size = 0; + op = NOOP; + type = NOTYPE; +}; +/* ------------------------------------------------------------------- */ +ezOptionValidator::ezOptionValidator(char _type) : s1(0), op(0), quiet(0), type(_type), size(0), insensitive(0) { + id = ezOptionParserIDGenerator::instance().next(); +}; +/* ------------------------------------------------------------------- */ +ezOptionValidator::ezOptionValidator(char _type, char _op, const char* list, int _size) : s1(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { + id = ezOptionParserIDGenerator::instance().next(); + s1 = new char[size]; + memcpy(s1, list, size); +}; +/* ------------------------------------------------------------------- */ +ezOptionValidator::ezOptionValidator(char _type, char _op, const unsigned char* list, int _size) : u1(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { + id = ezOptionParserIDGenerator::instance().next(); + u1 = new unsigned char[size]; + memcpy(u1, list, size); +}; +/* ------------------------------------------------------------------- */ +ezOptionValidator::ezOptionValidator(char _type, char _op, const short* list, int _size) : s2(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { + id = ezOptionParserIDGenerator::instance().next(); + s2 = new short[size]; + memcpy(s2, list, size*sizeof(short)); +}; +/* ------------------------------------------------------------------- */ +ezOptionValidator::ezOptionValidator(char _type, char _op, const unsigned short* list, int _size) : u2(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { + id = ezOptionParserIDGenerator::instance().next(); + u2 = new unsigned short[size]; + memcpy(u2, list, size*sizeof(unsigned short)); +}; +/* ------------------------------------------------------------------- */ +ezOptionValidator::ezOptionValidator(char _type, char _op, const int* list, int _size) : s4(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { + id = ezOptionParserIDGenerator::instance().next(); + s4 = new int[size]; + memcpy(s4, list, size*sizeof(int)); +}; +/* ------------------------------------------------------------------- */ +ezOptionValidator::ezOptionValidator(char _type, char _op, const unsigned int* list, int _size) : u4(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { + id = ezOptionParserIDGenerator::instance().next(); + u4 = new unsigned int[size]; + memcpy(u4, list, size*sizeof(unsigned int)); +}; +/* ------------------------------------------------------------------- */ +ezOptionValidator::ezOptionValidator(char _type, char _op, const long long* list, int _size) : s8(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { + id = ezOptionParserIDGenerator::instance().next(); + s8 = new long long[size]; + memcpy(s8, list, size*sizeof(long long)); +}; +/* ------------------------------------------------------------------- */ +ezOptionValidator::ezOptionValidator(char _type, char _op, const unsigned long long* list, int _size) : u8(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { + id = ezOptionParserIDGenerator::instance().next(); + u8 = new unsigned long long[size]; + memcpy(u8, list, size*sizeof(unsigned long long)); +}; +/* ------------------------------------------------------------------- */ +ezOptionValidator::ezOptionValidator(char _type, char _op, const float* list, int _size) : f(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { + id = ezOptionParserIDGenerator::instance().next(); + f = new float[size]; + memcpy(f, list, size*sizeof(float)); +}; +/* ------------------------------------------------------------------- */ +ezOptionValidator::ezOptionValidator(char _type, char _op, const double* list, int _size) : d(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { + id = ezOptionParserIDGenerator::instance().next(); + d = new double[size]; + memcpy(d, list, size*sizeof(double)); +}; +/* ------------------------------------------------------------------- */ +ezOptionValidator::ezOptionValidator(char _type, char _op, const char** list, int _size, bool _insensitive) : t(0), op(_op), quiet(0), type(_type), size(_size), insensitive(_insensitive) { + id = ezOptionParserIDGenerator::instance().next(); + t = new std::string*[size]; + int i=0; + + for(; i < size; ++i) { + t[i] = new std::string(list[i]); + } +}; +/* ------------------------------------------------------------------- */ +/* Less efficient but convenient ctor that parses strings to setup validator. +_type: s1, u1, s2, u2, ..., f, d, t +_op: lt, gt, ..., in +_list: comma-delimited string +*/ +ezOptionValidator::ezOptionValidator(const char* _type, const char* _op, const char* _list, bool _insensitive) : t(0), quiet(0), type(0), size(0), insensitive(_insensitive) { + id = ezOptionParserIDGenerator::instance().next(); + + switch(_type[0]) { + case 'u': + switch(_type[1]) { + case '1': type = U1; break; + case '2': type = U2; break; + case '4': type = U4; break; + case '8': type = U8; break; + default: break; + } + break; + case 's': + switch(_type[1]) { + case '1': type = S1; + break; + case '2': type = S2; break; + case '4': type = S4; break; + case '8': type = S8; break; + default: break; + } + break; + case 'f': type = F; break; + case 'd': type = D; break; + case 't': type = T; break; + default: + if (!quiet) + std::cerr << "ERROR: Unknown validator datatype \"" << _type << "\".\n"; + break; + } + + int nop = 0; + if (_op != 0) + nop = strlen(_op); + + switch(nop) { + case 0: op = NOOP; break; + case 2: + switch(_op[0]) { + case 'g': + switch(_op[1]) { + case 'e': op = GE; break; + default: op = GT; break; + } + break; + case 'i': op = IN; + break; + default: + switch(_op[1]) { + case 'e': op = LE; break; + default: op = LT; break; + } + break; + } + break; + case 4: + switch(_op[1]) { + case 'e': + switch(_op[3]) { + case 'e': op = GELE; break; + default: op = GELT; break; + } + break; + default: + switch(_op[3]) { + case 'e': op = GTLE; break; + default: op = GTLT; break; + } + break; + } + break; + default: + if (!quiet) + std::cerr << "ERROR: Unknown validator operation \"" << _op << "\".\n"; + break; + } + + if (_list == 0) return; + // Create list of strings and then cast to native datatypes. + std::string unsplit(_list); + std::list split; + std::list::iterator it; + SplitDelim(unsplit, ',', split); + size = split.size(); + std::string **strings = new std::string*[size]; + + int i = 0; + for(it = split.begin(); it != split.end(); ++it) + strings[i++] = *it; + + if (insensitive) + for(i=0; i < size; ++i) + ToLowerASCII(*strings[i]); + + #define FreeStrings() { \ + for(i=0; i < size; ++i)\ + delete strings[i];\ + delete [] strings;\ + } + + #define ToArray(T,P,Y) case T: P = new Y[size]; To##T(strings, P, size); FreeStrings(); break; + switch(type) { + ToArray(S1,s1,char); + ToArray(U1,u1,unsigned char); + ToArray(S2,s2,short); + ToArray(U2,u2,unsigned short); + ToArray(S4,s4,int); + ToArray(U4,u4,unsigned int); + ToArray(S8,s8,long long); + ToArray(U8,u8,unsigned long long); + ToArray(F,f,float); + ToArray(D,d,double); + case T: t = strings; break; /* Don't erase strings array. */ + default: break; + } +}; +/* ------------------------------------------------------------------- */ +void ezOptionValidator::print() { + printf("id=%d, op=%d, type=%d, size=%d, insensitive=%d\n", id, op, type, size, insensitive); +}; +/* ------------------------------------------------------------------- */ +bool ezOptionValidator::isValid(const std::string * valueAsString) { + if (valueAsString == 0) return false; + +#define CHECKRANGE(E,T) {\ + std::stringstream ss(valueAsString->c_str()); \ + long long E##value; \ + ss >> E##value; \ + long long E##min = static_cast(std::numeric_limits::min()); \ + if (E##value < E##min) { \ + if (!quiet) \ + std::cerr << "ERROR: Invalid value " << E##value << " is less than datatype min " << E##min << ".\n"; \ + return false; \ + } \ + \ + long long E##max = static_cast(std::numeric_limits::max()); \ + if (E##value > E##max) { \ + if (!quiet) \ + std::cerr << "ERROR: Invalid value " << E##value << " is greater than datatype max " << E##max << ".\n"; \ + return false; \ + } \ +} + // Check if within datatype limits. + if (type != T) { + switch(type) { + case S1: CHECKRANGE(S1,char); break; + case U1: CHECKRANGE(U1,unsigned char); break; + case S2: CHECKRANGE(S2,short); break; + case U2: CHECKRANGE(U2,unsigned short); break; + case S4: CHECKRANGE(S4,int); break; + case U4: CHECKRANGE(U4,unsigned int); break; + case S8: { + if ( (valueAsString->at(0) == '-') && + isdigit(valueAsString,1) && + (valueAsString->size() > 19) && + (valueAsString->compare(1, 19, "9223372036854775808") > 0) ) { + if (!quiet) + std::cerr << "ERROR: Invalid value " << *valueAsString << " is less than datatype min -9223372036854775808.\n"; + return false; + } + + if (isdigit(valueAsString) && + (valueAsString->size() > 18) && + valueAsString->compare("9223372036854775807") > 0) { + if (!quiet) + std::cerr << "ERROR: Invalid value " << *valueAsString << " is greater than datatype max 9223372036854775807.\n"; + return false; + } + } break; + case U8: { + if (valueAsString->compare("0") < 0) { + if (!quiet) + std::cerr << "ERROR: Invalid value " << *valueAsString << " is less than datatype min 0.\n"; + return false; + } + + if (isdigit(valueAsString) && + (valueAsString->size() > 19) && + valueAsString->compare("18446744073709551615") > 0) { + if (!quiet) + std::cerr << "ERROR: Invalid value " << *valueAsString << " is greater than datatype max 18446744073709551615.\n"; + return false; + } + } break; + case F: { + double dmax = static_cast(std::numeric_limits::max()); + double dvalue = atof(valueAsString->c_str()); + double dmin = -dmax; + if (dvalue < dmin) { + if (!quiet) { + fprintf(stderr, "ERROR: Invalid value %g is less than datatype min %g.\n", dvalue, dmin); + } + return false; + } + + if (dvalue > dmax) { + if (!quiet) + std::cerr << "ERROR: Invalid value " << dvalue << " is greater than datatype max " << dmax << ".\n"; + return false; + } + } break; + case D: { + long double ldmax = static_cast(std::numeric_limits::max()); + std::stringstream ss(valueAsString->c_str()); + long double ldvalue; + ss >> ldvalue; + long double ldmin = -ldmax; + + if (ldvalue < ldmin) { + if (!quiet) + std::cerr << "ERROR: Invalid value " << ldvalue << " is less than datatype min " << ldmin << ".\n"; + return false; + } + + if (ldvalue > ldmax) { + if (!quiet) + std::cerr << "ERROR: Invalid value " << ldvalue << " is greater than datatype max " << ldmax << ".\n"; + return false; + } + } break; + case NOTYPE: default: break; + } + } else { + if (op == IN) { + int i=0; + if (insensitive) { + std::string valueAsStringLower(*valueAsString); + ToLowerASCII(valueAsStringLower); + for(; i < size; ++i) { + if (valueAsStringLower.compare(t[i]->c_str()) == 0) + return true; + } + } else { + for(; i < size; ++i) { + if (valueAsString->compare(t[i]->c_str()) == 0) + return true; + } + } + return false; + } + } + + // Only check datatype limits, and return; + if (op == NOOP) return true; + +#define VALIDATE(T, U, LIST) { \ + /* Value string converted to true native type. */ \ + std::stringstream ss(valueAsString->c_str());\ + U v;\ + ss >> v;\ + /* Check if within list. */ \ + if (op == IN) { \ + T * last = LIST + size;\ + return (last != std::find(LIST, last, v)); \ + } \ + \ + /* Check if within user's custom range. */ \ + T v0, v1; \ + if (size > 0) { \ + v0 = LIST[0]; \ + } \ + \ + if (size > 1) { \ + v1 = LIST[1]; \ + } \ + \ + switch (op) {\ + case LT:\ + if (size > 0) {\ + return v < v0;\ + } else {\ + std::cerr << "ERROR: No value given to validate if " << v << " < X.\n";\ + return false;\ + }\ + break;\ + case LE:\ + if (size > 0) {\ + return v <= v0;\ + } else {\ + std::cerr << "ERROR: No value given to validate if " << v << " <= X.\n";\ + return false;\ + }\ + break;\ + case GT:\ + if (size > 0) {\ + return v > v0;\ + } else {\ + std::cerr << "ERROR: No value given to validate if " << v << " > X.\n";\ + return false;\ + }\ + break;\ + case GE:\ + if (size > 0) {\ + return v >= v0;\ + } else {\ + std::cerr << "ERROR: No value given to validate if " << v << " >= X.\n";\ + return false;\ + }\ + break;\ + case GTLT:\ + if (size > 1) {\ + return (v0 < v) && (v < v1);\ + } else {\ + std::cerr << "ERROR: Missing values to validate if X1 < " << v << " < X2.\n";\ + return false;\ + }\ + break;\ + case GELT:\ + if (size > 1) {\ + return (v0 <= v) && (v < v1);\ + } else {\ + std::cerr << "ERROR: Missing values to validate if X1 <= " << v << " < X2.\n";\ + return false;\ + }\ + break;\ + case GELE:\ + if (size > 1) {\ + return (v0 <= v) && (v <= v1);\ + } else {\ + std::cerr << "ERROR: Missing values to validate if X1 <= " << v << " <= X2.\n";\ + return false;\ + }\ + break;\ + case GTLE:\ + if (size > 1) {\ + return (v0 < v) && (v <= v1);\ + } else {\ + std::cerr << "ERROR: Missing values to validate if X1 < " << v << " <= X2.\n";\ + return false;\ + }\ + break;\ + case NOOP: case IN: default: break;\ + } \ + } + + switch(type) { + case U1: VALIDATE(unsigned char, int, u1); break; + case S1: VALIDATE(char, int, s1); break; + case U2: VALIDATE(unsigned short, int, u2); break; + case S2: VALIDATE(short, int, s2); break; + case U4: VALIDATE(unsigned int, unsigned int, u4); break; + case S4: VALIDATE(int, int, s4); break; + case U8: VALIDATE(unsigned long long, unsigned long long, u8); break; + case S8: VALIDATE(long long, long long, s8); break; + case F: VALIDATE(float, float, f); break; + case D: VALIDATE(double, double, d); break; + default: break; + } + + return true; +}; +/* ################################################################### */ +class OptionGroup { +public: + OptionGroup() : delim(0), expectArgs(0), isRequired(false), isSet(false) { } + + ~OptionGroup() { + int i; + for(i=0; i < (long int)flags.size(); ++i) + delete flags[i]; + + flags.clear(); + parseIndex.clear(); + clearArgs(); + }; + + inline void clearArgs(); + inline void getInt(int&); + inline void getLong(long&); + inline void getLongLong(long long&); + inline void getULong(unsigned long&); + inline void getULongLong(unsigned long long&); + inline void getFloat(float&); + inline void getDouble(double&); + inline void getString(std::string&); + inline void getInts(std::vector&); + inline void getLongs(std::vector&); + inline void getULongs(std::vector&); + inline void getFloats(std::vector&); + inline void getDoubles(std::vector&); + inline void getStrings(std::vector&); + inline void getMultiInts(std::vector< std::vector >&); + inline void getMultiLongs(std::vector< std::vector >&); + inline void getMultiULongs(std::vector< std::vector >&); + inline void getMultiFloats(std::vector< std::vector >&); + inline void getMultiDoubles(std::vector< std::vector >&); + inline void getMultiStrings(std::vector< std::vector >&); + + // defaults value regardless of being set by user. + std::string defaults; + // If expects arguments, this will delimit arg list. + char delim; + // If not 0, then number of delimited args. -1 for arbitrary number. + int expectArgs; + // Descriptive help message shown in usage instructions for option. + std::string help; + // 0 or 1. + bool isRequired; + // A list of flags that denote this option, i.e. -d, --dimension. + std::vector< std::string* > flags; + // If was set (or found). + bool isSet; + // Lists of arguments, per flag instance, after splitting by delimiter. + std::vector< std::vector< std::string* > * > args; + // Index where each group was parsed from input stream to track order. + std::vector parseIndex; +}; +/* ################################################################### */ +void OptionGroup::clearArgs() { + int i,j; + for(i=0; i < (long int)args.size(); ++i) { + for(j=0; j < (long int)args[i]->size(); ++j) + delete args[i]->at(j); + + delete args[i]; + } + + args.clear(); + isSet = false; +}; +/* ################################################################### */ +void OptionGroup::getInt(int & out) { + if (!isSet) { + if (defaults.empty()) + out = 0; + else + out = atoi(defaults.c_str()); + } else { + if (args.empty() || args[0]->empty()) + out = 0; + else { + out = atoi(args[0]->at(0)->c_str()); + } + } +}; +/* ################################################################### */ +void OptionGroup::getLong(long & out) { + if (!isSet) { + if (defaults.empty()) + out = 0; + else + out = atoi(defaults.c_str()); + } else { + if (args.empty() || args[0]->empty()) + out = 0; + else { + out = atol(args[0]->at(0)->c_str()); + } + } +}; +/* ################################################################### */ +void OptionGroup::getLongLong(long long & out) { + if (!isSet) { + if (defaults.empty()) + out = 0; + else { + std::stringstream ss(defaults.c_str()); + ss >> out; + } + } else { + if (args.empty() || args[0]->empty()) + out = 0; + else { + std::stringstream ss(args[0]->at(0)->c_str()); + ss >> out; + } + } +}; +/* ################################################################### */ +void OptionGroup::getULong(unsigned long & out) { + if (!isSet) { + if (defaults.empty()) + out = 0; + else + out = atoi(defaults.c_str()); + } else { + if (args.empty() || args[0]->empty()) + out = 0; + else { + out = strtoul(args[0]->at(0)->c_str(),0,0); + } + } +}; +/* ################################################################### */ +void OptionGroup::getULongLong(unsigned long long & out) { + if (!isSet) { + if (defaults.empty()) + out = 0; + else { + std::stringstream ss(defaults.c_str()); + ss >> out; + } + } else { + if (args.empty() || args[0]->empty()) + out = 0; + else { + std::stringstream ss(args[0]->at(0)->c_str()); + ss >> out; + } + } +}; +/* ################################################################### */ +void OptionGroup::getFloat(float & out) { + if (!isSet) { + if (defaults.empty()) + out = 0.0; + else + out = (float)atof(defaults.c_str()); + } else { + if (args.empty() || args[0]->empty()) + out = 0.0; + else { + out = (float)atof(args[0]->at(0)->c_str()); + } + } +}; +/* ################################################################### */ +void OptionGroup::getDouble(double & out) { + if (!isSet) { + if (defaults.empty()) + out = 0.0; + else + out = atof(defaults.c_str()); + } else { + if (args.empty() || args[0]->empty()) + out = 0.0; + else { + out = atof(args[0]->at(0)->c_str()); + } + } +}; +/* ################################################################### */ +void OptionGroup::getString(std::string & out) { + if (!isSet) { + out = defaults; + } else { + if (args.empty() || args[0]->empty()) + out = ""; + else { + out = *args[0]->at(0); + } + } +}; +/* ################################################################### */ +void OptionGroup::getInts(std::vector & out) { + if (!isSet) { + if (!defaults.empty()) { + std::vector< std::string > strings; + SplitDelim(defaults, delim, strings); + StringsToInts(strings, out); + } + } else { + if (!(args.empty() || args[0]->empty())) + StringsToInts(args[0], &out); + } +}; +/* ################################################################### */ +void OptionGroup::getLongs(std::vector & out) { + if (!isSet) { + if (!defaults.empty()) { + std::vector< std::string > strings; + SplitDelim(defaults, delim, strings); + StringsToLongs(strings, out); + } + } else { + if (!(args.empty() || args[0]->empty())) + StringsToLongs(args[0], &out); + } +}; +/* ################################################################### */ +void OptionGroup::getULongs(std::vector & out) { + if (!isSet) { + if (!defaults.empty()) { + std::vector< std::string > strings; + SplitDelim(defaults, delim, strings); + StringsToULongs(strings, out); + } + } else { + if (!(args.empty() || args[0]->empty())) + StringsToULongs(args[0], &out); + } +}; +/* ################################################################### */ +void OptionGroup::getFloats(std::vector & out) { + if (!isSet) { + if (!defaults.empty()) { + std::vector< std::string > strings; + SplitDelim(defaults, delim, strings); + StringsToFloats(strings, out); + } + } else { + if (!(args.empty() || args[0]->empty())) + StringsToFloats(args[0], &out); + } +}; +/* ################################################################### */ +void OptionGroup::getDoubles(std::vector & out) { + if (!isSet) { + if (!defaults.empty()) { + std::vector< std::string > strings; + SplitDelim(defaults, delim, strings); + StringsToDoubles(strings, out); + } + } else { + if (!(args.empty() || args[0]->empty())) + StringsToDoubles(args[0], &out); + } +}; +/* ################################################################### */ +void OptionGroup::getStrings(std::vector& out) { + if (!isSet) { + if (!defaults.empty()) { + SplitDelim(defaults, delim, out); + } + } else { + if (!(args.empty() || args[0]->empty())) + StringsToStrings(args[0], &out); + } +}; +/* ################################################################### */ +void OptionGroup::getMultiInts(std::vector< std::vector >& out) { + if (!isSet) { + if (!defaults.empty()) { + std::vector< std::string > strings; + SplitDelim(defaults, delim, strings); + if (out.size() < 1) out.resize(1); + StringsToInts(strings, out[0]); + } + } else { + if (!args.empty()) { + int n = args.size(); + if ((long int)out.size() < n) out.resize(n); + for(int i=0; i < n; ++i) { + StringsToInts(args[i], &out[i]); + } + } + } +}; +/* ################################################################### */ +void OptionGroup::getMultiLongs(std::vector< std::vector >& out) { + if (!isSet) { + if (!defaults.empty()) { + std::vector< std::string > strings; + SplitDelim(defaults, delim, strings); + if (out.size() < 1) out.resize(1); + StringsToLongs(strings, out[0]); + } + } else { + if (!args.empty()) { + int n = args.size(); + if ((long int)out.size() < n) out.resize(n); + for(int i=0; i < n; ++i) { + StringsToLongs(args[i], &out[i]); + } + } + } +}; +/* ################################################################### */ +void OptionGroup::getMultiULongs(std::vector< std::vector >& out) { + if (!isSet) { + if (!defaults.empty()) { + std::vector< std::string > strings; + SplitDelim(defaults, delim, strings); + if (out.size() < 1) out.resize(1); + StringsToULongs(strings, out[0]); + } + } else { + if (!args.empty()) { + int n = args.size(); + if ((long int)out.size() < n) out.resize(n); + for(int i=0; i < n; ++i) { + StringsToULongs(args[i], &out[i]); + } + } + } +}; +/* ################################################################### */ +void OptionGroup::getMultiFloats(std::vector< std::vector >& out) { + if (!isSet) { + if (!defaults.empty()) { + std::vector< std::string > strings; + SplitDelim(defaults, delim, strings); + if (out.size() < 1) out.resize(1); + StringsToFloats(strings, out[0]); + } + } else { + if (!args.empty()) { + int n = args.size(); + if ((long int)out.size() < n) out.resize(n); + for(int i=0; i < n; ++i) { + StringsToFloats(args[i], &out[i]); + } + } + } +}; +/* ################################################################### */ +void OptionGroup::getMultiDoubles(std::vector< std::vector >& out) { + if (!isSet) { + if (!defaults.empty()) { + std::vector< std::string > strings; + SplitDelim(defaults, delim, strings); + if (out.size() < 1) out.resize(1); + StringsToDoubles(strings, out[0]); + } + } else { + if (!args.empty()) { + int n = args.size(); + if ((long int)out.size() < n) out.resize(n); + for(int i=0; i < n; ++i) { + StringsToDoubles(args[i], &out[i]); + } + } + } +}; +/* ################################################################### */ +void OptionGroup::getMultiStrings(std::vector< std::vector >& out) { + if (!isSet) { + if (!defaults.empty()) { + std::vector< std::string > strings; + SplitDelim(defaults, delim, strings); + if (out.size() < 1) out.resize(1); + out[0] = strings; + } + } else { + if (!args.empty()) { + int n = args.size(); + if ((long int)out.size() < n) out.resize(n); + + for(int i=0; i < n; ++i) { + for(int j=0; j < (long int)args[i]->size(); ++j) + out[i].push_back( *args[i]->at(j) ); + } + } + } +}; +/* ################################################################### */ +typedef std::map< int, ezOptionValidator* > ValidatorMap; + +class ezOptionParser { +public: + // How to layout usage descriptions with the option flags. + enum Layout { ALIGN, INTERLEAVE, STAGGER }; + + inline ~ezOptionParser(); + + inline void add(const char * defaults, bool required, int expectArgs, char delim, const char * help, const char * flag1, ezOptionValidator* validator=0); + inline void add(const char * defaults, bool required, int expectArgs, char delim, const char * help, const char * flag1, const char * flag2, ezOptionValidator* validator=0); + inline void add(const char * defaults, bool required, int expectArgs, char delim, const char * help, const char * flag1, const char * flag2, const char * flag3, ezOptionValidator* validator=0); + inline void add(const char * defaults, bool required, int expectArgs, char delim, const char * help, const char * flag1, const char * flag2, const char * flag3, const char * flag4, ezOptionValidator* validator=0); + inline bool exportFile(const char * filename, bool all=false); + inline OptionGroup * get(const char * name); + inline void getUsage(std::string & usage, int width=80, Layout layout=ALIGN); + inline void getUsageDescriptions(std::string & usage, int width=80, Layout layout=STAGGER); + inline bool gotExpected(std::vector & badOptions); + inline bool gotRequired(std::vector & badOptions); + inline bool gotValid(std::vector & badOptions, std::vector & badArgs); + inline bool importFile(const char * filename, char comment='#'); + inline int isSet(const char * name); + inline int isSet(std::string & name); + inline void parse(int argc, const char * argv[]); + inline void prettyPrint(std::string & out); + inline void reset(); + inline void resetArgs(); + + // Insert extra empty line betwee each option's usage description. + char doublespace; + // General description in human language on what the user's tool does. + // It's the first section to get printed in the full usage message. + std::string overview; + // A synopsis of command and options usage to show expected order of input arguments. + // It's the second section to get printed in the full usage message. + std::string syntax; + // Example (third) section in usage message. + std::string example; + // Final section printed in usage message. For contact, copyrights, version info. + std::string footer; + // Map from an option to an Id of its parent group. + std::map< std::string, int > optionGroupIds; + // Unordered collection of the option groups. + std::vector< OptionGroup* > groups; + // Store unexpected args in input. + std::vector< std::string* > unknownArgs; + // List of args that occur left-most before first option flag. + std::vector< std::string* > firstArgs; + // List of args that occur after last right-most option flag and its args. + std::vector< std::string* > lastArgs; + // List of validators. + ValidatorMap validators; + // Maps group id to a validator index into vector of validators. Validator index is -1 if there is no validator for group. + std::map< int, int > groupValidators; +}; +/* ################################################################### */ +ezOptionParser::~ezOptionParser() { + reset(); +} +/* ################################################################### */ +void ezOptionParser::reset() { + this->doublespace = 1; + + int i; + for(i=0; i < (long int)groups.size(); ++i) + delete groups[i]; + groups.clear(); + + for(i=0; i < (long int)unknownArgs.size(); ++i) + delete unknownArgs[i]; + unknownArgs.clear(); + + for(i=0; i < (long int)firstArgs.size(); ++i) + delete firstArgs[i]; + firstArgs.clear(); + + for(i=0; i < (long int)lastArgs.size(); ++i) + delete lastArgs[i]; + lastArgs.clear(); + + ValidatorMap::iterator it; + for(it = validators.begin(); it != validators.end(); ++it) + delete it->second; + + validators.clear(); + optionGroupIds.clear(); + groupValidators.clear(); +}; +/* ################################################################### */ +void ezOptionParser::resetArgs() { + int i; + for(i=0; i < (long int)groups.size(); ++i) + groups[i]->clearArgs(); + + for(i=0; i < (long int)unknownArgs.size(); ++i) + delete unknownArgs[i]; + unknownArgs.clear(); + + for(i=0; i < (long int)firstArgs.size(); ++i) + delete firstArgs[i]; + firstArgs.clear(); + + for(i=0; i < (long int)lastArgs.size(); ++i) + delete lastArgs[i]; + lastArgs.clear(); +}; +/* ################################################################### */ +void ezOptionParser::add(const char * defaults, bool required, int expectArgs, char delim, const char * help, const char * flag1, ezOptionValidator* validator) { + int id = this->groups.size(); + OptionGroup * g = new OptionGroup; + g->defaults = defaults; + g->isRequired = required; + g->expectArgs = expectArgs; + g->delim = delim; + g->isSet = 0; + g->help = help; + std::string *f1 = new std::string(flag1); + g->flags.push_back( f1 ); + this->optionGroupIds[flag1] = id; + this->groups.push_back(g); + + if (validator) { + int vid = validator->id; + validators[vid] = validator; + groupValidators[id] = vid; + } else { + groupValidators[id] = -1; + } +}; +/* ################################################################### */ +void ezOptionParser::add(const char * defaults, bool required, int expectArgs, char delim, const char * help, const char * flag1, const char * flag2, ezOptionValidator* validator) { + int id = this->groups.size(); + OptionGroup * g = new OptionGroup; + g->defaults = defaults; + g->isRequired = required; + g->expectArgs = expectArgs; + g->delim = delim; + g->isSet = 0; + g->help = help; + std::string *f1 = new std::string(flag1); + g->flags.push_back( f1 ); + std::string *f2 = new std::string(flag2); + g->flags.push_back( f2 ); + this->optionGroupIds[flag1] = id; + this->optionGroupIds[flag2] = id; + + this->groups.push_back(g); + + if (validator) { + int vid = validator->id; + validators[vid] = validator; + groupValidators[id] = vid; + } else { + groupValidators[id] = -1; + } +}; +/* ################################################################### */ +void ezOptionParser::add(const char * defaults, bool required, int expectArgs, char delim, const char * help, const char * flag1, const char * flag2, const char * flag3, ezOptionValidator* validator) { + int id = this->groups.size(); + OptionGroup * g = new OptionGroup; + g->defaults = defaults; + g->isRequired = required; + g->expectArgs = expectArgs; + g->delim = delim; + g->isSet = 0; + g->help = help; + std::string *f1 = new std::string(flag1); + g->flags.push_back( f1 ); + std::string *f2 = new std::string(flag2); + g->flags.push_back( f2 ); + std::string *f3 = new std::string(flag3); + g->flags.push_back( f3 ); + this->optionGroupIds[flag1] = id; + this->optionGroupIds[flag2] = id; + this->optionGroupIds[flag3] = id; + + this->groups.push_back(g); + + if (validator) { + int vid = validator->id; + validators[vid] = validator; + groupValidators[id] = vid; + } else { + groupValidators[id] = -1; + } +}; +/* ################################################################### */ +void ezOptionParser::add(const char * defaults, bool required, int expectArgs, char delim, const char * help, const char * flag1, const char * flag2, const char * flag3, const char * flag4, ezOptionValidator* validator) { + int id = this->groups.size(); + OptionGroup * g = new OptionGroup; + g->defaults = defaults; + g->isRequired = required; + g->expectArgs = expectArgs; + g->delim = delim; + g->isSet = 0; + g->help = help; + std::string *f1 = new std::string(flag1); + g->flags.push_back( f1 ); + std::string *f2 = new std::string(flag2); + g->flags.push_back( f2 ); + std::string *f3 = new std::string(flag3); + g->flags.push_back( f3 ); + std::string *f4 = new std::string(flag4); + g->flags.push_back( f4 ); + this->optionGroupIds[flag1] = id; + this->optionGroupIds[flag2] = id; + this->optionGroupIds[flag3] = id; + this->optionGroupIds[flag4] = id; + + this->groups.push_back(g); + + if (validator) { + int vid = validator->id; + validators[vid] = validator; + groupValidators[id] = vid; + } else { + groupValidators[id] = -1; + } +}; +/* ################################################################### */ +bool ezOptionParser::exportFile(const char * filename, bool all) { + int i; + std::string out; + bool quote; + + // Export the first args, except the program name, so start from 1. + for(i=1; i < (long int)firstArgs.size(); ++i) { + quote = ((firstArgs[i]->find_first_of(" \t") != std::string::npos) && (firstArgs[i]->find_first_of("\'\"") == std::string::npos)); + + if (quote) + out.append("\""); + + out.append(*firstArgs[i]); + if (quote) + out.append("\""); + + out.append(" "); + } + + if (firstArgs.size() > 1) + out.append("\n"); + + std::vector stringPtrs(groups.size()); + int m; + int n = groups.size(); + for(i=0; i < n; ++i) { + stringPtrs[i] = groups[i]->flags[0]; + } + + OptionGroup *g; + // Sort first flag of each group with other groups. + std::sort(stringPtrs.begin(), stringPtrs.end(), CmpOptStringPtr); + for(i=0; i < n; ++i) { + g = get(stringPtrs[i]->c_str()); + if (g->isSet || all) { + if (!g->isSet || g->args.empty()) { + if (!g->defaults.empty()) { + out.append(*stringPtrs[i]); + out.append(" "); + quote = ((g->defaults.find_first_of(" \t") != std::string::npos) && (g->defaults.find_first_of("\'\"") == std::string::npos)); + if (quote) + out.append("\""); + + out.append(g->defaults); + if (quote) + out.append("\""); + + out.append("\n"); + } + } else { + int n = g->args.size(); + for(int j=0; j < n; ++j) { + out.append(*stringPtrs[i]); + out.append(" "); + m = g->args[j]->size(); + + for(int k=0; k < m; ++k) { + quote = ( (*g->args[j]->at(k)).find_first_of(" \t") != std::string::npos ); + if (quote) + out.append("\""); + + out.append(*g->args[j]->at(k)); + if (quote) + out.append("\""); + + if ((g->delim) && ((k+1) != m)) + out.append(1,g->delim); + } + out.append("\n"); + } + } + } + } + + // Export the last args. + for(i=0; i < (long int)lastArgs.size(); ++i) { + quote = ( lastArgs[i]->find_first_of(" \t") != std::string::npos ); + if (quote) + out.append("\""); + + out.append(*lastArgs[i]); + if (quote) + out.append("\""); + + out.append(" "); + } + + std::ofstream file(filename); + if (!file.is_open()) + return false; + + file << out; + file.close(); + + return true; +}; +/* ################################################################### */ +// Does not overwrite current options. +// Returns true if file was read successfully. +// So if this is used before parsing CLI, then option values will reflect +// this file, but if used after parsing CLI, then values will contain +// both CLI values and file's values. +// +// Comment lines are allowed if prefixed with #. +// Strings should be quoted as usual. +bool ezOptionParser::importFile(const char * filename, char comment) { + std::ifstream file (filename, std::ios::in | std::ios::ate); + if (!file.is_open()) + return false; + + // Read entire file contents. + std::ifstream::pos_type size = file.tellg(); + char * memblock = new char[(int)size+1]; // Add one for end of string. + file.seekg (0, std::ios::beg); + file.read (memblock, size); + memblock[size] = '\0'; + file.close(); + + // Find comment lines. + std::list lines; + std::string memblockstring(memblock); + delete[] memblock; + SplitDelim(memblockstring, '\n', lines); + int i,j,n; + std::list::iterator iter; + std::vector sq, dq; // Single and double quote indices. + std::vector::iterator lo; // For searching quote indices. + size_t pos; + const char *str; + std::string *line; + // Find all single and double quotes to correctly handle comment tokens. + for(iter=lines.begin(); iter != lines.end(); ++iter) { + line = *iter; + str = line->c_str(); + n = line->size(); + sq.clear(); + dq.clear(); + if (n) { + // If first char is comment, then erase line and continue. + pos = line->find_first_not_of(" \t\r"); + if ((pos==std::string::npos) || (line->at(pos)==comment)) { + line->erase(); + continue; + } else { + // Erase whitespace prefix. + line->erase(0,pos); + n = line->size(); + } + + if (line->at(0)=='"') + dq.push_back(0); + + if (line->at(0)=='\'') + sq.push_back(0); + } else { // Empty line. + continue; + } + + for(i=1; i < n; ++i) { + if ( (str[i]=='"') && (str[i-1]!='\\') ) + dq.push_back(i); + else if ( (str[i]=='\'') && (str[i-1]!='\\') ) + sq.push_back(i); + } + // Scan for comments, and when found, check bounds of quotes. + // Start with second char because already checked first char. + for(i=1; i < n; ++i) { + if ( (line->at(i)==comment) && (line->at(i-1)!='\\') ) { + // If within open/close quote pair, then not real comment. + if (sq.size()) { + lo = std::lower_bound(sq.begin(), sq.end(), i); + // All start of strings will be even indices, closing quotes is odd indices. + j = (int)(lo-sq.begin()); + if ( (j % 2) == 0) { // Even implies comment char not in quote pair. + // Erase from comment char to end of line. + line->erase(i); + break; + } + } else if (dq.size()) { + // Repeat tests for double quotes. + lo = std::lower_bound(dq.begin(), dq.end(), i); + j = (int)(lo-dq.begin()); + if ( (j % 2) == 0) { + line->erase(i); + break; + } + } else { + // Not in quotes. + line->erase(i); + break; + } + } + } + } + + std::string cmd; + // Convert list to string without newlines to simulate commandline. + for(iter=lines.begin(); iter != lines.end(); ++iter) { + if (! (*iter)->empty()) { + cmd.append(**iter); + cmd.append(" "); + } + } + + // Now parse as if from command line. + int argc=0; + char** argv = CommandLineToArgvA((char*)cmd.c_str(), &argc); + + // Parse. + parse(argc, (const char**)argv); + if (argv) free(argv); + for(iter=lines.begin(); iter != lines.end(); ++iter) + delete *iter; + + return true; +}; +/* ################################################################### */ +int ezOptionParser::isSet(const char * name) { + std::string sname(name); + + if (this->optionGroupIds.count(sname)) { + return this->groups[this->optionGroupIds[sname]]->isSet; + } + + return 0; +}; +/* ################################################################### */ +int ezOptionParser::isSet(std::string & name) { + if (this->optionGroupIds.count(name)) { + return this->groups[this->optionGroupIds[name]]->isSet; + } + + return 0; +}; +/* ################################################################### */ +OptionGroup * ezOptionParser::get(const char * name) { + if (optionGroupIds.count(name)) { + return groups[optionGroupIds[name]]; + } + + return 0; +}; +/* ################################################################### */ +void ezOptionParser::getUsage(std::string & usage, int width, Layout layout) { + + usage.append(overview); + usage.append("\n\n"); + usage.append("USAGE: "); + usage.append(syntax); + usage.append("\n\nOPTIONS:\n\n"); + getUsageDescriptions(usage, width, layout); + + if (!example.empty()) { + usage.append("EXAMPLES:\n\n"); + usage.append(example); + } + + if (!footer.empty()) { + usage.append(footer); + } +}; +/* ################################################################### */ +// Creates 2 column formatted help descriptions for each option flag. +void ezOptionParser::getUsageDescriptions(std::string & usage, int width, Layout layout) { + // Sort each flag list amongst each group. + int i; + // Store index of flag groups before sort for easy lookup later. + std::map stringPtrToIndexMap; + std::vector stringPtrs(groups.size()); + + for(i=0; i < (long int)groups.size(); ++i) { + std::sort(groups[i]->flags.begin(), groups[i]->flags.end(), CmpOptStringPtr); + stringPtrToIndexMap[groups[i]->flags[0]] = i; + stringPtrs[i] = groups[i]->flags[0]; + } + + size_t j, k; + std::string opts; + std::vector sortedOpts; + // Sort first flag of each group with other groups. + std::sort(stringPtrs.begin(), stringPtrs.end(), CmpOptStringPtr); + for(i=0; i < (long int)groups.size(); ++i) { + //printf("DEBUG:%d: %d %d %s\n", __LINE__, i, stringPtrToIndexMap[stringPtrs[i]], stringPtrs[i]->c_str()); + k = stringPtrToIndexMap[stringPtrs[i]]; + opts.clear(); + for(j=0; j < groups[k]->flags.size()-1; ++j) { + opts.append(*groups[k]->flags[j]); + opts.append(", "); + + if ((long int)opts.size() > width) + opts.append("\n"); + } + // The last flag. No need to append comma anymore. + opts.append( *groups[k]->flags[j] ); + + if (groups[k]->expectArgs) { + opts.append(" ARG"); + + if (groups[k]->delim) { + opts.append("1["); + opts.append(1, groups[k]->delim); + opts.append("ARGn]"); + } + } + + sortedOpts.push_back(opts); + } + + // Each option group will use this to build multiline help description. + std::list desc; + // Number of whitespaces from start of line to description (interleave layout) or + // gap between flag names and description (align, stagger layouts). + int gutter = 3; + + // Find longest opt flag string to set column start for help usage descriptions. + int maxlen=0; + if (layout == ALIGN) { + for(i=0; i < (long int)groups.size(); ++i) { + if (maxlen < (long int)sortedOpts[i].size()) + maxlen = sortedOpts[i].size(); + } + } + + // The amount of space remaining on a line for help text after flags. + int helpwidth; + std::list::iterator cIter, insertionIter; + size_t pos; + for(i=0; i < (long int)groups.size(); ++i) { + k = stringPtrToIndexMap[stringPtrs[i]]; + + if (layout == STAGGER) + maxlen = sortedOpts[i].size(); + + int pad = gutter + maxlen; + helpwidth = width - pad; + + // All the following split-fu could be optimized by just using substring (offset, length) tuples, but just to get it done, we'll do some not-too expensive string copying. + SplitDelim(groups[k]->help, '\n', desc); + // Split lines longer than allowable help width. + for(insertionIter=desc.begin(), cIter=insertionIter++; + cIter != desc.end(); + cIter=insertionIter++) { + if ((long int)((*cIter)->size()) > helpwidth) { + // Get pointer to next string to insert new strings before it. + std::string *rem = *cIter; + // Remove this line and add back in pieces. + desc.erase(cIter); + // Loop until remaining string is short enough. + while ((long int)rem->size() > helpwidth) { + // Find whitespace to split before helpwidth. + if (rem->at(helpwidth) == ' ') { + // If word ends exactly at helpwidth, then split after it. + pos = helpwidth; + } else { + // Otherwise, split occurs midword, so find whitespace before this word. + pos = rem->rfind(" ", helpwidth); + } + // Insert split string. + desc.insert(insertionIter, new std::string(*rem, 0, pos)); + // Now skip any whitespace to start new line. + pos = rem->find_first_not_of(' ', pos); + rem->erase(0, pos); + } + + if (rem->size()) + desc.insert(insertionIter, rem); + else + delete rem; + } + } + + usage.append(sortedOpts[i]); + if (layout != INTERLEAVE) + // Add whitespace between option names and description. + usage.append(pad - sortedOpts[i].size(), ' '); + else { + usage.append("\n"); + usage.append(gutter, ' '); + } + + if (desc.size() > 0) { // Crash fix by Bruce Shankle. + // First line already padded above (before calling SplitDelim) after option flag names. + cIter = desc.begin(); + usage.append(**cIter); + usage.append("\n"); + // Now inject the pad for each line. + for(++cIter; cIter != desc.end(); ++cIter) { + usage.append(pad, ' '); + usage.append(**cIter); + usage.append("\n"); + } + + if (this->doublespace) usage.append("\n"); + + for(cIter=desc.begin(); cIter != desc.end(); ++cIter) + delete *cIter; + + desc.clear(); + } + + } +}; +/* ################################################################### */ +bool ezOptionParser::gotExpected(std::vector & badOptions) { + int i,j; + + for(i=0; i < (long int)groups.size(); ++i) { + OptionGroup *g = groups[i]; + // If was set, ensure number of args is correct. + if (g->isSet) { + if ((g->expectArgs != 0) && g->args.empty()) { + badOptions.push_back(*g->flags[0]); + continue; + } + + for(j=0; j < (long int)g->args.size(); ++j) { + if ((g->expectArgs != -1) && (g->expectArgs != (long int)g->args[j]->size())) + badOptions.push_back(*g->flags[0]); + } + } + } + + return badOptions.empty(); +}; +/* ################################################################### */ +bool ezOptionParser::gotRequired(std::vector & badOptions) { + int i; + + for(i=0; i < (long int)groups.size(); ++i) { + OptionGroup *g = groups[i]; + // Simple case when required but user never set it. + if (g->isRequired && (!g->isSet)) { + badOptions.push_back(*g->flags[0]); + continue; + } + } + + return badOptions.empty(); +}; +/* ################################################################### */ +bool ezOptionParser::gotValid(std::vector & badOptions, std::vector & badArgs) { + int groupid, validatorid; + std::map< int, int >::iterator it; + + for(it = groupValidators.begin(); it != groupValidators.end(); ++it) { + groupid = it->first; + validatorid = it->second; + if (validatorid < 0) continue; + + OptionGroup *g = groups[groupid]; + ezOptionValidator *v = validators[validatorid]; + bool nextgroup = false; + + for (int i = 0; i < (long int)g->args.size(); ++i) { + if (nextgroup) break; + std::vector< std::string* > * args = g->args[i]; + for (int j = 0; j < (long int)args->size(); ++j) { + if (!v->isValid(args->at(j))) { + badOptions.push_back(*g->flags[0]); + badArgs.push_back(*args->at(j)); + nextgroup = true; + break; + } + } + } + } + + return badOptions.empty(); +}; +/* ################################################################### */ +void ezOptionParser::parse(int argc, const char * argv[]) { + if (argc < 1) return; + + /* + std::map::iterator it; + for ( it=optionGroupIds.begin() ; it != optionGroupIds.end(); it++ ) + std::cout << (*it).first << " => " << (*it).second << std::endl; + */ + + int i, k, firstOptIndex=0, lastOptIndex=0; + std::string s; + OptionGroup *g; + + for(i=0; i < argc; ++i) { + s = argv[i]; + + if (optionGroupIds.count(s)) + break; + } + + firstOptIndex = i; + + if (firstOptIndex == argc) { + // No flags encountered, so set last args. + this->firstArgs.push_back(new std::string(argv[0])); + + for(k=1; k < argc; ++k) + this->lastArgs.push_back(new std::string(argv[k])); + + return; + } + + // Store initial args before opts appear. + for(k=0; k < i; ++k) { + this->firstArgs.push_back(new std::string(argv[k])); + } + + for(; i < argc; ++i) { + s = argv[i]; + + if (optionGroupIds.count(s)) { + k = optionGroupIds[s]; + g = groups[k]; + g->isSet = 1; + g->parseIndex.push_back(i); + + if (g->expectArgs) { + // Read ahead to get args. + ++i; + if (i >= argc) return; + g->args.push_back(new std::vector); + SplitDelim(argv[i], g->delim, g->args.back()); + } + lastOptIndex = i; + } + } + + // Scan for unknown opts/arguments. + for(i=firstOptIndex; i <= lastOptIndex; ++i) { + s = argv[i]; + + if (optionGroupIds.count(s)) { + k = optionGroupIds[s]; + g = groups[k]; + if (g->expectArgs) { + // Read ahead for args and skip them. + ++i; + } + } else { + unknownArgs.push_back(new std::string(argv[i])); + } + } + + if ( lastOptIndex >= (argc-1) ) return; + + // Store final args without flags. + for(k=lastOptIndex + 1; k < argc; ++k) { + this->lastArgs.push_back(new std::string(argv[k])); + } +}; +/* ################################################################### */ +void ezOptionParser::prettyPrint(std::string & out) { + char tmp[256]; + int i,j,k; + + out += "First Args:\n"; + for(i=0; i < (long int)firstArgs.size(); ++i) { + sprintf(tmp, "%d: %s\n", i+1, firstArgs[i]->c_str()); + out += tmp; + } + + // Sort the option flag names. + int n = groups.size(); + std::vector stringPtrs(n); + for(i=0; i < n; ++i) { + stringPtrs[i] = groups[i]->flags[0]; + } + + // Sort first flag of each group with other groups. + std::sort(stringPtrs.begin(), stringPtrs.end(), CmpOptStringPtr); + + out += "\nOptions:\n"; + OptionGroup *g; + for(i=0; i < n; ++i) { + g = get(stringPtrs[i]->c_str()); + out += "\n"; + // The flag names: + for(j=0; j < (long int)g->flags.size()-1; ++j) { + sprintf(tmp, "%s, ", g->flags[j]->c_str()); + out += tmp; + } + sprintf(tmp, "%s:\n", g->flags.back()->c_str()); + out += tmp; + + if (g->isSet) { + if (g->expectArgs) { + if (g->args.empty()) { + sprintf(tmp, "%s (default)\n", g->defaults.c_str()); + out += tmp; + } else { + for(k=0; k < (long int)g->args.size(); ++k) { + for(j=0; j < (long int)g->args[k]->size()-1; ++j) { + sprintf(tmp, "%s%c", g->args[k]->at(j)->c_str(), g->delim); + out += tmp; + } + sprintf(tmp, "%s\n", g->args[k]->back()->c_str()); + out += tmp; + } + } + } else { // Set but no args expected. + sprintf(tmp, "Set\n"); + out += tmp; + } + } else { + sprintf(tmp, "Not set\n"); + out += tmp; + } + } + + out += "\nLast Args:\n"; + for(i=0; i < (long int)lastArgs.size(); ++i) { + sprintf(tmp, "%d: %s\n", i+1, lastArgs[i]->c_str()); + out += tmp; + } + + out += "\nUnknown Args:\n"; + for(i=0; i < (long int)unknownArgs.size(); ++i) { + sprintf(tmp, "%d: %s\n", i+1, unknownArgs[i]->c_str()); + out += tmp; + } +}; +} +/* ################################################################### */ +#endif /* EZ_OPTION_PARSER_H */ diff --git a/Tools/int.h b/Tools/int.h new file mode 100644 index 000000000..78e253bcb --- /dev/null +++ b/Tools/int.h @@ -0,0 +1,68 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * int.h + * + */ + +#ifndef TOOLS_INT_H_ +#define TOOLS_INT_H_ + + +typedef unsigned char octet; + +// Assumes word is a 64 bit value +#ifdef WIN32 + typedef unsigned __int64 word; +#else + typedef unsigned long word; +#endif + + +inline int CEIL_LOG2(int x) +{ + int result = 0; + x--; + while (x > 0) + { + result++; + x >>= 1; + } + return result; +} + +inline int FLOOR_LOG2(int x) +{ + int result = 0; + while (x > 1) + { + result++; + x >>= 1; + } + return result; +} + +// ceil(n / k) +inline int DIV_CEIL(long n, int k) +{ + return (n + k - 1)/k; +} + +inline void INT_TO_BYTES(octet *buff, int x) +{ + buff[0] = x&255; + buff[1] = (x>>8)&255; + buff[2] = (x>>16)&255; + buff[3] = (x>>24)&255; +} + +inline int BYTES_TO_INT(octet *buff) +{ + return buff[0] + 256*buff[1] + 65536*buff[2] + 16777216*buff[3]; +} + +inline int positive_modulo(int i, int n) { + return (i % n + n) % n; +} + +#endif /* TOOLS_INT_H_ */ diff --git a/Tools/mkpath.cpp b/Tools/mkpath.cpp new file mode 100644 index 000000000..a09525c8e --- /dev/null +++ b/Tools/mkpath.cpp @@ -0,0 +1,47 @@ +// (C) 2016 University of Bristol. See License.txt + +#include "Tools/mkpath.h" +#include +#include /* PATH_MAX */ +#include /* mkdir(2) */ +#include + +// mkdir -p, from https://gist.github.com/JonathonReinhart/8c0d90191c38af2dcadb102c4e202950 +int mkdir_p(const char *path) +{ + /* Adapted from http://stackoverflow.com/a/2336245/119527 */ + const size_t len = strlen(path); + char _path[PATH_MAX]; + char *p; + + errno = 0; + + /* Copy string so its mutable */ + if (len > sizeof(_path)-1) { + errno = ENAMETOOLONG; + return -1; + } + strcpy(_path, path); + + /* Iterate the string */ + for (p = _path + 1; *p; p++) { + if (*p == '/') { + /* Temporarily truncate */ + *p = '\0'; + + if (mkdir(_path, S_IRWXU) != 0) { + if (errno != EEXIST) + return -1; + } + + *p = '/'; + } + } + + if (mkdir(_path, S_IRWXU) != 0) { + if (errno != EEXIST) + return -1; + } + + return 0; +} \ No newline at end of file diff --git a/Tools/mkpath.h b/Tools/mkpath.h new file mode 100644 index 000000000..4eca401e3 --- /dev/null +++ b/Tools/mkpath.h @@ -0,0 +1,9 @@ +// (C) 2016 University of Bristol. See License.txt + +#ifndef TOOLS_MKPATH_H_ +#define TOOLS_MKPATH_H_ + +// mkdir -p, from https://gist.github.com/JonathonReinhart/8c0d90191c38af2dcadb102c4e202950 +int mkdir_p(const char *path); + +#endif /* TOOLS_MKPATH_H_ */ diff --git a/Tools/octetStream.cpp b/Tools/octetStream.cpp new file mode 100644 index 000000000..d3d83dbb7 --- /dev/null +++ b/Tools/octetStream.cpp @@ -0,0 +1,183 @@ +// (C) 2016 University of Bristol. See License.txt + + +#include + +#include "octetStream.h" +#include +#include "Networking/sockets.h" +#include "Tools/sha1.h" +#include "Exceptions/Exceptions.h" +#include "Networking/data.h" + + +void octetStream::assign(const octetStream& os) +{ + if (os.len>=mxlen) + { + if (data) + delete[] data; + mxlen=os.mxlen; + data=new octet[mxlen]; + } + len=os.len; + memcpy(data,os.data,len*sizeof(octet)); + ptr=os.ptr; +} + + +octetStream::octetStream(int maxlen) +{ + mxlen=maxlen; len=0; ptr=0; + data=new octet[mxlen]; +} + + +octetStream::octetStream(const octetStream& os) +{ + mxlen=os.mxlen; + len=os.len; + data=new octet[mxlen]; + memcpy(data,os.data,len*sizeof(octet)); + ptr=os.ptr; +} + + +void octetStream::hash(octetStream& output) const +{ + blk_SHA_CTX ctx; + blk_SHA1_Init(&ctx); + blk_SHA1_Update(&ctx,data,len); + blk_SHA1_Final(output.data,&ctx); + output.len=HASH_SIZE; +} + + +octetStream octetStream::hash() const +{ + octetStream h(HASH_SIZE); + hash(h); + return h; +} + + +bigint octetStream::check_sum() const +{ + unsigned char hash[HASH_SIZE]; + + blk_SHA_CTX ctx; + blk_SHA1_Init(&ctx); + blk_SHA1_Update(&ctx,data,len); + blk_SHA1_Final(hash,&ctx); + + bigint ans; + bigintFromBytes(ans,hash,HASH_SIZE); + return ans; +} + + + +bool octetStream::equals(const octetStream& a) const +{ + if (len!=a.len) { return false; } + for (int i=0; i>4; + s << hex << t1 << t0 << dec; + } + return s; +} + + + + diff --git a/Tools/octetStream.h b/Tools/octetStream.h new file mode 100644 index 000000000..13487decb --- /dev/null +++ b/Tools/octetStream.h @@ -0,0 +1,161 @@ +// (C) 2016 University of Bristol. See License.txt + +#ifndef _octetStream +#define _octetStream + + +/* This class creates a stream of data and adds stuff onto it. + * This is used to pack and unpack stuff which is sent over the + * network + * + * Unlike SPDZ-1.0 this class ONLY deals with native types + * For our types we assume pack/unpack operations defined within + * that class. This is to make sure this class is relatively independent + * of the rest of the application; and so can be re-used. + */ + + +#include "Math/bigint.h" +#include "Networking/data.h" +#include "Networking/sockets.h" + +#include +#include +#include +#include +using namespace std; + + +class octetStream +{ + int len,mxlen,ptr; // len is the "write head", ptr is the "read head" + octet *data; + + public: + + void resize(int l); + + void assign(const octetStream& os); + + octetStream() : len(0), mxlen(0), ptr(0), data(0) {} + octetStream(int maxlen); + octetStream(const octetStream& os); + octetStream& operator=(const octetStream& os) + { if (this!=&os) { assign(os); } + return *this; + } + ~octetStream() { if(data) delete[] data; } + + int get_ptr() const { return ptr; } + int get_length() const { return len; } + octet* get_data() const { return data; } + + octetStream hash() const; + // output must have length at least HASH_SIZE + void hash(octetStream& output) const; + // The following produces a check sum for debugging purposes + bigint check_sum() const; + + void concat(const octetStream& os); + + void reset_read_head() { ptr=0; } + /* If we reset write head then we should reset the read head as well */ + void reset_write_head() { len=0; ptr=0; } + + // Move len back num + void rewind_write_head(int num) { len-=num; } + + bool equals(const octetStream& a) const; + + /* Append NUM random bytes from dev/random */ + void append_random(int num); + + // Append with no padding for decoding + void append(const octet* x,const int l); + // Read l octets, with no padding for decoding + void consume(octet* x,const int l); + // Return pointer to next l octets and advance pointer + octet* consume(int l) { octet* res = data+ptr; ptr += l; return res; } + + /* Now store and restore different types of data (with padding for decoding) */ + + void store_bytes(octet* x, const int l); //not really "bytes"... + void get_bytes(octet* ans, int& l); //Assumes enough space in ans + + void store(unsigned int a); + void store(int a) { store((unsigned int) a); } + void get(unsigned int& a); + void get(int& a) { get((unsigned int&) a); } + + void store(const bigint& x); + void get(bigint& ans); + + void consume(octetStream& s,int l) + { s.resize(l); + consume(s.data,l); + s.len=l; + } + + void Send(int socket_num) const; + void Receive(int socket_num); + + friend ostream& operator<<(ostream& s,const octetStream& o); + friend class PRNG; +}; + + +inline void octetStream::resize(int l) +{ + if (lmxlen) + resize(len+l); + memcpy(data+len,x,l*sizeof(octet)); + len+=l; +} + + +inline void octetStream::consume(octet* x,const int l) +{ + memcpy(x,data+ptr,l*sizeof(octet)); + ptr+=l; +} + + +inline void octetStream::Send(int socket_num) const +{ + octet blen[4]; + encode_length(blen,len); + send(socket_num,blen,4); + send(socket_num,data,len); +} + + +inline void octetStream::Receive(int socket_num) +{ + octet blen[4]; + receive(socket_num,blen,4); + + int nlen=decode_length(blen); + len=0; + resize(nlen); + len=nlen; + + receive(socket_num,data,len); +} + +#endif + diff --git a/Tools/random.cpp b/Tools/random.cpp new file mode 100644 index 000000000..6d1f236a8 --- /dev/null +++ b/Tools/random.cpp @@ -0,0 +1,195 @@ +// (C) 2016 University of Bristol. See License.txt + + +#include "Tools/random.h" +#include + +#include +using namespace std; + + +PRNG::PRNG() : cnt(0) +{ + #ifdef USE_AES + useC=(Check_CPU_support_AES()==0); + #endif + +} + +void PRNG::ReSeed() +{ + FILE* rD=fopen("/dev/urandom", "r"); + fread(seed,sizeof(octet),SEED_SIZE,rD); + fclose(rD); + InitSeed(); +} + + +void PRNG::SetSeed(octet* inp) +{ + memcpy(seed,inp,SEED_SIZE*sizeof(octet)); + InitSeed(); +} + +void PRNG::InitSeed() +{ + #ifdef USE_AES + if (useC) + { aes_schedule(KeyScheduleC,seed); } + else + { aes_schedule(KeySchedule,seed); } + memset(state,0,RAND_SIZE*sizeof(octet)); + for (int i = 0; i < PIPELINES; i++) + state[i*AES_BLK_SIZE] = i; + #else + memcpy(state,seed,SEED_SIZE*sizeof(octet)); + #endif + next(); + //cout << "SetSeed : "; print_state(); cout << endl; +} + + +void PRNG::print_state() const +{ + int i; + for (i=0; i((__m128i*)random,(__m128i*)state,KeySchedule); } + #endif + // This is a new random value so we have not used any of it yet + cnt=0; +} + + + +void PRNG::next() +{ + // Increment state + for (int i = 0; i < PIPELINES; i++) + { + int64_t* s = (int64_t*)&state[i*AES_BLK_SIZE]; + s[0] += PIPELINES; + if (s[0] == 0) + s[1]++; + } + hash(); +} + + +double PRNG::get_double() +{ + // We need four bytes of randomness + if (cnt>RAND_SIZE-4) { next(); } + unsigned int a0=random[cnt],a1=random[cnt+1],a2=random[cnt+2],a3=random[cnt+3]; + double ans=(a0+(a1<<8)+(a2<<16)+(a3<<24)); + cnt=cnt+4; + unsigned int den=0xFFFFFFFF; + ans=ans/den; + //print_state(); cout << " DBLE " << ans << endl; + return ans; +} + + +unsigned int PRNG::get_uint() +{ + // We need four bytes of randomness + if (cnt>RAND_SIZE-4) { next(); } + unsigned int a0=random[cnt],a1=random[cnt+1],a2=random[cnt+2],a3=random[cnt+3]; + cnt=cnt+4; + unsigned int ans=(a0+(a1<<8)+(a2<<16)+(a3<<24)); + // print_state(); cout << " UINT " << ans << endl; + return ans; +} + + + +unsigned char PRNG::get_uchar() +{ + if (cnt>=RAND_SIZE) { next(); } + unsigned char ans=random[cnt]; + cnt++; + // print_state(); cout << " UCHA " << (int) ans << endl; + return ans; +} + + +__m128i PRNG::get_doubleword() +{ + if (cnt > RAND_SIZE - 16) + next(); + __m128i ans = _mm_loadu_si128((__m128i*)&random[cnt]); + cnt += 16; + return ans; +} + + +void PRNG::get_octetStream(octetStream& ans,int len) +{ + ans.resize(len); + for (int i=0; i +#include + +#if defined(__GNUC__) && (defined(__i386__) || defined(__x86_64__)) + +/* + * Force usage of rol or ror by selecting the one with the smaller constant. + * It _can_ generate slightly smaller code (a constant of 1 is special), but + * perhaps more importantly it's possibly faster on any uarch that does a + * rotate with a loop. + */ + +#define SHA_ASM(op, x, n) ({ unsigned int __res; __asm__(op " %1,%0":"=r" (__res):"i" (n), "0" (x)); __res; }) +#define SHA_ROL(x,n) SHA_ASM("rol", x, n) +#define SHA_ROR(x,n) SHA_ASM("ror", x, n) + +#else + +#define SHA_ROT(X,l,r) (((X) << (l)) | ((X) >> (r))) +#define SHA_ROL(X,n) SHA_ROT(X,n,32-(n)) +#define SHA_ROR(X,n) SHA_ROT(X,32-(n),n) + +#endif + +/* + * If you have 32 registers or more, the compiler can (and should) + * try to change the array[] accesses into registers. However, on + * machines with less than ~25 registers, that won't really work, + * and at least gcc will make an unholy mess of it. + * + * So to avoid that mess which just slows things down, we force + * the stores to memory to actually happen (we might be better off + * with a 'W(t)=(val);asm("":"+m" (W(t))' there instead, as + * suggested by Artur Skawina - that will also make gcc unable to + * try to do the silly "optimize away loads" part because it won't + * see what the value will be). + * + * Ben Herrenschmidt reports that on PPC, the C version comes close + * to the optimized asm with this (ie on PPC you don't want that + * 'volatile', since there are lots of registers). + * + * On ARM we get the best code generation by forcing a full memory barrier + * between each SHA_ROUND, otherwise gcc happily get wild with spilling and + * the stack frame size simply explode and performance goes down the drain. + */ + +#if defined(__i386__) || defined(__x86_64__) + #define setW(x, val) (*(volatile unsigned int *)&W(x) = (val)) +#elif defined(__GNUC__) && defined(__arm__) + #define setW(x, val) do { W(x) = (val); __asm__("":::"memory"); } while (0) +#else + #define setW(x, val) (W(x) = (val)) +#endif + +/* + * Performance might be improved if the CPU architecture is OK with + * unaligned 32-bit loads and a fast ntohl() is available. + * Otherwise fall back to byte loads and shifts which is portable, + * and is faster on architectures with memory alignment issues. + */ + +#if defined(__i386__) || defined(__x86_64__) || \ + defined(_M_IX86) || defined(_M_X64) || \ + defined(__ppc__) || defined(__ppc64__) || \ + defined(__powerpc__) || defined(__powerpc64__) || \ + defined(__s390__) || defined(__s390x__) + +#define get_be32(p) ntohl(*(unsigned int *)(p)) +#define put_be32(p, v) do { *(unsigned int *)(p) = htonl(v); } while (0) + +#else + +#define get_be32(p) ( \ + (*((unsigned char *)(p) + 0) << 24) | \ + (*((unsigned char *)(p) + 1) << 16) | \ + (*((unsigned char *)(p) + 2) << 8) | \ + (*((unsigned char *)(p) + 3) << 0) ) +#define put_be32(p, v) do { \ + unsigned int __v = (v); \ + *((unsigned char *)(p) + 0) = __v >> 24; \ + *((unsigned char *)(p) + 1) = __v >> 16; \ + *((unsigned char *)(p) + 2) = __v >> 8; \ + *((unsigned char *)(p) + 3) = __v >> 0; } while (0) + +#endif + +/* This "rolls" over the 512-bit array */ +#define W(x) (array[(x)&15]) + +/* + * Where do we get the source from? The first 16 iterations get it from + * the input data, the next mix it from the 512-bit array. + */ +#define SHA_SRC(t) get_be32(data + t) +#define SHA_MIX(t) SHA_ROL(W(t+13) ^ W(t+8) ^ W(t+2) ^ W(t), 1) + +#define SHA_ROUND(t, input, fn, constant, A, B, C, D, E) do { \ + unsigned int TEMP = input(t); setW(t, TEMP); \ + E += TEMP + SHA_ROL(A,5) + (fn) + (constant); \ + B = SHA_ROR(B, 2); } while (0) + +#define T_0_15(t, A, B, C, D, E) SHA_ROUND(t, SHA_SRC, (((C^D)&B)^D) , 0x5a827999, A, B, C, D, E ) +#define T_16_19(t, A, B, C, D, E) SHA_ROUND(t, SHA_MIX, (((C^D)&B)^D) , 0x5a827999, A, B, C, D, E ) +#define T_20_39(t, A, B, C, D, E) SHA_ROUND(t, SHA_MIX, (B^C^D) , 0x6ed9eba1, A, B, C, D, E ) +#define T_40_59(t, A, B, C, D, E) SHA_ROUND(t, SHA_MIX, ((B&C)+(D&(B^C))) , 0x8f1bbcdc, A, B, C, D, E ) +#define T_60_79(t, A, B, C, D, E) SHA_ROUND(t, SHA_MIX, (B^C^D) , 0xca62c1d6, A, B, C, D, E ) + +static void blk_SHA1_Block(blk_SHA_CTX *ctx, const unsigned int *data) +{ + unsigned int A,B,C,D,E; + unsigned int array[16]; + + A = ctx->H[0]; + B = ctx->H[1]; + C = ctx->H[2]; + D = ctx->H[3]; + E = ctx->H[4]; + + /* Round 1 - iterations 0-16 take their input from 'data' */ + T_0_15( 0, A, B, C, D, E); + T_0_15( 1, E, A, B, C, D); + T_0_15( 2, D, E, A, B, C); + T_0_15( 3, C, D, E, A, B); + T_0_15( 4, B, C, D, E, A); + T_0_15( 5, A, B, C, D, E); + T_0_15( 6, E, A, B, C, D); + T_0_15( 7, D, E, A, B, C); + T_0_15( 8, C, D, E, A, B); + T_0_15( 9, B, C, D, E, A); + T_0_15(10, A, B, C, D, E); + T_0_15(11, E, A, B, C, D); + T_0_15(12, D, E, A, B, C); + T_0_15(13, C, D, E, A, B); + T_0_15(14, B, C, D, E, A); + T_0_15(15, A, B, C, D, E); + + /* Round 1 - tail. Input from 512-bit mixing array */ + T_16_19(16, E, A, B, C, D); + T_16_19(17, D, E, A, B, C); + T_16_19(18, C, D, E, A, B); + T_16_19(19, B, C, D, E, A); + + /* Round 2 */ + T_20_39(20, A, B, C, D, E); + T_20_39(21, E, A, B, C, D); + T_20_39(22, D, E, A, B, C); + T_20_39(23, C, D, E, A, B); + T_20_39(24, B, C, D, E, A); + T_20_39(25, A, B, C, D, E); + T_20_39(26, E, A, B, C, D); + T_20_39(27, D, E, A, B, C); + T_20_39(28, C, D, E, A, B); + T_20_39(29, B, C, D, E, A); + T_20_39(30, A, B, C, D, E); + T_20_39(31, E, A, B, C, D); + T_20_39(32, D, E, A, B, C); + T_20_39(33, C, D, E, A, B); + T_20_39(34, B, C, D, E, A); + T_20_39(35, A, B, C, D, E); + T_20_39(36, E, A, B, C, D); + T_20_39(37, D, E, A, B, C); + T_20_39(38, C, D, E, A, B); + T_20_39(39, B, C, D, E, A); + + /* Round 3 */ + T_40_59(40, A, B, C, D, E); + T_40_59(41, E, A, B, C, D); + T_40_59(42, D, E, A, B, C); + T_40_59(43, C, D, E, A, B); + T_40_59(44, B, C, D, E, A); + T_40_59(45, A, B, C, D, E); + T_40_59(46, E, A, B, C, D); + T_40_59(47, D, E, A, B, C); + T_40_59(48, C, D, E, A, B); + T_40_59(49, B, C, D, E, A); + T_40_59(50, A, B, C, D, E); + T_40_59(51, E, A, B, C, D); + T_40_59(52, D, E, A, B, C); + T_40_59(53, C, D, E, A, B); + T_40_59(54, B, C, D, E, A); + T_40_59(55, A, B, C, D, E); + T_40_59(56, E, A, B, C, D); + T_40_59(57, D, E, A, B, C); + T_40_59(58, C, D, E, A, B); + T_40_59(59, B, C, D, E, A); + + /* Round 4 */ + T_60_79(60, A, B, C, D, E); + T_60_79(61, E, A, B, C, D); + T_60_79(62, D, E, A, B, C); + T_60_79(63, C, D, E, A, B); + T_60_79(64, B, C, D, E, A); + T_60_79(65, A, B, C, D, E); + T_60_79(66, E, A, B, C, D); + T_60_79(67, D, E, A, B, C); + T_60_79(68, C, D, E, A, B); + T_60_79(69, B, C, D, E, A); + T_60_79(70, A, B, C, D, E); + T_60_79(71, E, A, B, C, D); + T_60_79(72, D, E, A, B, C); + T_60_79(73, C, D, E, A, B); + T_60_79(74, B, C, D, E, A); + T_60_79(75, A, B, C, D, E); + T_60_79(76, E, A, B, C, D); + T_60_79(77, D, E, A, B, C); + T_60_79(78, C, D, E, A, B); + T_60_79(79, B, C, D, E, A); + + ctx->H[0] += A; + ctx->H[1] += B; + ctx->H[2] += C; + ctx->H[3] += D; + ctx->H[4] += E; +} + +void blk_SHA1_Init(blk_SHA_CTX *ctx) +{ + ctx->size = 0; + + /* Initialize H with the magic constants (see FIPS180 for constants) */ + ctx->H[0] = 0x67452301; + ctx->H[1] = 0xefcdab89; + ctx->H[2] = 0x98badcfe; + ctx->H[3] = 0x10325476; + ctx->H[4] = 0xc3d2e1f0; +} + +void blk_SHA1_Update(blk_SHA_CTX *ctx, const void *data, unsigned long len) +{ + unsigned int lenW = ctx->size & 63; + + ctx->size += len; + + /* Read the data into W and process blocks as they get full */ + if (lenW) { + unsigned int left = 64 - lenW; + if (len < left) + left = len; + memcpy(lenW + (char *)ctx->W, data, left); + lenW = (lenW + left) & 63; + len -= left; + data = ((const char *)data + left); + if (lenW) + return; + blk_SHA1_Block(ctx, ctx->W); + } + while (len >= 64) { + blk_SHA1_Block(ctx,(const unsigned int*) data); + data = ((const char *)data + 64); + len -= 64; + } + if (len) + memcpy(ctx->W, data, len); +} + +void blk_SHA1_Final(unsigned char hashout[20], blk_SHA_CTX *ctx) +{ + static const unsigned char pad[64] = { 0x80 }; + unsigned int padlen[2]; + int i; + + /* Pad with a binary 1 (ie 0x80), then zeroes, then length */ + padlen[0] = htonl((uint32_t)(ctx->size >> 29)); + padlen[1] = htonl((uint32_t)(ctx->size << 3)); + + i = ctx->size & 63; + blk_SHA1_Update(ctx, pad, 1+ (63 & (55 - i))); + blk_SHA1_Update(ctx, padlen, 8); + + /* Output hash */ + for (i = 0; i < 5; i++) + put_be32(hashout + i*4, ctx->H[i]); +} diff --git a/Tools/sha1.h b/Tools/sha1.h new file mode 100644 index 000000000..6d2cd1bd2 --- /dev/null +++ b/Tools/sha1.h @@ -0,0 +1,31 @@ +// (C) 2016 University of Bristol. See License.txt + +#ifndef _SHA1 +#define _SHA1 + +/* + * SHA1 routine optimized to do word accesses rather than byte accesses, + * and to avoid unnecessary copies into the context array. + * + * This was initially based on the Mozilla SHA1 implementation, although + * none of the original Mozilla code remains. + */ + +#define HASH_SIZE 20 + +typedef struct { + unsigned long long size; + unsigned int H[5]; + unsigned int W[16]; +} blk_SHA_CTX; + +void blk_SHA1_Init(blk_SHA_CTX *ctx); +void blk_SHA1_Update(blk_SHA_CTX *ctx, const void *dataIn, unsigned long len); +void blk_SHA1_Final(unsigned char hashout[20], blk_SHA_CTX *ctx); + +#define git_SHA_CTX blk_SHA_CTX +#define git_SHA1_Init blk_SHA1_Init +#define git_SHA1_Update blk_SHA1_Update +#define git_SHA1_Final blk_SHA1_Final + +#endif diff --git a/Tools/time-func.cpp b/Tools/time-func.cpp new file mode 100644 index 000000000..54c17d03f --- /dev/null +++ b/Tools/time-func.cpp @@ -0,0 +1,59 @@ +// (C) 2016 University of Bristol. See License.txt + + +#include "Tools/time-func.h" +#include "Exceptions/Exceptions.h" + + +long long timeval_diff(struct timeval *start_time, struct timeval *end_time) +{ struct timeval temp_diff; + struct timeval *difference; + difference=&temp_diff; + difference->tv_sec =end_time->tv_sec -start_time->tv_sec ; + difference->tv_usec=end_time->tv_usec-start_time->tv_usec; + while(difference->tv_usec<0) + { difference->tv_usec+=1000000; + difference->tv_sec -=1; + } + return 1000000LL*difference->tv_sec+difference->tv_usec; +} + +double timeval_diff_in_seconds(struct timeval *start_time, struct timeval *end_time) +{ + return double(timeval_diff(start_time, end_time)) / 1e6; +} + + +long long timespec_diff(struct timespec *start_time, struct timespec *end_time) +{ + long long sec =end_time->tv_sec -start_time->tv_sec ; + long long nsec=end_time->tv_nsec-start_time->tv_nsec; + while(nsec<0) + { nsec+=1000000000; + sec -=1; + } + return 1000000000*sec+nsec; +} + + +double convert_ns_to_seconds(long long x) +{ + return double(x) / 1e9; +} + + +double Timer::elapsed() +{ + long long res = elapsed_time; + if (running) + res += elapsed_since_last_start(); + return convert_ns_to_seconds(res); +} + +double Timer::idle() +{ + if (running) + throw Processor_Error("Timer running."); + else + return convert_ns_to_seconds(elapsed_since_last_start()); +} diff --git a/Tools/time-func.h b/Tools/time-func.h new file mode 100644 index 000000000..858985ff5 --- /dev/null +++ b/Tools/time-func.h @@ -0,0 +1,63 @@ +// (C) 2016 University of Bristol. See License.txt + +#ifndef _timer +#define _timer + +#include /* Wait for Process Termination */ +#include +#include + +#include "Exceptions/Exceptions.h" + +long long timeval_diff(struct timeval *start_time, struct timeval *end_time); +double timeval_diff_in_seconds(struct timeval *start_time, struct timeval *end_time); +long long timespec_diff(struct timespec *start_time, struct timespec *end_time); + +class Timer +{ + public: + Timer(clockid_t clock_id = CLOCK_MONOTONIC) : running(false), elapsed_time(0), clock_id(clock_id) + { clock_gettime(clock_id, &startv); } + Timer& start(); + void stop(); + double elapsed(); + double idle(); + + private: + timespec startv; + bool running; + long long elapsed_time; + clockid_t clock_id; + + long long elapsed_since_last_start(); +}; + +inline Timer& Timer::start() +{ + if (running) + throw Processor_Error("Timer already running."); + // clock() is not suitable in threaded programs so time using something else + clock_gettime(clock_id, &startv); + running = true; + return *this; +} + +inline void Timer::stop() +{ + if (!running) + throw Processor_Error("Time not running."); + elapsed_time += elapsed_since_last_start(); + + running = false; + clock_gettime(clock_id, &startv); +} + +inline long long Timer::elapsed_since_last_start() +{ + timespec endv; + clock_gettime(clock_id, &endv); + return timespec_diff(&startv, &endv); +} + +#endif + diff --git a/check-passive.cpp b/check-passive.cpp new file mode 100644 index 000000000..53b9c4e8a --- /dev/null +++ b/check-passive.cpp @@ -0,0 +1,75 @@ +// (C) 2016 University of Bristol. See License.txt + +#include "Math/gf2n.h" +#include "Math/gfp.h" +#include "Math/Setup.h" + +#include +#include +#include + +template +void check_triples(int n_players) +{ + ifstream* inputFiles = new ifstream[n_players]; + for (int i = 0; i < n_players; i++) + { + stringstream ss; + ss << get_prep_dir(n_players, 128, 128) << "Triples-" << T::type_char() << "-P" << i; + inputFiles[i].open(ss.str().c_str()); + cout << "Opening file " << ss.str() << endl; + } + + int j = 0; + while (inputFiles[0].peek() != EOF) + { + T a,b,c,cc,tmp; + vector as(n_players), bs(n_players), cs(n_players); + for (int i = 0; i < n_players; i++) + { + as[i].input(inputFiles[i], false); + bs[i].input(inputFiles[i], false); + cs[i].input(inputFiles[i], false); + } + + a = accumulate(as.begin(), as.end(), T()); + b = accumulate(bs.begin(), bs.end(), T()); + c = accumulate(cs.begin(), cs.end(), T()); + + if (a * b != c) + { + cout << T::type_string() << ": Error in " << j << endl; + cout << "a " << a << " " << as[0] << " " << as[1] << endl; + cout << "b " << b << " " << bs[0] << " " << bs[1] << endl; + cout << "c " << c << " " << cs[0] << " " << cs[1] << endl; + for (int i = 0; i < 2; i++) + for (int j = 0; j < 2; j++) + { + tmp = as[i] * bs[j]; + cc += tmp; + cout << "a" << i << " * b" << j << " " << tmp << endl; + } + cout << "cc " << cc << endl; + cout << "a*b " << a*b << endl; + cout << "DID YOU INDICATE THE CORRECT NUMBER OF PLAYERS?" << endl; + + return; + } + + j++; + } + + cout << j << " correct triples of type " << T::type_string() << endl; + delete[] inputFiles; +} + +int main(int argc, char** argv) +{ + int n_players = 2; + if (argc > 1) + n_players = atoi(argv[1]); + read_setup(n_players, 128, 128); + gfp::init_field(gfp::pr(), false); + check_triples(n_players); + check_triples(n_players); +} diff --git a/compile.py b/compile.py new file mode 100755 index 000000000..d3cfdd6bb --- /dev/null +++ b/compile.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python + +# (C) 2016 University of Bristol. See License.txt + + +# ===== Compiler usage instructions ===== +# +# ./compile.py input_file +# +# will compile Programs/Source/input_file.asm onto +# Programs/Bytecode/input_file.bc +# +# (run with --help for more options) +# +# See Compiler/README for details on the Compiler package + + +from optparse import OptionParser +import Compiler + +def main(): + usage = "usage: %prog [options] filename" + parser = OptionParser(usage=usage) + parser.add_option("-n", "--nomerge", + action="store_false", dest="merge_opens", default=True, + help="don't attempt to merge open instructions") + parser.add_option("-o", "--output", dest="outfile", + help="specify output file") + parser.add_option("-a", "--asm-output", dest="asmoutfile", + help="asm output file for debugging") + parser.add_option("-l", "--asm-input", action="store_true", dest="assemblymode", + help="old-style asm input") + parser.add_option("-p", "--primesize", dest="param", default=-1, + help="bit length of modulus") + parser.add_option("-g", "--galoissize", dest="galois", default=40, + help="bit length of Galois field") + parser.add_option("-d", "--debug", action="store_true", dest="debug", + help="keep track of trace for debugging") + parser.add_option("-e", "--emulate", action="store_true", dest="emulate", default=False, + help="emulate register values for debugging") + parser.add_option("-c", "--comparison", dest="comparison", default="log", + help="comparison variant: log|plain|inv|sinv") + parser.add_option("-r", "--noreorder", dest="reorder_between_opens", + action="store_false", default=True, + help="don't attempt to place instructions between start/stop opens") + parser.add_option("-M", "--preserve-mem-order", action="store_true", + dest="preserve_mem_order", default=False, + help="preserve order of memory instructions; possible efficiency loss") + parser.add_option("-u", "--noreallocate", action="store_true", dest="noreallocate", + default=False, help="don't reallocate") + parser.add_option("-m", "--max-parallel-open", dest="max_parallel_open", + default=False, help="restrict number of parallel opens") + parser.add_option("-D", "--dead-code-elimination", action="store_true", + dest="dead_code_elimination", default=False, + help="eliminate instructions with unused result") + parser.add_option("-P", "--profile", action="store_true", dest="profile", + help="profile compilation") + parser.add_option("-C", "--continous", action="store_true", dest="continuous", + help="continuous computation") + options,args = parser.parse_args() + if len(args) != 1: + parser.print_help() + return + + def compilation(): + prog = Compiler.run(args[0], options, param=int(options.param), + merge_opens=options.merge_opens, emulate=options.emulate, + assemblymode=options.assemblymode, debug=options.debug) + prog.write_bytes(options.outfile) + + if options.asmoutfile: + for tape in prog.tapes: + tape.write_str(options.asmoutfile + '-' + tape.name) + + if options.profile: + import cProfile + p = cProfile.Profile().runctx('compilation()', globals(), locals()) + p.dump_stats(args[0] + '.prof') + p.print_stats(2) + else: + compilation() + +if __name__ == '__main__': + main() diff --git a/ot-offline.cpp b/ot-offline.cpp new file mode 100644 index 000000000..13c356339 --- /dev/null +++ b/ot-offline.cpp @@ -0,0 +1,13 @@ +// (C) 2016 University of Bristol. See License.txt + +/* + * OT-Offline.cpp + * + */ + +#include "OT/NPartyTripleGenerator.h" + +int main(int argc, const char** argv) +{ + TripleMachine(argc, argv).run(); +} diff --git a/tutorial.md b/tutorial.md new file mode 100644 index 000000000..a58873c43 --- /dev/null +++ b/tutorial.md @@ -0,0 +1,161 @@ +(C) 2016 University of Bristol. See License.txt + +Suppose we want to add 2 integers mod p in clear, where p has 128 bits and compute over 2 parties inputs: P0, P1. + +First create a file named "addition.mpc" in Programs/Source/ folder containing the following: + + +Computation on Clear Data +================== + +``` +a = cint(2) +b = cint(10) + +c = a + b +print_ln('Result is %s', c) +``` + +Next step is to transform the file into bytecode which will be run later on the VM between different parties. +For that, type in terminal: + +``` +./compile -p 128 addition + +``` + +The command will output to stderr the number of registers, rounds and other parameters used for measuring the requirements from the offline phase. + +To simply run the program between 2 parties simulated locally, type in terminal: + +``` +sh Scripts/run-online.sh addition + +``` + +The output will be a summary for the online phase including the result of the computation. + +```` + Result is 12 + +```` + + +Computation on Secret Shared data +================================= + +``` +a = sint(2) +b = sint(10) + +c = a + b +print_ln('Result is %s', c.reveal()) + +``` + +This means that a = a_0 + a_1, b = b_0 + b_1 where a_i belongs to party P_i. + +It's that simple! When we reveal a secret register, the MAC-checking - according +to SPDZ online phase - is done automatically. If there are multiple calls to +reveal() then the rounds of communication are merged automatically by the +compiler. + +Remember that a and b are hard-coded constants so the data is shared by one +party having the actualy inputs (a_0=2, b_0=10) where the other one has (a_1 = +0, b_1 = 0). Usually it's easier to debug when things are written in this way. + +If we want to run a real MPC computation - P1 shares a and P2 shares b - and +reveal the sum of the values then we can write the following. + +``` +a = sint.get_raw_input(0) +b = sint.get_raw_input(1) + +c = a + b +print_ln('Result is %s', c.reveal()) + +``` + +SPDZ also supports multiplication, division (in a prime field) as well as +GF(2^n) data types. All types can be seen in types.py file. + + +Array Lookup +============= + +Suppose party P0 inputs an array of fixed length n: A[1]...A[n] and party P1 +inputs a SS index [index]. + +``` +A = [sint.get_raw_input(0) for _ in range(100)] +index = sint.get_raw_input(1) +``` + +Let's see how can we obtain [A[i]]. + +The standard way to solve this task in a non-MPC way is the following: + +``` +clear_index = index.reveal() +print_ln('%s', A[clear_index].reveal()) + +``` + +But this solution reveals party P1's input which isn't secure! + +Another way of doing this is:inary form. There is + +``` +accumulator = 0 +for i in range(len(A)): + accumulator = accumulator + (i == index) * A[i] + +print_ln('%s', accumulator.reveal()) + +``` + +In this case the only revealed value is A[index]. The main trick was that +comparison between sint() and clear data returns a sint(). In conclusion, all +data is masked except the last step. + +Providing input to SPDZ +======================== + +Looking back at the addition example, in Player-Data/Private-Input-{i} for i in +{0,1} we have to provide one integer in each file. Because the conversion +between human readable form to SPDZ data types is time expensive, we have to +feed the numbers in binary form. There is a script for that: gen_input_fp.cpp +and gen_input_f2.cpp in the Scripts directory designed for generating gfp +inputs. The executables can be found after compiling SPDZ. Customizing those +should be straightforward. Make sure you copy the output files to Player-Data +/Private-Input-{i} files. + + +Other examples +============== + +aes.mpc has the AES evaluation in MPC where the key belongs to party 0 +and the message to party 1. + +For example, say that we have a key equal to + +``` +28aed2a6abf7158809cf4f3c +``` + +We have to convert this into SPDZ datatypes so we use the scripts +gen_input_f2n.x which can be found in the main directory after compiling +SPDZ. Next, the file 'gf2n_vals.in' should contain the following: + +``` +16 +0x2b 0x7e 0x15 0x16 0x28 0xae 0xd2 0xa6 0xab 0xf7 0x15 0x88 0x9 0xcf 0x4f 0x3c +``` + +Now execute the gen_input_f2n.x script and copy the output file to +Player-Data/Private-Input-0 + +The same method applies to generate the message as an input to SPDZ. + +Since all items are in place we can run the online phase and see the +revealed encryption result: