From 10d1722c45c49c36660603beb9201476205c2656 Mon Sep 17 00:00:00 2001 From: Victorin Date: Tue, 16 Jul 2024 18:24:33 +0200 Subject: [PATCH] 16/07 --- include/hpcombi/bmat16.hpp | 13 +++--- include/hpcombi/bmat16_impl.hpp | 83 +++++++++++++++++++++------------ include/hpcombi/bmat8.hpp | 2 +- 3 files changed, 61 insertions(+), 37 deletions(-) diff --git a/include/hpcombi/bmat16.hpp b/include/hpcombi/bmat16.hpp index a69c858..f40f9af 100644 --- a/include/hpcombi/bmat16.hpp +++ b/include/hpcombi/bmat16.hpp @@ -78,6 +78,7 @@ class BMat16 { //! A constructor. //! //! This constructor initializes a matrix with 4 64 bits unsigned int + //! Each uint represents one of the four quarter (8x8 matrix). explicit BMat16(uint64_t n0, uint64_t n1, uint64_t n2, uint64_t n3) noexcept; //! A constructor. @@ -138,12 +139,12 @@ class BMat16 { //! is possible to access entries that you might not believe exist. bool operator()(size_t i, size_t j) const noexcept; - // //! Sets the (\p i, \p j)th position to \p val. - // //! - // //! This method sets the (\p i, \p j)th entry of \c this to \p val. - // //! Uses the bit twiddle for setting bits found - // //! here. - // void set(size_t i, size_t j, bool val) noexcept; + //! Sets the (\p i, \p j)th position to \p val. + //! + //! This method sets the (\p i, \p j)th entry of \c this to \p val. + //! Uses the bit twiddle for setting bits found + //! here. + void set(size_t i, size_t j, bool val) noexcept; //! Returns the array representation of \c this. //! diff --git a/include/hpcombi/bmat16_impl.hpp b/include/hpcombi/bmat16_impl.hpp index 5a1a0a1..5824ba0 100644 --- a/include/hpcombi/bmat16_impl.hpp +++ b/include/hpcombi/bmat16_impl.hpp @@ -45,18 +45,30 @@ inline BMat16::BMat16(std::vector> const &mat) noexcept { HPCOMBI_ASSERT(mat.size() <= 16); HPCOMBI_ASSERT(0 < mat.size()); std::array tmp = {0, 0, 0, 0}; - for (int i = mat.size() - 1; i >= 0; i--) { + for (int i = mat.size() - 1; i >= 0; --i) { HPCOMBI_ASSERT(mat.size() == mat[i].size()); tmp[i/4] <<= 16 - mat.size(); - for (int j = mat[i].size() - 1; j >= 0; j--) { + for (int j = mat[i].size() - 1; j >= 0; --j) { tmp[i/4] = (tmp[i/4] << 1) | mat[i][j]; } } _data = xpu64{tmp[0], tmp[1], tmp[2], tmp[3]}; } -bool BMat16::operator()(size_t i, size_t j) const noexcept { - return (_data[i/4] >> (16 * (i%4) + j)) % 2; +inline bool BMat16::operator()(size_t i, size_t j) const noexcept { + return (_data[i/4] >> (16 * (i%4) + j)) & 1; +} + +inline void BMat16::set(size_t i, size_t j, bool val) noexcept { + HPCOMBI_ASSERT(i < 16); + HPCOMBI_ASSERT(j < 16); + uint64_t a = 1; + a <<= 16 * (i%4) + j; + xpu64 mask{(i/4 == 0)*a, + (i/4 == 1)*a, + (i/4 == 2)*a, + (i/4 == 3)*a}; + _data ^= (-val ^ _data) & mask; } inline bool BMat16::operator==(BMat16 const &that) const noexcept { @@ -64,37 +76,37 @@ inline bool BMat16::operator==(BMat16 const &that) const noexcept { return simde_mm256_testz_si256(tmp, tmp); } -bool BMat16::operator<(BMat16 const &that) const noexcept { +inline bool BMat16::operator<(BMat16 const &that) const noexcept { return _data[0] < that._data[0] || (_data[0] == that._data[0] && (_data[1] < that._data[1] || (_data[1] == that._data[1] && (_data[2] < that._data[2] || (_data[2] == that._data[2] && (_data[3] < that._data[3])))))); } -bool BMat16::operator>(BMat16 const &that) const noexcept { +inline bool BMat16::operator>(BMat16 const &that) const noexcept { return _data[0] > that._data[0] || (_data[0] == that._data[0] && (_data[1] > that._data[1] || (_data[1] == that._data[1] && (_data[2] > that._data[2] || (_data[2] == that._data[2] && (_data[3] > that._data[3])))))); } -std::array, 16> BMat16::to_array() const noexcept { +inline std::array, 16> BMat16::to_array() const noexcept { xpu64 tmp = to_block(_data); uint64_t a = tmp[0], b = tmp[1], c = tmp[2], d = tmp[3]; std::array, 16> res; - for (int i = 0; i < 64; i++) { - res[i/8][i%8] = a % 2; a >>= 1; - res[i/8][8 + i%8] = b % 2; b >>= 1; - res[8 + i/8][i%8] = c % 2; c >>= 1; - res[8 + i/8][8 + i%8] = d % 2; d >>= 1; + for (size_t i = 0; i < 64; ++i) { + res[i/8][i%8] = a & 1; a >>= 1; + res[i/8][8 + i%8] = b & 1; b >>= 1; + res[8 + i/8][i%8] = c & 1; c >>= 1; + res[8 + i/8][8 + i%8] = d & 1; d >>= 1; } return res; } inline BMat16 BMat16::transpose_naive() const noexcept { uint64_t a = 0, b = 0, c = 0, d = 0; - for (int i = 7; i >= 0; i--) { - for (int j = 7; j >= 0; j--) { + for (int i = 7; i >= 0; --i) { + for (int j = 7; j >= 0; --j) { a = (a << 1) | (*this)(j, i); b = (b << 1) | (*this)(j+8, i); c = (c << 1) | (*this)(j, i+8); @@ -127,7 +139,7 @@ inline BMat16 BMat16::mult_transpose(BMat16 const &that) const noexcept { xpu16 data = simde_mm256_setzero_si256(); xpu16 diag1{0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80, 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000}; xpu16 diag2{0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000, 0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80}; - for (int i = 0; i < 8; ++i) { + for (size_t i = 0; i < 8; ++i) { data |= ((x & y1) != zero) & diag1; data |= ((x & y2) != zero) & diag2; y1 = simde_mm256_shuffle_epi8(y1, rot); @@ -138,7 +150,7 @@ inline BMat16 BMat16::mult_transpose(BMat16 const &that) const noexcept { return BMat16(data); } -BMat16 BMat16::mult_4bmat8(BMat16 const &that) const noexcept { +inline BMat16 BMat16::mult_4bmat8(BMat16 const &that) const noexcept { BMat16 tmp = that.transpose(); xpu64 t1 = to_block(_data), t2 = to_block(tmp._data); @@ -150,12 +162,12 @@ BMat16 BMat16::mult_4bmat8(BMat16 const &that) const noexcept { (c1.mult_transpose(c2) | d1.mult_transpose(d2)).to_int()); } -BMat16 BMat16::mult_naive(BMat16 const &that) const noexcept { +inline BMat16 BMat16::mult_naive(BMat16 const &that) const noexcept { uint64_t a = 0, b = 0, c = 0, d = 0; - for (int i = 7; i >= 0; i--) { - for (int j = 7; j >= 0; j--) { + for (int i = 7; i >= 0; --i) { + for (int j = 7; j >= 0; --j) { a <<= 1; b <<= 1; c <<= 1; d <<= 1; - for (int k = 0; k < 8; k++) { + for (size_t k = 0; k < 8; ++k) { a |= ((*this)(i, k) & that(k, j)) | ((*this)(i, k + 8) & that(k + 8, j)); b |= ((*this)(i, k) & that(k, j + 8)) | ((*this)(i, k + 8) & that(k + 8, j + 8)); c |= ((*this)(i + 8, k) & that(k, j)) | ((*this)(i + 8, k + 8) & that(k + 8, j)); @@ -166,13 +178,13 @@ BMat16 BMat16::mult_naive(BMat16 const &that) const noexcept { return BMat16(a, b, c, d); } -BMat16 BMat16::mult_naive_array(BMat16 const &that) const noexcept { +inline BMat16 BMat16::mult_naive_array(BMat16 const &that) const noexcept { std::array, 16> tab1 = to_array(), tab2 = that.to_array(); uint64_t a = 0, b = 0, c = 0, d = 0; - for (int i = 7; i >= 0; i--) { - for (int j = 7; j >= 0; j--) { + for (int i = 7; i >= 0; --i) { + for (int j = 7; j >= 0; --j) { a <<= 1; b <<= 1; c <<= 1; d <<= 1; - for (int k = 0; k < 16; k++) { + for (size_t k = 0; k < 16; ++k) { a |= tab1[i][k] & tab2[k][j]; b |= tab1[i][k] & tab2[k][j + 8]; c |= tab1[i + 8][k] & tab2[k][j]; @@ -184,16 +196,27 @@ BMat16 BMat16::mult_naive_array(BMat16 const &that) const noexcept { } inline size_t BMat16::nr_rows() const noexcept{ - xpu16 tmp = _data, zero = simde_mm256_setzero_si256(); - xpu16 x = (tmp != zero); - return 0; - // return simde_mm256_popcnt_epi16(x); // To change + size_t res = 0; + for (size_t i = 0; i < 16; ++i) + if ((_data[i/4] << (16 * (i%4)) >> 48) != 0) + ++res; + return res; + + //// Vectorized version that doesn't work due to the absence of popcnt in simde + // xpu16 tmp = _data, zero = simde_mm256_setzero_si256(); + // xpu16 x = (tmp != zero); + // return simde_mm256_popcnt_epi16(x); } -inline std::vector BMat16::rows() const { // To change +inline std::vector BMat16::rows() const { std::vector rows; for (size_t i = 0; i < 16; ++i) { - uint16_t row = static_cast(_data[i/4] << (16 * (i%4)) >> 48); + uint16_t row_rev = (_data[i/4] << (16 * (3 - i%4)) >> 48); + uint16_t row = 0; + for (size_t j = 0; j < 16; ++j) { + row = (row << 1) | (row_rev & 1); + row_rev >>= 1; + } rows.push_back(row); } return rows; diff --git a/include/hpcombi/bmat8.hpp b/include/hpcombi/bmat8.hpp index 1437da3..f4132ff 100644 --- a/include/hpcombi/bmat8.hpp +++ b/include/hpcombi/bmat8.hpp @@ -133,7 +133,7 @@ class BMat8 { //! //! This method sets the (\p i, \p j)th entry of \c this to \p val. //! Uses the bit twiddle for setting bits found - //! here. + //! here. void set(size_t i, size_t j, bool val) noexcept; //! Returns the integer representation of \c this.