diff --git a/include/marl/task.h b/include/marl/task.h index 1e7d3f4..1314242 100644 --- a/include/marl/task.h +++ b/include/marl/task.h @@ -17,14 +17,14 @@ #include "export.h" -#include +#include namespace marl { // Task is a unit of work for the scheduler. class Task { public: - using Function = std::function; + using Function = std::packaged_task; enum class Flags { None = 0, @@ -37,21 +37,17 @@ class Task { }; MARL_NO_EXPORT inline Task(); - MARL_NO_EXPORT inline Task(const Task&); + MARL_NO_EXPORT inline Task(const Task&) = delete; MARL_NO_EXPORT inline Task(Task&&); - MARL_NO_EXPORT inline Task(const Function& function, - Flags flags = Flags::None); - MARL_NO_EXPORT inline Task(Function&& function, Flags flags = Flags::None); - MARL_NO_EXPORT inline Task& operator=(const Task&); + template + MARL_NO_EXPORT inline Task(F&& function, Flags flags = Flags::None); + MARL_NO_EXPORT inline Task& operator=(const Task&) = delete; MARL_NO_EXPORT inline Task& operator=(Task&&); - MARL_NO_EXPORT inline Task& operator=(const Function&); + MARL_NO_EXPORT inline Task& operator=(const Function&) = delete; MARL_NO_EXPORT inline Task& operator=(Function&&); - // operator bool() returns true if the Task has a valid function. - MARL_NO_EXPORT inline operator bool() const; - // operator()() runs the task. - MARL_NO_EXPORT inline void operator()() const; + MARL_NO_EXPORT inline void operator()(); // is() returns true if the Task was created with the given flag. MARL_NO_EXPORT inline bool is(Flags flag) const; @@ -62,38 +58,23 @@ class Task { }; Task::Task() {} -Task::Task(const Task& o) : function(o.function), flags(o.flags) {} Task::Task(Task&& o) : function(std::move(o.function)), flags(o.flags) {} -Task::Task(const Function& function_, Flags flags_ /* = Flags::None */) - : function(function_), flags(flags_) {} -Task::Task(Function&& function_, Flags flags_ /* = Flags::None */) - : function(std::move(function_)), flags(flags_) {} -Task& Task::operator=(const Task& o) { - function = o.function; - flags = o.flags; - return *this; -} +template +Task::Task(F&& fn, Flags flags_ /* = Flags::None */) + : function(std::forward(fn)), flags(flags_) {} Task& Task::operator=(Task&& o) { function = std::move(o.function); flags = o.flags; return *this; } -Task& Task::operator=(const Function& f) { - function = f; - flags = Flags::None; - return *this; -} Task& Task::operator=(Function&& f) { function = std::move(f); flags = Flags::None; return *this; } -Task::operator bool() const { - return function.operator bool(); -} -void Task::operator()() const { +void Task::operator()() { function(); } diff --git a/src/scheduler_test.cpp b/src/scheduler_test.cpp index 64cf995..b6d7c1a 100644 --- a/src/scheduler_test.cpp +++ b/src/scheduler_test.cpp @@ -20,6 +20,7 @@ #include "marl/waitgroup.h" #include +#include TEST_F(WithoutBoundScheduler, SchedulerConstructAndDestruct) { auto scheduler = std::unique_ptr( @@ -108,6 +109,36 @@ TEST_P(WithBoundScheduler, ScheduleWithArgs) { ASSERT_EQ(got, "s: 'a string', i: 42, b: true"); } +TEST_P(WithBoundScheduler, ScheduleWithMovedCapture) { +#if __cplusplus >= 201402L // C++14 or greater + std::unique_ptr move_me(new std::string("move me")); + std::string got; + marl::WaitGroup wg(1); + marl::schedule([moved = std::move(move_me), wg, &got]() { + got = *moved; + wg.done(); + }); + wg.wait(); + ASSERT_EQ(got, "move me"); +#else + GTEST_SKIP() << "Test requires c++14 or greater"; +#endif +} + +TEST_P(WithBoundScheduler, ScheduleWithMovedArg) { + std::unique_ptr move_me(new std::string("move me")); + std::string got; + marl::WaitGroup wg(1); + marl::schedule( + [wg, &got](std::unique_ptr& str) { + got = *str; + wg.done(); + }, + std::move(move_me)); + wg.wait(); + ASSERT_EQ(got, "move me"); +} + TEST_P(WithBoundScheduler, FibersResumeOnSameThread) { marl::WaitGroup fence(1); marl::WaitGroup wg(1000);