Skip to content

Commit

Permalink
scheduler: Support moved tast captures / arguments
Browse files Browse the repository at this point in the history
Replace the internal use of `std::function` with `std::packaged_task`.
`std::function` requires that the wrapped function is CopyConstructable, where as a `std::packaged_task` does not.
This allows the tasks to hold `std::move`'d values.

This is an API / ABI breaking change, but I believe few people would be copying `marl::Task`s.
  • Loading branch information
ben-clayton committed Mar 27, 2023
1 parent 9c689c9 commit a2ca890
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 31 deletions.
43 changes: 12 additions & 31 deletions include/marl/task.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

#include "export.h"

#include <functional>
#include <future>

namespace marl {

// Task is a unit of work for the scheduler.
class Task {
public:
using Function = std::function<void()>;
using Function = std::packaged_task<void()>;

enum class Flags {
None = 0,
Expand All @@ -37,21 +37,17 @@ class Task {
};

MARL_NO_EXPORT inline Task();
MARL_NO_EXPORT inline Task(const Task&);
MARL_NO_EXPORT inline Task(const Task&) = delete;
MARL_NO_EXPORT inline Task(Task&&);
MARL_NO_EXPORT inline Task(const Function& function,
Flags flags = Flags::None);
MARL_NO_EXPORT inline Task(Function&& function, Flags flags = Flags::None);
MARL_NO_EXPORT inline Task& operator=(const Task&);
template <typename F>
MARL_NO_EXPORT inline Task(F&& function, Flags flags = Flags::None);
MARL_NO_EXPORT inline Task& operator=(const Task&) = delete;
MARL_NO_EXPORT inline Task& operator=(Task&&);
MARL_NO_EXPORT inline Task& operator=(const Function&);
MARL_NO_EXPORT inline Task& operator=(const Function&) = delete;
MARL_NO_EXPORT inline Task& operator=(Function&&);

// operator bool() returns true if the Task has a valid function.
MARL_NO_EXPORT inline operator bool() const;

// operator()() runs the task.
MARL_NO_EXPORT inline void operator()() const;
MARL_NO_EXPORT inline void operator()();

// is() returns true if the Task was created with the given flag.
MARL_NO_EXPORT inline bool is(Flags flag) const;
Expand All @@ -62,38 +58,23 @@ class Task {
};

Task::Task() {}
Task::Task(const Task& o) : function(o.function), flags(o.flags) {}
Task::Task(Task&& o) : function(std::move(o.function)), flags(o.flags) {}
Task::Task(const Function& function_, Flags flags_ /* = Flags::None */)
: function(function_), flags(flags_) {}
Task::Task(Function&& function_, Flags flags_ /* = Flags::None */)
: function(std::move(function_)), flags(flags_) {}
Task& Task::operator=(const Task& o) {
function = o.function;
flags = o.flags;
return *this;
}
template <typename F>
Task::Task(F&& fn, Flags flags_ /* = Flags::None */)
: function(std::forward<F>(fn)), flags(flags_) {}
Task& Task::operator=(Task&& o) {
function = std::move(o.function);
flags = o.flags;
return *this;
}

Task& Task::operator=(const Function& f) {
function = f;
flags = Flags::None;
return *this;
}
Task& Task::operator=(Function&& f) {
function = std::move(f);
flags = Flags::None;
return *this;
}
Task::operator bool() const {
return function.operator bool();
}

void Task::operator()() const {
void Task::operator()() {
function();
}

Expand Down
31 changes: 31 additions & 0 deletions src/scheduler_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "marl/waitgroup.h"

#include <atomic>
#include <memory>

TEST_F(WithoutBoundScheduler, SchedulerConstructAndDestruct) {
auto scheduler = std::unique_ptr<marl::Scheduler>(
Expand Down Expand Up @@ -108,6 +109,36 @@ TEST_P(WithBoundScheduler, ScheduleWithArgs) {
ASSERT_EQ(got, "s: 'a string', i: 42, b: true");
}

TEST_P(WithBoundScheduler, ScheduleWithMovedCapture) {
#if __cplusplus >= 201402L // C++14 or greater
std::unique_ptr<std::string> move_me(new std::string("move me"));
std::string got;
marl::WaitGroup wg(1);
marl::schedule([moved = std::move(move_me), wg, &got]() {
got = *moved;
wg.done();
});
wg.wait();
ASSERT_EQ(got, "move me");
#else
GTEST_SKIP() << "Test requires c++14 or greater";
#endif
}

TEST_P(WithBoundScheduler, ScheduleWithMovedArg) {
std::unique_ptr<std::string> move_me(new std::string("move me"));
std::string got;
marl::WaitGroup wg(1);
marl::schedule(
[wg, &got](std::unique_ptr<std::string>& str) {
got = *str;
wg.done();
},
std::move(move_me));
wg.wait();
ASSERT_EQ(got, "move me");
}

TEST_P(WithBoundScheduler, FibersResumeOnSameThread) {
marl::WaitGroup fence(1);
marl::WaitGroup wg(1000);
Expand Down

0 comments on commit a2ca890

Please sign in to comment.