diff --git a/ecc/bls12-377/fr/fft/bitreverse.go b/ecc/bls12-377/fr/fft/bitreverse.go new file mode 100644 index 000000000..104ea5c44 --- /dev/null +++ b/ecc/bls12-377/fr/fft/bitreverse.go @@ -0,0 +1,574 @@ +// Copyright 2020 Consensys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fft + +import ( + "math/bits" + "runtime" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" +) + +// BitReverse applies the bit-reversal permutation to v. +// len(v) must be a power of 2 +func BitReverse(v []fr.Element) { + n := uint64(len(v)) + if bits.OnesCount64(n) != 1 { + panic("len(a) must be a power of 2") + } + + if runtime.GOARCH == "arm64" { + bitReverseNaive(v) + } else { + bitReverseCobra(v) + } +} + +// bitReverseNaive applies the bit-reversal permutation to v. +// len(v) must be a power of 2 +func bitReverseNaive(v []fr.Element) { + n := uint64(len(v)) + nn := uint64(64 - bits.TrailingZeros64(n)) + + for i := uint64(0); i < n; i++ { + iRev := bits.Reverse64(i) >> nn + if iRev > i { + v[i], v[iRev] = v[iRev], v[i] + } + } +} + +// bitReverseCobraInPlace applies the bit-reversal permutation to v. +// len(v) must be a power of 2 +// This is derived from: +// +// - Towards an Optimal Bit-Reversal Permutation Program +// Larry Carter and Kang Su Gatlin, 1998 +// https://csaws.cs.technion.ac.il/~itai/Courses/Cache/bit.pdf +// +// - Practically efficient methods for performing bit-reversed +// permutation in C++11 on the x86-64 architecture +// Knauth, Adas, Whitfield, Wang, Ickler, Conrad, Serang, 2017 +// https://arxiv.org/pdf/1708.01873.pdf +// +// - and more specifically, constantine implementation: +// https://github.com/mratsim/constantine/blob/d51699248db04e29c7b1ad97e0bafa1499db00b5/constantine/math/polynomials/fft.nim#L205 +// by Mamy Ratsimbazafy (@mratsim). +func bitReverseCobraInPlace(v []fr.Element) { + logN := uint64(bits.Len64(uint64(len(v))) - 1) + logTileSize := deriveLogTileSize(logN) + logBLen := logN - 2*logTileSize + bLen := uint64(1) << logBLen + bShift := logBLen + logTileSize + tileSize := uint64(1) << logTileSize + + // rough idea; + // bit reversal permutation naive implementation may have some cache associativity issues, + // since we are accessing elements by strides of powers of 2. + // on large inputs, this is noticeable and can be improved by using a t buffer. + // idea is for t buffer to be small enough to fit in cache. + // in the first inner loop, we copy the elements of v into t in a bit-reversed order. + // in the subsequent inner loops, accesses have much better cache locality than the naive implementation. + // hence even if we apparently do more work (swaps / copies), we are faster. + // + // on arm64 (and particularly on M1 macs), this is not noticeable, and the naive implementation is faster, + // in most cases. + // on x86 (and particularly on aws hpc6a) this is noticeable, and the t buffer implementation is faster (up to 3x). + // + // optimal choice for the tile size is cache dependent; in theory, we want the t buffer to fit in the L1 cache; + // in practice, a common size for L1 is 64kb, a field element is 32bytes or more. + // hence we can fit 2k elements in the L1 cache, which corresponds to a tile size of 2**5 with some margin for cache conflicts. + // + // for most sizes of interest, this tile size choice doesn't yield good results; + // we find that a tile size of 2**9 gives best results for input sizes from 2**21 up to 2**27+. + t := make([]fr.Element, tileSize*tileSize) + + // see https://csaws.cs.technion.ac.il/~itai/Courses/Cache/bit.pdf + // for a detailed explanation of the algorithm. + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> (64 - logTileSize)) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> (64 - logTileSize)) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> (64 - logTileSize) + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> (64 - logTileSize) + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> (64 - logTileSize)) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } +} + +func bitReverseCobra(v []fr.Element) { + switch len(v) { + case 1 << 21: + bitReverseCobraInPlace_9_21(v) + case 1 << 22: + bitReverseCobraInPlace_9_22(v) + case 1 << 23: + bitReverseCobraInPlace_9_23(v) + case 1 << 24: + bitReverseCobraInPlace_9_24(v) + case 1 << 25: + bitReverseCobraInPlace_9_25(v) + case 1 << 26: + bitReverseCobraInPlace_9_26(v) + case 1 << 27: + bitReverseCobraInPlace_9_27(v) + default: + if len(v) > 1<<27 { + bitReverseCobraInPlace(v) + } else { + bitReverseNaive(v) + } + } +} + +func deriveLogTileSize(logN uint64) uint64 { + q := uint64(9) // see bitReverseCobraInPlace for more details + + for int(logN)-int(2*q) <= 0 { + q-- + } + + return q +} + +// bitReverseCobraInPlace_9_21 applies the bit-reversal permutation to v. +// len(v) must be 1 << 21. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_21(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 21 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_22 applies the bit-reversal permutation to v. +// len(v) must be 1 << 22. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_22(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 22 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_23 applies the bit-reversal permutation to v. +// len(v) must be 1 << 23. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_23(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 23 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_24 applies the bit-reversal permutation to v. +// len(v) must be 1 << 24. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_24(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 24 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_25 applies the bit-reversal permutation to v. +// len(v) must be 1 << 25. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_25(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 25 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_26 applies the bit-reversal permutation to v. +// len(v) must be 1 << 26. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_26(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 26 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_27 applies the bit-reversal permutation to v. +// len(v) must be 1 << 27. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_27(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 27 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} diff --git a/ecc/bls12-377/fr/fft/bitreverse_test.go b/ecc/bls12-377/fr/fft/bitreverse_test.go new file mode 100644 index 000000000..0687c079d --- /dev/null +++ b/ecc/bls12-377/fr/fft/bitreverse_test.go @@ -0,0 +1,113 @@ +// Copyright 2020 Consensys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fft + +import ( + "fmt" + "testing" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" +) + +type bitReverseVariant struct { + name string + buf []fr.Element + fn func([]fr.Element) +} + +const maxSizeBitReverse = 1 << 23 + +var bitReverse = []bitReverseVariant{ + {name: "bitReverseNaive", buf: make([]fr.Element, maxSizeBitReverse), fn: bitReverseNaive}, + {name: "BitReverse", buf: make([]fr.Element, maxSizeBitReverse), fn: BitReverse}, + {name: "bitReverseCobraInPlace", buf: make([]fr.Element, maxSizeBitReverse), fn: bitReverseCobraInPlace}, +} + +func TestBitReverse(t *testing.T) { + + // generate a random []fr.Element array of size 2**20 + pol := make([]fr.Element, maxSizeBitReverse) + one := fr.One() + pol[0].SetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // for each size, check that all the bitReverse functions fn compute the same result. + for size := 2; size <= maxSizeBitReverse; size <<= 1 { + + // copy pol into the buffers + for _, data := range bitReverse { + copy(data.buf, pol[:size]) + } + + // compute bit reverse shuffling + for _, data := range bitReverse { + data.fn(data.buf[:size]) + } + + // all bitReverse.buf should hold the same result + for i := 0; i < size; i++ { + for j := 1; j < len(bitReverse); j++ { + if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { + t.Fatalf("bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + } + } + } + + // bitReverse back should be identity + for _, data := range bitReverse { + data.fn(data.buf[:size]) + } + + for i := 0; i < size; i++ { + for j := 1; j < len(bitReverse); j++ { + if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { + t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + } + } + } + } + +} + +func BenchmarkBitReverse(b *testing.B) { + // generate a random []fr.Element array of size 2**22 + pol := make([]fr.Element, maxSizeBitReverse) + one := fr.One() + pol[0].SetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // copy pol into the buffers + for _, data := range bitReverse { + copy(data.buf, pol[:maxSizeBitReverse]) + } + + // benchmark for each size, each bitReverse function + for size := 1 << 18; size <= maxSizeBitReverse; size <<= 1 { + for _, data := range bitReverse { + b.Run(fmt.Sprintf("name=%s/size=%d", data.name, size), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + data.fn(data.buf[:size]) + } + }) + } + } +} diff --git a/ecc/bls12-377/fr/fft/fft.go b/ecc/bls12-377/fr/fft/fft.go index 8c01bf23a..20cafffd7 100644 --- a/ecc/bls12-377/fr/fft/fft.go +++ b/ecc/bls12-377/fr/fft/fft.go @@ -235,20 +235,6 @@ func ditFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon } } -// BitReverse applies the bit-reversal permutation to a. -// len(a) must be a power of 2 (as in every single function in this file) -func BitReverse(a []fr.Element) { - n := uint64(len(a)) - nn := uint64(64 - bits.TrailingZeros64(n)) - - for i := uint64(0); i < n; i++ { - irev := bits.Reverse64(i) >> nn - if irev > i { - a[i], a[irev] = a[irev], a[i] - } - } -} - // kerDIT8 is a kernel that process a FFT of size 8 func kerDIT8(a []fr.Element, twiddles [][]fr.Element, stage int) { diff --git a/ecc/bls12-377/fr/fft/fft_test.go b/ecc/bls12-377/fr/fft/fft_test.go index c548c40ed..a443204c1 100644 --- a/ecc/bls12-377/fr/fft/fft_test.go +++ b/ecc/bls12-377/fr/fft/fft_test.go @@ -240,26 +240,6 @@ func TestFFT(t *testing.T) { // -------------------------------------------------------------------- // benches -func BenchmarkBitReverse(b *testing.B) { - - const maxSize = 1 << 20 - - pol := make([]fr.Element, maxSize) - pol[0].SetRandom() - for i := 1; i < maxSize; i++ { - pol[i] = pol[i-1] - } - - for i := 8; i < 20; i++ { - b.Run("bit reversing 2**"+strconv.Itoa(i)+"bits", func(b *testing.B) { - b.ResetTimer() - for j := 0; j < b.N; j++ { - BitReverse(pol[:1<> nn + if iRev > i { + v[i], v[iRev] = v[iRev], v[i] + } + } +} + +// bitReverseCobraInPlace applies the bit-reversal permutation to v. +// len(v) must be a power of 2 +// This is derived from: +// +// - Towards an Optimal Bit-Reversal Permutation Program +// Larry Carter and Kang Su Gatlin, 1998 +// https://csaws.cs.technion.ac.il/~itai/Courses/Cache/bit.pdf +// +// - Practically efficient methods for performing bit-reversed +// permutation in C++11 on the x86-64 architecture +// Knauth, Adas, Whitfield, Wang, Ickler, Conrad, Serang, 2017 +// https://arxiv.org/pdf/1708.01873.pdf +// +// - and more specifically, constantine implementation: +// https://github.com/mratsim/constantine/blob/d51699248db04e29c7b1ad97e0bafa1499db00b5/constantine/math/polynomials/fft.nim#L205 +// by Mamy Ratsimbazafy (@mratsim). +func bitReverseCobraInPlace(v []fr.Element) { + logN := uint64(bits.Len64(uint64(len(v))) - 1) + logTileSize := deriveLogTileSize(logN) + logBLen := logN - 2*logTileSize + bLen := uint64(1) << logBLen + bShift := logBLen + logTileSize + tileSize := uint64(1) << logTileSize + + // rough idea; + // bit reversal permutation naive implementation may have some cache associativity issues, + // since we are accessing elements by strides of powers of 2. + // on large inputs, this is noticeable and can be improved by using a t buffer. + // idea is for t buffer to be small enough to fit in cache. + // in the first inner loop, we copy the elements of v into t in a bit-reversed order. + // in the subsequent inner loops, accesses have much better cache locality than the naive implementation. + // hence even if we apparently do more work (swaps / copies), we are faster. + // + // on arm64 (and particularly on M1 macs), this is not noticeable, and the naive implementation is faster, + // in most cases. + // on x86 (and particularly on aws hpc6a) this is noticeable, and the t buffer implementation is faster (up to 3x). + // + // optimal choice for the tile size is cache dependent; in theory, we want the t buffer to fit in the L1 cache; + // in practice, a common size for L1 is 64kb, a field element is 32bytes or more. + // hence we can fit 2k elements in the L1 cache, which corresponds to a tile size of 2**5 with some margin for cache conflicts. + // + // for most sizes of interest, this tile size choice doesn't yield good results; + // we find that a tile size of 2**9 gives best results for input sizes from 2**21 up to 2**27+. + t := make([]fr.Element, tileSize*tileSize) + + // see https://csaws.cs.technion.ac.il/~itai/Courses/Cache/bit.pdf + // for a detailed explanation of the algorithm. + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> (64 - logTileSize)) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> (64 - logTileSize)) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> (64 - logTileSize) + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> (64 - logTileSize) + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> (64 - logTileSize)) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } +} + +func bitReverseCobra(v []fr.Element) { + switch len(v) { + case 1 << 21: + bitReverseCobraInPlace_9_21(v) + case 1 << 22: + bitReverseCobraInPlace_9_22(v) + case 1 << 23: + bitReverseCobraInPlace_9_23(v) + case 1 << 24: + bitReverseCobraInPlace_9_24(v) + case 1 << 25: + bitReverseCobraInPlace_9_25(v) + case 1 << 26: + bitReverseCobraInPlace_9_26(v) + case 1 << 27: + bitReverseCobraInPlace_9_27(v) + default: + if len(v) > 1<<27 { + bitReverseCobraInPlace(v) + } else { + bitReverseNaive(v) + } + } +} + +func deriveLogTileSize(logN uint64) uint64 { + q := uint64(9) // see bitReverseCobraInPlace for more details + + for int(logN)-int(2*q) <= 0 { + q-- + } + + return q +} + +// bitReverseCobraInPlace_9_21 applies the bit-reversal permutation to v. +// len(v) must be 1 << 21. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_21(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 21 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_22 applies the bit-reversal permutation to v. +// len(v) must be 1 << 22. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_22(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 22 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_23 applies the bit-reversal permutation to v. +// len(v) must be 1 << 23. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_23(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 23 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_24 applies the bit-reversal permutation to v. +// len(v) must be 1 << 24. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_24(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 24 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_25 applies the bit-reversal permutation to v. +// len(v) must be 1 << 25. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_25(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 25 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_26 applies the bit-reversal permutation to v. +// len(v) must be 1 << 26. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_26(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 26 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_27 applies the bit-reversal permutation to v. +// len(v) must be 1 << 27. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_27(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 27 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} diff --git a/ecc/bls12-378/fr/fft/bitreverse_test.go b/ecc/bls12-378/fr/fft/bitreverse_test.go new file mode 100644 index 000000000..e6930a472 --- /dev/null +++ b/ecc/bls12-378/fr/fft/bitreverse_test.go @@ -0,0 +1,113 @@ +// Copyright 2020 Consensys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fft + +import ( + "fmt" + "testing" + + "github.com/consensys/gnark-crypto/ecc/bls12-378/fr" +) + +type bitReverseVariant struct { + name string + buf []fr.Element + fn func([]fr.Element) +} + +const maxSizeBitReverse = 1 << 23 + +var bitReverse = []bitReverseVariant{ + {name: "bitReverseNaive", buf: make([]fr.Element, maxSizeBitReverse), fn: bitReverseNaive}, + {name: "BitReverse", buf: make([]fr.Element, maxSizeBitReverse), fn: BitReverse}, + {name: "bitReverseCobraInPlace", buf: make([]fr.Element, maxSizeBitReverse), fn: bitReverseCobraInPlace}, +} + +func TestBitReverse(t *testing.T) { + + // generate a random []fr.Element array of size 2**20 + pol := make([]fr.Element, maxSizeBitReverse) + one := fr.One() + pol[0].SetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // for each size, check that all the bitReverse functions fn compute the same result. + for size := 2; size <= maxSizeBitReverse; size <<= 1 { + + // copy pol into the buffers + for _, data := range bitReverse { + copy(data.buf, pol[:size]) + } + + // compute bit reverse shuffling + for _, data := range bitReverse { + data.fn(data.buf[:size]) + } + + // all bitReverse.buf should hold the same result + for i := 0; i < size; i++ { + for j := 1; j < len(bitReverse); j++ { + if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { + t.Fatalf("bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + } + } + } + + // bitReverse back should be identity + for _, data := range bitReverse { + data.fn(data.buf[:size]) + } + + for i := 0; i < size; i++ { + for j := 1; j < len(bitReverse); j++ { + if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { + t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + } + } + } + } + +} + +func BenchmarkBitReverse(b *testing.B) { + // generate a random []fr.Element array of size 2**22 + pol := make([]fr.Element, maxSizeBitReverse) + one := fr.One() + pol[0].SetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // copy pol into the buffers + for _, data := range bitReverse { + copy(data.buf, pol[:maxSizeBitReverse]) + } + + // benchmark for each size, each bitReverse function + for size := 1 << 18; size <= maxSizeBitReverse; size <<= 1 { + for _, data := range bitReverse { + b.Run(fmt.Sprintf("name=%s/size=%d", data.name, size), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + data.fn(data.buf[:size]) + } + }) + } + } +} diff --git a/ecc/bls12-378/fr/fft/fft.go b/ecc/bls12-378/fr/fft/fft.go index a74c8b4e8..9f1527360 100644 --- a/ecc/bls12-378/fr/fft/fft.go +++ b/ecc/bls12-378/fr/fft/fft.go @@ -235,20 +235,6 @@ func ditFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon } } -// BitReverse applies the bit-reversal permutation to a. -// len(a) must be a power of 2 (as in every single function in this file) -func BitReverse(a []fr.Element) { - n := uint64(len(a)) - nn := uint64(64 - bits.TrailingZeros64(n)) - - for i := uint64(0); i < n; i++ { - irev := bits.Reverse64(i) >> nn - if irev > i { - a[i], a[irev] = a[irev], a[i] - } - } -} - // kerDIT8 is a kernel that process a FFT of size 8 func kerDIT8(a []fr.Element, twiddles [][]fr.Element, stage int) { diff --git a/ecc/bls12-378/fr/fft/fft_test.go b/ecc/bls12-378/fr/fft/fft_test.go index caa190aa8..0478dd270 100644 --- a/ecc/bls12-378/fr/fft/fft_test.go +++ b/ecc/bls12-378/fr/fft/fft_test.go @@ -240,26 +240,6 @@ func TestFFT(t *testing.T) { // -------------------------------------------------------------------- // benches -func BenchmarkBitReverse(b *testing.B) { - - const maxSize = 1 << 20 - - pol := make([]fr.Element, maxSize) - pol[0].SetRandom() - for i := 1; i < maxSize; i++ { - pol[i] = pol[i-1] - } - - for i := 8; i < 20; i++ { - b.Run("bit reversing 2**"+strconv.Itoa(i)+"bits", func(b *testing.B) { - b.ResetTimer() - for j := 0; j < b.N; j++ { - BitReverse(pol[:1<> nn + if iRev > i { + v[i], v[iRev] = v[iRev], v[i] + } + } +} + +// bitReverseCobraInPlace applies the bit-reversal permutation to v. +// len(v) must be a power of 2 +// This is derived from: +// +// - Towards an Optimal Bit-Reversal Permutation Program +// Larry Carter and Kang Su Gatlin, 1998 +// https://csaws.cs.technion.ac.il/~itai/Courses/Cache/bit.pdf +// +// - Practically efficient methods for performing bit-reversed +// permutation in C++11 on the x86-64 architecture +// Knauth, Adas, Whitfield, Wang, Ickler, Conrad, Serang, 2017 +// https://arxiv.org/pdf/1708.01873.pdf +// +// - and more specifically, constantine implementation: +// https://github.com/mratsim/constantine/blob/d51699248db04e29c7b1ad97e0bafa1499db00b5/constantine/math/polynomials/fft.nim#L205 +// by Mamy Ratsimbazafy (@mratsim). +func bitReverseCobraInPlace(v []fr.Element) { + logN := uint64(bits.Len64(uint64(len(v))) - 1) + logTileSize := deriveLogTileSize(logN) + logBLen := logN - 2*logTileSize + bLen := uint64(1) << logBLen + bShift := logBLen + logTileSize + tileSize := uint64(1) << logTileSize + + // rough idea; + // bit reversal permutation naive implementation may have some cache associativity issues, + // since we are accessing elements by strides of powers of 2. + // on large inputs, this is noticeable and can be improved by using a t buffer. + // idea is for t buffer to be small enough to fit in cache. + // in the first inner loop, we copy the elements of v into t in a bit-reversed order. + // in the subsequent inner loops, accesses have much better cache locality than the naive implementation. + // hence even if we apparently do more work (swaps / copies), we are faster. + // + // on arm64 (and particularly on M1 macs), this is not noticeable, and the naive implementation is faster, + // in most cases. + // on x86 (and particularly on aws hpc6a) this is noticeable, and the t buffer implementation is faster (up to 3x). + // + // optimal choice for the tile size is cache dependent; in theory, we want the t buffer to fit in the L1 cache; + // in practice, a common size for L1 is 64kb, a field element is 32bytes or more. + // hence we can fit 2k elements in the L1 cache, which corresponds to a tile size of 2**5 with some margin for cache conflicts. + // + // for most sizes of interest, this tile size choice doesn't yield good results; + // we find that a tile size of 2**9 gives best results for input sizes from 2**21 up to 2**27+. + t := make([]fr.Element, tileSize*tileSize) + + // see https://csaws.cs.technion.ac.il/~itai/Courses/Cache/bit.pdf + // for a detailed explanation of the algorithm. + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> (64 - logTileSize)) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> (64 - logTileSize)) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> (64 - logTileSize) + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> (64 - logTileSize) + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> (64 - logTileSize)) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } +} + +func bitReverseCobra(v []fr.Element) { + switch len(v) { + case 1 << 21: + bitReverseCobraInPlace_9_21(v) + case 1 << 22: + bitReverseCobraInPlace_9_22(v) + case 1 << 23: + bitReverseCobraInPlace_9_23(v) + case 1 << 24: + bitReverseCobraInPlace_9_24(v) + case 1 << 25: + bitReverseCobraInPlace_9_25(v) + case 1 << 26: + bitReverseCobraInPlace_9_26(v) + case 1 << 27: + bitReverseCobraInPlace_9_27(v) + default: + if len(v) > 1<<27 { + bitReverseCobraInPlace(v) + } else { + bitReverseNaive(v) + } + } +} + +func deriveLogTileSize(logN uint64) uint64 { + q := uint64(9) // see bitReverseCobraInPlace for more details + + for int(logN)-int(2*q) <= 0 { + q-- + } + + return q +} + +// bitReverseCobraInPlace_9_21 applies the bit-reversal permutation to v. +// len(v) must be 1 << 21. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_21(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 21 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_22 applies the bit-reversal permutation to v. +// len(v) must be 1 << 22. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_22(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 22 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_23 applies the bit-reversal permutation to v. +// len(v) must be 1 << 23. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_23(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 23 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_24 applies the bit-reversal permutation to v. +// len(v) must be 1 << 24. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_24(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 24 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_25 applies the bit-reversal permutation to v. +// len(v) must be 1 << 25. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_25(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 25 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_26 applies the bit-reversal permutation to v. +// len(v) must be 1 << 26. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_26(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 26 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_27 applies the bit-reversal permutation to v. +// len(v) must be 1 << 27. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_27(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 27 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} diff --git a/ecc/bls12-381/fr/fft/bitreverse_test.go b/ecc/bls12-381/fr/fft/bitreverse_test.go new file mode 100644 index 000000000..8bffd1270 --- /dev/null +++ b/ecc/bls12-381/fr/fft/bitreverse_test.go @@ -0,0 +1,113 @@ +// Copyright 2020 Consensys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fft + +import ( + "fmt" + "testing" + + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" +) + +type bitReverseVariant struct { + name string + buf []fr.Element + fn func([]fr.Element) +} + +const maxSizeBitReverse = 1 << 23 + +var bitReverse = []bitReverseVariant{ + {name: "bitReverseNaive", buf: make([]fr.Element, maxSizeBitReverse), fn: bitReverseNaive}, + {name: "BitReverse", buf: make([]fr.Element, maxSizeBitReverse), fn: BitReverse}, + {name: "bitReverseCobraInPlace", buf: make([]fr.Element, maxSizeBitReverse), fn: bitReverseCobraInPlace}, +} + +func TestBitReverse(t *testing.T) { + + // generate a random []fr.Element array of size 2**20 + pol := make([]fr.Element, maxSizeBitReverse) + one := fr.One() + pol[0].SetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // for each size, check that all the bitReverse functions fn compute the same result. + for size := 2; size <= maxSizeBitReverse; size <<= 1 { + + // copy pol into the buffers + for _, data := range bitReverse { + copy(data.buf, pol[:size]) + } + + // compute bit reverse shuffling + for _, data := range bitReverse { + data.fn(data.buf[:size]) + } + + // all bitReverse.buf should hold the same result + for i := 0; i < size; i++ { + for j := 1; j < len(bitReverse); j++ { + if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { + t.Fatalf("bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + } + } + } + + // bitReverse back should be identity + for _, data := range bitReverse { + data.fn(data.buf[:size]) + } + + for i := 0; i < size; i++ { + for j := 1; j < len(bitReverse); j++ { + if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { + t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + } + } + } + } + +} + +func BenchmarkBitReverse(b *testing.B) { + // generate a random []fr.Element array of size 2**22 + pol := make([]fr.Element, maxSizeBitReverse) + one := fr.One() + pol[0].SetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // copy pol into the buffers + for _, data := range bitReverse { + copy(data.buf, pol[:maxSizeBitReverse]) + } + + // benchmark for each size, each bitReverse function + for size := 1 << 18; size <= maxSizeBitReverse; size <<= 1 { + for _, data := range bitReverse { + b.Run(fmt.Sprintf("name=%s/size=%d", data.name, size), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + data.fn(data.buf[:size]) + } + }) + } + } +} diff --git a/ecc/bls12-381/fr/fft/fft.go b/ecc/bls12-381/fr/fft/fft.go index 443a46bde..dbc99e444 100644 --- a/ecc/bls12-381/fr/fft/fft.go +++ b/ecc/bls12-381/fr/fft/fft.go @@ -235,20 +235,6 @@ func ditFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon } } -// BitReverse applies the bit-reversal permutation to a. -// len(a) must be a power of 2 (as in every single function in this file) -func BitReverse(a []fr.Element) { - n := uint64(len(a)) - nn := uint64(64 - bits.TrailingZeros64(n)) - - for i := uint64(0); i < n; i++ { - irev := bits.Reverse64(i) >> nn - if irev > i { - a[i], a[irev] = a[irev], a[i] - } - } -} - // kerDIT8 is a kernel that process a FFT of size 8 func kerDIT8(a []fr.Element, twiddles [][]fr.Element, stage int) { diff --git a/ecc/bls12-381/fr/fft/fft_test.go b/ecc/bls12-381/fr/fft/fft_test.go index 58ea21799..21851d58d 100644 --- a/ecc/bls12-381/fr/fft/fft_test.go +++ b/ecc/bls12-381/fr/fft/fft_test.go @@ -240,26 +240,6 @@ func TestFFT(t *testing.T) { // -------------------------------------------------------------------- // benches -func BenchmarkBitReverse(b *testing.B) { - - const maxSize = 1 << 20 - - pol := make([]fr.Element, maxSize) - pol[0].SetRandom() - for i := 1; i < maxSize; i++ { - pol[i] = pol[i-1] - } - - for i := 8; i < 20; i++ { - b.Run("bit reversing 2**"+strconv.Itoa(i)+"bits", func(b *testing.B) { - b.ResetTimer() - for j := 0; j < b.N; j++ { - BitReverse(pol[:1<> nn + if iRev > i { + v[i], v[iRev] = v[iRev], v[i] + } + } +} + +// bitReverseCobraInPlace applies the bit-reversal permutation to v. +// len(v) must be a power of 2 +// This is derived from: +// +// - Towards an Optimal Bit-Reversal Permutation Program +// Larry Carter and Kang Su Gatlin, 1998 +// https://csaws.cs.technion.ac.il/~itai/Courses/Cache/bit.pdf +// +// - Practically efficient methods for performing bit-reversed +// permutation in C++11 on the x86-64 architecture +// Knauth, Adas, Whitfield, Wang, Ickler, Conrad, Serang, 2017 +// https://arxiv.org/pdf/1708.01873.pdf +// +// - and more specifically, constantine implementation: +// https://github.com/mratsim/constantine/blob/d51699248db04e29c7b1ad97e0bafa1499db00b5/constantine/math/polynomials/fft.nim#L205 +// by Mamy Ratsimbazafy (@mratsim). +func bitReverseCobraInPlace(v []fr.Element) { + logN := uint64(bits.Len64(uint64(len(v))) - 1) + logTileSize := deriveLogTileSize(logN) + logBLen := logN - 2*logTileSize + bLen := uint64(1) << logBLen + bShift := logBLen + logTileSize + tileSize := uint64(1) << logTileSize + + // rough idea; + // bit reversal permutation naive implementation may have some cache associativity issues, + // since we are accessing elements by strides of powers of 2. + // on large inputs, this is noticeable and can be improved by using a t buffer. + // idea is for t buffer to be small enough to fit in cache. + // in the first inner loop, we copy the elements of v into t in a bit-reversed order. + // in the subsequent inner loops, accesses have much better cache locality than the naive implementation. + // hence even if we apparently do more work (swaps / copies), we are faster. + // + // on arm64 (and particularly on M1 macs), this is not noticeable, and the naive implementation is faster, + // in most cases. + // on x86 (and particularly on aws hpc6a) this is noticeable, and the t buffer implementation is faster (up to 3x). + // + // optimal choice for the tile size is cache dependent; in theory, we want the t buffer to fit in the L1 cache; + // in practice, a common size for L1 is 64kb, a field element is 32bytes or more. + // hence we can fit 2k elements in the L1 cache, which corresponds to a tile size of 2**5 with some margin for cache conflicts. + // + // for most sizes of interest, this tile size choice doesn't yield good results; + // we find that a tile size of 2**9 gives best results for input sizes from 2**21 up to 2**27+. + t := make([]fr.Element, tileSize*tileSize) + + // see https://csaws.cs.technion.ac.il/~itai/Courses/Cache/bit.pdf + // for a detailed explanation of the algorithm. + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> (64 - logTileSize)) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> (64 - logTileSize)) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> (64 - logTileSize) + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> (64 - logTileSize) + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> (64 - logTileSize)) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } +} + +func bitReverseCobra(v []fr.Element) { + switch len(v) { + case 1 << 21: + bitReverseCobraInPlace_9_21(v) + case 1 << 22: + bitReverseCobraInPlace_9_22(v) + case 1 << 23: + bitReverseCobraInPlace_9_23(v) + case 1 << 24: + bitReverseCobraInPlace_9_24(v) + case 1 << 25: + bitReverseCobraInPlace_9_25(v) + case 1 << 26: + bitReverseCobraInPlace_9_26(v) + case 1 << 27: + bitReverseCobraInPlace_9_27(v) + default: + if len(v) > 1<<27 { + bitReverseCobraInPlace(v) + } else { + bitReverseNaive(v) + } + } +} + +func deriveLogTileSize(logN uint64) uint64 { + q := uint64(9) // see bitReverseCobraInPlace for more details + + for int(logN)-int(2*q) <= 0 { + q-- + } + + return q +} + +// bitReverseCobraInPlace_9_21 applies the bit-reversal permutation to v. +// len(v) must be 1 << 21. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_21(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 21 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_22 applies the bit-reversal permutation to v. +// len(v) must be 1 << 22. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_22(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 22 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_23 applies the bit-reversal permutation to v. +// len(v) must be 1 << 23. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_23(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 23 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_24 applies the bit-reversal permutation to v. +// len(v) must be 1 << 24. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_24(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 24 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_25 applies the bit-reversal permutation to v. +// len(v) must be 1 << 25. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_25(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 25 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_26 applies the bit-reversal permutation to v. +// len(v) must be 1 << 26. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_26(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 26 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_27 applies the bit-reversal permutation to v. +// len(v) must be 1 << 27. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_27(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 27 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} diff --git a/ecc/bls24-315/fr/fft/bitreverse_test.go b/ecc/bls24-315/fr/fft/bitreverse_test.go new file mode 100644 index 000000000..c19ccc7fa --- /dev/null +++ b/ecc/bls24-315/fr/fft/bitreverse_test.go @@ -0,0 +1,113 @@ +// Copyright 2020 Consensys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fft + +import ( + "fmt" + "testing" + + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" +) + +type bitReverseVariant struct { + name string + buf []fr.Element + fn func([]fr.Element) +} + +const maxSizeBitReverse = 1 << 23 + +var bitReverse = []bitReverseVariant{ + {name: "bitReverseNaive", buf: make([]fr.Element, maxSizeBitReverse), fn: bitReverseNaive}, + {name: "BitReverse", buf: make([]fr.Element, maxSizeBitReverse), fn: BitReverse}, + {name: "bitReverseCobraInPlace", buf: make([]fr.Element, maxSizeBitReverse), fn: bitReverseCobraInPlace}, +} + +func TestBitReverse(t *testing.T) { + + // generate a random []fr.Element array of size 2**20 + pol := make([]fr.Element, maxSizeBitReverse) + one := fr.One() + pol[0].SetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // for each size, check that all the bitReverse functions fn compute the same result. + for size := 2; size <= maxSizeBitReverse; size <<= 1 { + + // copy pol into the buffers + for _, data := range bitReverse { + copy(data.buf, pol[:size]) + } + + // compute bit reverse shuffling + for _, data := range bitReverse { + data.fn(data.buf[:size]) + } + + // all bitReverse.buf should hold the same result + for i := 0; i < size; i++ { + for j := 1; j < len(bitReverse); j++ { + if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { + t.Fatalf("bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + } + } + } + + // bitReverse back should be identity + for _, data := range bitReverse { + data.fn(data.buf[:size]) + } + + for i := 0; i < size; i++ { + for j := 1; j < len(bitReverse); j++ { + if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { + t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + } + } + } + } + +} + +func BenchmarkBitReverse(b *testing.B) { + // generate a random []fr.Element array of size 2**22 + pol := make([]fr.Element, maxSizeBitReverse) + one := fr.One() + pol[0].SetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // copy pol into the buffers + for _, data := range bitReverse { + copy(data.buf, pol[:maxSizeBitReverse]) + } + + // benchmark for each size, each bitReverse function + for size := 1 << 18; size <= maxSizeBitReverse; size <<= 1 { + for _, data := range bitReverse { + b.Run(fmt.Sprintf("name=%s/size=%d", data.name, size), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + data.fn(data.buf[:size]) + } + }) + } + } +} diff --git a/ecc/bls24-315/fr/fft/fft.go b/ecc/bls24-315/fr/fft/fft.go index 30fb173cf..bd2eda5fc 100644 --- a/ecc/bls24-315/fr/fft/fft.go +++ b/ecc/bls24-315/fr/fft/fft.go @@ -235,20 +235,6 @@ func ditFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon } } -// BitReverse applies the bit-reversal permutation to a. -// len(a) must be a power of 2 (as in every single function in this file) -func BitReverse(a []fr.Element) { - n := uint64(len(a)) - nn := uint64(64 - bits.TrailingZeros64(n)) - - for i := uint64(0); i < n; i++ { - irev := bits.Reverse64(i) >> nn - if irev > i { - a[i], a[irev] = a[irev], a[i] - } - } -} - // kerDIT8 is a kernel that process a FFT of size 8 func kerDIT8(a []fr.Element, twiddles [][]fr.Element, stage int) { diff --git a/ecc/bls24-315/fr/fft/fft_test.go b/ecc/bls24-315/fr/fft/fft_test.go index 7b94824f0..04860ec9f 100644 --- a/ecc/bls24-315/fr/fft/fft_test.go +++ b/ecc/bls24-315/fr/fft/fft_test.go @@ -240,26 +240,6 @@ func TestFFT(t *testing.T) { // -------------------------------------------------------------------- // benches -func BenchmarkBitReverse(b *testing.B) { - - const maxSize = 1 << 20 - - pol := make([]fr.Element, maxSize) - pol[0].SetRandom() - for i := 1; i < maxSize; i++ { - pol[i] = pol[i-1] - } - - for i := 8; i < 20; i++ { - b.Run("bit reversing 2**"+strconv.Itoa(i)+"bits", func(b *testing.B) { - b.ResetTimer() - for j := 0; j < b.N; j++ { - BitReverse(pol[:1<> nn + if iRev > i { + v[i], v[iRev] = v[iRev], v[i] + } + } +} + +// bitReverseCobraInPlace applies the bit-reversal permutation to v. +// len(v) must be a power of 2 +// This is derived from: +// +// - Towards an Optimal Bit-Reversal Permutation Program +// Larry Carter and Kang Su Gatlin, 1998 +// https://csaws.cs.technion.ac.il/~itai/Courses/Cache/bit.pdf +// +// - Practically efficient methods for performing bit-reversed +// permutation in C++11 on the x86-64 architecture +// Knauth, Adas, Whitfield, Wang, Ickler, Conrad, Serang, 2017 +// https://arxiv.org/pdf/1708.01873.pdf +// +// - and more specifically, constantine implementation: +// https://github.com/mratsim/constantine/blob/d51699248db04e29c7b1ad97e0bafa1499db00b5/constantine/math/polynomials/fft.nim#L205 +// by Mamy Ratsimbazafy (@mratsim). +func bitReverseCobraInPlace(v []fr.Element) { + logN := uint64(bits.Len64(uint64(len(v))) - 1) + logTileSize := deriveLogTileSize(logN) + logBLen := logN - 2*logTileSize + bLen := uint64(1) << logBLen + bShift := logBLen + logTileSize + tileSize := uint64(1) << logTileSize + + // rough idea; + // bit reversal permutation naive implementation may have some cache associativity issues, + // since we are accessing elements by strides of powers of 2. + // on large inputs, this is noticeable and can be improved by using a t buffer. + // idea is for t buffer to be small enough to fit in cache. + // in the first inner loop, we copy the elements of v into t in a bit-reversed order. + // in the subsequent inner loops, accesses have much better cache locality than the naive implementation. + // hence even if we apparently do more work (swaps / copies), we are faster. + // + // on arm64 (and particularly on M1 macs), this is not noticeable, and the naive implementation is faster, + // in most cases. + // on x86 (and particularly on aws hpc6a) this is noticeable, and the t buffer implementation is faster (up to 3x). + // + // optimal choice for the tile size is cache dependent; in theory, we want the t buffer to fit in the L1 cache; + // in practice, a common size for L1 is 64kb, a field element is 32bytes or more. + // hence we can fit 2k elements in the L1 cache, which corresponds to a tile size of 2**5 with some margin for cache conflicts. + // + // for most sizes of interest, this tile size choice doesn't yield good results; + // we find that a tile size of 2**9 gives best results for input sizes from 2**21 up to 2**27+. + t := make([]fr.Element, tileSize*tileSize) + + // see https://csaws.cs.technion.ac.il/~itai/Courses/Cache/bit.pdf + // for a detailed explanation of the algorithm. + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> (64 - logTileSize)) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> (64 - logTileSize)) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> (64 - logTileSize) + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> (64 - logTileSize) + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> (64 - logTileSize)) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } +} + +func bitReverseCobra(v []fr.Element) { + switch len(v) { + case 1 << 21: + bitReverseCobraInPlace_9_21(v) + case 1 << 22: + bitReverseCobraInPlace_9_22(v) + case 1 << 23: + bitReverseCobraInPlace_9_23(v) + case 1 << 24: + bitReverseCobraInPlace_9_24(v) + case 1 << 25: + bitReverseCobraInPlace_9_25(v) + case 1 << 26: + bitReverseCobraInPlace_9_26(v) + case 1 << 27: + bitReverseCobraInPlace_9_27(v) + default: + if len(v) > 1<<27 { + bitReverseCobraInPlace(v) + } else { + bitReverseNaive(v) + } + } +} + +func deriveLogTileSize(logN uint64) uint64 { + q := uint64(9) // see bitReverseCobraInPlace for more details + + for int(logN)-int(2*q) <= 0 { + q-- + } + + return q +} + +// bitReverseCobraInPlace_9_21 applies the bit-reversal permutation to v. +// len(v) must be 1 << 21. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_21(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 21 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_22 applies the bit-reversal permutation to v. +// len(v) must be 1 << 22. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_22(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 22 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_23 applies the bit-reversal permutation to v. +// len(v) must be 1 << 23. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_23(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 23 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_24 applies the bit-reversal permutation to v. +// len(v) must be 1 << 24. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_24(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 24 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_25 applies the bit-reversal permutation to v. +// len(v) must be 1 << 25. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_25(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 25 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_26 applies the bit-reversal permutation to v. +// len(v) must be 1 << 26. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_26(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 26 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_27 applies the bit-reversal permutation to v. +// len(v) must be 1 << 27. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_27(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 27 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} diff --git a/ecc/bls24-317/fr/fft/bitreverse_test.go b/ecc/bls24-317/fr/fft/bitreverse_test.go new file mode 100644 index 000000000..d5352c4c0 --- /dev/null +++ b/ecc/bls24-317/fr/fft/bitreverse_test.go @@ -0,0 +1,113 @@ +// Copyright 2020 Consensys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fft + +import ( + "fmt" + "testing" + + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" +) + +type bitReverseVariant struct { + name string + buf []fr.Element + fn func([]fr.Element) +} + +const maxSizeBitReverse = 1 << 23 + +var bitReverse = []bitReverseVariant{ + {name: "bitReverseNaive", buf: make([]fr.Element, maxSizeBitReverse), fn: bitReverseNaive}, + {name: "BitReverse", buf: make([]fr.Element, maxSizeBitReverse), fn: BitReverse}, + {name: "bitReverseCobraInPlace", buf: make([]fr.Element, maxSizeBitReverse), fn: bitReverseCobraInPlace}, +} + +func TestBitReverse(t *testing.T) { + + // generate a random []fr.Element array of size 2**20 + pol := make([]fr.Element, maxSizeBitReverse) + one := fr.One() + pol[0].SetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // for each size, check that all the bitReverse functions fn compute the same result. + for size := 2; size <= maxSizeBitReverse; size <<= 1 { + + // copy pol into the buffers + for _, data := range bitReverse { + copy(data.buf, pol[:size]) + } + + // compute bit reverse shuffling + for _, data := range bitReverse { + data.fn(data.buf[:size]) + } + + // all bitReverse.buf should hold the same result + for i := 0; i < size; i++ { + for j := 1; j < len(bitReverse); j++ { + if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { + t.Fatalf("bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + } + } + } + + // bitReverse back should be identity + for _, data := range bitReverse { + data.fn(data.buf[:size]) + } + + for i := 0; i < size; i++ { + for j := 1; j < len(bitReverse); j++ { + if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { + t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + } + } + } + } + +} + +func BenchmarkBitReverse(b *testing.B) { + // generate a random []fr.Element array of size 2**22 + pol := make([]fr.Element, maxSizeBitReverse) + one := fr.One() + pol[0].SetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // copy pol into the buffers + for _, data := range bitReverse { + copy(data.buf, pol[:maxSizeBitReverse]) + } + + // benchmark for each size, each bitReverse function + for size := 1 << 18; size <= maxSizeBitReverse; size <<= 1 { + for _, data := range bitReverse { + b.Run(fmt.Sprintf("name=%s/size=%d", data.name, size), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + data.fn(data.buf[:size]) + } + }) + } + } +} diff --git a/ecc/bls24-317/fr/fft/fft.go b/ecc/bls24-317/fr/fft/fft.go index cf4230ed1..5a205433b 100644 --- a/ecc/bls24-317/fr/fft/fft.go +++ b/ecc/bls24-317/fr/fft/fft.go @@ -235,20 +235,6 @@ func ditFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon } } -// BitReverse applies the bit-reversal permutation to a. -// len(a) must be a power of 2 (as in every single function in this file) -func BitReverse(a []fr.Element) { - n := uint64(len(a)) - nn := uint64(64 - bits.TrailingZeros64(n)) - - for i := uint64(0); i < n; i++ { - irev := bits.Reverse64(i) >> nn - if irev > i { - a[i], a[irev] = a[irev], a[i] - } - } -} - // kerDIT8 is a kernel that process a FFT of size 8 func kerDIT8(a []fr.Element, twiddles [][]fr.Element, stage int) { diff --git a/ecc/bls24-317/fr/fft/fft_test.go b/ecc/bls24-317/fr/fft/fft_test.go index d8b62eaa8..6a5e4dfdd 100644 --- a/ecc/bls24-317/fr/fft/fft_test.go +++ b/ecc/bls24-317/fr/fft/fft_test.go @@ -240,26 +240,6 @@ func TestFFT(t *testing.T) { // -------------------------------------------------------------------- // benches -func BenchmarkBitReverse(b *testing.B) { - - const maxSize = 1 << 20 - - pol := make([]fr.Element, maxSize) - pol[0].SetRandom() - for i := 1; i < maxSize; i++ { - pol[i] = pol[i-1] - } - - for i := 8; i < 20; i++ { - b.Run("bit reversing 2**"+strconv.Itoa(i)+"bits", func(b *testing.B) { - b.ResetTimer() - for j := 0; j < b.N; j++ { - BitReverse(pol[:1<> nn + if iRev > i { + v[i], v[iRev] = v[iRev], v[i] + } + } +} + +// bitReverseCobraInPlace applies the bit-reversal permutation to v. +// len(v) must be a power of 2 +// This is derived from: +// +// - Towards an Optimal Bit-Reversal Permutation Program +// Larry Carter and Kang Su Gatlin, 1998 +// https://csaws.cs.technion.ac.il/~itai/Courses/Cache/bit.pdf +// +// - Practically efficient methods for performing bit-reversed +// permutation in C++11 on the x86-64 architecture +// Knauth, Adas, Whitfield, Wang, Ickler, Conrad, Serang, 2017 +// https://arxiv.org/pdf/1708.01873.pdf +// +// - and more specifically, constantine implementation: +// https://github.com/mratsim/constantine/blob/d51699248db04e29c7b1ad97e0bafa1499db00b5/constantine/math/polynomials/fft.nim#L205 +// by Mamy Ratsimbazafy (@mratsim). +func bitReverseCobraInPlace(v []fr.Element) { + logN := uint64(bits.Len64(uint64(len(v))) - 1) + logTileSize := deriveLogTileSize(logN) + logBLen := logN - 2*logTileSize + bLen := uint64(1) << logBLen + bShift := logBLen + logTileSize + tileSize := uint64(1) << logTileSize + + // rough idea; + // bit reversal permutation naive implementation may have some cache associativity issues, + // since we are accessing elements by strides of powers of 2. + // on large inputs, this is noticeable and can be improved by using a t buffer. + // idea is for t buffer to be small enough to fit in cache. + // in the first inner loop, we copy the elements of v into t in a bit-reversed order. + // in the subsequent inner loops, accesses have much better cache locality than the naive implementation. + // hence even if we apparently do more work (swaps / copies), we are faster. + // + // on arm64 (and particularly on M1 macs), this is not noticeable, and the naive implementation is faster, + // in most cases. + // on x86 (and particularly on aws hpc6a) this is noticeable, and the t buffer implementation is faster (up to 3x). + // + // optimal choice for the tile size is cache dependent; in theory, we want the t buffer to fit in the L1 cache; + // in practice, a common size for L1 is 64kb, a field element is 32bytes or more. + // hence we can fit 2k elements in the L1 cache, which corresponds to a tile size of 2**5 with some margin for cache conflicts. + // + // for most sizes of interest, this tile size choice doesn't yield good results; + // we find that a tile size of 2**9 gives best results for input sizes from 2**21 up to 2**27+. + t := make([]fr.Element, tileSize*tileSize) + + // see https://csaws.cs.technion.ac.il/~itai/Courses/Cache/bit.pdf + // for a detailed explanation of the algorithm. + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> (64 - logTileSize)) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> (64 - logTileSize)) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> (64 - logTileSize) + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> (64 - logTileSize) + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> (64 - logTileSize)) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } +} + +func bitReverseCobra(v []fr.Element) { + switch len(v) { + case 1 << 21: + bitReverseCobraInPlace_9_21(v) + case 1 << 22: + bitReverseCobraInPlace_9_22(v) + case 1 << 23: + bitReverseCobraInPlace_9_23(v) + case 1 << 24: + bitReverseCobraInPlace_9_24(v) + case 1 << 25: + bitReverseCobraInPlace_9_25(v) + case 1 << 26: + bitReverseCobraInPlace_9_26(v) + case 1 << 27: + bitReverseCobraInPlace_9_27(v) + default: + if len(v) > 1<<27 { + bitReverseCobraInPlace(v) + } else { + bitReverseNaive(v) + } + } +} + +func deriveLogTileSize(logN uint64) uint64 { + q := uint64(9) // see bitReverseCobraInPlace for more details + + for int(logN)-int(2*q) <= 0 { + q-- + } + + return q +} + +// bitReverseCobraInPlace_9_21 applies the bit-reversal permutation to v. +// len(v) must be 1 << 21. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_21(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 21 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_22 applies the bit-reversal permutation to v. +// len(v) must be 1 << 22. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_22(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 22 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_23 applies the bit-reversal permutation to v. +// len(v) must be 1 << 23. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_23(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 23 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_24 applies the bit-reversal permutation to v. +// len(v) must be 1 << 24. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_24(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 24 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_25 applies the bit-reversal permutation to v. +// len(v) must be 1 << 25. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_25(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 25 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_26 applies the bit-reversal permutation to v. +// len(v) must be 1 << 26. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_26(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 26 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_27 applies the bit-reversal permutation to v. +// len(v) must be 1 << 27. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_27(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 27 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} diff --git a/ecc/bn254/fr/fft/bitreverse_test.go b/ecc/bn254/fr/fft/bitreverse_test.go new file mode 100644 index 000000000..8bdf34182 --- /dev/null +++ b/ecc/bn254/fr/fft/bitreverse_test.go @@ -0,0 +1,113 @@ +// Copyright 2020 Consensys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fft + +import ( + "fmt" + "testing" + + "github.com/consensys/gnark-crypto/ecc/bn254/fr" +) + +type bitReverseVariant struct { + name string + buf []fr.Element + fn func([]fr.Element) +} + +const maxSizeBitReverse = 1 << 23 + +var bitReverse = []bitReverseVariant{ + {name: "bitReverseNaive", buf: make([]fr.Element, maxSizeBitReverse), fn: bitReverseNaive}, + {name: "BitReverse", buf: make([]fr.Element, maxSizeBitReverse), fn: BitReverse}, + {name: "bitReverseCobraInPlace", buf: make([]fr.Element, maxSizeBitReverse), fn: bitReverseCobraInPlace}, +} + +func TestBitReverse(t *testing.T) { + + // generate a random []fr.Element array of size 2**20 + pol := make([]fr.Element, maxSizeBitReverse) + one := fr.One() + pol[0].SetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // for each size, check that all the bitReverse functions fn compute the same result. + for size := 2; size <= maxSizeBitReverse; size <<= 1 { + + // copy pol into the buffers + for _, data := range bitReverse { + copy(data.buf, pol[:size]) + } + + // compute bit reverse shuffling + for _, data := range bitReverse { + data.fn(data.buf[:size]) + } + + // all bitReverse.buf should hold the same result + for i := 0; i < size; i++ { + for j := 1; j < len(bitReverse); j++ { + if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { + t.Fatalf("bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + } + } + } + + // bitReverse back should be identity + for _, data := range bitReverse { + data.fn(data.buf[:size]) + } + + for i := 0; i < size; i++ { + for j := 1; j < len(bitReverse); j++ { + if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { + t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + } + } + } + } + +} + +func BenchmarkBitReverse(b *testing.B) { + // generate a random []fr.Element array of size 2**22 + pol := make([]fr.Element, maxSizeBitReverse) + one := fr.One() + pol[0].SetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // copy pol into the buffers + for _, data := range bitReverse { + copy(data.buf, pol[:maxSizeBitReverse]) + } + + // benchmark for each size, each bitReverse function + for size := 1 << 18; size <= maxSizeBitReverse; size <<= 1 { + for _, data := range bitReverse { + b.Run(fmt.Sprintf("name=%s/size=%d", data.name, size), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + data.fn(data.buf[:size]) + } + }) + } + } +} diff --git a/ecc/bn254/fr/fft/fft.go b/ecc/bn254/fr/fft/fft.go index 151b7832f..a3dfe0d34 100644 --- a/ecc/bn254/fr/fft/fft.go +++ b/ecc/bn254/fr/fft/fft.go @@ -235,20 +235,6 @@ func ditFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon } } -// BitReverse applies the bit-reversal permutation to a. -// len(a) must be a power of 2 (as in every single function in this file) -func BitReverse(a []fr.Element) { - n := uint64(len(a)) - nn := uint64(64 - bits.TrailingZeros64(n)) - - for i := uint64(0); i < n; i++ { - irev := bits.Reverse64(i) >> nn - if irev > i { - a[i], a[irev] = a[irev], a[i] - } - } -} - // kerDIT8 is a kernel that process a FFT of size 8 func kerDIT8(a []fr.Element, twiddles [][]fr.Element, stage int) { diff --git a/ecc/bn254/fr/fft/fft_test.go b/ecc/bn254/fr/fft/fft_test.go index b2c1e8999..ac1318130 100644 --- a/ecc/bn254/fr/fft/fft_test.go +++ b/ecc/bn254/fr/fft/fft_test.go @@ -240,26 +240,6 @@ func TestFFT(t *testing.T) { // -------------------------------------------------------------------- // benches -func BenchmarkBitReverse(b *testing.B) { - - const maxSize = 1 << 20 - - pol := make([]fr.Element, maxSize) - pol[0].SetRandom() - for i := 1; i < maxSize; i++ { - pol[i] = pol[i-1] - } - - for i := 8; i < 20; i++ { - b.Run("bit reversing 2**"+strconv.Itoa(i)+"bits", func(b *testing.B) { - b.ResetTimer() - for j := 0; j < b.N; j++ { - BitReverse(pol[:1<> nn + if iRev > i { + v[i], v[iRev] = v[iRev], v[i] + } + } +} + +// bitReverseCobraInPlace applies the bit-reversal permutation to v. +// len(v) must be a power of 2 +// This is derived from: +// +// - Towards an Optimal Bit-Reversal Permutation Program +// Larry Carter and Kang Su Gatlin, 1998 +// https://csaws.cs.technion.ac.il/~itai/Courses/Cache/bit.pdf +// +// - Practically efficient methods for performing bit-reversed +// permutation in C++11 on the x86-64 architecture +// Knauth, Adas, Whitfield, Wang, Ickler, Conrad, Serang, 2017 +// https://arxiv.org/pdf/1708.01873.pdf +// +// - and more specifically, constantine implementation: +// https://github.com/mratsim/constantine/blob/d51699248db04e29c7b1ad97e0bafa1499db00b5/constantine/math/polynomials/fft.nim#L205 +// by Mamy Ratsimbazafy (@mratsim). +func bitReverseCobraInPlace(v []fr.Element) { + logN := uint64(bits.Len64(uint64(len(v))) - 1) + logTileSize := deriveLogTileSize(logN) + logBLen := logN - 2*logTileSize + bLen := uint64(1) << logBLen + bShift := logBLen + logTileSize + tileSize := uint64(1) << logTileSize + + // rough idea; + // bit reversal permutation naive implementation may have some cache associativity issues, + // since we are accessing elements by strides of powers of 2. + // on large inputs, this is noticeable and can be improved by using a t buffer. + // idea is for t buffer to be small enough to fit in cache. + // in the first inner loop, we copy the elements of v into t in a bit-reversed order. + // in the subsequent inner loops, accesses have much better cache locality than the naive implementation. + // hence even if we apparently do more work (swaps / copies), we are faster. + // + // on arm64 (and particularly on M1 macs), this is not noticeable, and the naive implementation is faster, + // in most cases. + // on x86 (and particularly on aws hpc6a) this is noticeable, and the t buffer implementation is faster (up to 3x). + // + // optimal choice for the tile size is cache dependent; in theory, we want the t buffer to fit in the L1 cache; + // in practice, a common size for L1 is 64kb, a field element is 32bytes or more. + // hence we can fit 2k elements in the L1 cache, which corresponds to a tile size of 2**5 with some margin for cache conflicts. + // + // for most sizes of interest, this tile size choice doesn't yield good results; + // we find that a tile size of 2**9 gives best results for input sizes from 2**21 up to 2**27+. + t := make([]fr.Element, tileSize*tileSize) + + // see https://csaws.cs.technion.ac.il/~itai/Courses/Cache/bit.pdf + // for a detailed explanation of the algorithm. + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> (64 - logTileSize)) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> (64 - logTileSize)) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> (64 - logTileSize) + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> (64 - logTileSize) + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> (64 - logTileSize)) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } +} + +func bitReverseCobra(v []fr.Element) { + switch len(v) { + case 1 << 21: + bitReverseCobraInPlace_9_21(v) + case 1 << 22: + bitReverseCobraInPlace_9_22(v) + case 1 << 23: + bitReverseCobraInPlace_9_23(v) + case 1 << 24: + bitReverseCobraInPlace_9_24(v) + case 1 << 25: + bitReverseCobraInPlace_9_25(v) + case 1 << 26: + bitReverseCobraInPlace_9_26(v) + case 1 << 27: + bitReverseCobraInPlace_9_27(v) + default: + if len(v) > 1<<27 { + bitReverseCobraInPlace(v) + } else { + bitReverseNaive(v) + } + } +} + +func deriveLogTileSize(logN uint64) uint64 { + q := uint64(9) // see bitReverseCobraInPlace for more details + + for int(logN)-int(2*q) <= 0 { + q-- + } + + return q +} + +// bitReverseCobraInPlace_9_21 applies the bit-reversal permutation to v. +// len(v) must be 1 << 21. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_21(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 21 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_22 applies the bit-reversal permutation to v. +// len(v) must be 1 << 22. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_22(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 22 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_23 applies the bit-reversal permutation to v. +// len(v) must be 1 << 23. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_23(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 23 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_24 applies the bit-reversal permutation to v. +// len(v) must be 1 << 24. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_24(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 24 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_25 applies the bit-reversal permutation to v. +// len(v) must be 1 << 25. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_25(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 25 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_26 applies the bit-reversal permutation to v. +// len(v) must be 1 << 26. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_26(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 26 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_27 applies the bit-reversal permutation to v. +// len(v) must be 1 << 27. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_27(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 27 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} diff --git a/ecc/bw6-633/fr/fft/bitreverse_test.go b/ecc/bw6-633/fr/fft/bitreverse_test.go new file mode 100644 index 000000000..12755a689 --- /dev/null +++ b/ecc/bw6-633/fr/fft/bitreverse_test.go @@ -0,0 +1,113 @@ +// Copyright 2020 Consensys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fft + +import ( + "fmt" + "testing" + + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" +) + +type bitReverseVariant struct { + name string + buf []fr.Element + fn func([]fr.Element) +} + +const maxSizeBitReverse = 1 << 23 + +var bitReverse = []bitReverseVariant{ + {name: "bitReverseNaive", buf: make([]fr.Element, maxSizeBitReverse), fn: bitReverseNaive}, + {name: "BitReverse", buf: make([]fr.Element, maxSizeBitReverse), fn: BitReverse}, + {name: "bitReverseCobraInPlace", buf: make([]fr.Element, maxSizeBitReverse), fn: bitReverseCobraInPlace}, +} + +func TestBitReverse(t *testing.T) { + + // generate a random []fr.Element array of size 2**20 + pol := make([]fr.Element, maxSizeBitReverse) + one := fr.One() + pol[0].SetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // for each size, check that all the bitReverse functions fn compute the same result. + for size := 2; size <= maxSizeBitReverse; size <<= 1 { + + // copy pol into the buffers + for _, data := range bitReverse { + copy(data.buf, pol[:size]) + } + + // compute bit reverse shuffling + for _, data := range bitReverse { + data.fn(data.buf[:size]) + } + + // all bitReverse.buf should hold the same result + for i := 0; i < size; i++ { + for j := 1; j < len(bitReverse); j++ { + if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { + t.Fatalf("bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + } + } + } + + // bitReverse back should be identity + for _, data := range bitReverse { + data.fn(data.buf[:size]) + } + + for i := 0; i < size; i++ { + for j := 1; j < len(bitReverse); j++ { + if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { + t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + } + } + } + } + +} + +func BenchmarkBitReverse(b *testing.B) { + // generate a random []fr.Element array of size 2**22 + pol := make([]fr.Element, maxSizeBitReverse) + one := fr.One() + pol[0].SetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // copy pol into the buffers + for _, data := range bitReverse { + copy(data.buf, pol[:maxSizeBitReverse]) + } + + // benchmark for each size, each bitReverse function + for size := 1 << 18; size <= maxSizeBitReverse; size <<= 1 { + for _, data := range bitReverse { + b.Run(fmt.Sprintf("name=%s/size=%d", data.name, size), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + data.fn(data.buf[:size]) + } + }) + } + } +} diff --git a/ecc/bw6-633/fr/fft/fft.go b/ecc/bw6-633/fr/fft/fft.go index 7485014f1..63ba7a84c 100644 --- a/ecc/bw6-633/fr/fft/fft.go +++ b/ecc/bw6-633/fr/fft/fft.go @@ -235,20 +235,6 @@ func ditFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon } } -// BitReverse applies the bit-reversal permutation to a. -// len(a) must be a power of 2 (as in every single function in this file) -func BitReverse(a []fr.Element) { - n := uint64(len(a)) - nn := uint64(64 - bits.TrailingZeros64(n)) - - for i := uint64(0); i < n; i++ { - irev := bits.Reverse64(i) >> nn - if irev > i { - a[i], a[irev] = a[irev], a[i] - } - } -} - // kerDIT8 is a kernel that process a FFT of size 8 func kerDIT8(a []fr.Element, twiddles [][]fr.Element, stage int) { diff --git a/ecc/bw6-633/fr/fft/fft_test.go b/ecc/bw6-633/fr/fft/fft_test.go index e14364680..190028756 100644 --- a/ecc/bw6-633/fr/fft/fft_test.go +++ b/ecc/bw6-633/fr/fft/fft_test.go @@ -240,26 +240,6 @@ func TestFFT(t *testing.T) { // -------------------------------------------------------------------- // benches -func BenchmarkBitReverse(b *testing.B) { - - const maxSize = 1 << 20 - - pol := make([]fr.Element, maxSize) - pol[0].SetRandom() - for i := 1; i < maxSize; i++ { - pol[i] = pol[i-1] - } - - for i := 8; i < 20; i++ { - b.Run("bit reversing 2**"+strconv.Itoa(i)+"bits", func(b *testing.B) { - b.ResetTimer() - for j := 0; j < b.N; j++ { - BitReverse(pol[:1<> nn + if iRev > i { + v[i], v[iRev] = v[iRev], v[i] + } + } +} + +// bitReverseCobraInPlace applies the bit-reversal permutation to v. +// len(v) must be a power of 2 +// This is derived from: +// +// - Towards an Optimal Bit-Reversal Permutation Program +// Larry Carter and Kang Su Gatlin, 1998 +// https://csaws.cs.technion.ac.il/~itai/Courses/Cache/bit.pdf +// +// - Practically efficient methods for performing bit-reversed +// permutation in C++11 on the x86-64 architecture +// Knauth, Adas, Whitfield, Wang, Ickler, Conrad, Serang, 2017 +// https://arxiv.org/pdf/1708.01873.pdf +// +// - and more specifically, constantine implementation: +// https://github.com/mratsim/constantine/blob/d51699248db04e29c7b1ad97e0bafa1499db00b5/constantine/math/polynomials/fft.nim#L205 +// by Mamy Ratsimbazafy (@mratsim). +func bitReverseCobraInPlace(v []fr.Element) { + logN := uint64(bits.Len64(uint64(len(v))) - 1) + logTileSize := deriveLogTileSize(logN) + logBLen := logN - 2*logTileSize + bLen := uint64(1) << logBLen + bShift := logBLen + logTileSize + tileSize := uint64(1) << logTileSize + + // rough idea; + // bit reversal permutation naive implementation may have some cache associativity issues, + // since we are accessing elements by strides of powers of 2. + // on large inputs, this is noticeable and can be improved by using a t buffer. + // idea is for t buffer to be small enough to fit in cache. + // in the first inner loop, we copy the elements of v into t in a bit-reversed order. + // in the subsequent inner loops, accesses have much better cache locality than the naive implementation. + // hence even if we apparently do more work (swaps / copies), we are faster. + // + // on arm64 (and particularly on M1 macs), this is not noticeable, and the naive implementation is faster, + // in most cases. + // on x86 (and particularly on aws hpc6a) this is noticeable, and the t buffer implementation is faster (up to 3x). + // + // optimal choice for the tile size is cache dependent; in theory, we want the t buffer to fit in the L1 cache; + // in practice, a common size for L1 is 64kb, a field element is 32bytes or more. + // hence we can fit 2k elements in the L1 cache, which corresponds to a tile size of 2**5 with some margin for cache conflicts. + // + // for most sizes of interest, this tile size choice doesn't yield good results; + // we find that a tile size of 2**9 gives best results for input sizes from 2**21 up to 2**27+. + t := make([]fr.Element, tileSize*tileSize) + + // see https://csaws.cs.technion.ac.il/~itai/Courses/Cache/bit.pdf + // for a detailed explanation of the algorithm. + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> (64 - logTileSize)) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> (64 - logTileSize)) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> (64 - logTileSize) + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> (64 - logTileSize) + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> (64 - logTileSize)) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } +} + +func bitReverseCobra(v []fr.Element) { + switch len(v) { + case 1 << 21: + bitReverseCobraInPlace_9_21(v) + case 1 << 22: + bitReverseCobraInPlace_9_22(v) + case 1 << 23: + bitReverseCobraInPlace_9_23(v) + case 1 << 24: + bitReverseCobraInPlace_9_24(v) + case 1 << 25: + bitReverseCobraInPlace_9_25(v) + case 1 << 26: + bitReverseCobraInPlace_9_26(v) + case 1 << 27: + bitReverseCobraInPlace_9_27(v) + default: + if len(v) > 1<<27 { + bitReverseCobraInPlace(v) + } else { + bitReverseNaive(v) + } + } +} + +func deriveLogTileSize(logN uint64) uint64 { + q := uint64(9) // see bitReverseCobraInPlace for more details + + for int(logN)-int(2*q) <= 0 { + q-- + } + + return q +} + +// bitReverseCobraInPlace_9_21 applies the bit-reversal permutation to v. +// len(v) must be 1 << 21. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_21(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 21 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_22 applies the bit-reversal permutation to v. +// len(v) must be 1 << 22. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_22(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 22 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_23 applies the bit-reversal permutation to v. +// len(v) must be 1 << 23. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_23(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 23 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_24 applies the bit-reversal permutation to v. +// len(v) must be 1 << 24. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_24(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 24 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_25 applies the bit-reversal permutation to v. +// len(v) must be 1 << 25. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_25(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 25 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_26 applies the bit-reversal permutation to v. +// len(v) must be 1 << 26. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_26(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 26 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_27 applies the bit-reversal permutation to v. +// len(v) must be 1 << 27. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_27(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 27 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} diff --git a/ecc/bw6-756/fr/fft/bitreverse_test.go b/ecc/bw6-756/fr/fft/bitreverse_test.go new file mode 100644 index 000000000..500a2475f --- /dev/null +++ b/ecc/bw6-756/fr/fft/bitreverse_test.go @@ -0,0 +1,113 @@ +// Copyright 2020 Consensys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fft + +import ( + "fmt" + "testing" + + "github.com/consensys/gnark-crypto/ecc/bw6-756/fr" +) + +type bitReverseVariant struct { + name string + buf []fr.Element + fn func([]fr.Element) +} + +const maxSizeBitReverse = 1 << 23 + +var bitReverse = []bitReverseVariant{ + {name: "bitReverseNaive", buf: make([]fr.Element, maxSizeBitReverse), fn: bitReverseNaive}, + {name: "BitReverse", buf: make([]fr.Element, maxSizeBitReverse), fn: BitReverse}, + {name: "bitReverseCobraInPlace", buf: make([]fr.Element, maxSizeBitReverse), fn: bitReverseCobraInPlace}, +} + +func TestBitReverse(t *testing.T) { + + // generate a random []fr.Element array of size 2**20 + pol := make([]fr.Element, maxSizeBitReverse) + one := fr.One() + pol[0].SetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // for each size, check that all the bitReverse functions fn compute the same result. + for size := 2; size <= maxSizeBitReverse; size <<= 1 { + + // copy pol into the buffers + for _, data := range bitReverse { + copy(data.buf, pol[:size]) + } + + // compute bit reverse shuffling + for _, data := range bitReverse { + data.fn(data.buf[:size]) + } + + // all bitReverse.buf should hold the same result + for i := 0; i < size; i++ { + for j := 1; j < len(bitReverse); j++ { + if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { + t.Fatalf("bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + } + } + } + + // bitReverse back should be identity + for _, data := range bitReverse { + data.fn(data.buf[:size]) + } + + for i := 0; i < size; i++ { + for j := 1; j < len(bitReverse); j++ { + if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { + t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + } + } + } + } + +} + +func BenchmarkBitReverse(b *testing.B) { + // generate a random []fr.Element array of size 2**22 + pol := make([]fr.Element, maxSizeBitReverse) + one := fr.One() + pol[0].SetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // copy pol into the buffers + for _, data := range bitReverse { + copy(data.buf, pol[:maxSizeBitReverse]) + } + + // benchmark for each size, each bitReverse function + for size := 1 << 18; size <= maxSizeBitReverse; size <<= 1 { + for _, data := range bitReverse { + b.Run(fmt.Sprintf("name=%s/size=%d", data.name, size), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + data.fn(data.buf[:size]) + } + }) + } + } +} diff --git a/ecc/bw6-756/fr/fft/fft.go b/ecc/bw6-756/fr/fft/fft.go index 5fd501c20..0adffeb25 100644 --- a/ecc/bw6-756/fr/fft/fft.go +++ b/ecc/bw6-756/fr/fft/fft.go @@ -235,20 +235,6 @@ func ditFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon } } -// BitReverse applies the bit-reversal permutation to a. -// len(a) must be a power of 2 (as in every single function in this file) -func BitReverse(a []fr.Element) { - n := uint64(len(a)) - nn := uint64(64 - bits.TrailingZeros64(n)) - - for i := uint64(0); i < n; i++ { - irev := bits.Reverse64(i) >> nn - if irev > i { - a[i], a[irev] = a[irev], a[i] - } - } -} - // kerDIT8 is a kernel that process a FFT of size 8 func kerDIT8(a []fr.Element, twiddles [][]fr.Element, stage int) { diff --git a/ecc/bw6-756/fr/fft/fft_test.go b/ecc/bw6-756/fr/fft/fft_test.go index bb2bca8dc..ae23f5bcb 100644 --- a/ecc/bw6-756/fr/fft/fft_test.go +++ b/ecc/bw6-756/fr/fft/fft_test.go @@ -240,26 +240,6 @@ func TestFFT(t *testing.T) { // -------------------------------------------------------------------- // benches -func BenchmarkBitReverse(b *testing.B) { - - const maxSize = 1 << 20 - - pol := make([]fr.Element, maxSize) - pol[0].SetRandom() - for i := 1; i < maxSize; i++ { - pol[i] = pol[i-1] - } - - for i := 8; i < 20; i++ { - b.Run("bit reversing 2**"+strconv.Itoa(i)+"bits", func(b *testing.B) { - b.ResetTimer() - for j := 0; j < b.N; j++ { - BitReverse(pol[:1<> nn + if iRev > i { + v[i], v[iRev] = v[iRev], v[i] + } + } +} + +// bitReverseCobraInPlace applies the bit-reversal permutation to v. +// len(v) must be a power of 2 +// This is derived from: +// +// - Towards an Optimal Bit-Reversal Permutation Program +// Larry Carter and Kang Su Gatlin, 1998 +// https://csaws.cs.technion.ac.il/~itai/Courses/Cache/bit.pdf +// +// - Practically efficient methods for performing bit-reversed +// permutation in C++11 on the x86-64 architecture +// Knauth, Adas, Whitfield, Wang, Ickler, Conrad, Serang, 2017 +// https://arxiv.org/pdf/1708.01873.pdf +// +// - and more specifically, constantine implementation: +// https://github.com/mratsim/constantine/blob/d51699248db04e29c7b1ad97e0bafa1499db00b5/constantine/math/polynomials/fft.nim#L205 +// by Mamy Ratsimbazafy (@mratsim). +func bitReverseCobraInPlace(v []fr.Element) { + logN := uint64(bits.Len64(uint64(len(v))) - 1) + logTileSize := deriveLogTileSize(logN) + logBLen := logN - 2*logTileSize + bLen := uint64(1) << logBLen + bShift := logBLen + logTileSize + tileSize := uint64(1) << logTileSize + + // rough idea; + // bit reversal permutation naive implementation may have some cache associativity issues, + // since we are accessing elements by strides of powers of 2. + // on large inputs, this is noticeable and can be improved by using a t buffer. + // idea is for t buffer to be small enough to fit in cache. + // in the first inner loop, we copy the elements of v into t in a bit-reversed order. + // in the subsequent inner loops, accesses have much better cache locality than the naive implementation. + // hence even if we apparently do more work (swaps / copies), we are faster. + // + // on arm64 (and particularly on M1 macs), this is not noticeable, and the naive implementation is faster, + // in most cases. + // on x86 (and particularly on aws hpc6a) this is noticeable, and the t buffer implementation is faster (up to 3x). + // + // optimal choice for the tile size is cache dependent; in theory, we want the t buffer to fit in the L1 cache; + // in practice, a common size for L1 is 64kb, a field element is 32bytes or more. + // hence we can fit 2k elements in the L1 cache, which corresponds to a tile size of 2**5 with some margin for cache conflicts. + // + // for most sizes of interest, this tile size choice doesn't yield good results; + // we find that a tile size of 2**9 gives best results for input sizes from 2**21 up to 2**27+. + t := make([]fr.Element, tileSize*tileSize) + + // see https://csaws.cs.technion.ac.il/~itai/Courses/Cache/bit.pdf + // for a detailed explanation of the algorithm. + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> (64 - logTileSize)) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> (64 - logTileSize)) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> (64 - logTileSize) + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> (64 - logTileSize) + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> (64 - logTileSize)) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } +} + +func bitReverseCobra(v []fr.Element) { + switch len(v) { + case 1 << 21: + bitReverseCobraInPlace_9_21(v) + case 1 << 22: + bitReverseCobraInPlace_9_22(v) + case 1 << 23: + bitReverseCobraInPlace_9_23(v) + case 1 << 24: + bitReverseCobraInPlace_9_24(v) + case 1 << 25: + bitReverseCobraInPlace_9_25(v) + case 1 << 26: + bitReverseCobraInPlace_9_26(v) + case 1 << 27: + bitReverseCobraInPlace_9_27(v) + default: + if len(v) > 1<<27 { + bitReverseCobraInPlace(v) + } else { + bitReverseNaive(v) + } + } +} + +func deriveLogTileSize(logN uint64) uint64 { + q := uint64(9) // see bitReverseCobraInPlace for more details + + for int(logN)-int(2*q) <= 0 { + q-- + } + + return q +} + +// bitReverseCobraInPlace_9_21 applies the bit-reversal permutation to v. +// len(v) must be 1 << 21. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_21(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 21 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_22 applies the bit-reversal permutation to v. +// len(v) must be 1 << 22. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_22(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 22 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_23 applies the bit-reversal permutation to v. +// len(v) must be 1 << 23. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_23(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 23 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_24 applies the bit-reversal permutation to v. +// len(v) must be 1 << 24. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_24(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 24 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_25 applies the bit-reversal permutation to v. +// len(v) must be 1 << 25. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_25(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 25 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_26 applies the bit-reversal permutation to v. +// len(v) must be 1 << 26. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_26(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 26 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} + +// bitReverseCobraInPlace_9_27 applies the bit-reversal permutation to v. +// len(v) must be 1 << 27. +// see bitReverseCobraInPlace for more details; this function is specialized for 9, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_9_27(v []fr.Element) { + const ( + logTileSize = uint64(9) + tileSize = uint64(1) << logTileSize + logN = 27 + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> 55) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev|c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> 55) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> 55 + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> 55 + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> 55) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + +} diff --git a/ecc/bw6-761/fr/fft/bitreverse_test.go b/ecc/bw6-761/fr/fft/bitreverse_test.go new file mode 100644 index 000000000..ec150c8b9 --- /dev/null +++ b/ecc/bw6-761/fr/fft/bitreverse_test.go @@ -0,0 +1,113 @@ +// Copyright 2020 Consensys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fft + +import ( + "fmt" + "testing" + + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" +) + +type bitReverseVariant struct { + name string + buf []fr.Element + fn func([]fr.Element) +} + +const maxSizeBitReverse = 1 << 23 + +var bitReverse = []bitReverseVariant{ + {name: "bitReverseNaive", buf: make([]fr.Element, maxSizeBitReverse), fn: bitReverseNaive}, + {name: "BitReverse", buf: make([]fr.Element, maxSizeBitReverse), fn: BitReverse}, + {name: "bitReverseCobraInPlace", buf: make([]fr.Element, maxSizeBitReverse), fn: bitReverseCobraInPlace}, +} + +func TestBitReverse(t *testing.T) { + + // generate a random []fr.Element array of size 2**20 + pol := make([]fr.Element, maxSizeBitReverse) + one := fr.One() + pol[0].SetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // for each size, check that all the bitReverse functions fn compute the same result. + for size := 2; size <= maxSizeBitReverse; size <<= 1 { + + // copy pol into the buffers + for _, data := range bitReverse { + copy(data.buf, pol[:size]) + } + + // compute bit reverse shuffling + for _, data := range bitReverse { + data.fn(data.buf[:size]) + } + + // all bitReverse.buf should hold the same result + for i := 0; i < size; i++ { + for j := 1; j < len(bitReverse); j++ { + if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { + t.Fatalf("bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + } + } + } + + // bitReverse back should be identity + for _, data := range bitReverse { + data.fn(data.buf[:size]) + } + + for i := 0; i < size; i++ { + for j := 1; j < len(bitReverse); j++ { + if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { + t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + } + } + } + } + +} + +func BenchmarkBitReverse(b *testing.B) { + // generate a random []fr.Element array of size 2**22 + pol := make([]fr.Element, maxSizeBitReverse) + one := fr.One() + pol[0].SetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // copy pol into the buffers + for _, data := range bitReverse { + copy(data.buf, pol[:maxSizeBitReverse]) + } + + // benchmark for each size, each bitReverse function + for size := 1 << 18; size <= maxSizeBitReverse; size <<= 1 { + for _, data := range bitReverse { + b.Run(fmt.Sprintf("name=%s/size=%d", data.name, size), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + data.fn(data.buf[:size]) + } + }) + } + } +} diff --git a/ecc/bw6-761/fr/fft/fft.go b/ecc/bw6-761/fr/fft/fft.go index 5c05decd1..bc32d2165 100644 --- a/ecc/bw6-761/fr/fft/fft.go +++ b/ecc/bw6-761/fr/fft/fft.go @@ -235,20 +235,6 @@ func ditFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon } } -// BitReverse applies the bit-reversal permutation to a. -// len(a) must be a power of 2 (as in every single function in this file) -func BitReverse(a []fr.Element) { - n := uint64(len(a)) - nn := uint64(64 - bits.TrailingZeros64(n)) - - for i := uint64(0); i < n; i++ { - irev := bits.Reverse64(i) >> nn - if irev > i { - a[i], a[irev] = a[irev], a[i] - } - } -} - // kerDIT8 is a kernel that process a FFT of size 8 func kerDIT8(a []fr.Element, twiddles [][]fr.Element, stage int) { diff --git a/ecc/bw6-761/fr/fft/fft_test.go b/ecc/bw6-761/fr/fft/fft_test.go index 4e8bc8953..30b67a45b 100644 --- a/ecc/bw6-761/fr/fft/fft_test.go +++ b/ecc/bw6-761/fr/fft/fft_test.go @@ -240,26 +240,6 @@ func TestFFT(t *testing.T) { // -------------------------------------------------------------------- // benches -func BenchmarkBitReverse(b *testing.B) { - - const maxSize = 1 << 20 - - pol := make([]fr.Element, maxSize) - pol[0].SetRandom() - for i := 1; i < maxSize; i++ { - pol[i] = pol[i-1] - } - - for i := 8; i < 20; i++ { - b.Run("bit reversing 2**"+strconv.Itoa(i)+"bits", func(b *testing.B) { - b.ResetTimer() - for j := 0; j < b.N; j++ { - BitReverse(pol[:1<> anyToUint64(n) + } + funcs["shl"] = func(x, n any) uint64 { + return anyToUint64(x) << anyToUint64(n) + } + funcs["logicalOr"] = func(x, y any) uint64 { + return anyToUint64(x) | anyToUint64(y) + } bavardOpts := []func(*bavard.Bavard) error{bavard.Funcs(funcs)} return bgen.GenerateWithOptions(conf, conf.Package, "./fft/template/", bavardOpts, entries...) } + +func anyToUint64(x any) uint64 { + switch v := x.(type) { + case int: + return uint64(v) + case int64: + return uint64(v) + case uint64: + return v + default: + panic("unknown type") + } +} diff --git a/internal/generator/fft/template/bitreverse.go.tmpl b/internal/generator/fft/template/bitreverse.go.tmpl new file mode 100644 index 000000000..62d777389 --- /dev/null +++ b/internal/generator/fft/template/bitreverse.go.tmpl @@ -0,0 +1,236 @@ +import ( + "math/bits" + "runtime" + {{ template "import_fr" . }} +) + +// BitReverse applies the bit-reversal permutation to v. +// len(v) must be a power of 2 +func BitReverse(v []fr.Element) { + n := uint64(len(v)) + if bits.OnesCount64(n) != 1 { + panic("len(a) must be a power of 2") + } + + if runtime.GOARCH == "arm64" { + bitReverseNaive(v) + } else { + bitReverseCobra(v) + } +} + +// bitReverseNaive applies the bit-reversal permutation to v. +// len(v) must be a power of 2 +func bitReverseNaive(v []fr.Element) { + n := uint64(len(v)) + nn := uint64(64 - bits.TrailingZeros64(n)) + + for i := uint64(0); i < n; i++ { + iRev := bits.Reverse64(i) >> nn + if iRev > i { + v[i], v[iRev] = v[iRev], v[i] + } + } +} + + +// bitReverseCobraInPlace applies the bit-reversal permutation to v. +// len(v) must be a power of 2 +// This is derived from: +// +// - Towards an Optimal Bit-Reversal Permutation Program +// Larry Carter and Kang Su Gatlin, 1998 +// https://csaws.cs.technion.ac.il/~itai/Courses/Cache/bit.pdf +// +// - Practically efficient methods for performing bit-reversed +// permutation in C++11 on the x86-64 architecture +// Knauth, Adas, Whitfield, Wang, Ickler, Conrad, Serang, 2017 +// https://arxiv.org/pdf/1708.01873.pdf +// +// - and more specifically, constantine implementation: +// https://github.com/mratsim/constantine/blob/d51699248db04e29c7b1ad97e0bafa1499db00b5/constantine/math/polynomials/fft.nim#L205 +// by Mamy Ratsimbazafy (@mratsim). +// +func bitReverseCobraInPlace(v []fr.Element) { + logN := uint64(bits.Len64(uint64(len(v))) - 1) + logTileSize := deriveLogTileSize(logN) + logBLen := logN - 2*logTileSize + bLen := uint64(1) << logBLen + bShift := logBLen + logTileSize + tileSize := uint64(1) << logTileSize + + // rough idea; + // bit reversal permutation naive implementation may have some cache associativity issues, + // since we are accessing elements by strides of powers of 2. + // on large inputs, this is noticeable and can be improved by using a t buffer. + // idea is for t buffer to be small enough to fit in cache. + // in the first inner loop, we copy the elements of v into t in a bit-reversed order. + // in the subsequent inner loops, accesses have much better cache locality than the naive implementation. + // hence even if we apparently do more work (swaps / copies), we are faster. + // + // on arm64 (and particularly on M1 macs), this is not noticeable, and the naive implementation is faster, + // in most cases. + // on x86 (and particularly on aws hpc6a) this is noticeable, and the t buffer implementation is faster (up to 3x). + // + // optimal choice for the tile size is cache dependent; in theory, we want the t buffer to fit in the L1 cache; + // in practice, a common size for L1 is 64kb, a field element is 32bytes or more. + // hence we can fit 2k elements in the L1 cache, which corresponds to a tile size of 2**5 with some margin for cache conflicts. + // + // for most sizes of interest, this tile size choice doesn't yield good results; + // we find that a tile size of 2**9 gives best results for input sizes from 2**21 up to 2**27+. + t := make([]fr.Element, tileSize*tileSize) + + + // see https://csaws.cs.technion.ac.il/~itai/Courses/Cache/bit.pdf + // for a detailed explanation of the algorithm. + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev :=( bits.Reverse64(a) >> (64 - logTileSize)) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev | c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> (64 - logTileSize)) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> (64 - logTileSize) + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> (64 - logTileSize) + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> (64 - logTileSize)) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } +} + + +func bitReverseCobra(v []fr.Element) { + switch len(v) { + case 1 << 21: + bitReverseCobraInPlace_9_21(v) + case 1 << 22: + bitReverseCobraInPlace_9_22(v) + case 1 << 23: + bitReverseCobraInPlace_9_23(v) + case 1 << 24: + bitReverseCobraInPlace_9_24(v) + case 1 << 25: + bitReverseCobraInPlace_9_25(v) + case 1 << 26: + bitReverseCobraInPlace_9_26(v) + case 1 << 27: + bitReverseCobraInPlace_9_27(v) + default: + if len(v) > 1<<27 { + bitReverseCobraInPlace(v) + } else { + bitReverseNaive(v) + } + } +} + + +func deriveLogTileSize(logN uint64) uint64 { + q := uint64(9) // see bitReverseCobraInPlace for more details + + for int(logN)-int(2*q) <= 0 { + q-- + } + + return q +} + + +{{bitReverseCobraInPlace 9 21}} +{{bitReverseCobraInPlace 9 22}} +{{bitReverseCobraInPlace 9 23}} +{{bitReverseCobraInPlace 9 24}} +{{bitReverseCobraInPlace 9 25}} +{{bitReverseCobraInPlace 9 26}} +{{bitReverseCobraInPlace 9 27}} + + +{{define "bitReverseCobraInPlace logTileSize logN"}} + +// bitReverseCobraInPlace_{{.logTileSize}}_{{.logN}} applies the bit-reversal permutation to v. +// len(v) must be 1 << {{.logN}}. +// see bitReverseCobraInPlace for more details; this function is specialized for {{.logTileSize}}, +// as it declares the t buffer and various constants statically for performance. +func bitReverseCobraInPlace_{{.logTileSize}}_{{.logN}}(v []fr.Element) { + const ( + logTileSize = uint64({{.logTileSize}}) + tileSize = uint64(1) << logTileSize + logN = {{.logN}} + logBLen = logN - 2*logTileSize + bShift = logBLen + logTileSize + bLen = uint64(1) << logBLen + ) + + var t [tileSize * tileSize]fr.Element + {{$k := sub 64 .logTileSize}} + {{$l := .logTileSize}} + {{$tileSize := shl 1 .logTileSize}} + + for b := uint64(0); b < bLen; b++ { + + for a := uint64(0); a < tileSize; a++ { + aRev := (bits.Reverse64(a) >> {{$k}}) << logTileSize + for c := uint64(0); c < tileSize; c++ { + idx := (a << bShift) | (b << logTileSize) | c + t[aRev | c] = v[idx] + } + } + + bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize + + for c := uint64(0); c < tileSize; c++ { + cRev := ((bits.Reverse64(c) >> {{$k}}) << bShift) | bRev + for aRev := uint64(0); aRev < tileSize; aRev++ { + a := bits.Reverse64(aRev) >> {{$k}} + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idxRev], t[tIdx] = t[tIdx], v[idxRev] + } + } + } + + for a := uint64(0); a < tileSize; a++ { + aRev := bits.Reverse64(a) >> {{$k}} + for c := uint64(0); c < tileSize; c++ { + cRev := (bits.Reverse64(c) >> {{$k}}) << bShift + idx := (a << bShift) | (b << logTileSize) | c + idxRev := cRev | bRev | aRev + if idx < idxRev { + tIdx := (aRev << logTileSize) | c + v[idx], t[tIdx] = t[tIdx], v[idx] + } + } + } + } + + +} + +{{- end}} \ No newline at end of file diff --git a/internal/generator/fft/template/fft.go.tmpl b/internal/generator/fft/template/fft.go.tmpl index 90729bd6e..7cde1697c 100644 --- a/internal/generator/fft/template/fft.go.tmpl +++ b/internal/generator/fft/template/fft.go.tmpl @@ -218,19 +218,6 @@ func ditFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon } } -// BitReverse applies the bit-reversal permutation to a. -// len(a) must be a power of 2 (as in every single function in this file) -func BitReverse(a []fr.Element) { - n := uint64(len(a)) - nn := uint64(64 - bits.TrailingZeros64(n)) - - for i := uint64(0); i < n; i++ { - irev := bits.Reverse64(i) >> nn - if irev > i { - a[i], a[irev] = a[irev], a[i] - } - } -} // kerDIT8 is a kernel that process a FFT of size 8 func kerDIT8(a []fr.Element, twiddles [][]fr.Element, stage int) { diff --git a/internal/generator/fft/template/tests/bitreverse.go.tmpl b/internal/generator/fft/template/tests/bitreverse.go.tmpl new file mode 100644 index 000000000..2f70e73bb --- /dev/null +++ b/internal/generator/fft/template/tests/bitreverse.go.tmpl @@ -0,0 +1,98 @@ +import ( + "fmt" + "testing" + + {{ template "import_fr" . }} +) + + +type bitReverseVariant struct { + name string + buf []fr.Element + fn func([]fr.Element) +} + + + +const maxSizeBitReverse = 1 << 23 + +var bitReverse = []bitReverseVariant{ + {name: "bitReverseNaive", buf: make([]fr.Element, maxSizeBitReverse), fn: bitReverseNaive}, + {name: "BitReverse", buf: make([]fr.Element, maxSizeBitReverse), fn: BitReverse}, + {name: "bitReverseCobraInPlace", buf: make([]fr.Element, maxSizeBitReverse), fn: bitReverseCobraInPlace}, +} + +func TestBitReverse(t *testing.T) { + + // generate a random []fr.Element array of size 2**20 + pol := make([]fr.Element, maxSizeBitReverse) + one := fr.One() + pol[0].SetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // for each size, check that all the bitReverse functions fn compute the same result. + for size := 2; size <= maxSizeBitReverse; size <<= 1 { + + // copy pol into the buffers + for _, data := range bitReverse { + copy(data.buf, pol[:size]) + } + + // compute bit reverse shuffling + for _, data := range bitReverse { + data.fn(data.buf[:size]) + } + + // all bitReverse.buf should hold the same result + for i := 0; i < size; i++ { + for j := 1; j < len(bitReverse); j++ { + if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { + t.Fatalf("bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + } + } + } + + // bitReverse back should be identity + for _, data := range bitReverse { + data.fn(data.buf[:size]) + } + + for i := 0; i < size; i++ { + for j := 1; j < len(bitReverse); j++ { + if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { + t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + } + } + } + } + +} + +func BenchmarkBitReverse(b *testing.B) { + // generate a random []fr.Element array of size 2**22 + pol := make([]fr.Element, maxSizeBitReverse) + one := fr.One() + pol[0].SetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // copy pol into the buffers + for _, data := range bitReverse { + copy(data.buf, pol[:maxSizeBitReverse]) + } + + // benchmark for each size, each bitReverse function + for size := 1 << 18; size <= maxSizeBitReverse; size <<= 1 { + for _, data := range bitReverse { + b.Run(fmt.Sprintf("name=%s/size=%d", data.name, size), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + data.fn(data.buf[:size]) + } + }) + } + } +} diff --git a/internal/generator/fft/template/tests/fft.go.tmpl b/internal/generator/fft/template/tests/fft.go.tmpl index d440aa77e..5fb9742e9 100644 --- a/internal/generator/fft/template/tests/fft.go.tmpl +++ b/internal/generator/fft/template/tests/fft.go.tmpl @@ -223,26 +223,6 @@ func TestFFT(t *testing.T) { // -------------------------------------------------------------------- // benches -func BenchmarkBitReverse(b *testing.B) { - - const maxSize = 1 << 20 - - pol := make([]fr.Element, maxSize) - pol[0].SetRandom() - for i := 1; i < maxSize; i++ { - pol[i] = pol[i-1] - } - - for i := 8; i < 20; i++ { - b.Run("bit reversing 2**"+strconv.Itoa(i)+"bits", func(b *testing.B) { - b.ResetTimer() - for j := 0; j < b.N; j++ { - BitReverse(pol[:1<