diff --git a/ecc/bls12-377/fp/asm.go b/ecc/bls12-377/fp/asm_adx.go similarity index 100% rename from ecc/bls12-377/fp/asm.go rename to ecc/bls12-377/fp/asm_adx.go diff --git a/ecc/bls12-377/fp/element.go b/ecc/bls12-377/fp/element.go index 81a730fbd..393f45744 100644 --- a/ecc/bls12-377/fp/element.go +++ b/ecc/bls12-377/fp/element.go @@ -521,32 +521,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [7]uint64 var D uint64 diff --git a/ecc/bls12-377/fp/element_mul_amd64.s b/ecc/bls12-377/fp/element_mul_amd64.s deleted file mode 100644 index 1e19c4d3f..000000000 --- a/ecc/bls12-377/fp/element_mul_amd64.s +++ /dev/null @@ -1,857 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x8508c00000000001 -DATA q<>+8(SB)/8, $0x170b5d4430000000 -DATA q<>+16(SB)/8, $0x1ef3622fba094800 -DATA q<>+24(SB)/8, $0x1a22d9f300f5138f -DATA q<>+32(SB)/8, $0xc63b05c06ca1493b -DATA q<>+40(SB)/8, $0x01ae3a4617c510ea -GLOBL q<>(SB), (RODATA+NOPTR), $48 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x8508bfffffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, rb0, rb1, rb2, rb3, rb4, rb5) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $24-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), R8 - - // x[0] -> R10 - // x[1] -> R11 - // x[2] -> R12 - MOVQ 0(R8), R10 - MOVQ 8(R8), R11 - MOVQ 16(R8), R12 - MOVQ y+16(FP), R13 - - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // t[4] -> SI - // t[5] -> DI - // clear the flags - XORQ AX, AX - MOVQ 0(R13), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R10, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R11, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R12, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ 24(R8), AX, SI - ADOXQ AX, BX - - // (A,t[4]) := x[4]*y[0] + A - MULXQ 32(R8), AX, DI - ADOXQ AX, SI - - // (A,t[5]) := x[5]*y[0] + A - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 8(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[1] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 16(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[2] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 24(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[3] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 32(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[4] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 40(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[5] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[5] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[5] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[5] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[5] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[5] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // reduce element(R14,R15,CX,BX,SI,DI) using temp registers (R9,R8,R13,R10,R11,R12) - REDUCE(R14,R15,CX,BX,SI,DI,R9,R8,R13,R10,R11,R12) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $8-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R15 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - MOVQ 32(DX), SI - MOVQ 40(DX), DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - - // reduce element(R14,R15,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12,R13) - REDUCE(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/bls12-377/fp/element_ops_amd64.go b/ecc/bls12-377/fp/element_ops_amd64.go index 83bba45ae..ed2803d71 100644 --- a/ecc/bls12-377/fp/element_ops_amd64.go +++ b/ecc/bls12-377/fp/element_ops_amd64.go @@ -50,48 +50,8 @@ func Butterfly(a, b *Element) // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bls12-377/fp/element_ops_amd64.s b/ecc/bls12-377/fp/element_ops_amd64.s index 7242622a4..cabff26f7 100644 --- a/ecc/bls12-377/fp/element_ops_amd64.s +++ b/ecc/bls12-377/fp/element_ops_amd64.s @@ -1,306 +1,6 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// Code generated by gnark-crypto/generator. DO NOT EDIT. +// We include the hash to force the Go compiler to recompile: 11124594824487954849 +#include "../../../field/asm/element_6w_amd64.s" -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x8508c00000000001 -DATA q<>+8(SB)/8, $0x170b5d4430000000 -DATA q<>+16(SB)/8, $0x1ef3622fba094800 -DATA q<>+24(SB)/8, $0x1a22d9f300f5138f -DATA q<>+32(SB)/8, $0xc63b05c06ca1493b -DATA q<>+40(SB)/8, $0x01ae3a4617c510ea -GLOBL q<>(SB), (RODATA+NOPTR), $48 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x8508bfffffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, rb0, rb1, rb2, rb3, rb4, rb5) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R15,R9,R10,R11,R12,R13) - REDUCE(DX,CX,BX,SI,DI,R8,R15,R9,R10,R11,R12,R13) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R15,R9,R10,R11,R12,R13) - REDUCE(DX,CX,BX,SI,DI,R8,R15,R9,R10,R11,R12,R13) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R14,R15,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R14,R15,R9,R10,R11,R12) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), $40-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP)) - - MOVQ DX, R15 - MOVQ CX, s0-8(SP) - MOVQ BX, s1-16(SP) - MOVQ SI, s2-24(SP) - MOVQ DI, s3-32(SP) - MOVQ R8, s4-40(SP) - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ R15, DX - ADCQ s0-8(SP), CX - ADCQ s1-16(SP), BX - ADCQ s2-24(SP), SI - ADCQ s3-32(SP), DI - ADCQ s4-40(SP), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), $48-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ 32(AX), R8 - MOVQ 40(AX), R9 - MOVQ CX, R10 - MOVQ BX, R11 - MOVQ SI, R12 - MOVQ DI, R13 - MOVQ R8, R14 - MOVQ R9, R15 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - ADCQ 32(DX), R8 - ADCQ 40(DX), R9 - SUBQ 0(DX), R10 - SBBQ 8(DX), R11 - SBBQ 16(DX), R12 - SBBQ 24(DX), R13 - SBBQ 32(DX), R14 - SBBQ 40(DX), R15 - MOVQ CX, s0-8(SP) - MOVQ BX, s1-16(SP) - MOVQ SI, s2-24(SP) - MOVQ DI, s3-32(SP) - MOVQ R8, s4-40(SP) - MOVQ R9, s5-48(SP) - MOVQ $0x8508c00000000001, CX - MOVQ $0x170b5d4430000000, BX - MOVQ $0x1ef3622fba094800, SI - MOVQ $0x1a22d9f300f5138f, DI - MOVQ $0xc63b05c06ca1493b, R8 - MOVQ $0x01ae3a4617c510ea, R9 - CMOVQCC AX, CX - CMOVQCC AX, BX - CMOVQCC AX, SI - CMOVQCC AX, DI - CMOVQCC AX, R8 - CMOVQCC AX, R9 - ADDQ CX, R10 - ADCQ BX, R11 - ADCQ SI, R12 - ADCQ DI, R13 - ADCQ R8, R14 - ADCQ R9, R15 - MOVQ s0-8(SP), CX - MOVQ s1-16(SP), BX - MOVQ s2-24(SP), SI - MOVQ s3-32(SP), DI - MOVQ s4-40(SP), R8 - MOVQ s5-48(SP), R9 - MOVQ R10, 0(DX) - MOVQ R11, 8(DX) - MOVQ R12, 16(DX) - MOVQ R13, 24(DX) - MOVQ R14, 32(DX) - MOVQ R15, 40(DX) - - // reduce element(CX,BX,SI,DI,R8,R9) using temp registers (R10,R11,R12,R13,R14,R15) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - MOVQ R8, 32(AX) - MOVQ R9, 40(AX) - RET diff --git a/ecc/bls12-377/fp/element_ops_purego.go b/ecc/bls12-377/fp/element_ops_purego.go index a4c3796b9..072fb87c0 100644 --- a/ecc/bls12-377/fp/element_ops_purego.go +++ b/ecc/bls12-377/fp/element_ops_purego.go @@ -67,48 +67,8 @@ func reduce(z *Element) { // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3, t4, t5 uint64 var u0, u1, u2, u3, u4, u5 uint64 diff --git a/ecc/bls12-377/fp/element_test.go b/ecc/bls12-377/fp/element_test.go index 582d8b4af..a060095a0 100644 --- a/ecc/bls12-377/fp/element_test.go +++ b/ecc/bls12-377/fp/element_test.go @@ -641,7 +641,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -711,77 +710,6 @@ func TestElementLexicographicallyLargest(t *testing.T) { } -func TestElementVecOps(t *testing.T) { - assert := require.New(t) - - const N = 7 - a := make(Vector, N) - b := make(Vector, N) - c := make(Vector, N) - for i := 0; i < N; i++ { - a[i].SetRandom() - b[i].SetRandom() - } - - // Vector addition - c.Add(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Add(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector addition failed") - } - - // Vector subtraction - c.Sub(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Sub(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector subtraction failed") - } - - // Vector scaling - c.ScalarMul(a, &b[0]) - for i := 0; i < N; i++ { - var expected Element - expected.Mul(&a[i], &b[0]) - assert.True(c[i].Equal(&expected), "Vector scaling failed") - } -} - -func BenchmarkElementVecOps(b *testing.B) { - // note; to benchmark against "no asm" version, use the following - // build tag: -tags purego - const N = 1024 - a1 := make(Vector, N) - b1 := make(Vector, N) - c1 := make(Vector, N) - for i := 0; i < N; i++ { - a1[i].SetRandom() - b1[i].SetRandom() - } - - b.Run("Add", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Add(a1, b1) - } - }) - - b.Run("Sub", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Sub(a1, b1) - } - }) - - b.Run("ScalarMul", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.ScalarMul(a1, &b1[0]) - } - }) -} - func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() @@ -2353,42 +2281,42 @@ func gen() gopter.Gen { } } -func genFull() gopter.Gen { - return func(genParams *gopter.GenParameters) *gopter.GenResult { +func genRandomFq(genParams *gopter.GenParameters) Element { + var g Element - genRandomFq := func() Element { - var g Element + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } + if qElement[5] != ^uint64(0) { + g[5] %= (qElement[5] + 1) + } - if qElement[5] != ^uint64(0) { - g[5] %= (qElement[5] + 1) - } + for !g.smallerThanModulus() { + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[5] != ^uint64(0) { + g[5] %= (qElement[5] + 1) + } + } - for !g.smallerThanModulus() { - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } - if qElement[5] != ^uint64(0) { - g[5] %= (qElement[5] + 1) - } - } + return g +} - return g - } - a := genRandomFq() +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) var carry uint64 a[0], carry = bits.Add64(a[0], qElement[0], carry) @@ -2403,6 +2331,14 @@ func genFull() gopter.Gen { } } +func genElement() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} + func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { var modulus big.Int var aIntMod big.Int diff --git a/ecc/bls12-377/fp/vector.go b/ecc/bls12-377/fp/vector.go index 0df05e337..f1d659e76 100644 --- a/ecc/bls12-377/fp/vector.go +++ b/ecc/bls12-377/fp/vector.go @@ -219,6 +219,25 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} + func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") @@ -246,6 +265,32 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + +func innerProductVecGeneric(res *Element, a, b Vector) { + if len(a) != len(b) { + panic("vector.InnerProduct: vectors don't have the same length") + } + var tmp Element + for i := 0; i < len(a); i++ { + tmp.Mul(&a[i], &b[i]) + res.Add(res, &tmp) + } +} + +func mulVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Mul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], &b[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bls12-377/fp/vector_test.go b/ecc/bls12-377/fp/vector_test.go index 5d88af91c..a8deef945 100644 --- a/ecc/bls12-377/fp/vector_test.go +++ b/ecc/bls12-377/fp/vector_test.go @@ -18,10 +18,15 @@ package fp import ( "bytes" + "fmt" "github.com/stretchr/testify/require" + "os" "reflect" "sort" "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" ) func TestVectorSort(t *testing.T) { @@ -88,3 +93,283 @@ func (vector *Vector) unmarshalBinaryAsync(data []byte) error { } return <-chErr } + +func TestVectorOps(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = 2 + } else { + parameters.MinSuccessfulTests = 10 + } + properties := gopter.NewProperties(parameters) + + addVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Add(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Add(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + subVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Sub(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Sub(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + scalarMulVector := func(a Vector, b Element) bool { + c := make(Vector, len(a)) + c.ScalarMul(a, &b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sumVector := func(a Vector) bool { + var sum Element + computed := a.Sum() + for i := 0; i < len(a); i++ { + sum.Add(&sum, &a[i]) + } + + return sum.Equal(&computed) + } + + innerProductVector := func(a, b Vector) bool { + computed := a.InnerProduct(b) + var innerProduct Element + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + innerProduct.Add(&innerProduct, &tmp) + } + + return innerProduct.Equal(&computed) + } + + mulVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + a[0].SetUint64(0x24) + b[0].SetUint64(0x42) + c.Mul(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sizes := []int{1, 2, 3, 4, 8, 9, 15, 16, 509, 510, 511, 512, 513, 514} + type genPair struct { + g1, g2 gopter.Gen + label string + } + + for _, size := range sizes { + generators := []genPair{ + {genZeroVector(size), genZeroVector(size), "zero vectors"}, + {genMaxVector(size), genMaxVector(size), "max vectors"}, + {genVector(size), genVector(size), "random vectors"}, + {genVector(size), genZeroVector(size), "random and zero vectors"}, + } + for _, gp := range generators { + properties.Property(fmt.Sprintf("vector addition %d - %s", size, gp.label), prop.ForAll( + addVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector subtraction %d - %s", size, gp.label), prop.ForAll( + subVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector scalar multiplication %d - %s", size, gp.label), prop.ForAll( + scalarMulVector, + gp.g1, + genElement(), + )) + + properties.Property(fmt.Sprintf("vector sum %d - %s", size, gp.label), prop.ForAll( + sumVector, + gp.g1, + )) + + properties.Property(fmt.Sprintf("vector inner product %d - %s", size, gp.label), prop.ForAll( + innerProductVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector multiplication %d - %s", size, gp.label), prop.ForAll( + mulVector, + gp.g1, + gp.g2, + )) + } + } + + properties.TestingRun(t, gopter.NewFormatedReporter(false, 260, os.Stdout)) +} + +func BenchmarkVectorOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1 << 24 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + var mixer Element + mixer.SetRandom() + for i := 1; i < N; i++ { + a1[i-1].SetUint64(uint64(i)). + Mul(&a1[i-1], &mixer) + b1[i-1].SetUint64(^uint64(i)). + Mul(&b1[i-1], &mixer) + } + + for n := 1 << 4; n <= N; n <<= 1 { + b.Run(fmt.Sprintf("add %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Add(_a, _b) + } + }) + + b.Run(fmt.Sprintf("sub %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Sub(_a, _b) + } + }) + + b.Run(fmt.Sprintf("scalarMul %d", n), func(b *testing.B) { + _a := a1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.ScalarMul(_a, &mixer) + } + }) + + b.Run(fmt.Sprintf("sum %d", n), func(b *testing.B) { + _a := a1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.Sum() + } + }) + + b.Run(fmt.Sprintf("innerProduct %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.InnerProduct(_b) + } + }) + + b.Run(fmt.Sprintf("mul %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Mul(_a, _b) + } + }) + } +} + +func genZeroVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genMaxVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + + qMinusOne := qElement + qMinusOne[0]-- + + for i := 0; i < size; i++ { + g[i] = qMinusOne + } + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + mixer := Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[5] != ^uint64(0) { + mixer[5] %= (qElement[5] + 1) + } + + for !mixer.smallerThanModulus() { + mixer = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[5] != ^uint64(0) { + mixer[5] %= (qElement[5] + 1) + } + } + + for i := 1; i <= size; i++ { + g[i-1].SetUint64(uint64(i)). + Mul(&g[i-1], &mixer) + } + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} diff --git a/ecc/bls12-377/fr/asm.go b/ecc/bls12-377/fr/asm_adx.go similarity index 100% rename from ecc/bls12-377/fr/asm.go rename to ecc/bls12-377/fr/asm_adx.go diff --git a/ecc/bls12-377/fr/asm_avx.go b/ecc/bls12-377/fr/asm_avx.go new file mode 100644 index 000000000..955f55979 --- /dev/null +++ b/ecc/bls12-377/fr/asm_avx.go @@ -0,0 +1,27 @@ +//go:build !noavx +// +build !noavx + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import "golang.org/x/sys/cpu" + +var ( + supportAvx512 = supportAdx && cpu.X86.HasAVX512 && cpu.X86.HasAVX512DQ + _ = supportAvx512 +) diff --git a/ecc/bls12-377/fr/asm_noavx.go b/ecc/bls12-377/fr/asm_noavx.go new file mode 100644 index 000000000..e5a5b1f2c --- /dev/null +++ b/ecc/bls12-377/fr/asm_noavx.go @@ -0,0 +1,22 @@ +//go:build noavx +// +build noavx + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +const supportAvx512 = false diff --git a/ecc/bls12-377/fr/element.go b/ecc/bls12-377/fr/element.go index 07be74489..af277e8bb 100644 --- a/ecc/bls12-377/fr/element.go +++ b/ecc/bls12-377/fr/element.go @@ -81,6 +81,9 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 725501752471715839 +// mu = 2^288 / q needed for partial Barrett reduction +const mu uint64 = 58893420465 + func init() { _modulus.SetString("12ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001", 16) } @@ -477,32 +480,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [5]uint64 var D uint64 diff --git a/ecc/bls12-377/fr/element_mul_amd64.s b/ecc/bls12-377/fr/element_mul_amd64.s deleted file mode 100644 index ab1816245..000000000 --- a/ecc/bls12-377/fr/element_mul_amd64.s +++ /dev/null @@ -1,487 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x0a11800000000001 -DATA q<>+8(SB)/8, $0x59aa76fed0000001 -DATA q<>+16(SB)/8, $0x60b44d1e5c37b001 -DATA q<>+24(SB)/8, $0x12ab655e9a2ca556 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x0a117fffffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $24-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), SI - - // x[0] -> DI - // x[1] -> R8 - // x[2] -> R9 - // x[3] -> R10 - MOVQ 0(SI), DI - MOVQ 8(SI), R8 - MOVQ 16(SI), R9 - MOVQ 24(SI), R10 - MOVQ y+16(FP), R11 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ DI, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R8, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R9, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,R12,R11,DI) - REDUCE(R14,R13,CX,BX,SI,R12,R11,DI) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $8-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,DI,R8,R9) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/bls12-377/fr/element_ops_amd64.go b/ecc/bls12-377/fr/element_ops_amd64.go index 21568255d..b653e8006 100644 --- a/ecc/bls12-377/fr/element_ops_amd64.go +++ b/ecc/bls12-377/fr/element_ops_amd64.go @@ -51,7 +51,8 @@ func (vector *Vector) Add(a, b Vector) { if len(a) != len(b) || len(a) != len(*vector) { panic("vector.Add: vectors don't have the same length") } - addVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) + n := uint64(len(a)) + addVec(&(*vector)[0], &a[0], &b[0], n) } //go:noescape @@ -75,59 +76,123 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { if len(a) != len(*vector) { panic("vector.ScalarMul: vectors don't have the same length") } - scalarMulVec(&(*vector)[0], &a[0], b, uint64(len(a))) + const maxN = (1 << 32) - 1 + if !supportAvx512 || uint64(len(a)) >= maxN { + // call scalarMulVecGeneric + scalarMulVecGeneric(*vector, a, b) + return + } + n := uint64(len(a)) + if n == 0 { + return + } + // the code for scalarMul is identical to mulVec; and it expects at least + // 2 elements in the vector to fill the Z registers + var bb [2]Element + bb[0] = *b + bb[1] = *b + const blockSize = 16 + scalarMulVec(&(*vector)[0], &a[0], &bb[0], n/blockSize, qInvNeg) + if n%blockSize != 0 { + // call scalarMulVecGeneric on the rest + start := n - n%blockSize + scalarMulVecGeneric((*vector)[start:], a[start:], b) + } +} + +//go:noescape +func scalarMulVec(res, a, b *Element, n uint64, qInvNeg uint64) + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + n := uint64(len(*vector)) + if n == 0 { + return + } + const minN = 16 * 7 // AVX512 slower than generic for small n + const maxN = (1 << 32) - 1 + if !supportAvx512 || n <= minN || n >= maxN { + // call sumVecGeneric + sumVecGeneric(&res, *vector) + return + } + sumVec(&res, &(*vector)[0], uint64(len(*vector))) + return } //go:noescape -func scalarMulVec(res, a, b *Element, n uint64) +func sumVec(res *Element, a *Element, n uint64) + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + n := uint64(len(*vector)) + if n == 0 { + return + } + if n != uint64(len(other)) { + panic("vector.InnerProduct: vectors don't have the same length") + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call innerProductVecGeneric + // note; we could split the vector into smaller chunks and call innerProductVec + innerProductVecGeneric(&res, *vector, other) + return + } + innerProdVec(&res[0], &(*vector)[0], &other[0], uint64(len(*vector))) + + return +} + +//go:noescape +func innerProdVec(res *uint64, a, b *Element, n uint64) + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Mul: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call mulVecGeneric + mulVecGeneric(*vector, a, b) + return + } + + const blockSize = 16 + mulVec(&(*vector)[0], &a[0], &b[0], n/blockSize, qInvNeg) + if n%blockSize != 0 { + // call mulVecGeneric on the rest + start := n - n%blockSize + mulVecGeneric((*vector)[start:], a[start:], b[start:]) + } + +} + +// Patterns use for transposing the vectors in mulVec +var ( + pattern1 = [8]uint64{0, 8, 1, 9, 2, 10, 3, 11} + pattern2 = [8]uint64{12, 4, 13, 5, 14, 6, 15, 7} + pattern3 = [8]uint64{0, 1, 8, 9, 2, 3, 10, 11} + pattern4 = [8]uint64{12, 13, 4, 5, 14, 15, 6, 7} +) + +//go:noescape +func mulVec(res, a, b *Element, n uint64, qInvNeg uint64) // Mul z = x * y (mod q) // // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bls12-377/fr/element_ops_amd64.s b/ecc/bls12-377/fr/element_ops_amd64.s index ffa3b7bca..6c42136a7 100644 --- a/ecc/bls12-377/fr/element_ops_amd64.s +++ b/ecc/bls12-377/fr/element_ops_amd64.s @@ -1,627 +1,6 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// Code generated by gnark-crypto/generator. DO NOT EDIT. +// We include the hash to force the Go compiler to recompile: 9425145785761608449 +#include "../../../field/asm/element_4w_amd64.s" -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x0a11800000000001 -DATA q<>+8(SB)/8, $0x59aa76fed0000001 -DATA q<>+16(SB)/8, $0x60b44d1e5c37b001 -DATA q<>+24(SB)/8, $0x12ab655e9a2ca556 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x0a117fffffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R15,DI,R8,R9) - REDUCE(DX,CX,BX,SI,R15,DI,R8,R9) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, R11 - MOVQ CX, R12 - MOVQ BX, R13 - MOVQ SI, R14 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ R11, DX - ADCQ R12, CX - ADCQ R13, BX - ADCQ R14, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), NOSPLIT, $0-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ CX, R8 - MOVQ BX, R9 - MOVQ SI, R10 - MOVQ DI, R11 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - SUBQ 0(DX), R8 - SBBQ 8(DX), R9 - SBBQ 16(DX), R10 - SBBQ 24(DX), R11 - MOVQ $0x0a11800000000001, R12 - MOVQ $0x59aa76fed0000001, R13 - MOVQ $0x60b44d1e5c37b001, R14 - MOVQ $0x12ab655e9a2ca556, R15 - CMOVQCC AX, R12 - CMOVQCC AX, R13 - CMOVQCC AX, R14 - CMOVQCC AX, R15 - ADDQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - ADCQ R15, R11 - MOVQ R8, 0(DX) - MOVQ R9, 8(DX) - MOVQ R10, 16(DX) - MOVQ R11, 24(DX) - - // reduce element(CX,BX,SI,DI) using temp registers (R8,R9,R10,R11) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - RET - -// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] -TEXT ·addVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - -loop_1: - TESTQ BX, BX - JEQ done_2 // n == 0, we are done - - // a[0] -> SI - // a[1] -> DI - // a[2] -> R8 - // a[3] -> R9 - MOVQ 0(AX), SI - MOVQ 8(AX), DI - MOVQ 16(AX), R8 - MOVQ 24(AX), R9 - ADDQ 0(DX), SI - ADCQ 8(DX), DI - ADCQ 16(DX), R8 - ADCQ 24(DX), R9 - - // reduce element(SI,DI,R8,R9) using temp registers (R10,R11,R12,R13) - REDUCE(SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ SI, 0(CX) - MOVQ DI, 8(CX) - MOVQ R8, 16(CX) - MOVQ R9, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_1 - -done_2: - RET - -// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] -TEXT ·subVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - XORQ SI, SI - -loop_3: - TESTQ BX, BX - JEQ done_4 // n == 0, we are done - - // a[0] -> DI - // a[1] -> R8 - // a[2] -> R9 - // a[3] -> R10 - MOVQ 0(AX), DI - MOVQ 8(AX), R8 - MOVQ 16(AX), R9 - MOVQ 24(AX), R10 - SUBQ 0(DX), DI - SBBQ 8(DX), R8 - SBBQ 16(DX), R9 - SBBQ 24(DX), R10 - - // reduce (a-b) mod q - // q[0] -> R11 - // q[1] -> R12 - // q[2] -> R13 - // q[3] -> R14 - MOVQ $0x0a11800000000001, R11 - MOVQ $0x59aa76fed0000001, R12 - MOVQ $0x60b44d1e5c37b001, R13 - MOVQ $0x12ab655e9a2ca556, R14 - CMOVQCC SI, R11 - CMOVQCC SI, R12 - CMOVQCC SI, R13 - CMOVQCC SI, R14 - - // add registers (q or 0) to a, and set to result - ADDQ R11, DI - ADCQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - MOVQ DI, 0(CX) - MOVQ R8, 8(CX) - MOVQ R9, 16(CX) - MOVQ R10, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_3 - -done_4: - RET - -// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b -TEXT ·scalarMulVec(SB), $56-32 - CMPB ·supportAdx(SB), $1 - JNE noAdx_5 - MOVQ a+8(FP), R11 - MOVQ b+16(FP), R10 - MOVQ n+24(FP), R12 - - // scalar[0] -> SI - // scalar[1] -> DI - // scalar[2] -> R8 - // scalar[3] -> R9 - MOVQ 0(R10), SI - MOVQ 8(R10), DI - MOVQ 16(R10), R8 - MOVQ 24(R10), R9 - MOVQ res+0(FP), R10 - -loop_6: - TESTQ R12, R12 - JEQ done_7 // n == 0, we are done - - // TODO @gbotrel this is generated from the same macro as the unit mul, we should refactor this in a single asm function - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ SI, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ DI, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R8, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce t mod q - // reduce element(R14,R15,CX,BX) using temp registers (R13,AX,DX,s0-8(SP)) - REDUCE(R14,R15,CX,BX,R13,AX,DX,s0-8(SP)) - - MOVQ R14, 0(R10) - MOVQ R15, 8(R10) - MOVQ CX, 16(R10) - MOVQ BX, 24(R10) - - // increment pointers to visit next element - ADDQ $32, R11 - ADDQ $32, R10 - DECQ R12 // decrement n - JMP loop_6 - -done_7: - RET - -noAdx_5: - MOVQ n+24(FP), DX - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ DX, 8(SP) - MOVQ DX, 16(SP) - MOVQ a+8(FP), AX - MOVQ AX, 24(SP) - MOVQ DX, 32(SP) - MOVQ DX, 40(SP) - MOVQ b+16(FP), AX - MOVQ AX, 48(SP) - CALL ·scalarMulVecGeneric(SB) - RET diff --git a/ecc/bls12-377/fr/element_ops_purego.go b/ecc/bls12-377/fr/element_ops_purego.go index 9c34ebecc..f107066c7 100644 --- a/ecc/bls12-377/fr/element_ops_purego.go +++ b/ecc/bls12-377/fr/element_ops_purego.go @@ -78,53 +78,32 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} + // Mul z = x * y (mod q) // // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3 uint64 var u0, u1, u2, u3 uint64 diff --git a/ecc/bls12-377/fr/element_test.go b/ecc/bls12-377/fr/element_test.go index 9b4190285..27f878c17 100644 --- a/ecc/bls12-377/fr/element_test.go +++ b/ecc/bls12-377/fr/element_test.go @@ -637,7 +637,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -707,77 +706,6 @@ func TestElementLexicographicallyLargest(t *testing.T) { } -func TestElementVecOps(t *testing.T) { - assert := require.New(t) - - const N = 7 - a := make(Vector, N) - b := make(Vector, N) - c := make(Vector, N) - for i := 0; i < N; i++ { - a[i].SetRandom() - b[i].SetRandom() - } - - // Vector addition - c.Add(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Add(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector addition failed") - } - - // Vector subtraction - c.Sub(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Sub(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector subtraction failed") - } - - // Vector scaling - c.ScalarMul(a, &b[0]) - for i := 0; i < N; i++ { - var expected Element - expected.Mul(&a[i], &b[0]) - assert.True(c[i].Equal(&expected), "Vector scaling failed") - } -} - -func BenchmarkElementVecOps(b *testing.B) { - // note; to benchmark against "no asm" version, use the following - // build tag: -tags purego - const N = 1024 - a1 := make(Vector, N) - b1 := make(Vector, N) - c1 := make(Vector, N) - for i := 0; i < N; i++ { - a1[i].SetRandom() - b1[i].SetRandom() - } - - b.Run("Add", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Add(a1, b1) - } - }) - - b.Run("Sub", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Sub(a1, b1) - } - }) - - b.Run("ScalarMul", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.ScalarMul(a1, &b1[0]) - } - }) -} - func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() @@ -2345,38 +2273,38 @@ func gen() gopter.Gen { } } -func genFull() gopter.Gen { - return func(genParams *gopter.GenParameters) *gopter.GenResult { +func genRandomFq(genParams *gopter.GenParameters) Element { + var g Element - genRandomFq := func() Element { - var g Element + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } - if qElement[3] != ^uint64(0) { - g[3] %= (qElement[3] + 1) - } + for !g.smallerThanModulus() { + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } + } - for !g.smallerThanModulus() { - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } - if qElement[3] != ^uint64(0) { - g[3] %= (qElement[3] + 1) - } - } + return g +} - return g - } - a := genRandomFq() +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) var carry uint64 a[0], carry = bits.Add64(a[0], qElement[0], carry) @@ -2389,6 +2317,14 @@ func genFull() gopter.Gen { } } +func genElement() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} + func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { var modulus big.Int var aIntMod big.Int diff --git a/ecc/bls12-377/fr/vector.go b/ecc/bls12-377/fr/vector.go index f39828547..867cabbc3 100644 --- a/ecc/bls12-377/fr/vector.go +++ b/ecc/bls12-377/fr/vector.go @@ -226,6 +226,32 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + +func innerProductVecGeneric(res *Element, a, b Vector) { + if len(a) != len(b) { + panic("vector.InnerProduct: vectors don't have the same length") + } + var tmp Element + for i := 0; i < len(a); i++ { + tmp.Mul(&a[i], &b[i]) + res.Add(res, &tmp) + } +} + +func mulVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Mul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], &b[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bls12-377/fr/vector_test.go b/ecc/bls12-377/fr/vector_test.go index e58f2d9a3..b6344c18b 100644 --- a/ecc/bls12-377/fr/vector_test.go +++ b/ecc/bls12-377/fr/vector_test.go @@ -18,10 +18,15 @@ package fr import ( "bytes" + "fmt" "github.com/stretchr/testify/require" + "os" "reflect" "sort" "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" ) func TestVectorSort(t *testing.T) { @@ -88,3 +93,279 @@ func (vector *Vector) unmarshalBinaryAsync(data []byte) error { } return <-chErr } + +func TestVectorOps(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = 2 + } else { + parameters.MinSuccessfulTests = 10 + } + properties := gopter.NewProperties(parameters) + + addVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Add(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Add(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + subVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Sub(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Sub(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + scalarMulVector := func(a Vector, b Element) bool { + c := make(Vector, len(a)) + c.ScalarMul(a, &b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sumVector := func(a Vector) bool { + var sum Element + computed := a.Sum() + for i := 0; i < len(a); i++ { + sum.Add(&sum, &a[i]) + } + + return sum.Equal(&computed) + } + + innerProductVector := func(a, b Vector) bool { + computed := a.InnerProduct(b) + var innerProduct Element + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + innerProduct.Add(&innerProduct, &tmp) + } + + return innerProduct.Equal(&computed) + } + + mulVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + a[0].SetUint64(0x24) + b[0].SetUint64(0x42) + c.Mul(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sizes := []int{1, 2, 3, 4, 8, 9, 15, 16, 509, 510, 511, 512, 513, 514} + type genPair struct { + g1, g2 gopter.Gen + label string + } + + for _, size := range sizes { + generators := []genPair{ + {genZeroVector(size), genZeroVector(size), "zero vectors"}, + {genMaxVector(size), genMaxVector(size), "max vectors"}, + {genVector(size), genVector(size), "random vectors"}, + {genVector(size), genZeroVector(size), "random and zero vectors"}, + } + for _, gp := range generators { + properties.Property(fmt.Sprintf("vector addition %d - %s", size, gp.label), prop.ForAll( + addVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector subtraction %d - %s", size, gp.label), prop.ForAll( + subVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector scalar multiplication %d - %s", size, gp.label), prop.ForAll( + scalarMulVector, + gp.g1, + genElement(), + )) + + properties.Property(fmt.Sprintf("vector sum %d - %s", size, gp.label), prop.ForAll( + sumVector, + gp.g1, + )) + + properties.Property(fmt.Sprintf("vector inner product %d - %s", size, gp.label), prop.ForAll( + innerProductVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector multiplication %d - %s", size, gp.label), prop.ForAll( + mulVector, + gp.g1, + gp.g2, + )) + } + } + + properties.TestingRun(t, gopter.NewFormatedReporter(false, 260, os.Stdout)) +} + +func BenchmarkVectorOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1 << 24 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + var mixer Element + mixer.SetRandom() + for i := 1; i < N; i++ { + a1[i-1].SetUint64(uint64(i)). + Mul(&a1[i-1], &mixer) + b1[i-1].SetUint64(^uint64(i)). + Mul(&b1[i-1], &mixer) + } + + for n := 1 << 4; n <= N; n <<= 1 { + b.Run(fmt.Sprintf("add %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Add(_a, _b) + } + }) + + b.Run(fmt.Sprintf("sub %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Sub(_a, _b) + } + }) + + b.Run(fmt.Sprintf("scalarMul %d", n), func(b *testing.B) { + _a := a1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.ScalarMul(_a, &mixer) + } + }) + + b.Run(fmt.Sprintf("sum %d", n), func(b *testing.B) { + _a := a1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.Sum() + } + }) + + b.Run(fmt.Sprintf("innerProduct %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.InnerProduct(_b) + } + }) + + b.Run(fmt.Sprintf("mul %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Mul(_a, _b) + } + }) + } +} + +func genZeroVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genMaxVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + + qMinusOne := qElement + qMinusOne[0]-- + + for i := 0; i < size; i++ { + g[i] = qMinusOne + } + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + mixer := Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + mixer[3] %= (qElement[3] + 1) + } + + for !mixer.smallerThanModulus() { + mixer = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + mixer[3] %= (qElement[3] + 1) + } + } + + for i := 1; i <= size; i++ { + g[i-1].SetUint64(uint64(i)). + Mul(&g[i-1], &mixer) + } + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} diff --git a/ecc/bls12-377/internal/fptower/e2_amd64.go b/ecc/bls12-377/internal/fptower/e2_amd64.go index ac68ffa57..7df0c375f 100644 --- a/ecc/bls12-377/internal/fptower/e2_amd64.go +++ b/ecc/bls12-377/internal/fptower/e2_amd64.go @@ -16,6 +16,33 @@ package fptower +import ( + "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" +) + +// q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r +// used for Montgomery reduction +const qInvNeg uint64 = 9586122913090633727 + +// Field modulus q (Fp) +const ( + q0 uint64 = 9586122913090633729 + q1 uint64 = 1660523435060625408 + q2 uint64 = 2230234197602682880 + q3 uint64 = 1883307231910630287 + q4 uint64 = 14284016967150029115 + q5 uint64 = 121098312706494698 +) + +var qElement = fp.Element{ + q0, + q1, + q2, + q3, + q4, + q5, +} + //go:noescape func addE2(res, x, y *E2) diff --git a/ecc/bls12-377/internal/fptower/e2_amd64.s b/ecc/bls12-377/internal/fptower/e2_amd64.s index 053bd8ded..b52a56a85 100644 --- a/ecc/bls12-377/internal/fptower/e2_amd64.s +++ b/ecc/bls12-377/internal/fptower/e2_amd64.s @@ -14,39 +14,27 @@ #include "textflag.h" #include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x8508c00000000001 -DATA q<>+8(SB)/8, $0x170b5d4430000000 -DATA q<>+16(SB)/8, $0x1ef3622fba094800 -DATA q<>+24(SB)/8, $0x1a22d9f300f5138f -DATA q<>+32(SB)/8, $0xc63b05c06ca1493b -DATA q<>+40(SB)/8, $0x01ae3a4617c510ea -GLOBL q<>(SB), (RODATA+NOPTR), $48 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x8508bfffffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 +#include "go_asm.h" #define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, rb0, rb1, rb2, rb3, rb4, rb5) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ + MOVQ ra0, rb0; \ + SUBQ ·qElement(SB), ra0; \ + MOVQ ra1, rb1; \ + SBBQ ·qElement+8(SB), ra1; \ + MOVQ ra2, rb2; \ + SBBQ ·qElement+16(SB), ra2; \ + MOVQ ra3, rb3; \ + SBBQ ·qElement+24(SB), ra3; \ + MOVQ ra4, rb4; \ + SBBQ ·qElement+32(SB), ra4; \ + MOVQ ra5, rb5; \ + SBBQ ·qElement+40(SB), ra5; \ + CMOVQCS rb0, ra0; \ + CMOVQCS rb1, ra1; \ + CMOVQCS rb2, ra2; \ + CMOVQCS rb3, ra3; \ + CMOVQCS rb4, ra4; \ + CMOVQCS rb5, ra5; \ TEXT ·addE2(SB), NOSPLIT, $0-24 MOVQ x+8(FP), AX diff --git a/ecc/bls12-381/fp/asm.go b/ecc/bls12-381/fp/asm_adx.go similarity index 100% rename from ecc/bls12-381/fp/asm.go rename to ecc/bls12-381/fp/asm_adx.go diff --git a/ecc/bls12-381/fp/element.go b/ecc/bls12-381/fp/element.go index f5c2df0c2..f0bcfe51b 100644 --- a/ecc/bls12-381/fp/element.go +++ b/ecc/bls12-381/fp/element.go @@ -521,32 +521,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [7]uint64 var D uint64 diff --git a/ecc/bls12-381/fp/element_mul_amd64.s b/ecc/bls12-381/fp/element_mul_amd64.s deleted file mode 100644 index e95c98403..000000000 --- a/ecc/bls12-381/fp/element_mul_amd64.s +++ /dev/null @@ -1,857 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0xb9feffffffffaaab -DATA q<>+8(SB)/8, $0x1eabfffeb153ffff -DATA q<>+16(SB)/8, $0x6730d2a0f6b0f624 -DATA q<>+24(SB)/8, $0x64774b84f38512bf -DATA q<>+32(SB)/8, $0x4b1ba7b6434bacd7 -DATA q<>+40(SB)/8, $0x1a0111ea397fe69a -GLOBL q<>(SB), (RODATA+NOPTR), $48 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x89f3fffcfffcfffd -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, rb0, rb1, rb2, rb3, rb4, rb5) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $24-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), R8 - - // x[0] -> R10 - // x[1] -> R11 - // x[2] -> R12 - MOVQ 0(R8), R10 - MOVQ 8(R8), R11 - MOVQ 16(R8), R12 - MOVQ y+16(FP), R13 - - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // t[4] -> SI - // t[5] -> DI - // clear the flags - XORQ AX, AX - MOVQ 0(R13), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R10, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R11, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R12, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ 24(R8), AX, SI - ADOXQ AX, BX - - // (A,t[4]) := x[4]*y[0] + A - MULXQ 32(R8), AX, DI - ADOXQ AX, SI - - // (A,t[5]) := x[5]*y[0] + A - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 8(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[1] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 16(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[2] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 24(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[3] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 32(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[4] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 40(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[5] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[5] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[5] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[5] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[5] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[5] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // reduce element(R14,R15,CX,BX,SI,DI) using temp registers (R9,R8,R13,R10,R11,R12) - REDUCE(R14,R15,CX,BX,SI,DI,R9,R8,R13,R10,R11,R12) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $8-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R15 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - MOVQ 32(DX), SI - MOVQ 40(DX), DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - - // reduce element(R14,R15,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12,R13) - REDUCE(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/bls12-381/fp/element_ops_amd64.go b/ecc/bls12-381/fp/element_ops_amd64.go index 83bba45ae..ed2803d71 100644 --- a/ecc/bls12-381/fp/element_ops_amd64.go +++ b/ecc/bls12-381/fp/element_ops_amd64.go @@ -50,48 +50,8 @@ func Butterfly(a, b *Element) // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bls12-381/fp/element_ops_amd64.s b/ecc/bls12-381/fp/element_ops_amd64.s index 830b2dd63..cabff26f7 100644 --- a/ecc/bls12-381/fp/element_ops_amd64.s +++ b/ecc/bls12-381/fp/element_ops_amd64.s @@ -1,306 +1,6 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// Code generated by gnark-crypto/generator. DO NOT EDIT. +// We include the hash to force the Go compiler to recompile: 11124594824487954849 +#include "../../../field/asm/element_6w_amd64.s" -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0xb9feffffffffaaab -DATA q<>+8(SB)/8, $0x1eabfffeb153ffff -DATA q<>+16(SB)/8, $0x6730d2a0f6b0f624 -DATA q<>+24(SB)/8, $0x64774b84f38512bf -DATA q<>+32(SB)/8, $0x4b1ba7b6434bacd7 -DATA q<>+40(SB)/8, $0x1a0111ea397fe69a -GLOBL q<>(SB), (RODATA+NOPTR), $48 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x89f3fffcfffcfffd -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, rb0, rb1, rb2, rb3, rb4, rb5) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R15,R9,R10,R11,R12,R13) - REDUCE(DX,CX,BX,SI,DI,R8,R15,R9,R10,R11,R12,R13) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R15,R9,R10,R11,R12,R13) - REDUCE(DX,CX,BX,SI,DI,R8,R15,R9,R10,R11,R12,R13) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R14,R15,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R14,R15,R9,R10,R11,R12) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), $40-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP)) - - MOVQ DX, R15 - MOVQ CX, s0-8(SP) - MOVQ BX, s1-16(SP) - MOVQ SI, s2-24(SP) - MOVQ DI, s3-32(SP) - MOVQ R8, s4-40(SP) - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ R15, DX - ADCQ s0-8(SP), CX - ADCQ s1-16(SP), BX - ADCQ s2-24(SP), SI - ADCQ s3-32(SP), DI - ADCQ s4-40(SP), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), $48-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ 32(AX), R8 - MOVQ 40(AX), R9 - MOVQ CX, R10 - MOVQ BX, R11 - MOVQ SI, R12 - MOVQ DI, R13 - MOVQ R8, R14 - MOVQ R9, R15 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - ADCQ 32(DX), R8 - ADCQ 40(DX), R9 - SUBQ 0(DX), R10 - SBBQ 8(DX), R11 - SBBQ 16(DX), R12 - SBBQ 24(DX), R13 - SBBQ 32(DX), R14 - SBBQ 40(DX), R15 - MOVQ CX, s0-8(SP) - MOVQ BX, s1-16(SP) - MOVQ SI, s2-24(SP) - MOVQ DI, s3-32(SP) - MOVQ R8, s4-40(SP) - MOVQ R9, s5-48(SP) - MOVQ $0xb9feffffffffaaab, CX - MOVQ $0x1eabfffeb153ffff, BX - MOVQ $0x6730d2a0f6b0f624, SI - MOVQ $0x64774b84f38512bf, DI - MOVQ $0x4b1ba7b6434bacd7, R8 - MOVQ $0x1a0111ea397fe69a, R9 - CMOVQCC AX, CX - CMOVQCC AX, BX - CMOVQCC AX, SI - CMOVQCC AX, DI - CMOVQCC AX, R8 - CMOVQCC AX, R9 - ADDQ CX, R10 - ADCQ BX, R11 - ADCQ SI, R12 - ADCQ DI, R13 - ADCQ R8, R14 - ADCQ R9, R15 - MOVQ s0-8(SP), CX - MOVQ s1-16(SP), BX - MOVQ s2-24(SP), SI - MOVQ s3-32(SP), DI - MOVQ s4-40(SP), R8 - MOVQ s5-48(SP), R9 - MOVQ R10, 0(DX) - MOVQ R11, 8(DX) - MOVQ R12, 16(DX) - MOVQ R13, 24(DX) - MOVQ R14, 32(DX) - MOVQ R15, 40(DX) - - // reduce element(CX,BX,SI,DI,R8,R9) using temp registers (R10,R11,R12,R13,R14,R15) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - MOVQ R8, 32(AX) - MOVQ R9, 40(AX) - RET diff --git a/ecc/bls12-381/fp/element_ops_purego.go b/ecc/bls12-381/fp/element_ops_purego.go index fc10b3df3..ee3f7e740 100644 --- a/ecc/bls12-381/fp/element_ops_purego.go +++ b/ecc/bls12-381/fp/element_ops_purego.go @@ -67,48 +67,8 @@ func reduce(z *Element) { // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3, t4, t5 uint64 var u0, u1, u2, u3, u4, u5 uint64 diff --git a/ecc/bls12-381/fp/element_test.go b/ecc/bls12-381/fp/element_test.go index d070a1814..af57409da 100644 --- a/ecc/bls12-381/fp/element_test.go +++ b/ecc/bls12-381/fp/element_test.go @@ -641,7 +641,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -711,77 +710,6 @@ func TestElementLexicographicallyLargest(t *testing.T) { } -func TestElementVecOps(t *testing.T) { - assert := require.New(t) - - const N = 7 - a := make(Vector, N) - b := make(Vector, N) - c := make(Vector, N) - for i := 0; i < N; i++ { - a[i].SetRandom() - b[i].SetRandom() - } - - // Vector addition - c.Add(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Add(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector addition failed") - } - - // Vector subtraction - c.Sub(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Sub(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector subtraction failed") - } - - // Vector scaling - c.ScalarMul(a, &b[0]) - for i := 0; i < N; i++ { - var expected Element - expected.Mul(&a[i], &b[0]) - assert.True(c[i].Equal(&expected), "Vector scaling failed") - } -} - -func BenchmarkElementVecOps(b *testing.B) { - // note; to benchmark against "no asm" version, use the following - // build tag: -tags purego - const N = 1024 - a1 := make(Vector, N) - b1 := make(Vector, N) - c1 := make(Vector, N) - for i := 0; i < N; i++ { - a1[i].SetRandom() - b1[i].SetRandom() - } - - b.Run("Add", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Add(a1, b1) - } - }) - - b.Run("Sub", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Sub(a1, b1) - } - }) - - b.Run("ScalarMul", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.ScalarMul(a1, &b1[0]) - } - }) -} - func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() @@ -2353,42 +2281,42 @@ func gen() gopter.Gen { } } -func genFull() gopter.Gen { - return func(genParams *gopter.GenParameters) *gopter.GenResult { +func genRandomFq(genParams *gopter.GenParameters) Element { + var g Element - genRandomFq := func() Element { - var g Element + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } + if qElement[5] != ^uint64(0) { + g[5] %= (qElement[5] + 1) + } - if qElement[5] != ^uint64(0) { - g[5] %= (qElement[5] + 1) - } + for !g.smallerThanModulus() { + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[5] != ^uint64(0) { + g[5] %= (qElement[5] + 1) + } + } - for !g.smallerThanModulus() { - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } - if qElement[5] != ^uint64(0) { - g[5] %= (qElement[5] + 1) - } - } + return g +} - return g - } - a := genRandomFq() +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) var carry uint64 a[0], carry = bits.Add64(a[0], qElement[0], carry) @@ -2403,6 +2331,14 @@ func genFull() gopter.Gen { } } +func genElement() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} + func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { var modulus big.Int var aIntMod big.Int diff --git a/ecc/bls12-381/fp/vector.go b/ecc/bls12-381/fp/vector.go index 0df05e337..f1d659e76 100644 --- a/ecc/bls12-381/fp/vector.go +++ b/ecc/bls12-381/fp/vector.go @@ -219,6 +219,25 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} + func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") @@ -246,6 +265,32 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + +func innerProductVecGeneric(res *Element, a, b Vector) { + if len(a) != len(b) { + panic("vector.InnerProduct: vectors don't have the same length") + } + var tmp Element + for i := 0; i < len(a); i++ { + tmp.Mul(&a[i], &b[i]) + res.Add(res, &tmp) + } +} + +func mulVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Mul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], &b[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bls12-381/fp/vector_test.go b/ecc/bls12-381/fp/vector_test.go index 5d88af91c..a8deef945 100644 --- a/ecc/bls12-381/fp/vector_test.go +++ b/ecc/bls12-381/fp/vector_test.go @@ -18,10 +18,15 @@ package fp import ( "bytes" + "fmt" "github.com/stretchr/testify/require" + "os" "reflect" "sort" "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" ) func TestVectorSort(t *testing.T) { @@ -88,3 +93,283 @@ func (vector *Vector) unmarshalBinaryAsync(data []byte) error { } return <-chErr } + +func TestVectorOps(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = 2 + } else { + parameters.MinSuccessfulTests = 10 + } + properties := gopter.NewProperties(parameters) + + addVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Add(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Add(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + subVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Sub(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Sub(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + scalarMulVector := func(a Vector, b Element) bool { + c := make(Vector, len(a)) + c.ScalarMul(a, &b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sumVector := func(a Vector) bool { + var sum Element + computed := a.Sum() + for i := 0; i < len(a); i++ { + sum.Add(&sum, &a[i]) + } + + return sum.Equal(&computed) + } + + innerProductVector := func(a, b Vector) bool { + computed := a.InnerProduct(b) + var innerProduct Element + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + innerProduct.Add(&innerProduct, &tmp) + } + + return innerProduct.Equal(&computed) + } + + mulVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + a[0].SetUint64(0x24) + b[0].SetUint64(0x42) + c.Mul(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sizes := []int{1, 2, 3, 4, 8, 9, 15, 16, 509, 510, 511, 512, 513, 514} + type genPair struct { + g1, g2 gopter.Gen + label string + } + + for _, size := range sizes { + generators := []genPair{ + {genZeroVector(size), genZeroVector(size), "zero vectors"}, + {genMaxVector(size), genMaxVector(size), "max vectors"}, + {genVector(size), genVector(size), "random vectors"}, + {genVector(size), genZeroVector(size), "random and zero vectors"}, + } + for _, gp := range generators { + properties.Property(fmt.Sprintf("vector addition %d - %s", size, gp.label), prop.ForAll( + addVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector subtraction %d - %s", size, gp.label), prop.ForAll( + subVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector scalar multiplication %d - %s", size, gp.label), prop.ForAll( + scalarMulVector, + gp.g1, + genElement(), + )) + + properties.Property(fmt.Sprintf("vector sum %d - %s", size, gp.label), prop.ForAll( + sumVector, + gp.g1, + )) + + properties.Property(fmt.Sprintf("vector inner product %d - %s", size, gp.label), prop.ForAll( + innerProductVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector multiplication %d - %s", size, gp.label), prop.ForAll( + mulVector, + gp.g1, + gp.g2, + )) + } + } + + properties.TestingRun(t, gopter.NewFormatedReporter(false, 260, os.Stdout)) +} + +func BenchmarkVectorOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1 << 24 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + var mixer Element + mixer.SetRandom() + for i := 1; i < N; i++ { + a1[i-1].SetUint64(uint64(i)). + Mul(&a1[i-1], &mixer) + b1[i-1].SetUint64(^uint64(i)). + Mul(&b1[i-1], &mixer) + } + + for n := 1 << 4; n <= N; n <<= 1 { + b.Run(fmt.Sprintf("add %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Add(_a, _b) + } + }) + + b.Run(fmt.Sprintf("sub %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Sub(_a, _b) + } + }) + + b.Run(fmt.Sprintf("scalarMul %d", n), func(b *testing.B) { + _a := a1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.ScalarMul(_a, &mixer) + } + }) + + b.Run(fmt.Sprintf("sum %d", n), func(b *testing.B) { + _a := a1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.Sum() + } + }) + + b.Run(fmt.Sprintf("innerProduct %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.InnerProduct(_b) + } + }) + + b.Run(fmt.Sprintf("mul %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Mul(_a, _b) + } + }) + } +} + +func genZeroVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genMaxVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + + qMinusOne := qElement + qMinusOne[0]-- + + for i := 0; i < size; i++ { + g[i] = qMinusOne + } + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + mixer := Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[5] != ^uint64(0) { + mixer[5] %= (qElement[5] + 1) + } + + for !mixer.smallerThanModulus() { + mixer = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[5] != ^uint64(0) { + mixer[5] %= (qElement[5] + 1) + } + } + + for i := 1; i <= size; i++ { + g[i-1].SetUint64(uint64(i)). + Mul(&g[i-1], &mixer) + } + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} diff --git a/ecc/bls12-381/fr/asm.go b/ecc/bls12-381/fr/asm_adx.go similarity index 100% rename from ecc/bls12-381/fr/asm.go rename to ecc/bls12-381/fr/asm_adx.go diff --git a/ecc/bls12-381/fr/asm_avx.go b/ecc/bls12-381/fr/asm_avx.go new file mode 100644 index 000000000..955f55979 --- /dev/null +++ b/ecc/bls12-381/fr/asm_avx.go @@ -0,0 +1,27 @@ +//go:build !noavx +// +build !noavx + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import "golang.org/x/sys/cpu" + +var ( + supportAvx512 = supportAdx && cpu.X86.HasAVX512 && cpu.X86.HasAVX512DQ + _ = supportAvx512 +) diff --git a/ecc/bls12-381/fr/asm_noavx.go b/ecc/bls12-381/fr/asm_noavx.go new file mode 100644 index 000000000..e5a5b1f2c --- /dev/null +++ b/ecc/bls12-381/fr/asm_noavx.go @@ -0,0 +1,22 @@ +//go:build noavx +// +build noavx + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +const supportAvx512 = false diff --git a/ecc/bls12-381/fr/element.go b/ecc/bls12-381/fr/element.go index aa6c47cdd..dc38f08cd 100644 --- a/ecc/bls12-381/fr/element.go +++ b/ecc/bls12-381/fr/element.go @@ -81,6 +81,9 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 18446744069414584319 +// mu = 2^288 / q needed for partial Barrett reduction +const mu uint64 = 9484408045 + func init() { _modulus.SetString("73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001", 16) } @@ -477,32 +480,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [5]uint64 var D uint64 diff --git a/ecc/bls12-381/fr/element_mul_amd64.s b/ecc/bls12-381/fr/element_mul_amd64.s deleted file mode 100644 index 396d990b7..000000000 --- a/ecc/bls12-381/fr/element_mul_amd64.s +++ /dev/null @@ -1,487 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0xffffffff00000001 -DATA q<>+8(SB)/8, $0x53bda402fffe5bfe -DATA q<>+16(SB)/8, $0x3339d80809a1d805 -DATA q<>+24(SB)/8, $0x73eda753299d7d48 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0xfffffffeffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $24-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), SI - - // x[0] -> DI - // x[1] -> R8 - // x[2] -> R9 - // x[3] -> R10 - MOVQ 0(SI), DI - MOVQ 8(SI), R8 - MOVQ 16(SI), R9 - MOVQ 24(SI), R10 - MOVQ y+16(FP), R11 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ DI, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R8, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R9, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,R12,R11,DI) - REDUCE(R14,R13,CX,BX,SI,R12,R11,DI) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $8-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,DI,R8,R9) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/bls12-381/fr/element_ops_amd64.go b/ecc/bls12-381/fr/element_ops_amd64.go index 21568255d..b653e8006 100644 --- a/ecc/bls12-381/fr/element_ops_amd64.go +++ b/ecc/bls12-381/fr/element_ops_amd64.go @@ -51,7 +51,8 @@ func (vector *Vector) Add(a, b Vector) { if len(a) != len(b) || len(a) != len(*vector) { panic("vector.Add: vectors don't have the same length") } - addVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) + n := uint64(len(a)) + addVec(&(*vector)[0], &a[0], &b[0], n) } //go:noescape @@ -75,59 +76,123 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { if len(a) != len(*vector) { panic("vector.ScalarMul: vectors don't have the same length") } - scalarMulVec(&(*vector)[0], &a[0], b, uint64(len(a))) + const maxN = (1 << 32) - 1 + if !supportAvx512 || uint64(len(a)) >= maxN { + // call scalarMulVecGeneric + scalarMulVecGeneric(*vector, a, b) + return + } + n := uint64(len(a)) + if n == 0 { + return + } + // the code for scalarMul is identical to mulVec; and it expects at least + // 2 elements in the vector to fill the Z registers + var bb [2]Element + bb[0] = *b + bb[1] = *b + const blockSize = 16 + scalarMulVec(&(*vector)[0], &a[0], &bb[0], n/blockSize, qInvNeg) + if n%blockSize != 0 { + // call scalarMulVecGeneric on the rest + start := n - n%blockSize + scalarMulVecGeneric((*vector)[start:], a[start:], b) + } +} + +//go:noescape +func scalarMulVec(res, a, b *Element, n uint64, qInvNeg uint64) + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + n := uint64(len(*vector)) + if n == 0 { + return + } + const minN = 16 * 7 // AVX512 slower than generic for small n + const maxN = (1 << 32) - 1 + if !supportAvx512 || n <= minN || n >= maxN { + // call sumVecGeneric + sumVecGeneric(&res, *vector) + return + } + sumVec(&res, &(*vector)[0], uint64(len(*vector))) + return } //go:noescape -func scalarMulVec(res, a, b *Element, n uint64) +func sumVec(res *Element, a *Element, n uint64) + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + n := uint64(len(*vector)) + if n == 0 { + return + } + if n != uint64(len(other)) { + panic("vector.InnerProduct: vectors don't have the same length") + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call innerProductVecGeneric + // note; we could split the vector into smaller chunks and call innerProductVec + innerProductVecGeneric(&res, *vector, other) + return + } + innerProdVec(&res[0], &(*vector)[0], &other[0], uint64(len(*vector))) + + return +} + +//go:noescape +func innerProdVec(res *uint64, a, b *Element, n uint64) + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Mul: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call mulVecGeneric + mulVecGeneric(*vector, a, b) + return + } + + const blockSize = 16 + mulVec(&(*vector)[0], &a[0], &b[0], n/blockSize, qInvNeg) + if n%blockSize != 0 { + // call mulVecGeneric on the rest + start := n - n%blockSize + mulVecGeneric((*vector)[start:], a[start:], b[start:]) + } + +} + +// Patterns use for transposing the vectors in mulVec +var ( + pattern1 = [8]uint64{0, 8, 1, 9, 2, 10, 3, 11} + pattern2 = [8]uint64{12, 4, 13, 5, 14, 6, 15, 7} + pattern3 = [8]uint64{0, 1, 8, 9, 2, 3, 10, 11} + pattern4 = [8]uint64{12, 13, 4, 5, 14, 15, 6, 7} +) + +//go:noescape +func mulVec(res, a, b *Element, n uint64, qInvNeg uint64) // Mul z = x * y (mod q) // // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bls12-381/fr/element_ops_amd64.s b/ecc/bls12-381/fr/element_ops_amd64.s index caffb72b1..6c42136a7 100644 --- a/ecc/bls12-381/fr/element_ops_amd64.s +++ b/ecc/bls12-381/fr/element_ops_amd64.s @@ -1,627 +1,6 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// Code generated by gnark-crypto/generator. DO NOT EDIT. +// We include the hash to force the Go compiler to recompile: 9425145785761608449 +#include "../../../field/asm/element_4w_amd64.s" -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0xffffffff00000001 -DATA q<>+8(SB)/8, $0x53bda402fffe5bfe -DATA q<>+16(SB)/8, $0x3339d80809a1d805 -DATA q<>+24(SB)/8, $0x73eda753299d7d48 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0xfffffffeffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R15,DI,R8,R9) - REDUCE(DX,CX,BX,SI,R15,DI,R8,R9) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, R11 - MOVQ CX, R12 - MOVQ BX, R13 - MOVQ SI, R14 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ R11, DX - ADCQ R12, CX - ADCQ R13, BX - ADCQ R14, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), NOSPLIT, $0-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ CX, R8 - MOVQ BX, R9 - MOVQ SI, R10 - MOVQ DI, R11 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - SUBQ 0(DX), R8 - SBBQ 8(DX), R9 - SBBQ 16(DX), R10 - SBBQ 24(DX), R11 - MOVQ $0xffffffff00000001, R12 - MOVQ $0x53bda402fffe5bfe, R13 - MOVQ $0x3339d80809a1d805, R14 - MOVQ $0x73eda753299d7d48, R15 - CMOVQCC AX, R12 - CMOVQCC AX, R13 - CMOVQCC AX, R14 - CMOVQCC AX, R15 - ADDQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - ADCQ R15, R11 - MOVQ R8, 0(DX) - MOVQ R9, 8(DX) - MOVQ R10, 16(DX) - MOVQ R11, 24(DX) - - // reduce element(CX,BX,SI,DI) using temp registers (R8,R9,R10,R11) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - RET - -// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] -TEXT ·addVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - -loop_1: - TESTQ BX, BX - JEQ done_2 // n == 0, we are done - - // a[0] -> SI - // a[1] -> DI - // a[2] -> R8 - // a[3] -> R9 - MOVQ 0(AX), SI - MOVQ 8(AX), DI - MOVQ 16(AX), R8 - MOVQ 24(AX), R9 - ADDQ 0(DX), SI - ADCQ 8(DX), DI - ADCQ 16(DX), R8 - ADCQ 24(DX), R9 - - // reduce element(SI,DI,R8,R9) using temp registers (R10,R11,R12,R13) - REDUCE(SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ SI, 0(CX) - MOVQ DI, 8(CX) - MOVQ R8, 16(CX) - MOVQ R9, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_1 - -done_2: - RET - -// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] -TEXT ·subVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - XORQ SI, SI - -loop_3: - TESTQ BX, BX - JEQ done_4 // n == 0, we are done - - // a[0] -> DI - // a[1] -> R8 - // a[2] -> R9 - // a[3] -> R10 - MOVQ 0(AX), DI - MOVQ 8(AX), R8 - MOVQ 16(AX), R9 - MOVQ 24(AX), R10 - SUBQ 0(DX), DI - SBBQ 8(DX), R8 - SBBQ 16(DX), R9 - SBBQ 24(DX), R10 - - // reduce (a-b) mod q - // q[0] -> R11 - // q[1] -> R12 - // q[2] -> R13 - // q[3] -> R14 - MOVQ $0xffffffff00000001, R11 - MOVQ $0x53bda402fffe5bfe, R12 - MOVQ $0x3339d80809a1d805, R13 - MOVQ $0x73eda753299d7d48, R14 - CMOVQCC SI, R11 - CMOVQCC SI, R12 - CMOVQCC SI, R13 - CMOVQCC SI, R14 - - // add registers (q or 0) to a, and set to result - ADDQ R11, DI - ADCQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - MOVQ DI, 0(CX) - MOVQ R8, 8(CX) - MOVQ R9, 16(CX) - MOVQ R10, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_3 - -done_4: - RET - -// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b -TEXT ·scalarMulVec(SB), $56-32 - CMPB ·supportAdx(SB), $1 - JNE noAdx_5 - MOVQ a+8(FP), R11 - MOVQ b+16(FP), R10 - MOVQ n+24(FP), R12 - - // scalar[0] -> SI - // scalar[1] -> DI - // scalar[2] -> R8 - // scalar[3] -> R9 - MOVQ 0(R10), SI - MOVQ 8(R10), DI - MOVQ 16(R10), R8 - MOVQ 24(R10), R9 - MOVQ res+0(FP), R10 - -loop_6: - TESTQ R12, R12 - JEQ done_7 // n == 0, we are done - - // TODO @gbotrel this is generated from the same macro as the unit mul, we should refactor this in a single asm function - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ SI, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ DI, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R8, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce t mod q - // reduce element(R14,R15,CX,BX) using temp registers (R13,AX,DX,s0-8(SP)) - REDUCE(R14,R15,CX,BX,R13,AX,DX,s0-8(SP)) - - MOVQ R14, 0(R10) - MOVQ R15, 8(R10) - MOVQ CX, 16(R10) - MOVQ BX, 24(R10) - - // increment pointers to visit next element - ADDQ $32, R11 - ADDQ $32, R10 - DECQ R12 // decrement n - JMP loop_6 - -done_7: - RET - -noAdx_5: - MOVQ n+24(FP), DX - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ DX, 8(SP) - MOVQ DX, 16(SP) - MOVQ a+8(FP), AX - MOVQ AX, 24(SP) - MOVQ DX, 32(SP) - MOVQ DX, 40(SP) - MOVQ b+16(FP), AX - MOVQ AX, 48(SP) - CALL ·scalarMulVecGeneric(SB) - RET diff --git a/ecc/bls12-381/fr/element_ops_purego.go b/ecc/bls12-381/fr/element_ops_purego.go index 50e839865..8c1049643 100644 --- a/ecc/bls12-381/fr/element_ops_purego.go +++ b/ecc/bls12-381/fr/element_ops_purego.go @@ -78,53 +78,32 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} + // Mul z = x * y (mod q) // // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3 uint64 var u0, u1, u2, u3 uint64 diff --git a/ecc/bls12-381/fr/element_test.go b/ecc/bls12-381/fr/element_test.go index 684ea1525..b9bc1e397 100644 --- a/ecc/bls12-381/fr/element_test.go +++ b/ecc/bls12-381/fr/element_test.go @@ -637,7 +637,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -707,77 +706,6 @@ func TestElementLexicographicallyLargest(t *testing.T) { } -func TestElementVecOps(t *testing.T) { - assert := require.New(t) - - const N = 7 - a := make(Vector, N) - b := make(Vector, N) - c := make(Vector, N) - for i := 0; i < N; i++ { - a[i].SetRandom() - b[i].SetRandom() - } - - // Vector addition - c.Add(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Add(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector addition failed") - } - - // Vector subtraction - c.Sub(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Sub(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector subtraction failed") - } - - // Vector scaling - c.ScalarMul(a, &b[0]) - for i := 0; i < N; i++ { - var expected Element - expected.Mul(&a[i], &b[0]) - assert.True(c[i].Equal(&expected), "Vector scaling failed") - } -} - -func BenchmarkElementVecOps(b *testing.B) { - // note; to benchmark against "no asm" version, use the following - // build tag: -tags purego - const N = 1024 - a1 := make(Vector, N) - b1 := make(Vector, N) - c1 := make(Vector, N) - for i := 0; i < N; i++ { - a1[i].SetRandom() - b1[i].SetRandom() - } - - b.Run("Add", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Add(a1, b1) - } - }) - - b.Run("Sub", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Sub(a1, b1) - } - }) - - b.Run("ScalarMul", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.ScalarMul(a1, &b1[0]) - } - }) -} - func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() @@ -2345,38 +2273,38 @@ func gen() gopter.Gen { } } -func genFull() gopter.Gen { - return func(genParams *gopter.GenParameters) *gopter.GenResult { +func genRandomFq(genParams *gopter.GenParameters) Element { + var g Element - genRandomFq := func() Element { - var g Element + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } - if qElement[3] != ^uint64(0) { - g[3] %= (qElement[3] + 1) - } + for !g.smallerThanModulus() { + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } + } - for !g.smallerThanModulus() { - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } - if qElement[3] != ^uint64(0) { - g[3] %= (qElement[3] + 1) - } - } + return g +} - return g - } - a := genRandomFq() +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) var carry uint64 a[0], carry = bits.Add64(a[0], qElement[0], carry) @@ -2389,6 +2317,14 @@ func genFull() gopter.Gen { } } +func genElement() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} + func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { var modulus big.Int var aIntMod big.Int diff --git a/ecc/bls12-381/fr/vector.go b/ecc/bls12-381/fr/vector.go index f39828547..867cabbc3 100644 --- a/ecc/bls12-381/fr/vector.go +++ b/ecc/bls12-381/fr/vector.go @@ -226,6 +226,32 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + +func innerProductVecGeneric(res *Element, a, b Vector) { + if len(a) != len(b) { + panic("vector.InnerProduct: vectors don't have the same length") + } + var tmp Element + for i := 0; i < len(a); i++ { + tmp.Mul(&a[i], &b[i]) + res.Add(res, &tmp) + } +} + +func mulVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Mul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], &b[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bls12-381/fr/vector_test.go b/ecc/bls12-381/fr/vector_test.go index e58f2d9a3..b6344c18b 100644 --- a/ecc/bls12-381/fr/vector_test.go +++ b/ecc/bls12-381/fr/vector_test.go @@ -18,10 +18,15 @@ package fr import ( "bytes" + "fmt" "github.com/stretchr/testify/require" + "os" "reflect" "sort" "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" ) func TestVectorSort(t *testing.T) { @@ -88,3 +93,279 @@ func (vector *Vector) unmarshalBinaryAsync(data []byte) error { } return <-chErr } + +func TestVectorOps(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = 2 + } else { + parameters.MinSuccessfulTests = 10 + } + properties := gopter.NewProperties(parameters) + + addVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Add(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Add(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + subVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Sub(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Sub(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + scalarMulVector := func(a Vector, b Element) bool { + c := make(Vector, len(a)) + c.ScalarMul(a, &b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sumVector := func(a Vector) bool { + var sum Element + computed := a.Sum() + for i := 0; i < len(a); i++ { + sum.Add(&sum, &a[i]) + } + + return sum.Equal(&computed) + } + + innerProductVector := func(a, b Vector) bool { + computed := a.InnerProduct(b) + var innerProduct Element + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + innerProduct.Add(&innerProduct, &tmp) + } + + return innerProduct.Equal(&computed) + } + + mulVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + a[0].SetUint64(0x24) + b[0].SetUint64(0x42) + c.Mul(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sizes := []int{1, 2, 3, 4, 8, 9, 15, 16, 509, 510, 511, 512, 513, 514} + type genPair struct { + g1, g2 gopter.Gen + label string + } + + for _, size := range sizes { + generators := []genPair{ + {genZeroVector(size), genZeroVector(size), "zero vectors"}, + {genMaxVector(size), genMaxVector(size), "max vectors"}, + {genVector(size), genVector(size), "random vectors"}, + {genVector(size), genZeroVector(size), "random and zero vectors"}, + } + for _, gp := range generators { + properties.Property(fmt.Sprintf("vector addition %d - %s", size, gp.label), prop.ForAll( + addVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector subtraction %d - %s", size, gp.label), prop.ForAll( + subVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector scalar multiplication %d - %s", size, gp.label), prop.ForAll( + scalarMulVector, + gp.g1, + genElement(), + )) + + properties.Property(fmt.Sprintf("vector sum %d - %s", size, gp.label), prop.ForAll( + sumVector, + gp.g1, + )) + + properties.Property(fmt.Sprintf("vector inner product %d - %s", size, gp.label), prop.ForAll( + innerProductVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector multiplication %d - %s", size, gp.label), prop.ForAll( + mulVector, + gp.g1, + gp.g2, + )) + } + } + + properties.TestingRun(t, gopter.NewFormatedReporter(false, 260, os.Stdout)) +} + +func BenchmarkVectorOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1 << 24 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + var mixer Element + mixer.SetRandom() + for i := 1; i < N; i++ { + a1[i-1].SetUint64(uint64(i)). + Mul(&a1[i-1], &mixer) + b1[i-1].SetUint64(^uint64(i)). + Mul(&b1[i-1], &mixer) + } + + for n := 1 << 4; n <= N; n <<= 1 { + b.Run(fmt.Sprintf("add %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Add(_a, _b) + } + }) + + b.Run(fmt.Sprintf("sub %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Sub(_a, _b) + } + }) + + b.Run(fmt.Sprintf("scalarMul %d", n), func(b *testing.B) { + _a := a1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.ScalarMul(_a, &mixer) + } + }) + + b.Run(fmt.Sprintf("sum %d", n), func(b *testing.B) { + _a := a1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.Sum() + } + }) + + b.Run(fmt.Sprintf("innerProduct %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.InnerProduct(_b) + } + }) + + b.Run(fmt.Sprintf("mul %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Mul(_a, _b) + } + }) + } +} + +func genZeroVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genMaxVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + + qMinusOne := qElement + qMinusOne[0]-- + + for i := 0; i < size; i++ { + g[i] = qMinusOne + } + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + mixer := Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + mixer[3] %= (qElement[3] + 1) + } + + for !mixer.smallerThanModulus() { + mixer = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + mixer[3] %= (qElement[3] + 1) + } + } + + for i := 1; i <= size; i++ { + g[i-1].SetUint64(uint64(i)). + Mul(&g[i-1], &mixer) + } + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} diff --git a/ecc/bls12-381/internal/fptower/e2_amd64.go b/ecc/bls12-381/internal/fptower/e2_amd64.go index 5121f7cca..469d31927 100644 --- a/ecc/bls12-381/internal/fptower/e2_amd64.go +++ b/ecc/bls12-381/internal/fptower/e2_amd64.go @@ -16,6 +16,33 @@ package fptower +import ( + "github.com/consensys/gnark-crypto/ecc/bls12-381/fp" +) + +// q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r +// used for Montgomery reduction +const qInvNeg uint64 = 9940570264628428797 + +// Field modulus q (Fp) +const ( + q0 uint64 = 13402431016077863595 + q1 uint64 = 2210141511517208575 + q2 uint64 = 7435674573564081700 + q3 uint64 = 7239337960414712511 + q4 uint64 = 5412103778470702295 + q5 uint64 = 1873798617647539866 +) + +var qElement = fp.Element{ + q0, + q1, + q2, + q3, + q4, + q5, +} + //go:noescape func addE2(res, x, y *E2) diff --git a/ecc/bls12-381/internal/fptower/e2_amd64.s b/ecc/bls12-381/internal/fptower/e2_amd64.s index 7fc53f463..e90e3ed54 100644 --- a/ecc/bls12-381/internal/fptower/e2_amd64.s +++ b/ecc/bls12-381/internal/fptower/e2_amd64.s @@ -14,39 +14,27 @@ #include "textflag.h" #include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0xb9feffffffffaaab -DATA q<>+8(SB)/8, $0x1eabfffeb153ffff -DATA q<>+16(SB)/8, $0x6730d2a0f6b0f624 -DATA q<>+24(SB)/8, $0x64774b84f38512bf -DATA q<>+32(SB)/8, $0x4b1ba7b6434bacd7 -DATA q<>+40(SB)/8, $0x1a0111ea397fe69a -GLOBL q<>(SB), (RODATA+NOPTR), $48 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x89f3fffcfffcfffd -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 +#include "go_asm.h" #define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, rb0, rb1, rb2, rb3, rb4, rb5) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ + MOVQ ra0, rb0; \ + SUBQ ·qElement(SB), ra0; \ + MOVQ ra1, rb1; \ + SBBQ ·qElement+8(SB), ra1; \ + MOVQ ra2, rb2; \ + SBBQ ·qElement+16(SB), ra2; \ + MOVQ ra3, rb3; \ + SBBQ ·qElement+24(SB), ra3; \ + MOVQ ra4, rb4; \ + SBBQ ·qElement+32(SB), ra4; \ + MOVQ ra5, rb5; \ + SBBQ ·qElement+40(SB), ra5; \ + CMOVQCS rb0, ra0; \ + CMOVQCS rb1, ra1; \ + CMOVQCS rb2, ra2; \ + CMOVQCS rb3, ra3; \ + CMOVQCS rb4, ra4; \ + CMOVQCS rb5, ra5; \ TEXT ·addE2(SB), NOSPLIT, $0-24 MOVQ x+8(FP), AX @@ -421,496 +409,79 @@ TEXT ·squareAdxE2(SB), $48-16 // t[3] -> R11 // t[4] -> R12 // t[5] -> R13 - // clear the flags - XORQ AX, AX +#define MACC(in0, in1, in2) \ + ADCXQ in0, in1 \ + MULXQ in2, AX, in0 \ + ADOXQ AX, in1 \ + +#define DIV_SHIFT() \ + PUSHQ BP \ + MOVQ $const_qInvNeg, DX \ + IMULQ R8, DX \ + XORQ AX, AX \ + MULXQ ·qElement+0(SB), AX, BP \ + ADCXQ R8, AX \ + MOVQ BP, R8 \ + POPQ BP \ + MACC(R9, R8, ·qElement+8(SB)) \ + MACC(R10, R9, ·qElement+16(SB)) \ + MACC(R11, R10, ·qElement+24(SB)) \ + MACC(R12, R11, ·qElement+32(SB)) \ + MACC(R13, R12, ·qElement+40(SB)) \ + MOVQ $0, AX \ + ADCXQ AX, R13 \ + ADOXQ BP, R13 \ + +#define MUL_WORD_0() \ + XORQ AX, AX \ + MULXQ R14, R8, R9 \ + MULXQ R15, AX, R10 \ + ADOXQ AX, R9 \ + MULXQ CX, AX, R11 \ + ADOXQ AX, R10 \ + MULXQ BX, AX, R12 \ + ADOXQ AX, R11 \ + MULXQ SI, AX, R13 \ + ADOXQ AX, R12 \ + MULXQ DI, AX, BP \ + ADOXQ AX, R13 \ + MOVQ $0, AX \ + ADOXQ AX, BP \ + DIV_SHIFT() \ + +#define MUL_WORD_N() \ + XORQ AX, AX \ + MULXQ R14, AX, BP \ + ADOXQ AX, R8 \ + MACC(BP, R9, R15) \ + MACC(BP, R10, CX) \ + MACC(BP, R11, BX) \ + MACC(BP, R12, SI) \ + MACC(BP, R13, DI) \ + MOVQ $0, AX \ + ADCXQ AX, BP \ + ADOXQ AX, BP \ + DIV_SHIFT() \ + + // mul body MOVQ x+8(FP), DX MOVQ 0(DX), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R14, R8, R9 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R15, AX, R10 - ADOXQ AX, R9 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ CX, AX, R11 - ADOXQ AX, R10 - - // (A,t[3]) := x[3]*y[0] + A - MULXQ BX, AX, R12 - ADOXQ AX, R11 - - // (A,t[4]) := x[4]*y[0] + A - MULXQ SI, AX, R13 - ADOXQ AX, R12 - - // (A,t[5]) := x[5]*y[0] + A - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX + MUL_WORD_0() MOVQ x+8(FP), DX MOVQ 8(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[1] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX + MUL_WORD_N() MOVQ x+8(FP), DX MOVQ 16(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[2] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX + MUL_WORD_N() MOVQ x+8(FP), DX MOVQ 24(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[3] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX + MUL_WORD_N() MOVQ x+8(FP), DX MOVQ 32(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[4] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX + MUL_WORD_N() MOVQ x+8(FP), DX MOVQ 40(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[5] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[5] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[5] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[5] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[5] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[5] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 + MUL_WORD_N() // reduce element(R8,R9,R10,R11,R12,R13) using temp registers (R14,R15,CX,BX,SI,DI) REDUCE(R8,R9,R10,R11,R12,R13,R14,R15,CX,BX,SI,DI) @@ -997,490 +568,73 @@ TEXT ·squareAdxE2(SB), $48-16 // t[3] -> R11 // t[4] -> R12 // t[5] -> R13 - // clear the flags - XORQ AX, AX +#define MACC_0(in0, in1, in2) \ + ADCXQ in0, in1 \ + MULXQ in2, AX, in0 \ + ADOXQ AX, in1 \ + +#define DIV_SHIFT_0() \ + PUSHQ BP \ + MOVQ $const_qInvNeg, DX \ + IMULQ R8, DX \ + XORQ AX, AX \ + MULXQ ·qElement+0(SB), AX, BP \ + ADCXQ R8, AX \ + MOVQ BP, R8 \ + POPQ BP \ + MACC_0(R9, R8, ·qElement+8(SB)) \ + MACC_0(R10, R9, ·qElement+16(SB)) \ + MACC_0(R11, R10, ·qElement+24(SB)) \ + MACC_0(R12, R11, ·qElement+32(SB)) \ + MACC_0(R13, R12, ·qElement+40(SB)) \ + MOVQ $0, AX \ + ADCXQ AX, R13 \ + ADOXQ BP, R13 \ + +#define MUL_WORD_0_0() \ + XORQ AX, AX \ + MULXQ R14, R8, R9 \ + MULXQ R15, AX, R10 \ + ADOXQ AX, R9 \ + MULXQ CX, AX, R11 \ + ADOXQ AX, R10 \ + MULXQ BX, AX, R12 \ + ADOXQ AX, R11 \ + MULXQ SI, AX, R13 \ + ADOXQ AX, R12 \ + MULXQ DI, AX, BP \ + ADOXQ AX, R13 \ + MOVQ $0, AX \ + ADOXQ AX, BP \ + DIV_SHIFT_0() \ + +#define MUL_WORD_N_0() \ + XORQ AX, AX \ + MULXQ R14, AX, BP \ + ADOXQ AX, R8 \ + MACC_0(BP, R9, R15) \ + MACC_0(BP, R10, CX) \ + MACC_0(BP, R11, BX) \ + MACC_0(BP, R12, SI) \ + MACC_0(BP, R13, DI) \ + MOVQ $0, AX \ + ADCXQ AX, BP \ + ADOXQ AX, BP \ + DIV_SHIFT_0() \ + + // mul body MOVQ s0-8(SP), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R14, R8, R9 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R15, AX, R10 - ADOXQ AX, R9 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ CX, AX, R11 - ADOXQ AX, R10 - - // (A,t[3]) := x[3]*y[0] + A - MULXQ BX, AX, R12 - ADOXQ AX, R11 - - // (A,t[4]) := x[4]*y[0] + A - MULXQ SI, AX, R13 - ADOXQ AX, R12 - - // (A,t[5]) := x[5]*y[0] + A - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX + MUL_WORD_0_0() MOVQ s1-16(SP), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[1] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX + MUL_WORD_N_0() MOVQ s2-24(SP), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[2] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX + MUL_WORD_N_0() MOVQ s3-32(SP), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[3] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX + MUL_WORD_N_0() MOVQ s4-40(SP), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[4] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX + MUL_WORD_N_0() MOVQ s5-48(SP), DX - - // (A,t[0]) := t[0] + x[0]*y[5] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[5] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[5] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[5] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[5] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[5] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 + MUL_WORD_N_0() // reduce element(R8,R9,R10,R11,R12,R13) using temp registers (R14,R15,CX,BX,SI,DI) REDUCE(R8,R9,R10,R11,R12,R13,R14,R15,CX,BX,SI,DI) @@ -1531,496 +685,79 @@ TEXT ·mulAdxE2(SB), $96-24 // t[3] -> R11 // t[4] -> R12 // t[5] -> R13 - // clear the flags - XORQ AX, AX +#define MACC_1(in0, in1, in2) \ + ADCXQ in0, in1 \ + MULXQ in2, AX, in0 \ + ADOXQ AX, in1 \ + +#define DIV_SHIFT_1() \ + PUSHQ BP \ + MOVQ $const_qInvNeg, DX \ + IMULQ R8, DX \ + XORQ AX, AX \ + MULXQ ·qElement+0(SB), AX, BP \ + ADCXQ R8, AX \ + MOVQ BP, R8 \ + POPQ BP \ + MACC_1(R9, R8, ·qElement+8(SB)) \ + MACC_1(R10, R9, ·qElement+16(SB)) \ + MACC_1(R11, R10, ·qElement+24(SB)) \ + MACC_1(R12, R11, ·qElement+32(SB)) \ + MACC_1(R13, R12, ·qElement+40(SB)) \ + MOVQ $0, AX \ + ADCXQ AX, R13 \ + ADOXQ BP, R13 \ + +#define MUL_WORD_0_1() \ + XORQ AX, AX \ + MULXQ R14, R8, R9 \ + MULXQ R15, AX, R10 \ + ADOXQ AX, R9 \ + MULXQ CX, AX, R11 \ + ADOXQ AX, R10 \ + MULXQ BX, AX, R12 \ + ADOXQ AX, R11 \ + MULXQ SI, AX, R13 \ + ADOXQ AX, R12 \ + MULXQ DI, AX, BP \ + ADOXQ AX, R13 \ + MOVQ $0, AX \ + ADOXQ AX, BP \ + DIV_SHIFT_1() \ + +#define MUL_WORD_N_1() \ + XORQ AX, AX \ + MULXQ R14, AX, BP \ + ADOXQ AX, R8 \ + MACC_1(BP, R9, R15) \ + MACC_1(BP, R10, CX) \ + MACC_1(BP, R11, BX) \ + MACC_1(BP, R12, SI) \ + MACC_1(BP, R13, DI) \ + MOVQ $0, AX \ + ADCXQ AX, BP \ + ADOXQ AX, BP \ + DIV_SHIFT_1() \ + + // mul body MOVQ y+16(FP), DX MOVQ 48(DX), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R14, R8, R9 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R15, AX, R10 - ADOXQ AX, R9 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ CX, AX, R11 - ADOXQ AX, R10 - - // (A,t[3]) := x[3]*y[0] + A - MULXQ BX, AX, R12 - ADOXQ AX, R11 - - // (A,t[4]) := x[4]*y[0] + A - MULXQ SI, AX, R13 - ADOXQ AX, R12 - - // (A,t[5]) := x[5]*y[0] + A - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX + MUL_WORD_0_1() MOVQ y+16(FP), DX MOVQ 56(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[1] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX + MUL_WORD_N_1() MOVQ y+16(FP), DX MOVQ 64(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[2] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX + MUL_WORD_N_1() MOVQ y+16(FP), DX MOVQ 72(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[3] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX + MUL_WORD_N_1() MOVQ y+16(FP), DX MOVQ 80(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[4] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX + MUL_WORD_N_1() MOVQ y+16(FP), DX MOVQ 88(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[5] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[5] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[5] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[5] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[5] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[5] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 + MUL_WORD_N_1() // reduce element(R8,R9,R10,R11,R12,R13) using temp registers (R14,R15,CX,BX,SI,DI) REDUCE(R8,R9,R10,R11,R12,R13,R14,R15,CX,BX,SI,DI) @@ -2071,490 +808,73 @@ TEXT ·mulAdxE2(SB), $96-24 // t[3] -> R11 // t[4] -> R12 // t[5] -> R13 - // clear the flags - XORQ AX, AX +#define MACC_2(in0, in1, in2) \ + ADCXQ in0, in1 \ + MULXQ in2, AX, in0 \ + ADOXQ AX, in1 \ + +#define DIV_SHIFT_2() \ + PUSHQ BP \ + MOVQ $const_qInvNeg, DX \ + IMULQ R8, DX \ + XORQ AX, AX \ + MULXQ ·qElement+0(SB), AX, BP \ + ADCXQ R8, AX \ + MOVQ BP, R8 \ + POPQ BP \ + MACC_2(R9, R8, ·qElement+8(SB)) \ + MACC_2(R10, R9, ·qElement+16(SB)) \ + MACC_2(R11, R10, ·qElement+24(SB)) \ + MACC_2(R12, R11, ·qElement+32(SB)) \ + MACC_2(R13, R12, ·qElement+40(SB)) \ + MOVQ $0, AX \ + ADCXQ AX, R13 \ + ADOXQ BP, R13 \ + +#define MUL_WORD_0_2() \ + XORQ AX, AX \ + MULXQ R14, R8, R9 \ + MULXQ R15, AX, R10 \ + ADOXQ AX, R9 \ + MULXQ CX, AX, R11 \ + ADOXQ AX, R10 \ + MULXQ BX, AX, R12 \ + ADOXQ AX, R11 \ + MULXQ SI, AX, R13 \ + ADOXQ AX, R12 \ + MULXQ DI, AX, BP \ + ADOXQ AX, R13 \ + MOVQ $0, AX \ + ADOXQ AX, BP \ + DIV_SHIFT_2() \ + +#define MUL_WORD_N_2() \ + XORQ AX, AX \ + MULXQ R14, AX, BP \ + ADOXQ AX, R8 \ + MACC_2(BP, R9, R15) \ + MACC_2(BP, R10, CX) \ + MACC_2(BP, R11, BX) \ + MACC_2(BP, R12, SI) \ + MACC_2(BP, R13, DI) \ + MOVQ $0, AX \ + ADCXQ AX, BP \ + ADOXQ AX, BP \ + DIV_SHIFT_2() \ + + // mul body MOVQ s0-8(SP), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R14, R8, R9 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R15, AX, R10 - ADOXQ AX, R9 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ CX, AX, R11 - ADOXQ AX, R10 - - // (A,t[3]) := x[3]*y[0] + A - MULXQ BX, AX, R12 - ADOXQ AX, R11 - - // (A,t[4]) := x[4]*y[0] + A - MULXQ SI, AX, R13 - ADOXQ AX, R12 - - // (A,t[5]) := x[5]*y[0] + A - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX + MUL_WORD_0_2() MOVQ s1-16(SP), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[1] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX + MUL_WORD_N_2() MOVQ s2-24(SP), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[2] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX + MUL_WORD_N_2() MOVQ s3-32(SP), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[3] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX + MUL_WORD_N_2() MOVQ s4-40(SP), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[4] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX + MUL_WORD_N_2() MOVQ s5-48(SP), DX - - // (A,t[0]) := t[0] + x[0]*y[5] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[5] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[5] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[5] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[5] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[5] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 + MUL_WORD_N_2() // reduce element(R8,R9,R10,R11,R12,R13) using temp registers (R14,R15,CX,BX,SI,DI) REDUCE(R8,R9,R10,R11,R12,R13,R14,R15,CX,BX,SI,DI) @@ -2580,496 +900,79 @@ TEXT ·mulAdxE2(SB), $96-24 // t[3] -> R11 // t[4] -> R12 // t[5] -> R13 - // clear the flags - XORQ AX, AX +#define MACC_3(in0, in1, in2) \ + ADCXQ in0, in1 \ + MULXQ in2, AX, in0 \ + ADOXQ AX, in1 \ + +#define DIV_SHIFT_3() \ + PUSHQ BP \ + MOVQ $const_qInvNeg, DX \ + IMULQ R8, DX \ + XORQ AX, AX \ + MULXQ ·qElement+0(SB), AX, BP \ + ADCXQ R8, AX \ + MOVQ BP, R8 \ + POPQ BP \ + MACC_3(R9, R8, ·qElement+8(SB)) \ + MACC_3(R10, R9, ·qElement+16(SB)) \ + MACC_3(R11, R10, ·qElement+24(SB)) \ + MACC_3(R12, R11, ·qElement+32(SB)) \ + MACC_3(R13, R12, ·qElement+40(SB)) \ + MOVQ $0, AX \ + ADCXQ AX, R13 \ + ADOXQ BP, R13 \ + +#define MUL_WORD_0_3() \ + XORQ AX, AX \ + MULXQ R14, R8, R9 \ + MULXQ R15, AX, R10 \ + ADOXQ AX, R9 \ + MULXQ CX, AX, R11 \ + ADOXQ AX, R10 \ + MULXQ BX, AX, R12 \ + ADOXQ AX, R11 \ + MULXQ SI, AX, R13 \ + ADOXQ AX, R12 \ + MULXQ DI, AX, BP \ + ADOXQ AX, R13 \ + MOVQ $0, AX \ + ADOXQ AX, BP \ + DIV_SHIFT_3() \ + +#define MUL_WORD_N_3() \ + XORQ AX, AX \ + MULXQ R14, AX, BP \ + ADOXQ AX, R8 \ + MACC_3(BP, R9, R15) \ + MACC_3(BP, R10, CX) \ + MACC_3(BP, R11, BX) \ + MACC_3(BP, R12, SI) \ + MACC_3(BP, R13, DI) \ + MOVQ $0, AX \ + ADCXQ AX, BP \ + ADOXQ AX, BP \ + DIV_SHIFT_3() \ + + // mul body MOVQ y+16(FP), DX MOVQ 0(DX), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R14, R8, R9 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R15, AX, R10 - ADOXQ AX, R9 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ CX, AX, R11 - ADOXQ AX, R10 - - // (A,t[3]) := x[3]*y[0] + A - MULXQ BX, AX, R12 - ADOXQ AX, R11 - - // (A,t[4]) := x[4]*y[0] + A - MULXQ SI, AX, R13 - ADOXQ AX, R12 - - // (A,t[5]) := x[5]*y[0] + A - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX + MUL_WORD_0_3() MOVQ y+16(FP), DX MOVQ 8(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[1] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX + MUL_WORD_N_3() MOVQ y+16(FP), DX MOVQ 16(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[2] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX + MUL_WORD_N_3() MOVQ y+16(FP), DX MOVQ 24(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[3] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX + MUL_WORD_N_3() MOVQ y+16(FP), DX MOVQ 32(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[4] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX + MUL_WORD_N_3() MOVQ y+16(FP), DX MOVQ 40(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[5] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[5] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[5] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[5] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[5] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[5] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 + MUL_WORD_N_3() // reduce element(R8,R9,R10,R11,R12,R13) using temp registers (R14,R15,CX,BX,SI,DI) REDUCE(R8,R9,R10,R11,R12,R13,R14,R15,CX,BX,SI,DI) diff --git a/ecc/bls24-315/fp/asm.go b/ecc/bls24-315/fp/asm_adx.go similarity index 100% rename from ecc/bls24-315/fp/asm.go rename to ecc/bls24-315/fp/asm_adx.go diff --git a/ecc/bls24-315/fp/element.go b/ecc/bls24-315/fp/element.go index 4d6138686..4ab67695e 100644 --- a/ecc/bls24-315/fp/element.go +++ b/ecc/bls24-315/fp/element.go @@ -499,32 +499,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [6]uint64 var D uint64 diff --git a/ecc/bls24-315/fp/element_mul_amd64.s b/ecc/bls24-315/fp/element_mul_amd64.s deleted file mode 100644 index 92bba4f58..000000000 --- a/ecc/bls24-315/fp/element_mul_amd64.s +++ /dev/null @@ -1,656 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x6fe802ff40300001 -DATA q<>+8(SB)/8, $0x421ee5da52bde502 -DATA q<>+16(SB)/8, $0xdec1d01aa27a1ae0 -DATA q<>+24(SB)/8, $0xd3f7498be97c5eaf -DATA q<>+32(SB)/8, $0x04c23a02b586d650 -GLOBL q<>(SB), (RODATA+NOPTR), $40 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x702ff9ff402fffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, rb0, rb1, rb2, rb3, rb4) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $24-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), DI - - // x[0] -> R9 - // x[1] -> R10 - // x[2] -> R11 - MOVQ 0(DI), R9 - MOVQ 8(DI), R10 - MOVQ 16(DI), R11 - MOVQ y+16(FP), R12 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // t[4] -> SI - // clear the flags - XORQ AX, AX - MOVQ 0(R12), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R9, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R10, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R11, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ 24(DI), AX, SI - ADOXQ AX, BX - - // (A,t[4]) := x[4]*y[0] + A - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 8(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 16(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 24(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 32(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // reduce element(R14,R13,CX,BX,SI) using temp registers (R8,DI,R12,R9,R10) - REDUCE(R14,R13,CX,BX,SI,R8,DI,R12,R9,R10) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $8-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - MOVQ 32(DX), SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - - // reduce element(R14,R13,CX,BX,SI) using temp registers (DI,R8,R9,R10,R11) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/bls24-315/fp/element_ops_amd64.go b/ecc/bls24-315/fp/element_ops_amd64.go index 83bba45ae..ed2803d71 100644 --- a/ecc/bls24-315/fp/element_ops_amd64.go +++ b/ecc/bls24-315/fp/element_ops_amd64.go @@ -50,48 +50,8 @@ func Butterfly(a, b *Element) // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bls24-315/fp/element_ops_amd64.s b/ecc/bls24-315/fp/element_ops_amd64.s index 9528ab595..29314843d 100644 --- a/ecc/bls24-315/fp/element_ops_amd64.s +++ b/ecc/bls24-315/fp/element_ops_amd64.s @@ -1,272 +1,6 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// Code generated by gnark-crypto/generator. DO NOT EDIT. +// We include the hash to force the Go compiler to recompile: 18184981773209750009 +#include "../../../field/asm/element_5w_amd64.s" -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x6fe802ff40300001 -DATA q<>+8(SB)/8, $0x421ee5da52bde502 -DATA q<>+16(SB)/8, $0xdec1d01aa27a1ae0 -DATA q<>+24(SB)/8, $0xd3f7498be97c5eaf -DATA q<>+32(SB)/8, $0x04c23a02b586d650 -GLOBL q<>(SB), (RODATA+NOPTR), $40 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x702ff9ff402fffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, rb0, rb1, rb2, rb3, rb4) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R13,R14,R15,R8,R9) - REDUCE(DX,CX,BX,SI,DI,R13,R14,R15,R8,R9) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R13,R14,R15,R8,R9) - REDUCE(DX,CX,BX,SI,DI,R13,R14,R15,R8,R9) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R10,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), $16-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP)) - REDUCE(DX,CX,BX,SI,DI,R13,R14,R15,s0-8(SP),s1-16(SP)) - - MOVQ DX, R13 - MOVQ CX, R14 - MOVQ BX, R15 - MOVQ SI, s0-8(SP) - MOVQ DI, s1-16(SP) - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ R13, DX - ADCQ R14, CX - ADCQ R15, BX - ADCQ s0-8(SP), SI - ADCQ s1-16(SP), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), $24-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ 32(AX), R8 - MOVQ CX, R9 - MOVQ BX, R10 - MOVQ SI, R11 - MOVQ DI, R12 - MOVQ R8, R13 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - ADCQ 32(DX), R8 - SUBQ 0(DX), R9 - SBBQ 8(DX), R10 - SBBQ 16(DX), R11 - SBBQ 24(DX), R12 - SBBQ 32(DX), R13 - MOVQ CX, R14 - MOVQ BX, R15 - MOVQ SI, s0-8(SP) - MOVQ DI, s1-16(SP) - MOVQ R8, s2-24(SP) - MOVQ $0x6fe802ff40300001, CX - MOVQ $0x421ee5da52bde502, BX - MOVQ $0xdec1d01aa27a1ae0, SI - MOVQ $0xd3f7498be97c5eaf, DI - MOVQ $0x04c23a02b586d650, R8 - CMOVQCC AX, CX - CMOVQCC AX, BX - CMOVQCC AX, SI - CMOVQCC AX, DI - CMOVQCC AX, R8 - ADDQ CX, R9 - ADCQ BX, R10 - ADCQ SI, R11 - ADCQ DI, R12 - ADCQ R8, R13 - MOVQ R14, CX - MOVQ R15, BX - MOVQ s0-8(SP), SI - MOVQ s1-16(SP), DI - MOVQ s2-24(SP), R8 - MOVQ R9, 0(DX) - MOVQ R10, 8(DX) - MOVQ R11, 16(DX) - MOVQ R12, 24(DX) - MOVQ R13, 32(DX) - - // reduce element(CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - MOVQ R8, 32(AX) - RET diff --git a/ecc/bls24-315/fp/element_ops_purego.go b/ecc/bls24-315/fp/element_ops_purego.go index 9a557a358..4796fc3c5 100644 --- a/ecc/bls24-315/fp/element_ops_purego.go +++ b/ecc/bls24-315/fp/element_ops_purego.go @@ -66,48 +66,8 @@ func reduce(z *Element) { // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3, t4 uint64 var u0, u1, u2, u3, u4 uint64 diff --git a/ecc/bls24-315/fp/element_test.go b/ecc/bls24-315/fp/element_test.go index 665ffce6a..d4656a26d 100644 --- a/ecc/bls24-315/fp/element_test.go +++ b/ecc/bls24-315/fp/element_test.go @@ -639,7 +639,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -709,77 +708,6 @@ func TestElementLexicographicallyLargest(t *testing.T) { } -func TestElementVecOps(t *testing.T) { - assert := require.New(t) - - const N = 7 - a := make(Vector, N) - b := make(Vector, N) - c := make(Vector, N) - for i := 0; i < N; i++ { - a[i].SetRandom() - b[i].SetRandom() - } - - // Vector addition - c.Add(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Add(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector addition failed") - } - - // Vector subtraction - c.Sub(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Sub(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector subtraction failed") - } - - // Vector scaling - c.ScalarMul(a, &b[0]) - for i := 0; i < N; i++ { - var expected Element - expected.Mul(&a[i], &b[0]) - assert.True(c[i].Equal(&expected), "Vector scaling failed") - } -} - -func BenchmarkElementVecOps(b *testing.B) { - // note; to benchmark against "no asm" version, use the following - // build tag: -tags purego - const N = 1024 - a1 := make(Vector, N) - b1 := make(Vector, N) - c1 := make(Vector, N) - for i := 0; i < N; i++ { - a1[i].SetRandom() - b1[i].SetRandom() - } - - b.Run("Add", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Add(a1, b1) - } - }) - - b.Run("Sub", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Sub(a1, b1) - } - }) - - b.Run("ScalarMul", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.ScalarMul(a1, &b1[0]) - } - }) -} - func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() @@ -2349,40 +2277,40 @@ func gen() gopter.Gen { } } -func genFull() gopter.Gen { - return func(genParams *gopter.GenParameters) *gopter.GenResult { +func genRandomFq(genParams *gopter.GenParameters) Element { + var g Element - genRandomFq := func() Element { - var g Element + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } + if qElement[4] != ^uint64(0) { + g[4] %= (qElement[4] + 1) + } - if qElement[4] != ^uint64(0) { - g[4] %= (qElement[4] + 1) - } + for !g.smallerThanModulus() { + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[4] != ^uint64(0) { + g[4] %= (qElement[4] + 1) + } + } - for !g.smallerThanModulus() { - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } - if qElement[4] != ^uint64(0) { - g[4] %= (qElement[4] + 1) - } - } + return g +} - return g - } - a := genRandomFq() +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) var carry uint64 a[0], carry = bits.Add64(a[0], qElement[0], carry) @@ -2396,6 +2324,14 @@ func genFull() gopter.Gen { } } +func genElement() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} + func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { var modulus big.Int var aIntMod big.Int diff --git a/ecc/bls24-315/fp/vector.go b/ecc/bls24-315/fp/vector.go index 01b326d49..ce61e70ea 100644 --- a/ecc/bls24-315/fp/vector.go +++ b/ecc/bls24-315/fp/vector.go @@ -218,6 +218,25 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} + func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") @@ -245,6 +264,32 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + +func innerProductVecGeneric(res *Element, a, b Vector) { + if len(a) != len(b) { + panic("vector.InnerProduct: vectors don't have the same length") + } + var tmp Element + for i := 0; i < len(a); i++ { + tmp.Mul(&a[i], &b[i]) + res.Add(res, &tmp) + } +} + +func mulVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Mul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], &b[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bls24-315/fp/vector_test.go b/ecc/bls24-315/fp/vector_test.go index 5d88af91c..c60ee0844 100644 --- a/ecc/bls24-315/fp/vector_test.go +++ b/ecc/bls24-315/fp/vector_test.go @@ -18,10 +18,15 @@ package fp import ( "bytes" + "fmt" "github.com/stretchr/testify/require" + "os" "reflect" "sort" "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" ) func TestVectorSort(t *testing.T) { @@ -88,3 +93,281 @@ func (vector *Vector) unmarshalBinaryAsync(data []byte) error { } return <-chErr } + +func TestVectorOps(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = 2 + } else { + parameters.MinSuccessfulTests = 10 + } + properties := gopter.NewProperties(parameters) + + addVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Add(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Add(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + subVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Sub(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Sub(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + scalarMulVector := func(a Vector, b Element) bool { + c := make(Vector, len(a)) + c.ScalarMul(a, &b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sumVector := func(a Vector) bool { + var sum Element + computed := a.Sum() + for i := 0; i < len(a); i++ { + sum.Add(&sum, &a[i]) + } + + return sum.Equal(&computed) + } + + innerProductVector := func(a, b Vector) bool { + computed := a.InnerProduct(b) + var innerProduct Element + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + innerProduct.Add(&innerProduct, &tmp) + } + + return innerProduct.Equal(&computed) + } + + mulVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + a[0].SetUint64(0x24) + b[0].SetUint64(0x42) + c.Mul(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sizes := []int{1, 2, 3, 4, 8, 9, 15, 16, 509, 510, 511, 512, 513, 514} + type genPair struct { + g1, g2 gopter.Gen + label string + } + + for _, size := range sizes { + generators := []genPair{ + {genZeroVector(size), genZeroVector(size), "zero vectors"}, + {genMaxVector(size), genMaxVector(size), "max vectors"}, + {genVector(size), genVector(size), "random vectors"}, + {genVector(size), genZeroVector(size), "random and zero vectors"}, + } + for _, gp := range generators { + properties.Property(fmt.Sprintf("vector addition %d - %s", size, gp.label), prop.ForAll( + addVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector subtraction %d - %s", size, gp.label), prop.ForAll( + subVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector scalar multiplication %d - %s", size, gp.label), prop.ForAll( + scalarMulVector, + gp.g1, + genElement(), + )) + + properties.Property(fmt.Sprintf("vector sum %d - %s", size, gp.label), prop.ForAll( + sumVector, + gp.g1, + )) + + properties.Property(fmt.Sprintf("vector inner product %d - %s", size, gp.label), prop.ForAll( + innerProductVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector multiplication %d - %s", size, gp.label), prop.ForAll( + mulVector, + gp.g1, + gp.g2, + )) + } + } + + properties.TestingRun(t, gopter.NewFormatedReporter(false, 260, os.Stdout)) +} + +func BenchmarkVectorOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1 << 24 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + var mixer Element + mixer.SetRandom() + for i := 1; i < N; i++ { + a1[i-1].SetUint64(uint64(i)). + Mul(&a1[i-1], &mixer) + b1[i-1].SetUint64(^uint64(i)). + Mul(&b1[i-1], &mixer) + } + + for n := 1 << 4; n <= N; n <<= 1 { + b.Run(fmt.Sprintf("add %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Add(_a, _b) + } + }) + + b.Run(fmt.Sprintf("sub %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Sub(_a, _b) + } + }) + + b.Run(fmt.Sprintf("scalarMul %d", n), func(b *testing.B) { + _a := a1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.ScalarMul(_a, &mixer) + } + }) + + b.Run(fmt.Sprintf("sum %d", n), func(b *testing.B) { + _a := a1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.Sum() + } + }) + + b.Run(fmt.Sprintf("innerProduct %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.InnerProduct(_b) + } + }) + + b.Run(fmt.Sprintf("mul %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Mul(_a, _b) + } + }) + } +} + +func genZeroVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genMaxVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + + qMinusOne := qElement + qMinusOne[0]-- + + for i := 0; i < size; i++ { + g[i] = qMinusOne + } + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + mixer := Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[4] != ^uint64(0) { + mixer[4] %= (qElement[4] + 1) + } + + for !mixer.smallerThanModulus() { + mixer = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[4] != ^uint64(0) { + mixer[4] %= (qElement[4] + 1) + } + } + + for i := 1; i <= size; i++ { + g[i-1].SetUint64(uint64(i)). + Mul(&g[i-1], &mixer) + } + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} diff --git a/ecc/bls24-315/fr/asm.go b/ecc/bls24-315/fr/asm_adx.go similarity index 100% rename from ecc/bls24-315/fr/asm.go rename to ecc/bls24-315/fr/asm_adx.go diff --git a/ecc/bls24-315/fr/asm_avx.go b/ecc/bls24-315/fr/asm_avx.go new file mode 100644 index 000000000..955f55979 --- /dev/null +++ b/ecc/bls24-315/fr/asm_avx.go @@ -0,0 +1,27 @@ +//go:build !noavx +// +build !noavx + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import "golang.org/x/sys/cpu" + +var ( + supportAvx512 = supportAdx && cpu.X86.HasAVX512 && cpu.X86.HasAVX512DQ + _ = supportAvx512 +) diff --git a/ecc/bls24-315/fr/asm_noavx.go b/ecc/bls24-315/fr/asm_noavx.go new file mode 100644 index 000000000..e5a5b1f2c --- /dev/null +++ b/ecc/bls24-315/fr/asm_noavx.go @@ -0,0 +1,22 @@ +//go:build noavx +// +build noavx + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +const supportAvx512 = false diff --git a/ecc/bls24-315/fr/element.go b/ecc/bls24-315/fr/element.go index c24a104a6..abdb822ac 100644 --- a/ecc/bls24-315/fr/element.go +++ b/ecc/bls24-315/fr/element.go @@ -81,6 +81,9 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 2184305180030271487 +// mu = 2^288 / q needed for partial Barrett reduction +const mu uint64 = 43237874697 + func init() { _modulus.SetString("196deac24a9da12b25fc7ec9cf927a98c8c480ece644e36419d0c5fd00c00001", 16) } @@ -477,32 +480,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [5]uint64 var D uint64 diff --git a/ecc/bls24-315/fr/element_mul_amd64.s b/ecc/bls24-315/fr/element_mul_amd64.s deleted file mode 100644 index d028fed20..000000000 --- a/ecc/bls24-315/fr/element_mul_amd64.s +++ /dev/null @@ -1,487 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x19d0c5fd00c00001 -DATA q<>+8(SB)/8, $0xc8c480ece644e364 -DATA q<>+16(SB)/8, $0x25fc7ec9cf927a98 -DATA q<>+24(SB)/8, $0x196deac24a9da12b -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x1e5035fd00bfffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $24-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), SI - - // x[0] -> DI - // x[1] -> R8 - // x[2] -> R9 - // x[3] -> R10 - MOVQ 0(SI), DI - MOVQ 8(SI), R8 - MOVQ 16(SI), R9 - MOVQ 24(SI), R10 - MOVQ y+16(FP), R11 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ DI, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R8, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R9, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,R12,R11,DI) - REDUCE(R14,R13,CX,BX,SI,R12,R11,DI) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $8-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,DI,R8,R9) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/bls24-315/fr/element_ops_amd64.go b/ecc/bls24-315/fr/element_ops_amd64.go index 21568255d..b653e8006 100644 --- a/ecc/bls24-315/fr/element_ops_amd64.go +++ b/ecc/bls24-315/fr/element_ops_amd64.go @@ -51,7 +51,8 @@ func (vector *Vector) Add(a, b Vector) { if len(a) != len(b) || len(a) != len(*vector) { panic("vector.Add: vectors don't have the same length") } - addVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) + n := uint64(len(a)) + addVec(&(*vector)[0], &a[0], &b[0], n) } //go:noescape @@ -75,59 +76,123 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { if len(a) != len(*vector) { panic("vector.ScalarMul: vectors don't have the same length") } - scalarMulVec(&(*vector)[0], &a[0], b, uint64(len(a))) + const maxN = (1 << 32) - 1 + if !supportAvx512 || uint64(len(a)) >= maxN { + // call scalarMulVecGeneric + scalarMulVecGeneric(*vector, a, b) + return + } + n := uint64(len(a)) + if n == 0 { + return + } + // the code for scalarMul is identical to mulVec; and it expects at least + // 2 elements in the vector to fill the Z registers + var bb [2]Element + bb[0] = *b + bb[1] = *b + const blockSize = 16 + scalarMulVec(&(*vector)[0], &a[0], &bb[0], n/blockSize, qInvNeg) + if n%blockSize != 0 { + // call scalarMulVecGeneric on the rest + start := n - n%blockSize + scalarMulVecGeneric((*vector)[start:], a[start:], b) + } +} + +//go:noescape +func scalarMulVec(res, a, b *Element, n uint64, qInvNeg uint64) + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + n := uint64(len(*vector)) + if n == 0 { + return + } + const minN = 16 * 7 // AVX512 slower than generic for small n + const maxN = (1 << 32) - 1 + if !supportAvx512 || n <= minN || n >= maxN { + // call sumVecGeneric + sumVecGeneric(&res, *vector) + return + } + sumVec(&res, &(*vector)[0], uint64(len(*vector))) + return } //go:noescape -func scalarMulVec(res, a, b *Element, n uint64) +func sumVec(res *Element, a *Element, n uint64) + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + n := uint64(len(*vector)) + if n == 0 { + return + } + if n != uint64(len(other)) { + panic("vector.InnerProduct: vectors don't have the same length") + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call innerProductVecGeneric + // note; we could split the vector into smaller chunks and call innerProductVec + innerProductVecGeneric(&res, *vector, other) + return + } + innerProdVec(&res[0], &(*vector)[0], &other[0], uint64(len(*vector))) + + return +} + +//go:noescape +func innerProdVec(res *uint64, a, b *Element, n uint64) + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Mul: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call mulVecGeneric + mulVecGeneric(*vector, a, b) + return + } + + const blockSize = 16 + mulVec(&(*vector)[0], &a[0], &b[0], n/blockSize, qInvNeg) + if n%blockSize != 0 { + // call mulVecGeneric on the rest + start := n - n%blockSize + mulVecGeneric((*vector)[start:], a[start:], b[start:]) + } + +} + +// Patterns use for transposing the vectors in mulVec +var ( + pattern1 = [8]uint64{0, 8, 1, 9, 2, 10, 3, 11} + pattern2 = [8]uint64{12, 4, 13, 5, 14, 6, 15, 7} + pattern3 = [8]uint64{0, 1, 8, 9, 2, 3, 10, 11} + pattern4 = [8]uint64{12, 13, 4, 5, 14, 15, 6, 7} +) + +//go:noescape +func mulVec(res, a, b *Element, n uint64, qInvNeg uint64) // Mul z = x * y (mod q) // // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bls24-315/fr/element_ops_amd64.s b/ecc/bls24-315/fr/element_ops_amd64.s index 2e52c653b..6c42136a7 100644 --- a/ecc/bls24-315/fr/element_ops_amd64.s +++ b/ecc/bls24-315/fr/element_ops_amd64.s @@ -1,627 +1,6 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// Code generated by gnark-crypto/generator. DO NOT EDIT. +// We include the hash to force the Go compiler to recompile: 9425145785761608449 +#include "../../../field/asm/element_4w_amd64.s" -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x19d0c5fd00c00001 -DATA q<>+8(SB)/8, $0xc8c480ece644e364 -DATA q<>+16(SB)/8, $0x25fc7ec9cf927a98 -DATA q<>+24(SB)/8, $0x196deac24a9da12b -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x1e5035fd00bfffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R15,DI,R8,R9) - REDUCE(DX,CX,BX,SI,R15,DI,R8,R9) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, R11 - MOVQ CX, R12 - MOVQ BX, R13 - MOVQ SI, R14 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ R11, DX - ADCQ R12, CX - ADCQ R13, BX - ADCQ R14, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), NOSPLIT, $0-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ CX, R8 - MOVQ BX, R9 - MOVQ SI, R10 - MOVQ DI, R11 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - SUBQ 0(DX), R8 - SBBQ 8(DX), R9 - SBBQ 16(DX), R10 - SBBQ 24(DX), R11 - MOVQ $0x19d0c5fd00c00001, R12 - MOVQ $0xc8c480ece644e364, R13 - MOVQ $0x25fc7ec9cf927a98, R14 - MOVQ $0x196deac24a9da12b, R15 - CMOVQCC AX, R12 - CMOVQCC AX, R13 - CMOVQCC AX, R14 - CMOVQCC AX, R15 - ADDQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - ADCQ R15, R11 - MOVQ R8, 0(DX) - MOVQ R9, 8(DX) - MOVQ R10, 16(DX) - MOVQ R11, 24(DX) - - // reduce element(CX,BX,SI,DI) using temp registers (R8,R9,R10,R11) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - RET - -// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] -TEXT ·addVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - -loop_1: - TESTQ BX, BX - JEQ done_2 // n == 0, we are done - - // a[0] -> SI - // a[1] -> DI - // a[2] -> R8 - // a[3] -> R9 - MOVQ 0(AX), SI - MOVQ 8(AX), DI - MOVQ 16(AX), R8 - MOVQ 24(AX), R9 - ADDQ 0(DX), SI - ADCQ 8(DX), DI - ADCQ 16(DX), R8 - ADCQ 24(DX), R9 - - // reduce element(SI,DI,R8,R9) using temp registers (R10,R11,R12,R13) - REDUCE(SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ SI, 0(CX) - MOVQ DI, 8(CX) - MOVQ R8, 16(CX) - MOVQ R9, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_1 - -done_2: - RET - -// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] -TEXT ·subVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - XORQ SI, SI - -loop_3: - TESTQ BX, BX - JEQ done_4 // n == 0, we are done - - // a[0] -> DI - // a[1] -> R8 - // a[2] -> R9 - // a[3] -> R10 - MOVQ 0(AX), DI - MOVQ 8(AX), R8 - MOVQ 16(AX), R9 - MOVQ 24(AX), R10 - SUBQ 0(DX), DI - SBBQ 8(DX), R8 - SBBQ 16(DX), R9 - SBBQ 24(DX), R10 - - // reduce (a-b) mod q - // q[0] -> R11 - // q[1] -> R12 - // q[2] -> R13 - // q[3] -> R14 - MOVQ $0x19d0c5fd00c00001, R11 - MOVQ $0xc8c480ece644e364, R12 - MOVQ $0x25fc7ec9cf927a98, R13 - MOVQ $0x196deac24a9da12b, R14 - CMOVQCC SI, R11 - CMOVQCC SI, R12 - CMOVQCC SI, R13 - CMOVQCC SI, R14 - - // add registers (q or 0) to a, and set to result - ADDQ R11, DI - ADCQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - MOVQ DI, 0(CX) - MOVQ R8, 8(CX) - MOVQ R9, 16(CX) - MOVQ R10, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_3 - -done_4: - RET - -// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b -TEXT ·scalarMulVec(SB), $56-32 - CMPB ·supportAdx(SB), $1 - JNE noAdx_5 - MOVQ a+8(FP), R11 - MOVQ b+16(FP), R10 - MOVQ n+24(FP), R12 - - // scalar[0] -> SI - // scalar[1] -> DI - // scalar[2] -> R8 - // scalar[3] -> R9 - MOVQ 0(R10), SI - MOVQ 8(R10), DI - MOVQ 16(R10), R8 - MOVQ 24(R10), R9 - MOVQ res+0(FP), R10 - -loop_6: - TESTQ R12, R12 - JEQ done_7 // n == 0, we are done - - // TODO @gbotrel this is generated from the same macro as the unit mul, we should refactor this in a single asm function - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ SI, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ DI, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R8, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce t mod q - // reduce element(R14,R15,CX,BX) using temp registers (R13,AX,DX,s0-8(SP)) - REDUCE(R14,R15,CX,BX,R13,AX,DX,s0-8(SP)) - - MOVQ R14, 0(R10) - MOVQ R15, 8(R10) - MOVQ CX, 16(R10) - MOVQ BX, 24(R10) - - // increment pointers to visit next element - ADDQ $32, R11 - ADDQ $32, R10 - DECQ R12 // decrement n - JMP loop_6 - -done_7: - RET - -noAdx_5: - MOVQ n+24(FP), DX - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ DX, 8(SP) - MOVQ DX, 16(SP) - MOVQ a+8(FP), AX - MOVQ AX, 24(SP) - MOVQ DX, 32(SP) - MOVQ DX, 40(SP) - MOVQ b+16(FP), AX - MOVQ AX, 48(SP) - CALL ·scalarMulVecGeneric(SB) - RET diff --git a/ecc/bls24-315/fr/element_ops_purego.go b/ecc/bls24-315/fr/element_ops_purego.go index 7b6cfd87b..e7a8817f0 100644 --- a/ecc/bls24-315/fr/element_ops_purego.go +++ b/ecc/bls24-315/fr/element_ops_purego.go @@ -78,53 +78,32 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} + // Mul z = x * y (mod q) // // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3 uint64 var u0, u1, u2, u3 uint64 diff --git a/ecc/bls24-315/fr/element_test.go b/ecc/bls24-315/fr/element_test.go index ac030b6d0..7933f3aa3 100644 --- a/ecc/bls24-315/fr/element_test.go +++ b/ecc/bls24-315/fr/element_test.go @@ -637,7 +637,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -707,77 +706,6 @@ func TestElementLexicographicallyLargest(t *testing.T) { } -func TestElementVecOps(t *testing.T) { - assert := require.New(t) - - const N = 7 - a := make(Vector, N) - b := make(Vector, N) - c := make(Vector, N) - for i := 0; i < N; i++ { - a[i].SetRandom() - b[i].SetRandom() - } - - // Vector addition - c.Add(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Add(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector addition failed") - } - - // Vector subtraction - c.Sub(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Sub(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector subtraction failed") - } - - // Vector scaling - c.ScalarMul(a, &b[0]) - for i := 0; i < N; i++ { - var expected Element - expected.Mul(&a[i], &b[0]) - assert.True(c[i].Equal(&expected), "Vector scaling failed") - } -} - -func BenchmarkElementVecOps(b *testing.B) { - // note; to benchmark against "no asm" version, use the following - // build tag: -tags purego - const N = 1024 - a1 := make(Vector, N) - b1 := make(Vector, N) - c1 := make(Vector, N) - for i := 0; i < N; i++ { - a1[i].SetRandom() - b1[i].SetRandom() - } - - b.Run("Add", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Add(a1, b1) - } - }) - - b.Run("Sub", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Sub(a1, b1) - } - }) - - b.Run("ScalarMul", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.ScalarMul(a1, &b1[0]) - } - }) -} - func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() @@ -2345,38 +2273,38 @@ func gen() gopter.Gen { } } -func genFull() gopter.Gen { - return func(genParams *gopter.GenParameters) *gopter.GenResult { +func genRandomFq(genParams *gopter.GenParameters) Element { + var g Element - genRandomFq := func() Element { - var g Element + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } - if qElement[3] != ^uint64(0) { - g[3] %= (qElement[3] + 1) - } + for !g.smallerThanModulus() { + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } + } - for !g.smallerThanModulus() { - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } - if qElement[3] != ^uint64(0) { - g[3] %= (qElement[3] + 1) - } - } + return g +} - return g - } - a := genRandomFq() +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) var carry uint64 a[0], carry = bits.Add64(a[0], qElement[0], carry) @@ -2389,6 +2317,14 @@ func genFull() gopter.Gen { } } +func genElement() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} + func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { var modulus big.Int var aIntMod big.Int diff --git a/ecc/bls24-315/fr/vector.go b/ecc/bls24-315/fr/vector.go index f39828547..867cabbc3 100644 --- a/ecc/bls24-315/fr/vector.go +++ b/ecc/bls24-315/fr/vector.go @@ -226,6 +226,32 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + +func innerProductVecGeneric(res *Element, a, b Vector) { + if len(a) != len(b) { + panic("vector.InnerProduct: vectors don't have the same length") + } + var tmp Element + for i := 0; i < len(a); i++ { + tmp.Mul(&a[i], &b[i]) + res.Add(res, &tmp) + } +} + +func mulVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Mul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], &b[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bls24-315/fr/vector_test.go b/ecc/bls24-315/fr/vector_test.go index e58f2d9a3..b6344c18b 100644 --- a/ecc/bls24-315/fr/vector_test.go +++ b/ecc/bls24-315/fr/vector_test.go @@ -18,10 +18,15 @@ package fr import ( "bytes" + "fmt" "github.com/stretchr/testify/require" + "os" "reflect" "sort" "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" ) func TestVectorSort(t *testing.T) { @@ -88,3 +93,279 @@ func (vector *Vector) unmarshalBinaryAsync(data []byte) error { } return <-chErr } + +func TestVectorOps(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = 2 + } else { + parameters.MinSuccessfulTests = 10 + } + properties := gopter.NewProperties(parameters) + + addVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Add(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Add(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + subVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Sub(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Sub(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + scalarMulVector := func(a Vector, b Element) bool { + c := make(Vector, len(a)) + c.ScalarMul(a, &b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sumVector := func(a Vector) bool { + var sum Element + computed := a.Sum() + for i := 0; i < len(a); i++ { + sum.Add(&sum, &a[i]) + } + + return sum.Equal(&computed) + } + + innerProductVector := func(a, b Vector) bool { + computed := a.InnerProduct(b) + var innerProduct Element + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + innerProduct.Add(&innerProduct, &tmp) + } + + return innerProduct.Equal(&computed) + } + + mulVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + a[0].SetUint64(0x24) + b[0].SetUint64(0x42) + c.Mul(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sizes := []int{1, 2, 3, 4, 8, 9, 15, 16, 509, 510, 511, 512, 513, 514} + type genPair struct { + g1, g2 gopter.Gen + label string + } + + for _, size := range sizes { + generators := []genPair{ + {genZeroVector(size), genZeroVector(size), "zero vectors"}, + {genMaxVector(size), genMaxVector(size), "max vectors"}, + {genVector(size), genVector(size), "random vectors"}, + {genVector(size), genZeroVector(size), "random and zero vectors"}, + } + for _, gp := range generators { + properties.Property(fmt.Sprintf("vector addition %d - %s", size, gp.label), prop.ForAll( + addVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector subtraction %d - %s", size, gp.label), prop.ForAll( + subVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector scalar multiplication %d - %s", size, gp.label), prop.ForAll( + scalarMulVector, + gp.g1, + genElement(), + )) + + properties.Property(fmt.Sprintf("vector sum %d - %s", size, gp.label), prop.ForAll( + sumVector, + gp.g1, + )) + + properties.Property(fmt.Sprintf("vector inner product %d - %s", size, gp.label), prop.ForAll( + innerProductVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector multiplication %d - %s", size, gp.label), prop.ForAll( + mulVector, + gp.g1, + gp.g2, + )) + } + } + + properties.TestingRun(t, gopter.NewFormatedReporter(false, 260, os.Stdout)) +} + +func BenchmarkVectorOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1 << 24 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + var mixer Element + mixer.SetRandom() + for i := 1; i < N; i++ { + a1[i-1].SetUint64(uint64(i)). + Mul(&a1[i-1], &mixer) + b1[i-1].SetUint64(^uint64(i)). + Mul(&b1[i-1], &mixer) + } + + for n := 1 << 4; n <= N; n <<= 1 { + b.Run(fmt.Sprintf("add %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Add(_a, _b) + } + }) + + b.Run(fmt.Sprintf("sub %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Sub(_a, _b) + } + }) + + b.Run(fmt.Sprintf("scalarMul %d", n), func(b *testing.B) { + _a := a1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.ScalarMul(_a, &mixer) + } + }) + + b.Run(fmt.Sprintf("sum %d", n), func(b *testing.B) { + _a := a1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.Sum() + } + }) + + b.Run(fmt.Sprintf("innerProduct %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.InnerProduct(_b) + } + }) + + b.Run(fmt.Sprintf("mul %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Mul(_a, _b) + } + }) + } +} + +func genZeroVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genMaxVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + + qMinusOne := qElement + qMinusOne[0]-- + + for i := 0; i < size; i++ { + g[i] = qMinusOne + } + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + mixer := Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + mixer[3] %= (qElement[3] + 1) + } + + for !mixer.smallerThanModulus() { + mixer = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + mixer[3] %= (qElement[3] + 1) + } + } + + for i := 1; i <= size; i++ { + g[i-1].SetUint64(uint64(i)). + Mul(&g[i-1], &mixer) + } + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} diff --git a/ecc/bls24-317/fp/asm.go b/ecc/bls24-317/fp/asm_adx.go similarity index 100% rename from ecc/bls24-317/fp/asm.go rename to ecc/bls24-317/fp/asm_adx.go diff --git a/ecc/bls24-317/fp/element.go b/ecc/bls24-317/fp/element.go index 652a4a78e..77818de47 100644 --- a/ecc/bls24-317/fp/element.go +++ b/ecc/bls24-317/fp/element.go @@ -499,32 +499,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [6]uint64 var D uint64 diff --git a/ecc/bls24-317/fp/element_mul_amd64.s b/ecc/bls24-317/fp/element_mul_amd64.s deleted file mode 100644 index bfc863eeb..000000000 --- a/ecc/bls24-317/fp/element_mul_amd64.s +++ /dev/null @@ -1,656 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x8d512e565dab2aab -DATA q<>+8(SB)/8, $0xd6f339e43424bf7e -DATA q<>+16(SB)/8, $0x169a61e684c73446 -DATA q<>+24(SB)/8, $0xf28fc5a0b7f9d039 -DATA q<>+32(SB)/8, $0x1058ca226f60892c -GLOBL q<>(SB), (RODATA+NOPTR), $40 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x55b5e0028b047ffd -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, rb0, rb1, rb2, rb3, rb4) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $24-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), DI - - // x[0] -> R9 - // x[1] -> R10 - // x[2] -> R11 - MOVQ 0(DI), R9 - MOVQ 8(DI), R10 - MOVQ 16(DI), R11 - MOVQ y+16(FP), R12 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // t[4] -> SI - // clear the flags - XORQ AX, AX - MOVQ 0(R12), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R9, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R10, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R11, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ 24(DI), AX, SI - ADOXQ AX, BX - - // (A,t[4]) := x[4]*y[0] + A - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 8(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 16(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 24(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 32(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // reduce element(R14,R13,CX,BX,SI) using temp registers (R8,DI,R12,R9,R10) - REDUCE(R14,R13,CX,BX,SI,R8,DI,R12,R9,R10) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $8-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - MOVQ 32(DX), SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - - // reduce element(R14,R13,CX,BX,SI) using temp registers (DI,R8,R9,R10,R11) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/bls24-317/fp/element_ops_amd64.go b/ecc/bls24-317/fp/element_ops_amd64.go index 83bba45ae..ed2803d71 100644 --- a/ecc/bls24-317/fp/element_ops_amd64.go +++ b/ecc/bls24-317/fp/element_ops_amd64.go @@ -50,48 +50,8 @@ func Butterfly(a, b *Element) // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bls24-317/fp/element_ops_amd64.s b/ecc/bls24-317/fp/element_ops_amd64.s index cb68645b3..29314843d 100644 --- a/ecc/bls24-317/fp/element_ops_amd64.s +++ b/ecc/bls24-317/fp/element_ops_amd64.s @@ -1,272 +1,6 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// Code generated by gnark-crypto/generator. DO NOT EDIT. +// We include the hash to force the Go compiler to recompile: 18184981773209750009 +#include "../../../field/asm/element_5w_amd64.s" -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x8d512e565dab2aab -DATA q<>+8(SB)/8, $0xd6f339e43424bf7e -DATA q<>+16(SB)/8, $0x169a61e684c73446 -DATA q<>+24(SB)/8, $0xf28fc5a0b7f9d039 -DATA q<>+32(SB)/8, $0x1058ca226f60892c -GLOBL q<>(SB), (RODATA+NOPTR), $40 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x55b5e0028b047ffd -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, rb0, rb1, rb2, rb3, rb4) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R13,R14,R15,R8,R9) - REDUCE(DX,CX,BX,SI,DI,R13,R14,R15,R8,R9) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R13,R14,R15,R8,R9) - REDUCE(DX,CX,BX,SI,DI,R13,R14,R15,R8,R9) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R10,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), $16-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP)) - REDUCE(DX,CX,BX,SI,DI,R13,R14,R15,s0-8(SP),s1-16(SP)) - - MOVQ DX, R13 - MOVQ CX, R14 - MOVQ BX, R15 - MOVQ SI, s0-8(SP) - MOVQ DI, s1-16(SP) - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ R13, DX - ADCQ R14, CX - ADCQ R15, BX - ADCQ s0-8(SP), SI - ADCQ s1-16(SP), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), $24-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ 32(AX), R8 - MOVQ CX, R9 - MOVQ BX, R10 - MOVQ SI, R11 - MOVQ DI, R12 - MOVQ R8, R13 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - ADCQ 32(DX), R8 - SUBQ 0(DX), R9 - SBBQ 8(DX), R10 - SBBQ 16(DX), R11 - SBBQ 24(DX), R12 - SBBQ 32(DX), R13 - MOVQ CX, R14 - MOVQ BX, R15 - MOVQ SI, s0-8(SP) - MOVQ DI, s1-16(SP) - MOVQ R8, s2-24(SP) - MOVQ $0x8d512e565dab2aab, CX - MOVQ $0xd6f339e43424bf7e, BX - MOVQ $0x169a61e684c73446, SI - MOVQ $0xf28fc5a0b7f9d039, DI - MOVQ $0x1058ca226f60892c, R8 - CMOVQCC AX, CX - CMOVQCC AX, BX - CMOVQCC AX, SI - CMOVQCC AX, DI - CMOVQCC AX, R8 - ADDQ CX, R9 - ADCQ BX, R10 - ADCQ SI, R11 - ADCQ DI, R12 - ADCQ R8, R13 - MOVQ R14, CX - MOVQ R15, BX - MOVQ s0-8(SP), SI - MOVQ s1-16(SP), DI - MOVQ s2-24(SP), R8 - MOVQ R9, 0(DX) - MOVQ R10, 8(DX) - MOVQ R11, 16(DX) - MOVQ R12, 24(DX) - MOVQ R13, 32(DX) - - // reduce element(CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - MOVQ R8, 32(AX) - RET diff --git a/ecc/bls24-317/fp/element_ops_purego.go b/ecc/bls24-317/fp/element_ops_purego.go index aed04e01f..9f72e6f84 100644 --- a/ecc/bls24-317/fp/element_ops_purego.go +++ b/ecc/bls24-317/fp/element_ops_purego.go @@ -66,48 +66,8 @@ func reduce(z *Element) { // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3, t4 uint64 var u0, u1, u2, u3, u4 uint64 diff --git a/ecc/bls24-317/fp/element_test.go b/ecc/bls24-317/fp/element_test.go index 7bbabe259..7d179efec 100644 --- a/ecc/bls24-317/fp/element_test.go +++ b/ecc/bls24-317/fp/element_test.go @@ -639,7 +639,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -709,77 +708,6 @@ func TestElementLexicographicallyLargest(t *testing.T) { } -func TestElementVecOps(t *testing.T) { - assert := require.New(t) - - const N = 7 - a := make(Vector, N) - b := make(Vector, N) - c := make(Vector, N) - for i := 0; i < N; i++ { - a[i].SetRandom() - b[i].SetRandom() - } - - // Vector addition - c.Add(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Add(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector addition failed") - } - - // Vector subtraction - c.Sub(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Sub(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector subtraction failed") - } - - // Vector scaling - c.ScalarMul(a, &b[0]) - for i := 0; i < N; i++ { - var expected Element - expected.Mul(&a[i], &b[0]) - assert.True(c[i].Equal(&expected), "Vector scaling failed") - } -} - -func BenchmarkElementVecOps(b *testing.B) { - // note; to benchmark against "no asm" version, use the following - // build tag: -tags purego - const N = 1024 - a1 := make(Vector, N) - b1 := make(Vector, N) - c1 := make(Vector, N) - for i := 0; i < N; i++ { - a1[i].SetRandom() - b1[i].SetRandom() - } - - b.Run("Add", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Add(a1, b1) - } - }) - - b.Run("Sub", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Sub(a1, b1) - } - }) - - b.Run("ScalarMul", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.ScalarMul(a1, &b1[0]) - } - }) -} - func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() @@ -2349,40 +2277,40 @@ func gen() gopter.Gen { } } -func genFull() gopter.Gen { - return func(genParams *gopter.GenParameters) *gopter.GenResult { +func genRandomFq(genParams *gopter.GenParameters) Element { + var g Element - genRandomFq := func() Element { - var g Element + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } + if qElement[4] != ^uint64(0) { + g[4] %= (qElement[4] + 1) + } - if qElement[4] != ^uint64(0) { - g[4] %= (qElement[4] + 1) - } + for !g.smallerThanModulus() { + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[4] != ^uint64(0) { + g[4] %= (qElement[4] + 1) + } + } - for !g.smallerThanModulus() { - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } - if qElement[4] != ^uint64(0) { - g[4] %= (qElement[4] + 1) - } - } + return g +} - return g - } - a := genRandomFq() +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) var carry uint64 a[0], carry = bits.Add64(a[0], qElement[0], carry) @@ -2396,6 +2324,14 @@ func genFull() gopter.Gen { } } +func genElement() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} + func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { var modulus big.Int var aIntMod big.Int diff --git a/ecc/bls24-317/fp/vector.go b/ecc/bls24-317/fp/vector.go index 01b326d49..ce61e70ea 100644 --- a/ecc/bls24-317/fp/vector.go +++ b/ecc/bls24-317/fp/vector.go @@ -218,6 +218,25 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} + func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") @@ -245,6 +264,32 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + +func innerProductVecGeneric(res *Element, a, b Vector) { + if len(a) != len(b) { + panic("vector.InnerProduct: vectors don't have the same length") + } + var tmp Element + for i := 0; i < len(a); i++ { + tmp.Mul(&a[i], &b[i]) + res.Add(res, &tmp) + } +} + +func mulVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Mul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], &b[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bls24-317/fp/vector_test.go b/ecc/bls24-317/fp/vector_test.go index 5d88af91c..c60ee0844 100644 --- a/ecc/bls24-317/fp/vector_test.go +++ b/ecc/bls24-317/fp/vector_test.go @@ -18,10 +18,15 @@ package fp import ( "bytes" + "fmt" "github.com/stretchr/testify/require" + "os" "reflect" "sort" "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" ) func TestVectorSort(t *testing.T) { @@ -88,3 +93,281 @@ func (vector *Vector) unmarshalBinaryAsync(data []byte) error { } return <-chErr } + +func TestVectorOps(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = 2 + } else { + parameters.MinSuccessfulTests = 10 + } + properties := gopter.NewProperties(parameters) + + addVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Add(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Add(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + subVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Sub(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Sub(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + scalarMulVector := func(a Vector, b Element) bool { + c := make(Vector, len(a)) + c.ScalarMul(a, &b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sumVector := func(a Vector) bool { + var sum Element + computed := a.Sum() + for i := 0; i < len(a); i++ { + sum.Add(&sum, &a[i]) + } + + return sum.Equal(&computed) + } + + innerProductVector := func(a, b Vector) bool { + computed := a.InnerProduct(b) + var innerProduct Element + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + innerProduct.Add(&innerProduct, &tmp) + } + + return innerProduct.Equal(&computed) + } + + mulVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + a[0].SetUint64(0x24) + b[0].SetUint64(0x42) + c.Mul(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sizes := []int{1, 2, 3, 4, 8, 9, 15, 16, 509, 510, 511, 512, 513, 514} + type genPair struct { + g1, g2 gopter.Gen + label string + } + + for _, size := range sizes { + generators := []genPair{ + {genZeroVector(size), genZeroVector(size), "zero vectors"}, + {genMaxVector(size), genMaxVector(size), "max vectors"}, + {genVector(size), genVector(size), "random vectors"}, + {genVector(size), genZeroVector(size), "random and zero vectors"}, + } + for _, gp := range generators { + properties.Property(fmt.Sprintf("vector addition %d - %s", size, gp.label), prop.ForAll( + addVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector subtraction %d - %s", size, gp.label), prop.ForAll( + subVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector scalar multiplication %d - %s", size, gp.label), prop.ForAll( + scalarMulVector, + gp.g1, + genElement(), + )) + + properties.Property(fmt.Sprintf("vector sum %d - %s", size, gp.label), prop.ForAll( + sumVector, + gp.g1, + )) + + properties.Property(fmt.Sprintf("vector inner product %d - %s", size, gp.label), prop.ForAll( + innerProductVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector multiplication %d - %s", size, gp.label), prop.ForAll( + mulVector, + gp.g1, + gp.g2, + )) + } + } + + properties.TestingRun(t, gopter.NewFormatedReporter(false, 260, os.Stdout)) +} + +func BenchmarkVectorOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1 << 24 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + var mixer Element + mixer.SetRandom() + for i := 1; i < N; i++ { + a1[i-1].SetUint64(uint64(i)). + Mul(&a1[i-1], &mixer) + b1[i-1].SetUint64(^uint64(i)). + Mul(&b1[i-1], &mixer) + } + + for n := 1 << 4; n <= N; n <<= 1 { + b.Run(fmt.Sprintf("add %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Add(_a, _b) + } + }) + + b.Run(fmt.Sprintf("sub %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Sub(_a, _b) + } + }) + + b.Run(fmt.Sprintf("scalarMul %d", n), func(b *testing.B) { + _a := a1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.ScalarMul(_a, &mixer) + } + }) + + b.Run(fmt.Sprintf("sum %d", n), func(b *testing.B) { + _a := a1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.Sum() + } + }) + + b.Run(fmt.Sprintf("innerProduct %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.InnerProduct(_b) + } + }) + + b.Run(fmt.Sprintf("mul %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Mul(_a, _b) + } + }) + } +} + +func genZeroVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genMaxVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + + qMinusOne := qElement + qMinusOne[0]-- + + for i := 0; i < size; i++ { + g[i] = qMinusOne + } + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + mixer := Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[4] != ^uint64(0) { + mixer[4] %= (qElement[4] + 1) + } + + for !mixer.smallerThanModulus() { + mixer = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[4] != ^uint64(0) { + mixer[4] %= (qElement[4] + 1) + } + } + + for i := 1; i <= size; i++ { + g[i-1].SetUint64(uint64(i)). + Mul(&g[i-1], &mixer) + } + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} diff --git a/ecc/bls24-317/fr/asm.go b/ecc/bls24-317/fr/asm_adx.go similarity index 100% rename from ecc/bls24-317/fr/asm.go rename to ecc/bls24-317/fr/asm_adx.go diff --git a/ecc/bls24-317/fr/asm_avx.go b/ecc/bls24-317/fr/asm_avx.go new file mode 100644 index 000000000..955f55979 --- /dev/null +++ b/ecc/bls24-317/fr/asm_avx.go @@ -0,0 +1,27 @@ +//go:build !noavx +// +build !noavx + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import "golang.org/x/sys/cpu" + +var ( + supportAvx512 = supportAdx && cpu.X86.HasAVX512 && cpu.X86.HasAVX512DQ + _ = supportAvx512 +) diff --git a/ecc/bls24-317/fr/asm_noavx.go b/ecc/bls24-317/fr/asm_noavx.go new file mode 100644 index 000000000..e5a5b1f2c --- /dev/null +++ b/ecc/bls24-317/fr/asm_noavx.go @@ -0,0 +1,22 @@ +//go:build noavx +// +build noavx + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +const supportAvx512 = false diff --git a/ecc/bls24-317/fr/element.go b/ecc/bls24-317/fr/element.go index bf3215dad..3aefaebe6 100644 --- a/ecc/bls24-317/fr/element.go +++ b/ecc/bls24-317/fr/element.go @@ -81,6 +81,9 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 17293822569102704639 +// mu = 2^288 / q needed for partial Barrett reduction +const mu uint64 = 16110458503 + func init() { _modulus.SetString("443f917ea68dafc2d0b097f28d83cd491cd1e79196bf0e7af000000000000001", 16) } @@ -477,32 +480,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [5]uint64 var D uint64 diff --git a/ecc/bls24-317/fr/element_mul_amd64.s b/ecc/bls24-317/fr/element_mul_amd64.s deleted file mode 100644 index 6e58b40d6..000000000 --- a/ecc/bls24-317/fr/element_mul_amd64.s +++ /dev/null @@ -1,487 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0xf000000000000001 -DATA q<>+8(SB)/8, $0x1cd1e79196bf0e7a -DATA q<>+16(SB)/8, $0xd0b097f28d83cd49 -DATA q<>+24(SB)/8, $0x443f917ea68dafc2 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0xefffffffffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $24-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), SI - - // x[0] -> DI - // x[1] -> R8 - // x[2] -> R9 - // x[3] -> R10 - MOVQ 0(SI), DI - MOVQ 8(SI), R8 - MOVQ 16(SI), R9 - MOVQ 24(SI), R10 - MOVQ y+16(FP), R11 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ DI, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R8, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R9, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,R12,R11,DI) - REDUCE(R14,R13,CX,BX,SI,R12,R11,DI) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $8-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,DI,R8,R9) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/bls24-317/fr/element_ops_amd64.go b/ecc/bls24-317/fr/element_ops_amd64.go index 21568255d..b653e8006 100644 --- a/ecc/bls24-317/fr/element_ops_amd64.go +++ b/ecc/bls24-317/fr/element_ops_amd64.go @@ -51,7 +51,8 @@ func (vector *Vector) Add(a, b Vector) { if len(a) != len(b) || len(a) != len(*vector) { panic("vector.Add: vectors don't have the same length") } - addVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) + n := uint64(len(a)) + addVec(&(*vector)[0], &a[0], &b[0], n) } //go:noescape @@ -75,59 +76,123 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { if len(a) != len(*vector) { panic("vector.ScalarMul: vectors don't have the same length") } - scalarMulVec(&(*vector)[0], &a[0], b, uint64(len(a))) + const maxN = (1 << 32) - 1 + if !supportAvx512 || uint64(len(a)) >= maxN { + // call scalarMulVecGeneric + scalarMulVecGeneric(*vector, a, b) + return + } + n := uint64(len(a)) + if n == 0 { + return + } + // the code for scalarMul is identical to mulVec; and it expects at least + // 2 elements in the vector to fill the Z registers + var bb [2]Element + bb[0] = *b + bb[1] = *b + const blockSize = 16 + scalarMulVec(&(*vector)[0], &a[0], &bb[0], n/blockSize, qInvNeg) + if n%blockSize != 0 { + // call scalarMulVecGeneric on the rest + start := n - n%blockSize + scalarMulVecGeneric((*vector)[start:], a[start:], b) + } +} + +//go:noescape +func scalarMulVec(res, a, b *Element, n uint64, qInvNeg uint64) + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + n := uint64(len(*vector)) + if n == 0 { + return + } + const minN = 16 * 7 // AVX512 slower than generic for small n + const maxN = (1 << 32) - 1 + if !supportAvx512 || n <= minN || n >= maxN { + // call sumVecGeneric + sumVecGeneric(&res, *vector) + return + } + sumVec(&res, &(*vector)[0], uint64(len(*vector))) + return } //go:noescape -func scalarMulVec(res, a, b *Element, n uint64) +func sumVec(res *Element, a *Element, n uint64) + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + n := uint64(len(*vector)) + if n == 0 { + return + } + if n != uint64(len(other)) { + panic("vector.InnerProduct: vectors don't have the same length") + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call innerProductVecGeneric + // note; we could split the vector into smaller chunks and call innerProductVec + innerProductVecGeneric(&res, *vector, other) + return + } + innerProdVec(&res[0], &(*vector)[0], &other[0], uint64(len(*vector))) + + return +} + +//go:noescape +func innerProdVec(res *uint64, a, b *Element, n uint64) + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Mul: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call mulVecGeneric + mulVecGeneric(*vector, a, b) + return + } + + const blockSize = 16 + mulVec(&(*vector)[0], &a[0], &b[0], n/blockSize, qInvNeg) + if n%blockSize != 0 { + // call mulVecGeneric on the rest + start := n - n%blockSize + mulVecGeneric((*vector)[start:], a[start:], b[start:]) + } + +} + +// Patterns use for transposing the vectors in mulVec +var ( + pattern1 = [8]uint64{0, 8, 1, 9, 2, 10, 3, 11} + pattern2 = [8]uint64{12, 4, 13, 5, 14, 6, 15, 7} + pattern3 = [8]uint64{0, 1, 8, 9, 2, 3, 10, 11} + pattern4 = [8]uint64{12, 13, 4, 5, 14, 15, 6, 7} +) + +//go:noescape +func mulVec(res, a, b *Element, n uint64, qInvNeg uint64) // Mul z = x * y (mod q) // // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bls24-317/fr/element_ops_amd64.s b/ecc/bls24-317/fr/element_ops_amd64.s index fd237dad9..6c42136a7 100644 --- a/ecc/bls24-317/fr/element_ops_amd64.s +++ b/ecc/bls24-317/fr/element_ops_amd64.s @@ -1,627 +1,6 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// Code generated by gnark-crypto/generator. DO NOT EDIT. +// We include the hash to force the Go compiler to recompile: 9425145785761608449 +#include "../../../field/asm/element_4w_amd64.s" -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0xf000000000000001 -DATA q<>+8(SB)/8, $0x1cd1e79196bf0e7a -DATA q<>+16(SB)/8, $0xd0b097f28d83cd49 -DATA q<>+24(SB)/8, $0x443f917ea68dafc2 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0xefffffffffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R15,DI,R8,R9) - REDUCE(DX,CX,BX,SI,R15,DI,R8,R9) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, R11 - MOVQ CX, R12 - MOVQ BX, R13 - MOVQ SI, R14 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ R11, DX - ADCQ R12, CX - ADCQ R13, BX - ADCQ R14, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), NOSPLIT, $0-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ CX, R8 - MOVQ BX, R9 - MOVQ SI, R10 - MOVQ DI, R11 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - SUBQ 0(DX), R8 - SBBQ 8(DX), R9 - SBBQ 16(DX), R10 - SBBQ 24(DX), R11 - MOVQ $0xf000000000000001, R12 - MOVQ $0x1cd1e79196bf0e7a, R13 - MOVQ $0xd0b097f28d83cd49, R14 - MOVQ $0x443f917ea68dafc2, R15 - CMOVQCC AX, R12 - CMOVQCC AX, R13 - CMOVQCC AX, R14 - CMOVQCC AX, R15 - ADDQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - ADCQ R15, R11 - MOVQ R8, 0(DX) - MOVQ R9, 8(DX) - MOVQ R10, 16(DX) - MOVQ R11, 24(DX) - - // reduce element(CX,BX,SI,DI) using temp registers (R8,R9,R10,R11) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - RET - -// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] -TEXT ·addVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - -loop_1: - TESTQ BX, BX - JEQ done_2 // n == 0, we are done - - // a[0] -> SI - // a[1] -> DI - // a[2] -> R8 - // a[3] -> R9 - MOVQ 0(AX), SI - MOVQ 8(AX), DI - MOVQ 16(AX), R8 - MOVQ 24(AX), R9 - ADDQ 0(DX), SI - ADCQ 8(DX), DI - ADCQ 16(DX), R8 - ADCQ 24(DX), R9 - - // reduce element(SI,DI,R8,R9) using temp registers (R10,R11,R12,R13) - REDUCE(SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ SI, 0(CX) - MOVQ DI, 8(CX) - MOVQ R8, 16(CX) - MOVQ R9, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_1 - -done_2: - RET - -// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] -TEXT ·subVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - XORQ SI, SI - -loop_3: - TESTQ BX, BX - JEQ done_4 // n == 0, we are done - - // a[0] -> DI - // a[1] -> R8 - // a[2] -> R9 - // a[3] -> R10 - MOVQ 0(AX), DI - MOVQ 8(AX), R8 - MOVQ 16(AX), R9 - MOVQ 24(AX), R10 - SUBQ 0(DX), DI - SBBQ 8(DX), R8 - SBBQ 16(DX), R9 - SBBQ 24(DX), R10 - - // reduce (a-b) mod q - // q[0] -> R11 - // q[1] -> R12 - // q[2] -> R13 - // q[3] -> R14 - MOVQ $0xf000000000000001, R11 - MOVQ $0x1cd1e79196bf0e7a, R12 - MOVQ $0xd0b097f28d83cd49, R13 - MOVQ $0x443f917ea68dafc2, R14 - CMOVQCC SI, R11 - CMOVQCC SI, R12 - CMOVQCC SI, R13 - CMOVQCC SI, R14 - - // add registers (q or 0) to a, and set to result - ADDQ R11, DI - ADCQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - MOVQ DI, 0(CX) - MOVQ R8, 8(CX) - MOVQ R9, 16(CX) - MOVQ R10, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_3 - -done_4: - RET - -// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b -TEXT ·scalarMulVec(SB), $56-32 - CMPB ·supportAdx(SB), $1 - JNE noAdx_5 - MOVQ a+8(FP), R11 - MOVQ b+16(FP), R10 - MOVQ n+24(FP), R12 - - // scalar[0] -> SI - // scalar[1] -> DI - // scalar[2] -> R8 - // scalar[3] -> R9 - MOVQ 0(R10), SI - MOVQ 8(R10), DI - MOVQ 16(R10), R8 - MOVQ 24(R10), R9 - MOVQ res+0(FP), R10 - -loop_6: - TESTQ R12, R12 - JEQ done_7 // n == 0, we are done - - // TODO @gbotrel this is generated from the same macro as the unit mul, we should refactor this in a single asm function - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ SI, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ DI, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R8, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce t mod q - // reduce element(R14,R15,CX,BX) using temp registers (R13,AX,DX,s0-8(SP)) - REDUCE(R14,R15,CX,BX,R13,AX,DX,s0-8(SP)) - - MOVQ R14, 0(R10) - MOVQ R15, 8(R10) - MOVQ CX, 16(R10) - MOVQ BX, 24(R10) - - // increment pointers to visit next element - ADDQ $32, R11 - ADDQ $32, R10 - DECQ R12 // decrement n - JMP loop_6 - -done_7: - RET - -noAdx_5: - MOVQ n+24(FP), DX - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ DX, 8(SP) - MOVQ DX, 16(SP) - MOVQ a+8(FP), AX - MOVQ AX, 24(SP) - MOVQ DX, 32(SP) - MOVQ DX, 40(SP) - MOVQ b+16(FP), AX - MOVQ AX, 48(SP) - CALL ·scalarMulVecGeneric(SB) - RET diff --git a/ecc/bls24-317/fr/element_ops_purego.go b/ecc/bls24-317/fr/element_ops_purego.go index 14505483c..7afd9cc8d 100644 --- a/ecc/bls24-317/fr/element_ops_purego.go +++ b/ecc/bls24-317/fr/element_ops_purego.go @@ -78,53 +78,32 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} + // Mul z = x * y (mod q) // // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3 uint64 var u0, u1, u2, u3 uint64 diff --git a/ecc/bls24-317/fr/element_test.go b/ecc/bls24-317/fr/element_test.go index c533cc1c9..b5db750d6 100644 --- a/ecc/bls24-317/fr/element_test.go +++ b/ecc/bls24-317/fr/element_test.go @@ -637,7 +637,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -707,77 +706,6 @@ func TestElementLexicographicallyLargest(t *testing.T) { } -func TestElementVecOps(t *testing.T) { - assert := require.New(t) - - const N = 7 - a := make(Vector, N) - b := make(Vector, N) - c := make(Vector, N) - for i := 0; i < N; i++ { - a[i].SetRandom() - b[i].SetRandom() - } - - // Vector addition - c.Add(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Add(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector addition failed") - } - - // Vector subtraction - c.Sub(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Sub(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector subtraction failed") - } - - // Vector scaling - c.ScalarMul(a, &b[0]) - for i := 0; i < N; i++ { - var expected Element - expected.Mul(&a[i], &b[0]) - assert.True(c[i].Equal(&expected), "Vector scaling failed") - } -} - -func BenchmarkElementVecOps(b *testing.B) { - // note; to benchmark against "no asm" version, use the following - // build tag: -tags purego - const N = 1024 - a1 := make(Vector, N) - b1 := make(Vector, N) - c1 := make(Vector, N) - for i := 0; i < N; i++ { - a1[i].SetRandom() - b1[i].SetRandom() - } - - b.Run("Add", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Add(a1, b1) - } - }) - - b.Run("Sub", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Sub(a1, b1) - } - }) - - b.Run("ScalarMul", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.ScalarMul(a1, &b1[0]) - } - }) -} - func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() @@ -2345,38 +2273,38 @@ func gen() gopter.Gen { } } -func genFull() gopter.Gen { - return func(genParams *gopter.GenParameters) *gopter.GenResult { +func genRandomFq(genParams *gopter.GenParameters) Element { + var g Element - genRandomFq := func() Element { - var g Element + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } - if qElement[3] != ^uint64(0) { - g[3] %= (qElement[3] + 1) - } + for !g.smallerThanModulus() { + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } + } - for !g.smallerThanModulus() { - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } - if qElement[3] != ^uint64(0) { - g[3] %= (qElement[3] + 1) - } - } + return g +} - return g - } - a := genRandomFq() +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) var carry uint64 a[0], carry = bits.Add64(a[0], qElement[0], carry) @@ -2389,6 +2317,14 @@ func genFull() gopter.Gen { } } +func genElement() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} + func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { var modulus big.Int var aIntMod big.Int diff --git a/ecc/bls24-317/fr/vector.go b/ecc/bls24-317/fr/vector.go index f39828547..867cabbc3 100644 --- a/ecc/bls24-317/fr/vector.go +++ b/ecc/bls24-317/fr/vector.go @@ -226,6 +226,32 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + +func innerProductVecGeneric(res *Element, a, b Vector) { + if len(a) != len(b) { + panic("vector.InnerProduct: vectors don't have the same length") + } + var tmp Element + for i := 0; i < len(a); i++ { + tmp.Mul(&a[i], &b[i]) + res.Add(res, &tmp) + } +} + +func mulVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Mul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], &b[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bls24-317/fr/vector_test.go b/ecc/bls24-317/fr/vector_test.go index e58f2d9a3..b6344c18b 100644 --- a/ecc/bls24-317/fr/vector_test.go +++ b/ecc/bls24-317/fr/vector_test.go @@ -18,10 +18,15 @@ package fr import ( "bytes" + "fmt" "github.com/stretchr/testify/require" + "os" "reflect" "sort" "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" ) func TestVectorSort(t *testing.T) { @@ -88,3 +93,279 @@ func (vector *Vector) unmarshalBinaryAsync(data []byte) error { } return <-chErr } + +func TestVectorOps(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = 2 + } else { + parameters.MinSuccessfulTests = 10 + } + properties := gopter.NewProperties(parameters) + + addVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Add(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Add(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + subVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Sub(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Sub(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + scalarMulVector := func(a Vector, b Element) bool { + c := make(Vector, len(a)) + c.ScalarMul(a, &b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sumVector := func(a Vector) bool { + var sum Element + computed := a.Sum() + for i := 0; i < len(a); i++ { + sum.Add(&sum, &a[i]) + } + + return sum.Equal(&computed) + } + + innerProductVector := func(a, b Vector) bool { + computed := a.InnerProduct(b) + var innerProduct Element + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + innerProduct.Add(&innerProduct, &tmp) + } + + return innerProduct.Equal(&computed) + } + + mulVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + a[0].SetUint64(0x24) + b[0].SetUint64(0x42) + c.Mul(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sizes := []int{1, 2, 3, 4, 8, 9, 15, 16, 509, 510, 511, 512, 513, 514} + type genPair struct { + g1, g2 gopter.Gen + label string + } + + for _, size := range sizes { + generators := []genPair{ + {genZeroVector(size), genZeroVector(size), "zero vectors"}, + {genMaxVector(size), genMaxVector(size), "max vectors"}, + {genVector(size), genVector(size), "random vectors"}, + {genVector(size), genZeroVector(size), "random and zero vectors"}, + } + for _, gp := range generators { + properties.Property(fmt.Sprintf("vector addition %d - %s", size, gp.label), prop.ForAll( + addVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector subtraction %d - %s", size, gp.label), prop.ForAll( + subVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector scalar multiplication %d - %s", size, gp.label), prop.ForAll( + scalarMulVector, + gp.g1, + genElement(), + )) + + properties.Property(fmt.Sprintf("vector sum %d - %s", size, gp.label), prop.ForAll( + sumVector, + gp.g1, + )) + + properties.Property(fmt.Sprintf("vector inner product %d - %s", size, gp.label), prop.ForAll( + innerProductVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector multiplication %d - %s", size, gp.label), prop.ForAll( + mulVector, + gp.g1, + gp.g2, + )) + } + } + + properties.TestingRun(t, gopter.NewFormatedReporter(false, 260, os.Stdout)) +} + +func BenchmarkVectorOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1 << 24 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + var mixer Element + mixer.SetRandom() + for i := 1; i < N; i++ { + a1[i-1].SetUint64(uint64(i)). + Mul(&a1[i-1], &mixer) + b1[i-1].SetUint64(^uint64(i)). + Mul(&b1[i-1], &mixer) + } + + for n := 1 << 4; n <= N; n <<= 1 { + b.Run(fmt.Sprintf("add %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Add(_a, _b) + } + }) + + b.Run(fmt.Sprintf("sub %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Sub(_a, _b) + } + }) + + b.Run(fmt.Sprintf("scalarMul %d", n), func(b *testing.B) { + _a := a1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.ScalarMul(_a, &mixer) + } + }) + + b.Run(fmt.Sprintf("sum %d", n), func(b *testing.B) { + _a := a1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.Sum() + } + }) + + b.Run(fmt.Sprintf("innerProduct %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.InnerProduct(_b) + } + }) + + b.Run(fmt.Sprintf("mul %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Mul(_a, _b) + } + }) + } +} + +func genZeroVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genMaxVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + + qMinusOne := qElement + qMinusOne[0]-- + + for i := 0; i < size; i++ { + g[i] = qMinusOne + } + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + mixer := Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + mixer[3] %= (qElement[3] + 1) + } + + for !mixer.smallerThanModulus() { + mixer = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + mixer[3] %= (qElement[3] + 1) + } + } + + for i := 1; i <= size; i++ { + g[i-1].SetUint64(uint64(i)). + Mul(&g[i-1], &mixer) + } + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} diff --git a/ecc/bn254/fp/asm.go b/ecc/bn254/fp/asm_adx.go similarity index 100% rename from ecc/bn254/fp/asm.go rename to ecc/bn254/fp/asm_adx.go diff --git a/ecc/bn254/fp/asm_avx.go b/ecc/bn254/fp/asm_avx.go new file mode 100644 index 000000000..cea035ee8 --- /dev/null +++ b/ecc/bn254/fp/asm_avx.go @@ -0,0 +1,27 @@ +//go:build !noavx +// +build !noavx + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import "golang.org/x/sys/cpu" + +var ( + supportAvx512 = supportAdx && cpu.X86.HasAVX512 && cpu.X86.HasAVX512DQ + _ = supportAvx512 +) diff --git a/ecc/bn254/fp/asm_noavx.go b/ecc/bn254/fp/asm_noavx.go new file mode 100644 index 000000000..9ca08a375 --- /dev/null +++ b/ecc/bn254/fp/asm_noavx.go @@ -0,0 +1,22 @@ +//go:build noavx +// +build noavx + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +const supportAvx512 = false diff --git a/ecc/bn254/fp/element.go b/ecc/bn254/fp/element.go index 5ba388e73..25fcdb67c 100644 --- a/ecc/bn254/fp/element.go +++ b/ecc/bn254/fp/element.go @@ -81,6 +81,9 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 9786893198990664585 +// mu = 2^288 / q needed for partial Barrett reduction +const mu uint64 = 22721021478 + func init() { _modulus.SetString("30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", 16) } @@ -477,32 +480,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [5]uint64 var D uint64 diff --git a/ecc/bn254/fp/element_mul_amd64.s b/ecc/bn254/fp/element_mul_amd64.s deleted file mode 100644 index 9357a21d7..000000000 --- a/ecc/bn254/fp/element_mul_amd64.s +++ /dev/null @@ -1,487 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x3c208c16d87cfd47 -DATA q<>+8(SB)/8, $0x97816a916871ca8d -DATA q<>+16(SB)/8, $0xb85045b68181585d -DATA q<>+24(SB)/8, $0x30644e72e131a029 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x87d20782e4866389 -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $24-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), SI - - // x[0] -> DI - // x[1] -> R8 - // x[2] -> R9 - // x[3] -> R10 - MOVQ 0(SI), DI - MOVQ 8(SI), R8 - MOVQ 16(SI), R9 - MOVQ 24(SI), R10 - MOVQ y+16(FP), R11 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ DI, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R8, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R9, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,R12,R11,DI) - REDUCE(R14,R13,CX,BX,SI,R12,R11,DI) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $8-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,DI,R8,R9) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/bn254/fp/element_ops_amd64.go b/ecc/bn254/fp/element_ops_amd64.go index 6f16baf68..2ab1a9839 100644 --- a/ecc/bn254/fp/element_ops_amd64.go +++ b/ecc/bn254/fp/element_ops_amd64.go @@ -51,7 +51,8 @@ func (vector *Vector) Add(a, b Vector) { if len(a) != len(b) || len(a) != len(*vector) { panic("vector.Add: vectors don't have the same length") } - addVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) + n := uint64(len(a)) + addVec(&(*vector)[0], &a[0], &b[0], n) } //go:noescape @@ -75,59 +76,123 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { if len(a) != len(*vector) { panic("vector.ScalarMul: vectors don't have the same length") } - scalarMulVec(&(*vector)[0], &a[0], b, uint64(len(a))) + const maxN = (1 << 32) - 1 + if !supportAvx512 || uint64(len(a)) >= maxN { + // call scalarMulVecGeneric + scalarMulVecGeneric(*vector, a, b) + return + } + n := uint64(len(a)) + if n == 0 { + return + } + // the code for scalarMul is identical to mulVec; and it expects at least + // 2 elements in the vector to fill the Z registers + var bb [2]Element + bb[0] = *b + bb[1] = *b + const blockSize = 16 + scalarMulVec(&(*vector)[0], &a[0], &bb[0], n/blockSize, qInvNeg) + if n%blockSize != 0 { + // call scalarMulVecGeneric on the rest + start := n - n%blockSize + scalarMulVecGeneric((*vector)[start:], a[start:], b) + } +} + +//go:noescape +func scalarMulVec(res, a, b *Element, n uint64, qInvNeg uint64) + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + n := uint64(len(*vector)) + if n == 0 { + return + } + const minN = 16 * 7 // AVX512 slower than generic for small n + const maxN = (1 << 32) - 1 + if !supportAvx512 || n <= minN || n >= maxN { + // call sumVecGeneric + sumVecGeneric(&res, *vector) + return + } + sumVec(&res, &(*vector)[0], uint64(len(*vector))) + return } //go:noescape -func scalarMulVec(res, a, b *Element, n uint64) +func sumVec(res *Element, a *Element, n uint64) + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + n := uint64(len(*vector)) + if n == 0 { + return + } + if n != uint64(len(other)) { + panic("vector.InnerProduct: vectors don't have the same length") + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call innerProductVecGeneric + // note; we could split the vector into smaller chunks and call innerProductVec + innerProductVecGeneric(&res, *vector, other) + return + } + innerProdVec(&res[0], &(*vector)[0], &other[0], uint64(len(*vector))) + + return +} + +//go:noescape +func innerProdVec(res *uint64, a, b *Element, n uint64) + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Mul: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call mulVecGeneric + mulVecGeneric(*vector, a, b) + return + } + + const blockSize = 16 + mulVec(&(*vector)[0], &a[0], &b[0], n/blockSize, qInvNeg) + if n%blockSize != 0 { + // call mulVecGeneric on the rest + start := n - n%blockSize + mulVecGeneric((*vector)[start:], a[start:], b[start:]) + } + +} + +// Patterns use for transposing the vectors in mulVec +var ( + pattern1 = [8]uint64{0, 8, 1, 9, 2, 10, 3, 11} + pattern2 = [8]uint64{12, 4, 13, 5, 14, 6, 15, 7} + pattern3 = [8]uint64{0, 1, 8, 9, 2, 3, 10, 11} + pattern4 = [8]uint64{12, 13, 4, 5, 14, 15, 6, 7} +) + +//go:noescape +func mulVec(res, a, b *Element, n uint64, qInvNeg uint64) // Mul z = x * y (mod q) // // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bn254/fp/element_ops_amd64.s b/ecc/bn254/fp/element_ops_amd64.s index cbfba4ee5..6c42136a7 100644 --- a/ecc/bn254/fp/element_ops_amd64.s +++ b/ecc/bn254/fp/element_ops_amd64.s @@ -1,627 +1,6 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// Code generated by gnark-crypto/generator. DO NOT EDIT. +// We include the hash to force the Go compiler to recompile: 9425145785761608449 +#include "../../../field/asm/element_4w_amd64.s" -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x3c208c16d87cfd47 -DATA q<>+8(SB)/8, $0x97816a916871ca8d -DATA q<>+16(SB)/8, $0xb85045b68181585d -DATA q<>+24(SB)/8, $0x30644e72e131a029 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x87d20782e4866389 -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R15,DI,R8,R9) - REDUCE(DX,CX,BX,SI,R15,DI,R8,R9) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, R11 - MOVQ CX, R12 - MOVQ BX, R13 - MOVQ SI, R14 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ R11, DX - ADCQ R12, CX - ADCQ R13, BX - ADCQ R14, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), NOSPLIT, $0-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ CX, R8 - MOVQ BX, R9 - MOVQ SI, R10 - MOVQ DI, R11 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - SUBQ 0(DX), R8 - SBBQ 8(DX), R9 - SBBQ 16(DX), R10 - SBBQ 24(DX), R11 - MOVQ $0x3c208c16d87cfd47, R12 - MOVQ $0x97816a916871ca8d, R13 - MOVQ $0xb85045b68181585d, R14 - MOVQ $0x30644e72e131a029, R15 - CMOVQCC AX, R12 - CMOVQCC AX, R13 - CMOVQCC AX, R14 - CMOVQCC AX, R15 - ADDQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - ADCQ R15, R11 - MOVQ R8, 0(DX) - MOVQ R9, 8(DX) - MOVQ R10, 16(DX) - MOVQ R11, 24(DX) - - // reduce element(CX,BX,SI,DI) using temp registers (R8,R9,R10,R11) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - RET - -// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] -TEXT ·addVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - -loop_1: - TESTQ BX, BX - JEQ done_2 // n == 0, we are done - - // a[0] -> SI - // a[1] -> DI - // a[2] -> R8 - // a[3] -> R9 - MOVQ 0(AX), SI - MOVQ 8(AX), DI - MOVQ 16(AX), R8 - MOVQ 24(AX), R9 - ADDQ 0(DX), SI - ADCQ 8(DX), DI - ADCQ 16(DX), R8 - ADCQ 24(DX), R9 - - // reduce element(SI,DI,R8,R9) using temp registers (R10,R11,R12,R13) - REDUCE(SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ SI, 0(CX) - MOVQ DI, 8(CX) - MOVQ R8, 16(CX) - MOVQ R9, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_1 - -done_2: - RET - -// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] -TEXT ·subVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - XORQ SI, SI - -loop_3: - TESTQ BX, BX - JEQ done_4 // n == 0, we are done - - // a[0] -> DI - // a[1] -> R8 - // a[2] -> R9 - // a[3] -> R10 - MOVQ 0(AX), DI - MOVQ 8(AX), R8 - MOVQ 16(AX), R9 - MOVQ 24(AX), R10 - SUBQ 0(DX), DI - SBBQ 8(DX), R8 - SBBQ 16(DX), R9 - SBBQ 24(DX), R10 - - // reduce (a-b) mod q - // q[0] -> R11 - // q[1] -> R12 - // q[2] -> R13 - // q[3] -> R14 - MOVQ $0x3c208c16d87cfd47, R11 - MOVQ $0x97816a916871ca8d, R12 - MOVQ $0xb85045b68181585d, R13 - MOVQ $0x30644e72e131a029, R14 - CMOVQCC SI, R11 - CMOVQCC SI, R12 - CMOVQCC SI, R13 - CMOVQCC SI, R14 - - // add registers (q or 0) to a, and set to result - ADDQ R11, DI - ADCQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - MOVQ DI, 0(CX) - MOVQ R8, 8(CX) - MOVQ R9, 16(CX) - MOVQ R10, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_3 - -done_4: - RET - -// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b -TEXT ·scalarMulVec(SB), $56-32 - CMPB ·supportAdx(SB), $1 - JNE noAdx_5 - MOVQ a+8(FP), R11 - MOVQ b+16(FP), R10 - MOVQ n+24(FP), R12 - - // scalar[0] -> SI - // scalar[1] -> DI - // scalar[2] -> R8 - // scalar[3] -> R9 - MOVQ 0(R10), SI - MOVQ 8(R10), DI - MOVQ 16(R10), R8 - MOVQ 24(R10), R9 - MOVQ res+0(FP), R10 - -loop_6: - TESTQ R12, R12 - JEQ done_7 // n == 0, we are done - - // TODO @gbotrel this is generated from the same macro as the unit mul, we should refactor this in a single asm function - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ SI, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ DI, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R8, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce t mod q - // reduce element(R14,R15,CX,BX) using temp registers (R13,AX,DX,s0-8(SP)) - REDUCE(R14,R15,CX,BX,R13,AX,DX,s0-8(SP)) - - MOVQ R14, 0(R10) - MOVQ R15, 8(R10) - MOVQ CX, 16(R10) - MOVQ BX, 24(R10) - - // increment pointers to visit next element - ADDQ $32, R11 - ADDQ $32, R10 - DECQ R12 // decrement n - JMP loop_6 - -done_7: - RET - -noAdx_5: - MOVQ n+24(FP), DX - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ DX, 8(SP) - MOVQ DX, 16(SP) - MOVQ a+8(FP), AX - MOVQ AX, 24(SP) - MOVQ DX, 32(SP) - MOVQ DX, 40(SP) - MOVQ b+16(FP), AX - MOVQ AX, 48(SP) - CALL ·scalarMulVecGeneric(SB) - RET diff --git a/ecc/bn254/fp/element_ops_purego.go b/ecc/bn254/fp/element_ops_purego.go index 250ac5bce..454376da5 100644 --- a/ecc/bn254/fp/element_ops_purego.go +++ b/ecc/bn254/fp/element_ops_purego.go @@ -78,53 +78,32 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} + // Mul z = x * y (mod q) // // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3 uint64 var u0, u1, u2, u3 uint64 diff --git a/ecc/bn254/fp/element_test.go b/ecc/bn254/fp/element_test.go index a923ef657..22d11d4ac 100644 --- a/ecc/bn254/fp/element_test.go +++ b/ecc/bn254/fp/element_test.go @@ -637,7 +637,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -707,77 +706,6 @@ func TestElementLexicographicallyLargest(t *testing.T) { } -func TestElementVecOps(t *testing.T) { - assert := require.New(t) - - const N = 7 - a := make(Vector, N) - b := make(Vector, N) - c := make(Vector, N) - for i := 0; i < N; i++ { - a[i].SetRandom() - b[i].SetRandom() - } - - // Vector addition - c.Add(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Add(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector addition failed") - } - - // Vector subtraction - c.Sub(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Sub(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector subtraction failed") - } - - // Vector scaling - c.ScalarMul(a, &b[0]) - for i := 0; i < N; i++ { - var expected Element - expected.Mul(&a[i], &b[0]) - assert.True(c[i].Equal(&expected), "Vector scaling failed") - } -} - -func BenchmarkElementVecOps(b *testing.B) { - // note; to benchmark against "no asm" version, use the following - // build tag: -tags purego - const N = 1024 - a1 := make(Vector, N) - b1 := make(Vector, N) - c1 := make(Vector, N) - for i := 0; i < N; i++ { - a1[i].SetRandom() - b1[i].SetRandom() - } - - b.Run("Add", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Add(a1, b1) - } - }) - - b.Run("Sub", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Sub(a1, b1) - } - }) - - b.Run("ScalarMul", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.ScalarMul(a1, &b1[0]) - } - }) -} - func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() @@ -2345,38 +2273,38 @@ func gen() gopter.Gen { } } -func genFull() gopter.Gen { - return func(genParams *gopter.GenParameters) *gopter.GenResult { +func genRandomFq(genParams *gopter.GenParameters) Element { + var g Element - genRandomFq := func() Element { - var g Element + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } - if qElement[3] != ^uint64(0) { - g[3] %= (qElement[3] + 1) - } + for !g.smallerThanModulus() { + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } + } - for !g.smallerThanModulus() { - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } - if qElement[3] != ^uint64(0) { - g[3] %= (qElement[3] + 1) - } - } + return g +} - return g - } - a := genRandomFq() +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) var carry uint64 a[0], carry = bits.Add64(a[0], qElement[0], carry) @@ -2389,6 +2317,14 @@ func genFull() gopter.Gen { } } +func genElement() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} + func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { var modulus big.Int var aIntMod big.Int diff --git a/ecc/bn254/fp/vector.go b/ecc/bn254/fp/vector.go index 850b3603d..c97b4283c 100644 --- a/ecc/bn254/fp/vector.go +++ b/ecc/bn254/fp/vector.go @@ -226,6 +226,32 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + +func innerProductVecGeneric(res *Element, a, b Vector) { + if len(a) != len(b) { + panic("vector.InnerProduct: vectors don't have the same length") + } + var tmp Element + for i := 0; i < len(a); i++ { + tmp.Mul(&a[i], &b[i]) + res.Add(res, &tmp) + } +} + +func mulVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Mul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], &b[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bn254/fp/vector_test.go b/ecc/bn254/fp/vector_test.go index 5d88af91c..12f17e21f 100644 --- a/ecc/bn254/fp/vector_test.go +++ b/ecc/bn254/fp/vector_test.go @@ -18,10 +18,15 @@ package fp import ( "bytes" + "fmt" "github.com/stretchr/testify/require" + "os" "reflect" "sort" "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" ) func TestVectorSort(t *testing.T) { @@ -88,3 +93,279 @@ func (vector *Vector) unmarshalBinaryAsync(data []byte) error { } return <-chErr } + +func TestVectorOps(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = 2 + } else { + parameters.MinSuccessfulTests = 10 + } + properties := gopter.NewProperties(parameters) + + addVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Add(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Add(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + subVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Sub(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Sub(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + scalarMulVector := func(a Vector, b Element) bool { + c := make(Vector, len(a)) + c.ScalarMul(a, &b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sumVector := func(a Vector) bool { + var sum Element + computed := a.Sum() + for i := 0; i < len(a); i++ { + sum.Add(&sum, &a[i]) + } + + return sum.Equal(&computed) + } + + innerProductVector := func(a, b Vector) bool { + computed := a.InnerProduct(b) + var innerProduct Element + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + innerProduct.Add(&innerProduct, &tmp) + } + + return innerProduct.Equal(&computed) + } + + mulVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + a[0].SetUint64(0x24) + b[0].SetUint64(0x42) + c.Mul(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sizes := []int{1, 2, 3, 4, 8, 9, 15, 16, 509, 510, 511, 512, 513, 514} + type genPair struct { + g1, g2 gopter.Gen + label string + } + + for _, size := range sizes { + generators := []genPair{ + {genZeroVector(size), genZeroVector(size), "zero vectors"}, + {genMaxVector(size), genMaxVector(size), "max vectors"}, + {genVector(size), genVector(size), "random vectors"}, + {genVector(size), genZeroVector(size), "random and zero vectors"}, + } + for _, gp := range generators { + properties.Property(fmt.Sprintf("vector addition %d - %s", size, gp.label), prop.ForAll( + addVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector subtraction %d - %s", size, gp.label), prop.ForAll( + subVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector scalar multiplication %d - %s", size, gp.label), prop.ForAll( + scalarMulVector, + gp.g1, + genElement(), + )) + + properties.Property(fmt.Sprintf("vector sum %d - %s", size, gp.label), prop.ForAll( + sumVector, + gp.g1, + )) + + properties.Property(fmt.Sprintf("vector inner product %d - %s", size, gp.label), prop.ForAll( + innerProductVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector multiplication %d - %s", size, gp.label), prop.ForAll( + mulVector, + gp.g1, + gp.g2, + )) + } + } + + properties.TestingRun(t, gopter.NewFormatedReporter(false, 260, os.Stdout)) +} + +func BenchmarkVectorOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1 << 24 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + var mixer Element + mixer.SetRandom() + for i := 1; i < N; i++ { + a1[i-1].SetUint64(uint64(i)). + Mul(&a1[i-1], &mixer) + b1[i-1].SetUint64(^uint64(i)). + Mul(&b1[i-1], &mixer) + } + + for n := 1 << 4; n <= N; n <<= 1 { + b.Run(fmt.Sprintf("add %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Add(_a, _b) + } + }) + + b.Run(fmt.Sprintf("sub %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Sub(_a, _b) + } + }) + + b.Run(fmt.Sprintf("scalarMul %d", n), func(b *testing.B) { + _a := a1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.ScalarMul(_a, &mixer) + } + }) + + b.Run(fmt.Sprintf("sum %d", n), func(b *testing.B) { + _a := a1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.Sum() + } + }) + + b.Run(fmt.Sprintf("innerProduct %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.InnerProduct(_b) + } + }) + + b.Run(fmt.Sprintf("mul %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Mul(_a, _b) + } + }) + } +} + +func genZeroVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genMaxVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + + qMinusOne := qElement + qMinusOne[0]-- + + for i := 0; i < size; i++ { + g[i] = qMinusOne + } + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + mixer := Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + mixer[3] %= (qElement[3] + 1) + } + + for !mixer.smallerThanModulus() { + mixer = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + mixer[3] %= (qElement[3] + 1) + } + } + + for i := 1; i <= size; i++ { + g[i-1].SetUint64(uint64(i)). + Mul(&g[i-1], &mixer) + } + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} diff --git a/ecc/bn254/fr/asm.go b/ecc/bn254/fr/asm_adx.go similarity index 100% rename from ecc/bn254/fr/asm.go rename to ecc/bn254/fr/asm_adx.go diff --git a/ecc/bn254/fr/asm_avx.go b/ecc/bn254/fr/asm_avx.go new file mode 100644 index 000000000..955f55979 --- /dev/null +++ b/ecc/bn254/fr/asm_avx.go @@ -0,0 +1,27 @@ +//go:build !noavx +// +build !noavx + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import "golang.org/x/sys/cpu" + +var ( + supportAvx512 = supportAdx && cpu.X86.HasAVX512 && cpu.X86.HasAVX512DQ + _ = supportAvx512 +) diff --git a/ecc/bn254/fr/asm_noavx.go b/ecc/bn254/fr/asm_noavx.go new file mode 100644 index 000000000..e5a5b1f2c --- /dev/null +++ b/ecc/bn254/fr/asm_noavx.go @@ -0,0 +1,22 @@ +//go:build noavx +// +build noavx + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +const supportAvx512 = false diff --git a/ecc/bn254/fr/element.go b/ecc/bn254/fr/element.go index cda0b2c28..3650c954c 100644 --- a/ecc/bn254/fr/element.go +++ b/ecc/bn254/fr/element.go @@ -81,6 +81,9 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 14042775128853446655 +// mu = 2^288 / q needed for partial Barrett reduction +const mu uint64 = 22721021478 + func init() { _modulus.SetString("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", 16) } @@ -477,32 +480,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [5]uint64 var D uint64 diff --git a/ecc/bn254/fr/element_mul_amd64.s b/ecc/bn254/fr/element_mul_amd64.s deleted file mode 100644 index 4a9321837..000000000 --- a/ecc/bn254/fr/element_mul_amd64.s +++ /dev/null @@ -1,487 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x43e1f593f0000001 -DATA q<>+8(SB)/8, $0x2833e84879b97091 -DATA q<>+16(SB)/8, $0xb85045b68181585d -DATA q<>+24(SB)/8, $0x30644e72e131a029 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0xc2e1f593efffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $24-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), SI - - // x[0] -> DI - // x[1] -> R8 - // x[2] -> R9 - // x[3] -> R10 - MOVQ 0(SI), DI - MOVQ 8(SI), R8 - MOVQ 16(SI), R9 - MOVQ 24(SI), R10 - MOVQ y+16(FP), R11 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ DI, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R8, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R9, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,R12,R11,DI) - REDUCE(R14,R13,CX,BX,SI,R12,R11,DI) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $8-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,DI,R8,R9) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/bn254/fr/element_ops_amd64.go b/ecc/bn254/fr/element_ops_amd64.go index 21568255d..b653e8006 100644 --- a/ecc/bn254/fr/element_ops_amd64.go +++ b/ecc/bn254/fr/element_ops_amd64.go @@ -51,7 +51,8 @@ func (vector *Vector) Add(a, b Vector) { if len(a) != len(b) || len(a) != len(*vector) { panic("vector.Add: vectors don't have the same length") } - addVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) + n := uint64(len(a)) + addVec(&(*vector)[0], &a[0], &b[0], n) } //go:noescape @@ -75,59 +76,123 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { if len(a) != len(*vector) { panic("vector.ScalarMul: vectors don't have the same length") } - scalarMulVec(&(*vector)[0], &a[0], b, uint64(len(a))) + const maxN = (1 << 32) - 1 + if !supportAvx512 || uint64(len(a)) >= maxN { + // call scalarMulVecGeneric + scalarMulVecGeneric(*vector, a, b) + return + } + n := uint64(len(a)) + if n == 0 { + return + } + // the code for scalarMul is identical to mulVec; and it expects at least + // 2 elements in the vector to fill the Z registers + var bb [2]Element + bb[0] = *b + bb[1] = *b + const blockSize = 16 + scalarMulVec(&(*vector)[0], &a[0], &bb[0], n/blockSize, qInvNeg) + if n%blockSize != 0 { + // call scalarMulVecGeneric on the rest + start := n - n%blockSize + scalarMulVecGeneric((*vector)[start:], a[start:], b) + } +} + +//go:noescape +func scalarMulVec(res, a, b *Element, n uint64, qInvNeg uint64) + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + n := uint64(len(*vector)) + if n == 0 { + return + } + const minN = 16 * 7 // AVX512 slower than generic for small n + const maxN = (1 << 32) - 1 + if !supportAvx512 || n <= minN || n >= maxN { + // call sumVecGeneric + sumVecGeneric(&res, *vector) + return + } + sumVec(&res, &(*vector)[0], uint64(len(*vector))) + return } //go:noescape -func scalarMulVec(res, a, b *Element, n uint64) +func sumVec(res *Element, a *Element, n uint64) + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + n := uint64(len(*vector)) + if n == 0 { + return + } + if n != uint64(len(other)) { + panic("vector.InnerProduct: vectors don't have the same length") + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call innerProductVecGeneric + // note; we could split the vector into smaller chunks and call innerProductVec + innerProductVecGeneric(&res, *vector, other) + return + } + innerProdVec(&res[0], &(*vector)[0], &other[0], uint64(len(*vector))) + + return +} + +//go:noescape +func innerProdVec(res *uint64, a, b *Element, n uint64) + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Mul: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call mulVecGeneric + mulVecGeneric(*vector, a, b) + return + } + + const blockSize = 16 + mulVec(&(*vector)[0], &a[0], &b[0], n/blockSize, qInvNeg) + if n%blockSize != 0 { + // call mulVecGeneric on the rest + start := n - n%blockSize + mulVecGeneric((*vector)[start:], a[start:], b[start:]) + } + +} + +// Patterns use for transposing the vectors in mulVec +var ( + pattern1 = [8]uint64{0, 8, 1, 9, 2, 10, 3, 11} + pattern2 = [8]uint64{12, 4, 13, 5, 14, 6, 15, 7} + pattern3 = [8]uint64{0, 1, 8, 9, 2, 3, 10, 11} + pattern4 = [8]uint64{12, 13, 4, 5, 14, 15, 6, 7} +) + +//go:noescape +func mulVec(res, a, b *Element, n uint64, qInvNeg uint64) // Mul z = x * y (mod q) // // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bn254/fr/element_ops_amd64.s b/ecc/bn254/fr/element_ops_amd64.s index d077b1124..6c42136a7 100644 --- a/ecc/bn254/fr/element_ops_amd64.s +++ b/ecc/bn254/fr/element_ops_amd64.s @@ -1,627 +1,6 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// Code generated by gnark-crypto/generator. DO NOT EDIT. +// We include the hash to force the Go compiler to recompile: 9425145785761608449 +#include "../../../field/asm/element_4w_amd64.s" -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x43e1f593f0000001 -DATA q<>+8(SB)/8, $0x2833e84879b97091 -DATA q<>+16(SB)/8, $0xb85045b68181585d -DATA q<>+24(SB)/8, $0x30644e72e131a029 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0xc2e1f593efffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R15,DI,R8,R9) - REDUCE(DX,CX,BX,SI,R15,DI,R8,R9) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, R11 - MOVQ CX, R12 - MOVQ BX, R13 - MOVQ SI, R14 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ R11, DX - ADCQ R12, CX - ADCQ R13, BX - ADCQ R14, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), NOSPLIT, $0-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ CX, R8 - MOVQ BX, R9 - MOVQ SI, R10 - MOVQ DI, R11 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - SUBQ 0(DX), R8 - SBBQ 8(DX), R9 - SBBQ 16(DX), R10 - SBBQ 24(DX), R11 - MOVQ $0x43e1f593f0000001, R12 - MOVQ $0x2833e84879b97091, R13 - MOVQ $0xb85045b68181585d, R14 - MOVQ $0x30644e72e131a029, R15 - CMOVQCC AX, R12 - CMOVQCC AX, R13 - CMOVQCC AX, R14 - CMOVQCC AX, R15 - ADDQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - ADCQ R15, R11 - MOVQ R8, 0(DX) - MOVQ R9, 8(DX) - MOVQ R10, 16(DX) - MOVQ R11, 24(DX) - - // reduce element(CX,BX,SI,DI) using temp registers (R8,R9,R10,R11) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - RET - -// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] -TEXT ·addVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - -loop_1: - TESTQ BX, BX - JEQ done_2 // n == 0, we are done - - // a[0] -> SI - // a[1] -> DI - // a[2] -> R8 - // a[3] -> R9 - MOVQ 0(AX), SI - MOVQ 8(AX), DI - MOVQ 16(AX), R8 - MOVQ 24(AX), R9 - ADDQ 0(DX), SI - ADCQ 8(DX), DI - ADCQ 16(DX), R8 - ADCQ 24(DX), R9 - - // reduce element(SI,DI,R8,R9) using temp registers (R10,R11,R12,R13) - REDUCE(SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ SI, 0(CX) - MOVQ DI, 8(CX) - MOVQ R8, 16(CX) - MOVQ R9, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_1 - -done_2: - RET - -// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] -TEXT ·subVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - XORQ SI, SI - -loop_3: - TESTQ BX, BX - JEQ done_4 // n == 0, we are done - - // a[0] -> DI - // a[1] -> R8 - // a[2] -> R9 - // a[3] -> R10 - MOVQ 0(AX), DI - MOVQ 8(AX), R8 - MOVQ 16(AX), R9 - MOVQ 24(AX), R10 - SUBQ 0(DX), DI - SBBQ 8(DX), R8 - SBBQ 16(DX), R9 - SBBQ 24(DX), R10 - - // reduce (a-b) mod q - // q[0] -> R11 - // q[1] -> R12 - // q[2] -> R13 - // q[3] -> R14 - MOVQ $0x43e1f593f0000001, R11 - MOVQ $0x2833e84879b97091, R12 - MOVQ $0xb85045b68181585d, R13 - MOVQ $0x30644e72e131a029, R14 - CMOVQCC SI, R11 - CMOVQCC SI, R12 - CMOVQCC SI, R13 - CMOVQCC SI, R14 - - // add registers (q or 0) to a, and set to result - ADDQ R11, DI - ADCQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - MOVQ DI, 0(CX) - MOVQ R8, 8(CX) - MOVQ R9, 16(CX) - MOVQ R10, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_3 - -done_4: - RET - -// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b -TEXT ·scalarMulVec(SB), $56-32 - CMPB ·supportAdx(SB), $1 - JNE noAdx_5 - MOVQ a+8(FP), R11 - MOVQ b+16(FP), R10 - MOVQ n+24(FP), R12 - - // scalar[0] -> SI - // scalar[1] -> DI - // scalar[2] -> R8 - // scalar[3] -> R9 - MOVQ 0(R10), SI - MOVQ 8(R10), DI - MOVQ 16(R10), R8 - MOVQ 24(R10), R9 - MOVQ res+0(FP), R10 - -loop_6: - TESTQ R12, R12 - JEQ done_7 // n == 0, we are done - - // TODO @gbotrel this is generated from the same macro as the unit mul, we should refactor this in a single asm function - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ SI, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ DI, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R8, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce t mod q - // reduce element(R14,R15,CX,BX) using temp registers (R13,AX,DX,s0-8(SP)) - REDUCE(R14,R15,CX,BX,R13,AX,DX,s0-8(SP)) - - MOVQ R14, 0(R10) - MOVQ R15, 8(R10) - MOVQ CX, 16(R10) - MOVQ BX, 24(R10) - - // increment pointers to visit next element - ADDQ $32, R11 - ADDQ $32, R10 - DECQ R12 // decrement n - JMP loop_6 - -done_7: - RET - -noAdx_5: - MOVQ n+24(FP), DX - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ DX, 8(SP) - MOVQ DX, 16(SP) - MOVQ a+8(FP), AX - MOVQ AX, 24(SP) - MOVQ DX, 32(SP) - MOVQ DX, 40(SP) - MOVQ b+16(FP), AX - MOVQ AX, 48(SP) - CALL ·scalarMulVecGeneric(SB) - RET diff --git a/ecc/bn254/fr/element_ops_purego.go b/ecc/bn254/fr/element_ops_purego.go index cd5c53d8f..4ea220c18 100644 --- a/ecc/bn254/fr/element_ops_purego.go +++ b/ecc/bn254/fr/element_ops_purego.go @@ -78,53 +78,32 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} + // Mul z = x * y (mod q) // // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3 uint64 var u0, u1, u2, u3 uint64 diff --git a/ecc/bn254/fr/element_test.go b/ecc/bn254/fr/element_test.go index 3be23d96a..5bac70c4b 100644 --- a/ecc/bn254/fr/element_test.go +++ b/ecc/bn254/fr/element_test.go @@ -637,7 +637,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -707,77 +706,6 @@ func TestElementLexicographicallyLargest(t *testing.T) { } -func TestElementVecOps(t *testing.T) { - assert := require.New(t) - - const N = 7 - a := make(Vector, N) - b := make(Vector, N) - c := make(Vector, N) - for i := 0; i < N; i++ { - a[i].SetRandom() - b[i].SetRandom() - } - - // Vector addition - c.Add(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Add(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector addition failed") - } - - // Vector subtraction - c.Sub(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Sub(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector subtraction failed") - } - - // Vector scaling - c.ScalarMul(a, &b[0]) - for i := 0; i < N; i++ { - var expected Element - expected.Mul(&a[i], &b[0]) - assert.True(c[i].Equal(&expected), "Vector scaling failed") - } -} - -func BenchmarkElementVecOps(b *testing.B) { - // note; to benchmark against "no asm" version, use the following - // build tag: -tags purego - const N = 1024 - a1 := make(Vector, N) - b1 := make(Vector, N) - c1 := make(Vector, N) - for i := 0; i < N; i++ { - a1[i].SetRandom() - b1[i].SetRandom() - } - - b.Run("Add", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Add(a1, b1) - } - }) - - b.Run("Sub", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Sub(a1, b1) - } - }) - - b.Run("ScalarMul", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.ScalarMul(a1, &b1[0]) - } - }) -} - func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() @@ -2345,38 +2273,38 @@ func gen() gopter.Gen { } } -func genFull() gopter.Gen { - return func(genParams *gopter.GenParameters) *gopter.GenResult { +func genRandomFq(genParams *gopter.GenParameters) Element { + var g Element - genRandomFq := func() Element { - var g Element + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } - if qElement[3] != ^uint64(0) { - g[3] %= (qElement[3] + 1) - } + for !g.smallerThanModulus() { + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } + } - for !g.smallerThanModulus() { - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } - if qElement[3] != ^uint64(0) { - g[3] %= (qElement[3] + 1) - } - } + return g +} - return g - } - a := genRandomFq() +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) var carry uint64 a[0], carry = bits.Add64(a[0], qElement[0], carry) @@ -2389,6 +2317,14 @@ func genFull() gopter.Gen { } } +func genElement() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} + func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { var modulus big.Int var aIntMod big.Int diff --git a/ecc/bn254/fr/vector.go b/ecc/bn254/fr/vector.go index f39828547..867cabbc3 100644 --- a/ecc/bn254/fr/vector.go +++ b/ecc/bn254/fr/vector.go @@ -226,6 +226,32 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + +func innerProductVecGeneric(res *Element, a, b Vector) { + if len(a) != len(b) { + panic("vector.InnerProduct: vectors don't have the same length") + } + var tmp Element + for i := 0; i < len(a); i++ { + tmp.Mul(&a[i], &b[i]) + res.Add(res, &tmp) + } +} + +func mulVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Mul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], &b[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bn254/fr/vector_test.go b/ecc/bn254/fr/vector_test.go index e58f2d9a3..b6344c18b 100644 --- a/ecc/bn254/fr/vector_test.go +++ b/ecc/bn254/fr/vector_test.go @@ -18,10 +18,15 @@ package fr import ( "bytes" + "fmt" "github.com/stretchr/testify/require" + "os" "reflect" "sort" "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" ) func TestVectorSort(t *testing.T) { @@ -88,3 +93,279 @@ func (vector *Vector) unmarshalBinaryAsync(data []byte) error { } return <-chErr } + +func TestVectorOps(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = 2 + } else { + parameters.MinSuccessfulTests = 10 + } + properties := gopter.NewProperties(parameters) + + addVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Add(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Add(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + subVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Sub(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Sub(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + scalarMulVector := func(a Vector, b Element) bool { + c := make(Vector, len(a)) + c.ScalarMul(a, &b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sumVector := func(a Vector) bool { + var sum Element + computed := a.Sum() + for i := 0; i < len(a); i++ { + sum.Add(&sum, &a[i]) + } + + return sum.Equal(&computed) + } + + innerProductVector := func(a, b Vector) bool { + computed := a.InnerProduct(b) + var innerProduct Element + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + innerProduct.Add(&innerProduct, &tmp) + } + + return innerProduct.Equal(&computed) + } + + mulVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + a[0].SetUint64(0x24) + b[0].SetUint64(0x42) + c.Mul(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sizes := []int{1, 2, 3, 4, 8, 9, 15, 16, 509, 510, 511, 512, 513, 514} + type genPair struct { + g1, g2 gopter.Gen + label string + } + + for _, size := range sizes { + generators := []genPair{ + {genZeroVector(size), genZeroVector(size), "zero vectors"}, + {genMaxVector(size), genMaxVector(size), "max vectors"}, + {genVector(size), genVector(size), "random vectors"}, + {genVector(size), genZeroVector(size), "random and zero vectors"}, + } + for _, gp := range generators { + properties.Property(fmt.Sprintf("vector addition %d - %s", size, gp.label), prop.ForAll( + addVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector subtraction %d - %s", size, gp.label), prop.ForAll( + subVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector scalar multiplication %d - %s", size, gp.label), prop.ForAll( + scalarMulVector, + gp.g1, + genElement(), + )) + + properties.Property(fmt.Sprintf("vector sum %d - %s", size, gp.label), prop.ForAll( + sumVector, + gp.g1, + )) + + properties.Property(fmt.Sprintf("vector inner product %d - %s", size, gp.label), prop.ForAll( + innerProductVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector multiplication %d - %s", size, gp.label), prop.ForAll( + mulVector, + gp.g1, + gp.g2, + )) + } + } + + properties.TestingRun(t, gopter.NewFormatedReporter(false, 260, os.Stdout)) +} + +func BenchmarkVectorOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1 << 24 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + var mixer Element + mixer.SetRandom() + for i := 1; i < N; i++ { + a1[i-1].SetUint64(uint64(i)). + Mul(&a1[i-1], &mixer) + b1[i-1].SetUint64(^uint64(i)). + Mul(&b1[i-1], &mixer) + } + + for n := 1 << 4; n <= N; n <<= 1 { + b.Run(fmt.Sprintf("add %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Add(_a, _b) + } + }) + + b.Run(fmt.Sprintf("sub %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Sub(_a, _b) + } + }) + + b.Run(fmt.Sprintf("scalarMul %d", n), func(b *testing.B) { + _a := a1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.ScalarMul(_a, &mixer) + } + }) + + b.Run(fmt.Sprintf("sum %d", n), func(b *testing.B) { + _a := a1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.Sum() + } + }) + + b.Run(fmt.Sprintf("innerProduct %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.InnerProduct(_b) + } + }) + + b.Run(fmt.Sprintf("mul %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Mul(_a, _b) + } + }) + } +} + +func genZeroVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genMaxVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + + qMinusOne := qElement + qMinusOne[0]-- + + for i := 0; i < size; i++ { + g[i] = qMinusOne + } + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + mixer := Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + mixer[3] %= (qElement[3] + 1) + } + + for !mixer.smallerThanModulus() { + mixer = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + mixer[3] %= (qElement[3] + 1) + } + } + + for i := 1; i <= size; i++ { + g[i-1].SetUint64(uint64(i)). + Mul(&g[i-1], &mixer) + } + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} diff --git a/ecc/bn254/internal/fptower/e2_amd64.go b/ecc/bn254/internal/fptower/e2_amd64.go index 259609bd8..b6db5715a 100644 --- a/ecc/bn254/internal/fptower/e2_amd64.go +++ b/ecc/bn254/internal/fptower/e2_amd64.go @@ -16,6 +16,29 @@ package fptower +import ( + "github.com/consensys/gnark-crypto/ecc/bn254/fp" +) + +// q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r +// used for Montgomery reduction +const qInvNeg uint64 = 9786893198990664585 + +// Field modulus q (Fp) +const ( + q0 uint64 = 4332616871279656263 + q1 uint64 = 10917124144477883021 + q2 uint64 = 13281191951274694749 + q3 uint64 = 3486998266802970665 +) + +var qElement = fp.Element{ + q0, + q1, + q2, + q3, +} + //go:noescape func addE2(res, x, y *E2) diff --git a/ecc/bn254/internal/fptower/e2_amd64.s b/ecc/bn254/internal/fptower/e2_amd64.s index 43ffb7f16..172cd67e8 100644 --- a/ecc/bn254/internal/fptower/e2_amd64.s +++ b/ecc/bn254/internal/fptower/e2_amd64.s @@ -14,173 +14,83 @@ #include "textflag.h" #include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x3c208c16d87cfd47 -DATA q<>+8(SB)/8, $0x97816a916871ca8d -DATA q<>+16(SB)/8, $0xb85045b68181585d -DATA q<>+24(SB)/8, $0x30644e72e131a029 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x87d20782e4866389 -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 +#include "go_asm.h" #define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ + MOVQ ra0, rb0; \ + SUBQ ·qElement(SB), ra0; \ + MOVQ ra1, rb1; \ + SBBQ ·qElement+8(SB), ra1; \ + MOVQ ra2, rb2; \ + SBBQ ·qElement+16(SB), ra2; \ + MOVQ ra3, rb3; \ + SBBQ ·qElement+24(SB), ra3; \ + CMOVQCS rb0, ra0; \ + CMOVQCS rb1, ra1; \ + CMOVQCS rb2, ra2; \ + CMOVQCS rb3, ra3; \ // this code is generated and identical to fp.Mul(...) +// A -> BP +// t[0] -> R10 +// t[1] -> R11 +// t[2] -> R12 +// t[3] -> R13 +#define MACC(in0, in1, in2) \ + ADCXQ in0, in1 \ + MULXQ in2, AX, in0 \ + ADOXQ AX, in1 \ + +#define DIV_SHIFT() \ + PUSHQ BP \ + MOVQ $const_qInvNeg, DX \ + IMULQ R10, DX \ + XORQ AX, AX \ + MULXQ ·qElement+0(SB), AX, BP \ + ADCXQ R10, AX \ + MOVQ BP, R10 \ + POPQ BP \ + MACC(R11, R10, ·qElement+8(SB)) \ + MACC(R12, R11, ·qElement+16(SB)) \ + MACC(R13, R12, ·qElement+24(SB)) \ + MOVQ $0, AX \ + ADCXQ AX, R13 \ + ADOXQ BP, R13 \ + +#define MUL_WORD_0() \ + XORQ AX, AX \ + MULXQ R14, R10, R11 \ + MULXQ R15, AX, R12 \ + ADOXQ AX, R11 \ + MULXQ CX, AX, R13 \ + ADOXQ AX, R12 \ + MULXQ BX, AX, BP \ + ADOXQ AX, R13 \ + MOVQ $0, AX \ + ADOXQ AX, BP \ + DIV_SHIFT() \ + +#define MUL_WORD_N() \ + XORQ AX, AX \ + MULXQ R14, AX, BP \ + ADOXQ AX, R10 \ + MACC(BP, R11, R15) \ + MACC(BP, R12, CX) \ + MACC(BP, R13, BX) \ + MOVQ $0, AX \ + ADCXQ AX, BP \ + ADOXQ AX, BP \ + DIV_SHIFT() \ + #define MUL() \ - XORQ AX, AX; \ - MOVQ SI, DX; \ - MULXQ R14, R10, R11; \ - MULXQ R15, AX, R12; \ - ADOXQ AX, R11; \ - MULXQ CX, AX, R13; \ - ADOXQ AX, R12; \ - MULXQ BX, AX, BP; \ - ADOXQ AX, R13; \ - MOVQ $0, AX; \ - ADOXQ AX, BP; \ - PUSHQ BP; \ - MOVQ qInv0<>(SB), DX; \ - IMULQ R10, DX; \ - XORQ AX, AX; \ - MULXQ q<>+0(SB), AX, BP; \ - ADCXQ R10, AX; \ - MOVQ BP, R10; \ - POPQ BP; \ - ADCXQ R11, R10; \ - MULXQ q<>+8(SB), AX, R11; \ - ADOXQ AX, R10; \ - ADCXQ R12, R11; \ - MULXQ q<>+16(SB), AX, R12; \ - ADOXQ AX, R11; \ - ADCXQ R13, R12; \ - MULXQ q<>+24(SB), AX, R13; \ - ADOXQ AX, R12; \ - MOVQ $0, AX; \ - ADCXQ AX, R13; \ - ADOXQ BP, R13; \ - XORQ AX, AX; \ - MOVQ DI, DX; \ - MULXQ R14, AX, BP; \ - ADOXQ AX, R10; \ - ADCXQ BP, R11; \ - MULXQ R15, AX, BP; \ - ADOXQ AX, R11; \ - ADCXQ BP, R12; \ - MULXQ CX, AX, BP; \ - ADOXQ AX, R12; \ - ADCXQ BP, R13; \ - MULXQ BX, AX, BP; \ - ADOXQ AX, R13; \ - MOVQ $0, AX; \ - ADCXQ AX, BP; \ - ADOXQ AX, BP; \ - PUSHQ BP; \ - MOVQ qInv0<>(SB), DX; \ - IMULQ R10, DX; \ - XORQ AX, AX; \ - MULXQ q<>+0(SB), AX, BP; \ - ADCXQ R10, AX; \ - MOVQ BP, R10; \ - POPQ BP; \ - ADCXQ R11, R10; \ - MULXQ q<>+8(SB), AX, R11; \ - ADOXQ AX, R10; \ - ADCXQ R12, R11; \ - MULXQ q<>+16(SB), AX, R12; \ - ADOXQ AX, R11; \ - ADCXQ R13, R12; \ - MULXQ q<>+24(SB), AX, R13; \ - ADOXQ AX, R12; \ - MOVQ $0, AX; \ - ADCXQ AX, R13; \ - ADOXQ BP, R13; \ - XORQ AX, AX; \ - MOVQ R8, DX; \ - MULXQ R14, AX, BP; \ - ADOXQ AX, R10; \ - ADCXQ BP, R11; \ - MULXQ R15, AX, BP; \ - ADOXQ AX, R11; \ - ADCXQ BP, R12; \ - MULXQ CX, AX, BP; \ - ADOXQ AX, R12; \ - ADCXQ BP, R13; \ - MULXQ BX, AX, BP; \ - ADOXQ AX, R13; \ - MOVQ $0, AX; \ - ADCXQ AX, BP; \ - ADOXQ AX, BP; \ - PUSHQ BP; \ - MOVQ qInv0<>(SB), DX; \ - IMULQ R10, DX; \ - XORQ AX, AX; \ - MULXQ q<>+0(SB), AX, BP; \ - ADCXQ R10, AX; \ - MOVQ BP, R10; \ - POPQ BP; \ - ADCXQ R11, R10; \ - MULXQ q<>+8(SB), AX, R11; \ - ADOXQ AX, R10; \ - ADCXQ R12, R11; \ - MULXQ q<>+16(SB), AX, R12; \ - ADOXQ AX, R11; \ - ADCXQ R13, R12; \ - MULXQ q<>+24(SB), AX, R13; \ - ADOXQ AX, R12; \ - MOVQ $0, AX; \ - ADCXQ AX, R13; \ - ADOXQ BP, R13; \ - XORQ AX, AX; \ - MOVQ R9, DX; \ - MULXQ R14, AX, BP; \ - ADOXQ AX, R10; \ - ADCXQ BP, R11; \ - MULXQ R15, AX, BP; \ - ADOXQ AX, R11; \ - ADCXQ BP, R12; \ - MULXQ CX, AX, BP; \ - ADOXQ AX, R12; \ - ADCXQ BP, R13; \ - MULXQ BX, AX, BP; \ - ADOXQ AX, R13; \ - MOVQ $0, AX; \ - ADCXQ AX, BP; \ - ADOXQ AX, BP; \ - PUSHQ BP; \ - MOVQ qInv0<>(SB), DX; \ - IMULQ R10, DX; \ - XORQ AX, AX; \ - MULXQ q<>+0(SB), AX, BP; \ - ADCXQ R10, AX; \ - MOVQ BP, R10; \ - POPQ BP; \ - ADCXQ R11, R10; \ - MULXQ q<>+8(SB), AX, R11; \ - ADOXQ AX, R10; \ - ADCXQ R12, R11; \ - MULXQ q<>+16(SB), AX, R12; \ - ADOXQ AX, R11; \ - ADCXQ R13, R12; \ - MULXQ q<>+24(SB), AX, R13; \ - ADOXQ AX, R12; \ - MOVQ $0, AX; \ - ADCXQ AX, R13; \ - ADOXQ BP, R13; \ + MOVQ SI, DX; \ + MUL_WORD_0(); \ + MOVQ DI, DX; \ + MUL_WORD_N(); \ + MOVQ R8, DX; \ + MUL_WORD_N(); \ + MOVQ R9, DX; \ + MUL_WORD_N(); \ TEXT ·addE2(SB), NOSPLIT, $0-24 MOVQ x+8(FP), AX diff --git a/ecc/bw6-633/fp/asm.go b/ecc/bw6-633/fp/asm_adx.go similarity index 100% rename from ecc/bw6-633/fp/asm.go rename to ecc/bw6-633/fp/asm_adx.go diff --git a/ecc/bw6-633/fp/element.go b/ecc/bw6-633/fp/element.go index 475abd7e5..7656002f4 100644 --- a/ecc/bw6-633/fp/element.go +++ b/ecc/bw6-633/fp/element.go @@ -609,32 +609,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [11]uint64 var D uint64 diff --git a/ecc/bw6-633/fp/element_mul_amd64.s b/ecc/bw6-633/fp/element_mul_amd64.s deleted file mode 100644 index 62a7d4dda..000000000 --- a/ecc/bw6-633/fp/element_mul_amd64.s +++ /dev/null @@ -1,1974 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0xd74916ea4570000d -DATA q<>+8(SB)/8, $0x3d369bd31147f73c -DATA q<>+16(SB)/8, $0xd7b5ce7ab839c225 -DATA q<>+24(SB)/8, $0x7e0e8850edbda407 -DATA q<>+32(SB)/8, $0xb8da9f5e83f57c49 -DATA q<>+40(SB)/8, $0x8152a6c0fadea490 -DATA q<>+48(SB)/8, $0x4e59769ad9bbda2f -DATA q<>+56(SB)/8, $0xa8fcd8c75d79d2c7 -DATA q<>+64(SB)/8, $0xfc1a174f01d72ab5 -DATA q<>+72(SB)/8, $0x0126633cc0f35f63 -GLOBL q<>(SB), (RODATA+NOPTR), $80 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0xb50f29ab0b03b13b -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, ra6, ra7, ra8, ra9, rb0, rb1, rb2, rb3, rb4, rb5, rb6, rb7, rb8, rb9) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - MOVQ ra6, rb6; \ - SBBQ q<>+48(SB), ra6; \ - MOVQ ra7, rb7; \ - SBBQ q<>+56(SB), ra7; \ - MOVQ ra8, rb8; \ - SBBQ q<>+64(SB), ra8; \ - MOVQ ra9, rb9; \ - SBBQ q<>+72(SB), ra9; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ - CMOVQCS rb6, ra6; \ - CMOVQCS rb7, ra7; \ - CMOVQCS rb8, ra8; \ - CMOVQCS rb9, ra9; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $64-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), R12 - MOVQ y+16(FP), R13 - - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // t[4] -> SI - // t[5] -> DI - // t[6] -> R8 - // t[7] -> R9 - // t[8] -> R10 - // t[9] -> R11 - // clear the flags - XORQ AX, AX - MOVQ 0(R13), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ 0(R12), R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ 8(R12), AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ 16(R12), AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ 24(R12), AX, SI - ADOXQ AX, BX - - // (A,t[4]) := x[4]*y[0] + A - MULXQ 32(R12), AX, DI - ADOXQ AX, SI - - // (A,t[5]) := x[5]*y[0] + A - MULXQ 40(R12), AX, R8 - ADOXQ AX, DI - - // (A,t[6]) := x[6]*y[0] + A - MULXQ 48(R12), AX, R9 - ADOXQ AX, R8 - - // (A,t[7]) := x[7]*y[0] + A - MULXQ 56(R12), AX, R10 - ADOXQ AX, R9 - - // (A,t[8]) := x[8]*y[0] + A - MULXQ 64(R12), AX, R11 - ADOXQ AX, R10 - - // (A,t[9]) := x[9]*y[0] + A - MULXQ 72(R12), AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // t[9] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ 8(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ 0(R12), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ 8(R12), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ 16(R12), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ 24(R12), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, SI - MULXQ 32(R12), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[1] + A - ADCXQ BP, DI - MULXQ 40(R12), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[1] + A - ADCXQ BP, R8 - MULXQ 48(R12), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[1] + A - ADCXQ BP, R9 - MULXQ 56(R12), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[1] + A - ADCXQ BP, R10 - MULXQ 64(R12), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[1] + A - ADCXQ BP, R11 - MULXQ 72(R12), AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // t[9] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ 16(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ 0(R12), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ 8(R12), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ 16(R12), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ 24(R12), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, SI - MULXQ 32(R12), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[2] + A - ADCXQ BP, DI - MULXQ 40(R12), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[2] + A - ADCXQ BP, R8 - MULXQ 48(R12), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[2] + A - ADCXQ BP, R9 - MULXQ 56(R12), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[2] + A - ADCXQ BP, R10 - MULXQ 64(R12), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[2] + A - ADCXQ BP, R11 - MULXQ 72(R12), AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // t[9] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ 24(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ 0(R12), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ 8(R12), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ 16(R12), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ 24(R12), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, SI - MULXQ 32(R12), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[3] + A - ADCXQ BP, DI - MULXQ 40(R12), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[3] + A - ADCXQ BP, R8 - MULXQ 48(R12), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[3] + A - ADCXQ BP, R9 - MULXQ 56(R12), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[3] + A - ADCXQ BP, R10 - MULXQ 64(R12), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[3] + A - ADCXQ BP, R11 - MULXQ 72(R12), AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // t[9] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ 32(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ 0(R12), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R15 - MULXQ 8(R12), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, CX - MULXQ 16(R12), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, BX - MULXQ 24(R12), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, SI - MULXQ 32(R12), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[4] + A - ADCXQ BP, DI - MULXQ 40(R12), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[4] + A - ADCXQ BP, R8 - MULXQ 48(R12), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[4] + A - ADCXQ BP, R9 - MULXQ 56(R12), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[4] + A - ADCXQ BP, R10 - MULXQ 64(R12), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[4] + A - ADCXQ BP, R11 - MULXQ 72(R12), AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // t[9] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ 40(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[5] + A - MULXQ 0(R12), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[5] + A - ADCXQ BP, R15 - MULXQ 8(R12), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[5] + A - ADCXQ BP, CX - MULXQ 16(R12), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[5] + A - ADCXQ BP, BX - MULXQ 24(R12), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[5] + A - ADCXQ BP, SI - MULXQ 32(R12), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[5] + A - ADCXQ BP, DI - MULXQ 40(R12), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[5] + A - ADCXQ BP, R8 - MULXQ 48(R12), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[5] + A - ADCXQ BP, R9 - MULXQ 56(R12), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[5] + A - ADCXQ BP, R10 - MULXQ 64(R12), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[5] + A - ADCXQ BP, R11 - MULXQ 72(R12), AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // t[9] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ 48(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[6] + A - MULXQ 0(R12), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[6] + A - ADCXQ BP, R15 - MULXQ 8(R12), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[6] + A - ADCXQ BP, CX - MULXQ 16(R12), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[6] + A - ADCXQ BP, BX - MULXQ 24(R12), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[6] + A - ADCXQ BP, SI - MULXQ 32(R12), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[6] + A - ADCXQ BP, DI - MULXQ 40(R12), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[6] + A - ADCXQ BP, R8 - MULXQ 48(R12), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[6] + A - ADCXQ BP, R9 - MULXQ 56(R12), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[6] + A - ADCXQ BP, R10 - MULXQ 64(R12), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[6] + A - ADCXQ BP, R11 - MULXQ 72(R12), AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // t[9] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ 56(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[7] + A - MULXQ 0(R12), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[7] + A - ADCXQ BP, R15 - MULXQ 8(R12), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[7] + A - ADCXQ BP, CX - MULXQ 16(R12), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[7] + A - ADCXQ BP, BX - MULXQ 24(R12), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[7] + A - ADCXQ BP, SI - MULXQ 32(R12), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[7] + A - ADCXQ BP, DI - MULXQ 40(R12), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[7] + A - ADCXQ BP, R8 - MULXQ 48(R12), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[7] + A - ADCXQ BP, R9 - MULXQ 56(R12), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[7] + A - ADCXQ BP, R10 - MULXQ 64(R12), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[7] + A - ADCXQ BP, R11 - MULXQ 72(R12), AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // t[9] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ 64(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[8] + A - MULXQ 0(R12), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[8] + A - ADCXQ BP, R15 - MULXQ 8(R12), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[8] + A - ADCXQ BP, CX - MULXQ 16(R12), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[8] + A - ADCXQ BP, BX - MULXQ 24(R12), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[8] + A - ADCXQ BP, SI - MULXQ 32(R12), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[8] + A - ADCXQ BP, DI - MULXQ 40(R12), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[8] + A - ADCXQ BP, R8 - MULXQ 48(R12), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[8] + A - ADCXQ BP, R9 - MULXQ 56(R12), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[8] + A - ADCXQ BP, R10 - MULXQ 64(R12), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[8] + A - ADCXQ BP, R11 - MULXQ 72(R12), AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // t[9] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ 72(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[9] + A - MULXQ 0(R12), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[9] + A - ADCXQ BP, R15 - MULXQ 8(R12), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[9] + A - ADCXQ BP, CX - MULXQ 16(R12), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[9] + A - ADCXQ BP, BX - MULXQ 24(R12), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[9] + A - ADCXQ BP, SI - MULXQ 32(R12), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[9] + A - ADCXQ BP, DI - MULXQ 40(R12), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[9] + A - ADCXQ BP, R8 - MULXQ 48(R12), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[9] + A - ADCXQ BP, R9 - MULXQ 56(R12), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[9] + A - ADCXQ BP, R10 - MULXQ 64(R12), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[9] + A - ADCXQ BP, R11 - MULXQ 72(R12), AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // t[9] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // reduce element(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11) using temp registers (R12,R13,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP)) - REDUCE(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP)) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - MOVQ R8, 48(AX) - MOVQ R9, 56(AX) - MOVQ R10, 64(AX) - MOVQ R11, 72(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $64-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R15 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - MOVQ 32(DX), SI - MOVQ 40(DX), DI - MOVQ 48(DX), R8 - MOVQ 56(DX), R9 - MOVQ 64(DX), R10 - MOVQ 72(DX), R11 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ AX, R11 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ AX, R11 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ AX, R11 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ AX, R11 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ AX, R11 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ AX, R11 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ AX, R11 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ AX, R11 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ AX, R11 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ AX, R11 - - // reduce element(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11) using temp registers (R12,R13,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP)) - REDUCE(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP)) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - MOVQ R8, 48(AX) - MOVQ R9, 56(AX) - MOVQ R10, 64(AX) - MOVQ R11, 72(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/bw6-633/fp/element_ops_amd64.go b/ecc/bw6-633/fp/element_ops_amd64.go index 83bba45ae..ed2803d71 100644 --- a/ecc/bw6-633/fp/element_ops_amd64.go +++ b/ecc/bw6-633/fp/element_ops_amd64.go @@ -50,48 +50,8 @@ func Butterfly(a, b *Element) // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bw6-633/fp/element_ops_amd64.s b/ecc/bw6-633/fp/element_ops_amd64.s index 12a078963..db6a61c53 100644 --- a/ecc/bw6-633/fp/element_ops_amd64.s +++ b/ecc/bw6-633/fp/element_ops_amd64.s @@ -1,436 +1,6 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// Code generated by gnark-crypto/generator. DO NOT EDIT. +// We include the hash to force the Go compiler to recompile: 747913930085520082 +#include "../../../field/asm/element_10w_amd64.s" -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0xd74916ea4570000d -DATA q<>+8(SB)/8, $0x3d369bd31147f73c -DATA q<>+16(SB)/8, $0xd7b5ce7ab839c225 -DATA q<>+24(SB)/8, $0x7e0e8850edbda407 -DATA q<>+32(SB)/8, $0xb8da9f5e83f57c49 -DATA q<>+40(SB)/8, $0x8152a6c0fadea490 -DATA q<>+48(SB)/8, $0x4e59769ad9bbda2f -DATA q<>+56(SB)/8, $0xa8fcd8c75d79d2c7 -DATA q<>+64(SB)/8, $0xfc1a174f01d72ab5 -DATA q<>+72(SB)/8, $0x0126633cc0f35f63 -GLOBL q<>(SB), (RODATA+NOPTR), $80 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0xb50f29ab0b03b13b -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, ra6, ra7, ra8, ra9, rb0, rb1, rb2, rb3, rb4, rb5, rb6, rb7, rb8, rb9) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - MOVQ ra6, rb6; \ - SBBQ q<>+48(SB), ra6; \ - MOVQ ra7, rb7; \ - SBBQ q<>+56(SB), ra7; \ - MOVQ ra8, rb8; \ - SBBQ q<>+64(SB), ra8; \ - MOVQ ra9, rb9; \ - SBBQ q<>+72(SB), ra9; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ - CMOVQCS rb6, ra6; \ - CMOVQCS rb7, ra7; \ - CMOVQCS rb8, ra8; \ - CMOVQCS rb9, ra9; \ - -TEXT ·reduce(SB), $56-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - MOVQ 48(AX), R9 - MOVQ 56(AX), R10 - MOVQ 64(AX), R11 - MOVQ 72(AX), R12 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - MOVQ R9, 48(AX) - MOVQ R10, 56(AX) - MOVQ R11, 64(AX) - MOVQ R12, 72(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), $56-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - MOVQ 48(AX), R9 - MOVQ 56(AX), R10 - MOVQ 64(AX), R11 - MOVQ 72(AX), R12 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - ADCQ R10, R10 - ADCQ R11, R11 - ADCQ R12, R12 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - ADCQ 48(AX), R9 - ADCQ 56(AX), R10 - ADCQ 64(AX), R11 - ADCQ 72(AX), R12 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - MOVQ R9, 48(AX) - MOVQ R10, 56(AX) - MOVQ R11, 64(AX) - MOVQ R12, 72(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), $56-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - MOVQ 48(AX), R9 - MOVQ 56(AX), R10 - MOVQ 64(AX), R11 - MOVQ 72(AX), R12 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - ADCQ R10, R10 - ADCQ R11, R11 - ADCQ R12, R12 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - ADCQ R10, R10 - ADCQ R11, R11 - ADCQ R12, R12 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - ADCQ 48(AX), R9 - ADCQ 56(AX), R10 - ADCQ 64(AX), R11 - ADCQ 72(AX), R12 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - MOVQ R9, 48(AX) - MOVQ R10, 56(AX) - MOVQ R11, 64(AX) - MOVQ R12, 72(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), $136-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - MOVQ 48(AX), R9 - MOVQ 56(AX), R10 - MOVQ 64(AX), R11 - MOVQ 72(AX), R12 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - ADCQ R10, R10 - ADCQ R11, R11 - ADCQ R12, R12 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - ADCQ R10, R10 - ADCQ R11, R11 - ADCQ R12, R12 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP),s11-96(SP),s12-104(SP),s13-112(SP),s14-120(SP),s15-128(SP),s16-136(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP),s11-96(SP),s12-104(SP),s13-112(SP),s14-120(SP),s15-128(SP),s16-136(SP)) - - MOVQ DX, s7-64(SP) - MOVQ CX, s8-72(SP) - MOVQ BX, s9-80(SP) - MOVQ SI, s10-88(SP) - MOVQ DI, s11-96(SP) - MOVQ R8, s12-104(SP) - MOVQ R9, s13-112(SP) - MOVQ R10, s14-120(SP) - MOVQ R11, s15-128(SP) - MOVQ R12, s16-136(SP) - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - ADCQ R10, R10 - ADCQ R11, R11 - ADCQ R12, R12 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - - ADDQ s7-64(SP), DX - ADCQ s8-72(SP), CX - ADCQ s9-80(SP), BX - ADCQ s10-88(SP), SI - ADCQ s11-96(SP), DI - ADCQ s12-104(SP), R8 - ADCQ s13-112(SP), R9 - ADCQ s14-120(SP), R10 - ADCQ s15-128(SP), R11 - ADCQ s16-136(SP), R12 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - ADCQ 48(AX), R9 - ADCQ 56(AX), R10 - ADCQ 64(AX), R11 - ADCQ 72(AX), R12 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - MOVQ R9, 48(AX) - MOVQ R10, 56(AX) - MOVQ R11, 64(AX) - MOVQ R12, 72(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), $56-16 - MOVQ b+8(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - MOVQ 48(AX), R9 - MOVQ 56(AX), R10 - MOVQ 64(AX), R11 - MOVQ 72(AX), R12 - MOVQ a+0(FP), AX - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - ADCQ 48(AX), R9 - ADCQ 56(AX), R10 - ADCQ 64(AX), R11 - ADCQ 72(AX), R12 - MOVQ DX, R13 - MOVQ CX, R14 - MOVQ BX, R15 - MOVQ SI, s0-8(SP) - MOVQ DI, s1-16(SP) - MOVQ R8, s2-24(SP) - MOVQ R9, s3-32(SP) - MOVQ R10, s4-40(SP) - MOVQ R11, s5-48(SP) - MOVQ R12, s6-56(SP) - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - MOVQ 48(AX), R9 - MOVQ 56(AX), R10 - MOVQ 64(AX), R11 - MOVQ 72(AX), R12 - MOVQ b+8(FP), AX - SUBQ 0(AX), DX - SBBQ 8(AX), CX - SBBQ 16(AX), BX - SBBQ 24(AX), SI - SBBQ 32(AX), DI - SBBQ 40(AX), R8 - SBBQ 48(AX), R9 - SBBQ 56(AX), R10 - SBBQ 64(AX), R11 - SBBQ 72(AX), R12 - JCC noReduce_1 - MOVQ $0xd74916ea4570000d, AX - ADDQ AX, DX - MOVQ $0x3d369bd31147f73c, AX - ADCQ AX, CX - MOVQ $0xd7b5ce7ab839c225, AX - ADCQ AX, BX - MOVQ $0x7e0e8850edbda407, AX - ADCQ AX, SI - MOVQ $0xb8da9f5e83f57c49, AX - ADCQ AX, DI - MOVQ $0x8152a6c0fadea490, AX - ADCQ AX, R8 - MOVQ $0x4e59769ad9bbda2f, AX - ADCQ AX, R9 - MOVQ $0xa8fcd8c75d79d2c7, AX - ADCQ AX, R10 - MOVQ $0xfc1a174f01d72ab5, AX - ADCQ AX, R11 - MOVQ $0x0126633cc0f35f63, AX - ADCQ AX, R12 - -noReduce_1: - MOVQ b+8(FP), AX - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - MOVQ R9, 48(AX) - MOVQ R10, 56(AX) - MOVQ R11, 64(AX) - MOVQ R12, 72(AX) - MOVQ R13, DX - MOVQ R14, CX - MOVQ R15, BX - MOVQ s0-8(SP), SI - MOVQ s1-16(SP), DI - MOVQ s2-24(SP), R8 - MOVQ s3-32(SP), R9 - MOVQ s4-40(SP), R10 - MOVQ s5-48(SP), R11 - MOVQ s6-56(SP), R12 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - - MOVQ a+0(FP), AX - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - MOVQ R9, 48(AX) - MOVQ R10, 56(AX) - MOVQ R11, 64(AX) - MOVQ R12, 72(AX) - RET diff --git a/ecc/bw6-633/fp/element_ops_purego.go b/ecc/bw6-633/fp/element_ops_purego.go index 69c68919e..3b5d489a3 100644 --- a/ecc/bw6-633/fp/element_ops_purego.go +++ b/ecc/bw6-633/fp/element_ops_purego.go @@ -71,48 +71,8 @@ func reduce(z *Element) { // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3, t4, t5, t6, t7, t8, t9 uint64 var u0, u1, u2, u3, u4, u5, u6, u7, u8, u9 uint64 diff --git a/ecc/bw6-633/fp/element_test.go b/ecc/bw6-633/fp/element_test.go index 169cd6701..2aaadfa90 100644 --- a/ecc/bw6-633/fp/element_test.go +++ b/ecc/bw6-633/fp/element_test.go @@ -649,7 +649,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -719,77 +718,6 @@ func TestElementLexicographicallyLargest(t *testing.T) { } -func TestElementVecOps(t *testing.T) { - assert := require.New(t) - - const N = 7 - a := make(Vector, N) - b := make(Vector, N) - c := make(Vector, N) - for i := 0; i < N; i++ { - a[i].SetRandom() - b[i].SetRandom() - } - - // Vector addition - c.Add(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Add(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector addition failed") - } - - // Vector subtraction - c.Sub(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Sub(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector subtraction failed") - } - - // Vector scaling - c.ScalarMul(a, &b[0]) - for i := 0; i < N; i++ { - var expected Element - expected.Mul(&a[i], &b[0]) - assert.True(c[i].Equal(&expected), "Vector scaling failed") - } -} - -func BenchmarkElementVecOps(b *testing.B) { - // note; to benchmark against "no asm" version, use the following - // build tag: -tags purego - const N = 1024 - a1 := make(Vector, N) - b1 := make(Vector, N) - c1 := make(Vector, N) - for i := 0; i < N; i++ { - a1[i].SetRandom() - b1[i].SetRandom() - } - - b.Run("Add", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Add(a1, b1) - } - }) - - b.Run("Sub", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Sub(a1, b1) - } - }) - - b.Run("ScalarMul", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.ScalarMul(a1, &b1[0]) - } - }) -} - func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() @@ -2369,50 +2297,50 @@ func gen() gopter.Gen { } } -func genFull() gopter.Gen { - return func(genParams *gopter.GenParameters) *gopter.GenResult { +func genRandomFq(genParams *gopter.GenParameters) Element { + var g Element - genRandomFq := func() Element { - var g Element + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } + if qElement[9] != ^uint64(0) { + g[9] %= (qElement[9] + 1) + } - if qElement[9] != ^uint64(0) { - g[9] %= (qElement[9] + 1) - } + for !g.smallerThanModulus() { + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[9] != ^uint64(0) { + g[9] %= (qElement[9] + 1) + } + } - for !g.smallerThanModulus() { - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } - if qElement[9] != ^uint64(0) { - g[9] %= (qElement[9] + 1) - } - } + return g +} - return g - } - a := genRandomFq() +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) var carry uint64 a[0], carry = bits.Add64(a[0], qElement[0], carry) @@ -2431,6 +2359,14 @@ func genFull() gopter.Gen { } } +func genElement() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} + func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { var modulus big.Int var aIntMod big.Int diff --git a/ecc/bw6-633/fp/vector.go b/ecc/bw6-633/fp/vector.go index 1bd71a36e..90e2236c7 100644 --- a/ecc/bw6-633/fp/vector.go +++ b/ecc/bw6-633/fp/vector.go @@ -223,6 +223,25 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} + func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") @@ -250,6 +269,32 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + +func innerProductVecGeneric(res *Element, a, b Vector) { + if len(a) != len(b) { + panic("vector.InnerProduct: vectors don't have the same length") + } + var tmp Element + for i := 0; i < len(a); i++ { + tmp.Mul(&a[i], &b[i]) + res.Add(res, &tmp) + } +} + +func mulVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Mul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], &b[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bw6-633/fp/vector_test.go b/ecc/bw6-633/fp/vector_test.go index 5d88af91c..94d6557e2 100644 --- a/ecc/bw6-633/fp/vector_test.go +++ b/ecc/bw6-633/fp/vector_test.go @@ -18,10 +18,15 @@ package fp import ( "bytes" + "fmt" "github.com/stretchr/testify/require" + "os" "reflect" "sort" "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" ) func TestVectorSort(t *testing.T) { @@ -88,3 +93,291 @@ func (vector *Vector) unmarshalBinaryAsync(data []byte) error { } return <-chErr } + +func TestVectorOps(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = 2 + } else { + parameters.MinSuccessfulTests = 10 + } + properties := gopter.NewProperties(parameters) + + addVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Add(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Add(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + subVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Sub(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Sub(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + scalarMulVector := func(a Vector, b Element) bool { + c := make(Vector, len(a)) + c.ScalarMul(a, &b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sumVector := func(a Vector) bool { + var sum Element + computed := a.Sum() + for i := 0; i < len(a); i++ { + sum.Add(&sum, &a[i]) + } + + return sum.Equal(&computed) + } + + innerProductVector := func(a, b Vector) bool { + computed := a.InnerProduct(b) + var innerProduct Element + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + innerProduct.Add(&innerProduct, &tmp) + } + + return innerProduct.Equal(&computed) + } + + mulVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + a[0].SetUint64(0x24) + b[0].SetUint64(0x42) + c.Mul(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sizes := []int{1, 2, 3, 4, 8, 9, 15, 16, 509, 510, 511, 512, 513, 514} + type genPair struct { + g1, g2 gopter.Gen + label string + } + + for _, size := range sizes { + generators := []genPair{ + {genZeroVector(size), genZeroVector(size), "zero vectors"}, + {genMaxVector(size), genMaxVector(size), "max vectors"}, + {genVector(size), genVector(size), "random vectors"}, + {genVector(size), genZeroVector(size), "random and zero vectors"}, + } + for _, gp := range generators { + properties.Property(fmt.Sprintf("vector addition %d - %s", size, gp.label), prop.ForAll( + addVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector subtraction %d - %s", size, gp.label), prop.ForAll( + subVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector scalar multiplication %d - %s", size, gp.label), prop.ForAll( + scalarMulVector, + gp.g1, + genElement(), + )) + + properties.Property(fmt.Sprintf("vector sum %d - %s", size, gp.label), prop.ForAll( + sumVector, + gp.g1, + )) + + properties.Property(fmt.Sprintf("vector inner product %d - %s", size, gp.label), prop.ForAll( + innerProductVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector multiplication %d - %s", size, gp.label), prop.ForAll( + mulVector, + gp.g1, + gp.g2, + )) + } + } + + properties.TestingRun(t, gopter.NewFormatedReporter(false, 260, os.Stdout)) +} + +func BenchmarkVectorOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1 << 24 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + var mixer Element + mixer.SetRandom() + for i := 1; i < N; i++ { + a1[i-1].SetUint64(uint64(i)). + Mul(&a1[i-1], &mixer) + b1[i-1].SetUint64(^uint64(i)). + Mul(&b1[i-1], &mixer) + } + + for n := 1 << 4; n <= N; n <<= 1 { + b.Run(fmt.Sprintf("add %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Add(_a, _b) + } + }) + + b.Run(fmt.Sprintf("sub %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Sub(_a, _b) + } + }) + + b.Run(fmt.Sprintf("scalarMul %d", n), func(b *testing.B) { + _a := a1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.ScalarMul(_a, &mixer) + } + }) + + b.Run(fmt.Sprintf("sum %d", n), func(b *testing.B) { + _a := a1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.Sum() + } + }) + + b.Run(fmt.Sprintf("innerProduct %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.InnerProduct(_b) + } + }) + + b.Run(fmt.Sprintf("mul %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Mul(_a, _b) + } + }) + } +} + +func genZeroVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genMaxVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + + qMinusOne := qElement + qMinusOne[0]-- + + for i := 0; i < size; i++ { + g[i] = qMinusOne + } + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + mixer := Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[9] != ^uint64(0) { + mixer[9] %= (qElement[9] + 1) + } + + for !mixer.smallerThanModulus() { + mixer = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[9] != ^uint64(0) { + mixer[9] %= (qElement[9] + 1) + } + } + + for i := 1; i <= size; i++ { + g[i-1].SetUint64(uint64(i)). + Mul(&g[i-1], &mixer) + } + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} diff --git a/ecc/bw6-633/fr/asm.go b/ecc/bw6-633/fr/asm_adx.go similarity index 100% rename from ecc/bw6-633/fr/asm.go rename to ecc/bw6-633/fr/asm_adx.go diff --git a/ecc/bw6-633/fr/element.go b/ecc/bw6-633/fr/element.go index 208f672b1..8841cd342 100644 --- a/ecc/bw6-633/fr/element.go +++ b/ecc/bw6-633/fr/element.go @@ -499,32 +499,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [6]uint64 var D uint64 diff --git a/ecc/bw6-633/fr/element_mul_amd64.s b/ecc/bw6-633/fr/element_mul_amd64.s deleted file mode 100644 index 92bba4f58..000000000 --- a/ecc/bw6-633/fr/element_mul_amd64.s +++ /dev/null @@ -1,656 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x6fe802ff40300001 -DATA q<>+8(SB)/8, $0x421ee5da52bde502 -DATA q<>+16(SB)/8, $0xdec1d01aa27a1ae0 -DATA q<>+24(SB)/8, $0xd3f7498be97c5eaf -DATA q<>+32(SB)/8, $0x04c23a02b586d650 -GLOBL q<>(SB), (RODATA+NOPTR), $40 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x702ff9ff402fffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, rb0, rb1, rb2, rb3, rb4) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $24-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), DI - - // x[0] -> R9 - // x[1] -> R10 - // x[2] -> R11 - MOVQ 0(DI), R9 - MOVQ 8(DI), R10 - MOVQ 16(DI), R11 - MOVQ y+16(FP), R12 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // t[4] -> SI - // clear the flags - XORQ AX, AX - MOVQ 0(R12), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R9, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R10, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R11, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ 24(DI), AX, SI - ADOXQ AX, BX - - // (A,t[4]) := x[4]*y[0] + A - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 8(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 16(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 24(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 32(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // reduce element(R14,R13,CX,BX,SI) using temp registers (R8,DI,R12,R9,R10) - REDUCE(R14,R13,CX,BX,SI,R8,DI,R12,R9,R10) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $8-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - MOVQ 32(DX), SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - - // reduce element(R14,R13,CX,BX,SI) using temp registers (DI,R8,R9,R10,R11) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/bw6-633/fr/element_ops_amd64.go b/ecc/bw6-633/fr/element_ops_amd64.go index e40a9caed..83d40c28c 100644 --- a/ecc/bw6-633/fr/element_ops_amd64.go +++ b/ecc/bw6-633/fr/element_ops_amd64.go @@ -50,48 +50,8 @@ func Butterfly(a, b *Element) // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bw6-633/fr/element_ops_amd64.s b/ecc/bw6-633/fr/element_ops_amd64.s index 9528ab595..29314843d 100644 --- a/ecc/bw6-633/fr/element_ops_amd64.s +++ b/ecc/bw6-633/fr/element_ops_amd64.s @@ -1,272 +1,6 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// Code generated by gnark-crypto/generator. DO NOT EDIT. +// We include the hash to force the Go compiler to recompile: 18184981773209750009 +#include "../../../field/asm/element_5w_amd64.s" -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x6fe802ff40300001 -DATA q<>+8(SB)/8, $0x421ee5da52bde502 -DATA q<>+16(SB)/8, $0xdec1d01aa27a1ae0 -DATA q<>+24(SB)/8, $0xd3f7498be97c5eaf -DATA q<>+32(SB)/8, $0x04c23a02b586d650 -GLOBL q<>(SB), (RODATA+NOPTR), $40 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x702ff9ff402fffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, rb0, rb1, rb2, rb3, rb4) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R13,R14,R15,R8,R9) - REDUCE(DX,CX,BX,SI,DI,R13,R14,R15,R8,R9) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R13,R14,R15,R8,R9) - REDUCE(DX,CX,BX,SI,DI,R13,R14,R15,R8,R9) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R10,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), $16-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP)) - REDUCE(DX,CX,BX,SI,DI,R13,R14,R15,s0-8(SP),s1-16(SP)) - - MOVQ DX, R13 - MOVQ CX, R14 - MOVQ BX, R15 - MOVQ SI, s0-8(SP) - MOVQ DI, s1-16(SP) - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ R13, DX - ADCQ R14, CX - ADCQ R15, BX - ADCQ s0-8(SP), SI - ADCQ s1-16(SP), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), $24-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ 32(AX), R8 - MOVQ CX, R9 - MOVQ BX, R10 - MOVQ SI, R11 - MOVQ DI, R12 - MOVQ R8, R13 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - ADCQ 32(DX), R8 - SUBQ 0(DX), R9 - SBBQ 8(DX), R10 - SBBQ 16(DX), R11 - SBBQ 24(DX), R12 - SBBQ 32(DX), R13 - MOVQ CX, R14 - MOVQ BX, R15 - MOVQ SI, s0-8(SP) - MOVQ DI, s1-16(SP) - MOVQ R8, s2-24(SP) - MOVQ $0x6fe802ff40300001, CX - MOVQ $0x421ee5da52bde502, BX - MOVQ $0xdec1d01aa27a1ae0, SI - MOVQ $0xd3f7498be97c5eaf, DI - MOVQ $0x04c23a02b586d650, R8 - CMOVQCC AX, CX - CMOVQCC AX, BX - CMOVQCC AX, SI - CMOVQCC AX, DI - CMOVQCC AX, R8 - ADDQ CX, R9 - ADCQ BX, R10 - ADCQ SI, R11 - ADCQ DI, R12 - ADCQ R8, R13 - MOVQ R14, CX - MOVQ R15, BX - MOVQ s0-8(SP), SI - MOVQ s1-16(SP), DI - MOVQ s2-24(SP), R8 - MOVQ R9, 0(DX) - MOVQ R10, 8(DX) - MOVQ R11, 16(DX) - MOVQ R12, 24(DX) - MOVQ R13, 32(DX) - - // reduce element(CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - MOVQ R8, 32(AX) - RET diff --git a/ecc/bw6-633/fr/element_ops_purego.go b/ecc/bw6-633/fr/element_ops_purego.go index 34d6c54fb..4a7cdbfe4 100644 --- a/ecc/bw6-633/fr/element_ops_purego.go +++ b/ecc/bw6-633/fr/element_ops_purego.go @@ -66,48 +66,8 @@ func reduce(z *Element) { // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3, t4 uint64 var u0, u1, u2, u3, u4 uint64 diff --git a/ecc/bw6-633/fr/element_test.go b/ecc/bw6-633/fr/element_test.go index e232de8c8..b7acc3840 100644 --- a/ecc/bw6-633/fr/element_test.go +++ b/ecc/bw6-633/fr/element_test.go @@ -639,7 +639,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -709,77 +708,6 @@ func TestElementLexicographicallyLargest(t *testing.T) { } -func TestElementVecOps(t *testing.T) { - assert := require.New(t) - - const N = 7 - a := make(Vector, N) - b := make(Vector, N) - c := make(Vector, N) - for i := 0; i < N; i++ { - a[i].SetRandom() - b[i].SetRandom() - } - - // Vector addition - c.Add(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Add(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector addition failed") - } - - // Vector subtraction - c.Sub(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Sub(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector subtraction failed") - } - - // Vector scaling - c.ScalarMul(a, &b[0]) - for i := 0; i < N; i++ { - var expected Element - expected.Mul(&a[i], &b[0]) - assert.True(c[i].Equal(&expected), "Vector scaling failed") - } -} - -func BenchmarkElementVecOps(b *testing.B) { - // note; to benchmark against "no asm" version, use the following - // build tag: -tags purego - const N = 1024 - a1 := make(Vector, N) - b1 := make(Vector, N) - c1 := make(Vector, N) - for i := 0; i < N; i++ { - a1[i].SetRandom() - b1[i].SetRandom() - } - - b.Run("Add", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Add(a1, b1) - } - }) - - b.Run("Sub", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Sub(a1, b1) - } - }) - - b.Run("ScalarMul", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.ScalarMul(a1, &b1[0]) - } - }) -} - func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() @@ -2349,40 +2277,40 @@ func gen() gopter.Gen { } } -func genFull() gopter.Gen { - return func(genParams *gopter.GenParameters) *gopter.GenResult { +func genRandomFq(genParams *gopter.GenParameters) Element { + var g Element - genRandomFq := func() Element { - var g Element + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } + if qElement[4] != ^uint64(0) { + g[4] %= (qElement[4] + 1) + } - if qElement[4] != ^uint64(0) { - g[4] %= (qElement[4] + 1) - } + for !g.smallerThanModulus() { + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[4] != ^uint64(0) { + g[4] %= (qElement[4] + 1) + } + } - for !g.smallerThanModulus() { - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } - if qElement[4] != ^uint64(0) { - g[4] %= (qElement[4] + 1) - } - } + return g +} - return g - } - a := genRandomFq() +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) var carry uint64 a[0], carry = bits.Add64(a[0], qElement[0], carry) @@ -2396,6 +2324,14 @@ func genFull() gopter.Gen { } } +func genElement() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} + func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { var modulus big.Int var aIntMod big.Int diff --git a/ecc/bw6-633/fr/vector.go b/ecc/bw6-633/fr/vector.go index 1c9b6b975..e3bee5fbd 100644 --- a/ecc/bw6-633/fr/vector.go +++ b/ecc/bw6-633/fr/vector.go @@ -218,6 +218,25 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} + func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") @@ -245,6 +264,32 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + +func innerProductVecGeneric(res *Element, a, b Vector) { + if len(a) != len(b) { + panic("vector.InnerProduct: vectors don't have the same length") + } + var tmp Element + for i := 0; i < len(a); i++ { + tmp.Mul(&a[i], &b[i]) + res.Add(res, &tmp) + } +} + +func mulVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Mul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], &b[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bw6-633/fr/vector_test.go b/ecc/bw6-633/fr/vector_test.go index e58f2d9a3..8245cc928 100644 --- a/ecc/bw6-633/fr/vector_test.go +++ b/ecc/bw6-633/fr/vector_test.go @@ -18,10 +18,15 @@ package fr import ( "bytes" + "fmt" "github.com/stretchr/testify/require" + "os" "reflect" "sort" "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" ) func TestVectorSort(t *testing.T) { @@ -88,3 +93,281 @@ func (vector *Vector) unmarshalBinaryAsync(data []byte) error { } return <-chErr } + +func TestVectorOps(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = 2 + } else { + parameters.MinSuccessfulTests = 10 + } + properties := gopter.NewProperties(parameters) + + addVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Add(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Add(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + subVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Sub(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Sub(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + scalarMulVector := func(a Vector, b Element) bool { + c := make(Vector, len(a)) + c.ScalarMul(a, &b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sumVector := func(a Vector) bool { + var sum Element + computed := a.Sum() + for i := 0; i < len(a); i++ { + sum.Add(&sum, &a[i]) + } + + return sum.Equal(&computed) + } + + innerProductVector := func(a, b Vector) bool { + computed := a.InnerProduct(b) + var innerProduct Element + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + innerProduct.Add(&innerProduct, &tmp) + } + + return innerProduct.Equal(&computed) + } + + mulVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + a[0].SetUint64(0x24) + b[0].SetUint64(0x42) + c.Mul(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sizes := []int{1, 2, 3, 4, 8, 9, 15, 16, 509, 510, 511, 512, 513, 514} + type genPair struct { + g1, g2 gopter.Gen + label string + } + + for _, size := range sizes { + generators := []genPair{ + {genZeroVector(size), genZeroVector(size), "zero vectors"}, + {genMaxVector(size), genMaxVector(size), "max vectors"}, + {genVector(size), genVector(size), "random vectors"}, + {genVector(size), genZeroVector(size), "random and zero vectors"}, + } + for _, gp := range generators { + properties.Property(fmt.Sprintf("vector addition %d - %s", size, gp.label), prop.ForAll( + addVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector subtraction %d - %s", size, gp.label), prop.ForAll( + subVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector scalar multiplication %d - %s", size, gp.label), prop.ForAll( + scalarMulVector, + gp.g1, + genElement(), + )) + + properties.Property(fmt.Sprintf("vector sum %d - %s", size, gp.label), prop.ForAll( + sumVector, + gp.g1, + )) + + properties.Property(fmt.Sprintf("vector inner product %d - %s", size, gp.label), prop.ForAll( + innerProductVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector multiplication %d - %s", size, gp.label), prop.ForAll( + mulVector, + gp.g1, + gp.g2, + )) + } + } + + properties.TestingRun(t, gopter.NewFormatedReporter(false, 260, os.Stdout)) +} + +func BenchmarkVectorOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1 << 24 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + var mixer Element + mixer.SetRandom() + for i := 1; i < N; i++ { + a1[i-1].SetUint64(uint64(i)). + Mul(&a1[i-1], &mixer) + b1[i-1].SetUint64(^uint64(i)). + Mul(&b1[i-1], &mixer) + } + + for n := 1 << 4; n <= N; n <<= 1 { + b.Run(fmt.Sprintf("add %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Add(_a, _b) + } + }) + + b.Run(fmt.Sprintf("sub %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Sub(_a, _b) + } + }) + + b.Run(fmt.Sprintf("scalarMul %d", n), func(b *testing.B) { + _a := a1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.ScalarMul(_a, &mixer) + } + }) + + b.Run(fmt.Sprintf("sum %d", n), func(b *testing.B) { + _a := a1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.Sum() + } + }) + + b.Run(fmt.Sprintf("innerProduct %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.InnerProduct(_b) + } + }) + + b.Run(fmt.Sprintf("mul %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Mul(_a, _b) + } + }) + } +} + +func genZeroVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genMaxVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + + qMinusOne := qElement + qMinusOne[0]-- + + for i := 0; i < size; i++ { + g[i] = qMinusOne + } + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + mixer := Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[4] != ^uint64(0) { + mixer[4] %= (qElement[4] + 1) + } + + for !mixer.smallerThanModulus() { + mixer = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[4] != ^uint64(0) { + mixer[4] %= (qElement[4] + 1) + } + } + + for i := 1; i <= size; i++ { + g[i-1].SetUint64(uint64(i)). + Mul(&g[i-1], &mixer) + } + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} diff --git a/ecc/bw6-761/fp/asm.go b/ecc/bw6-761/fp/asm_adx.go similarity index 100% rename from ecc/bw6-761/fp/asm.go rename to ecc/bw6-761/fp/asm_adx.go diff --git a/ecc/bw6-761/fp/element.go b/ecc/bw6-761/fp/element.go index 36232ebff..8cdd31218 100644 --- a/ecc/bw6-761/fp/element.go +++ b/ecc/bw6-761/fp/element.go @@ -653,32 +653,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [13]uint64 var D uint64 diff --git a/ecc/bw6-761/fp/element_mul_amd64.s b/ecc/bw6-761/fp/element_mul_amd64.s deleted file mode 100644 index fd48d8606..000000000 --- a/ecc/bw6-761/fp/element_mul_amd64.s +++ /dev/null @@ -1,2758 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0xf49d00000000008b -DATA q<>+8(SB)/8, $0xe6913e6870000082 -DATA q<>+16(SB)/8, $0x160cf8aeeaf0a437 -DATA q<>+24(SB)/8, $0x98a116c25667a8f8 -DATA q<>+32(SB)/8, $0x71dcd3dc73ebff2e -DATA q<>+40(SB)/8, $0x8689c8ed12f9fd90 -DATA q<>+48(SB)/8, $0x03cebaff25b42304 -DATA q<>+56(SB)/8, $0x707ba638e584e919 -DATA q<>+64(SB)/8, $0x528275ef8087be41 -DATA q<>+72(SB)/8, $0xb926186a81d14688 -DATA q<>+80(SB)/8, $0xd187c94004faff3e -DATA q<>+88(SB)/8, $0x0122e824fb83ce0a -GLOBL q<>(SB), (RODATA+NOPTR), $96 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x0a5593568fa798dd -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, ra6, ra7, ra8, ra9, ra10, ra11, rb0, rb1, rb2, rb3, rb4, rb5, rb6, rb7, rb8, rb9, rb10, rb11) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - MOVQ ra6, rb6; \ - SBBQ q<>+48(SB), ra6; \ - MOVQ ra7, rb7; \ - SBBQ q<>+56(SB), ra7; \ - MOVQ ra8, rb8; \ - SBBQ q<>+64(SB), ra8; \ - MOVQ ra9, rb9; \ - SBBQ q<>+72(SB), ra9; \ - MOVQ ra10, rb10; \ - SBBQ q<>+80(SB), ra10; \ - MOVQ ra11, rb11; \ - SBBQ q<>+88(SB), ra11; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ - CMOVQCS rb6, ra6; \ - CMOVQCS rb7, ra7; \ - CMOVQCS rb8, ra8; \ - CMOVQCS rb9, ra9; \ - CMOVQCS rb10, ra10; \ - CMOVQCS rb11, ra11; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $96-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), AX - - // x[0] -> s0-8(SP) - // x[1] -> s1-16(SP) - // x[2] -> s2-24(SP) - // x[3] -> s3-32(SP) - // x[4] -> s4-40(SP) - // x[5] -> s5-48(SP) - // x[6] -> s6-56(SP) - // x[7] -> s7-64(SP) - // x[8] -> s8-72(SP) - // x[9] -> s9-80(SP) - // x[10] -> s10-88(SP) - // x[11] -> s11-96(SP) - MOVQ 0(AX), R14 - MOVQ 8(AX), R15 - MOVQ 16(AX), CX - MOVQ 24(AX), BX - MOVQ 32(AX), SI - MOVQ 40(AX), DI - MOVQ 48(AX), R8 - MOVQ 56(AX), R9 - MOVQ 64(AX), R10 - MOVQ 72(AX), R11 - MOVQ 80(AX), R12 - MOVQ 88(AX), R13 - MOVQ R14, s0-8(SP) - MOVQ R15, s1-16(SP) - MOVQ CX, s2-24(SP) - MOVQ BX, s3-32(SP) - MOVQ SI, s4-40(SP) - MOVQ DI, s5-48(SP) - MOVQ R8, s6-56(SP) - MOVQ R9, s7-64(SP) - MOVQ R10, s8-72(SP) - MOVQ R11, s9-80(SP) - MOVQ R12, s10-88(SP) - MOVQ R13, s11-96(SP) - - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // t[4] -> SI - // t[5] -> DI - // t[6] -> R8 - // t[7] -> R9 - // t[8] -> R10 - // t[9] -> R11 - // t[10] -> R12 - // t[11] -> R13 - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 0(AX), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ s0-8(SP), R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ s1-16(SP), AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ s2-24(SP), AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ s3-32(SP), AX, SI - ADOXQ AX, BX - - // (A,t[4]) := x[4]*y[0] + A - MULXQ s4-40(SP), AX, DI - ADOXQ AX, SI - - // (A,t[5]) := x[5]*y[0] + A - MULXQ s5-48(SP), AX, R8 - ADOXQ AX, DI - - // (A,t[6]) := x[6]*y[0] + A - MULXQ s6-56(SP), AX, R9 - ADOXQ AX, R8 - - // (A,t[7]) := x[7]*y[0] + A - MULXQ s7-64(SP), AX, R10 - ADOXQ AX, R9 - - // (A,t[8]) := x[8]*y[0] + A - MULXQ s8-72(SP), AX, R11 - ADOXQ AX, R10 - - // (A,t[9]) := x[9]*y[0] + A - MULXQ s9-80(SP), AX, R12 - ADOXQ AX, R11 - - // (A,t[10]) := x[10]*y[0] + A - MULXQ s10-88(SP), AX, R13 - ADOXQ AX, R12 - - // (A,t[11]) := x[11]*y[0] + A - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 8(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[1] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[1] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[1] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[1] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[1] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[1] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[1] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 16(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[2] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[2] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[2] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[2] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[2] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[2] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[2] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 24(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[3] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[3] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[3] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[3] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[3] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[3] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[3] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 32(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[4] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[4] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[4] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[4] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[4] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[4] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[4] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 40(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[5] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[5] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[5] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[5] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[5] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[5] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[5] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[5] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[5] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[5] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[5] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[5] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 48(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[6] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[6] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[6] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[6] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[6] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[6] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[6] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[6] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[6] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[6] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[6] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[6] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 56(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[7] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[7] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[7] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[7] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[7] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[7] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[7] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[7] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[7] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[7] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[7] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[7] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 64(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[8] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[8] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[8] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[8] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[8] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[8] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[8] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[8] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[8] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[8] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[8] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[8] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 72(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[9] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[9] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[9] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[9] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[9] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[9] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[9] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[9] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[9] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[9] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[9] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[9] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 80(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[10] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[10] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[10] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[10] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[10] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[10] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[10] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[10] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[10] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[10] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[10] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[10] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 88(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[11] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[11] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[11] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[11] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[11] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[11] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[11] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[11] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[11] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[11] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[11] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[11] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // reduce element(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) using temp registers (s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP),s11-96(SP)) - REDUCE(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP),s11-96(SP)) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - MOVQ R8, 48(AX) - MOVQ R9, 56(AX) - MOVQ R10, 64(AX) - MOVQ R11, 72(AX) - MOVQ R12, 80(AX) - MOVQ R13, 88(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $96-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R15 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - MOVQ 32(DX), SI - MOVQ 40(DX), DI - MOVQ 48(DX), R8 - MOVQ 56(DX), R9 - MOVQ 64(DX), R10 - MOVQ 72(DX), R11 - MOVQ 80(DX), R12 - MOVQ 88(DX), R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - - // reduce element(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) using temp registers (s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP),s11-96(SP)) - REDUCE(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP),s11-96(SP)) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - MOVQ R8, 48(AX) - MOVQ R9, 56(AX) - MOVQ R10, 64(AX) - MOVQ R11, 72(AX) - MOVQ R12, 80(AX) - MOVQ R13, 88(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/bw6-761/fp/element_ops_amd64.go b/ecc/bw6-761/fp/element_ops_amd64.go index 83bba45ae..ed2803d71 100644 --- a/ecc/bw6-761/fp/element_ops_amd64.go +++ b/ecc/bw6-761/fp/element_ops_amd64.go @@ -50,48 +50,8 @@ func Butterfly(a, b *Element) // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bw6-761/fp/element_ops_amd64.s b/ecc/bw6-761/fp/element_ops_amd64.s index 476e9e39e..3c8e045ed 100644 --- a/ecc/bw6-761/fp/element_ops_amd64.s +++ b/ecc/bw6-761/fp/element_ops_amd64.s @@ -1,502 +1,6 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// Code generated by gnark-crypto/generator. DO NOT EDIT. +// We include the hash to force the Go compiler to recompile: 13892629867042773109 +#include "../../../field/asm/element_12w_amd64.s" -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0xf49d00000000008b -DATA q<>+8(SB)/8, $0xe6913e6870000082 -DATA q<>+16(SB)/8, $0x160cf8aeeaf0a437 -DATA q<>+24(SB)/8, $0x98a116c25667a8f8 -DATA q<>+32(SB)/8, $0x71dcd3dc73ebff2e -DATA q<>+40(SB)/8, $0x8689c8ed12f9fd90 -DATA q<>+48(SB)/8, $0x03cebaff25b42304 -DATA q<>+56(SB)/8, $0x707ba638e584e919 -DATA q<>+64(SB)/8, $0x528275ef8087be41 -DATA q<>+72(SB)/8, $0xb926186a81d14688 -DATA q<>+80(SB)/8, $0xd187c94004faff3e -DATA q<>+88(SB)/8, $0x0122e824fb83ce0a -GLOBL q<>(SB), (RODATA+NOPTR), $96 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x0a5593568fa798dd -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, ra6, ra7, ra8, ra9, ra10, ra11, rb0, rb1, rb2, rb3, rb4, rb5, rb6, rb7, rb8, rb9, rb10, rb11) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - MOVQ ra6, rb6; \ - SBBQ q<>+48(SB), ra6; \ - MOVQ ra7, rb7; \ - SBBQ q<>+56(SB), ra7; \ - MOVQ ra8, rb8; \ - SBBQ q<>+64(SB), ra8; \ - MOVQ ra9, rb9; \ - SBBQ q<>+72(SB), ra9; \ - MOVQ ra10, rb10; \ - SBBQ q<>+80(SB), ra10; \ - MOVQ ra11, rb11; \ - SBBQ q<>+88(SB), ra11; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ - CMOVQCS rb6, ra6; \ - CMOVQCS rb7, ra7; \ - CMOVQCS rb8, ra8; \ - CMOVQCS rb9, ra9; \ - CMOVQCS rb10, ra10; \ - CMOVQCS rb11, ra11; \ - -TEXT ·reduce(SB), $88-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - MOVQ 48(AX), R9 - MOVQ 56(AX), R10 - MOVQ 64(AX), R11 - MOVQ 72(AX), R12 - MOVQ 80(AX), R13 - MOVQ 88(AX), R14 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - MOVQ R9, 48(AX) - MOVQ R10, 56(AX) - MOVQ R11, 64(AX) - MOVQ R12, 72(AX) - MOVQ R13, 80(AX) - MOVQ R14, 88(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), $88-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - MOVQ 48(AX), R9 - MOVQ 56(AX), R10 - MOVQ 64(AX), R11 - MOVQ 72(AX), R12 - MOVQ 80(AX), R13 - MOVQ 88(AX), R14 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - ADCQ R10, R10 - ADCQ R11, R11 - ADCQ R12, R12 - ADCQ R13, R13 - ADCQ R14, R14 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - ADCQ 48(AX), R9 - ADCQ 56(AX), R10 - ADCQ 64(AX), R11 - ADCQ 72(AX), R12 - ADCQ 80(AX), R13 - ADCQ 88(AX), R14 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - MOVQ R9, 48(AX) - MOVQ R10, 56(AX) - MOVQ R11, 64(AX) - MOVQ R12, 72(AX) - MOVQ R13, 80(AX) - MOVQ R14, 88(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), $88-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - MOVQ 48(AX), R9 - MOVQ 56(AX), R10 - MOVQ 64(AX), R11 - MOVQ 72(AX), R12 - MOVQ 80(AX), R13 - MOVQ 88(AX), R14 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - ADCQ R10, R10 - ADCQ R11, R11 - ADCQ R12, R12 - ADCQ R13, R13 - ADCQ R14, R14 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - ADCQ R10, R10 - ADCQ R11, R11 - ADCQ R12, R12 - ADCQ R13, R13 - ADCQ R14, R14 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - ADCQ 48(AX), R9 - ADCQ 56(AX), R10 - ADCQ 64(AX), R11 - ADCQ 72(AX), R12 - ADCQ 80(AX), R13 - ADCQ 88(AX), R14 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - MOVQ R9, 48(AX) - MOVQ R10, 56(AX) - MOVQ R11, 64(AX) - MOVQ R12, 72(AX) - MOVQ R13, 80(AX) - MOVQ R14, 88(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), $184-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - MOVQ 48(AX), R9 - MOVQ 56(AX), R10 - MOVQ 64(AX), R11 - MOVQ 72(AX), R12 - MOVQ 80(AX), R13 - MOVQ 88(AX), R14 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - ADCQ R10, R10 - ADCQ R11, R11 - ADCQ R12, R12 - ADCQ R13, R13 - ADCQ R14, R14 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - ADCQ R10, R10 - ADCQ R11, R11 - ADCQ R12, R12 - ADCQ R13, R13 - ADCQ R14, R14 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (s11-96(SP),s12-104(SP),s13-112(SP),s14-120(SP),s15-128(SP),s16-136(SP),s17-144(SP),s18-152(SP),s19-160(SP),s20-168(SP),s21-176(SP),s22-184(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,s11-96(SP),s12-104(SP),s13-112(SP),s14-120(SP),s15-128(SP),s16-136(SP),s17-144(SP),s18-152(SP),s19-160(SP),s20-168(SP),s21-176(SP),s22-184(SP)) - - MOVQ DX, s11-96(SP) - MOVQ CX, s12-104(SP) - MOVQ BX, s13-112(SP) - MOVQ SI, s14-120(SP) - MOVQ DI, s15-128(SP) - MOVQ R8, s16-136(SP) - MOVQ R9, s17-144(SP) - MOVQ R10, s18-152(SP) - MOVQ R11, s19-160(SP) - MOVQ R12, s20-168(SP) - MOVQ R13, s21-176(SP) - MOVQ R14, s22-184(SP) - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - ADCQ R10, R10 - ADCQ R11, R11 - ADCQ R12, R12 - ADCQ R13, R13 - ADCQ R14, R14 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - - ADDQ s11-96(SP), DX - ADCQ s12-104(SP), CX - ADCQ s13-112(SP), BX - ADCQ s14-120(SP), SI - ADCQ s15-128(SP), DI - ADCQ s16-136(SP), R8 - ADCQ s17-144(SP), R9 - ADCQ s18-152(SP), R10 - ADCQ s19-160(SP), R11 - ADCQ s20-168(SP), R12 - ADCQ s21-176(SP), R13 - ADCQ s22-184(SP), R14 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - ADCQ 48(AX), R9 - ADCQ 56(AX), R10 - ADCQ 64(AX), R11 - ADCQ 72(AX), R12 - ADCQ 80(AX), R13 - ADCQ 88(AX), R14 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - MOVQ R9, 48(AX) - MOVQ R10, 56(AX) - MOVQ R11, 64(AX) - MOVQ R12, 72(AX) - MOVQ R13, 80(AX) - MOVQ R14, 88(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), $88-16 - MOVQ b+8(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - MOVQ 48(AX), R9 - MOVQ 56(AX), R10 - MOVQ 64(AX), R11 - MOVQ 72(AX), R12 - MOVQ 80(AX), R13 - MOVQ 88(AX), R14 - MOVQ a+0(FP), AX - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - ADCQ 48(AX), R9 - ADCQ 56(AX), R10 - ADCQ 64(AX), R11 - ADCQ 72(AX), R12 - ADCQ 80(AX), R13 - ADCQ 88(AX), R14 - MOVQ DX, R15 - MOVQ CX, s0-8(SP) - MOVQ BX, s1-16(SP) - MOVQ SI, s2-24(SP) - MOVQ DI, s3-32(SP) - MOVQ R8, s4-40(SP) - MOVQ R9, s5-48(SP) - MOVQ R10, s6-56(SP) - MOVQ R11, s7-64(SP) - MOVQ R12, s8-72(SP) - MOVQ R13, s9-80(SP) - MOVQ R14, s10-88(SP) - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - MOVQ 48(AX), R9 - MOVQ 56(AX), R10 - MOVQ 64(AX), R11 - MOVQ 72(AX), R12 - MOVQ 80(AX), R13 - MOVQ 88(AX), R14 - MOVQ b+8(FP), AX - SUBQ 0(AX), DX - SBBQ 8(AX), CX - SBBQ 16(AX), BX - SBBQ 24(AX), SI - SBBQ 32(AX), DI - SBBQ 40(AX), R8 - SBBQ 48(AX), R9 - SBBQ 56(AX), R10 - SBBQ 64(AX), R11 - SBBQ 72(AX), R12 - SBBQ 80(AX), R13 - SBBQ 88(AX), R14 - JCC noReduce_1 - MOVQ $0xf49d00000000008b, AX - ADDQ AX, DX - MOVQ $0xe6913e6870000082, AX - ADCQ AX, CX - MOVQ $0x160cf8aeeaf0a437, AX - ADCQ AX, BX - MOVQ $0x98a116c25667a8f8, AX - ADCQ AX, SI - MOVQ $0x71dcd3dc73ebff2e, AX - ADCQ AX, DI - MOVQ $0x8689c8ed12f9fd90, AX - ADCQ AX, R8 - MOVQ $0x03cebaff25b42304, AX - ADCQ AX, R9 - MOVQ $0x707ba638e584e919, AX - ADCQ AX, R10 - MOVQ $0x528275ef8087be41, AX - ADCQ AX, R11 - MOVQ $0xb926186a81d14688, AX - ADCQ AX, R12 - MOVQ $0xd187c94004faff3e, AX - ADCQ AX, R13 - MOVQ $0x0122e824fb83ce0a, AX - ADCQ AX, R14 - -noReduce_1: - MOVQ b+8(FP), AX - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - MOVQ R9, 48(AX) - MOVQ R10, 56(AX) - MOVQ R11, 64(AX) - MOVQ R12, 72(AX) - MOVQ R13, 80(AX) - MOVQ R14, 88(AX) - MOVQ R15, DX - MOVQ s0-8(SP), CX - MOVQ s1-16(SP), BX - MOVQ s2-24(SP), SI - MOVQ s3-32(SP), DI - MOVQ s4-40(SP), R8 - MOVQ s5-48(SP), R9 - MOVQ s6-56(SP), R10 - MOVQ s7-64(SP), R11 - MOVQ s8-72(SP), R12 - MOVQ s9-80(SP), R13 - MOVQ s10-88(SP), R14 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - - MOVQ a+0(FP), AX - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - MOVQ R9, 48(AX) - MOVQ R10, 56(AX) - MOVQ R11, 64(AX) - MOVQ R12, 72(AX) - MOVQ R13, 80(AX) - MOVQ R14, 88(AX) - RET diff --git a/ecc/bw6-761/fp/element_ops_purego.go b/ecc/bw6-761/fp/element_ops_purego.go index 3c1ffa245..59d6d1d52 100644 --- a/ecc/bw6-761/fp/element_ops_purego.go +++ b/ecc/bw6-761/fp/element_ops_purego.go @@ -73,48 +73,8 @@ func reduce(z *Element) { // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11 uint64 var u0, u1, u2, u3, u4, u5, u6, u7, u8, u9, u10, u11 uint64 diff --git a/ecc/bw6-761/fp/element_test.go b/ecc/bw6-761/fp/element_test.go index fbba1f286..5df3d0de5 100644 --- a/ecc/bw6-761/fp/element_test.go +++ b/ecc/bw6-761/fp/element_test.go @@ -653,7 +653,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -723,77 +722,6 @@ func TestElementLexicographicallyLargest(t *testing.T) { } -func TestElementVecOps(t *testing.T) { - assert := require.New(t) - - const N = 7 - a := make(Vector, N) - b := make(Vector, N) - c := make(Vector, N) - for i := 0; i < N; i++ { - a[i].SetRandom() - b[i].SetRandom() - } - - // Vector addition - c.Add(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Add(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector addition failed") - } - - // Vector subtraction - c.Sub(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Sub(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector subtraction failed") - } - - // Vector scaling - c.ScalarMul(a, &b[0]) - for i := 0; i < N; i++ { - var expected Element - expected.Mul(&a[i], &b[0]) - assert.True(c[i].Equal(&expected), "Vector scaling failed") - } -} - -func BenchmarkElementVecOps(b *testing.B) { - // note; to benchmark against "no asm" version, use the following - // build tag: -tags purego - const N = 1024 - a1 := make(Vector, N) - b1 := make(Vector, N) - c1 := make(Vector, N) - for i := 0; i < N; i++ { - a1[i].SetRandom() - b1[i].SetRandom() - } - - b.Run("Add", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Add(a1, b1) - } - }) - - b.Run("Sub", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Sub(a1, b1) - } - }) - - b.Run("ScalarMul", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.ScalarMul(a1, &b1[0]) - } - }) -} - func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() @@ -2377,54 +2305,54 @@ func gen() gopter.Gen { } } -func genFull() gopter.Gen { - return func(genParams *gopter.GenParameters) *gopter.GenResult { +func genRandomFq(genParams *gopter.GenParameters) Element { + var g Element - genRandomFq := func() Element { - var g Element + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } + if qElement[11] != ^uint64(0) { + g[11] %= (qElement[11] + 1) + } - if qElement[11] != ^uint64(0) { - g[11] %= (qElement[11] + 1) - } + for !g.smallerThanModulus() { + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[11] != ^uint64(0) { + g[11] %= (qElement[11] + 1) + } + } - for !g.smallerThanModulus() { - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } - if qElement[11] != ^uint64(0) { - g[11] %= (qElement[11] + 1) - } - } + return g +} - return g - } - a := genRandomFq() +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) var carry uint64 a[0], carry = bits.Add64(a[0], qElement[0], carry) @@ -2445,6 +2373,14 @@ func genFull() gopter.Gen { } } +func genElement() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} + func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { var modulus big.Int var aIntMod big.Int diff --git a/ecc/bw6-761/fp/vector.go b/ecc/bw6-761/fp/vector.go index 87105028b..8b9107620 100644 --- a/ecc/bw6-761/fp/vector.go +++ b/ecc/bw6-761/fp/vector.go @@ -225,6 +225,25 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} + func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") @@ -252,6 +271,32 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + +func innerProductVecGeneric(res *Element, a, b Vector) { + if len(a) != len(b) { + panic("vector.InnerProduct: vectors don't have the same length") + } + var tmp Element + for i := 0; i < len(a); i++ { + tmp.Mul(&a[i], &b[i]) + res.Add(res, &tmp) + } +} + +func mulVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Mul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], &b[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bw6-761/fp/vector_test.go b/ecc/bw6-761/fp/vector_test.go index 5d88af91c..9f59efdd3 100644 --- a/ecc/bw6-761/fp/vector_test.go +++ b/ecc/bw6-761/fp/vector_test.go @@ -18,10 +18,15 @@ package fp import ( "bytes" + "fmt" "github.com/stretchr/testify/require" + "os" "reflect" "sort" "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" ) func TestVectorSort(t *testing.T) { @@ -88,3 +93,295 @@ func (vector *Vector) unmarshalBinaryAsync(data []byte) error { } return <-chErr } + +func TestVectorOps(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = 2 + } else { + parameters.MinSuccessfulTests = 10 + } + properties := gopter.NewProperties(parameters) + + addVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Add(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Add(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + subVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Sub(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Sub(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + scalarMulVector := func(a Vector, b Element) bool { + c := make(Vector, len(a)) + c.ScalarMul(a, &b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sumVector := func(a Vector) bool { + var sum Element + computed := a.Sum() + for i := 0; i < len(a); i++ { + sum.Add(&sum, &a[i]) + } + + return sum.Equal(&computed) + } + + innerProductVector := func(a, b Vector) bool { + computed := a.InnerProduct(b) + var innerProduct Element + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + innerProduct.Add(&innerProduct, &tmp) + } + + return innerProduct.Equal(&computed) + } + + mulVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + a[0].SetUint64(0x24) + b[0].SetUint64(0x42) + c.Mul(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sizes := []int{1, 2, 3, 4, 8, 9, 15, 16, 509, 510, 511, 512, 513, 514} + type genPair struct { + g1, g2 gopter.Gen + label string + } + + for _, size := range sizes { + generators := []genPair{ + {genZeroVector(size), genZeroVector(size), "zero vectors"}, + {genMaxVector(size), genMaxVector(size), "max vectors"}, + {genVector(size), genVector(size), "random vectors"}, + {genVector(size), genZeroVector(size), "random and zero vectors"}, + } + for _, gp := range generators { + properties.Property(fmt.Sprintf("vector addition %d - %s", size, gp.label), prop.ForAll( + addVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector subtraction %d - %s", size, gp.label), prop.ForAll( + subVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector scalar multiplication %d - %s", size, gp.label), prop.ForAll( + scalarMulVector, + gp.g1, + genElement(), + )) + + properties.Property(fmt.Sprintf("vector sum %d - %s", size, gp.label), prop.ForAll( + sumVector, + gp.g1, + )) + + properties.Property(fmt.Sprintf("vector inner product %d - %s", size, gp.label), prop.ForAll( + innerProductVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector multiplication %d - %s", size, gp.label), prop.ForAll( + mulVector, + gp.g1, + gp.g2, + )) + } + } + + properties.TestingRun(t, gopter.NewFormatedReporter(false, 260, os.Stdout)) +} + +func BenchmarkVectorOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1 << 24 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + var mixer Element + mixer.SetRandom() + for i := 1; i < N; i++ { + a1[i-1].SetUint64(uint64(i)). + Mul(&a1[i-1], &mixer) + b1[i-1].SetUint64(^uint64(i)). + Mul(&b1[i-1], &mixer) + } + + for n := 1 << 4; n <= N; n <<= 1 { + b.Run(fmt.Sprintf("add %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Add(_a, _b) + } + }) + + b.Run(fmt.Sprintf("sub %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Sub(_a, _b) + } + }) + + b.Run(fmt.Sprintf("scalarMul %d", n), func(b *testing.B) { + _a := a1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.ScalarMul(_a, &mixer) + } + }) + + b.Run(fmt.Sprintf("sum %d", n), func(b *testing.B) { + _a := a1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.Sum() + } + }) + + b.Run(fmt.Sprintf("innerProduct %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.InnerProduct(_b) + } + }) + + b.Run(fmt.Sprintf("mul %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Mul(_a, _b) + } + }) + } +} + +func genZeroVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genMaxVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + + qMinusOne := qElement + qMinusOne[0]-- + + for i := 0; i < size; i++ { + g[i] = qMinusOne + } + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + mixer := Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[11] != ^uint64(0) { + mixer[11] %= (qElement[11] + 1) + } + + for !mixer.smallerThanModulus() { + mixer = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[11] != ^uint64(0) { + mixer[11] %= (qElement[11] + 1) + } + } + + for i := 1; i <= size; i++ { + g[i-1].SetUint64(uint64(i)). + Mul(&g[i-1], &mixer) + } + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} diff --git a/ecc/bw6-761/fr/asm.go b/ecc/bw6-761/fr/asm_adx.go similarity index 100% rename from ecc/bw6-761/fr/asm.go rename to ecc/bw6-761/fr/asm_adx.go diff --git a/ecc/bw6-761/fr/element.go b/ecc/bw6-761/fr/element.go index 3e7eacc9e..6784bc911 100644 --- a/ecc/bw6-761/fr/element.go +++ b/ecc/bw6-761/fr/element.go @@ -521,32 +521,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [7]uint64 var D uint64 diff --git a/ecc/bw6-761/fr/element_mul_amd64.s b/ecc/bw6-761/fr/element_mul_amd64.s deleted file mode 100644 index 1e19c4d3f..000000000 --- a/ecc/bw6-761/fr/element_mul_amd64.s +++ /dev/null @@ -1,857 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x8508c00000000001 -DATA q<>+8(SB)/8, $0x170b5d4430000000 -DATA q<>+16(SB)/8, $0x1ef3622fba094800 -DATA q<>+24(SB)/8, $0x1a22d9f300f5138f -DATA q<>+32(SB)/8, $0xc63b05c06ca1493b -DATA q<>+40(SB)/8, $0x01ae3a4617c510ea -GLOBL q<>(SB), (RODATA+NOPTR), $48 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x8508bfffffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, rb0, rb1, rb2, rb3, rb4, rb5) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $24-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), R8 - - // x[0] -> R10 - // x[1] -> R11 - // x[2] -> R12 - MOVQ 0(R8), R10 - MOVQ 8(R8), R11 - MOVQ 16(R8), R12 - MOVQ y+16(FP), R13 - - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // t[4] -> SI - // t[5] -> DI - // clear the flags - XORQ AX, AX - MOVQ 0(R13), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R10, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R11, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R12, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ 24(R8), AX, SI - ADOXQ AX, BX - - // (A,t[4]) := x[4]*y[0] + A - MULXQ 32(R8), AX, DI - ADOXQ AX, SI - - // (A,t[5]) := x[5]*y[0] + A - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 8(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[1] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 16(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[2] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 24(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[3] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 32(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[4] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 40(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[5] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[5] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[5] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[5] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[5] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[5] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // reduce element(R14,R15,CX,BX,SI,DI) using temp registers (R9,R8,R13,R10,R11,R12) - REDUCE(R14,R15,CX,BX,SI,DI,R9,R8,R13,R10,R11,R12) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $8-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R15 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - MOVQ 32(DX), SI - MOVQ 40(DX), DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - - // reduce element(R14,R15,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12,R13) - REDUCE(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/bw6-761/fr/element_ops_amd64.go b/ecc/bw6-761/fr/element_ops_amd64.go index e40a9caed..83d40c28c 100644 --- a/ecc/bw6-761/fr/element_ops_amd64.go +++ b/ecc/bw6-761/fr/element_ops_amd64.go @@ -50,48 +50,8 @@ func Butterfly(a, b *Element) // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bw6-761/fr/element_ops_amd64.s b/ecc/bw6-761/fr/element_ops_amd64.s index 7242622a4..cabff26f7 100644 --- a/ecc/bw6-761/fr/element_ops_amd64.s +++ b/ecc/bw6-761/fr/element_ops_amd64.s @@ -1,306 +1,6 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// Code generated by gnark-crypto/generator. DO NOT EDIT. +// We include the hash to force the Go compiler to recompile: 11124594824487954849 +#include "../../../field/asm/element_6w_amd64.s" -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x8508c00000000001 -DATA q<>+8(SB)/8, $0x170b5d4430000000 -DATA q<>+16(SB)/8, $0x1ef3622fba094800 -DATA q<>+24(SB)/8, $0x1a22d9f300f5138f -DATA q<>+32(SB)/8, $0xc63b05c06ca1493b -DATA q<>+40(SB)/8, $0x01ae3a4617c510ea -GLOBL q<>(SB), (RODATA+NOPTR), $48 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x8508bfffffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, rb0, rb1, rb2, rb3, rb4, rb5) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R15,R9,R10,R11,R12,R13) - REDUCE(DX,CX,BX,SI,DI,R8,R15,R9,R10,R11,R12,R13) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R15,R9,R10,R11,R12,R13) - REDUCE(DX,CX,BX,SI,DI,R8,R15,R9,R10,R11,R12,R13) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R14,R15,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R14,R15,R9,R10,R11,R12) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), $40-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP)) - - MOVQ DX, R15 - MOVQ CX, s0-8(SP) - MOVQ BX, s1-16(SP) - MOVQ SI, s2-24(SP) - MOVQ DI, s3-32(SP) - MOVQ R8, s4-40(SP) - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ R15, DX - ADCQ s0-8(SP), CX - ADCQ s1-16(SP), BX - ADCQ s2-24(SP), SI - ADCQ s3-32(SP), DI - ADCQ s4-40(SP), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), $48-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ 32(AX), R8 - MOVQ 40(AX), R9 - MOVQ CX, R10 - MOVQ BX, R11 - MOVQ SI, R12 - MOVQ DI, R13 - MOVQ R8, R14 - MOVQ R9, R15 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - ADCQ 32(DX), R8 - ADCQ 40(DX), R9 - SUBQ 0(DX), R10 - SBBQ 8(DX), R11 - SBBQ 16(DX), R12 - SBBQ 24(DX), R13 - SBBQ 32(DX), R14 - SBBQ 40(DX), R15 - MOVQ CX, s0-8(SP) - MOVQ BX, s1-16(SP) - MOVQ SI, s2-24(SP) - MOVQ DI, s3-32(SP) - MOVQ R8, s4-40(SP) - MOVQ R9, s5-48(SP) - MOVQ $0x8508c00000000001, CX - MOVQ $0x170b5d4430000000, BX - MOVQ $0x1ef3622fba094800, SI - MOVQ $0x1a22d9f300f5138f, DI - MOVQ $0xc63b05c06ca1493b, R8 - MOVQ $0x01ae3a4617c510ea, R9 - CMOVQCC AX, CX - CMOVQCC AX, BX - CMOVQCC AX, SI - CMOVQCC AX, DI - CMOVQCC AX, R8 - CMOVQCC AX, R9 - ADDQ CX, R10 - ADCQ BX, R11 - ADCQ SI, R12 - ADCQ DI, R13 - ADCQ R8, R14 - ADCQ R9, R15 - MOVQ s0-8(SP), CX - MOVQ s1-16(SP), BX - MOVQ s2-24(SP), SI - MOVQ s3-32(SP), DI - MOVQ s4-40(SP), R8 - MOVQ s5-48(SP), R9 - MOVQ R10, 0(DX) - MOVQ R11, 8(DX) - MOVQ R12, 16(DX) - MOVQ R13, 24(DX) - MOVQ R14, 32(DX) - MOVQ R15, 40(DX) - - // reduce element(CX,BX,SI,DI,R8,R9) using temp registers (R10,R11,R12,R13,R14,R15) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - MOVQ R8, 32(AX) - MOVQ R9, 40(AX) - RET diff --git a/ecc/bw6-761/fr/element_ops_purego.go b/ecc/bw6-761/fr/element_ops_purego.go index bd2d33293..bdf76428d 100644 --- a/ecc/bw6-761/fr/element_ops_purego.go +++ b/ecc/bw6-761/fr/element_ops_purego.go @@ -67,48 +67,8 @@ func reduce(z *Element) { // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3, t4, t5 uint64 var u0, u1, u2, u3, u4, u5 uint64 diff --git a/ecc/bw6-761/fr/element_test.go b/ecc/bw6-761/fr/element_test.go index 0596297e8..bcc35c484 100644 --- a/ecc/bw6-761/fr/element_test.go +++ b/ecc/bw6-761/fr/element_test.go @@ -641,7 +641,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -711,77 +710,6 @@ func TestElementLexicographicallyLargest(t *testing.T) { } -func TestElementVecOps(t *testing.T) { - assert := require.New(t) - - const N = 7 - a := make(Vector, N) - b := make(Vector, N) - c := make(Vector, N) - for i := 0; i < N; i++ { - a[i].SetRandom() - b[i].SetRandom() - } - - // Vector addition - c.Add(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Add(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector addition failed") - } - - // Vector subtraction - c.Sub(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Sub(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector subtraction failed") - } - - // Vector scaling - c.ScalarMul(a, &b[0]) - for i := 0; i < N; i++ { - var expected Element - expected.Mul(&a[i], &b[0]) - assert.True(c[i].Equal(&expected), "Vector scaling failed") - } -} - -func BenchmarkElementVecOps(b *testing.B) { - // note; to benchmark against "no asm" version, use the following - // build tag: -tags purego - const N = 1024 - a1 := make(Vector, N) - b1 := make(Vector, N) - c1 := make(Vector, N) - for i := 0; i < N; i++ { - a1[i].SetRandom() - b1[i].SetRandom() - } - - b.Run("Add", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Add(a1, b1) - } - }) - - b.Run("Sub", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Sub(a1, b1) - } - }) - - b.Run("ScalarMul", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.ScalarMul(a1, &b1[0]) - } - }) -} - func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() @@ -2353,42 +2281,42 @@ func gen() gopter.Gen { } } -func genFull() gopter.Gen { - return func(genParams *gopter.GenParameters) *gopter.GenResult { +func genRandomFq(genParams *gopter.GenParameters) Element { + var g Element - genRandomFq := func() Element { - var g Element + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } + if qElement[5] != ^uint64(0) { + g[5] %= (qElement[5] + 1) + } - if qElement[5] != ^uint64(0) { - g[5] %= (qElement[5] + 1) - } + for !g.smallerThanModulus() { + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[5] != ^uint64(0) { + g[5] %= (qElement[5] + 1) + } + } - for !g.smallerThanModulus() { - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } - if qElement[5] != ^uint64(0) { - g[5] %= (qElement[5] + 1) - } - } + return g +} - return g - } - a := genRandomFq() +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) var carry uint64 a[0], carry = bits.Add64(a[0], qElement[0], carry) @@ -2403,6 +2331,14 @@ func genFull() gopter.Gen { } } +func genElement() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} + func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { var modulus big.Int var aIntMod big.Int diff --git a/ecc/bw6-761/fr/vector.go b/ecc/bw6-761/fr/vector.go index 8dd4774c5..af400c4e4 100644 --- a/ecc/bw6-761/fr/vector.go +++ b/ecc/bw6-761/fr/vector.go @@ -219,6 +219,25 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} + func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") @@ -246,6 +265,32 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + +func innerProductVecGeneric(res *Element, a, b Vector) { + if len(a) != len(b) { + panic("vector.InnerProduct: vectors don't have the same length") + } + var tmp Element + for i := 0; i < len(a); i++ { + tmp.Mul(&a[i], &b[i]) + res.Add(res, &tmp) + } +} + +func mulVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Mul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], &b[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bw6-761/fr/vector_test.go b/ecc/bw6-761/fr/vector_test.go index e58f2d9a3..ad574704b 100644 --- a/ecc/bw6-761/fr/vector_test.go +++ b/ecc/bw6-761/fr/vector_test.go @@ -18,10 +18,15 @@ package fr import ( "bytes" + "fmt" "github.com/stretchr/testify/require" + "os" "reflect" "sort" "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" ) func TestVectorSort(t *testing.T) { @@ -88,3 +93,283 @@ func (vector *Vector) unmarshalBinaryAsync(data []byte) error { } return <-chErr } + +func TestVectorOps(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = 2 + } else { + parameters.MinSuccessfulTests = 10 + } + properties := gopter.NewProperties(parameters) + + addVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Add(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Add(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + subVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Sub(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Sub(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + scalarMulVector := func(a Vector, b Element) bool { + c := make(Vector, len(a)) + c.ScalarMul(a, &b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sumVector := func(a Vector) bool { + var sum Element + computed := a.Sum() + for i := 0; i < len(a); i++ { + sum.Add(&sum, &a[i]) + } + + return sum.Equal(&computed) + } + + innerProductVector := func(a, b Vector) bool { + computed := a.InnerProduct(b) + var innerProduct Element + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + innerProduct.Add(&innerProduct, &tmp) + } + + return innerProduct.Equal(&computed) + } + + mulVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + a[0].SetUint64(0x24) + b[0].SetUint64(0x42) + c.Mul(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sizes := []int{1, 2, 3, 4, 8, 9, 15, 16, 509, 510, 511, 512, 513, 514} + type genPair struct { + g1, g2 gopter.Gen + label string + } + + for _, size := range sizes { + generators := []genPair{ + {genZeroVector(size), genZeroVector(size), "zero vectors"}, + {genMaxVector(size), genMaxVector(size), "max vectors"}, + {genVector(size), genVector(size), "random vectors"}, + {genVector(size), genZeroVector(size), "random and zero vectors"}, + } + for _, gp := range generators { + properties.Property(fmt.Sprintf("vector addition %d - %s", size, gp.label), prop.ForAll( + addVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector subtraction %d - %s", size, gp.label), prop.ForAll( + subVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector scalar multiplication %d - %s", size, gp.label), prop.ForAll( + scalarMulVector, + gp.g1, + genElement(), + )) + + properties.Property(fmt.Sprintf("vector sum %d - %s", size, gp.label), prop.ForAll( + sumVector, + gp.g1, + )) + + properties.Property(fmt.Sprintf("vector inner product %d - %s", size, gp.label), prop.ForAll( + innerProductVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector multiplication %d - %s", size, gp.label), prop.ForAll( + mulVector, + gp.g1, + gp.g2, + )) + } + } + + properties.TestingRun(t, gopter.NewFormatedReporter(false, 260, os.Stdout)) +} + +func BenchmarkVectorOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1 << 24 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + var mixer Element + mixer.SetRandom() + for i := 1; i < N; i++ { + a1[i-1].SetUint64(uint64(i)). + Mul(&a1[i-1], &mixer) + b1[i-1].SetUint64(^uint64(i)). + Mul(&b1[i-1], &mixer) + } + + for n := 1 << 4; n <= N; n <<= 1 { + b.Run(fmt.Sprintf("add %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Add(_a, _b) + } + }) + + b.Run(fmt.Sprintf("sub %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Sub(_a, _b) + } + }) + + b.Run(fmt.Sprintf("scalarMul %d", n), func(b *testing.B) { + _a := a1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.ScalarMul(_a, &mixer) + } + }) + + b.Run(fmt.Sprintf("sum %d", n), func(b *testing.B) { + _a := a1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.Sum() + } + }) + + b.Run(fmt.Sprintf("innerProduct %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.InnerProduct(_b) + } + }) + + b.Run(fmt.Sprintf("mul %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Mul(_a, _b) + } + }) + } +} + +func genZeroVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genMaxVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + + qMinusOne := qElement + qMinusOne[0]-- + + for i := 0; i < size; i++ { + g[i] = qMinusOne + } + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + mixer := Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[5] != ^uint64(0) { + mixer[5] %= (qElement[5] + 1) + } + + for !mixer.smallerThanModulus() { + mixer = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[5] != ^uint64(0) { + mixer[5] %= (qElement[5] + 1) + } + } + + for i := 1; i <= size; i++ { + g[i-1].SetUint64(uint64(i)). + Mul(&g[i-1], &mixer) + } + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} diff --git a/ecc/secp256k1/fp/element.go b/ecc/secp256k1/fp/element.go index 0a242dd37..73045a133 100644 --- a/ecc/secp256k1/fp/element.go +++ b/ecc/secp256k1/fp/element.go @@ -81,6 +81,9 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 15580212934572586289 +// mu = 2^288 / q needed for partial Barrett reduction +const mu uint64 = 4294967296 + func init() { _modulus.SetString("fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f", 16) } @@ -505,32 +508,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [5]uint64 var D uint64 diff --git a/ecc/secp256k1/fp/element_ops_purego.go b/ecc/secp256k1/fp/element_ops_purego.go index a8624a511..f53ffa325 100644 --- a/ecc/secp256k1/fp/element_ops_purego.go +++ b/ecc/secp256k1/fp/element_ops_purego.go @@ -57,53 +57,11 @@ func reduce(z *Element) { _reduceGeneric(z) } -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) -} - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) -} - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) -} - // Mul z = x * y (mod q) func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [5]uint64 var D uint64 diff --git a/ecc/secp256k1/fp/element_test.go b/ecc/secp256k1/fp/element_test.go index 6f8165b18..cbe2b50a0 100644 --- a/ecc/secp256k1/fp/element_test.go +++ b/ecc/secp256k1/fp/element_test.go @@ -635,7 +635,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -705,77 +704,6 @@ func TestElementLexicographicallyLargest(t *testing.T) { } -func TestElementVecOps(t *testing.T) { - assert := require.New(t) - - const N = 7 - a := make(Vector, N) - b := make(Vector, N) - c := make(Vector, N) - for i := 0; i < N; i++ { - a[i].SetRandom() - b[i].SetRandom() - } - - // Vector addition - c.Add(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Add(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector addition failed") - } - - // Vector subtraction - c.Sub(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Sub(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector subtraction failed") - } - - // Vector scaling - c.ScalarMul(a, &b[0]) - for i := 0; i < N; i++ { - var expected Element - expected.Mul(&a[i], &b[0]) - assert.True(c[i].Equal(&expected), "Vector scaling failed") - } -} - -func BenchmarkElementVecOps(b *testing.B) { - // note; to benchmark against "no asm" version, use the following - // build tag: -tags purego - const N = 1024 - a1 := make(Vector, N) - b1 := make(Vector, N) - c1 := make(Vector, N) - for i := 0; i < N; i++ { - a1[i].SetRandom() - b1[i].SetRandom() - } - - b.Run("Add", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Add(a1, b1) - } - }) - - b.Run("Sub", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Sub(a1, b1) - } - }) - - b.Run("ScalarMul", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.ScalarMul(a1, &b1[0]) - } - }) -} - func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() @@ -2343,38 +2271,38 @@ func gen() gopter.Gen { } } -func genFull() gopter.Gen { - return func(genParams *gopter.GenParameters) *gopter.GenResult { +func genRandomFq(genParams *gopter.GenParameters) Element { + var g Element - genRandomFq := func() Element { - var g Element + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } - if qElement[3] != ^uint64(0) { - g[3] %= (qElement[3] + 1) - } + for !g.smallerThanModulus() { + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } + } - for !g.smallerThanModulus() { - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } - if qElement[3] != ^uint64(0) { - g[3] %= (qElement[3] + 1) - } - } + return g +} - return g - } - a := genRandomFq() +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) var carry uint64 a[0], carry = bits.Add64(a[0], qElement[0], carry) @@ -2386,3 +2314,11 @@ func genFull() gopter.Gen { return genResult } } + +func genElement() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} diff --git a/ecc/secp256k1/fp/vector.go b/ecc/secp256k1/fp/vector.go index 850b3603d..fa22cb416 100644 --- a/ecc/secp256k1/fp/vector.go +++ b/ecc/secp256k1/fp/vector.go @@ -199,6 +199,43 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} + func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") @@ -226,6 +263,32 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + +func innerProductVecGeneric(res *Element, a, b Vector) { + if len(a) != len(b) { + panic("vector.InnerProduct: vectors don't have the same length") + } + var tmp Element + for i := 0; i < len(a); i++ { + tmp.Mul(&a[i], &b[i]) + res.Add(res, &tmp) + } +} + +func mulVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Mul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], &b[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/secp256k1/fp/vector_test.go b/ecc/secp256k1/fp/vector_test.go index 5d88af91c..12f17e21f 100644 --- a/ecc/secp256k1/fp/vector_test.go +++ b/ecc/secp256k1/fp/vector_test.go @@ -18,10 +18,15 @@ package fp import ( "bytes" + "fmt" "github.com/stretchr/testify/require" + "os" "reflect" "sort" "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" ) func TestVectorSort(t *testing.T) { @@ -88,3 +93,279 @@ func (vector *Vector) unmarshalBinaryAsync(data []byte) error { } return <-chErr } + +func TestVectorOps(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = 2 + } else { + parameters.MinSuccessfulTests = 10 + } + properties := gopter.NewProperties(parameters) + + addVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Add(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Add(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + subVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Sub(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Sub(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + scalarMulVector := func(a Vector, b Element) bool { + c := make(Vector, len(a)) + c.ScalarMul(a, &b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sumVector := func(a Vector) bool { + var sum Element + computed := a.Sum() + for i := 0; i < len(a); i++ { + sum.Add(&sum, &a[i]) + } + + return sum.Equal(&computed) + } + + innerProductVector := func(a, b Vector) bool { + computed := a.InnerProduct(b) + var innerProduct Element + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + innerProduct.Add(&innerProduct, &tmp) + } + + return innerProduct.Equal(&computed) + } + + mulVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + a[0].SetUint64(0x24) + b[0].SetUint64(0x42) + c.Mul(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sizes := []int{1, 2, 3, 4, 8, 9, 15, 16, 509, 510, 511, 512, 513, 514} + type genPair struct { + g1, g2 gopter.Gen + label string + } + + for _, size := range sizes { + generators := []genPair{ + {genZeroVector(size), genZeroVector(size), "zero vectors"}, + {genMaxVector(size), genMaxVector(size), "max vectors"}, + {genVector(size), genVector(size), "random vectors"}, + {genVector(size), genZeroVector(size), "random and zero vectors"}, + } + for _, gp := range generators { + properties.Property(fmt.Sprintf("vector addition %d - %s", size, gp.label), prop.ForAll( + addVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector subtraction %d - %s", size, gp.label), prop.ForAll( + subVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector scalar multiplication %d - %s", size, gp.label), prop.ForAll( + scalarMulVector, + gp.g1, + genElement(), + )) + + properties.Property(fmt.Sprintf("vector sum %d - %s", size, gp.label), prop.ForAll( + sumVector, + gp.g1, + )) + + properties.Property(fmt.Sprintf("vector inner product %d - %s", size, gp.label), prop.ForAll( + innerProductVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector multiplication %d - %s", size, gp.label), prop.ForAll( + mulVector, + gp.g1, + gp.g2, + )) + } + } + + properties.TestingRun(t, gopter.NewFormatedReporter(false, 260, os.Stdout)) +} + +func BenchmarkVectorOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1 << 24 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + var mixer Element + mixer.SetRandom() + for i := 1; i < N; i++ { + a1[i-1].SetUint64(uint64(i)). + Mul(&a1[i-1], &mixer) + b1[i-1].SetUint64(^uint64(i)). + Mul(&b1[i-1], &mixer) + } + + for n := 1 << 4; n <= N; n <<= 1 { + b.Run(fmt.Sprintf("add %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Add(_a, _b) + } + }) + + b.Run(fmt.Sprintf("sub %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Sub(_a, _b) + } + }) + + b.Run(fmt.Sprintf("scalarMul %d", n), func(b *testing.B) { + _a := a1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.ScalarMul(_a, &mixer) + } + }) + + b.Run(fmt.Sprintf("sum %d", n), func(b *testing.B) { + _a := a1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.Sum() + } + }) + + b.Run(fmt.Sprintf("innerProduct %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.InnerProduct(_b) + } + }) + + b.Run(fmt.Sprintf("mul %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Mul(_a, _b) + } + }) + } +} + +func genZeroVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genMaxVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + + qMinusOne := qElement + qMinusOne[0]-- + + for i := 0; i < size; i++ { + g[i] = qMinusOne + } + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + mixer := Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + mixer[3] %= (qElement[3] + 1) + } + + for !mixer.smallerThanModulus() { + mixer = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + mixer[3] %= (qElement[3] + 1) + } + } + + for i := 1; i <= size; i++ { + g[i-1].SetUint64(uint64(i)). + Mul(&g[i-1], &mixer) + } + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} diff --git a/ecc/secp256k1/fr/element.go b/ecc/secp256k1/fr/element.go index 6afe3590b..e2f81b66b 100644 --- a/ecc/secp256k1/fr/element.go +++ b/ecc/secp256k1/fr/element.go @@ -81,6 +81,9 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 5408259542528602431 +// mu = 2^288 / q needed for partial Barrett reduction +const mu uint64 = 4294967296 + func init() { _modulus.SetString("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141", 16) } @@ -505,32 +508,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [5]uint64 var D uint64 diff --git a/ecc/secp256k1/fr/element_ops_purego.go b/ecc/secp256k1/fr/element_ops_purego.go index 1a46f6d79..ef83ea20a 100644 --- a/ecc/secp256k1/fr/element_ops_purego.go +++ b/ecc/secp256k1/fr/element_ops_purego.go @@ -57,53 +57,11 @@ func reduce(z *Element) { _reduceGeneric(z) } -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) -} - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) -} - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) -} - // Mul z = x * y (mod q) func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [5]uint64 var D uint64 diff --git a/ecc/secp256k1/fr/element_test.go b/ecc/secp256k1/fr/element_test.go index f554db8e3..2ab020991 100644 --- a/ecc/secp256k1/fr/element_test.go +++ b/ecc/secp256k1/fr/element_test.go @@ -635,7 +635,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -705,77 +704,6 @@ func TestElementLexicographicallyLargest(t *testing.T) { } -func TestElementVecOps(t *testing.T) { - assert := require.New(t) - - const N = 7 - a := make(Vector, N) - b := make(Vector, N) - c := make(Vector, N) - for i := 0; i < N; i++ { - a[i].SetRandom() - b[i].SetRandom() - } - - // Vector addition - c.Add(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Add(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector addition failed") - } - - // Vector subtraction - c.Sub(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Sub(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector subtraction failed") - } - - // Vector scaling - c.ScalarMul(a, &b[0]) - for i := 0; i < N; i++ { - var expected Element - expected.Mul(&a[i], &b[0]) - assert.True(c[i].Equal(&expected), "Vector scaling failed") - } -} - -func BenchmarkElementVecOps(b *testing.B) { - // note; to benchmark against "no asm" version, use the following - // build tag: -tags purego - const N = 1024 - a1 := make(Vector, N) - b1 := make(Vector, N) - c1 := make(Vector, N) - for i := 0; i < N; i++ { - a1[i].SetRandom() - b1[i].SetRandom() - } - - b.Run("Add", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Add(a1, b1) - } - }) - - b.Run("Sub", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Sub(a1, b1) - } - }) - - b.Run("ScalarMul", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.ScalarMul(a1, &b1[0]) - } - }) -} - func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() @@ -2343,38 +2271,38 @@ func gen() gopter.Gen { } } -func genFull() gopter.Gen { - return func(genParams *gopter.GenParameters) *gopter.GenResult { +func genRandomFq(genParams *gopter.GenParameters) Element { + var g Element - genRandomFq := func() Element { - var g Element + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } - if qElement[3] != ^uint64(0) { - g[3] %= (qElement[3] + 1) - } + for !g.smallerThanModulus() { + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } + } - for !g.smallerThanModulus() { - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } - if qElement[3] != ^uint64(0) { - g[3] %= (qElement[3] + 1) - } - } + return g +} - return g - } - a := genRandomFq() +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) var carry uint64 a[0], carry = bits.Add64(a[0], qElement[0], carry) @@ -2386,3 +2314,11 @@ func genFull() gopter.Gen { return genResult } } + +func genElement() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} diff --git a/ecc/secp256k1/fr/vector.go b/ecc/secp256k1/fr/vector.go index f39828547..bcc71efcd 100644 --- a/ecc/secp256k1/fr/vector.go +++ b/ecc/secp256k1/fr/vector.go @@ -199,6 +199,43 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} + func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") @@ -226,6 +263,32 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + +func innerProductVecGeneric(res *Element, a, b Vector) { + if len(a) != len(b) { + panic("vector.InnerProduct: vectors don't have the same length") + } + var tmp Element + for i := 0; i < len(a); i++ { + tmp.Mul(&a[i], &b[i]) + res.Add(res, &tmp) + } +} + +func mulVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Mul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], &b[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/secp256k1/fr/vector_test.go b/ecc/secp256k1/fr/vector_test.go index e58f2d9a3..b6344c18b 100644 --- a/ecc/secp256k1/fr/vector_test.go +++ b/ecc/secp256k1/fr/vector_test.go @@ -18,10 +18,15 @@ package fr import ( "bytes" + "fmt" "github.com/stretchr/testify/require" + "os" "reflect" "sort" "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" ) func TestVectorSort(t *testing.T) { @@ -88,3 +93,279 @@ func (vector *Vector) unmarshalBinaryAsync(data []byte) error { } return <-chErr } + +func TestVectorOps(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = 2 + } else { + parameters.MinSuccessfulTests = 10 + } + properties := gopter.NewProperties(parameters) + + addVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Add(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Add(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + subVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Sub(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Sub(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + scalarMulVector := func(a Vector, b Element) bool { + c := make(Vector, len(a)) + c.ScalarMul(a, &b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sumVector := func(a Vector) bool { + var sum Element + computed := a.Sum() + for i := 0; i < len(a); i++ { + sum.Add(&sum, &a[i]) + } + + return sum.Equal(&computed) + } + + innerProductVector := func(a, b Vector) bool { + computed := a.InnerProduct(b) + var innerProduct Element + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + innerProduct.Add(&innerProduct, &tmp) + } + + return innerProduct.Equal(&computed) + } + + mulVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + a[0].SetUint64(0x24) + b[0].SetUint64(0x42) + c.Mul(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sizes := []int{1, 2, 3, 4, 8, 9, 15, 16, 509, 510, 511, 512, 513, 514} + type genPair struct { + g1, g2 gopter.Gen + label string + } + + for _, size := range sizes { + generators := []genPair{ + {genZeroVector(size), genZeroVector(size), "zero vectors"}, + {genMaxVector(size), genMaxVector(size), "max vectors"}, + {genVector(size), genVector(size), "random vectors"}, + {genVector(size), genZeroVector(size), "random and zero vectors"}, + } + for _, gp := range generators { + properties.Property(fmt.Sprintf("vector addition %d - %s", size, gp.label), prop.ForAll( + addVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector subtraction %d - %s", size, gp.label), prop.ForAll( + subVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector scalar multiplication %d - %s", size, gp.label), prop.ForAll( + scalarMulVector, + gp.g1, + genElement(), + )) + + properties.Property(fmt.Sprintf("vector sum %d - %s", size, gp.label), prop.ForAll( + sumVector, + gp.g1, + )) + + properties.Property(fmt.Sprintf("vector inner product %d - %s", size, gp.label), prop.ForAll( + innerProductVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector multiplication %d - %s", size, gp.label), prop.ForAll( + mulVector, + gp.g1, + gp.g2, + )) + } + } + + properties.TestingRun(t, gopter.NewFormatedReporter(false, 260, os.Stdout)) +} + +func BenchmarkVectorOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1 << 24 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + var mixer Element + mixer.SetRandom() + for i := 1; i < N; i++ { + a1[i-1].SetUint64(uint64(i)). + Mul(&a1[i-1], &mixer) + b1[i-1].SetUint64(^uint64(i)). + Mul(&b1[i-1], &mixer) + } + + for n := 1 << 4; n <= N; n <<= 1 { + b.Run(fmt.Sprintf("add %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Add(_a, _b) + } + }) + + b.Run(fmt.Sprintf("sub %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Sub(_a, _b) + } + }) + + b.Run(fmt.Sprintf("scalarMul %d", n), func(b *testing.B) { + _a := a1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.ScalarMul(_a, &mixer) + } + }) + + b.Run(fmt.Sprintf("sum %d", n), func(b *testing.B) { + _a := a1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.Sum() + } + }) + + b.Run(fmt.Sprintf("innerProduct %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.InnerProduct(_b) + } + }) + + b.Run(fmt.Sprintf("mul %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Mul(_a, _b) + } + }) + } +} + +func genZeroVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genMaxVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + + qMinusOne := qElement + qMinusOne[0]-- + + for i := 0; i < size; i++ { + g[i] = qMinusOne + } + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + mixer := Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + mixer[3] %= (qElement[3] + 1) + } + + for !mixer.smallerThanModulus() { + mixer = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + mixer[3] %= (qElement[3] + 1) + } + } + + for i := 1; i <= size; i++ { + g[i-1].SetUint64(uint64(i)). + Mul(&g[i-1], &mixer) + } + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} diff --git a/ecc/stark-curve/fp/asm.go b/ecc/stark-curve/fp/asm_adx.go similarity index 100% rename from ecc/stark-curve/fp/asm.go rename to ecc/stark-curve/fp/asm_adx.go diff --git a/ecc/stark-curve/fp/asm_avx.go b/ecc/stark-curve/fp/asm_avx.go new file mode 100644 index 000000000..cea035ee8 --- /dev/null +++ b/ecc/stark-curve/fp/asm_avx.go @@ -0,0 +1,27 @@ +//go:build !noavx +// +build !noavx + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import "golang.org/x/sys/cpu" + +var ( + supportAvx512 = supportAdx && cpu.X86.HasAVX512 && cpu.X86.HasAVX512DQ + _ = supportAvx512 +) diff --git a/ecc/stark-curve/fp/asm_noavx.go b/ecc/stark-curve/fp/asm_noavx.go new file mode 100644 index 000000000..9ca08a375 --- /dev/null +++ b/ecc/stark-curve/fp/asm_noavx.go @@ -0,0 +1,22 @@ +//go:build noavx +// +build noavx + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +const supportAvx512 = false diff --git a/ecc/stark-curve/fp/element.go b/ecc/stark-curve/fp/element.go index 7a057be06..1c53dcb09 100644 --- a/ecc/stark-curve/fp/element.go +++ b/ecc/stark-curve/fp/element.go @@ -81,6 +81,9 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 18446744073709551615 +// mu = 2^288 / q needed for partial Barrett reduction +const mu uint64 = 137438953471 + func init() { _modulus.SetString("800000000000011000000000000000000000000000000000000000000000001", 16) } @@ -477,32 +480,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [5]uint64 var D uint64 diff --git a/ecc/stark-curve/fp/element_mul_amd64.s b/ecc/stark-curve/fp/element_mul_amd64.s deleted file mode 100644 index fab328c86..000000000 --- a/ecc/stark-curve/fp/element_mul_amd64.s +++ /dev/null @@ -1,487 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $1 -DATA q<>+8(SB)/8, $0 -DATA q<>+16(SB)/8, $0 -DATA q<>+24(SB)/8, $0x0800000000000011 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0xffffffffffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $24-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), SI - - // x[0] -> DI - // x[1] -> R8 - // x[2] -> R9 - // x[3] -> R10 - MOVQ 0(SI), DI - MOVQ 8(SI), R8 - MOVQ 16(SI), R9 - MOVQ 24(SI), R10 - MOVQ y+16(FP), R11 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ DI, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R8, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R9, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,R12,R11,DI) - REDUCE(R14,R13,CX,BX,SI,R12,R11,DI) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $8-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,DI,R8,R9) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/stark-curve/fp/element_ops_amd64.go b/ecc/stark-curve/fp/element_ops_amd64.go index 6f16baf68..2ab1a9839 100644 --- a/ecc/stark-curve/fp/element_ops_amd64.go +++ b/ecc/stark-curve/fp/element_ops_amd64.go @@ -51,7 +51,8 @@ func (vector *Vector) Add(a, b Vector) { if len(a) != len(b) || len(a) != len(*vector) { panic("vector.Add: vectors don't have the same length") } - addVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) + n := uint64(len(a)) + addVec(&(*vector)[0], &a[0], &b[0], n) } //go:noescape @@ -75,59 +76,123 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { if len(a) != len(*vector) { panic("vector.ScalarMul: vectors don't have the same length") } - scalarMulVec(&(*vector)[0], &a[0], b, uint64(len(a))) + const maxN = (1 << 32) - 1 + if !supportAvx512 || uint64(len(a)) >= maxN { + // call scalarMulVecGeneric + scalarMulVecGeneric(*vector, a, b) + return + } + n := uint64(len(a)) + if n == 0 { + return + } + // the code for scalarMul is identical to mulVec; and it expects at least + // 2 elements in the vector to fill the Z registers + var bb [2]Element + bb[0] = *b + bb[1] = *b + const blockSize = 16 + scalarMulVec(&(*vector)[0], &a[0], &bb[0], n/blockSize, qInvNeg) + if n%blockSize != 0 { + // call scalarMulVecGeneric on the rest + start := n - n%blockSize + scalarMulVecGeneric((*vector)[start:], a[start:], b) + } +} + +//go:noescape +func scalarMulVec(res, a, b *Element, n uint64, qInvNeg uint64) + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + n := uint64(len(*vector)) + if n == 0 { + return + } + const minN = 16 * 7 // AVX512 slower than generic for small n + const maxN = (1 << 32) - 1 + if !supportAvx512 || n <= minN || n >= maxN { + // call sumVecGeneric + sumVecGeneric(&res, *vector) + return + } + sumVec(&res, &(*vector)[0], uint64(len(*vector))) + return } //go:noescape -func scalarMulVec(res, a, b *Element, n uint64) +func sumVec(res *Element, a *Element, n uint64) + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + n := uint64(len(*vector)) + if n == 0 { + return + } + if n != uint64(len(other)) { + panic("vector.InnerProduct: vectors don't have the same length") + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call innerProductVecGeneric + // note; we could split the vector into smaller chunks and call innerProductVec + innerProductVecGeneric(&res, *vector, other) + return + } + innerProdVec(&res[0], &(*vector)[0], &other[0], uint64(len(*vector))) + + return +} + +//go:noescape +func innerProdVec(res *uint64, a, b *Element, n uint64) + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Mul: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call mulVecGeneric + mulVecGeneric(*vector, a, b) + return + } + + const blockSize = 16 + mulVec(&(*vector)[0], &a[0], &b[0], n/blockSize, qInvNeg) + if n%blockSize != 0 { + // call mulVecGeneric on the rest + start := n - n%blockSize + mulVecGeneric((*vector)[start:], a[start:], b[start:]) + } + +} + +// Patterns use for transposing the vectors in mulVec +var ( + pattern1 = [8]uint64{0, 8, 1, 9, 2, 10, 3, 11} + pattern2 = [8]uint64{12, 4, 13, 5, 14, 6, 15, 7} + pattern3 = [8]uint64{0, 1, 8, 9, 2, 3, 10, 11} + pattern4 = [8]uint64{12, 13, 4, 5, 14, 15, 6, 7} +) + +//go:noescape +func mulVec(res, a, b *Element, n uint64, qInvNeg uint64) // Mul z = x * y (mod q) // // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/stark-curve/fp/element_ops_amd64.s b/ecc/stark-curve/fp/element_ops_amd64.s index 914653b70..6c42136a7 100644 --- a/ecc/stark-curve/fp/element_ops_amd64.s +++ b/ecc/stark-curve/fp/element_ops_amd64.s @@ -1,627 +1,6 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// Code generated by gnark-crypto/generator. DO NOT EDIT. +// We include the hash to force the Go compiler to recompile: 9425145785761608449 +#include "../../../field/asm/element_4w_amd64.s" -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $1 -DATA q<>+8(SB)/8, $0 -DATA q<>+16(SB)/8, $0 -DATA q<>+24(SB)/8, $0x0800000000000011 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0xffffffffffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R15,DI,R8,R9) - REDUCE(DX,CX,BX,SI,R15,DI,R8,R9) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, R11 - MOVQ CX, R12 - MOVQ BX, R13 - MOVQ SI, R14 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ R11, DX - ADCQ R12, CX - ADCQ R13, BX - ADCQ R14, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), NOSPLIT, $0-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ CX, R8 - MOVQ BX, R9 - MOVQ SI, R10 - MOVQ DI, R11 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - SUBQ 0(DX), R8 - SBBQ 8(DX), R9 - SBBQ 16(DX), R10 - SBBQ 24(DX), R11 - MOVQ $1, R12 - MOVQ $0, R13 - MOVQ $0, R14 - MOVQ $0x0800000000000011, R15 - CMOVQCC AX, R12 - CMOVQCC AX, R13 - CMOVQCC AX, R14 - CMOVQCC AX, R15 - ADDQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - ADCQ R15, R11 - MOVQ R8, 0(DX) - MOVQ R9, 8(DX) - MOVQ R10, 16(DX) - MOVQ R11, 24(DX) - - // reduce element(CX,BX,SI,DI) using temp registers (R8,R9,R10,R11) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - RET - -// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] -TEXT ·addVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - -loop_1: - TESTQ BX, BX - JEQ done_2 // n == 0, we are done - - // a[0] -> SI - // a[1] -> DI - // a[2] -> R8 - // a[3] -> R9 - MOVQ 0(AX), SI - MOVQ 8(AX), DI - MOVQ 16(AX), R8 - MOVQ 24(AX), R9 - ADDQ 0(DX), SI - ADCQ 8(DX), DI - ADCQ 16(DX), R8 - ADCQ 24(DX), R9 - - // reduce element(SI,DI,R8,R9) using temp registers (R10,R11,R12,R13) - REDUCE(SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ SI, 0(CX) - MOVQ DI, 8(CX) - MOVQ R8, 16(CX) - MOVQ R9, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_1 - -done_2: - RET - -// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] -TEXT ·subVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - XORQ SI, SI - -loop_3: - TESTQ BX, BX - JEQ done_4 // n == 0, we are done - - // a[0] -> DI - // a[1] -> R8 - // a[2] -> R9 - // a[3] -> R10 - MOVQ 0(AX), DI - MOVQ 8(AX), R8 - MOVQ 16(AX), R9 - MOVQ 24(AX), R10 - SUBQ 0(DX), DI - SBBQ 8(DX), R8 - SBBQ 16(DX), R9 - SBBQ 24(DX), R10 - - // reduce (a-b) mod q - // q[0] -> R11 - // q[1] -> R12 - // q[2] -> R13 - // q[3] -> R14 - MOVQ $1, R11 - MOVQ $0, R12 - MOVQ $0, R13 - MOVQ $0x0800000000000011, R14 - CMOVQCC SI, R11 - CMOVQCC SI, R12 - CMOVQCC SI, R13 - CMOVQCC SI, R14 - - // add registers (q or 0) to a, and set to result - ADDQ R11, DI - ADCQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - MOVQ DI, 0(CX) - MOVQ R8, 8(CX) - MOVQ R9, 16(CX) - MOVQ R10, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_3 - -done_4: - RET - -// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b -TEXT ·scalarMulVec(SB), $56-32 - CMPB ·supportAdx(SB), $1 - JNE noAdx_5 - MOVQ a+8(FP), R11 - MOVQ b+16(FP), R10 - MOVQ n+24(FP), R12 - - // scalar[0] -> SI - // scalar[1] -> DI - // scalar[2] -> R8 - // scalar[3] -> R9 - MOVQ 0(R10), SI - MOVQ 8(R10), DI - MOVQ 16(R10), R8 - MOVQ 24(R10), R9 - MOVQ res+0(FP), R10 - -loop_6: - TESTQ R12, R12 - JEQ done_7 // n == 0, we are done - - // TODO @gbotrel this is generated from the same macro as the unit mul, we should refactor this in a single asm function - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ SI, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ DI, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R8, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce t mod q - // reduce element(R14,R15,CX,BX) using temp registers (R13,AX,DX,s0-8(SP)) - REDUCE(R14,R15,CX,BX,R13,AX,DX,s0-8(SP)) - - MOVQ R14, 0(R10) - MOVQ R15, 8(R10) - MOVQ CX, 16(R10) - MOVQ BX, 24(R10) - - // increment pointers to visit next element - ADDQ $32, R11 - ADDQ $32, R10 - DECQ R12 // decrement n - JMP loop_6 - -done_7: - RET - -noAdx_5: - MOVQ n+24(FP), DX - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ DX, 8(SP) - MOVQ DX, 16(SP) - MOVQ a+8(FP), AX - MOVQ AX, 24(SP) - MOVQ DX, 32(SP) - MOVQ DX, 40(SP) - MOVQ b+16(FP), AX - MOVQ AX, 48(SP) - CALL ·scalarMulVecGeneric(SB) - RET diff --git a/ecc/stark-curve/fp/element_ops_purego.go b/ecc/stark-curve/fp/element_ops_purego.go index 4906d13e0..19cb3649b 100644 --- a/ecc/stark-curve/fp/element_ops_purego.go +++ b/ecc/stark-curve/fp/element_ops_purego.go @@ -78,53 +78,32 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} + // Mul z = x * y (mod q) // // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3 uint64 var u0, u1, u2, u3 uint64 diff --git a/ecc/stark-curve/fp/element_test.go b/ecc/stark-curve/fp/element_test.go index 87e38f7c1..c33f3f21b 100644 --- a/ecc/stark-curve/fp/element_test.go +++ b/ecc/stark-curve/fp/element_test.go @@ -637,7 +637,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -707,77 +706,6 @@ func TestElementLexicographicallyLargest(t *testing.T) { } -func TestElementVecOps(t *testing.T) { - assert := require.New(t) - - const N = 7 - a := make(Vector, N) - b := make(Vector, N) - c := make(Vector, N) - for i := 0; i < N; i++ { - a[i].SetRandom() - b[i].SetRandom() - } - - // Vector addition - c.Add(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Add(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector addition failed") - } - - // Vector subtraction - c.Sub(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Sub(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector subtraction failed") - } - - // Vector scaling - c.ScalarMul(a, &b[0]) - for i := 0; i < N; i++ { - var expected Element - expected.Mul(&a[i], &b[0]) - assert.True(c[i].Equal(&expected), "Vector scaling failed") - } -} - -func BenchmarkElementVecOps(b *testing.B) { - // note; to benchmark against "no asm" version, use the following - // build tag: -tags purego - const N = 1024 - a1 := make(Vector, N) - b1 := make(Vector, N) - c1 := make(Vector, N) - for i := 0; i < N; i++ { - a1[i].SetRandom() - b1[i].SetRandom() - } - - b.Run("Add", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Add(a1, b1) - } - }) - - b.Run("Sub", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Sub(a1, b1) - } - }) - - b.Run("ScalarMul", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.ScalarMul(a1, &b1[0]) - } - }) -} - func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() @@ -2345,38 +2273,38 @@ func gen() gopter.Gen { } } -func genFull() gopter.Gen { - return func(genParams *gopter.GenParameters) *gopter.GenResult { +func genRandomFq(genParams *gopter.GenParameters) Element { + var g Element - genRandomFq := func() Element { - var g Element + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } - if qElement[3] != ^uint64(0) { - g[3] %= (qElement[3] + 1) - } + for !g.smallerThanModulus() { + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } + } - for !g.smallerThanModulus() { - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } - if qElement[3] != ^uint64(0) { - g[3] %= (qElement[3] + 1) - } - } + return g +} - return g - } - a := genRandomFq() +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) var carry uint64 a[0], carry = bits.Add64(a[0], qElement[0], carry) @@ -2389,6 +2317,14 @@ func genFull() gopter.Gen { } } +func genElement() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} + func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { var modulus big.Int var aIntMod big.Int diff --git a/ecc/stark-curve/fp/vector.go b/ecc/stark-curve/fp/vector.go index 850b3603d..c97b4283c 100644 --- a/ecc/stark-curve/fp/vector.go +++ b/ecc/stark-curve/fp/vector.go @@ -226,6 +226,32 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + +func innerProductVecGeneric(res *Element, a, b Vector) { + if len(a) != len(b) { + panic("vector.InnerProduct: vectors don't have the same length") + } + var tmp Element + for i := 0; i < len(a); i++ { + tmp.Mul(&a[i], &b[i]) + res.Add(res, &tmp) + } +} + +func mulVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Mul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], &b[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/stark-curve/fp/vector_test.go b/ecc/stark-curve/fp/vector_test.go index 5d88af91c..12f17e21f 100644 --- a/ecc/stark-curve/fp/vector_test.go +++ b/ecc/stark-curve/fp/vector_test.go @@ -18,10 +18,15 @@ package fp import ( "bytes" + "fmt" "github.com/stretchr/testify/require" + "os" "reflect" "sort" "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" ) func TestVectorSort(t *testing.T) { @@ -88,3 +93,279 @@ func (vector *Vector) unmarshalBinaryAsync(data []byte) error { } return <-chErr } + +func TestVectorOps(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = 2 + } else { + parameters.MinSuccessfulTests = 10 + } + properties := gopter.NewProperties(parameters) + + addVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Add(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Add(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + subVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Sub(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Sub(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + scalarMulVector := func(a Vector, b Element) bool { + c := make(Vector, len(a)) + c.ScalarMul(a, &b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sumVector := func(a Vector) bool { + var sum Element + computed := a.Sum() + for i := 0; i < len(a); i++ { + sum.Add(&sum, &a[i]) + } + + return sum.Equal(&computed) + } + + innerProductVector := func(a, b Vector) bool { + computed := a.InnerProduct(b) + var innerProduct Element + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + innerProduct.Add(&innerProduct, &tmp) + } + + return innerProduct.Equal(&computed) + } + + mulVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + a[0].SetUint64(0x24) + b[0].SetUint64(0x42) + c.Mul(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sizes := []int{1, 2, 3, 4, 8, 9, 15, 16, 509, 510, 511, 512, 513, 514} + type genPair struct { + g1, g2 gopter.Gen + label string + } + + for _, size := range sizes { + generators := []genPair{ + {genZeroVector(size), genZeroVector(size), "zero vectors"}, + {genMaxVector(size), genMaxVector(size), "max vectors"}, + {genVector(size), genVector(size), "random vectors"}, + {genVector(size), genZeroVector(size), "random and zero vectors"}, + } + for _, gp := range generators { + properties.Property(fmt.Sprintf("vector addition %d - %s", size, gp.label), prop.ForAll( + addVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector subtraction %d - %s", size, gp.label), prop.ForAll( + subVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector scalar multiplication %d - %s", size, gp.label), prop.ForAll( + scalarMulVector, + gp.g1, + genElement(), + )) + + properties.Property(fmt.Sprintf("vector sum %d - %s", size, gp.label), prop.ForAll( + sumVector, + gp.g1, + )) + + properties.Property(fmt.Sprintf("vector inner product %d - %s", size, gp.label), prop.ForAll( + innerProductVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector multiplication %d - %s", size, gp.label), prop.ForAll( + mulVector, + gp.g1, + gp.g2, + )) + } + } + + properties.TestingRun(t, gopter.NewFormatedReporter(false, 260, os.Stdout)) +} + +func BenchmarkVectorOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1 << 24 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + var mixer Element + mixer.SetRandom() + for i := 1; i < N; i++ { + a1[i-1].SetUint64(uint64(i)). + Mul(&a1[i-1], &mixer) + b1[i-1].SetUint64(^uint64(i)). + Mul(&b1[i-1], &mixer) + } + + for n := 1 << 4; n <= N; n <<= 1 { + b.Run(fmt.Sprintf("add %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Add(_a, _b) + } + }) + + b.Run(fmt.Sprintf("sub %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Sub(_a, _b) + } + }) + + b.Run(fmt.Sprintf("scalarMul %d", n), func(b *testing.B) { + _a := a1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.ScalarMul(_a, &mixer) + } + }) + + b.Run(fmt.Sprintf("sum %d", n), func(b *testing.B) { + _a := a1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.Sum() + } + }) + + b.Run(fmt.Sprintf("innerProduct %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.InnerProduct(_b) + } + }) + + b.Run(fmt.Sprintf("mul %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Mul(_a, _b) + } + }) + } +} + +func genZeroVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genMaxVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + + qMinusOne := qElement + qMinusOne[0]-- + + for i := 0; i < size; i++ { + g[i] = qMinusOne + } + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + mixer := Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + mixer[3] %= (qElement[3] + 1) + } + + for !mixer.smallerThanModulus() { + mixer = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + mixer[3] %= (qElement[3] + 1) + } + } + + for i := 1; i <= size; i++ { + g[i-1].SetUint64(uint64(i)). + Mul(&g[i-1], &mixer) + } + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} diff --git a/ecc/stark-curve/fr/asm.go b/ecc/stark-curve/fr/asm_adx.go similarity index 100% rename from ecc/stark-curve/fr/asm.go rename to ecc/stark-curve/fr/asm_adx.go diff --git a/ecc/stark-curve/fr/asm_avx.go b/ecc/stark-curve/fr/asm_avx.go new file mode 100644 index 000000000..955f55979 --- /dev/null +++ b/ecc/stark-curve/fr/asm_avx.go @@ -0,0 +1,27 @@ +//go:build !noavx +// +build !noavx + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import "golang.org/x/sys/cpu" + +var ( + supportAvx512 = supportAdx && cpu.X86.HasAVX512 && cpu.X86.HasAVX512DQ + _ = supportAvx512 +) diff --git a/ecc/stark-curve/fr/asm_noavx.go b/ecc/stark-curve/fr/asm_noavx.go new file mode 100644 index 000000000..e5a5b1f2c --- /dev/null +++ b/ecc/stark-curve/fr/asm_noavx.go @@ -0,0 +1,22 @@ +//go:build noavx +// +build noavx + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +const supportAvx512 = false diff --git a/ecc/stark-curve/fr/element.go b/ecc/stark-curve/fr/element.go index a7ab8e217..216e287eb 100644 --- a/ecc/stark-curve/fr/element.go +++ b/ecc/stark-curve/fr/element.go @@ -81,6 +81,9 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 13504954208620504625 +// mu = 2^288 / q needed for partial Barrett reduction +const mu uint64 = 137438953471 + func init() { _modulus.SetString("800000000000010ffffffffffffffffb781126dcae7b2321e66a241adc64d2f", 16) } @@ -477,32 +480,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [5]uint64 var D uint64 diff --git a/ecc/stark-curve/fr/element_mul_amd64.s b/ecc/stark-curve/fr/element_mul_amd64.s deleted file mode 100644 index 8eb931e77..000000000 --- a/ecc/stark-curve/fr/element_mul_amd64.s +++ /dev/null @@ -1,487 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x1e66a241adc64d2f -DATA q<>+8(SB)/8, $0xb781126dcae7b232 -DATA q<>+16(SB)/8, $0xffffffffffffffff -DATA q<>+24(SB)/8, $0x0800000000000010 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0xbb6b3c4ce8bde631 -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $24-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), SI - - // x[0] -> DI - // x[1] -> R8 - // x[2] -> R9 - // x[3] -> R10 - MOVQ 0(SI), DI - MOVQ 8(SI), R8 - MOVQ 16(SI), R9 - MOVQ 24(SI), R10 - MOVQ y+16(FP), R11 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ DI, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R8, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R9, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,R12,R11,DI) - REDUCE(R14,R13,CX,BX,SI,R12,R11,DI) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $8-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,DI,R8,R9) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/stark-curve/fr/element_ops_amd64.go b/ecc/stark-curve/fr/element_ops_amd64.go index 21568255d..b653e8006 100644 --- a/ecc/stark-curve/fr/element_ops_amd64.go +++ b/ecc/stark-curve/fr/element_ops_amd64.go @@ -51,7 +51,8 @@ func (vector *Vector) Add(a, b Vector) { if len(a) != len(b) || len(a) != len(*vector) { panic("vector.Add: vectors don't have the same length") } - addVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) + n := uint64(len(a)) + addVec(&(*vector)[0], &a[0], &b[0], n) } //go:noescape @@ -75,59 +76,123 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { if len(a) != len(*vector) { panic("vector.ScalarMul: vectors don't have the same length") } - scalarMulVec(&(*vector)[0], &a[0], b, uint64(len(a))) + const maxN = (1 << 32) - 1 + if !supportAvx512 || uint64(len(a)) >= maxN { + // call scalarMulVecGeneric + scalarMulVecGeneric(*vector, a, b) + return + } + n := uint64(len(a)) + if n == 0 { + return + } + // the code for scalarMul is identical to mulVec; and it expects at least + // 2 elements in the vector to fill the Z registers + var bb [2]Element + bb[0] = *b + bb[1] = *b + const blockSize = 16 + scalarMulVec(&(*vector)[0], &a[0], &bb[0], n/blockSize, qInvNeg) + if n%blockSize != 0 { + // call scalarMulVecGeneric on the rest + start := n - n%blockSize + scalarMulVecGeneric((*vector)[start:], a[start:], b) + } +} + +//go:noescape +func scalarMulVec(res, a, b *Element, n uint64, qInvNeg uint64) + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + n := uint64(len(*vector)) + if n == 0 { + return + } + const minN = 16 * 7 // AVX512 slower than generic for small n + const maxN = (1 << 32) - 1 + if !supportAvx512 || n <= minN || n >= maxN { + // call sumVecGeneric + sumVecGeneric(&res, *vector) + return + } + sumVec(&res, &(*vector)[0], uint64(len(*vector))) + return } //go:noescape -func scalarMulVec(res, a, b *Element, n uint64) +func sumVec(res *Element, a *Element, n uint64) + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + n := uint64(len(*vector)) + if n == 0 { + return + } + if n != uint64(len(other)) { + panic("vector.InnerProduct: vectors don't have the same length") + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call innerProductVecGeneric + // note; we could split the vector into smaller chunks and call innerProductVec + innerProductVecGeneric(&res, *vector, other) + return + } + innerProdVec(&res[0], &(*vector)[0], &other[0], uint64(len(*vector))) + + return +} + +//go:noescape +func innerProdVec(res *uint64, a, b *Element, n uint64) + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Mul: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call mulVecGeneric + mulVecGeneric(*vector, a, b) + return + } + + const blockSize = 16 + mulVec(&(*vector)[0], &a[0], &b[0], n/blockSize, qInvNeg) + if n%blockSize != 0 { + // call mulVecGeneric on the rest + start := n - n%blockSize + mulVecGeneric((*vector)[start:], a[start:], b[start:]) + } + +} + +// Patterns use for transposing the vectors in mulVec +var ( + pattern1 = [8]uint64{0, 8, 1, 9, 2, 10, 3, 11} + pattern2 = [8]uint64{12, 4, 13, 5, 14, 6, 15, 7} + pattern3 = [8]uint64{0, 1, 8, 9, 2, 3, 10, 11} + pattern4 = [8]uint64{12, 13, 4, 5, 14, 15, 6, 7} +) + +//go:noescape +func mulVec(res, a, b *Element, n uint64, qInvNeg uint64) // Mul z = x * y (mod q) // // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/stark-curve/fr/element_ops_amd64.s b/ecc/stark-curve/fr/element_ops_amd64.s index 245dcb895..6c42136a7 100644 --- a/ecc/stark-curve/fr/element_ops_amd64.s +++ b/ecc/stark-curve/fr/element_ops_amd64.s @@ -1,627 +1,6 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// Code generated by gnark-crypto/generator. DO NOT EDIT. +// We include the hash to force the Go compiler to recompile: 9425145785761608449 +#include "../../../field/asm/element_4w_amd64.s" -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x1e66a241adc64d2f -DATA q<>+8(SB)/8, $0xb781126dcae7b232 -DATA q<>+16(SB)/8, $0xffffffffffffffff -DATA q<>+24(SB)/8, $0x0800000000000010 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0xbb6b3c4ce8bde631 -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R15,DI,R8,R9) - REDUCE(DX,CX,BX,SI,R15,DI,R8,R9) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, R11 - MOVQ CX, R12 - MOVQ BX, R13 - MOVQ SI, R14 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ R11, DX - ADCQ R12, CX - ADCQ R13, BX - ADCQ R14, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), NOSPLIT, $0-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ CX, R8 - MOVQ BX, R9 - MOVQ SI, R10 - MOVQ DI, R11 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - SUBQ 0(DX), R8 - SBBQ 8(DX), R9 - SBBQ 16(DX), R10 - SBBQ 24(DX), R11 - MOVQ $0x1e66a241adc64d2f, R12 - MOVQ $0xb781126dcae7b232, R13 - MOVQ $0xffffffffffffffff, R14 - MOVQ $0x0800000000000010, R15 - CMOVQCC AX, R12 - CMOVQCC AX, R13 - CMOVQCC AX, R14 - CMOVQCC AX, R15 - ADDQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - ADCQ R15, R11 - MOVQ R8, 0(DX) - MOVQ R9, 8(DX) - MOVQ R10, 16(DX) - MOVQ R11, 24(DX) - - // reduce element(CX,BX,SI,DI) using temp registers (R8,R9,R10,R11) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - RET - -// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] -TEXT ·addVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - -loop_1: - TESTQ BX, BX - JEQ done_2 // n == 0, we are done - - // a[0] -> SI - // a[1] -> DI - // a[2] -> R8 - // a[3] -> R9 - MOVQ 0(AX), SI - MOVQ 8(AX), DI - MOVQ 16(AX), R8 - MOVQ 24(AX), R9 - ADDQ 0(DX), SI - ADCQ 8(DX), DI - ADCQ 16(DX), R8 - ADCQ 24(DX), R9 - - // reduce element(SI,DI,R8,R9) using temp registers (R10,R11,R12,R13) - REDUCE(SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ SI, 0(CX) - MOVQ DI, 8(CX) - MOVQ R8, 16(CX) - MOVQ R9, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_1 - -done_2: - RET - -// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] -TEXT ·subVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - XORQ SI, SI - -loop_3: - TESTQ BX, BX - JEQ done_4 // n == 0, we are done - - // a[0] -> DI - // a[1] -> R8 - // a[2] -> R9 - // a[3] -> R10 - MOVQ 0(AX), DI - MOVQ 8(AX), R8 - MOVQ 16(AX), R9 - MOVQ 24(AX), R10 - SUBQ 0(DX), DI - SBBQ 8(DX), R8 - SBBQ 16(DX), R9 - SBBQ 24(DX), R10 - - // reduce (a-b) mod q - // q[0] -> R11 - // q[1] -> R12 - // q[2] -> R13 - // q[3] -> R14 - MOVQ $0x1e66a241adc64d2f, R11 - MOVQ $0xb781126dcae7b232, R12 - MOVQ $0xffffffffffffffff, R13 - MOVQ $0x0800000000000010, R14 - CMOVQCC SI, R11 - CMOVQCC SI, R12 - CMOVQCC SI, R13 - CMOVQCC SI, R14 - - // add registers (q or 0) to a, and set to result - ADDQ R11, DI - ADCQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - MOVQ DI, 0(CX) - MOVQ R8, 8(CX) - MOVQ R9, 16(CX) - MOVQ R10, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_3 - -done_4: - RET - -// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b -TEXT ·scalarMulVec(SB), $56-32 - CMPB ·supportAdx(SB), $1 - JNE noAdx_5 - MOVQ a+8(FP), R11 - MOVQ b+16(FP), R10 - MOVQ n+24(FP), R12 - - // scalar[0] -> SI - // scalar[1] -> DI - // scalar[2] -> R8 - // scalar[3] -> R9 - MOVQ 0(R10), SI - MOVQ 8(R10), DI - MOVQ 16(R10), R8 - MOVQ 24(R10), R9 - MOVQ res+0(FP), R10 - -loop_6: - TESTQ R12, R12 - JEQ done_7 // n == 0, we are done - - // TODO @gbotrel this is generated from the same macro as the unit mul, we should refactor this in a single asm function - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ SI, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ DI, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R8, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce t mod q - // reduce element(R14,R15,CX,BX) using temp registers (R13,AX,DX,s0-8(SP)) - REDUCE(R14,R15,CX,BX,R13,AX,DX,s0-8(SP)) - - MOVQ R14, 0(R10) - MOVQ R15, 8(R10) - MOVQ CX, 16(R10) - MOVQ BX, 24(R10) - - // increment pointers to visit next element - ADDQ $32, R11 - ADDQ $32, R10 - DECQ R12 // decrement n - JMP loop_6 - -done_7: - RET - -noAdx_5: - MOVQ n+24(FP), DX - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ DX, 8(SP) - MOVQ DX, 16(SP) - MOVQ a+8(FP), AX - MOVQ AX, 24(SP) - MOVQ DX, 32(SP) - MOVQ DX, 40(SP) - MOVQ b+16(FP), AX - MOVQ AX, 48(SP) - CALL ·scalarMulVecGeneric(SB) - RET diff --git a/ecc/stark-curve/fr/element_ops_purego.go b/ecc/stark-curve/fr/element_ops_purego.go index b04f5202f..2d0db6915 100644 --- a/ecc/stark-curve/fr/element_ops_purego.go +++ b/ecc/stark-curve/fr/element_ops_purego.go @@ -78,53 +78,32 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} + // Mul z = x * y (mod q) // // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3 uint64 var u0, u1, u2, u3 uint64 diff --git a/ecc/stark-curve/fr/element_test.go b/ecc/stark-curve/fr/element_test.go index b81aff116..fb10e5a55 100644 --- a/ecc/stark-curve/fr/element_test.go +++ b/ecc/stark-curve/fr/element_test.go @@ -637,7 +637,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -707,77 +706,6 @@ func TestElementLexicographicallyLargest(t *testing.T) { } -func TestElementVecOps(t *testing.T) { - assert := require.New(t) - - const N = 7 - a := make(Vector, N) - b := make(Vector, N) - c := make(Vector, N) - for i := 0; i < N; i++ { - a[i].SetRandom() - b[i].SetRandom() - } - - // Vector addition - c.Add(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Add(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector addition failed") - } - - // Vector subtraction - c.Sub(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Sub(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector subtraction failed") - } - - // Vector scaling - c.ScalarMul(a, &b[0]) - for i := 0; i < N; i++ { - var expected Element - expected.Mul(&a[i], &b[0]) - assert.True(c[i].Equal(&expected), "Vector scaling failed") - } -} - -func BenchmarkElementVecOps(b *testing.B) { - // note; to benchmark against "no asm" version, use the following - // build tag: -tags purego - const N = 1024 - a1 := make(Vector, N) - b1 := make(Vector, N) - c1 := make(Vector, N) - for i := 0; i < N; i++ { - a1[i].SetRandom() - b1[i].SetRandom() - } - - b.Run("Add", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Add(a1, b1) - } - }) - - b.Run("Sub", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Sub(a1, b1) - } - }) - - b.Run("ScalarMul", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.ScalarMul(a1, &b1[0]) - } - }) -} - func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() @@ -2297,38 +2225,38 @@ func gen() gopter.Gen { } } -func genFull() gopter.Gen { - return func(genParams *gopter.GenParameters) *gopter.GenResult { +func genRandomFq(genParams *gopter.GenParameters) Element { + var g Element - genRandomFq := func() Element { - var g Element + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } - if qElement[3] != ^uint64(0) { - g[3] %= (qElement[3] + 1) - } + for !g.smallerThanModulus() { + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } + } - for !g.smallerThanModulus() { - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } - if qElement[3] != ^uint64(0) { - g[3] %= (qElement[3] + 1) - } - } + return g +} - return g - } - a := genRandomFq() +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) var carry uint64 a[0], carry = bits.Add64(a[0], qElement[0], carry) @@ -2341,6 +2269,14 @@ func genFull() gopter.Gen { } } +func genElement() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} + func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { var modulus big.Int var aIntMod big.Int diff --git a/ecc/stark-curve/fr/vector.go b/ecc/stark-curve/fr/vector.go index f39828547..867cabbc3 100644 --- a/ecc/stark-curve/fr/vector.go +++ b/ecc/stark-curve/fr/vector.go @@ -226,6 +226,32 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + +func innerProductVecGeneric(res *Element, a, b Vector) { + if len(a) != len(b) { + panic("vector.InnerProduct: vectors don't have the same length") + } + var tmp Element + for i := 0; i < len(a); i++ { + tmp.Mul(&a[i], &b[i]) + res.Add(res, &tmp) + } +} + +func mulVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Mul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], &b[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/stark-curve/fr/vector_test.go b/ecc/stark-curve/fr/vector_test.go index e58f2d9a3..b6344c18b 100644 --- a/ecc/stark-curve/fr/vector_test.go +++ b/ecc/stark-curve/fr/vector_test.go @@ -18,10 +18,15 @@ package fr import ( "bytes" + "fmt" "github.com/stretchr/testify/require" + "os" "reflect" "sort" "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" ) func TestVectorSort(t *testing.T) { @@ -88,3 +93,279 @@ func (vector *Vector) unmarshalBinaryAsync(data []byte) error { } return <-chErr } + +func TestVectorOps(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = 2 + } else { + parameters.MinSuccessfulTests = 10 + } + properties := gopter.NewProperties(parameters) + + addVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Add(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Add(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + subVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Sub(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Sub(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + scalarMulVector := func(a Vector, b Element) bool { + c := make(Vector, len(a)) + c.ScalarMul(a, &b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sumVector := func(a Vector) bool { + var sum Element + computed := a.Sum() + for i := 0; i < len(a); i++ { + sum.Add(&sum, &a[i]) + } + + return sum.Equal(&computed) + } + + innerProductVector := func(a, b Vector) bool { + computed := a.InnerProduct(b) + var innerProduct Element + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + innerProduct.Add(&innerProduct, &tmp) + } + + return innerProduct.Equal(&computed) + } + + mulVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + a[0].SetUint64(0x24) + b[0].SetUint64(0x42) + c.Mul(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sizes := []int{1, 2, 3, 4, 8, 9, 15, 16, 509, 510, 511, 512, 513, 514} + type genPair struct { + g1, g2 gopter.Gen + label string + } + + for _, size := range sizes { + generators := []genPair{ + {genZeroVector(size), genZeroVector(size), "zero vectors"}, + {genMaxVector(size), genMaxVector(size), "max vectors"}, + {genVector(size), genVector(size), "random vectors"}, + {genVector(size), genZeroVector(size), "random and zero vectors"}, + } + for _, gp := range generators { + properties.Property(fmt.Sprintf("vector addition %d - %s", size, gp.label), prop.ForAll( + addVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector subtraction %d - %s", size, gp.label), prop.ForAll( + subVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector scalar multiplication %d - %s", size, gp.label), prop.ForAll( + scalarMulVector, + gp.g1, + genElement(), + )) + + properties.Property(fmt.Sprintf("vector sum %d - %s", size, gp.label), prop.ForAll( + sumVector, + gp.g1, + )) + + properties.Property(fmt.Sprintf("vector inner product %d - %s", size, gp.label), prop.ForAll( + innerProductVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector multiplication %d - %s", size, gp.label), prop.ForAll( + mulVector, + gp.g1, + gp.g2, + )) + } + } + + properties.TestingRun(t, gopter.NewFormatedReporter(false, 260, os.Stdout)) +} + +func BenchmarkVectorOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1 << 24 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + var mixer Element + mixer.SetRandom() + for i := 1; i < N; i++ { + a1[i-1].SetUint64(uint64(i)). + Mul(&a1[i-1], &mixer) + b1[i-1].SetUint64(^uint64(i)). + Mul(&b1[i-1], &mixer) + } + + for n := 1 << 4; n <= N; n <<= 1 { + b.Run(fmt.Sprintf("add %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Add(_a, _b) + } + }) + + b.Run(fmt.Sprintf("sub %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Sub(_a, _b) + } + }) + + b.Run(fmt.Sprintf("scalarMul %d", n), func(b *testing.B) { + _a := a1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.ScalarMul(_a, &mixer) + } + }) + + b.Run(fmt.Sprintf("sum %d", n), func(b *testing.B) { + _a := a1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.Sum() + } + }) + + b.Run(fmt.Sprintf("innerProduct %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.InnerProduct(_b) + } + }) + + b.Run(fmt.Sprintf("mul %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Mul(_a, _b) + } + }) + } +} + +func genZeroVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genMaxVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + + qMinusOne := qElement + qMinusOne[0]-- + + for i := 0; i < size; i++ { + g[i] = qMinusOne + } + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + mixer := Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + mixer[3] %= (qElement[3] + 1) + } + + for !mixer.smallerThanModulus() { + mixer = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + mixer[3] %= (qElement[3] + 1) + } + } + + for i := 1; i <= size; i++ { + g[i-1].SetUint64(uint64(i)). + Mul(&g[i-1], &mixer) + } + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} diff --git a/field/asm/.gitignore b/field/asm/.gitignore new file mode 100644 index 000000000..7c22f7f93 --- /dev/null +++ b/field/asm/.gitignore @@ -0,0 +1,6 @@ +# generated by integration tests +element_2w_amd64.s +element_3w_amd64.s +element_7w_amd64.s +element_8w_amd64.s +*.h \ No newline at end of file diff --git a/field/asm/element_10w_amd64.s b/field/asm/element_10w_amd64.s new file mode 100644 index 000000000..e5f10dca8 --- /dev/null +++ b/field/asm/element_10w_amd64.s @@ -0,0 +1,1187 @@ +// Code generated by gnark-crypto/generator. DO NOT EDIT. +#include "textflag.h" +#include "funcdata.h" +#include "go_asm.h" + +#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, ra6, ra7, ra8, ra9, rb0, rb1, rb2, rb3, rb4, rb5, rb6, rb7, rb8, rb9) \ + MOVQ ra0, rb0; \ + SUBQ ·qElement(SB), ra0; \ + MOVQ ra1, rb1; \ + SBBQ ·qElement+8(SB), ra1; \ + MOVQ ra2, rb2; \ + SBBQ ·qElement+16(SB), ra2; \ + MOVQ ra3, rb3; \ + SBBQ ·qElement+24(SB), ra3; \ + MOVQ ra4, rb4; \ + SBBQ ·qElement+32(SB), ra4; \ + MOVQ ra5, rb5; \ + SBBQ ·qElement+40(SB), ra5; \ + MOVQ ra6, rb6; \ + SBBQ ·qElement+48(SB), ra6; \ + MOVQ ra7, rb7; \ + SBBQ ·qElement+56(SB), ra7; \ + MOVQ ra8, rb8; \ + SBBQ ·qElement+64(SB), ra8; \ + MOVQ ra9, rb9; \ + SBBQ ·qElement+72(SB), ra9; \ + CMOVQCS rb0, ra0; \ + CMOVQCS rb1, ra1; \ + CMOVQCS rb2, ra2; \ + CMOVQCS rb3, ra3; \ + CMOVQCS rb4, ra4; \ + CMOVQCS rb5, ra5; \ + CMOVQCS rb6, ra6; \ + CMOVQCS rb7, ra7; \ + CMOVQCS rb8, ra8; \ + CMOVQCS rb9, ra9; \ + +TEXT ·reduce(SB), $56-8 + MOVQ res+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + MOVQ 48(AX), R9 + MOVQ 56(AX), R10 + MOVQ 64(AX), R11 + MOVQ 72(AX), R12 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + MOVQ R9, 48(AX) + MOVQ R10, 56(AX) + MOVQ R11, 64(AX) + MOVQ R12, 72(AX) + RET + +// MulBy3(x *Element) +TEXT ·MulBy3(SB), $56-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + MOVQ 48(AX), R9 + MOVQ 56(AX), R10 + MOVQ 64(AX), R11 + MOVQ 72(AX), R12 + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + ADCQ R9, R9 + ADCQ R10, R10 + ADCQ R11, R11 + ADCQ R12, R12 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + ADCQ 40(AX), R8 + ADCQ 48(AX), R9 + ADCQ 56(AX), R10 + ADCQ 64(AX), R11 + ADCQ 72(AX), R12 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + MOVQ R9, 48(AX) + MOVQ R10, 56(AX) + MOVQ R11, 64(AX) + MOVQ R12, 72(AX) + RET + +// MulBy5(x *Element) +TEXT ·MulBy5(SB), $56-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + MOVQ 48(AX), R9 + MOVQ 56(AX), R10 + MOVQ 64(AX), R11 + MOVQ 72(AX), R12 + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + ADCQ R9, R9 + ADCQ R10, R10 + ADCQ R11, R11 + ADCQ R12, R12 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + ADCQ R9, R9 + ADCQ R10, R10 + ADCQ R11, R11 + ADCQ R12, R12 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + ADCQ 40(AX), R8 + ADCQ 48(AX), R9 + ADCQ 56(AX), R10 + ADCQ 64(AX), R11 + ADCQ 72(AX), R12 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + MOVQ R9, 48(AX) + MOVQ R10, 56(AX) + MOVQ R11, 64(AX) + MOVQ R12, 72(AX) + RET + +// MulBy13(x *Element) +TEXT ·MulBy13(SB), $136-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + MOVQ 48(AX), R9 + MOVQ 56(AX), R10 + MOVQ 64(AX), R11 + MOVQ 72(AX), R12 + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + ADCQ R9, R9 + ADCQ R10, R10 + ADCQ R11, R11 + ADCQ R12, R12 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + ADCQ R9, R9 + ADCQ R10, R10 + ADCQ R11, R11 + ADCQ R12, R12 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP),s11-96(SP),s12-104(SP),s13-112(SP),s14-120(SP),s15-128(SP),s16-136(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP),s11-96(SP),s12-104(SP),s13-112(SP),s14-120(SP),s15-128(SP),s16-136(SP)) + + MOVQ DX, s7-64(SP) + MOVQ CX, s8-72(SP) + MOVQ BX, s9-80(SP) + MOVQ SI, s10-88(SP) + MOVQ DI, s11-96(SP) + MOVQ R8, s12-104(SP) + MOVQ R9, s13-112(SP) + MOVQ R10, s14-120(SP) + MOVQ R11, s15-128(SP) + MOVQ R12, s16-136(SP) + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + ADCQ R9, R9 + ADCQ R10, R10 + ADCQ R11, R11 + ADCQ R12, R12 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + + ADDQ s7-64(SP), DX + ADCQ s8-72(SP), CX + ADCQ s9-80(SP), BX + ADCQ s10-88(SP), SI + ADCQ s11-96(SP), DI + ADCQ s12-104(SP), R8 + ADCQ s13-112(SP), R9 + ADCQ s14-120(SP), R10 + ADCQ s15-128(SP), R11 + ADCQ s16-136(SP), R12 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + ADCQ 40(AX), R8 + ADCQ 48(AX), R9 + ADCQ 56(AX), R10 + ADCQ 64(AX), R11 + ADCQ 72(AX), R12 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + MOVQ R9, 48(AX) + MOVQ R10, 56(AX) + MOVQ R11, 64(AX) + MOVQ R12, 72(AX) + RET + +// Butterfly(a, b *Element) sets a = a + b; b = a - b +TEXT ·Butterfly(SB), $56-16 + MOVQ b+8(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + MOVQ 48(AX), R9 + MOVQ 56(AX), R10 + MOVQ 64(AX), R11 + MOVQ 72(AX), R12 + MOVQ a+0(FP), AX + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + ADCQ 40(AX), R8 + ADCQ 48(AX), R9 + ADCQ 56(AX), R10 + ADCQ 64(AX), R11 + ADCQ 72(AX), R12 + MOVQ DX, R13 + MOVQ CX, R14 + MOVQ BX, R15 + MOVQ SI, s0-8(SP) + MOVQ DI, s1-16(SP) + MOVQ R8, s2-24(SP) + MOVQ R9, s3-32(SP) + MOVQ R10, s4-40(SP) + MOVQ R11, s5-48(SP) + MOVQ R12, s6-56(SP) + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + MOVQ 48(AX), R9 + MOVQ 56(AX), R10 + MOVQ 64(AX), R11 + MOVQ 72(AX), R12 + MOVQ b+8(FP), AX + SUBQ 0(AX), DX + SBBQ 8(AX), CX + SBBQ 16(AX), BX + SBBQ 24(AX), SI + SBBQ 32(AX), DI + SBBQ 40(AX), R8 + SBBQ 48(AX), R9 + SBBQ 56(AX), R10 + SBBQ 64(AX), R11 + SBBQ 72(AX), R12 + JCC noReduce_1 + MOVQ $const_q0, AX + ADDQ AX, DX + MOVQ $const_q1, AX + ADCQ AX, CX + MOVQ $const_q2, AX + ADCQ AX, BX + MOVQ $const_q3, AX + ADCQ AX, SI + MOVQ $const_q4, AX + ADCQ AX, DI + MOVQ $const_q5, AX + ADCQ AX, R8 + MOVQ $const_q6, AX + ADCQ AX, R9 + MOVQ $const_q7, AX + ADCQ AX, R10 + MOVQ $const_q8, AX + ADCQ AX, R11 + MOVQ $const_q9, AX + ADCQ AX, R12 + +noReduce_1: + MOVQ b+8(FP), AX + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + MOVQ R9, 48(AX) + MOVQ R10, 56(AX) + MOVQ R11, 64(AX) + MOVQ R12, 72(AX) + MOVQ R13, DX + MOVQ R14, CX + MOVQ R15, BX + MOVQ s0-8(SP), SI + MOVQ s1-16(SP), DI + MOVQ s2-24(SP), R8 + MOVQ s3-32(SP), R9 + MOVQ s4-40(SP), R10 + MOVQ s5-48(SP), R11 + MOVQ s6-56(SP), R12 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + + MOVQ a+0(FP), AX + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + MOVQ R9, 48(AX) + MOVQ R10, 56(AX) + MOVQ R11, 64(AX) + MOVQ R12, 72(AX) + RET + +// mul(res, x, y *Element) +TEXT ·mul(SB), $64-24 + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + // See github.com/gnark-crypto/field/generator for more comments. + + NO_LOCAL_POINTERS + CMPB ·supportAdx(SB), $1 + JNE noAdx_2 + MOVQ x+8(FP), R12 + MOVQ y+16(FP), R13 + + // A -> BP + // t[0] -> R14 + // t[1] -> R15 + // t[2] -> CX + // t[3] -> BX + // t[4] -> SI + // t[5] -> DI + // t[6] -> R8 + // t[7] -> R9 + // t[8] -> R10 + // t[9] -> R11 +#define MACC(in0, in1, in2) \ + ADCXQ in0, in1 \ + MULXQ in2, AX, in0 \ + ADOXQ AX, in1 \ + +#define DIV_SHIFT() \ + PUSHQ BP \ + MOVQ $const_qInvNeg, DX \ + IMULQ R14, DX \ + XORQ AX, AX \ + MULXQ ·qElement+0(SB), AX, BP \ + ADCXQ R14, AX \ + MOVQ BP, R14 \ + POPQ BP \ + MACC(R15, R14, ·qElement+8(SB)) \ + MACC(CX, R15, ·qElement+16(SB)) \ + MACC(BX, CX, ·qElement+24(SB)) \ + MACC(SI, BX, ·qElement+32(SB)) \ + MACC(DI, SI, ·qElement+40(SB)) \ + MACC(R8, DI, ·qElement+48(SB)) \ + MACC(R9, R8, ·qElement+56(SB)) \ + MACC(R10, R9, ·qElement+64(SB)) \ + MACC(R11, R10, ·qElement+72(SB)) \ + MOVQ $0, AX \ + ADCXQ AX, R11 \ + ADOXQ BP, R11 \ + +#define MUL_WORD_0() \ + XORQ AX, AX \ + MULXQ 0(R12), R14, R15 \ + MULXQ 8(R12), AX, CX \ + ADOXQ AX, R15 \ + MULXQ 16(R12), AX, BX \ + ADOXQ AX, CX \ + MULXQ 24(R12), AX, SI \ + ADOXQ AX, BX \ + MULXQ 32(R12), AX, DI \ + ADOXQ AX, SI \ + MULXQ 40(R12), AX, R8 \ + ADOXQ AX, DI \ + MULXQ 48(R12), AX, R9 \ + ADOXQ AX, R8 \ + MULXQ 56(R12), AX, R10 \ + ADOXQ AX, R9 \ + MULXQ 64(R12), AX, R11 \ + ADOXQ AX, R10 \ + MULXQ 72(R12), AX, BP \ + ADOXQ AX, R11 \ + MOVQ $0, AX \ + ADOXQ AX, BP \ + DIV_SHIFT() \ + +#define MUL_WORD_N() \ + XORQ AX, AX \ + MULXQ 0(R12), AX, BP \ + ADOXQ AX, R14 \ + MACC(BP, R15, 8(R12)) \ + MACC(BP, CX, 16(R12)) \ + MACC(BP, BX, 24(R12)) \ + MACC(BP, SI, 32(R12)) \ + MACC(BP, DI, 40(R12)) \ + MACC(BP, R8, 48(R12)) \ + MACC(BP, R9, 56(R12)) \ + MACC(BP, R10, 64(R12)) \ + MACC(BP, R11, 72(R12)) \ + MOVQ $0, AX \ + ADCXQ AX, BP \ + ADOXQ AX, BP \ + DIV_SHIFT() \ + + // mul body + MOVQ 0(R13), DX + MUL_WORD_0() + MOVQ 8(R13), DX + MUL_WORD_N() + MOVQ 16(R13), DX + MUL_WORD_N() + MOVQ 24(R13), DX + MUL_WORD_N() + MOVQ 32(R13), DX + MUL_WORD_N() + MOVQ 40(R13), DX + MUL_WORD_N() + MOVQ 48(R13), DX + MUL_WORD_N() + MOVQ 56(R13), DX + MUL_WORD_N() + MOVQ 64(R13), DX + MUL_WORD_N() + MOVQ 72(R13), DX + MUL_WORD_N() + + // reduce element(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11) using temp registers (R12,R13,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP)) + REDUCE(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP)) + + MOVQ res+0(FP), AX + MOVQ R14, 0(AX) + MOVQ R15, 8(AX) + MOVQ CX, 16(AX) + MOVQ BX, 24(AX) + MOVQ SI, 32(AX) + MOVQ DI, 40(AX) + MOVQ R8, 48(AX) + MOVQ R9, 56(AX) + MOVQ R10, 64(AX) + MOVQ R11, 72(AX) + RET + +noAdx_2: + MOVQ res+0(FP), AX + MOVQ AX, (SP) + MOVQ x+8(FP), AX + MOVQ AX, 8(SP) + MOVQ y+16(FP), AX + MOVQ AX, 16(SP) + CALL ·_mulGeneric(SB) + RET + +TEXT ·fromMont(SB), $64-8 + NO_LOCAL_POINTERS + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + // when y = 1 we have: + // for i=0 to N-1 + // t[i] = x[i] + // for i=0 to N-1 + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // t[N-1] = C + CMPB ·supportAdx(SB), $1 + JNE noAdx_3 + MOVQ res+0(FP), DX + MOVQ 0(DX), R14 + MOVQ 8(DX), R15 + MOVQ 16(DX), CX + MOVQ 24(DX), BX + MOVQ 32(DX), SI + MOVQ 40(DX), DI + MOVQ 48(DX), R8 + MOVQ 56(DX), R9 + MOVQ 64(DX), R10 + MOVQ 72(DX), R11 + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + + // (C,t[5]) := t[6] + m*q[6] + C + ADCXQ R8, DI + MULXQ ·qElement+48(SB), AX, R8 + ADOXQ AX, DI + + // (C,t[6]) := t[7] + m*q[7] + C + ADCXQ R9, R8 + MULXQ ·qElement+56(SB), AX, R9 + ADOXQ AX, R8 + + // (C,t[7]) := t[8] + m*q[8] + C + ADCXQ R10, R9 + MULXQ ·qElement+64(SB), AX, R10 + ADOXQ AX, R9 + + // (C,t[8]) := t[9] + m*q[9] + C + ADCXQ R11, R10 + MULXQ ·qElement+72(SB), AX, R11 + ADOXQ AX, R10 + MOVQ $0, AX + ADCXQ AX, R11 + ADOXQ AX, R11 + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + + // (C,t[5]) := t[6] + m*q[6] + C + ADCXQ R8, DI + MULXQ ·qElement+48(SB), AX, R8 + ADOXQ AX, DI + + // (C,t[6]) := t[7] + m*q[7] + C + ADCXQ R9, R8 + MULXQ ·qElement+56(SB), AX, R9 + ADOXQ AX, R8 + + // (C,t[7]) := t[8] + m*q[8] + C + ADCXQ R10, R9 + MULXQ ·qElement+64(SB), AX, R10 + ADOXQ AX, R9 + + // (C,t[8]) := t[9] + m*q[9] + C + ADCXQ R11, R10 + MULXQ ·qElement+72(SB), AX, R11 + ADOXQ AX, R10 + MOVQ $0, AX + ADCXQ AX, R11 + ADOXQ AX, R11 + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + + // (C,t[5]) := t[6] + m*q[6] + C + ADCXQ R8, DI + MULXQ ·qElement+48(SB), AX, R8 + ADOXQ AX, DI + + // (C,t[6]) := t[7] + m*q[7] + C + ADCXQ R9, R8 + MULXQ ·qElement+56(SB), AX, R9 + ADOXQ AX, R8 + + // (C,t[7]) := t[8] + m*q[8] + C + ADCXQ R10, R9 + MULXQ ·qElement+64(SB), AX, R10 + ADOXQ AX, R9 + + // (C,t[8]) := t[9] + m*q[9] + C + ADCXQ R11, R10 + MULXQ ·qElement+72(SB), AX, R11 + ADOXQ AX, R10 + MOVQ $0, AX + ADCXQ AX, R11 + ADOXQ AX, R11 + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + + // (C,t[5]) := t[6] + m*q[6] + C + ADCXQ R8, DI + MULXQ ·qElement+48(SB), AX, R8 + ADOXQ AX, DI + + // (C,t[6]) := t[7] + m*q[7] + C + ADCXQ R9, R8 + MULXQ ·qElement+56(SB), AX, R9 + ADOXQ AX, R8 + + // (C,t[7]) := t[8] + m*q[8] + C + ADCXQ R10, R9 + MULXQ ·qElement+64(SB), AX, R10 + ADOXQ AX, R9 + + // (C,t[8]) := t[9] + m*q[9] + C + ADCXQ R11, R10 + MULXQ ·qElement+72(SB), AX, R11 + ADOXQ AX, R10 + MOVQ $0, AX + ADCXQ AX, R11 + ADOXQ AX, R11 + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + + // (C,t[5]) := t[6] + m*q[6] + C + ADCXQ R8, DI + MULXQ ·qElement+48(SB), AX, R8 + ADOXQ AX, DI + + // (C,t[6]) := t[7] + m*q[7] + C + ADCXQ R9, R8 + MULXQ ·qElement+56(SB), AX, R9 + ADOXQ AX, R8 + + // (C,t[7]) := t[8] + m*q[8] + C + ADCXQ R10, R9 + MULXQ ·qElement+64(SB), AX, R10 + ADOXQ AX, R9 + + // (C,t[8]) := t[9] + m*q[9] + C + ADCXQ R11, R10 + MULXQ ·qElement+72(SB), AX, R11 + ADOXQ AX, R10 + MOVQ $0, AX + ADCXQ AX, R11 + ADOXQ AX, R11 + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + + // (C,t[5]) := t[6] + m*q[6] + C + ADCXQ R8, DI + MULXQ ·qElement+48(SB), AX, R8 + ADOXQ AX, DI + + // (C,t[6]) := t[7] + m*q[7] + C + ADCXQ R9, R8 + MULXQ ·qElement+56(SB), AX, R9 + ADOXQ AX, R8 + + // (C,t[7]) := t[8] + m*q[8] + C + ADCXQ R10, R9 + MULXQ ·qElement+64(SB), AX, R10 + ADOXQ AX, R9 + + // (C,t[8]) := t[9] + m*q[9] + C + ADCXQ R11, R10 + MULXQ ·qElement+72(SB), AX, R11 + ADOXQ AX, R10 + MOVQ $0, AX + ADCXQ AX, R11 + ADOXQ AX, R11 + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + + // (C,t[5]) := t[6] + m*q[6] + C + ADCXQ R8, DI + MULXQ ·qElement+48(SB), AX, R8 + ADOXQ AX, DI + + // (C,t[6]) := t[7] + m*q[7] + C + ADCXQ R9, R8 + MULXQ ·qElement+56(SB), AX, R9 + ADOXQ AX, R8 + + // (C,t[7]) := t[8] + m*q[8] + C + ADCXQ R10, R9 + MULXQ ·qElement+64(SB), AX, R10 + ADOXQ AX, R9 + + // (C,t[8]) := t[9] + m*q[9] + C + ADCXQ R11, R10 + MULXQ ·qElement+72(SB), AX, R11 + ADOXQ AX, R10 + MOVQ $0, AX + ADCXQ AX, R11 + ADOXQ AX, R11 + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + + // (C,t[5]) := t[6] + m*q[6] + C + ADCXQ R8, DI + MULXQ ·qElement+48(SB), AX, R8 + ADOXQ AX, DI + + // (C,t[6]) := t[7] + m*q[7] + C + ADCXQ R9, R8 + MULXQ ·qElement+56(SB), AX, R9 + ADOXQ AX, R8 + + // (C,t[7]) := t[8] + m*q[8] + C + ADCXQ R10, R9 + MULXQ ·qElement+64(SB), AX, R10 + ADOXQ AX, R9 + + // (C,t[8]) := t[9] + m*q[9] + C + ADCXQ R11, R10 + MULXQ ·qElement+72(SB), AX, R11 + ADOXQ AX, R10 + MOVQ $0, AX + ADCXQ AX, R11 + ADOXQ AX, R11 + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + + // (C,t[5]) := t[6] + m*q[6] + C + ADCXQ R8, DI + MULXQ ·qElement+48(SB), AX, R8 + ADOXQ AX, DI + + // (C,t[6]) := t[7] + m*q[7] + C + ADCXQ R9, R8 + MULXQ ·qElement+56(SB), AX, R9 + ADOXQ AX, R8 + + // (C,t[7]) := t[8] + m*q[8] + C + ADCXQ R10, R9 + MULXQ ·qElement+64(SB), AX, R10 + ADOXQ AX, R9 + + // (C,t[8]) := t[9] + m*q[9] + C + ADCXQ R11, R10 + MULXQ ·qElement+72(SB), AX, R11 + ADOXQ AX, R10 + MOVQ $0, AX + ADCXQ AX, R11 + ADOXQ AX, R11 + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + + // (C,t[5]) := t[6] + m*q[6] + C + ADCXQ R8, DI + MULXQ ·qElement+48(SB), AX, R8 + ADOXQ AX, DI + + // (C,t[6]) := t[7] + m*q[7] + C + ADCXQ R9, R8 + MULXQ ·qElement+56(SB), AX, R9 + ADOXQ AX, R8 + + // (C,t[7]) := t[8] + m*q[8] + C + ADCXQ R10, R9 + MULXQ ·qElement+64(SB), AX, R10 + ADOXQ AX, R9 + + // (C,t[8]) := t[9] + m*q[9] + C + ADCXQ R11, R10 + MULXQ ·qElement+72(SB), AX, R11 + ADOXQ AX, R10 + MOVQ $0, AX + ADCXQ AX, R11 + ADOXQ AX, R11 + + // reduce element(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11) using temp registers (R12,R13,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP)) + REDUCE(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP)) + + MOVQ res+0(FP), AX + MOVQ R14, 0(AX) + MOVQ R15, 8(AX) + MOVQ CX, 16(AX) + MOVQ BX, 24(AX) + MOVQ SI, 32(AX) + MOVQ DI, 40(AX) + MOVQ R8, 48(AX) + MOVQ R9, 56(AX) + MOVQ R10, 64(AX) + MOVQ R11, 72(AX) + RET + +noAdx_3: + MOVQ res+0(FP), AX + MOVQ AX, (SP) + CALL ·_fromMontGeneric(SB) + RET diff --git a/field/asm/element_12w_amd64.s b/field/asm/element_12w_amd64.s new file mode 100644 index 000000000..52cf02adc --- /dev/null +++ b/field/asm/element_12w_amd64.s @@ -0,0 +1,1557 @@ +// Code generated by gnark-crypto/generator. DO NOT EDIT. +#include "textflag.h" +#include "funcdata.h" +#include "go_asm.h" + +#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, ra6, ra7, ra8, ra9, ra10, ra11, rb0, rb1, rb2, rb3, rb4, rb5, rb6, rb7, rb8, rb9, rb10, rb11) \ + MOVQ ra0, rb0; \ + SUBQ ·qElement(SB), ra0; \ + MOVQ ra1, rb1; \ + SBBQ ·qElement+8(SB), ra1; \ + MOVQ ra2, rb2; \ + SBBQ ·qElement+16(SB), ra2; \ + MOVQ ra3, rb3; \ + SBBQ ·qElement+24(SB), ra3; \ + MOVQ ra4, rb4; \ + SBBQ ·qElement+32(SB), ra4; \ + MOVQ ra5, rb5; \ + SBBQ ·qElement+40(SB), ra5; \ + MOVQ ra6, rb6; \ + SBBQ ·qElement+48(SB), ra6; \ + MOVQ ra7, rb7; \ + SBBQ ·qElement+56(SB), ra7; \ + MOVQ ra8, rb8; \ + SBBQ ·qElement+64(SB), ra8; \ + MOVQ ra9, rb9; \ + SBBQ ·qElement+72(SB), ra9; \ + MOVQ ra10, rb10; \ + SBBQ ·qElement+80(SB), ra10; \ + MOVQ ra11, rb11; \ + SBBQ ·qElement+88(SB), ra11; \ + CMOVQCS rb0, ra0; \ + CMOVQCS rb1, ra1; \ + CMOVQCS rb2, ra2; \ + CMOVQCS rb3, ra3; \ + CMOVQCS rb4, ra4; \ + CMOVQCS rb5, ra5; \ + CMOVQCS rb6, ra6; \ + CMOVQCS rb7, ra7; \ + CMOVQCS rb8, ra8; \ + CMOVQCS rb9, ra9; \ + CMOVQCS rb10, ra10; \ + CMOVQCS rb11, ra11; \ + +TEXT ·reduce(SB), $88-8 + MOVQ res+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + MOVQ 48(AX), R9 + MOVQ 56(AX), R10 + MOVQ 64(AX), R11 + MOVQ 72(AX), R12 + MOVQ 80(AX), R13 + MOVQ 88(AX), R14 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + MOVQ R9, 48(AX) + MOVQ R10, 56(AX) + MOVQ R11, 64(AX) + MOVQ R12, 72(AX) + MOVQ R13, 80(AX) + MOVQ R14, 88(AX) + RET + +// MulBy3(x *Element) +TEXT ·MulBy3(SB), $88-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + MOVQ 48(AX), R9 + MOVQ 56(AX), R10 + MOVQ 64(AX), R11 + MOVQ 72(AX), R12 + MOVQ 80(AX), R13 + MOVQ 88(AX), R14 + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + ADCQ R9, R9 + ADCQ R10, R10 + ADCQ R11, R11 + ADCQ R12, R12 + ADCQ R13, R13 + ADCQ R14, R14 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + ADCQ 40(AX), R8 + ADCQ 48(AX), R9 + ADCQ 56(AX), R10 + ADCQ 64(AX), R11 + ADCQ 72(AX), R12 + ADCQ 80(AX), R13 + ADCQ 88(AX), R14 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + MOVQ R9, 48(AX) + MOVQ R10, 56(AX) + MOVQ R11, 64(AX) + MOVQ R12, 72(AX) + MOVQ R13, 80(AX) + MOVQ R14, 88(AX) + RET + +// MulBy5(x *Element) +TEXT ·MulBy5(SB), $88-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + MOVQ 48(AX), R9 + MOVQ 56(AX), R10 + MOVQ 64(AX), R11 + MOVQ 72(AX), R12 + MOVQ 80(AX), R13 + MOVQ 88(AX), R14 + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + ADCQ R9, R9 + ADCQ R10, R10 + ADCQ R11, R11 + ADCQ R12, R12 + ADCQ R13, R13 + ADCQ R14, R14 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + ADCQ R9, R9 + ADCQ R10, R10 + ADCQ R11, R11 + ADCQ R12, R12 + ADCQ R13, R13 + ADCQ R14, R14 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + ADCQ 40(AX), R8 + ADCQ 48(AX), R9 + ADCQ 56(AX), R10 + ADCQ 64(AX), R11 + ADCQ 72(AX), R12 + ADCQ 80(AX), R13 + ADCQ 88(AX), R14 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + MOVQ R9, 48(AX) + MOVQ R10, 56(AX) + MOVQ R11, 64(AX) + MOVQ R12, 72(AX) + MOVQ R13, 80(AX) + MOVQ R14, 88(AX) + RET + +// MulBy13(x *Element) +TEXT ·MulBy13(SB), $184-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + MOVQ 48(AX), R9 + MOVQ 56(AX), R10 + MOVQ 64(AX), R11 + MOVQ 72(AX), R12 + MOVQ 80(AX), R13 + MOVQ 88(AX), R14 + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + ADCQ R9, R9 + ADCQ R10, R10 + ADCQ R11, R11 + ADCQ R12, R12 + ADCQ R13, R13 + ADCQ R14, R14 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + ADCQ R9, R9 + ADCQ R10, R10 + ADCQ R11, R11 + ADCQ R12, R12 + ADCQ R13, R13 + ADCQ R14, R14 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (s11-96(SP),s12-104(SP),s13-112(SP),s14-120(SP),s15-128(SP),s16-136(SP),s17-144(SP),s18-152(SP),s19-160(SP),s20-168(SP),s21-176(SP),s22-184(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,s11-96(SP),s12-104(SP),s13-112(SP),s14-120(SP),s15-128(SP),s16-136(SP),s17-144(SP),s18-152(SP),s19-160(SP),s20-168(SP),s21-176(SP),s22-184(SP)) + + MOVQ DX, s11-96(SP) + MOVQ CX, s12-104(SP) + MOVQ BX, s13-112(SP) + MOVQ SI, s14-120(SP) + MOVQ DI, s15-128(SP) + MOVQ R8, s16-136(SP) + MOVQ R9, s17-144(SP) + MOVQ R10, s18-152(SP) + MOVQ R11, s19-160(SP) + MOVQ R12, s20-168(SP) + MOVQ R13, s21-176(SP) + MOVQ R14, s22-184(SP) + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + ADCQ R9, R9 + ADCQ R10, R10 + ADCQ R11, R11 + ADCQ R12, R12 + ADCQ R13, R13 + ADCQ R14, R14 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + + ADDQ s11-96(SP), DX + ADCQ s12-104(SP), CX + ADCQ s13-112(SP), BX + ADCQ s14-120(SP), SI + ADCQ s15-128(SP), DI + ADCQ s16-136(SP), R8 + ADCQ s17-144(SP), R9 + ADCQ s18-152(SP), R10 + ADCQ s19-160(SP), R11 + ADCQ s20-168(SP), R12 + ADCQ s21-176(SP), R13 + ADCQ s22-184(SP), R14 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + ADCQ 40(AX), R8 + ADCQ 48(AX), R9 + ADCQ 56(AX), R10 + ADCQ 64(AX), R11 + ADCQ 72(AX), R12 + ADCQ 80(AX), R13 + ADCQ 88(AX), R14 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + MOVQ R9, 48(AX) + MOVQ R10, 56(AX) + MOVQ R11, 64(AX) + MOVQ R12, 72(AX) + MOVQ R13, 80(AX) + MOVQ R14, 88(AX) + RET + +// Butterfly(a, b *Element) sets a = a + b; b = a - b +TEXT ·Butterfly(SB), $88-16 + MOVQ b+8(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + MOVQ 48(AX), R9 + MOVQ 56(AX), R10 + MOVQ 64(AX), R11 + MOVQ 72(AX), R12 + MOVQ 80(AX), R13 + MOVQ 88(AX), R14 + MOVQ a+0(FP), AX + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + ADCQ 40(AX), R8 + ADCQ 48(AX), R9 + ADCQ 56(AX), R10 + ADCQ 64(AX), R11 + ADCQ 72(AX), R12 + ADCQ 80(AX), R13 + ADCQ 88(AX), R14 + MOVQ DX, R15 + MOVQ CX, s0-8(SP) + MOVQ BX, s1-16(SP) + MOVQ SI, s2-24(SP) + MOVQ DI, s3-32(SP) + MOVQ R8, s4-40(SP) + MOVQ R9, s5-48(SP) + MOVQ R10, s6-56(SP) + MOVQ R11, s7-64(SP) + MOVQ R12, s8-72(SP) + MOVQ R13, s9-80(SP) + MOVQ R14, s10-88(SP) + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + MOVQ 48(AX), R9 + MOVQ 56(AX), R10 + MOVQ 64(AX), R11 + MOVQ 72(AX), R12 + MOVQ 80(AX), R13 + MOVQ 88(AX), R14 + MOVQ b+8(FP), AX + SUBQ 0(AX), DX + SBBQ 8(AX), CX + SBBQ 16(AX), BX + SBBQ 24(AX), SI + SBBQ 32(AX), DI + SBBQ 40(AX), R8 + SBBQ 48(AX), R9 + SBBQ 56(AX), R10 + SBBQ 64(AX), R11 + SBBQ 72(AX), R12 + SBBQ 80(AX), R13 + SBBQ 88(AX), R14 + JCC noReduce_1 + MOVQ $const_q0, AX + ADDQ AX, DX + MOVQ $const_q1, AX + ADCQ AX, CX + MOVQ $const_q2, AX + ADCQ AX, BX + MOVQ $const_q3, AX + ADCQ AX, SI + MOVQ $const_q4, AX + ADCQ AX, DI + MOVQ $const_q5, AX + ADCQ AX, R8 + MOVQ $const_q6, AX + ADCQ AX, R9 + MOVQ $const_q7, AX + ADCQ AX, R10 + MOVQ $const_q8, AX + ADCQ AX, R11 + MOVQ $const_q9, AX + ADCQ AX, R12 + MOVQ $const_q10, AX + ADCQ AX, R13 + MOVQ $const_q11, AX + ADCQ AX, R14 + +noReduce_1: + MOVQ b+8(FP), AX + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + MOVQ R9, 48(AX) + MOVQ R10, 56(AX) + MOVQ R11, 64(AX) + MOVQ R12, 72(AX) + MOVQ R13, 80(AX) + MOVQ R14, 88(AX) + MOVQ R15, DX + MOVQ s0-8(SP), CX + MOVQ s1-16(SP), BX + MOVQ s2-24(SP), SI + MOVQ s3-32(SP), DI + MOVQ s4-40(SP), R8 + MOVQ s5-48(SP), R9 + MOVQ s6-56(SP), R10 + MOVQ s7-64(SP), R11 + MOVQ s8-72(SP), R12 + MOVQ s9-80(SP), R13 + MOVQ s10-88(SP), R14 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + + MOVQ a+0(FP), AX + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + MOVQ R9, 48(AX) + MOVQ R10, 56(AX) + MOVQ R11, 64(AX) + MOVQ R12, 72(AX) + MOVQ R13, 80(AX) + MOVQ R14, 88(AX) + RET + +// mul(res, x, y *Element) +TEXT ·mul(SB), $96-24 + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + // See github.com/gnark-crypto/field/generator for more comments. + + NO_LOCAL_POINTERS + CMPB ·supportAdx(SB), $1 + JNE noAdx_2 + MOVQ x+8(FP), AX + + // x[0] -> s0-8(SP) + // x[1] -> s1-16(SP) + // x[2] -> s2-24(SP) + // x[3] -> s3-32(SP) + // x[4] -> s4-40(SP) + // x[5] -> s5-48(SP) + // x[6] -> s6-56(SP) + // x[7] -> s7-64(SP) + // x[8] -> s8-72(SP) + // x[9] -> s9-80(SP) + // x[10] -> s10-88(SP) + // x[11] -> s11-96(SP) + MOVQ 0(AX), R14 + MOVQ 8(AX), R15 + MOVQ 16(AX), CX + MOVQ 24(AX), BX + MOVQ 32(AX), SI + MOVQ 40(AX), DI + MOVQ 48(AX), R8 + MOVQ 56(AX), R9 + MOVQ 64(AX), R10 + MOVQ 72(AX), R11 + MOVQ 80(AX), R12 + MOVQ 88(AX), R13 + MOVQ R14, s0-8(SP) + MOVQ R15, s1-16(SP) + MOVQ CX, s2-24(SP) + MOVQ BX, s3-32(SP) + MOVQ SI, s4-40(SP) + MOVQ DI, s5-48(SP) + MOVQ R8, s6-56(SP) + MOVQ R9, s7-64(SP) + MOVQ R10, s8-72(SP) + MOVQ R11, s9-80(SP) + MOVQ R12, s10-88(SP) + MOVQ R13, s11-96(SP) + + // A -> BP + // t[0] -> R14 + // t[1] -> R15 + // t[2] -> CX + // t[3] -> BX + // t[4] -> SI + // t[5] -> DI + // t[6] -> R8 + // t[7] -> R9 + // t[8] -> R10 + // t[9] -> R11 + // t[10] -> R12 + // t[11] -> R13 +#define MACC(in0, in1, in2) \ + ADCXQ in0, in1 \ + MULXQ in2, AX, in0 \ + ADOXQ AX, in1 \ + +#define DIV_SHIFT() \ + PUSHQ BP \ + MOVQ $const_qInvNeg, DX \ + IMULQ R14, DX \ + XORQ AX, AX \ + MULXQ ·qElement+0(SB), AX, BP \ + ADCXQ R14, AX \ + MOVQ BP, R14 \ + POPQ BP \ + MACC(R15, R14, ·qElement+8(SB)) \ + MACC(CX, R15, ·qElement+16(SB)) \ + MACC(BX, CX, ·qElement+24(SB)) \ + MACC(SI, BX, ·qElement+32(SB)) \ + MACC(DI, SI, ·qElement+40(SB)) \ + MACC(R8, DI, ·qElement+48(SB)) \ + MACC(R9, R8, ·qElement+56(SB)) \ + MACC(R10, R9, ·qElement+64(SB)) \ + MACC(R11, R10, ·qElement+72(SB)) \ + MACC(R12, R11, ·qElement+80(SB)) \ + MACC(R13, R12, ·qElement+88(SB)) \ + MOVQ $0, AX \ + ADCXQ AX, R13 \ + ADOXQ BP, R13 \ + +#define MUL_WORD_0() \ + XORQ AX, AX \ + MULXQ s0-8(SP), R14, R15 \ + MULXQ s1-16(SP), AX, CX \ + ADOXQ AX, R15 \ + MULXQ s2-24(SP), AX, BX \ + ADOXQ AX, CX \ + MULXQ s3-32(SP), AX, SI \ + ADOXQ AX, BX \ + MULXQ s4-40(SP), AX, DI \ + ADOXQ AX, SI \ + MULXQ s5-48(SP), AX, R8 \ + ADOXQ AX, DI \ + MULXQ s6-56(SP), AX, R9 \ + ADOXQ AX, R8 \ + MULXQ s7-64(SP), AX, R10 \ + ADOXQ AX, R9 \ + MULXQ s8-72(SP), AX, R11 \ + ADOXQ AX, R10 \ + MULXQ s9-80(SP), AX, R12 \ + ADOXQ AX, R11 \ + MULXQ s10-88(SP), AX, R13 \ + ADOXQ AX, R12 \ + MULXQ s11-96(SP), AX, BP \ + ADOXQ AX, R13 \ + MOVQ $0, AX \ + ADOXQ AX, BP \ + DIV_SHIFT() \ + +#define MUL_WORD_N() \ + XORQ AX, AX \ + MULXQ s0-8(SP), AX, BP \ + ADOXQ AX, R14 \ + MACC(BP, R15, s1-16(SP)) \ + MACC(BP, CX, s2-24(SP)) \ + MACC(BP, BX, s3-32(SP)) \ + MACC(BP, SI, s4-40(SP)) \ + MACC(BP, DI, s5-48(SP)) \ + MACC(BP, R8, s6-56(SP)) \ + MACC(BP, R9, s7-64(SP)) \ + MACC(BP, R10, s8-72(SP)) \ + MACC(BP, R11, s9-80(SP)) \ + MACC(BP, R12, s10-88(SP)) \ + MACC(BP, R13, s11-96(SP)) \ + MOVQ $0, AX \ + ADCXQ AX, BP \ + ADOXQ AX, BP \ + DIV_SHIFT() \ + + // mul body + MOVQ y+16(FP), AX + MOVQ 0(AX), DX + MUL_WORD_0() + MOVQ y+16(FP), AX + MOVQ 8(AX), DX + MUL_WORD_N() + MOVQ y+16(FP), AX + MOVQ 16(AX), DX + MUL_WORD_N() + MOVQ y+16(FP), AX + MOVQ 24(AX), DX + MUL_WORD_N() + MOVQ y+16(FP), AX + MOVQ 32(AX), DX + MUL_WORD_N() + MOVQ y+16(FP), AX + MOVQ 40(AX), DX + MUL_WORD_N() + MOVQ y+16(FP), AX + MOVQ 48(AX), DX + MUL_WORD_N() + MOVQ y+16(FP), AX + MOVQ 56(AX), DX + MUL_WORD_N() + MOVQ y+16(FP), AX + MOVQ 64(AX), DX + MUL_WORD_N() + MOVQ y+16(FP), AX + MOVQ 72(AX), DX + MUL_WORD_N() + MOVQ y+16(FP), AX + MOVQ 80(AX), DX + MUL_WORD_N() + MOVQ y+16(FP), AX + MOVQ 88(AX), DX + MUL_WORD_N() + + // reduce element(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) using temp registers (s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP),s11-96(SP)) + REDUCE(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP),s11-96(SP)) + + MOVQ res+0(FP), AX + MOVQ R14, 0(AX) + MOVQ R15, 8(AX) + MOVQ CX, 16(AX) + MOVQ BX, 24(AX) + MOVQ SI, 32(AX) + MOVQ DI, 40(AX) + MOVQ R8, 48(AX) + MOVQ R9, 56(AX) + MOVQ R10, 64(AX) + MOVQ R11, 72(AX) + MOVQ R12, 80(AX) + MOVQ R13, 88(AX) + RET + +noAdx_2: + MOVQ res+0(FP), AX + MOVQ AX, (SP) + MOVQ x+8(FP), AX + MOVQ AX, 8(SP) + MOVQ y+16(FP), AX + MOVQ AX, 16(SP) + CALL ·_mulGeneric(SB) + RET + +TEXT ·fromMont(SB), $96-8 + NO_LOCAL_POINTERS + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + // when y = 1 we have: + // for i=0 to N-1 + // t[i] = x[i] + // for i=0 to N-1 + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // t[N-1] = C + CMPB ·supportAdx(SB), $1 + JNE noAdx_3 + MOVQ res+0(FP), DX + MOVQ 0(DX), R14 + MOVQ 8(DX), R15 + MOVQ 16(DX), CX + MOVQ 24(DX), BX + MOVQ 32(DX), SI + MOVQ 40(DX), DI + MOVQ 48(DX), R8 + MOVQ 56(DX), R9 + MOVQ 64(DX), R10 + MOVQ 72(DX), R11 + MOVQ 80(DX), R12 + MOVQ 88(DX), R13 + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + + // (C,t[5]) := t[6] + m*q[6] + C + ADCXQ R8, DI + MULXQ ·qElement+48(SB), AX, R8 + ADOXQ AX, DI + + // (C,t[6]) := t[7] + m*q[7] + C + ADCXQ R9, R8 + MULXQ ·qElement+56(SB), AX, R9 + ADOXQ AX, R8 + + // (C,t[7]) := t[8] + m*q[8] + C + ADCXQ R10, R9 + MULXQ ·qElement+64(SB), AX, R10 + ADOXQ AX, R9 + + // (C,t[8]) := t[9] + m*q[9] + C + ADCXQ R11, R10 + MULXQ ·qElement+72(SB), AX, R11 + ADOXQ AX, R10 + + // (C,t[9]) := t[10] + m*q[10] + C + ADCXQ R12, R11 + MULXQ ·qElement+80(SB), AX, R12 + ADOXQ AX, R11 + + // (C,t[10]) := t[11] + m*q[11] + C + ADCXQ R13, R12 + MULXQ ·qElement+88(SB), AX, R13 + ADOXQ AX, R12 + MOVQ $0, AX + ADCXQ AX, R13 + ADOXQ AX, R13 + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + + // (C,t[5]) := t[6] + m*q[6] + C + ADCXQ R8, DI + MULXQ ·qElement+48(SB), AX, R8 + ADOXQ AX, DI + + // (C,t[6]) := t[7] + m*q[7] + C + ADCXQ R9, R8 + MULXQ ·qElement+56(SB), AX, R9 + ADOXQ AX, R8 + + // (C,t[7]) := t[8] + m*q[8] + C + ADCXQ R10, R9 + MULXQ ·qElement+64(SB), AX, R10 + ADOXQ AX, R9 + + // (C,t[8]) := t[9] + m*q[9] + C + ADCXQ R11, R10 + MULXQ ·qElement+72(SB), AX, R11 + ADOXQ AX, R10 + + // (C,t[9]) := t[10] + m*q[10] + C + ADCXQ R12, R11 + MULXQ ·qElement+80(SB), AX, R12 + ADOXQ AX, R11 + + // (C,t[10]) := t[11] + m*q[11] + C + ADCXQ R13, R12 + MULXQ ·qElement+88(SB), AX, R13 + ADOXQ AX, R12 + MOVQ $0, AX + ADCXQ AX, R13 + ADOXQ AX, R13 + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + + // (C,t[5]) := t[6] + m*q[6] + C + ADCXQ R8, DI + MULXQ ·qElement+48(SB), AX, R8 + ADOXQ AX, DI + + // (C,t[6]) := t[7] + m*q[7] + C + ADCXQ R9, R8 + MULXQ ·qElement+56(SB), AX, R9 + ADOXQ AX, R8 + + // (C,t[7]) := t[8] + m*q[8] + C + ADCXQ R10, R9 + MULXQ ·qElement+64(SB), AX, R10 + ADOXQ AX, R9 + + // (C,t[8]) := t[9] + m*q[9] + C + ADCXQ R11, R10 + MULXQ ·qElement+72(SB), AX, R11 + ADOXQ AX, R10 + + // (C,t[9]) := t[10] + m*q[10] + C + ADCXQ R12, R11 + MULXQ ·qElement+80(SB), AX, R12 + ADOXQ AX, R11 + + // (C,t[10]) := t[11] + m*q[11] + C + ADCXQ R13, R12 + MULXQ ·qElement+88(SB), AX, R13 + ADOXQ AX, R12 + MOVQ $0, AX + ADCXQ AX, R13 + ADOXQ AX, R13 + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + + // (C,t[5]) := t[6] + m*q[6] + C + ADCXQ R8, DI + MULXQ ·qElement+48(SB), AX, R8 + ADOXQ AX, DI + + // (C,t[6]) := t[7] + m*q[7] + C + ADCXQ R9, R8 + MULXQ ·qElement+56(SB), AX, R9 + ADOXQ AX, R8 + + // (C,t[7]) := t[8] + m*q[8] + C + ADCXQ R10, R9 + MULXQ ·qElement+64(SB), AX, R10 + ADOXQ AX, R9 + + // (C,t[8]) := t[9] + m*q[9] + C + ADCXQ R11, R10 + MULXQ ·qElement+72(SB), AX, R11 + ADOXQ AX, R10 + + // (C,t[9]) := t[10] + m*q[10] + C + ADCXQ R12, R11 + MULXQ ·qElement+80(SB), AX, R12 + ADOXQ AX, R11 + + // (C,t[10]) := t[11] + m*q[11] + C + ADCXQ R13, R12 + MULXQ ·qElement+88(SB), AX, R13 + ADOXQ AX, R12 + MOVQ $0, AX + ADCXQ AX, R13 + ADOXQ AX, R13 + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + + // (C,t[5]) := t[6] + m*q[6] + C + ADCXQ R8, DI + MULXQ ·qElement+48(SB), AX, R8 + ADOXQ AX, DI + + // (C,t[6]) := t[7] + m*q[7] + C + ADCXQ R9, R8 + MULXQ ·qElement+56(SB), AX, R9 + ADOXQ AX, R8 + + // (C,t[7]) := t[8] + m*q[8] + C + ADCXQ R10, R9 + MULXQ ·qElement+64(SB), AX, R10 + ADOXQ AX, R9 + + // (C,t[8]) := t[9] + m*q[9] + C + ADCXQ R11, R10 + MULXQ ·qElement+72(SB), AX, R11 + ADOXQ AX, R10 + + // (C,t[9]) := t[10] + m*q[10] + C + ADCXQ R12, R11 + MULXQ ·qElement+80(SB), AX, R12 + ADOXQ AX, R11 + + // (C,t[10]) := t[11] + m*q[11] + C + ADCXQ R13, R12 + MULXQ ·qElement+88(SB), AX, R13 + ADOXQ AX, R12 + MOVQ $0, AX + ADCXQ AX, R13 + ADOXQ AX, R13 + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + + // (C,t[5]) := t[6] + m*q[6] + C + ADCXQ R8, DI + MULXQ ·qElement+48(SB), AX, R8 + ADOXQ AX, DI + + // (C,t[6]) := t[7] + m*q[7] + C + ADCXQ R9, R8 + MULXQ ·qElement+56(SB), AX, R9 + ADOXQ AX, R8 + + // (C,t[7]) := t[8] + m*q[8] + C + ADCXQ R10, R9 + MULXQ ·qElement+64(SB), AX, R10 + ADOXQ AX, R9 + + // (C,t[8]) := t[9] + m*q[9] + C + ADCXQ R11, R10 + MULXQ ·qElement+72(SB), AX, R11 + ADOXQ AX, R10 + + // (C,t[9]) := t[10] + m*q[10] + C + ADCXQ R12, R11 + MULXQ ·qElement+80(SB), AX, R12 + ADOXQ AX, R11 + + // (C,t[10]) := t[11] + m*q[11] + C + ADCXQ R13, R12 + MULXQ ·qElement+88(SB), AX, R13 + ADOXQ AX, R12 + MOVQ $0, AX + ADCXQ AX, R13 + ADOXQ AX, R13 + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + + // (C,t[5]) := t[6] + m*q[6] + C + ADCXQ R8, DI + MULXQ ·qElement+48(SB), AX, R8 + ADOXQ AX, DI + + // (C,t[6]) := t[7] + m*q[7] + C + ADCXQ R9, R8 + MULXQ ·qElement+56(SB), AX, R9 + ADOXQ AX, R8 + + // (C,t[7]) := t[8] + m*q[8] + C + ADCXQ R10, R9 + MULXQ ·qElement+64(SB), AX, R10 + ADOXQ AX, R9 + + // (C,t[8]) := t[9] + m*q[9] + C + ADCXQ R11, R10 + MULXQ ·qElement+72(SB), AX, R11 + ADOXQ AX, R10 + + // (C,t[9]) := t[10] + m*q[10] + C + ADCXQ R12, R11 + MULXQ ·qElement+80(SB), AX, R12 + ADOXQ AX, R11 + + // (C,t[10]) := t[11] + m*q[11] + C + ADCXQ R13, R12 + MULXQ ·qElement+88(SB), AX, R13 + ADOXQ AX, R12 + MOVQ $0, AX + ADCXQ AX, R13 + ADOXQ AX, R13 + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + + // (C,t[5]) := t[6] + m*q[6] + C + ADCXQ R8, DI + MULXQ ·qElement+48(SB), AX, R8 + ADOXQ AX, DI + + // (C,t[6]) := t[7] + m*q[7] + C + ADCXQ R9, R8 + MULXQ ·qElement+56(SB), AX, R9 + ADOXQ AX, R8 + + // (C,t[7]) := t[8] + m*q[8] + C + ADCXQ R10, R9 + MULXQ ·qElement+64(SB), AX, R10 + ADOXQ AX, R9 + + // (C,t[8]) := t[9] + m*q[9] + C + ADCXQ R11, R10 + MULXQ ·qElement+72(SB), AX, R11 + ADOXQ AX, R10 + + // (C,t[9]) := t[10] + m*q[10] + C + ADCXQ R12, R11 + MULXQ ·qElement+80(SB), AX, R12 + ADOXQ AX, R11 + + // (C,t[10]) := t[11] + m*q[11] + C + ADCXQ R13, R12 + MULXQ ·qElement+88(SB), AX, R13 + ADOXQ AX, R12 + MOVQ $0, AX + ADCXQ AX, R13 + ADOXQ AX, R13 + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + + // (C,t[5]) := t[6] + m*q[6] + C + ADCXQ R8, DI + MULXQ ·qElement+48(SB), AX, R8 + ADOXQ AX, DI + + // (C,t[6]) := t[7] + m*q[7] + C + ADCXQ R9, R8 + MULXQ ·qElement+56(SB), AX, R9 + ADOXQ AX, R8 + + // (C,t[7]) := t[8] + m*q[8] + C + ADCXQ R10, R9 + MULXQ ·qElement+64(SB), AX, R10 + ADOXQ AX, R9 + + // (C,t[8]) := t[9] + m*q[9] + C + ADCXQ R11, R10 + MULXQ ·qElement+72(SB), AX, R11 + ADOXQ AX, R10 + + // (C,t[9]) := t[10] + m*q[10] + C + ADCXQ R12, R11 + MULXQ ·qElement+80(SB), AX, R12 + ADOXQ AX, R11 + + // (C,t[10]) := t[11] + m*q[11] + C + ADCXQ R13, R12 + MULXQ ·qElement+88(SB), AX, R13 + ADOXQ AX, R12 + MOVQ $0, AX + ADCXQ AX, R13 + ADOXQ AX, R13 + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + + // (C,t[5]) := t[6] + m*q[6] + C + ADCXQ R8, DI + MULXQ ·qElement+48(SB), AX, R8 + ADOXQ AX, DI + + // (C,t[6]) := t[7] + m*q[7] + C + ADCXQ R9, R8 + MULXQ ·qElement+56(SB), AX, R9 + ADOXQ AX, R8 + + // (C,t[7]) := t[8] + m*q[8] + C + ADCXQ R10, R9 + MULXQ ·qElement+64(SB), AX, R10 + ADOXQ AX, R9 + + // (C,t[8]) := t[9] + m*q[9] + C + ADCXQ R11, R10 + MULXQ ·qElement+72(SB), AX, R11 + ADOXQ AX, R10 + + // (C,t[9]) := t[10] + m*q[10] + C + ADCXQ R12, R11 + MULXQ ·qElement+80(SB), AX, R12 + ADOXQ AX, R11 + + // (C,t[10]) := t[11] + m*q[11] + C + ADCXQ R13, R12 + MULXQ ·qElement+88(SB), AX, R13 + ADOXQ AX, R12 + MOVQ $0, AX + ADCXQ AX, R13 + ADOXQ AX, R13 + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + + // (C,t[5]) := t[6] + m*q[6] + C + ADCXQ R8, DI + MULXQ ·qElement+48(SB), AX, R8 + ADOXQ AX, DI + + // (C,t[6]) := t[7] + m*q[7] + C + ADCXQ R9, R8 + MULXQ ·qElement+56(SB), AX, R9 + ADOXQ AX, R8 + + // (C,t[7]) := t[8] + m*q[8] + C + ADCXQ R10, R9 + MULXQ ·qElement+64(SB), AX, R10 + ADOXQ AX, R9 + + // (C,t[8]) := t[9] + m*q[9] + C + ADCXQ R11, R10 + MULXQ ·qElement+72(SB), AX, R11 + ADOXQ AX, R10 + + // (C,t[9]) := t[10] + m*q[10] + C + ADCXQ R12, R11 + MULXQ ·qElement+80(SB), AX, R12 + ADOXQ AX, R11 + + // (C,t[10]) := t[11] + m*q[11] + C + ADCXQ R13, R12 + MULXQ ·qElement+88(SB), AX, R13 + ADOXQ AX, R12 + MOVQ $0, AX + ADCXQ AX, R13 + ADOXQ AX, R13 + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + + // (C,t[5]) := t[6] + m*q[6] + C + ADCXQ R8, DI + MULXQ ·qElement+48(SB), AX, R8 + ADOXQ AX, DI + + // (C,t[6]) := t[7] + m*q[7] + C + ADCXQ R9, R8 + MULXQ ·qElement+56(SB), AX, R9 + ADOXQ AX, R8 + + // (C,t[7]) := t[8] + m*q[8] + C + ADCXQ R10, R9 + MULXQ ·qElement+64(SB), AX, R10 + ADOXQ AX, R9 + + // (C,t[8]) := t[9] + m*q[9] + C + ADCXQ R11, R10 + MULXQ ·qElement+72(SB), AX, R11 + ADOXQ AX, R10 + + // (C,t[9]) := t[10] + m*q[10] + C + ADCXQ R12, R11 + MULXQ ·qElement+80(SB), AX, R12 + ADOXQ AX, R11 + + // (C,t[10]) := t[11] + m*q[11] + C + ADCXQ R13, R12 + MULXQ ·qElement+88(SB), AX, R13 + ADOXQ AX, R12 + MOVQ $0, AX + ADCXQ AX, R13 + ADOXQ AX, R13 + + // reduce element(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) using temp registers (s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP),s11-96(SP)) + REDUCE(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP),s11-96(SP)) + + MOVQ res+0(FP), AX + MOVQ R14, 0(AX) + MOVQ R15, 8(AX) + MOVQ CX, 16(AX) + MOVQ BX, 24(AX) + MOVQ SI, 32(AX) + MOVQ DI, 40(AX) + MOVQ R8, 48(AX) + MOVQ R9, 56(AX) + MOVQ R10, 64(AX) + MOVQ R11, 72(AX) + MOVQ R12, 80(AX) + MOVQ R13, 88(AX) + RET + +noAdx_3: + MOVQ res+0(FP), AX + MOVQ AX, (SP) + CALL ·_fromMontGeneric(SB) + RET diff --git a/field/asm/element_4w_amd64.s b/field/asm/element_4w_amd64.s new file mode 100644 index 000000000..6f62d6310 --- /dev/null +++ b/field/asm/element_4w_amd64.s @@ -0,0 +1,2438 @@ +// Code generated by gnark-crypto/generator. DO NOT EDIT. +#include "textflag.h" +#include "funcdata.h" +#include "go_asm.h" + +#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ + MOVQ ra0, rb0; \ + SUBQ ·qElement(SB), ra0; \ + MOVQ ra1, rb1; \ + SBBQ ·qElement+8(SB), ra1; \ + MOVQ ra2, rb2; \ + SBBQ ·qElement+16(SB), ra2; \ + MOVQ ra3, rb3; \ + SBBQ ·qElement+24(SB), ra3; \ + CMOVQCS rb0, ra0; \ + CMOVQCS rb1, ra1; \ + CMOVQCS rb2, ra2; \ + CMOVQCS rb3, ra3; \ + +TEXT ·reduce(SB), NOSPLIT, $0-8 + MOVQ res+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + + // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + RET + +// MulBy3(x *Element) +TEXT ·MulBy3(SB), NOSPLIT, $0-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + + // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + + // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) + REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + RET + +// MulBy5(x *Element) +TEXT ·MulBy5(SB), NOSPLIT, $0-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + + // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) + + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + + // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) + REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + + // reduce element(DX,CX,BX,SI) using temp registers (R15,DI,R8,R9) + REDUCE(DX,CX,BX,SI,R15,DI,R8,R9) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + RET + +// MulBy13(x *Element) +TEXT ·MulBy13(SB), NOSPLIT, $0-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + + // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) + + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + + // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) + REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) + + MOVQ DX, R11 + MOVQ CX, R12 + MOVQ BX, R13 + MOVQ SI, R14 + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + + // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) + + ADDQ R11, DX + ADCQ R12, CX + ADCQ R13, BX + ADCQ R14, SI + + // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + + // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + RET + +// Butterfly(a, b *Element) sets a = a + b; b = a - b +TEXT ·Butterfly(SB), NOSPLIT, $0-16 + MOVQ a+0(FP), AX + MOVQ 0(AX), CX + MOVQ 8(AX), BX + MOVQ 16(AX), SI + MOVQ 24(AX), DI + MOVQ CX, R8 + MOVQ BX, R9 + MOVQ SI, R10 + MOVQ DI, R11 + XORQ AX, AX + MOVQ b+8(FP), DX + ADDQ 0(DX), CX + ADCQ 8(DX), BX + ADCQ 16(DX), SI + ADCQ 24(DX), DI + SUBQ 0(DX), R8 + SBBQ 8(DX), R9 + SBBQ 16(DX), R10 + SBBQ 24(DX), R11 + MOVQ $const_q0, R12 + MOVQ $const_q1, R13 + MOVQ $const_q2, R14 + MOVQ $const_q3, R15 + CMOVQCC AX, R12 + CMOVQCC AX, R13 + CMOVQCC AX, R14 + CMOVQCC AX, R15 + ADDQ R12, R8 + ADCQ R13, R9 + ADCQ R14, R10 + ADCQ R15, R11 + MOVQ R8, 0(DX) + MOVQ R9, 8(DX) + MOVQ R10, 16(DX) + MOVQ R11, 24(DX) + + // reduce element(CX,BX,SI,DI) using temp registers (R8,R9,R10,R11) + REDUCE(CX,BX,SI,DI,R8,R9,R10,R11) + + MOVQ a+0(FP), AX + MOVQ CX, 0(AX) + MOVQ BX, 8(AX) + MOVQ SI, 16(AX) + MOVQ DI, 24(AX) + RET + +// mul(res, x, y *Element) +TEXT ·mul(SB), $24-24 + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + // See github.com/gnark-crypto/field/generator for more comments. + + NO_LOCAL_POINTERS + CMPB ·supportAdx(SB), $1 + JNE noAdx_1 + MOVQ x+8(FP), SI + + // x[0] -> DI + // x[1] -> R8 + // x[2] -> R9 + // x[3] -> R10 + MOVQ 0(SI), DI + MOVQ 8(SI), R8 + MOVQ 16(SI), R9 + MOVQ 24(SI), R10 + MOVQ y+16(FP), R11 + + // A -> BP + // t[0] -> R14 + // t[1] -> R13 + // t[2] -> CX + // t[3] -> BX +#define MACC(in0, in1, in2) \ + ADCXQ in0, in1 \ + MULXQ in2, AX, in0 \ + ADOXQ AX, in1 \ + +#define DIV_SHIFT() \ + MOVQ $const_qInvNeg, DX \ + IMULQ R14, DX \ + XORQ AX, AX \ + MULXQ ·qElement+0(SB), AX, R12 \ + ADCXQ R14, AX \ + MOVQ R12, R14 \ + MACC(R13, R14, ·qElement+8(SB)) \ + MACC(CX, R13, ·qElement+16(SB)) \ + MACC(BX, CX, ·qElement+24(SB)) \ + MOVQ $0, AX \ + ADCXQ AX, BX \ + ADOXQ BP, BX \ + +#define MUL_WORD_0() \ + XORQ AX, AX \ + MULXQ DI, R14, R13 \ + MULXQ R8, AX, CX \ + ADOXQ AX, R13 \ + MULXQ R9, AX, BX \ + ADOXQ AX, CX \ + MULXQ R10, AX, BP \ + ADOXQ AX, BX \ + MOVQ $0, AX \ + ADOXQ AX, BP \ + DIV_SHIFT() \ + +#define MUL_WORD_N() \ + XORQ AX, AX \ + MULXQ DI, AX, BP \ + ADOXQ AX, R14 \ + MACC(BP, R13, R8) \ + MACC(BP, CX, R9) \ + MACC(BP, BX, R10) \ + MOVQ $0, AX \ + ADCXQ AX, BP \ + ADOXQ AX, BP \ + DIV_SHIFT() \ + + // mul body + MOVQ 0(R11), DX + MUL_WORD_0() + MOVQ 8(R11), DX + MUL_WORD_N() + MOVQ 16(R11), DX + MUL_WORD_N() + MOVQ 24(R11), DX + MUL_WORD_N() + + // reduce element(R14,R13,CX,BX) using temp registers (SI,R12,R11,DI) + REDUCE(R14,R13,CX,BX,SI,R12,R11,DI) + + MOVQ res+0(FP), AX + MOVQ R14, 0(AX) + MOVQ R13, 8(AX) + MOVQ CX, 16(AX) + MOVQ BX, 24(AX) + RET + +noAdx_1: + MOVQ res+0(FP), AX + MOVQ AX, (SP) + MOVQ x+8(FP), AX + MOVQ AX, 8(SP) + MOVQ y+16(FP), AX + MOVQ AX, 16(SP) + CALL ·_mulGeneric(SB) + RET + +TEXT ·fromMont(SB), $8-8 + NO_LOCAL_POINTERS + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + // when y = 1 we have: + // for i=0 to N-1 + // t[i] = x[i] + // for i=0 to N-1 + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // t[N-1] = C + CMPB ·supportAdx(SB), $1 + JNE noAdx_2 + MOVQ res+0(FP), DX + MOVQ 0(DX), R14 + MOVQ 8(DX), R13 + MOVQ 16(DX), CX + MOVQ 24(DX), BX + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R13, R14 + MULXQ ·qElement+8(SB), AX, R13 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R13 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R13 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ AX, BX + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R13, R14 + MULXQ ·qElement+8(SB), AX, R13 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R13 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R13 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ AX, BX + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R13, R14 + MULXQ ·qElement+8(SB), AX, R13 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R13 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R13 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ AX, BX + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R13, R14 + MULXQ ·qElement+8(SB), AX, R13 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R13 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R13 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ AX, BX + + // reduce element(R14,R13,CX,BX) using temp registers (SI,DI,R8,R9) + REDUCE(R14,R13,CX,BX,SI,DI,R8,R9) + + MOVQ res+0(FP), AX + MOVQ R14, 0(AX) + MOVQ R13, 8(AX) + MOVQ CX, 16(AX) + MOVQ BX, 24(AX) + RET + +noAdx_2: + MOVQ res+0(FP), AX + MOVQ AX, (SP) + CALL ·_fromMontGeneric(SB) + RET + +// Vector operations are partially derived from Dag Arne Osvik's work in github.com/a16z/vectorized-fields + +// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] +TEXT ·addVec(SB), NOSPLIT, $0-32 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + +loop_3: + TESTQ BX, BX + JEQ done_4 // n == 0, we are done + + // a[0] -> SI + // a[1] -> DI + // a[2] -> R8 + // a[3] -> R9 + MOVQ 0(AX), SI + MOVQ 8(AX), DI + MOVQ 16(AX), R8 + MOVQ 24(AX), R9 + ADDQ 0(DX), SI + ADCQ 8(DX), DI + ADCQ 16(DX), R8 + ADCQ 24(DX), R9 + PREFETCHT0 2048(AX) + PREFETCHT0 2048(DX) + + // reduce element(SI,DI,R8,R9) using temp registers (R10,R11,R12,R13) + REDUCE(SI,DI,R8,R9,R10,R11,R12,R13) + + MOVQ SI, 0(CX) + MOVQ DI, 8(CX) + MOVQ R8, 16(CX) + MOVQ R9, 24(CX) + + // increment pointers to visit next element + ADDQ $32, AX + ADDQ $32, DX + ADDQ $32, CX + DECQ BX // decrement n + JMP loop_3 + +done_4: + RET + +// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] +TEXT ·subVec(SB), NOSPLIT, $0-32 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + XORQ SI, SI + +loop_5: + TESTQ BX, BX + JEQ done_6 // n == 0, we are done + + // a[0] -> DI + // a[1] -> R8 + // a[2] -> R9 + // a[3] -> R10 + MOVQ 0(AX), DI + MOVQ 8(AX), R8 + MOVQ 16(AX), R9 + MOVQ 24(AX), R10 + SUBQ 0(DX), DI + SBBQ 8(DX), R8 + SBBQ 16(DX), R9 + SBBQ 24(DX), R10 + PREFETCHT0 2048(AX) + PREFETCHT0 2048(DX) + + // reduce (a-b) mod q + // q[0] -> R11 + // q[1] -> R12 + // q[2] -> R13 + // q[3] -> R14 + MOVQ $const_q0, R11 + MOVQ $const_q1, R12 + MOVQ $const_q2, R13 + MOVQ $const_q3, R14 + CMOVQCC SI, R11 + CMOVQCC SI, R12 + CMOVQCC SI, R13 + CMOVQCC SI, R14 + + // add registers (q or 0) to a, and set to result + ADDQ R11, DI + ADCQ R12, R8 + ADCQ R13, R9 + ADCQ R14, R10 + MOVQ DI, 0(CX) + MOVQ R8, 8(CX) + MOVQ R9, 16(CX) + MOVQ R10, 24(CX) + + // increment pointers to visit next element + ADDQ $32, AX + ADDQ $32, DX + ADDQ $32, CX + DECQ BX // decrement n + JMP loop_5 + +done_6: + RET + +// sumVec(res, a *Element, n uint64) res = sum(a[0...n]) +TEXT ·sumVec(SB), NOSPLIT, $0-24 + + // Derived from https://github.com/a16z/vectorized-fields + // The idea is to use Z registers to accumulate the sum of elements, 8 by 8 + // first, we handle the case where n % 8 != 0 + // then, we loop over the elements 8 by 8 and accumulate the sum in the Z registers + // finally, we reduce the sum and store it in res + // + // when we move an element of a into a Z register, we use VPMOVZXDQ + // let's note w0...w3 the 4 64bits words of ai: w0 = ai[0], w1 = ai[1], w2 = ai[2], w3 = ai[3] + // VPMOVZXDQ(ai, Z0) will result in + // Z0= [hi(w3), lo(w3), hi(w2), lo(w2), hi(w1), lo(w1), hi(w0), lo(w0)] + // with hi(wi) the high 32 bits of wi and lo(wi) the low 32 bits of wi + // we can safely add 2^32+1 times Z registers constructed this way without overflow + // since each of this lo/hi bits are moved into a "64bits" slot + // N = 2^64-1 / 2^32-1 = 2^32+1 + // + // we then propagate the carry using ADOXQ and ADCXQ + // r0 = w0l + lo(woh) + // r1 = carry + hi(woh) + w1l + lo(w1h) + // r2 = carry + hi(w1h) + w2l + lo(w2h) + // r3 = carry + hi(w2h) + w3l + lo(w3h) + // r4 = carry + hi(w3h) + // we then reduce the sum using a single-word Barrett reduction + // we pick mu = 2^288 / q; which correspond to 4.5 words max. + // meaning we must guarantee that r4 fits in 32bits. + // To do so, we reduce N to 2^32-1 (since r4 receives 2 carries max) + + MOVQ a+8(FP), R14 + MOVQ n+16(FP), R15 + + // initialize accumulators Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7 + VXORPS Z0, Z0, Z0 + VMOVDQA64 Z0, Z1 + VMOVDQA64 Z0, Z2 + VMOVDQA64 Z0, Z3 + VMOVDQA64 Z0, Z4 + VMOVDQA64 Z0, Z5 + VMOVDQA64 Z0, Z6 + VMOVDQA64 Z0, Z7 + + // n % 8 -> CX + // n / 8 -> R15 + MOVQ R15, CX + ANDQ $7, CX + SHRQ $3, R15 + +loop_single_9: + TESTQ CX, CX + JEQ loop8by8_7 // n % 8 == 0, we are going to loop over 8 by 8 + VPMOVZXDQ 0(R14), Z8 + VPADDQ Z8, Z0, Z0 + ADDQ $32, R14 + DECQ CX // decrement nMod8 + JMP loop_single_9 + +loop8by8_7: + TESTQ R15, R15 + JEQ accumulate_10 // n == 0, we are going to accumulate + VPMOVZXDQ 0*32(R14), Z8 + VPMOVZXDQ 1*32(R14), Z9 + VPMOVZXDQ 2*32(R14), Z10 + VPMOVZXDQ 3*32(R14), Z11 + VPMOVZXDQ 4*32(R14), Z12 + VPMOVZXDQ 5*32(R14), Z13 + VPMOVZXDQ 6*32(R14), Z14 + VPMOVZXDQ 7*32(R14), Z15 + PREFETCHT0 4096(R14) + VPADDQ Z8, Z0, Z0 + VPADDQ Z9, Z1, Z1 + VPADDQ Z10, Z2, Z2 + VPADDQ Z11, Z3, Z3 + VPADDQ Z12, Z4, Z4 + VPADDQ Z13, Z5, Z5 + VPADDQ Z14, Z6, Z6 + VPADDQ Z15, Z7, Z7 + + // increment pointers to visit next 8 elements + ADDQ $256, R14 + DECQ R15 // decrement n + JMP loop8by8_7 + +accumulate_10: + // accumulate the 8 Z registers into Z0 + VPADDQ Z7, Z6, Z6 + VPADDQ Z6, Z5, Z5 + VPADDQ Z5, Z4, Z4 + VPADDQ Z4, Z3, Z3 + VPADDQ Z3, Z2, Z2 + VPADDQ Z2, Z1, Z1 + VPADDQ Z1, Z0, Z0 + + // carry propagation + // lo(w0) -> BX + // hi(w0) -> SI + // lo(w1) -> DI + // hi(w1) -> R8 + // lo(w2) -> R9 + // hi(w2) -> R10 + // lo(w3) -> R11 + // hi(w3) -> R12 + VMOVQ X0, BX + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, SI + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, DI + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R8 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R9 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R10 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R11 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R12 + + // lo(hi(wo)) -> R13 + // lo(hi(w1)) -> CX + // lo(hi(w2)) -> R15 + // lo(hi(w3)) -> R14 +#define SPLIT_LO_HI(in0, in1) \ + MOVQ in1, in0 \ + ANDQ $0xffffffff, in0 \ + SHLQ $32, in0 \ + SHRQ $32, in1 \ + + SPLIT_LO_HI(R13, SI) + SPLIT_LO_HI(CX, R8) + SPLIT_LO_HI(R15, R10) + SPLIT_LO_HI(R14, R12) + + // r0 = w0l + lo(woh) + // r1 = carry + hi(woh) + w1l + lo(w1h) + // r2 = carry + hi(w1h) + w2l + lo(w2h) + // r3 = carry + hi(w2h) + w3l + lo(w3h) + // r4 = carry + hi(w3h) + + XORQ AX, AX // clear the flags + ADOXQ R13, BX + ADOXQ CX, DI + ADCXQ SI, DI + ADOXQ R15, R9 + ADCXQ R8, R9 + ADOXQ R14, R11 + ADCXQ R10, R11 + ADOXQ AX, R12 + ADCXQ AX, R12 + + // r[0] -> BX + // r[1] -> DI + // r[2] -> R9 + // r[3] -> R11 + // r[4] -> R12 + // reduce using single-word Barrett + // see see Handbook of Applied Cryptography, Algorithm 14.42. + // mu=2^288 / q -> SI + MOVQ $const_mu, SI + MOVQ R11, AX + SHRQ $32, R12, AX + MULQ SI // high bits of res stored in DX + MULXQ ·qElement+0(SB), AX, SI + SUBQ AX, BX + SBBQ SI, DI + MULXQ ·qElement+16(SB), AX, SI + SBBQ AX, R9 + SBBQ SI, R11 + SBBQ $0, R12 + MULXQ ·qElement+8(SB), AX, SI + SUBQ AX, DI + SBBQ SI, R9 + MULXQ ·qElement+24(SB), AX, SI + SBBQ AX, R11 + SBBQ SI, R12 + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + SUBQ ·qElement+0(SB), BX + SBBQ ·qElement+8(SB), DI + SBBQ ·qElement+16(SB), R9 + SBBQ ·qElement+24(SB), R11 + SBBQ $0, R12 + JCS modReduced_11 + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + SUBQ ·qElement+0(SB), BX + SBBQ ·qElement+8(SB), DI + SBBQ ·qElement+16(SB), R9 + SBBQ ·qElement+24(SB), R11 + SBBQ $0, R12 + JCS modReduced_11 + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + +modReduced_11: + MOVQ res+0(FP), SI + MOVQ R8, 0(SI) + MOVQ R10, 8(SI) + MOVQ R13, 16(SI) + MOVQ CX, 24(SI) + +done_8: + RET + +// innerProdVec(res, a,b *Element, n uint64) res = sum(a[0...n] * b[0...n]) +TEXT ·innerProdVec(SB), NOSPLIT, $0-32 + MOVQ a+8(FP), R14 + MOVQ b+16(FP), R15 + MOVQ n+24(FP), CX + + // Create mask for low dword in each qword + VPCMPEQB Y0, Y0, Y0 + VPMOVZXDQ Y0, Z5 + VPXORQ Z16, Z16, Z16 + VMOVDQA64 Z16, Z17 + VMOVDQA64 Z16, Z18 + VMOVDQA64 Z16, Z19 + VMOVDQA64 Z16, Z20 + VMOVDQA64 Z16, Z21 + VMOVDQA64 Z16, Z22 + VMOVDQA64 Z16, Z23 + VMOVDQA64 Z16, Z24 + VMOVDQA64 Z16, Z25 + VMOVDQA64 Z16, Z26 + VMOVDQA64 Z16, Z27 + VMOVDQA64 Z16, Z28 + VMOVDQA64 Z16, Z29 + VMOVDQA64 Z16, Z30 + VMOVDQA64 Z16, Z31 + TESTQ CX, CX + JEQ done_13 // n == 0, we are done + +loop_12: + TESTQ CX, CX + JEQ accumulate_14 // n == 0 we can accumulate + VPMOVZXDQ (R15), Z4 + ADDQ $32, R15 + + // we multiply and accumulate partial products of 4 bytes * 32 bytes +#define MAC(in0, in1, in2) \ + VPMULUDQ.BCST in0, Z4, Z2 \ + VPSRLQ $32, Z2, Z3 \ + VPANDQ Z5, Z2, Z2 \ + VPADDQ Z2, in1, in1 \ + VPADDQ Z3, in2, in2 \ + + MAC(0*4(R14), Z16, Z24) + MAC(1*4(R14), Z17, Z25) + MAC(2*4(R14), Z18, Z26) + MAC(3*4(R14), Z19, Z27) + MAC(4*4(R14), Z20, Z28) + MAC(5*4(R14), Z21, Z29) + MAC(6*4(R14), Z22, Z30) + MAC(7*4(R14), Z23, Z31) + ADDQ $32, R14 + DECQ CX // decrement n + JMP loop_12 + +accumulate_14: + // we accumulate the partial products into 544bits in Z1:Z0 + MOVQ $0x0000000000001555, AX + KMOVD AX, K1 + MOVQ $1, AX + KMOVD AX, K2 + + // store the least significant 32 bits of ACC (starts with A0L) in Z0 + VALIGND.Z $16, Z16, Z16, K2, Z0 + KSHIFTLW $1, K2, K2 + VPSRLQ $32, Z16, Z2 + VALIGND.Z $2, Z16, Z16, K1, Z16 + VPADDQ Z2, Z16, Z16 + VPANDQ Z5, Z24, Z2 + VPADDQ Z2, Z16, Z16 + VPANDQ Z5, Z17, Z2 + VPADDQ Z2, Z16, Z16 + VALIGND $15, Z16, Z16, K2, Z0 + KSHIFTLW $1, K2, K2 + + // macro to add partial products and store the result in Z0 +#define ADDPP(in0, in1, in2, in3, in4) \ + VPSRLQ $32, Z16, Z2 \ + VALIGND.Z $2, Z16, Z16, K1, Z16 \ + VPADDQ Z2, Z16, Z16 \ + VPSRLQ $32, in0, in0 \ + VPADDQ in0, Z16, Z16 \ + VPSRLQ $32, in1, in1 \ + VPADDQ in1, Z16, Z16 \ + VPANDQ Z5, in2, Z2 \ + VPADDQ Z2, Z16, Z16 \ + VPANDQ Z5, in3, Z2 \ + VPADDQ Z2, Z16, Z16 \ + VALIGND $16-in4, Z16, Z16, K2, Z0 \ + KADDW K2, K2, K2 \ + + ADDPP(Z24, Z17, Z25, Z18, 2) + ADDPP(Z25, Z18, Z26, Z19, 3) + ADDPP(Z26, Z19, Z27, Z20, 4) + ADDPP(Z27, Z20, Z28, Z21, 5) + ADDPP(Z28, Z21, Z29, Z22, 6) + ADDPP(Z29, Z22, Z30, Z23, 7) + VPSRLQ $32, Z16, Z2 + VALIGND.Z $2, Z16, Z16, K1, Z16 + VPADDQ Z2, Z16, Z16 + VPSRLQ $32, Z30, Z30 + VPADDQ Z30, Z16, Z16 + VPSRLQ $32, Z23, Z23 + VPADDQ Z23, Z16, Z16 + VPANDQ Z5, Z31, Z2 + VPADDQ Z2, Z16, Z16 + VALIGND $16-8, Z16, Z16, K2, Z0 + KSHIFTLW $1, K2, K2 + VPSRLQ $32, Z16, Z2 + VALIGND.Z $2, Z16, Z16, K1, Z16 + VPADDQ Z2, Z16, Z16 + VPSRLQ $32, Z31, Z31 + VPADDQ Z31, Z16, Z16 + VALIGND $16-9, Z16, Z16, K2, Z0 + KSHIFTLW $1, K2, K2 + +#define ADDPP2(in0) \ + VPSRLQ $32, Z16, Z2 \ + VALIGND.Z $2, Z16, Z16, K1, Z16 \ + VPADDQ Z2, Z16, Z16 \ + VALIGND $16-in0, Z16, Z16, K2, Z0 \ + KSHIFTLW $1, K2, K2 \ + + ADDPP2(10) + ADDPP2(11) + ADDPP2(12) + ADDPP2(13) + ADDPP2(14) + ADDPP2(15) + VPSRLQ $32, Z16, Z2 + VALIGND.Z $2, Z16, Z16, K1, Z16 + VPADDQ Z2, Z16, Z16 + VMOVDQA64.Z Z16, K1, Z1 + + // Extract the 4 least significant qwords of Z0 + VMOVQ X0, SI + VALIGNQ $1, Z0, Z1, Z0 + VMOVQ X0, DI + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R8 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R9 + VALIGNQ $1, Z0, Z0, Z0 + XORQ BX, BX + MOVQ $const_qInvNeg, DX + MULXQ SI, DX, R10 + MULXQ ·qElement+0(SB), AX, R10 + ADDQ AX, SI + ADCQ R10, DI + MULXQ ·qElement+16(SB), AX, R10 + ADCQ AX, R8 + ADCQ R10, R9 + ADCQ $0, BX + MULXQ ·qElement+8(SB), AX, R10 + ADDQ AX, DI + ADCQ R10, R8 + MULXQ ·qElement+24(SB), AX, R10 + ADCQ AX, R9 + ADCQ R10, BX + ADCQ $0, SI + MOVQ $const_qInvNeg, DX + MULXQ DI, DX, R10 + MULXQ ·qElement+0(SB), AX, R10 + ADDQ AX, DI + ADCQ R10, R8 + MULXQ ·qElement+16(SB), AX, R10 + ADCQ AX, R9 + ADCQ R10, BX + ADCQ $0, SI + MULXQ ·qElement+8(SB), AX, R10 + ADDQ AX, R8 + ADCQ R10, R9 + MULXQ ·qElement+24(SB), AX, R10 + ADCQ AX, BX + ADCQ R10, SI + ADCQ $0, DI + MOVQ $const_qInvNeg, DX + MULXQ R8, DX, R10 + MULXQ ·qElement+0(SB), AX, R10 + ADDQ AX, R8 + ADCQ R10, R9 + MULXQ ·qElement+16(SB), AX, R10 + ADCQ AX, BX + ADCQ R10, SI + ADCQ $0, DI + MULXQ ·qElement+8(SB), AX, R10 + ADDQ AX, R9 + ADCQ R10, BX + MULXQ ·qElement+24(SB), AX, R10 + ADCQ AX, SI + ADCQ R10, DI + ADCQ $0, R8 + MOVQ $const_qInvNeg, DX + MULXQ R9, DX, R10 + MULXQ ·qElement+0(SB), AX, R10 + ADDQ AX, R9 + ADCQ R10, BX + MULXQ ·qElement+16(SB), AX, R10 + ADCQ AX, SI + ADCQ R10, DI + ADCQ $0, R8 + MULXQ ·qElement+8(SB), AX, R10 + ADDQ AX, BX + ADCQ R10, SI + MULXQ ·qElement+24(SB), AX, R10 + ADCQ AX, DI + ADCQ R10, R8 + ADCQ $0, R9 + VMOVQ X0, AX + ADDQ AX, BX + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, AX + ADCQ AX, SI + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, AX + ADCQ AX, DI + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, AX + ADCQ AX, R8 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, AX + ADCQ AX, R9 + + // Barrett reduction; see Handbook of Applied Cryptography, Algorithm 14.42. + MOVQ R8, AX + SHRQ $32, R9, AX + MOVQ $const_mu, DX + MULQ DX + MULXQ ·qElement+0(SB), AX, R10 + SUBQ AX, BX + SBBQ R10, SI + MULXQ ·qElement+16(SB), AX, R10 + SBBQ AX, DI + SBBQ R10, R8 + SBBQ $0, R9 + MULXQ ·qElement+8(SB), AX, R10 + SUBQ AX, SI + SBBQ R10, DI + MULXQ ·qElement+24(SB), AX, R10 + SBBQ AX, R8 + SBBQ R10, R9 + + // we need up to 2 conditional substractions to be < q + MOVQ res+0(FP), R11 + MOVQ BX, 0(R11) + MOVQ SI, 8(R11) + MOVQ DI, 16(R11) + MOVQ R8, 24(R11) + SUBQ ·qElement+0(SB), BX + SBBQ ·qElement+8(SB), SI + SBBQ ·qElement+16(SB), DI + SBBQ ·qElement+24(SB), R8 + SBBQ $0, R9 + JCS done_13 + MOVQ BX, 0(R11) + MOVQ SI, 8(R11) + MOVQ DI, 16(R11) + MOVQ R8, 24(R11) + SUBQ ·qElement+0(SB), BX + SBBQ ·qElement+8(SB), SI + SBBQ ·qElement+16(SB), DI + SBBQ ·qElement+24(SB), R8 + SBBQ $0, R9 + JCS done_13 + MOVQ BX, 0(R11) + MOVQ SI, 8(R11) + MOVQ DI, 16(R11) + MOVQ R8, 24(R11) + +done_13: + RET + +TEXT ·scalarMulVec(SB), $8-40 +#define AVX_MUL_Q_LO() \ + VPMULUDQ.BCST ·qElement+0(SB), Z9, Z10 \ + VPADDQ Z10, Z0, Z0 \ + VPMULUDQ.BCST ·qElement+4(SB), Z9, Z11 \ + VPADDQ Z11, Z1, Z1 \ + VPMULUDQ.BCST ·qElement+8(SB), Z9, Z12 \ + VPADDQ Z12, Z2, Z2 \ + VPMULUDQ.BCST ·qElement+12(SB), Z9, Z13 \ + VPADDQ Z13, Z3, Z3 \ + +#define AVX_MUL_Q_HI() \ + VPMULUDQ.BCST ·qElement+16(SB), Z9, Z14 \ + VPADDQ Z14, Z4, Z4 \ + VPMULUDQ.BCST ·qElement+20(SB), Z9, Z15 \ + VPADDQ Z15, Z5, Z5 \ + VPMULUDQ.BCST ·qElement+24(SB), Z9, Z16 \ + VPADDQ Z16, Z6, Z6 \ + VPMULUDQ.BCST ·qElement+28(SB), Z9, Z17 \ + VPADDQ Z17, Z7, Z7 \ + +#define SHIFT_ADD_AND(in0, in1, in2, in3) \ + VPSRLQ $32, in0, in1 \ + VPADDQ in1, in2, in2 \ + VPANDQ in3, in2, in0 \ + +#define CARRY1() \ + SHIFT_ADD_AND(Z0, Z10, Z1, Z8) \ + SHIFT_ADD_AND(Z1, Z11, Z2, Z8) \ + SHIFT_ADD_AND(Z2, Z12, Z3, Z8) \ + SHIFT_ADD_AND(Z3, Z13, Z4, Z8) \ + +#define CARRY2() \ + SHIFT_ADD_AND(Z4, Z14, Z5, Z8) \ + SHIFT_ADD_AND(Z5, Z15, Z6, Z8) \ + SHIFT_ADD_AND(Z6, Z16, Z7, Z8) \ + VPSRLQ $32, Z7, Z7 \ + +#define CARRY3() \ + VPSRLQ $32, Z0, Z10 \ + VPANDQ Z8, Z0, Z0 \ + VPADDQ Z10, Z1, Z1 \ + VPSRLQ $32, Z1, Z11 \ + VPANDQ Z8, Z1, Z1 \ + VPADDQ Z11, Z2, Z2 \ + VPSRLQ $32, Z2, Z12 \ + VPANDQ Z8, Z2, Z2 \ + VPADDQ Z12, Z3, Z3 \ + VPSRLQ $32, Z3, Z13 \ + VPANDQ Z8, Z3, Z3 \ + VPADDQ Z13, Z4, Z4 \ + +#define CARRY4() \ + VPSRLQ $32, Z4, Z14 \ + VPANDQ Z8, Z4, Z4 \ + VPADDQ Z14, Z5, Z5 \ + VPSRLQ $32, Z5, Z15 \ + VPANDQ Z8, Z5, Z5 \ + VPADDQ Z15, Z6, Z6 \ + VPSRLQ $32, Z6, Z16 \ + VPANDQ Z8, Z6, Z6 \ + VPADDQ Z16, Z7, Z7 \ + + // t[0] -> R14 + // t[1] -> R13 + // t[2] -> CX + // t[3] -> BX + // y[0] -> DI + // y[1] -> R8 + // y[2] -> R9 + // y[3] -> R10 + MOVQ res+0(FP), SI + MOVQ a+8(FP), R11 + MOVQ b+16(FP), R15 + MOVQ n+24(FP), R12 + MOVQ 0(R15), DI + MOVQ 8(R15), R8 + MOVQ 16(R15), R9 + MOVQ 24(R15), R10 + MOVQ R12, s0-8(SP) + + // Create mask for low dword in each qword + VPCMPEQB Y8, Y8, Y8 + VPMOVZXDQ Y8, Z8 + MOVQ $0x5555, DX + KMOVD DX, K1 + +loop_16: + TESTQ R12, R12 + JEQ done_15 // n == 0, we are done + MOVQ 0(R11), DX + VMOVDQU64 256+0*64(R11), Z16 + VMOVDQU64 256+1*64(R11), Z17 + VMOVDQU64 256+2*64(R11), Z18 + VMOVDQU64 256+3*64(R11), Z19 + VMOVDQU64 0(R15), Z24 + VMOVDQU64 0(R15), Z25 + VMOVDQU64 0(R15), Z26 + VMOVDQU64 0(R15), Z27 + + // Transpose and expand x and y + VSHUFI64X2 $0x88, Z17, Z16, Z20 + VSHUFI64X2 $0xdd, Z17, Z16, Z22 + VSHUFI64X2 $0x88, Z19, Z18, Z21 + VSHUFI64X2 $0xdd, Z19, Z18, Z23 + VSHUFI64X2 $0x88, Z25, Z24, Z28 + VSHUFI64X2 $0xdd, Z25, Z24, Z30 + VSHUFI64X2 $0x88, Z27, Z26, Z29 + VSHUFI64X2 $0xdd, Z27, Z26, Z31 + VPERMQ $0xd8, Z20, Z20 + VPERMQ $0xd8, Z21, Z21 + VPERMQ $0xd8, Z22, Z22 + VPERMQ $0xd8, Z23, Z23 + + // z[0] -> y * x[0] + MUL_WORD_0() + VPERMQ $0xd8, Z28, Z28 + VPERMQ $0xd8, Z29, Z29 + VPERMQ $0xd8, Z30, Z30 + VPERMQ $0xd8, Z31, Z31 + VSHUFI64X2 $0xd8, Z20, Z20, Z20 + VSHUFI64X2 $0xd8, Z21, Z21, Z21 + VSHUFI64X2 $0xd8, Z22, Z22, Z22 + VSHUFI64X2 $0xd8, Z23, Z23, Z23 + + // z[0] -> y * x[1] + MOVQ 8(R11), DX + MUL_WORD_N() + VSHUFI64X2 $0xd8, Z28, Z28, Z28 + VSHUFI64X2 $0xd8, Z29, Z29, Z29 + VSHUFI64X2 $0xd8, Z30, Z30, Z30 + VSHUFI64X2 $0xd8, Z31, Z31, Z31 + VSHUFI64X2 $0x44, Z21, Z20, Z16 + VSHUFI64X2 $0xee, Z21, Z20, Z18 + VSHUFI64X2 $0x44, Z23, Z22, Z20 + VSHUFI64X2 $0xee, Z23, Z22, Z22 + + // z[0] -> y * x[2] + MOVQ 16(R11), DX + MUL_WORD_N() + VSHUFI64X2 $0x44, Z29, Z28, Z24 + VSHUFI64X2 $0xee, Z29, Z28, Z26 + VSHUFI64X2 $0x44, Z31, Z30, Z28 + VSHUFI64X2 $0xee, Z31, Z30, Z30 + PREFETCHT0 1024(R11) + VPSRLQ $32, Z16, Z17 + VPSRLQ $32, Z18, Z19 + VPSRLQ $32, Z20, Z21 + VPSRLQ $32, Z22, Z23 + VPSRLQ $32, Z24, Z25 + VPSRLQ $32, Z26, Z27 + VPSRLQ $32, Z28, Z29 + VPSRLQ $32, Z30, Z31 + + // z[0] -> y * x[3] + MOVQ 24(R11), DX + MUL_WORD_N() + VPANDQ Z8, Z16, Z16 + VPANDQ Z8, Z18, Z18 + VPANDQ Z8, Z20, Z20 + VPANDQ Z8, Z22, Z22 + VPANDQ Z8, Z24, Z24 + VPANDQ Z8, Z26, Z26 + VPANDQ Z8, Z28, Z28 + VPANDQ Z8, Z30, Z30 + + // reduce element(R14,R13,CX,BX) using temp registers (BP,R12,AX,DX) + REDUCE(R14,R13,CX,BX,BP,R12,AX,DX) + + // store output z[0] + MOVQ R14, 0(SI) + MOVQ R13, 8(SI) + MOVQ CX, 16(SI) + MOVQ BX, 24(SI) + ADDQ $32, R11 + MOVQ 0(R11), DX + + // For each 256-bit input value, each zmm register now represents a 32-bit input word zero-extended to 64 bits. + // Multiply y by doubleword 0 of x + VPMULUDQ Z16, Z24, Z0 + VPMULUDQ Z16, Z25, Z1 + VPMULUDQ Z16, Z26, Z2 + VPMULUDQ Z16, Z27, Z3 + VPMULUDQ Z16, Z28, Z4 + VPMULUDQ Z16, Z29, Z5 + VPMULUDQ Z16, Z30, Z6 + VPMULUDQ Z16, Z31, Z7 + VPMULUDQ.BCST qInvNeg+32(FP), Z0, Z9 + VPSRLQ $32, Z0, Z10 + VPANDQ Z8, Z0, Z0 + VPADDQ Z10, Z1, Z1 + VPSRLQ $32, Z1, Z11 + VPANDQ Z8, Z1, Z1 + VPADDQ Z11, Z2, Z2 + VPSRLQ $32, Z2, Z12 + VPANDQ Z8, Z2, Z2 + VPADDQ Z12, Z3, Z3 + VPSRLQ $32, Z3, Z13 + VPANDQ Z8, Z3, Z3 + VPADDQ Z13, Z4, Z4 + + // z[1] -> y * x[0] + MUL_WORD_0() + VPSRLQ $32, Z4, Z14 + VPANDQ Z8, Z4, Z4 + VPADDQ Z14, Z5, Z5 + VPSRLQ $32, Z5, Z15 + VPANDQ Z8, Z5, Z5 + VPADDQ Z15, Z6, Z6 + VPSRLQ $32, Z6, Z16 + VPANDQ Z8, Z6, Z6 + VPADDQ Z16, Z7, Z7 + VPMULUDQ.BCST ·qElement+0(SB), Z9, Z10 + VPADDQ Z10, Z0, Z0 + VPMULUDQ.BCST ·qElement+4(SB), Z9, Z11 + VPADDQ Z11, Z1, Z1 + VPMULUDQ.BCST ·qElement+8(SB), Z9, Z12 + VPADDQ Z12, Z2, Z2 + VPMULUDQ.BCST ·qElement+12(SB), Z9, Z13 + VPADDQ Z13, Z3, Z3 + + // z[1] -> y * x[1] + MOVQ 8(R11), DX + MUL_WORD_N() + VPMULUDQ.BCST ·qElement+16(SB), Z9, Z14 + VPADDQ Z14, Z4, Z4 + VPMULUDQ.BCST ·qElement+20(SB), Z9, Z15 + VPADDQ Z15, Z5, Z5 + VPMULUDQ.BCST ·qElement+24(SB), Z9, Z16 + VPADDQ Z16, Z6, Z6 + VPMULUDQ.BCST ·qElement+28(SB), Z9, Z10 + VPADDQ Z10, Z7, Z7 + CARRY1() + + // z[1] -> y * x[2] + MOVQ 16(R11), DX + MUL_WORD_N() + SHIFT_ADD_AND(Z4, Z14, Z5, Z8) + SHIFT_ADD_AND(Z5, Z15, Z6, Z8) + SHIFT_ADD_AND(Z6, Z16, Z7, Z8) + VPSRLQ $32, Z7, Z7 + + // Process doubleword 1 of x + VPMULUDQ Z17, Z24, Z10 + VPADDQ Z10, Z0, Z0 + VPMULUDQ Z17, Z25, Z11 + VPADDQ Z11, Z1, Z1 + VPMULUDQ Z17, Z26, Z12 + VPADDQ Z12, Z2, Z2 + VPMULUDQ Z17, Z27, Z13 + VPADDQ Z13, Z3, Z3 + + // z[1] -> y * x[3] + MOVQ 24(R11), DX + MUL_WORD_N() + VPMULUDQ Z17, Z28, Z14 + VPADDQ Z14, Z4, Z4 + VPMULUDQ Z17, Z29, Z15 + VPADDQ Z15, Z5, Z5 + VPMULUDQ Z17, Z30, Z16 + VPADDQ Z16, Z6, Z6 + VPMULUDQ Z17, Z31, Z17 + VPADDQ Z17, Z7, Z7 + VPMULUDQ.BCST qInvNeg+32(FP), Z0, Z9 + + // reduce element(R14,R13,CX,BX) using temp registers (BP,R12,AX,DX) + REDUCE(R14,R13,CX,BX,BP,R12,AX,DX) + + // store output z[1] + MOVQ R14, 32(SI) + MOVQ R13, 40(SI) + MOVQ CX, 48(SI) + MOVQ BX, 56(SI) + ADDQ $32, R11 + MOVQ 0(R11), DX + + // Move high dwords to zmm10-16, add each to the corresponding low dword (propagate 32-bit carries) + VPSRLQ $32, Z0, Z10 + VPANDQ Z8, Z0, Z0 + VPADDQ Z10, Z1, Z1 + VPSRLQ $32, Z1, Z11 + VPANDQ Z8, Z1, Z1 + VPADDQ Z11, Z2, Z2 + VPSRLQ $32, Z2, Z12 + VPANDQ Z8, Z2, Z2 + VPADDQ Z12, Z3, Z3 + VPSRLQ $32, Z3, Z13 + VPANDQ Z8, Z3, Z3 + VPADDQ Z13, Z4, Z4 + CARRY4() + + // z[2] -> y * x[0] + MUL_WORD_0() + AVX_MUL_Q_LO() + AVX_MUL_Q_HI() + + // z[2] -> y * x[1] + MOVQ 8(R11), DX + MUL_WORD_N() + CARRY1() + CARRY2() + + // z[2] -> y * x[2] + MOVQ 16(R11), DX + MUL_WORD_N() + + // Process doubleword 2 of x + VPMULUDQ Z18, Z24, Z10 + VPADDQ Z10, Z0, Z0 + VPMULUDQ Z18, Z25, Z11 + VPADDQ Z11, Z1, Z1 + VPMULUDQ Z18, Z26, Z12 + VPADDQ Z12, Z2, Z2 + VPMULUDQ Z18, Z27, Z13 + VPADDQ Z13, Z3, Z3 + VPMULUDQ Z18, Z28, Z14 + VPADDQ Z14, Z4, Z4 + VPMULUDQ Z18, Z29, Z15 + VPADDQ Z15, Z5, Z5 + VPMULUDQ Z18, Z30, Z16 + VPADDQ Z16, Z6, Z6 + VPMULUDQ Z18, Z31, Z17 + VPADDQ Z17, Z7, Z7 + VPMULUDQ.BCST qInvNeg+32(FP), Z0, Z9 + + // z[2] -> y * x[3] + MOVQ 24(R11), DX + MUL_WORD_N() + + // Move high dwords to zmm10-16, add each to the corresponding low dword (propagate 32-bit carries) + CARRY3() + + // reduce element(R14,R13,CX,BX) using temp registers (BP,R12,AX,DX) + REDUCE(R14,R13,CX,BX,BP,R12,AX,DX) + + // store output z[2] + MOVQ R14, 64(SI) + MOVQ R13, 72(SI) + MOVQ CX, 80(SI) + MOVQ BX, 88(SI) + ADDQ $32, R11 + MOVQ 0(R11), DX + CARRY4() + AVX_MUL_Q_LO() + + // z[3] -> y * x[0] + MUL_WORD_0() + AVX_MUL_Q_HI() + CARRY1() + CARRY2() + + // Process doubleword 3 of x + VPMULUDQ Z19, Z24, Z10 + VPADDQ Z10, Z0, Z0 + VPMULUDQ Z19, Z25, Z11 + VPADDQ Z11, Z1, Z1 + VPMULUDQ Z19, Z26, Z12 + VPADDQ Z12, Z2, Z2 + VPMULUDQ Z19, Z27, Z13 + VPADDQ Z13, Z3, Z3 + + // z[3] -> y * x[1] + MOVQ 8(R11), DX + MUL_WORD_N() + VPMULUDQ Z19, Z28, Z14 + VPADDQ Z14, Z4, Z4 + VPMULUDQ Z19, Z29, Z15 + VPADDQ Z15, Z5, Z5 + VPMULUDQ Z19, Z30, Z16 + VPADDQ Z16, Z6, Z6 + VPMULUDQ Z19, Z31, Z17 + VPADDQ Z17, Z7, Z7 + + // z[3] -> y * x[2] + MOVQ 16(R11), DX + MUL_WORD_N() + VPMULUDQ.BCST qInvNeg+32(FP), Z0, Z9 + CARRY3() + CARRY4() + + // z[3] -> y * x[3] + MOVQ 24(R11), DX + MUL_WORD_N() + AVX_MUL_Q_LO() + AVX_MUL_Q_HI() + + // reduce element(R14,R13,CX,BX) using temp registers (BP,R12,AX,DX) + REDUCE(R14,R13,CX,BX,BP,R12,AX,DX) + + // store output z[3] + MOVQ R14, 96(SI) + MOVQ R13, 104(SI) + MOVQ CX, 112(SI) + MOVQ BX, 120(SI) + ADDQ $32, R11 + MOVQ 0(R11), DX + + // Propagate carries and shift down by one dword + CARRY1() + CARRY2() + + // Process doubleword 4 of x + VPMULUDQ Z20, Z24, Z10 + VPADDQ Z10, Z0, Z0 + VPMULUDQ Z20, Z25, Z11 + VPADDQ Z11, Z1, Z1 + VPMULUDQ Z20, Z26, Z12 + VPADDQ Z12, Z2, Z2 + VPMULUDQ Z20, Z27, Z13 + VPADDQ Z13, Z3, Z3 + + // z[4] -> y * x[0] + MUL_WORD_0() + VPMULUDQ Z20, Z28, Z14 + VPADDQ Z14, Z4, Z4 + VPMULUDQ Z20, Z29, Z15 + VPADDQ Z15, Z5, Z5 + VPMULUDQ Z20, Z30, Z16 + VPADDQ Z16, Z6, Z6 + VPMULUDQ Z20, Z31, Z17 + VPADDQ Z17, Z7, Z7 + VPMULUDQ.BCST qInvNeg+32(FP), Z0, Z9 + + // z[4] -> y * x[1] + MOVQ 8(R11), DX + MUL_WORD_N() + + // Move high dwords to zmm10-16, add each to the corresponding low dword (propagate 32-bit carries) + CARRY3() + CARRY4() + + // z[4] -> y * x[2] + MOVQ 16(R11), DX + MUL_WORD_N() + + // zmm7 keeps all 64 bits + AVX_MUL_Q_LO() + AVX_MUL_Q_HI() + + // z[4] -> y * x[3] + MOVQ 24(R11), DX + MUL_WORD_N() + + // Propagate carries and shift down by one dword + CARRY1() + CARRY2() + + // reduce element(R14,R13,CX,BX) using temp registers (BP,R12,AX,DX) + REDUCE(R14,R13,CX,BX,BP,R12,AX,DX) + + // store output z[4] + MOVQ R14, 128(SI) + MOVQ R13, 136(SI) + MOVQ CX, 144(SI) + MOVQ BX, 152(SI) + ADDQ $32, R11 + MOVQ 0(R11), DX + + // Process doubleword 5 of x + VPMULUDQ Z21, Z24, Z10 + VPADDQ Z10, Z0, Z0 + VPMULUDQ Z21, Z25, Z11 + VPADDQ Z11, Z1, Z1 + VPMULUDQ Z21, Z26, Z12 + VPADDQ Z12, Z2, Z2 + VPMULUDQ Z21, Z27, Z13 + VPADDQ Z13, Z3, Z3 + VPMULUDQ Z21, Z28, Z14 + VPADDQ Z14, Z4, Z4 + VPMULUDQ Z21, Z29, Z15 + VPADDQ Z15, Z5, Z5 + VPMULUDQ Z21, Z30, Z16 + VPADDQ Z16, Z6, Z6 + VPMULUDQ Z21, Z31, Z17 + VPADDQ Z17, Z7, Z7 + + // z[5] -> y * x[0] + MUL_WORD_0() + VPMULUDQ.BCST qInvNeg+32(FP), Z0, Z9 + + // Move high dwords to zmm10-16, add each to the corresponding low dword (propagate 32-bit carries) + CARRY3() + CARRY4() + + // z[5] -> y * x[1] + MOVQ 8(R11), DX + MUL_WORD_N() + AVX_MUL_Q_LO() + AVX_MUL_Q_HI() + + // z[5] -> y * x[2] + MOVQ 16(R11), DX + MUL_WORD_N() + CARRY1() + CARRY2() + + // z[5] -> y * x[3] + MOVQ 24(R11), DX + MUL_WORD_N() + + // Process doubleword 6 of x + VPMULUDQ Z22, Z24, Z10 + VPADDQ Z10, Z0, Z0 + VPMULUDQ Z22, Z25, Z11 + VPADDQ Z11, Z1, Z1 + VPMULUDQ Z22, Z26, Z12 + VPADDQ Z12, Z2, Z2 + VPMULUDQ Z22, Z27, Z13 + VPADDQ Z13, Z3, Z3 + VPMULUDQ Z22, Z28, Z14 + VPADDQ Z14, Z4, Z4 + VPMULUDQ Z22, Z29, Z15 + VPADDQ Z15, Z5, Z5 + VPMULUDQ Z22, Z30, Z16 + VPADDQ Z16, Z6, Z6 + VPMULUDQ Z22, Z31, Z17 + VPADDQ Z17, Z7, Z7 + VPMULUDQ.BCST qInvNeg+32(FP), Z0, Z9 + + // reduce element(R14,R13,CX,BX) using temp registers (BP,R12,AX,DX) + REDUCE(R14,R13,CX,BX,BP,R12,AX,DX) + + // store output z[5] + MOVQ R14, 160(SI) + MOVQ R13, 168(SI) + MOVQ CX, 176(SI) + MOVQ BX, 184(SI) + ADDQ $32, R11 + MOVQ 0(R11), DX + + // Move high dwords to zmm10-16, add each to the corresponding low dword (propagate 32-bit carries) + CARRY3() + CARRY4() + + // z[6] -> y * x[0] + MUL_WORD_0() + AVX_MUL_Q_LO() + AVX_MUL_Q_HI() + + // z[6] -> y * x[1] + MOVQ 8(R11), DX + MUL_WORD_N() + CARRY1() + CARRY2() + + // z[6] -> y * x[2] + MOVQ 16(R11), DX + MUL_WORD_N() + + // Process doubleword 7 of x + VPMULUDQ Z23, Z24, Z10 + VPADDQ Z10, Z0, Z0 + VPMULUDQ Z23, Z25, Z11 + VPADDQ Z11, Z1, Z1 + VPMULUDQ Z23, Z26, Z12 + VPADDQ Z12, Z2, Z2 + VPMULUDQ Z23, Z27, Z13 + VPADDQ Z13, Z3, Z3 + VPMULUDQ Z23, Z28, Z14 + VPADDQ Z14, Z4, Z4 + VPMULUDQ Z23, Z29, Z15 + VPADDQ Z15, Z5, Z5 + VPMULUDQ Z23, Z30, Z16 + VPADDQ Z16, Z6, Z6 + VPMULUDQ Z23, Z31, Z17 + VPADDQ Z17, Z7, Z7 + VPMULUDQ.BCST qInvNeg+32(FP), Z0, Z9 + + // z[6] -> y * x[3] + MOVQ 24(R11), DX + MUL_WORD_N() + CARRY3() + + // reduce element(R14,R13,CX,BX) using temp registers (BP,R12,AX,DX) + REDUCE(R14,R13,CX,BX,BP,R12,AX,DX) + + // store output z[6] + MOVQ R14, 192(SI) + MOVQ R13, 200(SI) + MOVQ CX, 208(SI) + MOVQ BX, 216(SI) + ADDQ $32, R11 + MOVQ 0(R11), DX + CARRY4() + AVX_MUL_Q_LO() + AVX_MUL_Q_HI() + + // z[7] -> y * x[0] + MUL_WORD_0() + CARRY1() + CARRY2() + + // z[7] -> y * x[1] + MOVQ 8(R11), DX + MUL_WORD_N() + + // Conditional subtraction of the modulus + VPERMD.BCST.Z ·qElement+0(SB), Z8, K1, Z10 + VPERMD.BCST.Z ·qElement+4(SB), Z8, K1, Z11 + VPERMD.BCST.Z ·qElement+8(SB), Z8, K1, Z12 + VPERMD.BCST.Z ·qElement+12(SB), Z8, K1, Z13 + VPERMD.BCST.Z ·qElement+16(SB), Z8, K1, Z14 + VPERMD.BCST.Z ·qElement+20(SB), Z8, K1, Z15 + VPERMD.BCST.Z ·qElement+24(SB), Z8, K1, Z16 + VPERMD.BCST.Z ·qElement+28(SB), Z8, K1, Z17 + VPSUBQ Z10, Z0, Z10 + VPSRLQ $63, Z10, Z20 + VPANDQ Z8, Z10, Z10 + VPSUBQ Z11, Z1, Z11 + VPSUBQ Z20, Z11, Z11 + VPSRLQ $63, Z11, Z21 + VPANDQ Z8, Z11, Z11 + VPSUBQ Z12, Z2, Z12 + VPSUBQ Z21, Z12, Z12 + VPSRLQ $63, Z12, Z22 + VPANDQ Z8, Z12, Z12 + VPSUBQ Z13, Z3, Z13 + VPSUBQ Z22, Z13, Z13 + VPSRLQ $63, Z13, Z23 + VPANDQ Z8, Z13, Z13 + VPSUBQ Z14, Z4, Z14 + VPSUBQ Z23, Z14, Z14 + VPSRLQ $63, Z14, Z24 + VPANDQ Z8, Z14, Z14 + VPSUBQ Z15, Z5, Z15 + VPSUBQ Z24, Z15, Z15 + VPSRLQ $63, Z15, Z25 + VPANDQ Z8, Z15, Z15 + VPSUBQ Z16, Z6, Z16 + VPSUBQ Z25, Z16, Z16 + VPSRLQ $63, Z16, Z26 + VPANDQ Z8, Z16, Z16 + VPSUBQ Z17, Z7, Z17 + VPSUBQ Z26, Z17, Z17 + VPMOVQ2M Z17, K2 + KNOTB K2, K2 + VMOVDQU64 Z10, K2, Z0 + VMOVDQU64 Z11, K2, Z1 + VMOVDQU64 Z12, K2, Z2 + VMOVDQU64 Z13, K2, Z3 + VMOVDQU64 Z14, K2, Z4 + + // z[7] -> y * x[2] + MOVQ 16(R11), DX + MUL_WORD_N() + VMOVDQU64 Z15, K2, Z5 + VMOVDQU64 Z16, K2, Z6 + VMOVDQU64 Z17, K2, Z7 + + // Transpose results back + VALIGND $0, ·pattern1+0(SB), Z11, Z11 + VALIGND $0, ·pattern2+0(SB), Z12, Z12 + VALIGND $0, ·pattern3+0(SB), Z13, Z13 + VALIGND $0, ·pattern4+0(SB), Z14, Z14 + VPSLLQ $32, Z1, Z1 + VPORQ Z1, Z0, Z0 + VPSLLQ $32, Z3, Z3 + VPORQ Z3, Z2, Z1 + VPSLLQ $32, Z5, Z5 + VPORQ Z5, Z4, Z2 + VPSLLQ $32, Z7, Z7 + VPORQ Z7, Z6, Z3 + VMOVDQU64 Z0, Z4 + VMOVDQU64 Z2, Z6 + + // z[7] -> y * x[3] + MOVQ 24(R11), DX + MUL_WORD_N() + VPERMT2Q Z1, Z11, Z0 + VPERMT2Q Z4, Z12, Z1 + VPERMT2Q Z3, Z11, Z2 + VPERMT2Q Z6, Z12, Z3 + + // reduce element(R14,R13,CX,BX) using temp registers (BP,R12,AX,DX) + REDUCE(R14,R13,CX,BX,BP,R12,AX,DX) + + // store output z[7] + MOVQ R14, 224(SI) + MOVQ R13, 232(SI) + MOVQ CX, 240(SI) + MOVQ BX, 248(SI) + ADDQ $288, R11 + VMOVDQU64 Z0, Z4 + VMOVDQU64 Z1, Z5 + VPERMT2Q Z2, Z13, Z0 + VPERMT2Q Z4, Z14, Z2 + VPERMT2Q Z3, Z13, Z1 + VPERMT2Q Z5, Z14, Z3 + + // Save AVX-512 results + VMOVDQU64 Z0, 256+0*64(SI) + VMOVDQU64 Z2, 256+1*64(SI) + VMOVDQU64 Z1, 256+2*64(SI) + VMOVDQU64 Z3, 256+3*64(SI) + ADDQ $512, SI + MOVQ s0-8(SP), R12 + DECQ R12 // decrement n + MOVQ R12, s0-8(SP) + JMP loop_16 + +done_15: + RET + +TEXT ·mulVec(SB), $8-40 + // t[0] -> R14 + // t[1] -> R13 + // t[2] -> CX + // t[3] -> BX + // y[0] -> DI + // y[1] -> R8 + // y[2] -> R9 + // y[3] -> R10 + MOVQ res+0(FP), SI + MOVQ a+8(FP), R11 + MOVQ b+16(FP), R15 + MOVQ n+24(FP), R12 + MOVQ R12, s0-8(SP) + + // Create mask for low dword in each qword + VPCMPEQB Y8, Y8, Y8 + VPMOVZXDQ Y8, Z8 + MOVQ $0x5555, DX + KMOVD DX, K1 + +loop_18: + TESTQ R12, R12 + JEQ done_17 // n == 0, we are done + MOVQ 0(R11), DX + VMOVDQU64 256+0*64(R11), Z16 + VMOVDQU64 256+1*64(R11), Z17 + VMOVDQU64 256+2*64(R11), Z18 + VMOVDQU64 256+3*64(R11), Z19 + + // load input y[0] + MOVQ 0(R15), DI + MOVQ 8(R15), R8 + MOVQ 16(R15), R9 + MOVQ 24(R15), R10 + VMOVDQU64 256+0*64(R15), Z24 + VMOVDQU64 256+1*64(R15), Z25 + VMOVDQU64 256+2*64(R15), Z26 + VMOVDQU64 256+3*64(R15), Z27 + + // Transpose and expand x and y + VSHUFI64X2 $0x88, Z17, Z16, Z20 + VSHUFI64X2 $0xdd, Z17, Z16, Z22 + VSHUFI64X2 $0x88, Z19, Z18, Z21 + VSHUFI64X2 $0xdd, Z19, Z18, Z23 + VSHUFI64X2 $0x88, Z25, Z24, Z28 + VSHUFI64X2 $0xdd, Z25, Z24, Z30 + VSHUFI64X2 $0x88, Z27, Z26, Z29 + VSHUFI64X2 $0xdd, Z27, Z26, Z31 + VPERMQ $0xd8, Z20, Z20 + VPERMQ $0xd8, Z21, Z21 + VPERMQ $0xd8, Z22, Z22 + VPERMQ $0xd8, Z23, Z23 + + // z[0] -> y * x[0] + MUL_WORD_0() + VPERMQ $0xd8, Z28, Z28 + VPERMQ $0xd8, Z29, Z29 + VPERMQ $0xd8, Z30, Z30 + VPERMQ $0xd8, Z31, Z31 + VSHUFI64X2 $0xd8, Z20, Z20, Z20 + VSHUFI64X2 $0xd8, Z21, Z21, Z21 + VSHUFI64X2 $0xd8, Z22, Z22, Z22 + VSHUFI64X2 $0xd8, Z23, Z23, Z23 + + // z[0] -> y * x[1] + MOVQ 8(R11), DX + MUL_WORD_N() + VSHUFI64X2 $0xd8, Z28, Z28, Z28 + VSHUFI64X2 $0xd8, Z29, Z29, Z29 + VSHUFI64X2 $0xd8, Z30, Z30, Z30 + VSHUFI64X2 $0xd8, Z31, Z31, Z31 + VSHUFI64X2 $0x44, Z21, Z20, Z16 + VSHUFI64X2 $0xee, Z21, Z20, Z18 + VSHUFI64X2 $0x44, Z23, Z22, Z20 + VSHUFI64X2 $0xee, Z23, Z22, Z22 + + // z[0] -> y * x[2] + MOVQ 16(R11), DX + MUL_WORD_N() + VSHUFI64X2 $0x44, Z29, Z28, Z24 + VSHUFI64X2 $0xee, Z29, Z28, Z26 + VSHUFI64X2 $0x44, Z31, Z30, Z28 + VSHUFI64X2 $0xee, Z31, Z30, Z30 + PREFETCHT0 1024(R11) + VPSRLQ $32, Z16, Z17 + VPSRLQ $32, Z18, Z19 + VPSRLQ $32, Z20, Z21 + VPSRLQ $32, Z22, Z23 + VPSRLQ $32, Z24, Z25 + VPSRLQ $32, Z26, Z27 + VPSRLQ $32, Z28, Z29 + VPSRLQ $32, Z30, Z31 + + // z[0] -> y * x[3] + MOVQ 24(R11), DX + MUL_WORD_N() + VPANDQ Z8, Z16, Z16 + VPANDQ Z8, Z18, Z18 + VPANDQ Z8, Z20, Z20 + VPANDQ Z8, Z22, Z22 + VPANDQ Z8, Z24, Z24 + VPANDQ Z8, Z26, Z26 + VPANDQ Z8, Z28, Z28 + VPANDQ Z8, Z30, Z30 + + // reduce element(R14,R13,CX,BX) using temp registers (BP,R12,AX,DX) + REDUCE(R14,R13,CX,BX,BP,R12,AX,DX) + + // store output z[0] + MOVQ R14, 0(SI) + MOVQ R13, 8(SI) + MOVQ CX, 16(SI) + MOVQ BX, 24(SI) + ADDQ $32, R11 + MOVQ 0(R11), DX + + // For each 256-bit input value, each zmm register now represents a 32-bit input word zero-extended to 64 bits. + // Multiply y by doubleword 0 of x + VPMULUDQ Z16, Z24, Z0 + VPMULUDQ Z16, Z25, Z1 + VPMULUDQ Z16, Z26, Z2 + VPMULUDQ Z16, Z27, Z3 + VPMULUDQ Z16, Z28, Z4 + PREFETCHT0 1024(R15) + VPMULUDQ Z16, Z29, Z5 + VPMULUDQ Z16, Z30, Z6 + VPMULUDQ Z16, Z31, Z7 + + // load input y[1] + MOVQ 32(R15), DI + MOVQ 40(R15), R8 + MOVQ 48(R15), R9 + MOVQ 56(R15), R10 + VPMULUDQ.BCST qInvNeg+32(FP), Z0, Z9 + VPSRLQ $32, Z0, Z10 + VPANDQ Z8, Z0, Z0 + VPADDQ Z10, Z1, Z1 + VPSRLQ $32, Z1, Z11 + VPANDQ Z8, Z1, Z1 + VPADDQ Z11, Z2, Z2 + VPSRLQ $32, Z2, Z12 + VPANDQ Z8, Z2, Z2 + VPADDQ Z12, Z3, Z3 + VPSRLQ $32, Z3, Z13 + VPANDQ Z8, Z3, Z3 + VPADDQ Z13, Z4, Z4 + + // z[1] -> y * x[0] + MUL_WORD_0() + VPSRLQ $32, Z4, Z14 + VPANDQ Z8, Z4, Z4 + VPADDQ Z14, Z5, Z5 + VPSRLQ $32, Z5, Z15 + VPANDQ Z8, Z5, Z5 + VPADDQ Z15, Z6, Z6 + VPSRLQ $32, Z6, Z16 + VPANDQ Z8, Z6, Z6 + VPADDQ Z16, Z7, Z7 + VPMULUDQ.BCST ·qElement+0(SB), Z9, Z10 + VPADDQ Z10, Z0, Z0 + VPMULUDQ.BCST ·qElement+4(SB), Z9, Z11 + VPADDQ Z11, Z1, Z1 + VPMULUDQ.BCST ·qElement+8(SB), Z9, Z12 + VPADDQ Z12, Z2, Z2 + VPMULUDQ.BCST ·qElement+12(SB), Z9, Z13 + VPADDQ Z13, Z3, Z3 + + // z[1] -> y * x[1] + MOVQ 8(R11), DX + MUL_WORD_N() + VPMULUDQ.BCST ·qElement+16(SB), Z9, Z14 + VPADDQ Z14, Z4, Z4 + VPMULUDQ.BCST ·qElement+20(SB), Z9, Z15 + VPADDQ Z15, Z5, Z5 + VPMULUDQ.BCST ·qElement+24(SB), Z9, Z16 + VPADDQ Z16, Z6, Z6 + VPMULUDQ.BCST ·qElement+28(SB), Z9, Z10 + VPADDQ Z10, Z7, Z7 + CARRY1() + + // z[1] -> y * x[2] + MOVQ 16(R11), DX + MUL_WORD_N() + SHIFT_ADD_AND(Z4, Z14, Z5, Z8) + SHIFT_ADD_AND(Z5, Z15, Z6, Z8) + SHIFT_ADD_AND(Z6, Z16, Z7, Z8) + VPSRLQ $32, Z7, Z7 + + // Process doubleword 1 of x + VPMULUDQ Z17, Z24, Z10 + VPADDQ Z10, Z0, Z0 + VPMULUDQ Z17, Z25, Z11 + VPADDQ Z11, Z1, Z1 + VPMULUDQ Z17, Z26, Z12 + VPADDQ Z12, Z2, Z2 + VPMULUDQ Z17, Z27, Z13 + VPADDQ Z13, Z3, Z3 + + // z[1] -> y * x[3] + MOVQ 24(R11), DX + MUL_WORD_N() + VPMULUDQ Z17, Z28, Z14 + VPADDQ Z14, Z4, Z4 + VPMULUDQ Z17, Z29, Z15 + VPADDQ Z15, Z5, Z5 + VPMULUDQ Z17, Z30, Z16 + VPADDQ Z16, Z6, Z6 + VPMULUDQ Z17, Z31, Z17 + VPADDQ Z17, Z7, Z7 + VPMULUDQ.BCST qInvNeg+32(FP), Z0, Z9 + + // reduce element(R14,R13,CX,BX) using temp registers (BP,R12,AX,DX) + REDUCE(R14,R13,CX,BX,BP,R12,AX,DX) + + // store output z[1] + MOVQ R14, 32(SI) + MOVQ R13, 40(SI) + MOVQ CX, 48(SI) + MOVQ BX, 56(SI) + ADDQ $32, R11 + MOVQ 0(R11), DX + + // Move high dwords to zmm10-16, add each to the corresponding low dword (propagate 32-bit carries) + VPSRLQ $32, Z0, Z10 + VPANDQ Z8, Z0, Z0 + VPADDQ Z10, Z1, Z1 + VPSRLQ $32, Z1, Z11 + VPANDQ Z8, Z1, Z1 + VPADDQ Z11, Z2, Z2 + VPSRLQ $32, Z2, Z12 + VPANDQ Z8, Z2, Z2 + VPADDQ Z12, Z3, Z3 + + // load input y[2] + MOVQ 64(R15), DI + MOVQ 72(R15), R8 + MOVQ 80(R15), R9 + MOVQ 88(R15), R10 + VPSRLQ $32, Z3, Z13 + VPANDQ Z8, Z3, Z3 + VPADDQ Z13, Z4, Z4 + CARRY4() + + // z[2] -> y * x[0] + MUL_WORD_0() + AVX_MUL_Q_LO() + AVX_MUL_Q_HI() + + // z[2] -> y * x[1] + MOVQ 8(R11), DX + MUL_WORD_N() + CARRY1() + CARRY2() + + // z[2] -> y * x[2] + MOVQ 16(R11), DX + MUL_WORD_N() + + // Process doubleword 2 of x + VPMULUDQ Z18, Z24, Z10 + VPADDQ Z10, Z0, Z0 + VPMULUDQ Z18, Z25, Z11 + VPADDQ Z11, Z1, Z1 + VPMULUDQ Z18, Z26, Z12 + VPADDQ Z12, Z2, Z2 + VPMULUDQ Z18, Z27, Z13 + VPADDQ Z13, Z3, Z3 + VPMULUDQ Z18, Z28, Z14 + VPADDQ Z14, Z4, Z4 + VPMULUDQ Z18, Z29, Z15 + VPADDQ Z15, Z5, Z5 + VPMULUDQ Z18, Z30, Z16 + VPADDQ Z16, Z6, Z6 + VPMULUDQ Z18, Z31, Z17 + VPADDQ Z17, Z7, Z7 + VPMULUDQ.BCST qInvNeg+32(FP), Z0, Z9 + + // z[2] -> y * x[3] + MOVQ 24(R11), DX + MUL_WORD_N() + + // Move high dwords to zmm10-16, add each to the corresponding low dword (propagate 32-bit carries) + CARRY3() + + // reduce element(R14,R13,CX,BX) using temp registers (BP,R12,AX,DX) + REDUCE(R14,R13,CX,BX,BP,R12,AX,DX) + + // store output z[2] + MOVQ R14, 64(SI) + MOVQ R13, 72(SI) + MOVQ CX, 80(SI) + MOVQ BX, 88(SI) + ADDQ $32, R11 + MOVQ 0(R11), DX + + // load input y[3] + MOVQ 96(R15), DI + MOVQ 104(R15), R8 + MOVQ 112(R15), R9 + MOVQ 120(R15), R10 + CARRY4() + AVX_MUL_Q_LO() + + // z[3] -> y * x[0] + MUL_WORD_0() + AVX_MUL_Q_HI() + CARRY1() + CARRY2() + + // Process doubleword 3 of x + VPMULUDQ Z19, Z24, Z10 + VPADDQ Z10, Z0, Z0 + VPMULUDQ Z19, Z25, Z11 + VPADDQ Z11, Z1, Z1 + VPMULUDQ Z19, Z26, Z12 + VPADDQ Z12, Z2, Z2 + VPMULUDQ Z19, Z27, Z13 + VPADDQ Z13, Z3, Z3 + + // z[3] -> y * x[1] + MOVQ 8(R11), DX + MUL_WORD_N() + VPMULUDQ Z19, Z28, Z14 + VPADDQ Z14, Z4, Z4 + VPMULUDQ Z19, Z29, Z15 + VPADDQ Z15, Z5, Z5 + VPMULUDQ Z19, Z30, Z16 + VPADDQ Z16, Z6, Z6 + VPMULUDQ Z19, Z31, Z17 + VPADDQ Z17, Z7, Z7 + + // z[3] -> y * x[2] + MOVQ 16(R11), DX + MUL_WORD_N() + VPMULUDQ.BCST qInvNeg+32(FP), Z0, Z9 + CARRY3() + CARRY4() + + // z[3] -> y * x[3] + MOVQ 24(R11), DX + MUL_WORD_N() + AVX_MUL_Q_LO() + AVX_MUL_Q_HI() + + // reduce element(R14,R13,CX,BX) using temp registers (BP,R12,AX,DX) + REDUCE(R14,R13,CX,BX,BP,R12,AX,DX) + + // store output z[3] + MOVQ R14, 96(SI) + MOVQ R13, 104(SI) + MOVQ CX, 112(SI) + MOVQ BX, 120(SI) + ADDQ $32, R11 + MOVQ 0(R11), DX + + // Propagate carries and shift down by one dword + CARRY1() + CARRY2() + + // load input y[4] + MOVQ 128(R15), DI + MOVQ 136(R15), R8 + MOVQ 144(R15), R9 + MOVQ 152(R15), R10 + + // Process doubleword 4 of x + VPMULUDQ Z20, Z24, Z10 + VPADDQ Z10, Z0, Z0 + VPMULUDQ Z20, Z25, Z11 + VPADDQ Z11, Z1, Z1 + VPMULUDQ Z20, Z26, Z12 + VPADDQ Z12, Z2, Z2 + VPMULUDQ Z20, Z27, Z13 + VPADDQ Z13, Z3, Z3 + + // z[4] -> y * x[0] + MUL_WORD_0() + VPMULUDQ Z20, Z28, Z14 + VPADDQ Z14, Z4, Z4 + VPMULUDQ Z20, Z29, Z15 + VPADDQ Z15, Z5, Z5 + VPMULUDQ Z20, Z30, Z16 + VPADDQ Z16, Z6, Z6 + VPMULUDQ Z20, Z31, Z17 + VPADDQ Z17, Z7, Z7 + VPMULUDQ.BCST qInvNeg+32(FP), Z0, Z9 + + // z[4] -> y * x[1] + MOVQ 8(R11), DX + MUL_WORD_N() + + // Move high dwords to zmm10-16, add each to the corresponding low dword (propagate 32-bit carries) + CARRY3() + CARRY4() + + // z[4] -> y * x[2] + MOVQ 16(R11), DX + MUL_WORD_N() + + // zmm7 keeps all 64 bits + AVX_MUL_Q_LO() + AVX_MUL_Q_HI() + + // z[4] -> y * x[3] + MOVQ 24(R11), DX + MUL_WORD_N() + + // Propagate carries and shift down by one dword + CARRY1() + CARRY2() + + // reduce element(R14,R13,CX,BX) using temp registers (BP,R12,AX,DX) + REDUCE(R14,R13,CX,BX,BP,R12,AX,DX) + + // store output z[4] + MOVQ R14, 128(SI) + MOVQ R13, 136(SI) + MOVQ CX, 144(SI) + MOVQ BX, 152(SI) + ADDQ $32, R11 + MOVQ 0(R11), DX + + // Process doubleword 5 of x + VPMULUDQ Z21, Z24, Z10 + VPADDQ Z10, Z0, Z0 + VPMULUDQ Z21, Z25, Z11 + VPADDQ Z11, Z1, Z1 + VPMULUDQ Z21, Z26, Z12 + VPADDQ Z12, Z2, Z2 + VPMULUDQ Z21, Z27, Z13 + VPADDQ Z13, Z3, Z3 + + // load input y[5] + MOVQ 160(R15), DI + MOVQ 168(R15), R8 + MOVQ 176(R15), R9 + MOVQ 184(R15), R10 + VPMULUDQ Z21, Z28, Z14 + VPADDQ Z14, Z4, Z4 + VPMULUDQ Z21, Z29, Z15 + VPADDQ Z15, Z5, Z5 + VPMULUDQ Z21, Z30, Z16 + VPADDQ Z16, Z6, Z6 + VPMULUDQ Z21, Z31, Z17 + VPADDQ Z17, Z7, Z7 + + // z[5] -> y * x[0] + MUL_WORD_0() + VPMULUDQ.BCST qInvNeg+32(FP), Z0, Z9 + + // Move high dwords to zmm10-16, add each to the corresponding low dword (propagate 32-bit carries) + CARRY3() + CARRY4() + + // z[5] -> y * x[1] + MOVQ 8(R11), DX + MUL_WORD_N() + AVX_MUL_Q_LO() + AVX_MUL_Q_HI() + + // z[5] -> y * x[2] + MOVQ 16(R11), DX + MUL_WORD_N() + CARRY1() + CARRY2() + + // z[5] -> y * x[3] + MOVQ 24(R11), DX + MUL_WORD_N() + + // Process doubleword 6 of x + VPMULUDQ Z22, Z24, Z10 + VPADDQ Z10, Z0, Z0 + VPMULUDQ Z22, Z25, Z11 + VPADDQ Z11, Z1, Z1 + VPMULUDQ Z22, Z26, Z12 + VPADDQ Z12, Z2, Z2 + VPMULUDQ Z22, Z27, Z13 + VPADDQ Z13, Z3, Z3 + VPMULUDQ Z22, Z28, Z14 + VPADDQ Z14, Z4, Z4 + VPMULUDQ Z22, Z29, Z15 + VPADDQ Z15, Z5, Z5 + VPMULUDQ Z22, Z30, Z16 + VPADDQ Z16, Z6, Z6 + VPMULUDQ Z22, Z31, Z17 + VPADDQ Z17, Z7, Z7 + VPMULUDQ.BCST qInvNeg+32(FP), Z0, Z9 + + // reduce element(R14,R13,CX,BX) using temp registers (BP,R12,AX,DX) + REDUCE(R14,R13,CX,BX,BP,R12,AX,DX) + + // store output z[5] + MOVQ R14, 160(SI) + MOVQ R13, 168(SI) + MOVQ CX, 176(SI) + MOVQ BX, 184(SI) + ADDQ $32, R11 + MOVQ 0(R11), DX + + // Move high dwords to zmm10-16, add each to the corresponding low dword (propagate 32-bit carries) + CARRY3() + + // load input y[6] + MOVQ 192(R15), DI + MOVQ 200(R15), R8 + MOVQ 208(R15), R9 + MOVQ 216(R15), R10 + CARRY4() + + // z[6] -> y * x[0] + MUL_WORD_0() + AVX_MUL_Q_LO() + AVX_MUL_Q_HI() + + // z[6] -> y * x[1] + MOVQ 8(R11), DX + MUL_WORD_N() + CARRY1() + CARRY2() + + // z[6] -> y * x[2] + MOVQ 16(R11), DX + MUL_WORD_N() + + // Process doubleword 7 of x + VPMULUDQ Z23, Z24, Z10 + VPADDQ Z10, Z0, Z0 + VPMULUDQ Z23, Z25, Z11 + VPADDQ Z11, Z1, Z1 + VPMULUDQ Z23, Z26, Z12 + VPADDQ Z12, Z2, Z2 + VPMULUDQ Z23, Z27, Z13 + VPADDQ Z13, Z3, Z3 + VPMULUDQ Z23, Z28, Z14 + VPADDQ Z14, Z4, Z4 + VPMULUDQ Z23, Z29, Z15 + VPADDQ Z15, Z5, Z5 + VPMULUDQ Z23, Z30, Z16 + VPADDQ Z16, Z6, Z6 + VPMULUDQ Z23, Z31, Z17 + VPADDQ Z17, Z7, Z7 + VPMULUDQ.BCST qInvNeg+32(FP), Z0, Z9 + + // z[6] -> y * x[3] + MOVQ 24(R11), DX + MUL_WORD_N() + CARRY3() + + // reduce element(R14,R13,CX,BX) using temp registers (BP,R12,AX,DX) + REDUCE(R14,R13,CX,BX,BP,R12,AX,DX) + + // store output z[6] + MOVQ R14, 192(SI) + MOVQ R13, 200(SI) + MOVQ CX, 208(SI) + MOVQ BX, 216(SI) + ADDQ $32, R11 + MOVQ 0(R11), DX + CARRY4() + + // load input y[7] + MOVQ 224(R15), DI + MOVQ 232(R15), R8 + MOVQ 240(R15), R9 + MOVQ 248(R15), R10 + AVX_MUL_Q_LO() + AVX_MUL_Q_HI() + + // z[7] -> y * x[0] + MUL_WORD_0() + CARRY1() + CARRY2() + + // z[7] -> y * x[1] + MOVQ 8(R11), DX + MUL_WORD_N() + + // Conditional subtraction of the modulus + VPERMD.BCST.Z ·qElement+0(SB), Z8, K1, Z10 + VPERMD.BCST.Z ·qElement+4(SB), Z8, K1, Z11 + VPERMD.BCST.Z ·qElement+8(SB), Z8, K1, Z12 + VPERMD.BCST.Z ·qElement+12(SB), Z8, K1, Z13 + VPERMD.BCST.Z ·qElement+16(SB), Z8, K1, Z14 + VPERMD.BCST.Z ·qElement+20(SB), Z8, K1, Z15 + VPERMD.BCST.Z ·qElement+24(SB), Z8, K1, Z16 + VPERMD.BCST.Z ·qElement+28(SB), Z8, K1, Z17 + VPSUBQ Z10, Z0, Z10 + VPSRLQ $63, Z10, Z20 + VPANDQ Z8, Z10, Z10 + VPSUBQ Z11, Z1, Z11 + VPSUBQ Z20, Z11, Z11 + VPSRLQ $63, Z11, Z21 + VPANDQ Z8, Z11, Z11 + VPSUBQ Z12, Z2, Z12 + VPSUBQ Z21, Z12, Z12 + VPSRLQ $63, Z12, Z22 + VPANDQ Z8, Z12, Z12 + VPSUBQ Z13, Z3, Z13 + VPSUBQ Z22, Z13, Z13 + VPSRLQ $63, Z13, Z23 + VPANDQ Z8, Z13, Z13 + VPSUBQ Z14, Z4, Z14 + VPSUBQ Z23, Z14, Z14 + VPSRLQ $63, Z14, Z24 + VPANDQ Z8, Z14, Z14 + VPSUBQ Z15, Z5, Z15 + VPSUBQ Z24, Z15, Z15 + VPSRLQ $63, Z15, Z25 + VPANDQ Z8, Z15, Z15 + VPSUBQ Z16, Z6, Z16 + VPSUBQ Z25, Z16, Z16 + VPSRLQ $63, Z16, Z26 + VPANDQ Z8, Z16, Z16 + VPSUBQ Z17, Z7, Z17 + VPSUBQ Z26, Z17, Z17 + VPMOVQ2M Z17, K2 + KNOTB K2, K2 + VMOVDQU64 Z10, K2, Z0 + VMOVDQU64 Z11, K2, Z1 + VMOVDQU64 Z12, K2, Z2 + VMOVDQU64 Z13, K2, Z3 + VMOVDQU64 Z14, K2, Z4 + + // z[7] -> y * x[2] + MOVQ 16(R11), DX + MUL_WORD_N() + VMOVDQU64 Z15, K2, Z5 + VMOVDQU64 Z16, K2, Z6 + VMOVDQU64 Z17, K2, Z7 + + // Transpose results back + VALIGND $0, ·pattern1+0(SB), Z11, Z11 + VALIGND $0, ·pattern2+0(SB), Z12, Z12 + VALIGND $0, ·pattern3+0(SB), Z13, Z13 + VALIGND $0, ·pattern4+0(SB), Z14, Z14 + VPSLLQ $32, Z1, Z1 + VPORQ Z1, Z0, Z0 + VPSLLQ $32, Z3, Z3 + VPORQ Z3, Z2, Z1 + VPSLLQ $32, Z5, Z5 + VPORQ Z5, Z4, Z2 + VPSLLQ $32, Z7, Z7 + VPORQ Z7, Z6, Z3 + VMOVDQU64 Z0, Z4 + VMOVDQU64 Z2, Z6 + + // z[7] -> y * x[3] + MOVQ 24(R11), DX + MUL_WORD_N() + VPERMT2Q Z1, Z11, Z0 + VPERMT2Q Z4, Z12, Z1 + VPERMT2Q Z3, Z11, Z2 + VPERMT2Q Z6, Z12, Z3 + + // reduce element(R14,R13,CX,BX) using temp registers (BP,R12,AX,DX) + REDUCE(R14,R13,CX,BX,BP,R12,AX,DX) + + // store output z[7] + MOVQ R14, 224(SI) + MOVQ R13, 232(SI) + MOVQ CX, 240(SI) + MOVQ BX, 248(SI) + ADDQ $288, R11 + VMOVDQU64 Z0, Z4 + VMOVDQU64 Z1, Z5 + VPERMT2Q Z2, Z13, Z0 + VPERMT2Q Z4, Z14, Z2 + VPERMT2Q Z3, Z13, Z1 + VPERMT2Q Z5, Z14, Z3 + + // Save AVX-512 results + VMOVDQU64 Z0, 256+0*64(SI) + VMOVDQU64 Z2, 256+1*64(SI) + VMOVDQU64 Z1, 256+2*64(SI) + VMOVDQU64 Z3, 256+3*64(SI) + ADDQ $512, SI + ADDQ $512, R15 + MOVQ s0-8(SP), R12 + DECQ R12 // decrement n + MOVQ R12, s0-8(SP) + JMP loop_18 + +done_17: + RET diff --git a/field/asm/element_5w_amd64.s b/field/asm/element_5w_amd64.s new file mode 100644 index 000000000..b3ccd84b1 --- /dev/null +++ b/field/asm/element_5w_amd64.s @@ -0,0 +1,563 @@ +// Code generated by gnark-crypto/generator. DO NOT EDIT. +#include "textflag.h" +#include "funcdata.h" +#include "go_asm.h" + +#define REDUCE(ra0, ra1, ra2, ra3, ra4, rb0, rb1, rb2, rb3, rb4) \ + MOVQ ra0, rb0; \ + SUBQ ·qElement(SB), ra0; \ + MOVQ ra1, rb1; \ + SBBQ ·qElement+8(SB), ra1; \ + MOVQ ra2, rb2; \ + SBBQ ·qElement+16(SB), ra2; \ + MOVQ ra3, rb3; \ + SBBQ ·qElement+24(SB), ra3; \ + MOVQ ra4, rb4; \ + SBBQ ·qElement+32(SB), ra4; \ + CMOVQCS rb0, ra0; \ + CMOVQCS rb1, ra1; \ + CMOVQCS rb2, ra2; \ + CMOVQCS rb3, ra3; \ + CMOVQCS rb4, ra4; \ + +TEXT ·reduce(SB), NOSPLIT, $0-8 + MOVQ res+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + + // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + RET + +// MulBy3(x *Element) +TEXT ·MulBy3(SB), NOSPLIT, $0-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + + // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + + // reduce element(DX,CX,BX,SI,DI) using temp registers (R13,R14,R15,R8,R9) + REDUCE(DX,CX,BX,SI,DI,R13,R14,R15,R8,R9) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + RET + +// MulBy5(x *Element) +TEXT ·MulBy5(SB), NOSPLIT, $0-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + + // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) + + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + + // reduce element(DX,CX,BX,SI,DI) using temp registers (R13,R14,R15,R8,R9) + REDUCE(DX,CX,BX,SI,DI,R13,R14,R15,R8,R9) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + + // reduce element(DX,CX,BX,SI,DI) using temp registers (R10,R11,R12,R13,R14) + REDUCE(DX,CX,BX,SI,DI,R10,R11,R12,R13,R14) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + RET + +// MulBy13(x *Element) +TEXT ·MulBy13(SB), $16-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + + // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) + + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + + // reduce element(DX,CX,BX,SI,DI) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP)) + REDUCE(DX,CX,BX,SI,DI,R13,R14,R15,s0-8(SP),s1-16(SP)) + + MOVQ DX, R13 + MOVQ CX, R14 + MOVQ BX, R15 + MOVQ SI, s0-8(SP) + MOVQ DI, s1-16(SP) + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + + // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) + + ADDQ R13, DX + ADCQ R14, CX + ADCQ R15, BX + ADCQ s0-8(SP), SI + ADCQ s1-16(SP), DI + + // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + + // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + RET + +// Butterfly(a, b *Element) sets a = a + b; b = a - b +TEXT ·Butterfly(SB), $24-16 + MOVQ a+0(FP), AX + MOVQ 0(AX), CX + MOVQ 8(AX), BX + MOVQ 16(AX), SI + MOVQ 24(AX), DI + MOVQ 32(AX), R8 + MOVQ CX, R9 + MOVQ BX, R10 + MOVQ SI, R11 + MOVQ DI, R12 + MOVQ R8, R13 + XORQ AX, AX + MOVQ b+8(FP), DX + ADDQ 0(DX), CX + ADCQ 8(DX), BX + ADCQ 16(DX), SI + ADCQ 24(DX), DI + ADCQ 32(DX), R8 + SUBQ 0(DX), R9 + SBBQ 8(DX), R10 + SBBQ 16(DX), R11 + SBBQ 24(DX), R12 + SBBQ 32(DX), R13 + MOVQ CX, R14 + MOVQ BX, R15 + MOVQ SI, s0-8(SP) + MOVQ DI, s1-16(SP) + MOVQ R8, s2-24(SP) + MOVQ $const_q0, CX + MOVQ $const_q1, BX + MOVQ $const_q2, SI + MOVQ $const_q3, DI + MOVQ $const_q4, R8 + CMOVQCC AX, CX + CMOVQCC AX, BX + CMOVQCC AX, SI + CMOVQCC AX, DI + CMOVQCC AX, R8 + ADDQ CX, R9 + ADCQ BX, R10 + ADCQ SI, R11 + ADCQ DI, R12 + ADCQ R8, R13 + MOVQ R14, CX + MOVQ R15, BX + MOVQ s0-8(SP), SI + MOVQ s1-16(SP), DI + MOVQ s2-24(SP), R8 + MOVQ R9, 0(DX) + MOVQ R10, 8(DX) + MOVQ R11, 16(DX) + MOVQ R12, 24(DX) + MOVQ R13, 32(DX) + + // reduce element(CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13) + REDUCE(CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) + + MOVQ a+0(FP), AX + MOVQ CX, 0(AX) + MOVQ BX, 8(AX) + MOVQ SI, 16(AX) + MOVQ DI, 24(AX) + MOVQ R8, 32(AX) + RET + +// mul(res, x, y *Element) +TEXT ·mul(SB), $24-24 + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + // See github.com/gnark-crypto/field/generator for more comments. + + NO_LOCAL_POINTERS + CMPB ·supportAdx(SB), $1 + JNE noAdx_1 + MOVQ x+8(FP), DI + + // x[0] -> R9 + // x[1] -> R10 + // x[2] -> R11 + MOVQ 0(DI), R9 + MOVQ 8(DI), R10 + MOVQ 16(DI), R11 + MOVQ y+16(FP), R12 + + // A -> BP + // t[0] -> R14 + // t[1] -> R13 + // t[2] -> CX + // t[3] -> BX + // t[4] -> SI +#define MACC(in0, in1, in2) \ + ADCXQ in0, in1 \ + MULXQ in2, AX, in0 \ + ADOXQ AX, in1 \ + +#define DIV_SHIFT() \ + MOVQ $const_qInvNeg, DX \ + IMULQ R14, DX \ + XORQ AX, AX \ + MULXQ ·qElement+0(SB), AX, R8 \ + ADCXQ R14, AX \ + MOVQ R8, R14 \ + MACC(R13, R14, ·qElement+8(SB)) \ + MACC(CX, R13, ·qElement+16(SB)) \ + MACC(BX, CX, ·qElement+24(SB)) \ + MACC(SI, BX, ·qElement+32(SB)) \ + MOVQ $0, AX \ + ADCXQ AX, SI \ + ADOXQ BP, SI \ + +#define MUL_WORD_0() \ + XORQ AX, AX \ + MULXQ R9, R14, R13 \ + MULXQ R10, AX, CX \ + ADOXQ AX, R13 \ + MULXQ R11, AX, BX \ + ADOXQ AX, CX \ + MULXQ 24(DI), AX, SI \ + ADOXQ AX, BX \ + MULXQ 32(DI), AX, BP \ + ADOXQ AX, SI \ + MOVQ $0, AX \ + ADOXQ AX, BP \ + DIV_SHIFT() \ + +#define MUL_WORD_N() \ + XORQ AX, AX \ + MULXQ R9, AX, BP \ + ADOXQ AX, R14 \ + MACC(BP, R13, R10) \ + MACC(BP, CX, R11) \ + MACC(BP, BX, 24(DI)) \ + MACC(BP, SI, 32(DI)) \ + MOVQ $0, AX \ + ADCXQ AX, BP \ + ADOXQ AX, BP \ + DIV_SHIFT() \ + + // mul body + MOVQ 0(R12), DX + MUL_WORD_0() + MOVQ 8(R12), DX + MUL_WORD_N() + MOVQ 16(R12), DX + MUL_WORD_N() + MOVQ 24(R12), DX + MUL_WORD_N() + MOVQ 32(R12), DX + MUL_WORD_N() + + // reduce element(R14,R13,CX,BX,SI) using temp registers (R8,DI,R12,R9,R10) + REDUCE(R14,R13,CX,BX,SI,R8,DI,R12,R9,R10) + + MOVQ res+0(FP), AX + MOVQ R14, 0(AX) + MOVQ R13, 8(AX) + MOVQ CX, 16(AX) + MOVQ BX, 24(AX) + MOVQ SI, 32(AX) + RET + +noAdx_1: + MOVQ res+0(FP), AX + MOVQ AX, (SP) + MOVQ x+8(FP), AX + MOVQ AX, 8(SP) + MOVQ y+16(FP), AX + MOVQ AX, 16(SP) + CALL ·_mulGeneric(SB) + RET + +TEXT ·fromMont(SB), $8-8 + NO_LOCAL_POINTERS + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + // when y = 1 we have: + // for i=0 to N-1 + // t[i] = x[i] + // for i=0 to N-1 + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // t[N-1] = C + CMPB ·supportAdx(SB), $1 + JNE noAdx_2 + MOVQ res+0(FP), DX + MOVQ 0(DX), R14 + MOVQ 8(DX), R13 + MOVQ 16(DX), CX + MOVQ 24(DX), BX + MOVQ 32(DX), SI + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R13, R14 + MULXQ ·qElement+8(SB), AX, R13 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R13 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R13 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + MOVQ $0, AX + ADCXQ AX, SI + ADOXQ AX, SI + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R13, R14 + MULXQ ·qElement+8(SB), AX, R13 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R13 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R13 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + MOVQ $0, AX + ADCXQ AX, SI + ADOXQ AX, SI + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R13, R14 + MULXQ ·qElement+8(SB), AX, R13 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R13 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R13 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + MOVQ $0, AX + ADCXQ AX, SI + ADOXQ AX, SI + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R13, R14 + MULXQ ·qElement+8(SB), AX, R13 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R13 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R13 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + MOVQ $0, AX + ADCXQ AX, SI + ADOXQ AX, SI + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R13, R14 + MULXQ ·qElement+8(SB), AX, R13 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R13 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R13 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + MOVQ $0, AX + ADCXQ AX, SI + ADOXQ AX, SI + + // reduce element(R14,R13,CX,BX,SI) using temp registers (DI,R8,R9,R10,R11) + REDUCE(R14,R13,CX,BX,SI,DI,R8,R9,R10,R11) + + MOVQ res+0(FP), AX + MOVQ R14, 0(AX) + MOVQ R13, 8(AX) + MOVQ CX, 16(AX) + MOVQ BX, 24(AX) + MOVQ SI, 32(AX) + RET + +noAdx_2: + MOVQ res+0(FP), AX + MOVQ AX, (SP) + CALL ·_fromMontGeneric(SB) + RET diff --git a/field/asm/element_6w_amd64.s b/field/asm/element_6w_amd64.s new file mode 100644 index 000000000..b8172b287 --- /dev/null +++ b/field/asm/element_6w_amd64.s @@ -0,0 +1,670 @@ +// Code generated by gnark-crypto/generator. DO NOT EDIT. +#include "textflag.h" +#include "funcdata.h" +#include "go_asm.h" + +#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, rb0, rb1, rb2, rb3, rb4, rb5) \ + MOVQ ra0, rb0; \ + SUBQ ·qElement(SB), ra0; \ + MOVQ ra1, rb1; \ + SBBQ ·qElement+8(SB), ra1; \ + MOVQ ra2, rb2; \ + SBBQ ·qElement+16(SB), ra2; \ + MOVQ ra3, rb3; \ + SBBQ ·qElement+24(SB), ra3; \ + MOVQ ra4, rb4; \ + SBBQ ·qElement+32(SB), ra4; \ + MOVQ ra5, rb5; \ + SBBQ ·qElement+40(SB), ra5; \ + CMOVQCS rb0, ra0; \ + CMOVQCS rb1, ra1; \ + CMOVQCS rb2, ra2; \ + CMOVQCS rb3, ra3; \ + CMOVQCS rb4, ra4; \ + CMOVQCS rb5, ra5; \ + +TEXT ·reduce(SB), NOSPLIT, $0-8 + MOVQ res+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + + // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + RET + +// MulBy3(x *Element) +TEXT ·MulBy3(SB), NOSPLIT, $0-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + + // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + ADCQ 40(AX), R8 + + // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R15,R9,R10,R11,R12,R13) + REDUCE(DX,CX,BX,SI,DI,R8,R15,R9,R10,R11,R12,R13) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + RET + +// MulBy5(x *Element) +TEXT ·MulBy5(SB), NOSPLIT, $0-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + + // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) + + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + + // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R15,R9,R10,R11,R12,R13) + REDUCE(DX,CX,BX,SI,DI,R8,R15,R9,R10,R11,R12,R13) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + ADCQ 40(AX), R8 + + // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R14,R15,R9,R10,R11,R12) + REDUCE(DX,CX,BX,SI,DI,R8,R14,R15,R9,R10,R11,R12) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + RET + +// MulBy13(x *Element) +TEXT ·MulBy13(SB), $40-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + + // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) + + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + + // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP)) + + MOVQ DX, R15 + MOVQ CX, s0-8(SP) + MOVQ BX, s1-16(SP) + MOVQ SI, s2-24(SP) + MOVQ DI, s3-32(SP) + MOVQ R8, s4-40(SP) + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + + // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) + + ADDQ R15, DX + ADCQ s0-8(SP), CX + ADCQ s1-16(SP), BX + ADCQ s2-24(SP), SI + ADCQ s3-32(SP), DI + ADCQ s4-40(SP), R8 + + // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + ADCQ 40(AX), R8 + + // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + RET + +// Butterfly(a, b *Element) sets a = a + b; b = a - b +TEXT ·Butterfly(SB), $48-16 + MOVQ a+0(FP), AX + MOVQ 0(AX), CX + MOVQ 8(AX), BX + MOVQ 16(AX), SI + MOVQ 24(AX), DI + MOVQ 32(AX), R8 + MOVQ 40(AX), R9 + MOVQ CX, R10 + MOVQ BX, R11 + MOVQ SI, R12 + MOVQ DI, R13 + MOVQ R8, R14 + MOVQ R9, R15 + XORQ AX, AX + MOVQ b+8(FP), DX + ADDQ 0(DX), CX + ADCQ 8(DX), BX + ADCQ 16(DX), SI + ADCQ 24(DX), DI + ADCQ 32(DX), R8 + ADCQ 40(DX), R9 + SUBQ 0(DX), R10 + SBBQ 8(DX), R11 + SBBQ 16(DX), R12 + SBBQ 24(DX), R13 + SBBQ 32(DX), R14 + SBBQ 40(DX), R15 + MOVQ CX, s0-8(SP) + MOVQ BX, s1-16(SP) + MOVQ SI, s2-24(SP) + MOVQ DI, s3-32(SP) + MOVQ R8, s4-40(SP) + MOVQ R9, s5-48(SP) + MOVQ $const_q0, CX + MOVQ $const_q1, BX + MOVQ $const_q2, SI + MOVQ $const_q3, DI + MOVQ $const_q4, R8 + MOVQ $const_q5, R9 + CMOVQCC AX, CX + CMOVQCC AX, BX + CMOVQCC AX, SI + CMOVQCC AX, DI + CMOVQCC AX, R8 + CMOVQCC AX, R9 + ADDQ CX, R10 + ADCQ BX, R11 + ADCQ SI, R12 + ADCQ DI, R13 + ADCQ R8, R14 + ADCQ R9, R15 + MOVQ s0-8(SP), CX + MOVQ s1-16(SP), BX + MOVQ s2-24(SP), SI + MOVQ s3-32(SP), DI + MOVQ s4-40(SP), R8 + MOVQ s5-48(SP), R9 + MOVQ R10, 0(DX) + MOVQ R11, 8(DX) + MOVQ R12, 16(DX) + MOVQ R13, 24(DX) + MOVQ R14, 32(DX) + MOVQ R15, 40(DX) + + // reduce element(CX,BX,SI,DI,R8,R9) using temp registers (R10,R11,R12,R13,R14,R15) + REDUCE(CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15) + + MOVQ a+0(FP), AX + MOVQ CX, 0(AX) + MOVQ BX, 8(AX) + MOVQ SI, 16(AX) + MOVQ DI, 24(AX) + MOVQ R8, 32(AX) + MOVQ R9, 40(AX) + RET + +// mul(res, x, y *Element) +TEXT ·mul(SB), $24-24 + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + // See github.com/gnark-crypto/field/generator for more comments. + + NO_LOCAL_POINTERS + CMPB ·supportAdx(SB), $1 + JNE noAdx_1 + MOVQ x+8(FP), R8 + + // x[0] -> R10 + // x[1] -> R11 + // x[2] -> R12 + MOVQ 0(R8), R10 + MOVQ 8(R8), R11 + MOVQ 16(R8), R12 + MOVQ y+16(FP), R13 + + // A -> BP + // t[0] -> R14 + // t[1] -> R15 + // t[2] -> CX + // t[3] -> BX + // t[4] -> SI + // t[5] -> DI +#define MACC(in0, in1, in2) \ + ADCXQ in0, in1 \ + MULXQ in2, AX, in0 \ + ADOXQ AX, in1 \ + +#define DIV_SHIFT() \ + MOVQ $const_qInvNeg, DX \ + IMULQ R14, DX \ + XORQ AX, AX \ + MULXQ ·qElement+0(SB), AX, R9 \ + ADCXQ R14, AX \ + MOVQ R9, R14 \ + MACC(R15, R14, ·qElement+8(SB)) \ + MACC(CX, R15, ·qElement+16(SB)) \ + MACC(BX, CX, ·qElement+24(SB)) \ + MACC(SI, BX, ·qElement+32(SB)) \ + MACC(DI, SI, ·qElement+40(SB)) \ + MOVQ $0, AX \ + ADCXQ AX, DI \ + ADOXQ BP, DI \ + +#define MUL_WORD_0() \ + XORQ AX, AX \ + MULXQ R10, R14, R15 \ + MULXQ R11, AX, CX \ + ADOXQ AX, R15 \ + MULXQ R12, AX, BX \ + ADOXQ AX, CX \ + MULXQ 24(R8), AX, SI \ + ADOXQ AX, BX \ + MULXQ 32(R8), AX, DI \ + ADOXQ AX, SI \ + MULXQ 40(R8), AX, BP \ + ADOXQ AX, DI \ + MOVQ $0, AX \ + ADOXQ AX, BP \ + DIV_SHIFT() \ + +#define MUL_WORD_N() \ + XORQ AX, AX \ + MULXQ R10, AX, BP \ + ADOXQ AX, R14 \ + MACC(BP, R15, R11) \ + MACC(BP, CX, R12) \ + MACC(BP, BX, 24(R8)) \ + MACC(BP, SI, 32(R8)) \ + MACC(BP, DI, 40(R8)) \ + MOVQ $0, AX \ + ADCXQ AX, BP \ + ADOXQ AX, BP \ + DIV_SHIFT() \ + + // mul body + MOVQ 0(R13), DX + MUL_WORD_0() + MOVQ 8(R13), DX + MUL_WORD_N() + MOVQ 16(R13), DX + MUL_WORD_N() + MOVQ 24(R13), DX + MUL_WORD_N() + MOVQ 32(R13), DX + MUL_WORD_N() + MOVQ 40(R13), DX + MUL_WORD_N() + + // reduce element(R14,R15,CX,BX,SI,DI) using temp registers (R9,R8,R13,R10,R11,R12) + REDUCE(R14,R15,CX,BX,SI,DI,R9,R8,R13,R10,R11,R12) + + MOVQ res+0(FP), AX + MOVQ R14, 0(AX) + MOVQ R15, 8(AX) + MOVQ CX, 16(AX) + MOVQ BX, 24(AX) + MOVQ SI, 32(AX) + MOVQ DI, 40(AX) + RET + +noAdx_1: + MOVQ res+0(FP), AX + MOVQ AX, (SP) + MOVQ x+8(FP), AX + MOVQ AX, 8(SP) + MOVQ y+16(FP), AX + MOVQ AX, 16(SP) + CALL ·_mulGeneric(SB) + RET + +TEXT ·fromMont(SB), $8-8 + NO_LOCAL_POINTERS + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + // when y = 1 we have: + // for i=0 to N-1 + // t[i] = x[i] + // for i=0 to N-1 + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // t[N-1] = C + CMPB ·supportAdx(SB), $1 + JNE noAdx_2 + MOVQ res+0(FP), DX + MOVQ 0(DX), R14 + MOVQ 8(DX), R15 + MOVQ 16(DX), CX + MOVQ 24(DX), BX + MOVQ 32(DX), SI + MOVQ 40(DX), DI + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + MOVQ $0, AX + ADCXQ AX, DI + ADOXQ AX, DI + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + MOVQ $0, AX + ADCXQ AX, DI + ADOXQ AX, DI + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + MOVQ $0, AX + ADCXQ AX, DI + ADOXQ AX, DI + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + MOVQ $0, AX + ADCXQ AX, DI + ADOXQ AX, DI + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + MOVQ $0, AX + ADCXQ AX, DI + ADOXQ AX, DI + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ $const_qInvNeg, DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ ·qElement+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ ·qElement+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ ·qElement+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ ·qElement+24(SB), AX, BX + ADOXQ AX, CX + + // (C,t[3]) := t[4] + m*q[4] + C + ADCXQ SI, BX + MULXQ ·qElement+32(SB), AX, SI + ADOXQ AX, BX + + // (C,t[4]) := t[5] + m*q[5] + C + ADCXQ DI, SI + MULXQ ·qElement+40(SB), AX, DI + ADOXQ AX, SI + MOVQ $0, AX + ADCXQ AX, DI + ADOXQ AX, DI + + // reduce element(R14,R15,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12,R13) + REDUCE(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) + + MOVQ res+0(FP), AX + MOVQ R14, 0(AX) + MOVQ R15, 8(AX) + MOVQ CX, 16(AX) + MOVQ BX, 24(AX) + MOVQ SI, 32(AX) + MOVQ DI, 40(AX) + RET + +noAdx_2: + MOVQ res+0(FP), AX + MOVQ AX, (SP) + CALL ·_fromMontGeneric(SB) + RET diff --git a/field/generator/asm/amd64/asm_macros.go b/field/generator/asm/amd64/asm_macros.go index 45d324c94..b2c6341cf 100644 --- a/field/generator/asm/amd64/asm_macros.go +++ b/field/generator/asm/amd64/asm_macros.go @@ -61,37 +61,25 @@ func (f *FFAmd64) ReduceElement(t, scratch []amd64.Register) { f.WriteLn("") } -// TODO @gbotrel: figure out if interleaving MOVQ and SUBQ or CMOVQ and MOVQ instructions makes sense -const tmplDefines = ` - -// modulus q -{{- range $i, $w := .Q}} -DATA q<>+{{mul $i 8}}(SB)/8, {{imm $w}} -{{- end}} -GLOBL q<>(SB), (RODATA+NOPTR), ${{mul 8 $.NbWords}} - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, {{$qinv0 := index .QInverse 0}}{{imm $qinv0}} -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 +const tmplReduceDefine = ` #define REDUCE( {{- range $i := .NbWordsIndexesFull}}ra{{$i}},{{- end}} {{- range $i := .NbWordsIndexesFull}}rb{{$i}}{{- if ne $.NbWordsLastIndex $i}},{{- end}}{{- end}}) \ MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ + SUBQ ·qElement(SB), ra0; \ {{- range $i := .NbWordsIndexesNoZero}} MOVQ ra{{$i}}, rb{{$i}}; \ - SBBQ q<>+{{mul $i 8}}(SB), ra{{$i}}; \ + SBBQ ·qElement+{{mul $i 8}}(SB), ra{{$i}}; \ {{- end}} {{- range $i := .NbWordsIndexesFull}} CMOVQCS rb{{$i}}, ra{{$i}}; \ {{- end}} - ` -func (f *FFAmd64) GenerateDefines() { +func (f *FFAmd64) GenerateReduceDefine() { tmpl := template.Must(template.New(""). Funcs(helpers()). - Parse(tmplDefines)) + Parse(tmplReduceDefine)) // execute template var buf bytes.Buffer diff --git a/field/generator/asm/amd64/build.go b/field/generator/asm/amd64/build.go index b760ad3e3..ccf41c4e8 100644 --- a/field/generator/asm/amd64/build.go +++ b/field/generator/asm/amd64/build.go @@ -17,28 +17,58 @@ package amd64 import ( "fmt" + "hash/fnv" "io" + "os" + "path/filepath" "strings" - "github.com/consensys/bavard" - "github.com/consensys/bavard/amd64" "github.com/consensys/gnark-crypto/field/generator/config" ) const SmallModulus = 6 +const ( + ElementASMFileName = "element_%dw_amd64.s" +) -func NewFFAmd64(w io.Writer, F *config.FieldConfig) *FFAmd64 { - return &FFAmd64{F, amd64.NewAmd64(w), 0, 0} +func NewFFAmd64(w io.Writer, nbWords int) *FFAmd64 { + F := &FFAmd64{ + amd64.NewAmd64(w), + 0, + 0, + nbWords, + nbWords - 1, + make([]int, nbWords), + make([]int, nbWords-1), + make(map[string]defineFn), + } + + // indexes (template helpers) + for i := 0; i < F.NbWords; i++ { + F.NbWordsIndexesFull[i] = i + if i > 0 { + F.NbWordsIndexesNoZero[i-1] = i + } + } + + return F } type FFAmd64 struct { - *config.FieldConfig + // *config.FieldConfig *amd64.Amd64 - nbElementsOnStack int - maxOnStack int + nbElementsOnStack int + maxOnStack int + NbWords int + NbWordsLastIndex int + NbWordsIndexesFull []int + NbWordsIndexesNoZero []int + mDefines map[string]defineFn } +type defineFn func(args ...amd64.Register) + func (f *FFAmd64) StackSize(maxNbRegistersNeeded, nbRegistersReserved, minStackSize int) int { got := amd64.NbRegisters - nbRegistersReserved r := got - maxNbRegistersNeeded @@ -49,6 +79,64 @@ func (f *FFAmd64) StackSize(maxNbRegistersNeeded, nbRegistersReserved, minStackS return max(r, minStackSize) } +func (f *FFAmd64) DefineFn(name string) (fn defineFn, err error) { + fn, ok := f.mDefines[name] + if !ok { + return nil, fmt.Errorf("function %s not defined", name) + } + return fn, nil +} + +func (f *FFAmd64) Define(name string, nbInputs int, fn defineFn) defineFn { + + inputs := make([]string, nbInputs) + for i := 0; i < nbInputs; i++ { + inputs[i] = fmt.Sprintf("in%d", i) + } + name = strings.ToUpper(name) + + for _, ok := f.mDefines[name]; ok; { + // name already exist, for code generation purpose we add a suffix + // should happen only with e2 deprecated functions + fmt.Println("WARNING: function name already defined, adding suffix") + i := 0 + for { + newName := fmt.Sprintf("%s_%d", name, i) + if _, ok := f.mDefines[newName]; !ok { + name = newName + goto startDefine + } + i++ + } + } +startDefine: + + f.StartDefine() + f.WriteLn("#define " + name + "(" + strings.Join(inputs, ", ") + ")") + inputsRegisters := make([]amd64.Register, nbInputs) + for i := 0; i < nbInputs; i++ { + inputsRegisters[i] = amd64.Register(inputs[i]) + } + fn(inputsRegisters...) + f.EndDefine() + f.WriteLn("") + + toReturn := func(args ...amd64.Register) { + if len(args) != nbInputs { + panic("invalid number of arguments") + } + inputsStr := make([]string, len(args)) + for i := 0; i < len(args); i++ { + inputsStr[i] = string(args[i]) + } + f.WriteLn(name + "(" + strings.Join(inputsStr, ", ") + ")") + } + + f.mDefines[name] = toReturn + + return toReturn +} + func max(a, b int) int { if a > b { return a @@ -128,24 +216,68 @@ func (f *FFAmd64) PopN(registers *amd64.Registers, forceStack ...bool) []amd64.R } func (f *FFAmd64) qAt(index int) string { - return fmt.Sprintf("q<>+%d(SB)", index*8) + return fmt.Sprintf("·qElement+%d(SB)", index*8) +} + +func (f *FFAmd64) qAt_bcst(index int) string { + return fmt.Sprintf("·qElement+%d(SB)", index*4) } func (f *FFAmd64) qInv0() string { - return "qInv0<>(SB)" + return "$const_qInvNeg" +} + +func (f *FFAmd64) mu() string { + return "$const_mu" +} + +func GenerateFieldWrapper(w io.Writer, F *config.FieldConfig, asmDirBuildPath, asmDirIncludePath string) error { + // for each field we generate the defines for the modulus and the montgomery constant + f := NewFFAmd64(w, F.NbWords) + + // we add the defines first, then the common asm, then the global variable section + // to enable correct compilations with #include in order. + f.WriteLn("") + + hashAndInclude := func(fileName string) error { + // we hash the file content and include the hash in comment of the generated file + // to force the Go compiler to recompile the file if the content has changed + fData, err := os.ReadFile(filepath.Join(asmDirBuildPath, fileName)) + if err != nil { + return err + } + // hash the file using FNV + hasher := fnv.New64() + hasher.Write(fData) + hash := hasher.Sum64() + + f.WriteLn("// Code generated by gnark-crypto/generator. DO NOT EDIT.") + f.WriteLn(fmt.Sprintf("// We include the hash to force the Go compiler to recompile: %d", hash)) + f.WriteLn(fmt.Sprintf("#include \"%s\"\n", filepath.Join(asmDirIncludePath, fileName))) + + return nil + } + + toInclude := fmt.Sprintf(ElementASMFileName, F.NbWords) + if err := hashAndInclude(toInclude); err != nil { + return err + } + + return nil } -// Generate generates assembly code for the base field provided to goff +// GenerateCommonASM generates assembly code for the base field provided to goff // see internal/templates/ops* -func Generate(w io.Writer, F *config.FieldConfig) error { - f := NewFFAmd64(w, F) - f.WriteLn(bavard.Apache2Header("ConsenSys Software Inc.", 2020)) +func GenerateCommonASM(w io.Writer, nbWords int, hasVector bool) error { + f := NewFFAmd64(w, nbWords) + f.Comment("Code generated by gnark-crypto/generator. DO NOT EDIT.") f.WriteLn("#include \"textflag.h\"") f.WriteLn("#include \"funcdata.h\"") + f.WriteLn("#include \"go_asm.h\"") f.WriteLn("") - f.GenerateDefines() + f.GenerateReduceDefine() // reduce f.generateReduce() @@ -158,48 +290,24 @@ func Generate(w io.Writer, F *config.FieldConfig) error { // fft butterflies f.generateButterfly() - // generate vector operations for "small" modulus - if f.NbWords == 4 { - f.generateAddVec() - f.generateSubVec() - f.generateScalarMulVec() - } - - return nil -} - -func GenerateMul(w io.Writer, F *config.FieldConfig) error { - f := NewFFAmd64(w, F) - f.WriteLn(bavard.Apache2Header("ConsenSys Software Inc.", 2020)) - - f.WriteLn("#include \"textflag.h\"") - f.WriteLn("#include \"funcdata.h\"") - f.WriteLn("") - f.GenerateDefines() - // mul f.generateMul(false) // from mont f.generateFromMont(false) - return nil -} - -func GenerateMulADX(w io.Writer, F *config.FieldConfig) error { - f := NewFFAmd64(w, F) - f.WriteLn(bavard.Apache2Header("ConsenSys Software Inc.", 2020)) + if hasVector { + f.WriteLn("") + f.Comment("Vector operations are partially derived from Dag Arne Osvik's work in github.com/a16z/vectorized-fields") + f.WriteLn("") - f.WriteLn("#include \"textflag.h\"") - f.WriteLn("#include \"funcdata.h\"") - f.WriteLn("") - f.GenerateDefines() - - // mul - f.generateMul(true) - - // from mont - f.generateFromMont(true) + f.generateAddVec() + f.generateSubVec() + f.generateSumVec() + f.generateInnerProduct() + f.generateMulVec("scalarMulVec") + f.generateMulVec("mulVec") + } return nil } diff --git a/field/generator/asm/amd64/element_butterfly.go b/field/generator/asm/amd64/element_butterfly.go index 2d996754a..8792b6dce 100644 --- a/field/generator/asm/amd64/element_butterfly.go +++ b/field/generator/asm/amd64/element_butterfly.go @@ -14,6 +14,8 @@ package amd64 +import "fmt" + // Butterfly sets // a = a + b // b = a - b @@ -56,7 +58,9 @@ func (f *FFAmd64) generateButterfly() { if f.NbWords >= 5 { // q is on the stack, can't use for CMOVQCC f.Mov(t0, q) // save t0 - f.Mov(f.Q, t0) + for i := 0; i < f.NbWords; i++ { + f.MOVQ(fmt.Sprintf("$const_q%d", i), t0[i]) + } for i := 0; i < f.NbWords; i++ { f.CMOVQCC(a, t0[i]) } @@ -64,7 +68,9 @@ func (f *FFAmd64) generateButterfly() { f.Add(t0, t1) f.Mov(q, t0) // restore t0 } else { - f.Mov(f.Q, q) + for i := 0; i < f.NbWords; i++ { + f.MOVQ(fmt.Sprintf("$const_q%d", i), q[i]) + } for i := 0; i < f.NbWords; i++ { f.CMOVQCC(a, q[i]) } @@ -110,10 +116,11 @@ func (f *FFAmd64) generateButterfly() { noReduce := f.NewLabel("noReduce") f.JCC(noReduce) q := r - f.MOVQ(f.Q[0], q) + f.MOVQ("$const_q0", q) + f.ADDQ(q, t0[0]) for i := 1; i < f.NbWords; i++ { - f.MOVQ(f.Q[i], q) + f.MOVQ(fmt.Sprintf("$const_q%d", i), q) f.ADCQ(q, t0[i]) } f.LABEL(noReduce) diff --git a/field/generator/asm/amd64/element_frommont.go b/field/generator/asm/amd64/element_frommont.go index 79b717dcc..de9d0e3c4 100644 --- a/field/generator/asm/amd64/element_frommont.go +++ b/field/generator/asm/amd64/element_frommont.go @@ -42,8 +42,8 @@ func (f *FFAmd64) generateFromMont(forceADX bool) { f.WriteLn("NO_LOCAL_POINTERS") } f.WriteLn(` - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 // when y = 1 we have: // for i=0 to N-1 // t[i] = x[i] diff --git a/field/generator/asm/amd64/element_mul.go b/field/generator/asm/amd64/element_mul.go index df63f1ad4..55f816c26 100644 --- a/field/generator/asm/amd64/element_mul.go +++ b/field/generator/asm/amd64/element_mul.go @@ -20,6 +20,18 @@ import ( "github.com/consensys/bavard/amd64" ) +// Registers used when f.NbWords == 4 +// for the multiplication. +// They are re-referenced in defines in the vectorized operations. +var mul4Registers = []amd64.Register{ + // t + amd64.R14, amd64.R13, amd64.CX, amd64.BX, + // x + amd64.DI, amd64.R8, amd64.R9, amd64.R10, + // tr + amd64.R12, +} + // MulADX uses AX, DX and BP // sets x * y into t, without modular reduction // x() will have more accesses than y() @@ -40,69 +52,50 @@ func (f *FFAmd64) MulADX(registers *amd64.Registers, x, y func(int) string, t [] f.LabelRegisters("A", A) f.LabelRegisters("t", t...) - for i := 0; i < f.NbWords; i++ { - f.Comment("clear the flags") - f.XORQ(amd64.AX, amd64.AX) - - f.MOVQ(y(i), amd64.DX) - - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - if i == 0 { - for j := 0; j < f.NbWords; j++ { - f.Comment(fmt.Sprintf("(A,t[%[1]d]) := x[%[1]d]*y[%[2]d] + A", j, i)) - - if j == 0 && f.NbWords == 1 { - f.MULXQ(x(j), t[j], A) - } else if j == 0 { - f.MULXQ(x(j), t[j], t[j+1]) - } else { - highBits := A - if j != f.NbWordsLastIndex { - highBits = t[j+1] - } - f.MULXQ(x(j), amd64.AX, highBits) - f.ADOXQ(amd64.AX, t[j]) - } + if f.NbWords == 4 && hasFreeRegister { + // ensure the registers match the "hardcoded ones" in mul4Registers for the vecops + match := true + for i := 0; i < 4; i++ { + if mul4Registers[i] != t[i] { + match = false + fmt.Printf("expected %s, got t[%d] %s\n", mul4Registers[i], i, t[i]) } - } else { - for j := 0; j < f.NbWords; j++ { - f.Comment(fmt.Sprintf("(A,t[%[1]d]) := t[%[1]d] + x[%[1]d]*y[%[2]d] + A", j, i)) - - if j != 0 { - f.ADCXQ(A, t[j]) - } - f.MULXQ(x(j), amd64.AX, A) - f.ADOXQ(amd64.AX, t[j]) + if mul4Registers[i+4] != amd64.Register(x(i)) { + match = false + fmt.Printf("expected %s, got x[%d] %s\n", mul4Registers[i+4], i, x(i)) } } - - f.Comment("A += carries from ADCXQ and ADOXQ") - f.MOVQ(0, amd64.AX) - if i != 0 { - f.ADCXQ(amd64.AX, A) + if tr != mul4Registers[8] { + match = false + fmt.Printf("expected %s, got tr %s\n", mul4Registers[8], tr) } - f.ADOXQ(amd64.AX, A) + if !match { + panic("registers do not match hardcoded ones") + } + } + + mac := f.Define("MACC", 3, func(args ...amd64.Register) { + in0 := args[0] + in1 := args[1] + in2 := args[2] + f.ADCXQ(in0, in1) + f.MULXQ(in2, amd64.AX, in0) + f.ADOXQ(amd64.AX, in1) + }) + divShift := f.Define("DIV_SHIFT", 0, func(_ ...amd64.Register) { if !hasFreeRegister { f.PUSHQ(A) } - // m := t[0]*q'[0] mod W - f.Comment("m := t[0]*q'[0] mod W") m := amd64.DX - // f.MOVQ(t[0], m) - // f.MULXQ(f.qInv0(), m, amd64.AX) f.MOVQ(f.qInv0(), m) f.IMULQ(t[0], m) // clear the carry flags - f.Comment("clear the flags") f.XORQ(amd64.AX, amd64.AX) // C,_ := t[0] + m*q[0] - f.Comment("C,_ := t[0] + m*q[0]") - f.MULXQ(f.qAt(0), amd64.AX, tr) f.ADCXQ(t[0], amd64.AX) f.MOVQ(tr, t[0]) @@ -110,20 +103,68 @@ func (f *FFAmd64) MulADX(registers *amd64.Registers, x, y func(int) string, t [] if !hasFreeRegister { f.POPQ(A) } + // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[j-1]) := t[j] + m*q[j] + C for j := 1; j < f.NbWords; j++ { - f.Comment(fmt.Sprintf("(C,t[%[1]d]) := t[%[2]d] + m*q[%[2]d] + C", j-1, j)) - f.ADCXQ(t[j], t[j-1]) - f.MULXQ(f.qAt(j), amd64.AX, t[j]) - f.ADOXQ(amd64.AX, t[j-1]) + mac(t[j], t[j-1], amd64.Register(f.qAt(j))) } - f.Comment(fmt.Sprintf("t[%d] = C + A", f.NbWordsLastIndex)) f.MOVQ(0, amd64.AX) f.ADCXQ(amd64.AX, t[f.NbWordsLastIndex]) f.ADOXQ(A, t[f.NbWordsLastIndex]) + }) + + mulWord0 := f.Define("MUL_WORD_0", 0, func(_ ...amd64.Register) { + f.XORQ(amd64.AX, amd64.AX) + // for j=0 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + for j := 0; j < f.NbWords; j++ { + if j == 0 && f.NbWords == 1 { + f.MULXQ(x(j), t[j], A) + } else if j == 0 { + f.MULXQ(x(j), t[j], t[j+1]) + } else { + highBits := A + if j != f.NbWordsLastIndex { + highBits = t[j+1] + } + f.MULXQ(x(j), amd64.AX, highBits) + f.ADOXQ(amd64.AX, t[j]) + } + } + f.MOVQ(0, amd64.AX) + f.ADOXQ(amd64.AX, A) + divShift() + }) + + mulWordN := f.Define("MUL_WORD_N", 0, func(args ...amd64.Register) { + f.XORQ(amd64.AX, amd64.AX) + // for j=0 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + f.MULXQ(x(0), amd64.AX, A) + f.ADOXQ(amd64.AX, t[0]) + for j := 1; j < f.NbWords; j++ { + mac(A, t[j], amd64.Register(x(j))) + } + f.MOVQ(0, amd64.AX) + f.ADCXQ(amd64.AX, A) + f.ADOXQ(amd64.AX, A) + divShift() + }) + + f.Comment("mul body") + + for i := 0; i < f.NbWords; i++ { + f.MOVQ(y(i), amd64.DX) + + if i == 0 { + mulWord0() + } else { + mulWordN() + } } if hasFreeRegister { @@ -152,19 +193,11 @@ func (f *FFAmd64) generateMul(forceADX bool) { registers := f.FnHeader("mul", stackSize, argSize, reserved...) defer f.AssertCleanStack(stackSize, minStackSize) - f.WriteLn(fmt.Sprintf(` - // the algorithm is described in the %s.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - `, f.ElementName)) + f.WriteLn(` + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + // See github.com/gnark-crypto/field/generator for more comments. + `) if stackSize > 0 { f.WriteLn("NO_LOCAL_POINTERS") } diff --git a/field/generator/asm/amd64/element_vec.go b/field/generator/asm/amd64/element_vec.go index 05c2cf3f1..c67137d6c 100644 --- a/field/generator/asm/amd64/element_vec.go +++ b/field/generator/asm/amd64/element_vec.go @@ -14,7 +14,12 @@ package amd64 -import "github.com/consensys/bavard/amd64" +import ( + "fmt" + "strconv" + + "github.com/consensys/bavard/amd64" +) // addVec res = a + b // func addVec(res, a, b *{{.ElementName}}, n uint64) @@ -53,13 +58,14 @@ func (f *FFAmd64) generateAddVec() { f.LabelRegisters("a", a...) f.Mov(addrA, a) f.Add(addrB, a) + f.WriteLn(fmt.Sprintf("PREFETCHT0 2048(%[1]s)", addrA)) + f.WriteLn(fmt.Sprintf("PREFETCHT0 2048(%[1]s)", addrB)) // reduce a f.ReduceElement(a, t) // save a into res f.Mov(a, addrRes) - f.Comment("increment pointers to visit next element") f.ADDQ("$32", addrA) f.ADDQ("$32", addrB) @@ -117,11 +123,15 @@ func (f *FFAmd64) generateSubVec() { f.LabelRegisters("a", a...) f.Mov(addrA, a) f.Sub(addrB, a) + f.WriteLn(fmt.Sprintf("PREFETCHT0 2048(%[1]s)", addrA)) + f.WriteLn(fmt.Sprintf("PREFETCHT0 2048(%[1]s)", addrB)) // reduce a f.Comment("reduce (a-b) mod q") f.LabelRegisters("q", q...) - f.Mov(f.Q, q) + for i := 0; i < f.NbWords; i++ { + f.MOVQ(fmt.Sprintf("$const_q%d", i), q[i]) + } for i := 0; i < f.NbWords; i++ { f.CMOVQCC(zero, q[i]) } @@ -149,94 +159,1341 @@ func (f *FFAmd64) generateSubVec() { } -// scalarMulVec res = a * b -// func scalarMulVec(res, a, b *{{.ElementName}}, n uint64) -func (f *FFAmd64) generateScalarMulVec() { - f.Comment("scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b") +// sumVec res = sum(a[0...n]) +func (f *FFAmd64) generateSumVec() { + f.Comment("sumVec(res, a *Element, n uint64) res = sum(a[0...n])") + + const argSize = 3 * 8 + stackSize := f.StackSize(f.NbWords*3+2, 0, 0) + registers := f.FnHeader("sumVec", stackSize, argSize, amd64.DX, amd64.AX) + defer f.AssertCleanStack(stackSize, 0) + + f.WriteLn(` + // Derived from https://github.com/a16z/vectorized-fields + // The idea is to use Z registers to accumulate the sum of elements, 8 by 8 + // first, we handle the case where n % 8 != 0 + // then, we loop over the elements 8 by 8 and accumulate the sum in the Z registers + // finally, we reduce the sum and store it in res + // + // when we move an element of a into a Z register, we use VPMOVZXDQ + // let's note w0...w3 the 4 64bits words of ai: w0 = ai[0], w1 = ai[1], w2 = ai[2], w3 = ai[3] + // VPMOVZXDQ(ai, Z0) will result in + // Z0= [hi(w3), lo(w3), hi(w2), lo(w2), hi(w1), lo(w1), hi(w0), lo(w0)] + // with hi(wi) the high 32 bits of wi and lo(wi) the low 32 bits of wi + // we can safely add 2^32+1 times Z registers constructed this way without overflow + // since each of this lo/hi bits are moved into a "64bits" slot + // N = 2^64-1 / 2^32-1 = 2^32+1 + // + // we then propagate the carry using ADOXQ and ADCXQ + // r0 = w0l + lo(woh) + // r1 = carry + hi(woh) + w1l + lo(w1h) + // r2 = carry + hi(w1h) + w2l + lo(w2h) + // r3 = carry + hi(w2h) + w3l + lo(w3h) + // r4 = carry + hi(w3h) + // we then reduce the sum using a single-word Barrett reduction + // we pick mu = 2^288 / q; which correspond to 4.5 words max. + // meaning we must guarantee that r4 fits in 32bits. + // To do so, we reduce N to 2^32-1 (since r4 receives 2 carries max) + `) + + // registers & labels we need + addrA := f.Pop(®isters) + n := f.Pop(®isters) + nMod8 := f.Pop(®isters) + + loop := f.NewLabel("loop8by8") + done := f.NewLabel("done") + loopSingle := f.NewLabel("loop_single") + accumulate := f.NewLabel("accumulate") + + // AVX512 registers + Z0 := amd64.Register("Z0") + Z1 := amd64.Register("Z1") + Z2 := amd64.Register("Z2") + Z3 := amd64.Register("Z3") + Z4 := amd64.Register("Z4") + Z5 := amd64.Register("Z5") + Z6 := amd64.Register("Z6") + Z7 := amd64.Register("Z7") + Z8 := amd64.Register("Z8") + + X0 := amd64.Register("X0") + + // load arguments + f.MOVQ("a+8(FP)", addrA) + f.MOVQ("n+16(FP)", n) + + f.Comment("initialize accumulators Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7") + f.VXORPS(Z0, Z0, Z0) + f.VMOVDQA64(Z0, Z1) + f.VMOVDQA64(Z0, Z2) + f.VMOVDQA64(Z0, Z3) + f.VMOVDQA64(Z0, Z4) + f.VMOVDQA64(Z0, Z5) + f.VMOVDQA64(Z0, Z6) + f.VMOVDQA64(Z0, Z7) + + // note: we don't need to handle the case n==0; handled by caller already. + // f.TESTQ(n, n) + // f.JEQ(done, "n == 0, we are done") + + f.LabelRegisters("n % 8", nMod8) + f.LabelRegisters("n / 8", n) + f.MOVQ(n, nMod8) + f.ANDQ("$7", nMod8) // nMod8 = n % 8 + f.SHRQ("$3", n) // len = n / 8 + + f.LABEL(loopSingle) + f.TESTQ(nMod8, nMod8) + f.JEQ(loop, "n % 8 == 0, we are going to loop over 8 by 8") + + f.VPMOVZXDQ("0("+addrA+")", Z8) + f.VPADDQ(Z8, Z0, Z0) + f.ADDQ("$32", addrA) + + f.DECQ(nMod8, "decrement nMod8") + f.JMP(loopSingle) + + f.Push(®isters, nMod8) // we don't need tmp0 + + f.LABEL(loop) + f.TESTQ(n, n) + f.JEQ(accumulate, "n == 0, we are going to accumulate") + + for i := 0; i < 8; i++ { + r := fmt.Sprintf("Z%d", i+8) + f.VPMOVZXDQ(fmt.Sprintf("%d*32("+string(addrA)+")", i), r) + } + + f.WriteLn(fmt.Sprintf("PREFETCHT0 4096(%[1]s)", addrA)) + for i := 0; i < 8; i++ { + r := fmt.Sprintf("Z%d", i) + f.VPADDQ(fmt.Sprintf("Z%d", i+8), r, r) + } + + f.Comment("increment pointers to visit next 8 elements") + f.ADDQ("$256", addrA) + f.DECQ(n, "decrement n") + f.JMP(loop) + + f.Push(®isters, n, addrA) + + f.LABEL(accumulate) + + f.Comment("accumulate the 8 Z registers into Z0") + f.VPADDQ(Z7, Z6, Z6) + f.VPADDQ(Z6, Z5, Z5) + f.VPADDQ(Z5, Z4, Z4) + f.VPADDQ(Z4, Z3, Z3) + f.VPADDQ(Z3, Z2, Z2) + f.VPADDQ(Z2, Z1, Z1) + f.VPADDQ(Z1, Z0, Z0) + + w0l := f.Pop(®isters) + w0h := f.Pop(®isters) + w1l := f.Pop(®isters) + w1h := f.Pop(®isters) + w2l := f.Pop(®isters) + w2h := f.Pop(®isters) + w3l := f.Pop(®isters) + w3h := f.Pop(®isters) + low0h := f.Pop(®isters) + low1h := f.Pop(®isters) + low2h := f.Pop(®isters) + low3h := f.Pop(®isters) + + // Propagate carries + f.Comment("carry propagation") + + f.LabelRegisters("lo(w0)", w0l) + f.LabelRegisters("hi(w0)", w0h) + f.LabelRegisters("lo(w1)", w1l) + f.LabelRegisters("hi(w1)", w1h) + f.LabelRegisters("lo(w2)", w2l) + f.LabelRegisters("hi(w2)", w2h) + f.LabelRegisters("lo(w3)", w3l) + f.LabelRegisters("hi(w3)", w3h) + + f.VMOVQ(X0, w0l) + f.VALIGNQ("$1", Z0, Z0, Z0) + f.VMOVQ(X0, w0h) + f.VALIGNQ("$1", Z0, Z0, Z0) + f.VMOVQ(X0, w1l) + f.VALIGNQ("$1", Z0, Z0, Z0) + f.VMOVQ(X0, w1h) + f.VALIGNQ("$1", Z0, Z0, Z0) + f.VMOVQ(X0, w2l) + f.VALIGNQ("$1", Z0, Z0, Z0) + f.VMOVQ(X0, w2h) + f.VALIGNQ("$1", Z0, Z0, Z0) + f.VMOVQ(X0, w3l) + f.VALIGNQ("$1", Z0, Z0, Z0) + f.VMOVQ(X0, w3h) + + f.LabelRegisters("lo(hi(wo))", low0h) + f.LabelRegisters("lo(hi(w1))", low1h) + f.LabelRegisters("lo(hi(w2))", low2h) + f.LabelRegisters("lo(hi(w3))", low3h) + + type hilo struct { + hi, lo amd64.Register + } + + splitLoHi := f.Define("SPLIT_LO_HI", 2, func(args ...amd64.Register) { + lo := args[0] + hi := args[1] + f.MOVQ(hi, lo) + f.ANDQ("$0xffffffff", lo) + f.SHLQ("$32", lo) + f.SHRQ("$32", hi) + }) + + for _, v := range []hilo{{w0h, low0h}, {w1h, low1h}, {w2h, low2h}, {w3h, low3h}} { + splitLoHi(v.lo, v.hi) + } + + f.WriteLn(` + // r0 = w0l + lo(woh) + // r1 = carry + hi(woh) + w1l + lo(w1h) + // r2 = carry + hi(w1h) + w2l + lo(w2h) + // r3 = carry + hi(w2h) + w3l + lo(w3h) + // r4 = carry + hi(w3h) + `) + f.XORQ(amd64.AX, amd64.AX, "clear the flags") + f.ADOXQ(low0h, w0l) + + f.ADOXQ(low1h, w1l) + f.ADCXQ(w0h, w1l) + + f.ADOXQ(low2h, w2l) + f.ADCXQ(w1h, w2l) + + f.ADOXQ(low3h, w3l) + f.ADCXQ(w2h, w3l) + + f.ADOXQ(amd64.AX, w3h) + f.ADCXQ(amd64.AX, w3h) + + r0 := w0l + r1 := w1l + r2 := w2l + r3 := w3l + r4 := w3h + + r := []amd64.Register{r0, r1, r2, r3, r4} + f.LabelRegisters("r", r...) + // we don't need w0h, w1h, w2h anymore + f.Push(®isters, w0h, w1h, w2h) + // we don't need the low bits anymore + f.Push(®isters, low0h, low1h, low2h, low3h) + + // Reduce using single-word Barrett + mu := f.Pop(®isters) + + f.Comment("reduce using single-word Barrett") + f.Comment("see see Handbook of Applied Cryptography, Algorithm 14.42.") + f.LabelRegisters("mu=2^288 / q", mu) + f.MOVQ(f.mu(), mu) + f.MOVQ(r3, amd64.AX) + f.SHRQw("$32", r4, amd64.AX) + f.MULQ(mu, "high bits of res stored in DX") + + f.MULXQ(f.qAt(0), amd64.AX, mu) + f.SUBQ(amd64.AX, r0) + f.SBBQ(mu, r1) + + f.MULXQ(f.qAt(2), amd64.AX, mu) + f.SBBQ(amd64.AX, r2) + f.SBBQ(mu, r3) + f.SBBQ("$0", r4) + + f.MULXQ(f.qAt(1), amd64.AX, mu) + f.SUBQ(amd64.AX, r1) + f.SBBQ(mu, r2) + + f.MULXQ(f.qAt(3), amd64.AX, mu) + f.SBBQ(amd64.AX, r3) + f.SBBQ(mu, r4) + + // we need up to 2 conditional substractions to be < q + modReduced := f.NewLabel("modReduced") + t := f.PopN(®isters) + f.Mov(r[:4], t) // backup r0 to r3 (our result) + + // sub modulus + f.SUBQ(f.qAt(0), r0) + f.SBBQ(f.qAt(1), r1) + f.SBBQ(f.qAt(2), r2) + f.SBBQ(f.qAt(3), r3) + f.SBBQ("$0", r4) + + // if borrow, we go to mod reduced + f.JCS(modReduced) + f.Mov(r, t) + f.SUBQ(f.qAt(0), r0) + f.SBBQ(f.qAt(1), r1) + f.SBBQ(f.qAt(2), r2) + f.SBBQ(f.qAt(3), r3) + f.SBBQ("$0", r4) + + // if borrow, we skip to the end + f.JCS(modReduced) + f.Mov(r, t) + + f.LABEL(modReduced) + addrRes := mu + f.MOVQ("res+0(FP)", addrRes) + f.Mov(t, addrRes) + + f.LABEL(done) + + f.RET() + f.Push(®isters, mu) + f.Push(®isters, w0l, w1l, w2l, w3l, w3h) +} + +func (f *FFAmd64) generateInnerProduct() { + f.Comment("innerProdVec(res, a,b *Element, n uint64) res = sum(a[0...n] * b[0...n])") const argSize = 4 * 8 - const minStackSize = 7 * 8 // 2 slices (3 words each) + pointer to the scalar - stackSize := f.StackSize(f.NbWords*2+3, 2, minStackSize) - reserved := []amd64.Register{amd64.DX, amd64.AX} - registers := f.FnHeader("scalarMulVec", stackSize, argSize, reserved...) - defer f.AssertCleanStack(stackSize, minStackSize) - - // labels & registers we need - noAdx := f.NewLabel("noAdx") + stackSize := f.StackSize(f.NbWords*3+2, 0, 0) + registers := f.FnHeader("innerProdVec", stackSize, argSize, amd64.DX, amd64.AX) + defer f.AssertCleanStack(stackSize, 0) + + // registers & labels we need + PX := f.Pop(®isters) + PY := f.Pop(®isters) + LEN := f.Pop(®isters) + loop := f.NewLabel("loop") done := f.NewLabel("done") + AddPP := f.NewLabel("accumulate") - t := registers.PopN(f.NbWords) - scalar := registers.PopN(f.NbWords) + // AVX512 registers + PPL := amd64.Register("Z2") + PPH := amd64.Register("Z3") + Y := amd64.Register("Z4") + LSW := amd64.Register("Z5") - addrB := registers.Pop() - addrA := registers.Pop() - addrRes := addrB - len := registers.Pop() + ACC := amd64.Register("Z16") + A0L := amd64.Register("Z16") + A1L := amd64.Register("Z17") + A2L := amd64.Register("Z18") + A3L := amd64.Register("Z19") + A4L := amd64.Register("Z20") + A5L := amd64.Register("Z21") + A6L := amd64.Register("Z22") + A7L := amd64.Register("Z23") + A0H := amd64.Register("Z24") + A1H := amd64.Register("Z25") + A2H := amd64.Register("Z26") + A3H := amd64.Register("Z27") + A4H := amd64.Register("Z28") + A5H := amd64.Register("Z29") + A6H := amd64.Register("Z30") + A7H := amd64.Register("Z31") - // check ADX instruction support - f.CMPB("·supportAdx(SB)", 1) - f.JNE(noAdx) + // load arguments + f.MOVQ("a+8(FP)", PX) + f.MOVQ("b+16(FP)", PY) + f.MOVQ("n+24(FP)", LEN) - f.MOVQ("a+8(FP)", addrA) - f.MOVQ("b+16(FP)", addrB) - f.MOVQ("n+24(FP)", len) + f.Comment("Create mask for low dword in each qword") + f.VPCMPEQB("Y0", "Y0", "Y0") + f.VPMOVZXDQ("Y0", LSW) + + // Clear accumulator registers + f.VPXORQ(A0L, A0L, A0L) + f.VMOVDQA64(A0L, A1L) + f.VMOVDQA64(A0L, A2L) + f.VMOVDQA64(A0L, A3L) + f.VMOVDQA64(A0L, A4L) + f.VMOVDQA64(A0L, A5L) + f.VMOVDQA64(A0L, A6L) + f.VMOVDQA64(A0L, A7L) + f.VMOVDQA64(A0L, A0H) + f.VMOVDQA64(A0L, A1H) + f.VMOVDQA64(A0L, A2H) + f.VMOVDQA64(A0L, A3H) + f.VMOVDQA64(A0L, A4H) + f.VMOVDQA64(A0L, A5H) + f.VMOVDQA64(A0L, A6H) + f.VMOVDQA64(A0L, A7H) + + // note: we don't need to handle the case n==0; handled by caller already. + f.TESTQ(LEN, LEN) + f.JEQ(done, "n == 0, we are done") + + f.LABEL(loop) + f.TESTQ(LEN, LEN) + f.JEQ(AddPP, "n == 0 we can accumulate") + + f.VPMOVZXDQ("("+PY+")", Y) + + f.ADDQ("$32", PY) + + f.Comment("we multiply and accumulate partial products of 4 bytes * 32 bytes") + + mac := f.Define("MAC", 3, func(inputs ...amd64.Register) { + opLeft := inputs[0] + lo := inputs[1] + hi := inputs[2] + + f.VPMULUDQ_BCST(opLeft, Y, PPL) + f.VPSRLQ("$32", PPL, PPH) + f.VPANDQ(LSW, PPL, PPL) + f.VPADDQ(PPL, lo, lo) + f.VPADDQ(PPH, hi, hi) + }) - // we store b, the scalar, fully in registers - f.LabelRegisters("scalar", scalar...) - f.Mov(addrB, scalar) + mac("0*4("+PX+")", A0L, A0H) + mac("1*4("+PX+")", A1L, A1H) + mac("2*4("+PX+")", A2L, A2H) + mac("3*4("+PX+")", A3L, A3H) + mac("4*4("+PX+")", A4L, A4H) + mac("5*4("+PX+")", A5L, A5H) + mac("6*4("+PX+")", A6L, A6H) + mac("7*4("+PX+")", A7L, A7H) - xat := func(i int) string { - return string(scalar[i]) + f.ADDQ("$32", PX) + + f.DECQ(LEN, "decrement n") + f.JMP(loop) + + f.Push(®isters, LEN, PX, PY) + + f.LABEL(AddPP) + f.Comment("we accumulate the partial products into 544bits in Z1:Z0") + + f.MOVQ(uint64(0x1555), amd64.AX) + f.KMOVD(amd64.AX, "K1") + + f.MOVQ(uint64(1), amd64.AX) + f.KMOVD(amd64.AX, "K2") + + // ACC starts with the value of A0L + + f.Comment("store the least significant 32 bits of ACC (starts with A0L) in Z0") + f.VALIGND_Z("$16", ACC, ACC, "K2", "Z0") + f.KSHIFTLW("$1", "K2", "K2") + + f.VPSRLQ("$32", ACC, PPL) + f.VALIGND_Z("$2", ACC, ACC, "K1", ACC) + f.VPADDQ(PPL, ACC, ACC) + + f.VPANDQ(LSW, A0H, PPL) + f.VPADDQ(PPL, ACC, ACC) + + f.VPANDQ(LSW, A1L, PPL) + f.VPADDQ(PPL, ACC, ACC) + + // Word 1 of z is ready + f.VALIGND("$15", ACC, ACC, "K2", "Z0") + f.KSHIFTLW("$1", "K2", "K2") + + f.Comment("macro to add partial products and store the result in Z0") + addPP := f.Define("ADDPP", 5, func(inputs ...amd64.Register) { + AxH := inputs[0] + AyL := inputs[1] + AyH := inputs[2] + AzL := inputs[3] + I := inputs[4] + f.VPSRLQ("$32", ACC, PPL) + f.VALIGND_Z("$2", ACC, ACC, "K1", ACC) + f.VPADDQ(PPL, ACC, ACC) + f.VPSRLQ("$32", AxH, AxH) + f.VPADDQ(AxH, ACC, ACC) + f.VPSRLQ("$32", AyL, AyL) + f.VPADDQ(AyL, ACC, ACC) + f.VPANDQ(LSW, AyH, PPL) + f.VPADDQ(PPL, ACC, ACC) + f.VPANDQ(LSW, AzL, PPL) + f.VPADDQ(PPL, ACC, ACC) + f.VALIGND("$16-"+I, ACC, ACC, "K2", "Z0") + f.KADDW("K2", "K2", "K2") + }) + + addPP(A0H, A1L, A1H, A2L, "2") + addPP(A1H, A2L, A2H, A3L, "3") + addPP(A2H, A3L, A3H, A4L, "4") + addPP(A3H, A4L, A4H, A5L, "5") + addPP(A4H, A5L, A5H, A6L, "6") + addPP(A5H, A6L, A6H, A7L, "7") + f.VPSRLQ("$32", ACC, PPL) + f.VALIGND_Z("$2", ACC, ACC, "K1", ACC) + f.VPADDQ(PPL, ACC, ACC) + f.VPSRLQ("$32", A6H, A6H) + f.VPADDQ(A6H, ACC, ACC) + f.VPSRLQ("$32", A7L, A7L) + f.VPADDQ(A7L, ACC, ACC) + f.VPANDQ(LSW, A7H, PPL) + f.VPADDQ(PPL, ACC, ACC) + f.VALIGND("$16-8", ACC, ACC, "K2", "Z0") + f.KSHIFTLW("$1", "K2", "K2") + + f.VPSRLQ("$32", ACC, PPL) + f.VALIGND_Z("$2", ACC, ACC, "K1", ACC) + f.VPADDQ(PPL, ACC, ACC) + f.VPSRLQ("$32", A7H, A7H) + f.VPADDQ(A7H, ACC, ACC) + f.VALIGND("$16-9", ACC, ACC, "K2", "Z0") + f.KSHIFTLW("$1", "K2", "K2") + + addPP2 := f.Define("ADDPP2", 1, func(args ...amd64.Register) { + f.VPSRLQ("$32", ACC, PPL) + f.VALIGND_Z("$2", ACC, ACC, "K1", ACC) + f.VPADDQ(PPL, ACC, ACC) + f.VALIGND("$16-"+args[0], ACC, ACC, "K2", "Z0") + f.KSHIFTLW("$1", "K2", "K2") + }) + + addPP2("10") + addPP2("11") + addPP2("12") + addPP2("13") + addPP2("14") + addPP2("15") + + f.VPSRLQ("$32", ACC, PPL) + f.VALIGND_Z("$2", ACC, ACC, "K1", ACC) + f.VPADDQ(PPL, ACC, ACC) + f.VMOVDQA64_Z(ACC, "K1", "Z1") + + T0 := f.Pop(®isters) + T1 := f.Pop(®isters) + T2 := f.Pop(®isters) + T3 := f.Pop(®isters) + T4 := f.Pop(®isters) + + f.Comment("Extract the 4 least significant qwords of Z0") + f.VMOVQ("X0", T1) + f.VALIGNQ("$1", "Z0", "Z1", "Z0") + f.VMOVQ("X0", T2) + f.VALIGNQ("$1", "Z0", "Z0", "Z0") + f.VMOVQ("X0", T3) + f.VALIGNQ("$1", "Z0", "Z0", "Z0") + f.VMOVQ("X0", T4) + f.VALIGNQ("$1", "Z0", "Z0", "Z0") + f.XORQ(T0, T0) + + PH := f.Pop(®isters) + PL := amd64.AX + f.MOVQ(f.qInv0(), amd64.DX) + f.MULXQ(T1, amd64.DX, PH) + f.MULXQ(f.qAt(0), PL, PH) + f.ADDQ(PL, T1) + f.ADCQ(PH, T2) + f.MULXQ(f.qAt(2), PL, PH) + f.ADCQ(PL, T3) + f.ADCQ(PH, T4) + f.ADCQ("$0", T0) + f.MULXQ(f.qAt(1), PL, PH) + f.ADDQ(PL, T2) + f.ADCQ(PH, T3) + f.MULXQ(f.qAt(3), PL, PH) + f.ADCQ(PL, T4) + f.ADCQ(PH, T0) + f.ADCQ("$0", T1) + + f.MOVQ(f.qInv0(), amd64.DX) + f.MULXQ(T2, amd64.DX, PH) + + f.MULXQ(f.qAt(0), PL, PH) + f.ADDQ(PL, T2) + f.ADCQ(PH, T3) + f.MULXQ(f.qAt(2), PL, PH) + f.ADCQ(PL, T4) + f.ADCQ(PH, T0) + f.ADCQ("$0", T1) + f.MULXQ(f.qAt(1), PL, PH) + f.ADDQ(PL, T3) + f.ADCQ(PH, T4) + f.MULXQ(f.qAt(3), PL, PH) + f.ADCQ(PL, T0) + f.ADCQ(PH, T1) + f.ADCQ("$0", T2) + + f.MOVQ(f.qInv0(), amd64.DX) + + f.MULXQ(T3, amd64.DX, PH) + + f.MULXQ(f.qAt(0), PL, PH) + f.ADDQ(PL, T3) + f.ADCQ(PH, T4) + f.MULXQ(f.qAt(2), PL, PH) + f.ADCQ(PL, T0) + f.ADCQ(PH, T1) + f.ADCQ("$0", T2) + f.MULXQ(f.qAt(1), PL, PH) + f.ADDQ(PL, T4) + f.ADCQ(PH, T0) + f.MULXQ(f.qAt(3), PL, PH) + f.ADCQ(PL, T1) + f.ADCQ(PH, T2) + f.ADCQ("$0", T3) + + f.MOVQ(f.qInv0(), amd64.DX) + + f.MULXQ(T4, amd64.DX, PH) + + f.MULXQ(f.qAt(0), PL, PH) + f.ADDQ(PL, T4) + f.ADCQ(PH, T0) + f.MULXQ(f.qAt(2), PL, PH) + f.ADCQ(PL, T1) + f.ADCQ(PH, T2) + f.ADCQ("$0", T3) + f.MULXQ(f.qAt(1), PL, PH) + f.ADDQ(PL, T0) + f.ADCQ(PH, T1) + f.MULXQ(f.qAt(3), PL, PH) + f.ADCQ(PL, T2) + f.ADCQ(PH, T3) + f.ADCQ("$0", T4) + + // Add the remaining 5 qwords (9 dwords) from zmm0 + + f.VMOVQ("X0", PL) + f.ADDQ(PL, T0) + f.VALIGNQ("$1", "Z0", "Z0", "Z0") + f.VMOVQ("X0", PL) + f.ADCQ(PL, T1) + f.VALIGNQ("$1", "Z0", "Z0", "Z0") + f.VMOVQ("X0", PL) + f.ADCQ(PL, T2) + f.VALIGNQ("$1", "Z0", "Z0", "Z0") + f.VMOVQ("X0", PL) + f.ADCQ(PL, T3) + f.VALIGNQ("$1", "Z0", "Z0", "Z0") + f.VMOVQ("X0", PL) + f.ADCQ(PL, T4) + + f.Comment("Barrett reduction; see Handbook of Applied Cryptography, Algorithm 14.42.") + f.MOVQ(T3, amd64.AX) + f.SHRQw("$32", T4, amd64.AX) + f.MOVQ(f.mu(), amd64.DX) + f.MULQ(amd64.DX) + + f.MULXQ(f.qAt(0), PL, PH) + f.SUBQ(PL, T0) + f.SBBQ(PH, T1) + f.MULXQ(f.qAt(2), PL, PH) + f.SBBQ(PL, T2) + f.SBBQ(PH, T3) + f.SBBQ("$0", T4) + f.MULXQ(f.qAt(1), PL, PH) + f.SUBQ(PL, T1) + f.SBBQ(PH, T2) + f.MULXQ(f.qAt(3), PL, PH) + f.SBBQ(PL, T3) + f.SBBQ(PH, T4) + + f.Comment("we need up to 2 conditional substractions to be < q") + + PZ := f.Pop(®isters) + f.MOVQ("res+0(FP)", PZ) + t := []amd64.Register{T0, T1, T2, T3} + f.Mov(t, PZ) + + // sub q + f.SUBQ(f.qAt(0), T0) + f.SBBQ(f.qAt(1), T1) + f.SBBQ(f.qAt(2), T2) + f.SBBQ(f.qAt(3), T3) + f.SBBQ("$0", T4) + + // if borrow, we go to done + f.JCS(done) + + f.Mov(t, PZ) + + f.SUBQ(f.qAt(0), T0) + f.SBBQ(f.qAt(1), T1) + f.SBBQ(f.qAt(2), T2) + f.SBBQ(f.qAt(3), T3) + f.SBBQ("$0", T4) + + f.JCS(done) + + f.Mov(t, PZ) + + f.LABEL(done) + + f.RET() +} + +func (f *FFAmd64) generateMulVec(funcName string) { + scalarMul := funcName != "mulVec" + + const argSize = 5 * 8 + stackSize := f.StackSize(6+f.NbWords, 2, 8) + reserved := make([]amd64.Register, len(mul4Registers)+2) + copy(reserved, mul4Registers) + reserved[len(mul4Registers)] = amd64.AX + reserved[len(mul4Registers)+1] = amd64.DX + registers := f.FnHeader(funcName, stackSize, argSize, reserved...) + defer f.AssertCleanStack(stackSize, 0) + + // to simplify the generated assembly, we only handle n/16 (and do blocks of 16 muls). + // that is if n%16 != 0, we let the caller (Go) handle the remaining elements. + LEN := f.Pop(®isters, true) + PZ := f.Pop(®isters) + PX := f.Pop(®isters) + PY := f.Pop(®isters) + + zi := func(i int) amd64.Register { + return amd64.Register("Z" + strconv.Itoa(i)) } - f.MOVQ("res+0(FP)", addrRes) + // AVX_MUL_Q_LO: + AVX_MUL_Q_LO, err := f.DefineFn("AVX_MUL_Q_LO") + if err != nil { + AVX_MUL_Q_LO = f.Define("AVX_MUL_Q_LO", 0, func(args ...amd64.Register) { + for i := 0; i < 4; i++ { + f.VPMULUDQ_BCST(f.qAt_bcst(i), "Z9", zi(10+i)) + f.VPADDQ(zi(10+i), zi(i), zi(i)) + } + }) + } + + // AVX_MUL_Q_HI: + AVX_MUL_Q_HI, err := f.DefineFn("AVX_MUL_Q_HI") + if err != nil { + AVX_MUL_Q_HI = f.Define("AVX_MUL_Q_HI", 0, func(args ...amd64.Register) { + for i := 0; i < 4; i++ { + f.VPMULUDQ_BCST(f.qAt_bcst(i+4), "Z9", zi(14+i)) + f.VPADDQ(zi(14+i), zi(i+4), zi(i+4)) + } + }) + } + + SHIFT_ADD_AND, err := f.DefineFn("SHIFT_ADD_AND") + if err != nil { + SHIFT_ADD_AND = f.Define("SHIFT_ADD_AND", 4, func(args ...amd64.Register) { + in0 := args[0] + in1 := args[1] + in2 := args[2] + in3 := args[3] + f.VPSRLQ("$32", in0, in1) + f.VPADDQ(in1, in2, in2) + f.VPANDQ(in3, in2, in0) + }) + } + + // CARRY1: + CARRY1, err := f.DefineFn("CARRY1") + if err != nil { + CARRY1 = f.Define("CARRY1", 0, func(args ...amd64.Register) { + for i := 0; i < 4; i++ { + SHIFT_ADD_AND(zi(i), zi(10+i), zi(i+1), "Z8") + } + }) + } + + // CARRY2: + CARRY2, err := f.DefineFn("CARRY2") + if err != nil { + CARRY2 = f.Define("CARRY2", 0, func(args ...amd64.Register) { + for i := 0; i < 3; i++ { + SHIFT_ADD_AND(zi(i+4), zi(14+i), zi(i+5), "Z8") + } + f.VPSRLQ("$32", "Z7", "Z7") + }) + } + + // CARRY3: + CARRY3, err := f.DefineFn("CARRY3") + if err != nil { + CARRY3 = f.Define("CARRY3", 0, func(args ...amd64.Register) { + for i := 0; i < 4; i++ { + f.VPSRLQ("$32", zi(i), zi(10+i)) + f.VPANDQ("Z8", zi(i), zi(i)) + f.VPADDQ(zi(10+i), zi(i+1), zi(i+1)) + } + }) + } + + // CARRY4: + CARRY4, err := f.DefineFn("CARRY4") + if err != nil { + CARRY4 = f.Define("CARRY4", 0, func(args ...amd64.Register) { + for i := 0; i < 3; i++ { + f.VPSRLQ("$32", zi(i+4), zi(14+i)) + f.VPANDQ("Z8", zi(i+4), zi(i+4)) + f.VPADDQ(zi(14+i), zi(i+5), zi(i+5)) + } + }) + } + + // we use the same registers as defined in the mul. + t := mul4Registers[:4] + f.LabelRegisters("t", t...) + y := mul4Registers[4:8] + f.LabelRegisters("y", y...) + tr := mul4Registers[8] + A := amd64.BP // note, BP is used in the mul defines. + + // reuse defines from the mul function + mulWord0, err := f.DefineFn("MUL_WORD_0") + if err != nil { + panic(err) + } + mulWordN, err := f.DefineFn("MUL_WORD_N") + if err != nil { + panic(err) + } + + zIndex := 0 + + loadInput := func() { + if scalarMul { + return + } + f.Comment(fmt.Sprintf("load input y[%d]", zIndex)) + f.Mov(PY, y, zIndex*4) + } + + mulXi := func(wordIndex int) { + f.Comment(fmt.Sprintf("z[%d] -> y * x[%d]", zIndex, wordIndex)) + if wordIndex == 0 { + mulWord0() + } else { + f.MOVQ(amd64.Register(PX.At(wordIndex)), amd64.DX) + mulWordN() + } + } + + storeOutput := func() { + scratch := []amd64.Register{A, tr, amd64.AX, amd64.DX} + f.ReduceElement(t, scratch) + + f.Comment(fmt.Sprintf("store output z[%d]", zIndex)) + f.Mov(t, PZ, 0, zIndex*4) + if zIndex == 7 { + f.ADDQ("$288", PX) + } else { + f.ADDQ("$32", PX) + f.MOVQ(amd64.Register(PX.At(0)), amd64.DX) + } + zIndex++ + } + + done := f.NewLabel("done") + loop := f.NewLabel("loop") + + f.MOVQ("res+0(FP)", PZ) + f.MOVQ("a+8(FP)", PX) + f.MOVQ("b+16(FP)", PY) + f.MOVQ("n+24(FP)", tr) + + if scalarMul { + // for scalar mul we move the scalar only once in registers. + f.Mov(PY, y) + } + + // we process 16 elements at a time, Go caller divided len by 16. + f.MOVQ(tr, LEN) + + f.Comment("Create mask for low dword in each qword") + + f.VPCMPEQB("Y8", "Y8", "Y8") + f.VPMOVZXDQ("Y8", "Z8") + f.MOVQ("$0x5555", amd64.DX) + f.KMOVD(amd64.DX, "K1") f.LABEL(loop) - f.TESTQ(len, len) + // f.MOVQ(LEN, tr) + f.TESTQ(tr, tr) f.JEQ(done, "n == 0, we are done") - yat := func(i int) string { - return addrA.At(i) + f.MOVQ(amd64.Register(PX.At(0)), amd64.DX) + f.VMOVDQU64("256+0*64("+PX+")", "Z16") + f.VMOVDQU64("256+1*64("+PX+")", "Z17") + f.VMOVDQU64("256+2*64("+PX+")", "Z18") + f.VMOVDQU64("256+3*64("+PX+")", "Z19") + + loadInput() + if scalarMul { + f.VMOVDQU64("0("+PY+")", "Z24") + f.VMOVDQU64("0("+PY+")", "Z25") + f.VMOVDQU64("0("+PY+")", "Z26") + f.VMOVDQU64("0("+PY+")", "Z27") + } else { + f.VMOVDQU64("256+0*64("+PY+")", "Z24") + f.VMOVDQU64("256+1*64("+PY+")", "Z25") + f.VMOVDQU64("256+2*64("+PY+")", "Z26") + f.VMOVDQU64("256+3*64("+PY+")", "Z27") } - f.Comment("TODO @gbotrel this is generated from the same macro as the unit mul, we should refactor this in a single asm function") + f.Comment("Transpose and expand x and y") - f.MulADX(®isters, xat, yat, t) + // Step 1 - // registers.Push(addrA) + f.VSHUFI64X2("$0x88", "Z17", "Z16", "Z20") + f.VSHUFI64X2("$0xdd", "Z17", "Z16", "Z22") + f.VSHUFI64X2("$0x88", "Z19", "Z18", "Z21") + f.VSHUFI64X2("$0xdd", "Z19", "Z18", "Z23") - // reduce; we need at least 4 extra registers - registers.Push(amd64.AX, amd64.DX) - f.Comment("reduce t mod q") - f.Reduce(®isters, t) - f.Mov(t, addrRes) + f.VSHUFI64X2("$0x88", "Z25", "Z24", "Z28") + f.VSHUFI64X2("$0xdd", "Z25", "Z24", "Z30") + f.VSHUFI64X2("$0x88", "Z27", "Z26", "Z29") + f.VSHUFI64X2("$0xdd", "Z27", "Z26", "Z31") - f.Comment("increment pointers to visit next element") - f.ADDQ("$32", addrA) - f.ADDQ("$32", addrRes) - f.DECQ(len, "decrement n") + // Step 2 + + f.VPERMQ("$0xd8", "Z20", "Z20") + f.VPERMQ("$0xd8", "Z21", "Z21") + f.VPERMQ("$0xd8", "Z22", "Z22") + f.VPERMQ("$0xd8", "Z23", "Z23") + + mulXi(0) + + f.VPERMQ("$0xd8", "Z28", "Z28") + f.VPERMQ("$0xd8", "Z29", "Z29") + f.VPERMQ("$0xd8", "Z30", "Z30") + f.VPERMQ("$0xd8", "Z31", "Z31") + + // Step 3 + + for i := 20; i <= 23; i++ { + f.VSHUFI64X2("$0xd8", zi(i), zi(i), zi(i)) + } + + mulXi(1) + + for i := 28; i <= 31; i++ { + f.VSHUFI64X2("$0xd8", zi(i), zi(i), zi(i)) + } + + // Step 4 + + f.VSHUFI64X2("$0x44", "Z21", "Z20", "Z16") + f.VSHUFI64X2("$0xee", "Z21", "Z20", "Z18") + f.VSHUFI64X2("$0x44", "Z23", "Z22", "Z20") + f.VSHUFI64X2("$0xee", "Z23", "Z22", "Z22") + + mulXi(2) + f.VSHUFI64X2("$0x44", "Z29", "Z28", "Z24") + f.VSHUFI64X2("$0xee", "Z29", "Z28", "Z26") + f.VSHUFI64X2("$0x44", "Z31", "Z30", "Z28") + f.VSHUFI64X2("$0xee", "Z31", "Z30", "Z30") + + f.WriteLn("PREFETCHT0 1024(" + string(PX) + ")") + + // Step 5 + + f.VPSRLQ("$32", "Z16", "Z17") + f.VPSRLQ("$32", "Z18", "Z19") + f.VPSRLQ("$32", "Z20", "Z21") + f.VPSRLQ("$32", "Z22", "Z23") + + for i := 24; i <= 30; i += 2 { + f.VPSRLQ("$32", zi(i), zi(i+1)) + } + mulXi(3) + + for i := 16; i <= 30; i += 2 { + f.VPANDQ("Z8", zi(i), zi(i)) + } + + storeOutput() + + f.Comment("For each 256-bit input value, each zmm register now represents a 32-bit input word zero-extended to 64 bits.") + f.Comment("Multiply y by doubleword 0 of x") + + for i := 0; i < 8; i++ { + f.VPMULUDQ("Z16", zi(24+i), zi(i)) + if i == 4 { + if !scalarMul { + f.WriteLn("PREFETCHT0 1024(" + string(PY) + ")") + } + } + } + + loadInput() + + f.VPMULUDQ_BCST("qInvNeg+32(FP)", "Z0", "Z9") + + for i := 0; i < 4; i++ { + f.VPSRLQ("$32", zi(i), zi(10+i)) + f.VPANDQ("Z8", zi(i), zi(i)) + f.VPADDQ(zi(10+i), zi(i+1), zi(i+1)) + + } + + mulXi(0) + + for i := 0; i < 3; i++ { + f.VPSRLQ("$32", zi(4+i), zi(14+i)) + f.VPANDQ("Z8", zi(4+i), zi(4+i)) + f.VPADDQ(zi(14+i), zi(5+i), zi(5+i)) + + } + + for i := 0; i < 4; i++ { + f.VPMULUDQ_BCST(f.qAt_bcst(i), "Z9", zi(10+i)) + f.VPADDQ(zi(10+i), zi(i), zi(i)) + } + + mulXi(1) + + f.VPMULUDQ_BCST(f.qAt_bcst(4), "Z9", "Z14") + f.VPADDQ("Z14", "Z4", "Z4") + + f.VPMULUDQ_BCST(f.qAt_bcst(5), "Z9", "Z15") + f.VPADDQ("Z15", "Z5", "Z5") + + f.VPMULUDQ_BCST(f.qAt_bcst(6), "Z9", "Z16") + f.VPADDQ("Z16", "Z6", "Z6") + + f.VPMULUDQ_BCST(f.qAt_bcst(7), "Z9", "Z10") + f.VPADDQ("Z10", "Z7", "Z7") + + CARRY1() + + mulXi(2) + + for i := 0; i < 3; i++ { + SHIFT_ADD_AND(zi(4+i), zi(14+i), zi(5+i), "Z8") + } + f.VPSRLQ("$32", "Z7", "Z7") + + f.Comment("Process doubleword 1 of x") + + for i := 0; i < 4; i++ { + f.VPMULUDQ("Z17", zi(24+i), zi(10+i)) + f.VPADDQ(zi(10+i), zi(i), zi(i)) + + } + + mulXi(3) + + for i := 0; i < 4; i++ { + f.VPMULUDQ("Z17", zi(28+i), zi(14+i)) + f.VPADDQ(zi(14+i), zi(4+i), zi(4+i)) + + } + + f.VPMULUDQ_BCST("qInvNeg+32(FP)", "Z0", "Z9") + + storeOutput() + + f.Comment("Move high dwords to zmm10-16, add each to the corresponding low dword (propagate 32-bit carries)") + + for i := 0; i < 3; i++ { + f.VPSRLQ("$32", zi(i), zi(10+i)) + f.VPANDQ("Z8", zi(i), zi(i)) + f.VPADDQ(zi(10+i), zi(i+1), zi(i+1)) + } + loadInput() + + f.VPSRLQ("$32", "Z3", "Z13") + f.VPANDQ("Z8", "Z3", "Z3") + f.VPADDQ("Z13", "Z4", "Z4") + + CARRY4() + mulXi(0) + + AVX_MUL_Q_LO() + + AVX_MUL_Q_HI() + mulXi(1) + + CARRY1() + + CARRY2() + mulXi(2) + + f.Comment("Process doubleword 2 of x") + + for i := 0; i < 4; i++ { + f.VPMULUDQ("Z18", zi(24+i), zi(10+i)) + f.VPADDQ(zi(10+i), zi(i), zi(i)) + } + + for i := 0; i < 4; i++ { + f.VPMULUDQ("Z18", zi(28+i), zi(14+i)) + f.VPADDQ(zi(14+i), zi(4+i), zi(4+i)) + + } + + f.VPMULUDQ_BCST("qInvNeg+32(FP)", "Z0", "Z9") + + mulXi(3) + + f.Comment("Move high dwords to zmm10-16, add each to the corresponding low dword (propagate 32-bit carries)") + CARRY3() + + storeOutput() + loadInput() + + CARRY4() + + AVX_MUL_Q_LO() + + mulXi(0) + AVX_MUL_Q_HI() + + CARRY1() + CARRY2() + + f.Comment("Process doubleword 3 of x") + + for i := 0; i < 4; i++ { + f.VPMULUDQ("Z19", zi(24+i), zi(10+i)) + f.VPADDQ(zi(10+i), zi(i), zi(i)) + + } + + mulXi(1) + + for i := 0; i < 4; i++ { + f.VPMULUDQ("Z19", zi(28+i), zi(14+i)) + f.VPADDQ(zi(14+i), zi(4+i), zi(4+i)) + + } + mulXi(2) + f.VPMULUDQ_BCST("qInvNeg+32(FP)", "Z0", "Z9") + + // Move high dwords to zmm10-16, add each to the corresponding low dword (propagate 32-bit carries) + CARRY3() + CARRY4() + mulXi(3) + + AVX_MUL_Q_LO() + + AVX_MUL_Q_HI() + + storeOutput() + + f.Comment("Propagate carries and shift down by one dword") + CARRY1() + + CARRY2() + + loadInput() + + f.Comment("Process doubleword 4 of x") + + for i := 0; i < 4; i++ { + f.VPMULUDQ("Z20", zi(24+i), zi(10+i)) + f.VPADDQ(zi(10+i), zi(i), zi(i)) + + } + mulXi(0) + for i := 0; i < 4; i++ { + f.VPMULUDQ("Z20", zi(28+i), zi(14+i)) + f.VPADDQ(zi(14+i), zi(4+i), zi(4+i)) + + } + + f.VPMULUDQ_BCST("qInvNeg+32(FP)", "Z0", "Z9") + mulXi(1) + + f.Comment("Move high dwords to zmm10-16, add each to the corresponding low dword (propagate 32-bit carries)") + + CARRY3() + + CARRY4() + mulXi(2) + + f.Comment("zmm7 keeps all 64 bits") + + AVX_MUL_Q_LO() + + AVX_MUL_Q_HI() + + mulXi(3) + + f.Comment("Propagate carries and shift down by one dword") + + CARRY1() + + CARRY2() + + storeOutput() + + f.Comment("Process doubleword 5 of x") + + for i := 0; i < 4; i++ { + f.VPMULUDQ("Z21", zi(24+i), zi(10+i)) + f.VPADDQ(zi(10+i), zi(i), zi(i)) + + } + loadInput() + for i := 0; i < 4; i++ { + f.VPMULUDQ("Z21", zi(28+i), zi(14+i)) + f.VPADDQ(zi(14+i), zi(4+i), zi(4+i)) + + } + + mulXi(0) + + f.VPMULUDQ_BCST("qInvNeg+32(FP)", "Z0", "Z9") + + f.Comment("Move high dwords to zmm10-16, add each to the corresponding low dword (propagate 32-bit carries)") + CARRY3() + + CARRY4() + + mulXi(1) + + AVX_MUL_Q_LO() + + AVX_MUL_Q_HI() + + mulXi(2) + + CARRY1() + + CARRY2() + + mulXi(3) + + f.Comment("Process doubleword 6 of x") + + for i := 0; i < 8; i++ { + f.VPMULUDQ("Z22", zi(24+i), zi(10+i)) + f.VPADDQ(zi(10+i), zi(i), zi(i)) + } + + f.VPMULUDQ_BCST("qInvNeg+32(FP)", "Z0", "Z9") + + storeOutput() + + f.Comment("Move high dwords to zmm10-16, add each to the corresponding low dword (propagate 32-bit carries)") + CARRY3() + loadInput() + CARRY4() + + mulXi(0) + + AVX_MUL_Q_LO() + + AVX_MUL_Q_HI() + + mulXi(1) + + CARRY1() + + CARRY2() + + mulXi(2) + + f.Comment("Process doubleword 7 of x") + for i := 0; i < 4; i++ { + f.VPMULUDQ("Z23", zi(24+i), zi(10+i)) + f.VPADDQ(zi(10+i), zi(i), zi(i)) + + } + + for i := 0; i < 4; i++ { + f.VPMULUDQ("Z23", zi(28+i), zi(14+i)) + f.VPADDQ(zi(14+i), zi(4+i), zi(4+i)) + + } + f.VPMULUDQ_BCST("qInvNeg+32(FP)", "Z0", "Z9") + + mulXi(3) + + CARRY3() + storeOutput() + CARRY4() + + loadInput() + + AVX_MUL_Q_LO() + + AVX_MUL_Q_HI() + + mulXi(0) + + CARRY1() + + CARRY2() + + mulXi(1) + + f.Comment("Conditional subtraction of the modulus") + + for i := 0; i < 8; i++ { + f.VPERMD_BCST_Z(f.qAt_bcst(i), "Z8", "K1", zi(10+i)) + } + + for i := 0; i < 8; i++ { + f.VPSUBQ(zi(10+i), zi(i), zi(10+i)) + if i > 0 { + f.VPSUBQ(zi(20+i-1), zi(10+i), zi(10+i)) + } + if i != 7 { + f.VPSRLQ("$63", zi(10+i), zi(20+i)) + f.VPANDQ("Z8", zi(10+i), zi(10+i)) + } + + } + + f.VPMOVQ2M("Z17", "K2") + f.KNOTB("K2", "K2") + + for i := 0; i < 8; i++ { + f.VMOVDQU64k(zi(10+i), "K2", zi(i)) + if i == 4 { + mulXi(2) + } + } + + f.Comment("Transpose results back") + + f.WriteLn("VALIGND $0, ·pattern1+0(SB), Z11, Z11") + f.WriteLn("VALIGND $0, ·pattern2+0(SB), Z12, Z12") + f.WriteLn("VALIGND $0, ·pattern3+0(SB), Z13, Z13") + f.WriteLn("VALIGND $0, ·pattern4+0(SB), Z14, Z14") + + for i := 0; i < 4; i++ { + f.VPSLLQ("$32", zi(2*i+1), zi(2*i+1)) + f.VPORQ(zi(2*i+1), zi(2*i), zi(i)) + } + + f.VMOVDQU64("Z0", "Z4") + f.VMOVDQU64("Z2", "Z6") + + mulXi(3) + f.VPERMT2Q("Z1", "Z11", "Z0") + f.VPERMT2Q("Z4", "Z12", "Z1") + f.VPERMT2Q("Z3", "Z11", "Z2") + f.VPERMT2Q("Z6", "Z12", "Z3") + + // Step 3 + storeOutput() + + f.VMOVDQU64("Z0", "Z4") + f.VMOVDQU64("Z1", "Z5") + f.VPERMT2Q("Z2", "Z13", "Z0") + f.VPERMT2Q("Z4", "Z14", "Z2") + f.VPERMT2Q("Z3", "Z13", "Z1") + f.VPERMT2Q("Z5", "Z14", "Z3") + + f.Comment("Save AVX-512 results") + + f.VMOVDQU64("Z0", "256+0*64("+PZ+")") + f.VMOVDQU64("Z2", "256+1*64("+PZ+")") + f.VMOVDQU64("Z1", "256+2*64("+PZ+")") + f.VMOVDQU64("Z3", "256+3*64("+PZ+")") + f.ADDQ("$512", PZ) + + if !scalarMul { + f.ADDQ("$512", PY) + } + + f.MOVQ(LEN, tr) + f.DECQ(tr, "decrement n") + f.MOVQ(tr, LEN) f.JMP(loop) f.LABEL(done) - f.RET() - // no ADX support - f.LABEL(noAdx) - - f.MOVQ("n+24(FP)", amd64.DX) - - f.MOVQ("res+0(FP)", amd64.AX) - f.MOVQ(amd64.AX, "(SP)") - f.MOVQ(amd64.DX, "8(SP)") // len - f.MOVQ(amd64.DX, "16(SP)") // cap - f.MOVQ("a+8(FP)", amd64.AX) - f.MOVQ(amd64.AX, "24(SP)") - f.MOVQ(amd64.DX, "32(SP)") // len - f.MOVQ(amd64.DX, "40(SP)") // cap - f.MOVQ("b+16(FP)", amd64.AX) - f.MOVQ(amd64.AX, "48(SP)") - f.WriteLn("CALL ·scalarMulVecGeneric(SB)") f.RET() + f.Push(®isters, LEN) + } diff --git a/field/generator/config/field_config.go b/field/generator/config/field_config.go index 457a89d7d..6d24aab74 100644 --- a/field/generator/config/field_config.go +++ b/field/generator/config/field_config.go @@ -51,7 +51,9 @@ type FieldConfig struct { Q []uint64 QInverse []uint64 QMinusOneHalvedP []uint64 // ((q-1) / 2 ) + 1 + Mu uint64 // mu = 2^288 / q for 4.5 word barrett reduction ASM bool + ASMVector bool RSquare []uint64 One, Thirteen []uint64 LegendreExponent string // big.Int to base16 string @@ -260,6 +262,16 @@ func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) // moduli that meet the condition F.NoCarry // asm code generation for moduli with more than 6 words can be optimized further F.ASM = F.NoCarry && F.NbWords <= 12 && F.NbWords > 1 + F.ASMVector = F.ASM && F.NbWords == 4 && F.NbBits > 225 + + // setting Mu 2^288 / q + if F.NbWords == 4 { + _mu := big.NewInt(1) + _mu.Lsh(_mu, 288) + _mu.Div(_mu, &bModulus) + muSlice := toUint64Slice(_mu, F.NbWords) + F.Mu = muSlice[0] + } return F, nil } diff --git a/field/generator/generator.go b/field/generator/generator.go index e149ac9ae..015089767 100644 --- a/field/generator/generator.go +++ b/field/generator/generator.go @@ -22,7 +22,7 @@ import ( // // fp, _ = config.NewField("fp", "Element", fpModulus") // generator.GenerateFF(fp, filepath.Join(baseDir, "fp")) -func GenerateFF(F *config.FieldConfig, outputDir string) error { +func GenerateFF(F *config.FieldConfig, outputDir, asmDirBuildPath, asmDirIncludePath string) error { // source file templates sourceFiles := []string{ element.Base, @@ -73,6 +73,8 @@ func GenerateFF(F *config.FieldConfig, outputDir string) error { } _ = os.Remove(filepath.Join(outputDir, "asm.go")) _ = os.Remove(filepath.Join(outputDir, "asm_noadx.go")) + _ = os.Remove(filepath.Join(outputDir, "avx.go")) + _ = os.Remove(filepath.Join(outputDir, "noavx.go")) funcs := template.FuncMap{} if F.UseAddChain { @@ -137,33 +139,7 @@ func GenerateFF(F *config.FieldConfig, outputDir string) error { _, _ = io.WriteString(f, "// +build !purego\n") - if err := amd64.Generate(f, F); err != nil { - _ = f.Close() - return err - } - _ = f.Close() - - // run asmfmt - // run go fmt on whole directory - cmd := exec.Command("asmfmt", "-w", pathSrc) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - return err - } - } - - { - pathSrc := filepath.Join(outputDir, eName+"_mul_amd64.s") - fmt.Println("generating", pathSrc) - f, err := os.Create(pathSrc) - if err != nil { - return err - } - - _, _ = io.WriteString(f, "// +build !purego\n") - - if err := amd64.GenerateMul(f, F); err != nil { + if err := amd64.GenerateFieldWrapper(f, F, asmDirBuildPath, asmDirIncludePath); err != nil { _ = f.Close() return err } @@ -234,7 +210,7 @@ func GenerateFF(F *config.FieldConfig, outputDir string) error { src := []string{ element.Asm, } - pathSrc := filepath.Join(outputDir, "asm.go") + pathSrc := filepath.Join(outputDir, "asm_adx.go") bavardOptsCpy := make([]func(*bavard.Bavard) error, len(bavardOpts)) copy(bavardOptsCpy, bavardOpts) bavardOptsCpy = append(bavardOptsCpy, bavard.BuildTag("!noadx")) @@ -256,6 +232,34 @@ func GenerateFF(F *config.FieldConfig, outputDir string) error { } } + if F.ASMVector { + // generate asm.go and asm_noadx.go + src := []string{ + element.Avx, + } + pathSrc := filepath.Join(outputDir, "asm_avx.go") + bavardOptsCpy := make([]func(*bavard.Bavard) error, len(bavardOpts)) + copy(bavardOptsCpy, bavardOpts) + bavardOptsCpy = append(bavardOptsCpy, bavard.BuildTag("!noavx")) + if err := bavard.GenerateFromString(pathSrc, src, F, bavardOptsCpy...); err != nil { + return err + } + } + + if F.ASMVector { + // generate asm.go and asm_noadx.go + src := []string{ + element.NoAvx, + } + pathSrc := filepath.Join(outputDir, "asm_noavx.go") + bavardOptsCpy := make([]func(*bavard.Bavard) error, len(bavardOpts)) + copy(bavardOptsCpy, bavardOpts) + bavardOptsCpy = append(bavardOptsCpy, bavard.BuildTag("noavx")) + if err := bavard.GenerateFromString(pathSrc, src, F, bavardOptsCpy...); err != nil { + return err + } + } + // run go fmt on whole directory cmd := exec.Command("gofmt", "-s", "-w", outputDir) cmd.Stdout = os.Stdout @@ -274,3 +278,31 @@ func shorten(input string) string { } return input } + +func GenerateCommonASM(nbWords int, asmDir string, hasVector bool) error { + os.MkdirAll(asmDir, 0755) + pathSrc := filepath.Join(asmDir, fmt.Sprintf(amd64.ElementASMFileName, nbWords)) + + fmt.Println("generating", pathSrc) + f, err := os.Create(pathSrc) + if err != nil { + return err + } + + if err := amd64.GenerateCommonASM(f, nbWords, hasVector); err != nil { + _ = f.Close() + return err + } + _ = f.Close() + + // run asmfmt + // run go fmt on whole directory + cmd := exec.Command("asmfmt", "-w", pathSrc) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return err + } + + return nil +} diff --git a/field/generator/generator_test.go b/field/generator/generator_test.go index e107a0e85..85b79cfdb 100644 --- a/field/generator/generator_test.go +++ b/field/generator/generator_test.go @@ -21,27 +21,32 @@ import ( "os" "os/exec" "path/filepath" + "strings" "testing" field "github.com/consensys/gnark-crypto/field/generator/config" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" ) // integration test will create modulus for various field sizes and run tests const rootDir = "integration_test" +const asmDir = "../asm" func TestIntegration(t *testing.T) { + assert := require.New(t) + os.RemoveAll(rootDir) err := os.MkdirAll(rootDir, 0700) defer os.RemoveAll(rootDir) - if err != nil { - t.Fatal(err) - } + assert.NoError(err) var bits []int for i := 64; i <= 448; i += 64 { bits = append(bits, i-3, i-2, i-1, i, i+1) } + bits = append(bits, 224, 225, 226) moduli := make(map[string]string) for _, i := range bits { @@ -73,17 +78,18 @@ func TestIntegration(t *testing.T) { moduli["e_nocarry_edge_0127"] = "170141183460469231731687303715884105727" moduli["e_nocarry_edge_1279"] = "10407932194664399081925240327364085538615262247266704805319112350403608059673360298012239441732324184842421613954281007791383566248323464908139906605677320762924129509389220345773183349661583550472959420547689811211693677147548478866962501384438260291732348885311160828538416585028255604666224831890918801847068222203140521026698435488732958028878050869736186900714720710555703168729087" + assert.NoError(GenerateCommonASM(2, asmDir, false)) + assert.NoError(GenerateCommonASM(3, asmDir, false)) + assert.NoError(GenerateCommonASM(7, asmDir, false)) + assert.NoError(GenerateCommonASM(8, asmDir, false)) + for elementName, modulus := range moduli { var fIntegration *field.FieldConfig // generate field childDir := filepath.Join(rootDir, elementName) fIntegration, err = field.NewFieldConfig("integration", elementName, modulus, false) - if err != nil { - t.Fatal(elementName, err) - } - if err = GenerateFF(fIntegration, childDir); err != nil { - t.Fatal(elementName, err) - } + assert.NoError(err) + assert.NoError(GenerateFF(fIntegration, childDir, asmDir, "../../../asm")) } // run go test @@ -91,14 +97,39 @@ func TestIntegration(t *testing.T) { if err != nil { t.Fatal(err) } - packageDir := filepath.Join(wd, rootDir) + string(filepath.Separator) + "..." - cmd := exec.Command("go", "test", packageDir) - if err := cmd.Run(); err != nil { - if exitErr, ok := err.(*exec.ExitError); ok { - t.Fatal(string(exitErr.Stderr)) - } else { - t.Fatal(err) + packageDir := filepath.Join(wd, rootDir) // + string(filepath.Separator) + "..." + + // list all subdirectories in package dir + var subDirs []string + err = filepath.Walk(packageDir, func(path string, info os.FileInfo, err error) error { + if info.IsDir() && path != packageDir { + subDirs = append(subDirs, path) } + return nil + }) + if err != nil { + t.Fatal(err) + } + + errGroup := errgroup.Group{} + + for _, subDir := range subDirs { + // run go test in parallel + errGroup.Go(func() error { + cmd := exec.Command("go", "test") + cmd.Dir = subDir + var stdouterr strings.Builder + cmd.Stdout = &stdouterr + cmd.Stderr = &stdouterr + if err := cmd.Run(); err != nil { + return fmt.Errorf("go test failed, output:\n%s\n%s", stdouterr.String(), err) + } + return nil + }) + } + + if err := errGroup.Wait(); err != nil { + t.Fatal(err) } } diff --git a/field/generator/internal/templates/element/asm.go b/field/generator/internal/templates/element/asm.go index c1027f148..ed7fac7b4 100644 --- a/field/generator/internal/templates/element/asm.go +++ b/field/generator/internal/templates/element/asm.go @@ -10,6 +10,19 @@ var ( ) ` +const Avx = ` +import "golang.org/x/sys/cpu" + +var ( + supportAvx512 = supportAdx && cpu.X86.HasAVX512 && cpu.X86.HasAVX512DQ + _ = supportAvx512 +) +` + +const NoAvx = ` +const supportAvx512 = false +` + // AsmNoAdx ... const AsmNoAdx = ` diff --git a/field/generator/internal/templates/element/base.go b/field/generator/internal/templates/element/base.go index 7690f0657..7dc8bffc4 100644 --- a/field/generator/internal/templates/element/base.go +++ b/field/generator/internal/templates/element/base.go @@ -49,7 +49,7 @@ const ( {{- end}} ) -var q{{.ElementName}} = {{.ElementName}}{ +var qElement = {{.ElementName}}{ {{- range $i := $.NbWordsIndexesFull}} q{{$i}},{{end}} } @@ -68,6 +68,12 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = {{index .QInverse 0}} +{{- if eq .NbWords 4}} +// mu = 2^288 / q needed for partial Barrett reduction +const mu uint64 = {{.Mu}} +{{- end}} + + func init() { _modulus.SetString("{{.ModulusHex}}", 16) } diff --git a/field/generator/internal/templates/element/mul_cios.go b/field/generator/internal/templates/element/mul_cios.go index 070d07ec8..6fb554c3b 100644 --- a/field/generator/internal/templates/element/mul_cios.go +++ b/field/generator/internal/templates/element/mul_cios.go @@ -2,6 +2,9 @@ package element // MulCIOS text book CIOS works for all modulus. // +// Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" +// by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 +// // There are couple of variations to the multiplication (and squaring) algorithms. // // All versions are derived from the Montgomery CIOS algorithm: see @@ -126,49 +129,7 @@ const MulCIOS = ` const MulDoc = ` {{define "mul_doc noCarry"}} -// Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis -// https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf -// -// The algorithm: -// -// for i=0 to N-1 -// C := 0 -// for j=0 to N-1 -// (C,t[j]) := t[j] + x[j]*y[i] + C -// (t[N+1],t[N]) := t[N] + C -// -// C := 0 -// m := t[0]*q'[0] mod D -// (C,_) := t[0] + m*q[0] -// for j=1 to N-1 -// (C,t[j-1]) := t[j] + m*q[j] + C -// -// (C,t[N-1]) := t[N] + C -// t[N] := t[N+1] + C -// -// → N is the number of machine words needed to store the modulus q -// → D is the word size. For example, on a 64-bit architecture D is 2 64 -// → x[i], y[i], q[i] is the ith word of the numbers x,y,q -// → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. -// → t is a temporary array of size N+2 -// → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number -{{- if .noCarry}} -// -// As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: -// (also described in https://eprint.iacr.org/2022/1400.pdf annex) -// -// for i=0 to N-1 -// (A,t[0]) := t[0] + x[0]*y[i] -// m := t[0]*q'[0] mod W -// C,_ := t[0] + m*q[0] -// for j=1 to N-1 -// (A,t[j]) := t[j] + x[j]*y[i] + A -// (C,t[j-1]) := t[j] + m*q[j] + C -// -// t[N-1] = C + A -// -// This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit -// of the modulus is zero (and not all of the remaining bits are set). -{{- end}} +// Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" +// by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 {{ end }} ` diff --git a/field/generator/internal/templates/element/mul_nocarry.go b/field/generator/internal/templates/element/mul_nocarry.go index 0ec89f7a8..14740fd4a 100644 --- a/field/generator/internal/templates/element/mul_nocarry.go +++ b/field/generator/internal/templates/element/mul_nocarry.go @@ -1,6 +1,8 @@ package element -// MulNoCarry see https://eprint.iacr.org/2022/1400.pdf annex for more info on the algorithm +// MulNoCarry +// Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" +// by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 // Note that these templates are optimized for arm64 target, since x86 benefits from assembly impl. const MulNoCarry = ` {{ define "mul_nocarry" }} diff --git a/field/generator/internal/templates/element/ops_asm.go b/field/generator/internal/templates/element/ops_asm.go index ffa7231e1..1d1408e26 100644 --- a/field/generator/internal/templates/element/ops_asm.go +++ b/field/generator/internal/templates/element/ops_asm.go @@ -29,14 +29,15 @@ func reduce(res *{{.ElementName}}) //go:noescape func Butterfly(a, b *{{.ElementName}}) -{{- if eq .NbWords 4}} +{{- if .ASMVector}} // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { if len(a) != len(b) || len(a) != len(*vector) { panic("vector.Add: vectors don't have the same length") } - addVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) + n := uint64(len(a)) + addVec(&(*vector)[0], &a[0], &b[0], n) } //go:noescape @@ -60,11 +61,116 @@ func (vector *Vector) ScalarMul(a Vector, b *{{.ElementName}}) { if len(a) != len(*vector) { panic("vector.ScalarMul: vectors don't have the same length") } - scalarMulVec(&(*vector)[0], &a[0], b, uint64(len(a))) + const maxN = (1 << 32) - 1 + if !supportAvx512 || uint64(len(a)) >= maxN { + // call scalarMulVecGeneric + scalarMulVecGeneric(*vector, a, b) + return + } + n := uint64(len(a)) + if n == 0 { + return + } + // the code for scalarMul is identical to mulVec; and it expects at least + // 2 elements in the vector to fill the Z registers + var bb [2]{{.ElementName}} + bb[0] = *b + bb[1] = *b + const blockSize = 16 + scalarMulVec(&(*vector)[0], &a[0], &bb[0], n/blockSize, qInvNeg) + if n % blockSize != 0 { + // call scalarMulVecGeneric on the rest + start := n - n % blockSize + scalarMulVecGeneric((*vector)[start:], a[start:], b) + } +} + +//go:noescape +func scalarMulVec(res, a, b *{{.ElementName}}, n uint64, qInvNeg uint64) + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res {{.ElementName}}) { + n := uint64(len(*vector)) + if n == 0 { + return + } + const minN = 16*7 // AVX512 slower than generic for small n + const maxN = (1 << 32) - 1 + if !supportAvx512 || n <= minN || n >= maxN { + // call sumVecGeneric + sumVecGeneric(&res, *vector) + return + } + sumVec(&res, &(*vector)[0], uint64(len(*vector))) + return } //go:noescape -func scalarMulVec(res, a, b *{{.ElementName}}, n uint64) +func sumVec(res *{{.ElementName}}, a *{{.ElementName}}, n uint64) + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res {{.ElementName}}) { + n := uint64(len(*vector)) + if n == 0 { + return + } + if n != uint64(len(other)) { + panic("vector.InnerProduct: vectors don't have the same length") + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call innerProductVecGeneric + // note; we could split the vector into smaller chunks and call innerProductVec + innerProductVecGeneric(&res, *vector, other) + return + } + innerProdVec(&res[0], &(*vector)[0], &other[0], uint64(len(*vector))) + + return +} + +//go:noescape +func innerProdVec(res *uint64, a,b *{{.ElementName}}, n uint64) + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Mul: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call mulVecGeneric + mulVecGeneric(*vector, a, b) + return + } + + const blockSize = 16 + mulVec(&(*vector)[0], &a[0], &b[0], n/blockSize, qInvNeg) + if n % blockSize != 0 { + // call mulVecGeneric on the rest + start := n - n % blockSize + mulVecGeneric((*vector)[start:], a[start:], b[start:]) + } + +} + +// Patterns use for transposing the vectors in mulVec +var ( + pattern1 = [8]uint64{0, 8, 1, 9, 2, 10, 3, 11} + pattern2 = [8]uint64{12, 4, 13, 5, 14, 6, 15, 7} + pattern3 = [8]uint64{0, 1, 8, 9, 2, 3, 10, 11} + pattern4 = [8]uint64{12, 13, 4, 5, 14, 15, 6, 7} +) + +//go:noescape +func mulVec(res, a, b *{{.ElementName}}, n uint64, qInvNeg uint64) + {{- end}} // Mul z = x * y (mod q) diff --git a/field/generator/internal/templates/element/ops_purego.go b/field/generator/internal/templates/element/ops_purego.go index a4fde0d05..fe3cead3c 100644 --- a/field/generator/internal/templates/element/ops_purego.go +++ b/field/generator/internal/templates/element/ops_purego.go @@ -50,7 +50,7 @@ func reduce(z *{{.ElementName}}) { _reduceGeneric(z) } -{{- if eq .NbWords 4}} +{{- if .ASMVector}} // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { @@ -68,6 +68,26 @@ func (vector *Vector) Sub(a, b Vector) { func (vector *Vector) ScalarMul(a Vector, b *{{.ElementName}}) { scalarMulVecGeneric(*vector, a, b) } + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res {{.ElementName}}) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res {{.ElementName}}) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} + {{- end}} // Mul z = x * y (mod q) diff --git a/field/generator/internal/templates/element/tests.go b/field/generator/internal/templates/element/tests.go index 416b3f30e..d09dbeb7e 100644 --- a/field/generator/internal/templates/element/tests.go +++ b/field/generator/internal/templates/element/tests.go @@ -331,14 +331,14 @@ func init() { staticTestValues = append(staticTestValues, rSquare) // r² var e, one {{.ElementName}} one.SetOne() - e.Sub(&q{{.ElementName}}, &one) + e.Sub(&qElement, &one) staticTestValues = append(staticTestValues, e) // q - 1 e.Double(&one) staticTestValues = append(staticTestValues, e) // 2 { - a := q{{.ElementName}} + a := qElement a[0]-- staticTestValues = append(staticTestValues, a) } @@ -354,14 +354,14 @@ func init() { {{- end}} { - a := q{{.ElementName}} + a := qElement a[{{.NbWordsLastIndex}}]-- staticTestValues = append(staticTestValues, a) } {{- if ne .NbWords 1}} { - a := q{{.ElementName}} + a := qElement a[{{.NbWordsLastIndex}}]-- a[0]++ staticTestValues = append(staticTestValues, a) @@ -369,7 +369,7 @@ func init() { {{- end}} { - a := q{{.ElementName}} + a := qElement a[{{.NbWordsLastIndex}}] = 0 staticTestValues = append(staticTestValues, a) } @@ -653,8 +653,6 @@ func Test{{toTitle .ElementName}}BitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - - } @@ -728,78 +726,6 @@ func Test{{toTitle .ElementName}}LexicographicallyLargest(t *testing.T) { } -func Test{{toTitle .ElementName}}VecOps(t *testing.T) { - assert := require.New(t) - - const N = 7 - a := make(Vector, N) - b := make(Vector, N) - c := make(Vector, N) - for i := 0; i < N; i++ { - a[i].SetRandom() - b[i].SetRandom() - } - - // Vector addition - c.Add(a, b) - for i := 0; i < N; i++ { - var expected {{.ElementName}} - expected.Add(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector addition failed") - } - - // Vector subtraction - c.Sub(a, b) - for i := 0; i < N; i++ { - var expected {{.ElementName}} - expected.Sub(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector subtraction failed") - } - - // Vector scaling - c.ScalarMul(a, &b[0]) - for i := 0; i < N; i++ { - var expected {{.ElementName}} - expected.Mul(&a[i], &b[0]) - assert.True(c[i].Equal(&expected), "Vector scaling failed") - } -} - -func Benchmark{{toTitle .ElementName}}VecOps(b *testing.B) { - // note; to benchmark against "no asm" version, use the following - // build tag: -tags purego - const N = 1024 - a1 := make(Vector, N) - b1 := make(Vector, N) - c1 := make(Vector, N) - for i := 0; i < N; i++ { - a1[i].SetRandom() - b1[i].SetRandom() - } - - - b.Run("Add", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Add(a1, b1) - } - }) - - b.Run("Sub", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Sub(a1, b1) - } - }) - - b.Run("ScalarMul", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.ScalarMul(a1, &b1[0]) - } - }) -} - {{template "testBinaryOp" dict "all" . "Op" "Add"}} {{template "testBinaryOp" dict "all" . "Op" "Sub"}} @@ -1651,8 +1577,8 @@ func gen() gopter.Gen { {{- range $i := .NbWordsIndexesFull}} genParams.NextUint64(),{{end}} } - if q{{.ElementName}}[{{.NbWordsLastIndex}}] != ^uint64(0) { - g.element[{{.NbWordsLastIndex}}] %= (q{{.ElementName}}[{{.NbWordsLastIndex}}] +1 ) + if qElement[{{.NbWordsLastIndex}}] != ^uint64(0) { + g.element[{{.NbWordsLastIndex}}] %= (qElement[{{.NbWordsLastIndex}}] +1 ) } @@ -1661,8 +1587,8 @@ func gen() gopter.Gen { {{- range $i := .NbWordsIndexesFull}} genParams.NextUint64(),{{end}} } - if q{{.ElementName}}[{{.NbWordsLastIndex}}] != ^uint64(0) { - g.element[{{.NbWordsLastIndex}}] %= (q{{.ElementName}}[{{.NbWordsLastIndex}}] +1 ) + if qElement[{{.NbWordsLastIndex}}] != ^uint64(0) { + g.element[{{.NbWordsLastIndex}}] %= (qElement[{{.NbWordsLastIndex}}] +1 ) } } @@ -1672,42 +1598,42 @@ func gen() gopter.Gen { } } +func genRandomFq(genParams *gopter.GenParameters) {{.ElementName}} { + var g {{.ElementName}} -func genFull() gopter.Gen { - return func(genParams *gopter.GenParameters) *gopter.GenResult { + g = {{.ElementName}}{ + {{- range $i := .NbWordsIndexesFull}} + genParams.NextUint64(),{{end}} + } - genRandomFq := func() {{.ElementName}} { - var g {{.ElementName}} + if qElement[{{.NbWordsLastIndex}}] != ^uint64(0) { + g[{{.NbWordsLastIndex}}] %= (qElement[{{.NbWordsLastIndex}}] +1 ) + } - g = {{.ElementName}}{ - {{- range $i := .NbWordsIndexesFull}} - genParams.NextUint64(),{{end}} - } + for !g.smallerThanModulus() { + g = {{.ElementName}}{ + {{- range $i := .NbWordsIndexesFull}} + genParams.NextUint64(),{{end}} + } + if qElement[{{.NbWordsLastIndex}}] != ^uint64(0) { + g[{{.NbWordsLastIndex}}] %= (qElement[{{.NbWordsLastIndex}}] +1 ) + } + } - if q{{.ElementName}}[{{.NbWordsLastIndex}}] != ^uint64(0) { - g[{{.NbWordsLastIndex}}] %= (q{{.ElementName}}[{{.NbWordsLastIndex}}] +1 ) - } + return g +} - for !g.smallerThanModulus() { - g = {{.ElementName}}{ - {{- range $i := .NbWordsIndexesFull}} - genParams.NextUint64(),{{end}} - } - if q{{.ElementName}}[{{.NbWordsLastIndex}}] != ^uint64(0) { - g[{{.NbWordsLastIndex}}] %= (q{{.ElementName}}[{{.NbWordsLastIndex}}] +1 ) - } - } - return g - } - a := genRandomFq() +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) var carry uint64 {{- range $i := .NbWordsIndexesFull}} {{- if eq $i $.NbWordsLastIndex}} - a[{{$i}}], _ = bits.Add64(a[{{$i}}], q{{$.ElementName}}[{{$i}}], carry) + a[{{$i}}], _ = bits.Add64(a[{{$i}}], qElement[{{$i}}], carry) {{- else}} - a[{{$i}}], carry = bits.Add64(a[{{$i}}], q{{$.ElementName}}[{{$i}}], carry) + a[{{$i}}], carry = bits.Add64(a[{{$i}}], qElement[{{$i}}], carry) {{- end}} {{- end}} @@ -1715,6 +1641,15 @@ func genFull() gopter.Gen { return genResult } } + +func gen{{.ElementName}}() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} + {{if $.UsingP20Inverse}} func (z *{{.ElementName}}) matchVeryBigInt(aHi uint64, aInt *big.Int) error { var modulus big.Int diff --git a/field/generator/internal/templates/element/tests_vector.go b/field/generator/internal/templates/element/tests_vector.go index 14a7c3e50..41187a2ca 100644 --- a/field/generator/internal/templates/element/tests_vector.go +++ b/field/generator/internal/templates/element/tests_vector.go @@ -9,9 +9,12 @@ import ( "sort" "reflect" "bytes" -) - + "os" + "fmt" + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" +) func TestVectorSort(t *testing.T) { assert := require.New(t) @@ -69,8 +72,6 @@ func TestVectorEmptyRoundTrip(t *testing.T) { assert.True(reflect.DeepEqual(v3,v2)) } - - func (vector *Vector) unmarshalBinaryAsync(data []byte) error { r := bytes.NewReader(data) _, err, chErr := vector.AsyncReadFrom(r) @@ -80,4 +81,280 @@ func (vector *Vector) unmarshalBinaryAsync(data []byte) error { return <-chErr } + + +func TestVectorOps(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = 2 + } else { + parameters.MinSuccessfulTests = 10 + } + properties := gopter.NewProperties(parameters) + + addVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Add(a, b) + + for i := 0; i < len(a); i++ { + var tmp {{.ElementName}} + tmp.Add(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + subVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Sub(a, b) + + for i := 0; i < len(a); i++ { + var tmp {{.ElementName}} + tmp.Sub(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + scalarMulVector := func(a Vector, b {{.ElementName}}) bool { + c := make(Vector, len(a)) + c.ScalarMul(a, &b) + + for i := 0; i < len(a); i++ { + var tmp {{.ElementName}} + tmp.Mul(&a[i], &b) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sumVector := func(a Vector) bool { + var sum {{.ElementName}} + computed := a.Sum() + for i := 0; i < len(a); i++ { + sum.Add(&sum, &a[i]) + } + + return sum.Equal(&computed) + } + + innerProductVector := func(a, b Vector) bool { + computed := a.InnerProduct(b) + var innerProduct {{.ElementName}} + for i := 0; i < len(a); i++ { + var tmp {{.ElementName}} + tmp.Mul(&a[i], &b[i]) + innerProduct.Add(&innerProduct, &tmp) + } + + return innerProduct.Equal(&computed) + } + + mulVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + a[0].SetUint64(0x24) + b[0].SetUint64(0x42) + c.Mul(a, b) + + for i := 0; i < len(a); i++ { + var tmp {{.ElementName}} + tmp.Mul(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sizes := []int{1, 2, 3, 4,8,9,15,16, 509, 510, 511, 512, 513, 514} + type genPair struct { + g1, g2 gopter.Gen + label string + } + + for _, size := range sizes { + generators := []genPair{ + {genZeroVector(size), genZeroVector(size), "zero vectors"}, + {genMaxVector(size), genMaxVector(size), "max vectors"}, + {genVector(size), genVector(size), "random vectors"}, + {genVector(size), genZeroVector(size), "random and zero vectors"}, + } + for _, gp := range generators { + properties.Property(fmt.Sprintf("vector addition %d - %s", size, gp.label), prop.ForAll( + addVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector subtraction %d - %s", size, gp.label), prop.ForAll( + subVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector scalar multiplication %d - %s", size, gp.label), prop.ForAll( + scalarMulVector, + gp.g1, + gen{{.ElementName}}(), + )) + + properties.Property(fmt.Sprintf("vector sum %d - %s", size, gp.label), prop.ForAll( + sumVector, + gp.g1, + )) + + properties.Property(fmt.Sprintf("vector inner product %d - %s", size, gp.label), prop.ForAll( + innerProductVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector multiplication %d - %s", size, gp.label), prop.ForAll( + mulVector, + gp.g1, + gp.g2, + )) + } + } + + properties.TestingRun(t, gopter.NewFormatedReporter(false, 260, os.Stdout)) +} + + +func BenchmarkVectorOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1<<24 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + var mixer {{.ElementName}} + mixer.SetRandom() + for i := 1; i < N; i++ { + a1[i-1].SetUint64(uint64(i)). + Mul(&a1[i-1], &mixer) + b1[i-1].SetUint64(^uint64(i)). + Mul(&b1[i-1], &mixer) + } + + for n:= 1<<4; n <= N; n <<= 1 { + b.Run(fmt.Sprintf("add %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Add(_a, _b) + } + }) + + b.Run(fmt.Sprintf("sub %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Sub(_a, _b) + } + }) + + b.Run(fmt.Sprintf("scalarMul %d", n), func(b *testing.B) { + _a := a1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.ScalarMul(_a, &mixer) + } + }) + + b.Run(fmt.Sprintf("sum %d", n), func(b *testing.B) { + _a := a1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.Sum() + } + }) + + b.Run(fmt.Sprintf("innerProduct %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.InnerProduct(_b) + } + }) + + b.Run(fmt.Sprintf("mul %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Mul(_a, _b) + } + }) + } +} + +func genZeroVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genMaxVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + + qMinusOne := qElement + qMinusOne[0]-- + + for i := 0; i < size; i++ { + g[i] = qMinusOne + } + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + mixer := {{.ElementName}}{ + {{- range $i := .NbWordsIndexesFull}} + genParams.NextUint64(),{{end}} + } + if qElement[{{.NbWordsLastIndex}}] != ^uint64(0) { + mixer[{{.NbWordsLastIndex}}] %= (qElement[{{.NbWordsLastIndex}}] +1 ) + } + + + for !mixer.smallerThanModulus() { + mixer = {{.ElementName}}{ + {{- range $i := .NbWordsIndexesFull}} + genParams.NextUint64(),{{end}} + } + if qElement[{{.NbWordsLastIndex}}] != ^uint64(0) { + mixer[{{.NbWordsLastIndex}}] %= (qElement[{{.NbWordsLastIndex}}] +1 ) + } + } + + for i := 1; i <= size; i++ { + g[i-1].SetUint64(uint64(i)). + Mul(&g[i-1], &mixer) + } + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + ` diff --git a/field/generator/internal/templates/element/vector.go b/field/generator/internal/templates/element/vector.go index 8f06b54c9..6407a024d 100644 --- a/field/generator/internal/templates/element/vector.go +++ b/field/generator/internal/templates/element/vector.go @@ -174,7 +174,6 @@ func (vector Vector) String() string { return sbb.String() } - // Len is the number of elements in the collection. func (vector Vector) Len() int { return len(vector) @@ -193,7 +192,7 @@ func (vector Vector) Swap(i, j int) { {{/* For 4 elements, we have a special assembly path and copy this in ops_pure.go */}} -{{- if ne .NbWords 4}} +{{- if not .ASMVector}} // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { @@ -211,6 +210,26 @@ func (vector *Vector) Sub(a, b Vector) { func (vector *Vector) ScalarMul(a Vector, b *{{.ElementName}}) { scalarMulVecGeneric(*vector, a, b) } + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res {{.ElementName}}) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res {{.ElementName}}) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} + {{- end}} @@ -242,6 +261,32 @@ func scalarMulVecGeneric(res, a Vector, b *{{.ElementName}}) { } } +func sumVecGeneric(res *{{.ElementName}}, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + +func innerProductVecGeneric(res *{{.ElementName}},a, b Vector) { + if len(a) != len(b) { + panic("vector.InnerProduct: vectors don't have the same length") + } + var tmp {{.ElementName}} + for i := 0; i < len(a); i++ { + tmp.Mul(&a[i], &b[i]) + res.Add(res, &tmp) + } +} + +func mulVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Mul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], &b[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/field/goff/cmd/root.go b/field/goff/cmd/root.go index 15ed94fc0..849e26da7 100644 --- a/field/goff/cmd/root.go +++ b/field/goff/cmd/root.go @@ -71,7 +71,14 @@ func cmdGenerate(cmd *cobra.Command, args []string) { fmt.Printf("\n%s\n", err.Error()) os.Exit(-1) } - if err := generator.GenerateFF(F, fOutputDir); err != nil { + + asmDir := filepath.Join(fOutputDir, "asm") + if err := generator.GenerateCommonASM(F.NbWords, asmDir, F.ASMVector); err != nil { + fmt.Printf("\n%s\n", err.Error()) + os.Exit(-1) + } + + if err := generator.GenerateFF(F, fOutputDir, asmDir, "asm/"); err != nil { fmt.Printf("\n%s\n", err.Error()) os.Exit(-1) } diff --git a/field/goldilocks/element.go b/field/goldilocks/element.go index 3afe5c447..8dd4d6991 100644 --- a/field/goldilocks/element.go +++ b/field/goldilocks/element.go @@ -412,32 +412,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [2]uint64 var D uint64 diff --git a/field/goldilocks/element_test.go b/field/goldilocks/element_test.go index 339fb4ea6..454d057db 100644 --- a/field/goldilocks/element_test.go +++ b/field/goldilocks/element_test.go @@ -582,7 +582,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -652,77 +651,6 @@ func TestElementLexicographicallyLargest(t *testing.T) { } -func TestElementVecOps(t *testing.T) { - assert := require.New(t) - - const N = 7 - a := make(Vector, N) - b := make(Vector, N) - c := make(Vector, N) - for i := 0; i < N; i++ { - a[i].SetRandom() - b[i].SetRandom() - } - - // Vector addition - c.Add(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Add(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector addition failed") - } - - // Vector subtraction - c.Sub(a, b) - for i := 0; i < N; i++ { - var expected Element - expected.Sub(&a[i], &b[i]) - assert.True(c[i].Equal(&expected), "Vector subtraction failed") - } - - // Vector scaling - c.ScalarMul(a, &b[0]) - for i := 0; i < N; i++ { - var expected Element - expected.Mul(&a[i], &b[0]) - assert.True(c[i].Equal(&expected), "Vector scaling failed") - } -} - -func BenchmarkElementVecOps(b *testing.B) { - // note; to benchmark against "no asm" version, use the following - // build tag: -tags purego - const N = 1024 - a1 := make(Vector, N) - b1 := make(Vector, N) - c1 := make(Vector, N) - for i := 0; i < N; i++ { - a1[i].SetRandom() - b1[i].SetRandom() - } - - b.Run("Add", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Add(a1, b1) - } - }) - - b.Run("Sub", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.Sub(a1, b1) - } - }) - - b.Run("ScalarMul", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - c1.ScalarMul(a1, &b1[0]) - } - }) -} - func TestElementAdd(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() @@ -2284,32 +2212,32 @@ func gen() gopter.Gen { } } -func genFull() gopter.Gen { - return func(genParams *gopter.GenParameters) *gopter.GenResult { +func genRandomFq(genParams *gopter.GenParameters) Element { + var g Element - genRandomFq := func() Element { - var g Element + g = Element{ + genParams.NextUint64(), + } - g = Element{ - genParams.NextUint64(), - } + if qElement[0] != ^uint64(0) { + g[0] %= (qElement[0] + 1) + } - if qElement[0] != ^uint64(0) { - g[0] %= (qElement[0] + 1) - } + for !g.smallerThanModulus() { + g = Element{ + genParams.NextUint64(), + } + if qElement[0] != ^uint64(0) { + g[0] %= (qElement[0] + 1) + } + } - for !g.smallerThanModulus() { - g = Element{ - genParams.NextUint64(), - } - if qElement[0] != ^uint64(0) { - g[0] %= (qElement[0] + 1) - } - } + return g +} - return g - } - a := genRandomFq() +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) var carry uint64 a[0], _ = bits.Add64(a[0], qElement[0], carry) @@ -2318,3 +2246,11 @@ func genFull() gopter.Gen { return genResult } } + +func genElement() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} diff --git a/field/goldilocks/internal/main.go b/field/goldilocks/internal/main.go index 4f5bacd3f..e9f8749a5 100644 --- a/field/goldilocks/internal/main.go +++ b/field/goldilocks/internal/main.go @@ -14,7 +14,7 @@ func main() { if err != nil { panic(err) } - if err := generator.GenerateFF(goldilocks, "../"); err != nil { + if err := generator.GenerateFF(goldilocks, "../", "", ""); err != nil { panic(err) } fmt.Println("successfully generated goldilocks field") diff --git a/field/goldilocks/vector.go b/field/goldilocks/vector.go index 3de71afb8..7411cb7bf 100644 --- a/field/goldilocks/vector.go +++ b/field/goldilocks/vector.go @@ -214,6 +214,25 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} + func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") @@ -241,6 +260,32 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + +func innerProductVecGeneric(res *Element, a, b Vector) { + if len(a) != len(b) { + panic("vector.InnerProduct: vectors don't have the same length") + } + var tmp Element + for i := 0; i < len(a); i++ { + tmp.Mul(&a[i], &b[i]) + res.Add(res, &tmp) + } +} + +func mulVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Mul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], &b[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/field/goldilocks/vector_test.go b/field/goldilocks/vector_test.go index e1ce6992c..8d45c0f5d 100644 --- a/field/goldilocks/vector_test.go +++ b/field/goldilocks/vector_test.go @@ -18,10 +18,15 @@ package goldilocks import ( "bytes" + "fmt" "github.com/stretchr/testify/require" + "os" "reflect" "sort" "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" ) func TestVectorSort(t *testing.T) { @@ -88,3 +93,273 @@ func (vector *Vector) unmarshalBinaryAsync(data []byte) error { } return <-chErr } + +func TestVectorOps(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = 2 + } else { + parameters.MinSuccessfulTests = 10 + } + properties := gopter.NewProperties(parameters) + + addVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Add(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Add(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + subVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Sub(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Sub(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + scalarMulVector := func(a Vector, b Element) bool { + c := make(Vector, len(a)) + c.ScalarMul(a, &b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sumVector := func(a Vector) bool { + var sum Element + computed := a.Sum() + for i := 0; i < len(a); i++ { + sum.Add(&sum, &a[i]) + } + + return sum.Equal(&computed) + } + + innerProductVector := func(a, b Vector) bool { + computed := a.InnerProduct(b) + var innerProduct Element + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + innerProduct.Add(&innerProduct, &tmp) + } + + return innerProduct.Equal(&computed) + } + + mulVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + a[0].SetUint64(0x24) + b[0].SetUint64(0x42) + c.Mul(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sizes := []int{1, 2, 3, 4, 8, 9, 15, 16, 509, 510, 511, 512, 513, 514} + type genPair struct { + g1, g2 gopter.Gen + label string + } + + for _, size := range sizes { + generators := []genPair{ + {genZeroVector(size), genZeroVector(size), "zero vectors"}, + {genMaxVector(size), genMaxVector(size), "max vectors"}, + {genVector(size), genVector(size), "random vectors"}, + {genVector(size), genZeroVector(size), "random and zero vectors"}, + } + for _, gp := range generators { + properties.Property(fmt.Sprintf("vector addition %d - %s", size, gp.label), prop.ForAll( + addVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector subtraction %d - %s", size, gp.label), prop.ForAll( + subVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector scalar multiplication %d - %s", size, gp.label), prop.ForAll( + scalarMulVector, + gp.g1, + genElement(), + )) + + properties.Property(fmt.Sprintf("vector sum %d - %s", size, gp.label), prop.ForAll( + sumVector, + gp.g1, + )) + + properties.Property(fmt.Sprintf("vector inner product %d - %s", size, gp.label), prop.ForAll( + innerProductVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector multiplication %d - %s", size, gp.label), prop.ForAll( + mulVector, + gp.g1, + gp.g2, + )) + } + } + + properties.TestingRun(t, gopter.NewFormatedReporter(false, 260, os.Stdout)) +} + +func BenchmarkVectorOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1 << 24 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + var mixer Element + mixer.SetRandom() + for i := 1; i < N; i++ { + a1[i-1].SetUint64(uint64(i)). + Mul(&a1[i-1], &mixer) + b1[i-1].SetUint64(^uint64(i)). + Mul(&b1[i-1], &mixer) + } + + for n := 1 << 4; n <= N; n <<= 1 { + b.Run(fmt.Sprintf("add %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Add(_a, _b) + } + }) + + b.Run(fmt.Sprintf("sub %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Sub(_a, _b) + } + }) + + b.Run(fmt.Sprintf("scalarMul %d", n), func(b *testing.B) { + _a := a1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.ScalarMul(_a, &mixer) + } + }) + + b.Run(fmt.Sprintf("sum %d", n), func(b *testing.B) { + _a := a1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.Sum() + } + }) + + b.Run(fmt.Sprintf("innerProduct %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.InnerProduct(_b) + } + }) + + b.Run(fmt.Sprintf("mul %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Mul(_a, _b) + } + }) + } +} + +func genZeroVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genMaxVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + + qMinusOne := qElement + qMinusOne[0]-- + + for i := 0; i < size; i++ { + g[i] = qMinusOne + } + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + mixer := Element{ + genParams.NextUint64(), + } + if qElement[0] != ^uint64(0) { + mixer[0] %= (qElement[0] + 1) + } + + for !mixer.smallerThanModulus() { + mixer = Element{ + genParams.NextUint64(), + } + if qElement[0] != ^uint64(0) { + mixer[0] %= (qElement[0] + 1) + } + } + + for i := 1; i <= size; i++ { + g[i-1].SetUint64(uint64(i)). + Mul(&g[i-1], &mixer) + } + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} diff --git a/go.mod b/go.mod index 1cc1f399b..b5486971f 100644 --- a/go.mod +++ b/go.mod @@ -4,12 +4,13 @@ go 1.22 require ( github.com/bits-and-blooms/bitset v1.14.2 - github.com/consensys/bavard v0.1.15 + github.com/consensys/bavard v0.1.22 github.com/leanovate/gopter v0.2.11 github.com/mmcloughlin/addchain v0.4.0 github.com/spf13/cobra v1.8.1 github.com/stretchr/testify v1.9.0 golang.org/x/crypto v0.26.0 + golang.org/x/sync v0.1.0 golang.org/x/sys v0.24.0 gopkg.in/yaml.v2 v2.4.0 ) diff --git a/go.sum b/go.sum index 2c324c00b..26af63d69 100644 --- a/go.sum +++ b/go.sum @@ -55,8 +55,8 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= -github.com/consensys/bavard v0.1.15 h1:fxv2mg1afRMJvZgpwEgLmyr2MsQwaAYcyKf31UBHzw4= -github.com/consensys/bavard v0.1.15/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= +github.com/consensys/bavard v0.1.22 h1:Uw2CGvbXSZWhqK59X0VG/zOjpTFuOMcPLStrp1ihI0A= +github.com/consensys/bavard v0.1.22/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= @@ -384,6 +384,7 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/internal/generator/main.go b/internal/generator/main.go index 389f96c2e..3d7cd9c0a 100644 --- a/internal/generator/main.go +++ b/internal/generator/main.go @@ -43,28 +43,55 @@ var bgen = bavard.NewBatchGenerator(copyrightHolder, copyrightYear, "consensys/g //go:generate go run main.go func main() { - var wg sync.WaitGroup + // first we loop through the field arithmetic we must generate. + // then, we create the common files (only once) for the assembly code. + asmDirBuildPath := filepath.Join(baseDir, "field", "asm") + asmDirIncludePath := filepath.Join("../../../", "field", "asm") + + // generate common assembly files depending on field number of words + mCommon := make(map[int]bool) + mVec := make(map[int]bool) + + for i, conf := range config.Curves { + var err error + // generate base field + conf.Fp, err = field.NewFieldConfig("fp", "Element", conf.FpModulus, true) + assertNoError(err) + + conf.Fr, err = field.NewFieldConfig("fr", "Element", conf.FrModulus, !conf.Equal(config.STARK_CURVE)) + assertNoError(err) + + mCommon[conf.Fr.NbWords] = true + mCommon[conf.Fp.NbWords] = true + + if conf.Fr.ASMVector { + mVec[conf.Fr.NbWords] = true + } + if conf.Fp.ASMVector { + mVec[conf.Fp.NbWords] = true + } + + config.Curves[i] = conf + } + + for nbWords := range mCommon { + assertNoError(generator.GenerateCommonASM(nbWords, asmDirBuildPath, mVec[nbWords])) + } + + var wg sync.WaitGroup for _, conf := range config.Curves { wg.Add(1) // for each curve, generate the needed files go func(conf config.Curve) { defer wg.Done() - var err error curveDir := filepath.Join(baseDir, "ecc", conf.Name) - // generate base field - conf.Fp, err = field.NewFieldConfig("fp", "Element", conf.FpModulus, true) - assertNoError(err) - - conf.Fr, err = field.NewFieldConfig("fr", "Element", conf.FrModulus, !conf.Equal(config.STARK_CURVE)) - assertNoError(err) - conf.FpUnusedBits = 64 - (conf.Fp.NbBits % 64) - assertNoError(generator.GenerateFF(conf.Fr, filepath.Join(curveDir, "fr"))) - assertNoError(generator.GenerateFF(conf.Fp, filepath.Join(curveDir, "fp"))) + assertNoError(generator.GenerateFF(conf.Fr, filepath.Join(curveDir, "fr"), asmDirBuildPath, asmDirIncludePath)) + assertNoError(generator.GenerateFF(conf.Fp, filepath.Join(curveDir, "fp"), asmDirBuildPath, asmDirIncludePath)) // generate ecdsa assertNoError(ecdsa.Generate(conf, curveDir, bgen)) diff --git a/internal/generator/tower/asm/amd64/e2.go b/internal/generator/tower/asm/amd64/e2.go index 5fb3ea969..e3564861a 100644 --- a/internal/generator/tower/asm/amd64/e2.go +++ b/internal/generator/tower/asm/amd64/e2.go @@ -35,7 +35,7 @@ type Fq2Amd64 struct { // NewFq2Amd64 ... func NewFq2Amd64(w io.Writer, F *field.FieldConfig, config config.Curve) *Fq2Amd64 { return &Fq2Amd64{ - amd64.NewFFAmd64(w, F), + amd64.NewFFAmd64(w, F.NbWords), config, w, F, @@ -48,8 +48,9 @@ func (fq2 *Fq2Amd64) Generate(forceADXCheck bool) error { fq2.WriteLn("#include \"textflag.h\"") fq2.WriteLn("#include \"funcdata.h\"") + fq2.WriteLn("#include \"go_asm.h\"") - fq2.GenerateDefines() + fq2.GenerateReduceDefine() if fq2.config.Equal(config.BN254) { fq2.generateMulDefine() } @@ -174,7 +175,7 @@ func (fq2 *Fq2Amd64) generateNegE2() { // z = x - q for i := 0; i < fq2.NbWords; i++ { - fq2.MOVQ(fq2.Q[i], q) + fq2.MOVQ(fq2.F.Q[i], q) if i == 0 { fq2.SUBQ(t[i], q) } else { @@ -208,7 +209,7 @@ func (fq2 *Fq2Amd64) generateNegE2() { // z = x - q for i := 0; i < fq2.NbWords; i++ { - fq2.MOVQ(fq2.Q[i], q) + fq2.MOVQ(fq2.F.Q[i], q) if i == 0 { fq2.SUBQ(t[i], q) } else { @@ -272,7 +273,7 @@ func (fq2 *Fq2Amd64) modReduceAfterSub(registers *ramd64.Registers, zero ramd64. } func (fq2 *Fq2Amd64) modReduceAfterSubScratch(zero ramd64.Register, t, scratch []ramd64.Register) { - fq2.Mov(fq2.Q, scratch) + fq2.Mov(fq2.F.Q, scratch) for i := 0; i < fq2.NbWords; i++ { fq2.CMOVQCC(zero, scratch[i]) } diff --git a/internal/generator/tower/asm/amd64/e2_bn254.go b/internal/generator/tower/asm/amd64/e2_bn254.go index 821338109..c894c2c68 100644 --- a/internal/generator/tower/asm/amd64/e2_bn254.go +++ b/internal/generator/tower/asm/amd64/e2_bn254.go @@ -305,11 +305,10 @@ func (fq2 *Fq2Amd64) generateMulDefine() { return string(op2[i]) } - wd := writerDefine{fq2.w} - tw := gamd64.NewFFAmd64(&wd, fq2.F) + wd := writerDefine{fq2.w, 0, false} + tw := gamd64.NewFFAmd64(&wd, fq2.F.NbWords) _, _ = io.WriteString(fq2.w, "// this code is generated and identical to fp.Mul(...)\n") - _, _ = io.WriteString(fq2.w, "#define MUL() \\ \n") tw.MulADX(&r, xat, yat, res) } @@ -339,10 +338,25 @@ func (fq2 *Fq2Amd64) mulElement() { } type writerDefine struct { - w io.Writer + w io.Writer + cptXORQ int + first bool } func (w *writerDefine) Write(p []byte) (n int, err error) { + // TODO @gbotrel temporary hack to re-use new struct in mul; + // then if it's the first time we are here, we print the header + if strings.Contains(string(p), "mul body") { + w.first = true + n, err = io.WriteString(w.w, "#define MUL() \\ \n") + if err != nil { + return + } + } + if !w.first { + return w.w.Write(p) + } + line := string(p) line = strings.TrimSpace(line) if strings.HasPrefix(line, "//") { diff --git a/internal/generator/tower/template/fq12over6over2/amd64.fq2.go.tmpl b/internal/generator/tower/template/fq12over6over2/amd64.fq2.go.tmpl index 74efb87e1..cebb6827b 100644 --- a/internal/generator/tower/template/fq12over6over2/amd64.fq2.go.tmpl +++ b/internal/generator/tower/template/fq12over6over2/amd64.fq2.go.tmpl @@ -1,3 +1,25 @@ +import ( + "github.com/consensys/gnark-crypto/ecc/{{.Name}}/fp" +) + +// q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r +// used for Montgomery reduction +const qInvNeg uint64 = {{index .Fp.QInverse 0}} + +// Field modulus q (Fp) +const ( + {{- range $i := $.Fp.NbWordsIndexesFull}} + q{{$i}} uint64 = {{index $.Fp.Q $i}} + {{- if eq $.Fp.NbWords 1}} + q uint64 = q0 + {{- end}} + {{- end}} + ) + +var qElement = fp.Element { + {{- range $i := $.Fp.NbWordsIndexesFull}} + q{{$i}},{{end}} +} //go:noescape func addE2(res,x,y *E2)