Skip to content

Commit

Permalink
Fix for combination of product and stride
Browse files Browse the repository at this point in the history
  • Loading branch information
Thoemi09 authored and Wentzell committed Oct 2, 2023
1 parent a6f4d84 commit c06f2d4
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
10 changes: 5 additions & 5 deletions c++/itertools/itertools.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,18 +372,18 @@ namespace itertools {
bool operator==(strided const &) const = default;

private:
[[nodiscard]] std::ptrdiff_t step() const {
auto end_idx = distance(std::cbegin(x), std::cend(x));
return (end_idx % stride == 0 ? 0 : stride - end_idx % stride);
[[nodiscard]] std::ptrdiff_t end_offset() const {
auto size = distance(std::cbegin(x), std::cend(x));
return (size == 0) ? 0 : ((size - 1) / stride + 1) * stride;
}

public:
[[nodiscard]] iterator begin() noexcept { return {std::begin(x), stride}; }
[[nodiscard]] const_iterator cbegin() const noexcept { return {std::cbegin(x), stride}; }
[[nodiscard]] const_iterator begin() const noexcept { return cbegin(); }

[[nodiscard]] iterator end() noexcept { return {std::next(std::end(x), step()), stride}; }
[[nodiscard]] const_iterator cend() const noexcept { return {std::next(std::cend(x), step()), stride}; }
[[nodiscard]] iterator end() noexcept { return {std::next(std::begin(x), end_offset()), stride}; }
[[nodiscard]] const_iterator cend() const noexcept { return {std::next(std::cbegin(x), end_offset()), stride}; }
[[nodiscard]] const_iterator end() const noexcept { return cend(); }
};

Expand Down
30 changes: 27 additions & 3 deletions test/c++/itertools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ TEST(Itertools, Product_Range) {

TEST(Itertools, Stride) {
std::vector<int> V1{0, 1, 2, 3, 4};

// simple stride
for (int s = 1; s < 6; ++s) {
int i = 0;
int size = 0;
Expand All @@ -216,9 +218,31 @@ TEST(Itertools, Stride) {

// empty range
std::vector<int> V2;
int size = 0;
for ([[maybe_unused]] auto x : stride(V2, 1)) { ++size; }
EXPECT_EQ(size, 0);
int empty_size = 0;
for ([[maybe_unused]] auto x : stride(V2, 2)) { ++empty_size; }
EXPECT_EQ(empty_size, 0);

// stride and product
for (int s = 1; s < 6; ++s) {
int idx = 0;
for (auto [x1, x2] : stride(product(V1, V1), s)) {
auto i = idx / static_cast<int>(V1.size());
auto j = idx - i * static_cast<int>(V1.size());
EXPECT_EQ(x1, i);
EXPECT_EQ(x2, j);
idx += s;
}
}

// zip and stride
for (int s = 1; s < 6; ++s) {
int i = 0;
for (auto [x1, x2] : zip(stride(V1, s), stride(V1, s))) {
EXPECT_EQ(x1, i);
EXPECT_EQ(x2, i);
i += s;
}
}
}

int main(int argc, char **argv) {
Expand Down

0 comments on commit c06f2d4

Please sign in to comment.