diff --git a/Compiler/allocator.py b/Compiler/allocator.py index f2154cabe..2aaac2015 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -581,6 +581,70 @@ def keep_text_order(inst, n): keep_text_order(instr, n) elif isinstance(instr, RawInputInstruction): keep_merged_order(instr, n, RawInputInstruction) + elif isinstance(instr, matmulsm): + if options.preserve_mem_order: + strict_mem_access(n, last_mem_read, last_mem_write) + else: + if instr.indices_values is not None and instr.first_factor_base_addresses is not None and instr.second_factor_base_addresses is not None: + # Determine which values get accessed by the MATMULSM instruction and only add the according dependencies. + for matmul_idx in range(len(instr.first_factor_base_addresses)): + start_time = time.time() + first_base = instr.first_factor_base_addresses[matmul_idx] + second_base = instr.second_factor_base_addresses[matmul_idx] + + first_factor_row_indices = instr.indices_values[4 * matmul_idx] + first_factor_column_indices = instr.indices_values[4 * matmul_idx + 1] + second_factor_row_indices = instr.indices_values[4 * matmul_idx + 2] + second_factor_column_indices = instr.indices_values[4 * matmul_idx + 3] + + first_factor_row_length = instr.args[12 * matmul_idx + 10] + second_factor_row_length = instr.args[12 * matmul_idx + 11] + + # Due to the potentially very large number of inputs on large matrices, adding dependencies to + # all inputs may take a long time. Therefore, we only partially build the dependencies on + # large matrices and output a warning. + # The threshold of 2_250_000 values per matrix is equivalent to multiplying two 1500x1500 + # matrices. Experiments showed that multiplying two 1700x1700 matrices requires roughly 10 seconds on an i7-1370P, + # so this threshold should lead to acceptable compile times even on slower processors. + first_factor_total_number_of_values = instr.args[12 * matmul_idx + 3] * instr.args[12 * matmul_idx + 4] + second_factor_total_number_of_values = instr.args[12 * matmul_idx + 4] * instr.args[12 * matmul_idx + 5] + max_dependencies_per_matrix = 1500**2 + if first_factor_total_number_of_values > max_dependencies_per_matrix or second_factor_total_number_of_values > max_dependencies_per_matrix: + if block.warn_about_mem and not block.parent.warned_about_mem: + print('WARNING: Order of memory instructions not preserved due to long vector, errors possible') + block.parent.warned_about_mem = True + + # Add dependencies to the first factor. + # If the size of the matrix exceeds the max_dependencies_per_matrix, only a limited number + # of rows will be processed. + for i in range(min(instr.args[12 * matmul_idx + 3], max_dependencies_per_matrix // instr.args[12 * matmul_idx + 4] + 1)): + for k in range(instr.args[12 * matmul_idx + 4]): + first_factor_addr = first_base + \ + first_factor_row_length * first_factor_row_indices[i] + \ + first_factor_column_indices[k] + handle_mem_access(first_factor_addr, 's', last_mem_read_of, last_mem_write_of) + + # Add dependencies to the second factor. + # If the size of the matrix exceeds the max_dependencies_per_matrix, only a limited number + # of rows will be processed. + for k in range(min(instr.args[12 * matmul_idx + 4], max_dependencies_per_matrix // instr.args[12 * matmul_idx + 5] + 1)): + if (time.time() - start_time) > 10: + # Abort building the dependencies if that takes too much time. + if block.warn_about_mem and not block.parent.warned_about_mem: + print('WARNING: Order of memory instructions not preserved due to long vector, errors possible') + block.parent.warned_about_mem = True + break + + for j in range(instr.args[12 * matmul_idx + 5]): + second_factor_addr = second_base + \ + second_factor_row_length * second_factor_row_indices[k] + \ + second_factor_column_indices[j] + handle_mem_access(second_factor_addr, 's', last_mem_read_of, last_mem_write_of) + else: + # If the accessed values cannot be determined, be cautious I guess. + for i in last_mem_write_of.values(): + for j in i: + add_edge(j, n) if isinstance(instr, merge_classes): open_nodes.add(n) @@ -622,13 +686,6 @@ def keep_text_order(inst, n): strict_mem_access(n, scope.write, scope.read) if not options.preserve_mem_order: mem_access(n, instr, last_mem_write_of, last_mem_read_of) - elif isinstance(instr, matmulsm): - if options.preserve_mem_order: - strict_mem_access(n, last_mem_read, last_mem_write) - else: - for i in last_mem_write_of.values(): - for j in i: - add_edge(j, n) # keep I/O instructions in order elif isinstance(instr, IOInstruction): if last_print_str is not None: diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 9e7b23e7d..230f62539 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -2484,7 +2484,7 @@ def get_repeat(self): return sum(reduce(operator.mul, self.args[i + 3:i + 6]) for i in range(0, len(self.args), 6)) -class matmulsm(matmul_base): +class matmulsm(matmul_base, base.Mergeable): """ Secret matrix multiplication reading directly from memory. :param: result (sint vector in row-first order) @@ -2494,26 +2494,46 @@ class matmulsm(matmul_base): :param: number of columns in first factor and rows in second factor (int) :param: number of columns in second factor and result (int) :param: rows of first factor to use (regint vector, length as number of rows in first factor) - :param: columns of first factor to use (regint vector, length below) - :param: rows of second factor to use (regint vector, length below) - :param: columns of second factor to use (regint vector, length below) - :param: number of columns of first / rows of second factor to use (int) - :param: number of columns of second factor to use (int) + :param: columns of first factor to use (regint vector, length as number of columns in the first factor) + :param: rows of second factor to use (regint vector, length as number of columns in the first factor) + :param: columns of second factor to use (regint vector, length as number of columns in the second factor) + :param: total number of columns in the first factor, equal to used number of columns when all columns are used (int) + :param: total number of columns in the second factor, equal to used number of columns when all columns are used (int) """ code = base.opcodes['MATMULSM'] - arg_format = ['sw','ci','ci','int','int','int','ci','ci','ci','ci', - 'int','int'] - - def __init__(self, *args, **kwargs): + arg_format = itertools.cycle(['sw','ci','ci','int','int','int','ci','ci','ci','ci', + 'int','int']) + + def __init__(self, *args, + first_factor_base_addresses=None, + second_factor_base_addresses=None, + indices_values=None, + **kwargs): matmul_base.__init__(self, *args, **kwargs) - for i in range(2): - assert args[6 + i].size == args[3 + i] - for i in range(2): - assert args[8 + i].size == args[4 + i] + for matmul_index in range(len(args) // 12): + for i in range(2): + assert args[12 * matmul_index + 6 + i].size == args[12 * matmul_index + 3 + i] + for i in range(2): + assert args[12 * matmul_index + 8 + i].size == args[12 * matmul_index + 4 + i] + + # These are used to reconstruct that accessed memory addresses in the allocator. + self.first_factor_base_addresses = first_factor_base_addresses + self.second_factor_base_addresses = second_factor_base_addresses + self.indices_values = indices_values + + if first_factor_base_addresses is not None: + assert len(first_factor_base_addresses) == len(second_factor_base_addresses) + if indices_values is not None: + assert len(indices_values) == 4 * len(first_factor_base_addresses) def add_usage(self, req_node): super(matmulsm, self).add_usage(req_node) - req_node.increment(('matmul', tuple(self.args[3:6])), 1) + for i in range(0, len(self.args), 12): + req_node.increment(('matmul', (self.args[i + 3], self.args[i + 4], self.args[i + 5])), 1) + + def get_repeat(self): + return sum(reduce(operator.mul, self.args[i + 3:i + 6]) + for i in range(0, len(self.args), 12)) class conv2ds(base.DataInstruction, base.VarArgsInstruction, base.Mergeable): """ Secret 2D convolution. diff --git a/Compiler/types.py b/Compiler/types.py index 9606058b9..6aa31c4e5 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -2668,12 +2668,24 @@ def store_in_mem(self, address): self._store_in_mem(address, stms, stmsi) @classmethod - def direct_matrix_mul(cls, A, B, n, m, l, reduce=None, indices=None): + def direct_matrix_mul(cls, A, B, n, m, l, reduce=None, indices=None, indices_values=None): if indices is None: indices = [regint.inc(i) for i in (n, m, m, l)] + indices_values = [list(range(i)) for i in (n, m, m, l)] res = cls(size=indices[0].size * indices[3].size) + + if isinstance(A, int) and isinstance(B, int): + first_factor_base_addresses = [A] + second_factor_base_addresses = [B] + else: + first_factor_base_addresses = None + second_factor_base_addresses = None + matmulsm(res, regint(A), regint(B), len(indices[0]), len(indices[1]), - len(indices[3]), *(list(indices) + [m, l])) + len(indices[3]), *(list(indices) + [m, l]), + first_factor_base_addresses=first_factor_base_addresses, + second_factor_base_addresses=second_factor_base_addresses, + indices_values=indices_values) return res @vectorize_init diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index ef21e3728..90f4db526 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -323,8 +323,8 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) get_vector(num_var_args, start, s); break; case MATMULSM: - get_ints(r, s, 3); - get_vector(9, start, s); + num_var_args = get_int(s); + get_vector(num_var_args, start, s); break; // read from file, input is opcode num_args, @@ -1117,8 +1117,7 @@ inline void Instruction::execute(Processor& Proc) const Proc.Procp.matmuls(Proc.Procp.get_S(), *this); return; case MATMULSM: - Proc.Procp.protocol.matmulsm(Proc.Procp, Proc.machine.Mp.MS, *this, - Proc.read_Ci(r[1]), Proc.read_Ci(r[2])); + Proc.Procp.protocol.matmulsm(Proc.Procp, Proc.machine.Mp.MS, *this); return; case CONV2DS: Proc.Procp.protocol.conv2ds(Proc.Procp, *this); diff --git a/Processor/Processor.h b/Processor/Processor.h index 9b4757f4e..08f4cd269 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -77,8 +77,12 @@ class SubProcessor void mulrs(const vector& reg); void dotprods(const vector& reg, int size); void matmuls(const vector& source, const Instruction& instruction); - void matmulsm(const MemoryPart& source, const Instruction& instruction, size_t a, - size_t b); + void matmulsm(const MemoryPart& source, const Instruction& instruction); + + void matmulsm_finalize_batch(vector::const_iterator startMatmul, int startI, int startJ, + vector::const_iterator endMatmul, + int endI, int endJ); + void conv2ds(const Instruction& instruction); void secure_shuffle(const Instruction& instruction); diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index d468a6302..aab2ba9d0 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -601,73 +601,156 @@ void SubProcessor::matmuls(const vector& source, } } + template void SubProcessor::matmulsm(const MemoryPart& source, - const Instruction& instruction, size_t a, size_t b) + const Instruction& instruction) { - auto& dim = instruction.get_start(); - auto C = S.begin() + (instruction.get_r(0)); - assert(C + dim[0] * dim[2] <= S.end()); assert(Proc); - int base = 0; - int base2 = 0; + auto& start = instruction.get_start(); + + auto batchStartMatrix = start.begin(); + int batchStartI = 0; + int batchStartJ = 0; + + size_t sourceSize = source.size(); + const T* sourceData = source.data(); + protocol.init_dotprod(); - for (int i = 0; i < dim[0]; i++) - { - auto ii = Proc->get_Ci().at(dim[3] + i).get(); - for (int j = 0; j < dim[2]; j++) - { -#ifdef DEBUG_MATMULSM - cerr << "matmulsm prep " << i << " " << j << endl; -#endif - matmulsm_prep(ii, j, source, dim, a, b); - if (protocol.get_buffer_size() > OnlineOptions::singleton.batch_size) - { -#ifdef DEBUG_MATMULSM - cerr << "matmulsm round " << protocol.get_buffer_size() << endl; + for (auto matmulArgs = start.begin(); matmulArgs < start.end(); matmulArgs += 12) { + auto output = S.begin() + matmulArgs[0]; + size_t firstFactorBase = Proc->get_Ci().at(matmulArgs[1]).get(); + size_t secondFactorBase = Proc->get_Ci().at(matmulArgs[2]).get(); + auto resultNumberOfRows = matmulArgs[3]; + auto usedNumberOfFirstFactorColumns = matmulArgs[4]; + auto resultNumberOfColumns = matmulArgs[5]; + auto firstFactorTotalNumberOfColumns = matmulArgs[10]; + auto secondFactorTotalNumberOfColumns = matmulArgs[11]; + + assert(output + resultNumberOfRows * resultNumberOfColumns <= S.end()); + + for (int i = 0; i < resultNumberOfRows; i += 1) { + auto actualFirstFactorRow = Proc->get_Ci().at(matmulArgs[6] + i).get(); + + for (int j = 0; j < resultNumberOfColumns; j += 1) { + auto actualSecondFactorColumn = Proc->get_Ci().at(matmulArgs[9] + j).get(); + +#ifdef MATMULSM_DEBUG + cout << "Preparing " << i << "," << j << "(buffer size: " << protocol.get_buffer_size() << ")" << endl; #endif - protocol.exchange(); - if (base < i) - for (int l = base2; l < dim[2]; l++) - matmulsm_finalize(base, l, dim, C); - for (int k = base + 1; k < i; k++) - for (int l = 0; l < dim[2]; l++) - matmulsm_finalize(k, l, dim, C); - for (int l = base < i ? 0 : base2; l <= j; l++) - matmulsm_finalize(i, l, dim, C); - base = i; - base2 = j + 1; - protocol.init_dotprod(); + + for (int k = 0; k < usedNumberOfFirstFactorColumns; k += 1) { + auto actualFirstFactorColumn = Proc->get_Ci().at(matmulArgs[7] + k).get(); + auto actualSecondFactorRow = Proc->get_Ci().at(matmulArgs[8] + k).get(); + + auto firstAddress = firstFactorBase + actualFirstFactorRow * firstFactorTotalNumberOfColumns + actualFirstFactorColumn; + auto secondAddress = secondFactorBase + actualSecondFactorRow * secondFactorTotalNumberOfColumns + actualSecondFactorColumn; + + assert(firstAddress < sourceSize); + assert(secondAddress < sourceSize); + + protocol.prepare_dotprod(sourceData[firstAddress], sourceData[secondAddress]); + } + protocol.next_dotprod(); + + if (protocol.get_buffer_size() > OnlineOptions::singleton.batch_size) { + protocol.exchange(); + + matmulsm_finalize_batch(batchStartMatrix, batchStartI, batchStartJ, + matmulArgs, i, j); + batchStartMatrix = matmulArgs; + batchStartI = i; + batchStartJ = j + 1; + + protocol.init_dotprod(); + } } } } + protocol.exchange(); - for (int j = base2; j < dim[2]; j++) - matmulsm_finalize(base, j, dim, C); - for (int i = base + 1; i < dim[0]; i++) - for (int j = 0; j < dim[2]; j++) - matmulsm_finalize(i, j, dim, C); + auto lastMatmulsArgs = start.end() - 12; + auto lastMatrixRows = lastMatmulsArgs[3]; + auto lastMatrixColumns = lastMatmulsArgs[5]; + matmulsm_finalize_batch(batchStartMatrix, batchStartI, batchStartJ, + lastMatmulsArgs, lastMatrixRows - 1, lastMatrixColumns - 1); } template -void SubProcessor::matmulsm_prep(int ii, int j, const MemoryPart& source, - const vector& dim, size_t a, size_t b) -{ - auto jj = Proc->get_Ci().at(dim[6] + j).get(); - const T* base = source.data(); - size_t size = source.size(); - for (int k = 0; k < dim[1]; k++) - { - auto kk = Proc->get_Ci().at(dim[4] + k).get(); - auto ll = Proc->get_Ci().at(dim[5] + k).get(); - auto aa = a + ii * dim[7] + kk; - auto bb = b + ll * dim[8] + jj; - assert(aa < size); - assert(bb < size); - protocol.prepare_dotprod(base[aa], base[bb]); +void SubProcessor::matmulsm_finalize_batch(vector::const_iterator startMatmul, int startI, int startJ, + vector::const_iterator endMatmul, int endI, int endJ) { + + for (auto matmulArgs = startMatmul; matmulArgs <= endMatmul; matmulArgs += 12) { + auto output = S.begin() + matmulArgs[0]; + auto resultNumberOfRows = matmulArgs[3]; + auto usedNumberOfFirstFactorColumns = matmulArgs[4]; + auto resultNumberOfColumns = matmulArgs[5]; + + assert(output + resultNumberOfRows * resultNumberOfColumns <= S.end()); + + // Finish the first unfinished row in the current matrix. + int firstRowEndJ = resultNumberOfColumns - 1; + if (matmulArgs == endMatmul && startI == endI) // For the case that the batch covers only a part of the first row of current matrix or only part of a single row. + firstRowEndJ = endJ; + #ifdef MATMULSM_DEBUG + cout << "Batch is in single row " << endJ << endl; + #endif + for (int j = startJ; j <= firstRowEndJ; j += 1) { +#ifdef MATMULSM_DEBUG + cout << "Finalizing (first row) " << startI << "," << j << endl; +#endif + *(output + startI * resultNumberOfColumns + j) = protocol.finalize_dotprod(usedNumberOfFirstFactorColumns); + } + if (firstRowEndJ == resultNumberOfColumns - 1) { + startJ = 0; + startI += 1; + } + else { + // The whole batch covers only a part of a single row. + startJ = endJ + 1; + } + + // Determine the point up until which the batch runs in the current matrix. + int currentMatrixEndI = resultNumberOfRows - 1; + int currentMatrixEndJ = resultNumberOfColumns - 1; + if (matmulArgs == endMatmul) { + currentMatrixEndI = endI; + currentMatrixEndJ = endJ; + } + + // Finish the rows that always are complete, i.e., the second to the "second to last" row. + for (; startI <= currentMatrixEndI - 1; startI += 1) { + for (int j = 0; j < resultNumberOfColumns; j += 1) { +#ifdef MATMULSM_DEBUG + cout << "Finalizing (main part) " << startI << "," << j << endl; +#endif + *(output + startI * resultNumberOfColumns + j) = protocol.finalize_dotprod(usedNumberOfFirstFactorColumns); + } + } + + // (Partially) finish the last row. + if (startI == currentMatrixEndI) { + for (; startJ <= currentMatrixEndJ; startJ += 1) { +#ifdef MATMULSM_DEBUG + cout << "Finalizing (last row) " << startI << "," << startJ << endl; +#endif + *(output + startI * resultNumberOfColumns + startJ) = protocol.finalize_dotprod(usedNumberOfFirstFactorColumns); + } + } + else { +#ifdef MATMULSM_DEBUG + // This happens when there is only one row. + cout << "Skipping final row of matrix because it was handled previously." << endl; +#endif + } + + if (matmulArgs < endMatmul) { + // Reset startI and startJ to the beginning of the matrix. + startI = 0; + startJ = 0; + } } - protocol.next_dotprod(); } template diff --git a/Programs/Source/test_dot.mpc b/Programs/Source/test_dot.mpc new file mode 100644 index 000000000..92f0bad0c --- /dev/null +++ b/Programs/Source/test_dot.mpc @@ -0,0 +1,165 @@ +a = Array.create_from([sint(1), sint(2), sint(3), sint(4)]) +b = Array.create_from([sint(3), sint(2), sint(1)]) + +c = Matrix.create_from([ + [sint(1), sint(2), sint(3)], + [sint(4), sint(5), sint(6)], + [sint(7), sint(8), sint(9)], + [sint(10), sint(11), sint(12)] +]) + +d = Matrix.create_from([ + [sint(12), sint(11), sint(10), sint(9)], + [sint(8), sint(7), sint(6), sint(5)], + [sint(4), sint(3), sint(2), sint(1)] +]) + + +def test_array(expected, actual): + actual = actual.reveal() + expected = Array.create_from([cint(x) for x in expected]) + @for_range(len(expected)) + def _(i): + @if_(actual[i] != expected[i]) + def fail(): + print_ln("Unexpected entry at index %s", i) + print_ln("Expected:") + expected.print_reveal_nested() + print_ln("Actual:") + actual.print_reveal_nested() + + crash() + + +def test_matrix(expected, actual): + actual = actual.reveal() + expected = Matrix.create_from([[cint(x) for x in row] for row in expected]) + @for_range(len(expected)) + def outer(i): + + @for_range(len(expected[0])) + def inner(j): + @if_(actual[i][j] != expected[i][j]) + def fail(): + print_ln("Unexpected entry at index %s,%s", i, j) + print_ln("Expected:") + expected.print_reveal_nested() + print_ln("Actual:") + actual.print_reveal_nested() + + crash() + +break_point() +def hacky_array_dot_matrix(arr, mat): + # Arrays sadly do not have a dot function, therefore the array is converted into a 1 times n Matrix by copying memory addresses. + tmp = sint.Matrix(rows=1, columns=len(arr), address=arr.address) + result = tmp.dot(mat) + return sint.Array(mat.shape[1], result.address) + +start_timer(3) + +e3 = hacky_array_dot_matrix(a, c) +# b[0] = e3[0] +f3 = hacky_array_dot_matrix(b, d) + +stop_timer(3) + +e3 = e3.reveal() +f3 = f3.reveal() + +e3.print_reveal_nested() +f3.print_reveal_nested() + +test_array([70, 80, 90], e3) +test_array([56, 50, 44, 38], f3) + +start_timer(4) + +e4 = hacky_array_dot_matrix(a, c) +b[-1] = e4[0] +f4 = hacky_array_dot_matrix(b, d) + +stop_timer(4) + +test_array([70, 80, 90], e4) +test_array([332, 257, 182, 107], f4) + +f4.print_reveal_nested() + +# TODO: Crashes + + +start_timer(5) +g = c.dot(d) +stop_timer(5) + +test_matrix([ + [ 40, 34, 28, 22], + [112, 97, 82, 67], + [184, 160, 136, 112], + [256, 223, 190, 157] +], g) +g.print_reveal_nested() + + +# Big matrix tests. +# These are intended to test matrix multiplications that require multiple batches. + +def identity(size): + result = sint.Matrix(rows=size, columns=size) + result.assign_all(0) + for i in range(size): + result[i][i] = 1 + return result + + +def counting_matrix(rows, columns): + result = sint.Matrix(rows, columns) + @for_range(rows) + def outer(i): + @for_range(columns) + def inner(j): + result[i][j] = i * columns + j + return result + + +def clear_counting_matrix(rows, columns): + return [list(range(i * columns, (i + 1) * columns)) for i in range(rows)] + + +# Single matrix multiplication requiring multiple batches. +a = counting_matrix(20, 20) +b = identity(20) + +start_timer(6) +c = a * b +stop_timer(6) + +test_matrix(clear_counting_matrix(20, 20), c) + +# Multiple matrix multiplications requiring multiple batches. +start_timer(7) +d = a * b +e = c * b +stop_timer(7) + +test_matrix(clear_counting_matrix(20, 20), d) +test_matrix(clear_counting_matrix(20, 20), e) + + +start_timer(8) +d = a.dot(b, n_threads=2) +stop_timer(8) + +test_matrix(clear_counting_matrix(20, 20), d) + +start_timer(9) +M = sint.Matrix(10, 10) +M.direct_mul(M, indices=[regint(0), regint.inc(10), regint.inc(10), + regint(0)]) +stop_timer(9) + + +start_timer(10) +sint.Matrix(1000, 1000) * sint.Matrix(1000, 1000) +stop_timer(10) diff --git a/Protocols/Hemi.h b/Protocols/Hemi.h index 2073eac26..2bceb9f31 100644 --- a/Protocols/Hemi.h +++ b/Protocols/Hemi.h @@ -34,7 +34,7 @@ class Hemi : public T::BasicProtocol SubProcessor& processor); void matmulsm(SubProcessor& processor, MemoryPart& source, - const Instruction& instruction, int a, int b); + const Instruction& instruction); void conv2ds(SubProcessor& processor, const Instruction& instruction); }; diff --git a/Protocols/Hemi.hpp b/Protocols/Hemi.hpp index b232bc42d..807f5fc79 100644 --- a/Protocols/Hemi.hpp +++ b/Protocols/Hemi.hpp @@ -34,51 +34,70 @@ typename T::MatrixPrep& Hemi::get_matrix_prep(const array& dims, template void Hemi::matmulsm(SubProcessor& processor, MemoryPart& source, - const Instruction& instruction, int a, int b) + const Instruction& instruction) { if (HemiOptions::singleton.plain_matmul or not OnlineOptions::singleton.live_prep) { - processor.matmulsm(source, instruction, a, b); + processor.matmulsm(source, instruction); return; } - auto& dim = instruction.get_start(); - auto& S = processor.get_S(); - auto C = S.begin() + (instruction.get_r(0)); - assert(C + dim[0] * dim[2] <= S.end()); + // Perform the matrix multiplications in sequence. + // They are not merged into one communication round since that would require multiple matrix_preps to + // merge rounds. + // An improvement might be to merge the communication of multiple matrices with the same dimension into one round, + // which is not implemented yet. auto Proc = processor.Proc; assert(Proc); + auto& S = processor.get_S(); + auto& start = instruction.get_start(); + + for (auto matmulArgs = start.begin(); matmulArgs < start.end(); matmulArgs += 12) { + auto C = S.begin() + matmulArgs[0]; + size_t firstFactorBase = Proc->get_Ci().at(matmulArgs[1]).get(); + size_t secondFactorBase = Proc->get_Ci().at(matmulArgs[2]).get(); + auto resultNumberOfRows = matmulArgs[3]; + auto usedNumberOfFirstFactorColumns = matmulArgs[4]; + auto resultNumberOfColumns = matmulArgs[5]; + auto firstFactorTotalNumberOfColumns = matmulArgs[10]; + auto secondFactorTotalNumberOfColumns = matmulArgs[11]; + + assert(C + resultNumberOfRows * resultNumberOfColumns <= S.end()); + + ShareMatrix A(resultNumberOfRows, usedNumberOfFirstFactorColumns), B(usedNumberOfFirstFactorColumns, resultNumberOfColumns); + if (not T::real_shares(processor.P)) + { + matrix_multiply(A, B, processor); + return; + } - ShareMatrix A(dim[0], dim[1]), B(dim[1], dim[2]); + for (int i = 0; i < resultNumberOfRows; i++) { + auto actualFirstFactorRow = Proc->get_Ci().at(matmulArgs[6] + i).get(); - if (not T::real_shares(processor.P)) - { - matrix_multiply(A, B, processor); - return; - } - - for (int i = 0; i < dim[0]; i++) - for (int k = 0; k < dim[1]; k++) - { - auto kk = Proc->get_Ci().at(dim[4] + k).get(); - auto ii = Proc->get_Ci().at(dim[3] + i).get(); - A.entries.v.push_back(source.at(a + ii * dim[7] + kk)); + for (int k = 0; k < usedNumberOfFirstFactorColumns; k++) + { + auto actualFirstFactorColumn = Proc->get_Ci().at(matmulArgs[7] + k).get(); + A.entries.v.push_back(source.at(firstFactorBase + actualFirstFactorRow * firstFactorTotalNumberOfColumns + actualFirstFactorColumn)); + } } - for (int k = 0; k < dim[1]; k++) - for (int j = 0; j < dim[2]; j++) - { - auto jj = Proc->get_Ci().at(dim[6] + j).get(); - auto ll = Proc->get_Ci().at(dim[5] + k).get(); - B.entries.v.push_back(source.at(b + ll * dim[8] + jj)); + + for (int k = 0; k < usedNumberOfFirstFactorColumns; k++) { + auto actualSecondFactorRow = Proc->get_Ci().at(matmulArgs[8] + k).get(); + for (int j = 0; j < resultNumberOfColumns; j++) + { + auto actualSecondFactorColumn = Proc->get_Ci().at(matmulArgs[9] + j).get(); + B.entries.v.push_back(source.at(secondFactorBase + actualSecondFactorRow * secondFactorTotalNumberOfColumns + actualSecondFactorColumn)); + } } - auto res = matrix_multiply(A, B, processor); + auto res = matrix_multiply(A, B, processor); - for (int i = 0; i < dim[0]; i++) - for (int j = 0; j < dim[2]; j++) - *(C + i * dim[2] + j) = res[{i, j}]; + for (int i = 0; i < resultNumberOfRows; i++) + for (int j = 0; j < resultNumberOfColumns; j++) + *(C + i * resultNumberOfColumns + j) = res[{i, j}]; + } } template diff --git a/Protocols/Replicated.h b/Protocols/Replicated.h index 4fb5a6317..1f1176ff7 100644 --- a/Protocols/Replicated.h +++ b/Protocols/Replicated.h @@ -111,8 +111,8 @@ class ProtocolBase template void matmulsm(SubProcessor & proc, MemoryPart& source, - const Instruction& instruction, int a, int b) - { proc.matmulsm(source, instruction, a, b); } + const Instruction& instruction) + { proc.matmulsm(source, instruction); } template void conv2ds(SubProcessor& proc, const Instruction& instruction)