diff --git a/c++/itertools/itertools.hpp b/c++/itertools/itertools.hpp index d228eed..ee21091 100644 --- a/c++/itertools/itertools.hpp +++ b/c++/itertools/itertools.hpp @@ -372,9 +372,9 @@ namespace itertools { bool operator==(strided const &) const = default; private: - [[nodiscard]] std::ptrdiff_t step() const { - auto end_idx = distance(std::cbegin(x), std::cend(x)); - return (end_idx % stride == 0 ? 0 : stride - end_idx % stride); + [[nodiscard]] std::ptrdiff_t end_offset() const { + auto size = distance(std::cbegin(x), std::cend(x)); + return (size == 0) ? 0 : ((size - 1) / stride + 1) * stride; } public: @@ -382,8 +382,8 @@ namespace itertools { [[nodiscard]] const_iterator cbegin() const noexcept { return {std::cbegin(x), stride}; } [[nodiscard]] const_iterator begin() const noexcept { return cbegin(); } - [[nodiscard]] iterator end() noexcept { return {std::next(std::end(x), step()), stride}; } - [[nodiscard]] const_iterator cend() const noexcept { return {std::next(std::cend(x), step()), stride}; } + [[nodiscard]] iterator end() noexcept { return {std::next(std::begin(x), end_offset()), stride}; } + [[nodiscard]] const_iterator cend() const noexcept { return {std::next(std::cbegin(x), end_offset()), stride}; } [[nodiscard]] const_iterator end() const noexcept { return cend(); } }; diff --git a/test/c++/itertools.cpp b/test/c++/itertools.cpp index 5573d25..5f53e52 100644 --- a/test/c++/itertools.cpp +++ b/test/c++/itertools.cpp @@ -203,6 +203,8 @@ TEST(Itertools, Product_Range) { TEST(Itertools, Stride) { std::vector V1{0, 1, 2, 3, 4}; + + // simple stride for (int s = 1; s < 6; ++s) { int i = 0; int size = 0; @@ -216,9 +218,31 @@ TEST(Itertools, Stride) { // empty range std::vector V2; - int size = 0; - for ([[maybe_unused]] auto x : stride(V2, 1)) { ++size; } - EXPECT_EQ(size, 0); + int empty_size = 0; + for ([[maybe_unused]] auto x : stride(V2, 2)) { ++empty_size; } + EXPECT_EQ(empty_size, 0); + + // stride and product + for (int s = 1; s < 6; ++s) { + int idx = 0; + for (auto [x1, x2] : stride(product(V1, V1), s)) { + auto i = idx / static_cast(V1.size()); + auto j = idx - i * static_cast(V1.size()); + EXPECT_EQ(x1, i); + EXPECT_EQ(x2, j); + idx += s; + } + } + + // zip and stride + for (int s = 1; s < 6; ++s) { + int i = 0; + for (auto [x1, x2] : zip(stride(V1, s), stride(V1, s))) { + EXPECT_EQ(x1, i); + EXPECT_EQ(x2, i); + i += s; + } + } } int main(int argc, char **argv) {