diff --git a/include/perm16.hpp b/include/perm16.hpp index f4b3d92b..da65c158 100644 --- a/include/perm16.hpp +++ b/include/perm16.hpp @@ -73,7 +73,16 @@ struct alignas(16) PTransf16 : public Vect16 { } /** Returns a mask for the image of \c *this */ - epu8 image_mask(bool complement = false) const; + epu8 image_mask_cmpestrm(bool complement = false) const; + /** Returns a mask for the image of \c *this */ + epu8 image_mask_ref(bool complement = false) const; + epu8 image_mask(bool complement = false) const { +#ifdef SIMDE_X86_SSE4_2_NATIVE + return image_mask_cmpestrm(complement); +#else + return image_mask_ref(complement); +#endif + } /** Returns a bit mask for the image of \c *this */ uint32_t image_bitset(bool complement = false) const; /** Returns a mask for the domain of \c *this */ diff --git a/include/perm16_impl.hpp b/include/perm16_impl.hpp index 6e6bde42..44ccf42b 100644 --- a/include/perm16_impl.hpp +++ b/include/perm16_impl.hpp @@ -57,13 +57,20 @@ inline PTransf16 PTransf16::right_one() const { return domain_mask(true) | epu8id; } -inline epu8 PTransf16::image_mask(bool complement) const { #ifdef SIMDE_X86_SSE4_2_NATIVE +inline epu8 PTransf16::image_mask_cmpestrm(bool complement) const { return complement ? _mm_cmpestrm(v, 16, one().v, 16, FIND_IN_VECT) : _mm_cmpestrm(v, 16, one().v, 16, FIND_IN_VECT_COMPL); -#else +} #endif +inline epu8 PTransf16::image_mask_ref(bool complement) const { + epu8 res{}; + for (auto x : *this) + if (x != 0xFF) + res[x] = 0xFF; + return complement ? static_cast(!res) : res; } + inline uint32_t PTransf16::image_bitset(bool complement) const { return simde_mm_movemask_epi8(image_mask(complement)); } @@ -73,10 +80,9 @@ inline PTransf16 PTransf16::left_one() const { inline uint32_t PTransf16::rank_ref() const { TPUBuild::array tmp{}; static_assert(TPUBuild::size == 16, "Wrong size of EPU8 array"); - for (size_t i = 0; i < 16; i++) { - if (v[i] != 0xFF) - tmp[v[i]] = 1; - } + for (auto x : *this) + if (x != 0xFF) + tmp[x] = 1; return std::accumulate(tmp.begin(), tmp.end(), uint8_t(0)); } inline uint32_t PTransf16::rank() const { diff --git a/include/vect_generic.hpp b/include/vect_generic.hpp index 594a14ea..6c98f281 100644 --- a/include/vect_generic.hpp +++ b/include/vect_generic.hpp @@ -222,10 +222,10 @@ namespace std { template std::ostream &operator<<(std::ostream &stream, const HPCombi::VectGeneric<_Size, Expo> &v) { - stream << "[" << std::setw(2) << unsigned(v[0]); + stream << "{" << std::setw(2) << unsigned(v[0]); for (unsigned i = 1; i < _Size; ++i) stream << "," << std::setw(2) << unsigned(v[i]); - stream << "]"; + stream << "}"; return stream; } diff --git a/tests/test_epu.cpp b/tests/test_epu.cpp index 7f8fe36d..fe80eb77 100644 --- a/tests/test_epu.cpp +++ b/tests/test_epu.cpp @@ -194,18 +194,18 @@ TEST_CASE_METHOD(Fix, "Epu8::less", "[Epu8][010]") { } TEST_CASE_METHOD(Fix, "Epu8::permuted", "[Epu8][011]") { - REQUIRE(equal( + REQUIRE_THAT( permuted(epu8{0, 1, 3, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, epu8{3, 2, 5, 1, 4, 0, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}), - epu8{2, 3, 5, 1, 4, 0, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15})); - REQUIRE(equal( + Equals(epu8{2, 3, 5, 1, 4, 0, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15})); + REQUIRE_THAT( permuted(epu8{3, 2, 5, 1, 4, 0, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, epu8{0, 1, 3, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}), - epu8{3, 2, 1, 5, 4, 0, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15})); - REQUIRE(equal( + Equals(epu8{3, 2, 1, 5, 4, 0, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15})); + REQUIRE_THAT( permuted(epu8{3, 2, 5, 1, 4, 0, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, epu8{2, 2, 1, 2, 3, 6, 12, 4, 5, 16, 17, 11, 12, 13, 14, 15}), - epu8{5, 5, 2, 5, 1, 6, 12, 4, 0, 3, 2, 11, 12, 13, 14, 15})); + Equals(epu8{5, 5, 2, 5, 1, 6, 12, 4, 0, 3, 2, 11, 12, 13, 14, 15})); } TEST_CASE_METHOD(Fix, "Epu8::shifted_left", "[Epu8][012]") { diff --git a/tests/test_perm16.cpp b/tests/test_perm16.cpp index d75e3146..8d3278e6 100644 --- a/tests/test_perm16.cpp +++ b/tests/test_perm16.cpp @@ -92,32 +92,55 @@ TEST_CASE("PTransf16::hash", "[PTransf16][001]") { REQUIRE(std::hash()(PTransf16({4, 5, 0}, {9, 0, 1})) != 0); } -// TODO uncomment -// TEST_CASE("PTransf16::image_mask", "[PTransf16][002]") { -// REQUIRE(equal(PTransf16({}).image_mask(), Epu8(FF)); -// REQUIRE(equal(PTransf16({}).image_mask(false), Epu8(FF)); -// REQUIRE(equal(PTransf16({}).image_mask(true), Epu8(0)); -// REQUIRE(equal(PTransf16({4, 4, 4, 4}).image_mask(), Epu8({0, 0, 0, 0}, -// FF)); REQUIRE(equal(PTransf16({4, 4, 4, 4}).image_mask(false), -// Epu8({0, 0, 0, 0}, FF)); -// REQUIRE(equal(PTransf16({4, 4, 4, 4}).image_mask(true), -// Epu8({FF, FF, FF, FF}, 0)); -// REQUIRE(equal(PTransf16(Epu8(1)).image_mask(), Epu8({0, FF}, 0)); -// REQUIRE(equal(PTransf16(Epu8(2)).image_mask(), Epu8({0, 0, FF}, 0)); -// REQUIRE(equal(PTransf16(Epu8({2, 2, 2, 0xf}, 2)).image_mask(), -// Epu8({0, 0, FF, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, FF}, 0)); -// REQUIRE(equal( -// PTransf16(Epu8({0, 2, 2, 0xf, 2, 2, 2, 2, 5, 2}, 2)).image_mask(), -// Epu8({FF, 0, FF, 0, 0, FF, 0, 0, 0, 0, 0, 0, 0, 0, 0, FF}, 0)); -// REQUIRE(equal( -// PTransf16(Epu8({0, 2, 2, 0xf, 2, 2, 2, 2, 5, 2}, -// 2)).image_mask(false), Epu8({FF, 0, FF, 0, 0, FF, 0, 0, 0, 0, 0, 0, -// 0, 0, 0, FF}, 0)); -// REQUIRE(equal( -// PTransf16(Epu8({0, 2, 2, 0xf, 2, 2, 2, 2, 5, 2}, -// 2)).image_mask(true), Epu8({0, FF, 0, FF, FF, 0, FF, FF, FF, FF, FF, -// FF, FF, FF, FF, 0}, 0)); -// } +TEST_CASE("PTransf16::image_mask", "[PTransf16][002]") { + REQUIRE_THAT(PTransf16({}).image_mask(), Equals(Epu8(FF))); + REQUIRE_THAT(PTransf16({}).image_mask(false), Equals(Epu8(FF))); + REQUIRE_THAT(PTransf16({}).image_mask(true), Equals(Epu8(0))); + REQUIRE_THAT(PTransf16({4, 4, 4, 4}).image_mask(), + Equals(Epu8({0, 0, 0, 0}, FF))); + REQUIRE_THAT(PTransf16({4, 4, 4, 4}).image_mask(false), + Equals(Epu8({0, 0, 0, 0}, FF))); + REQUIRE_THAT(PTransf16({4, 4, 4, 4}).image_mask(true), + Equals(Epu8({FF, FF, FF, FF}, 0))); + REQUIRE_THAT(PTransf16(Epu8(1)).image_mask(), Equals(Epu8({0, FF}, 0))); + REQUIRE_THAT(PTransf16(Epu8(2)).image_mask(), Equals(Epu8({0, 0, FF}, 0))); + REQUIRE_THAT(PTransf16(Epu8({2, 2, 2, 0xf}, 2)).image_mask(), + Equals(Epu8({0, 0, FF, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, FF}, 0))); + REQUIRE_THAT( + PTransf16(Epu8({0, 2, 2, 0xf, 2, 2, 2, 2, 5, 2}, 2)).image_mask(), + Equals(Epu8({FF, 0, FF, 0, 0, FF, 0, 0, 0, 0, 0, 0, 0, 0, 0, FF}, 0))); + REQUIRE_THAT( + PTransf16(Epu8({0, 2, 2, 0xf, 2, 2, 2, 2, 5, 2}, 2)).image_mask(false), + Equals(Epu8({FF, 0, FF, 0, 0, FF, 0, 0, 0, 0, 0, 0, 0, 0, 0, FF}, 0))); + REQUIRE_THAT( + PTransf16(Epu8({0, 2, 2, 0xf, 2, 2, 2, 2, 5, 2}, 2)).image_mask(true), + Equals(Epu8({0, FF, 0, FF, FF, 0, FF, FF, FF, FF, FF, FF, FF, FF, FF, 0}, 0))); +} + +TEST_CASE("PTransf16::image_mask_ref_ref", "[PTransf16][002]") { + REQUIRE_THAT(PTransf16({}).image_mask_ref(), Equals(Epu8(FF))); + REQUIRE_THAT(PTransf16({}).image_mask_ref(false), Equals(Epu8(FF))); + REQUIRE_THAT(PTransf16({}).image_mask_ref(true), Equals(Epu8(0))); + REQUIRE_THAT(PTransf16({4, 4, 4, 4}).image_mask_ref(), + Equals(Epu8({0, 0, 0, 0}, FF))); + REQUIRE_THAT(PTransf16({4, 4, 4, 4}).image_mask_ref(false), + Equals(Epu8({0, 0, 0, 0}, FF))); + REQUIRE_THAT(PTransf16({4, 4, 4, 4}).image_mask_ref(true), + Equals(Epu8({FF, FF, FF, FF}, 0))); + REQUIRE_THAT(PTransf16(Epu8(1)).image_mask_ref(), Equals(Epu8({0, FF}, 0))); + REQUIRE_THAT(PTransf16(Epu8(2)).image_mask_ref(), Equals(Epu8({0, 0, FF}, 0))); + REQUIRE_THAT(PTransf16(Epu8({2, 2, 2, 0xf}, 2)).image_mask_ref(), + Equals(Epu8({0, 0, FF, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, FF}, 0))); + REQUIRE_THAT( + PTransf16(Epu8({0, 2, 2, 0xf, 2, 2, 2, 2, 5, 2}, 2)).image_mask_ref(), + Equals(Epu8({FF, 0, FF, 0, 0, FF, 0, 0, 0, 0, 0, 0, 0, 0, 0, FF}, 0))); + REQUIRE_THAT( + PTransf16(Epu8({0, 2, 2, 0xf, 2, 2, 2, 2, 5, 2}, 2)).image_mask_ref(false), + Equals(Epu8({FF, 0, FF, 0, 0, FF, 0, 0, 0, 0, 0, 0, 0, 0, 0, FF}, 0))); + REQUIRE_THAT( + PTransf16(Epu8({0, 2, 2, 0xf, 2, 2, 2, 2, 5, 2}, 2)).image_mask_ref(true), + Equals(Epu8({0, FF, 0, FF, FF, 0, FF, FF, FF, FF, FF, FF, FF, FF, FF, 0}, 0))); +} // TODO uncomment // TEST_CASE("PTransf16::left_one", "[PTransf16][003]") {