diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 6e03cafd7..79a6a4064 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -328,7 +328,9 @@ if(WITH_TESTS) # Command tests ../tests/commands/DriveToWaypointCommandTest.cpp # Util tests - ../tests/util/UtilTest.cpp + ../tests/util/CoreTest.cpp + ../tests/util/MathTest.cpp + ../tests/util/TimeTest.cpp ../tests/util/SchedulerTest.cpp # Protocol/teleop tests ../tests/kinematics/DiffWristKinematicsTest.cpp) diff --git a/src/utils/core.cpp b/src/utils/core.cpp index 96c8e4438..f84d81ca0 100644 --- a/src/utils/core.cpp +++ b/src/utils/core.cpp @@ -1,6 +1,6 @@ #include "core.h" -namespace util{ +namespace util { bool almostEqual(double a, double b, double threshold) { return std::abs(a - b) < threshold; @@ -15,4 +15,16 @@ frozen::string freezeStr(const std::string& str) { return frozen::string(str.c_str(), str.size()); } -} \ No newline at end of file +RAIIHelper::RAIIHelper(const std::function& f) : f(f) {} + +RAIIHelper::RAIIHelper(RAIIHelper&& other) : f(std::move(other.f)) { + other.f = {}; +} + +RAIIHelper::~RAIIHelper() { + if (f) { + f(); + } +} + +} // namespace util diff --git a/src/utils/core.h b/src/utils/core.h index 1840a36bb..0f7dfb1b5 100644 --- a/src/utils/core.h +++ b/src/utils/core.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -78,4 +79,39 @@ std::tuple pairToTuple(const std::pair& pair) { return std::tuple(pair.first, pair.second); } +/** + * @brief A helper class for executing a function when leaving a scope, in RAII-style. + */ +class RAIIHelper { +public: + /** + * @brief Construct a new RAIIHelper. + * + * @param f The function to execute when this object is destructed. + */ + RAIIHelper(const std::function& f); + + RAIIHelper(const RAIIHelper&) = delete; + + /** + * @brief Move an RAIIHelper into another. + * + * The RAIIHelper being moved is guaranteed to be empty after this, i.e. will not execute + * any code when being destructed. + * + * @param other The RAIIHelper to move. + */ + RAIIHelper(RAIIHelper&& other); + + RAIIHelper& operator=(const RAIIHelper&) = delete; + + /** + * @brief Destroy the RAIIHelper object, executing the given function, if not empty. + */ + ~RAIIHelper(); + +private: + std::function f; +}; + } // namespace util diff --git a/src/utils/scheduler.h b/src/utils/scheduler.h index bced6b04d..6be1a4a53 100644 --- a/src/utils/scheduler.h +++ b/src/utils/scheduler.h @@ -1,5 +1,7 @@ #pragma once +#include "core.h" + #include #include #include @@ -299,6 +301,219 @@ class Watchdog : private impl::Notifiable { } }; +/** + * @brief An abstract class that can be overridden to run long-running tasks and encapsulate + * task-related data. + * + * Client code should use this class by deriving from it and overriding AsyncTask::task. + * + * For simpler tasks that just require a function to be run periodically, consider + * PeriodicTask. + * + * @tparam Clock The clock to use for timing. + */ +template +class AsyncTask : private virtual impl::Notifiable { +public: + /** + * @brief Construct a new task. + * + * @param name The name of this task, for logging purposes. + */ + AsyncTask(const std::optional& name = std::nullopt) + : name(name), running(false), quitting(false) {} + + AsyncTask(const AsyncTask&) = delete; + + virtual ~AsyncTask() { + stop(); + if (thread.joinable()) { + thread.join(); + } + } + + AsyncTask& operator=(const AsyncTask&) = delete; + + /** + * @brief Start the task. + * + * If the task is already running, do nothing. + */ + virtual void start() { + std::lock_guard lock(mutex); + if (!running) { + std::lock_guard threadLock(threadMutex); + if (thread.joinable()) { + thread.join(); + } + running = true; + quitting = false; + thread = std::thread(&AsyncTask::run, this); + } + } + + /** + * @brief Stop the task and wait for it to finish. + * + * If the task is not running, do nothing. + */ + virtual void stop() { + bool isRunning = false; + { + std::lock_guard lock(mutex); + if (isRunningInternal()) { + quitting = true; + isRunning = true; + } + } + if (isRunning) { + cv.notify_one(); + std::lock_guard threadLock(threadMutex); + if (thread.joinable()) { + thread.join(); + } + } + } + + /** + * @brief Check if the task is running. + * + * @return If the task is currently running. + */ + bool isRunning() { + std::lock_guard lock(mutex); + return running; + } + +protected: + /** + * @brief The long-running task, overridden by client code. + * + * If a task wants to stop itself, it can just return. + * + * @param lock The lock on the private internal state of the AsyncTask. Client code should + * generally not use this except for the wait_until_xxx methods. + */ + virtual void task(std::unique_lock& lock) = 0; + + /** + * @brief Version of AsyncTask::isRunning() that does no synchronization. + * + * This is useful if called from AsyncTask::task(), to prevent deadlocks. + * + * @return If the task is currently running. + */ + bool isRunningInternal() { + return running; + } + + /** + * @brief Wait until the specified time point, or until the task has been stopped. + * + * @param lock The lock passed to AsyncTask::task. + * @param tp The time point to wait until. + * @return true iff the task was stopped while waiting. + */ + bool wait_until(std::unique_lock& lock, + const std::chrono::time_point& tp) { + return cv.wait_until(lock, tp, [&]() { return quitting; }); + } + + /** + * @brief Wait for a given duration, or until the task has been stopped. + * + * @param lock The lock passed to AsyncTask::task. + * @param tp The duration of time to wait for. + * @return true iff the task was stopped while waiting. + */ + template + bool wait_for(std::unique_lock& lock, + const std::chrono::duration& dur) { + return wait_until(lock, Clock::now() + dur); + } + + /** + * @brief Wait until the task has been stopped. + * + * @param lock The lock passed to AsyncTask::task. + */ + void wait_until_done(std::unique_lock& lock) { + return cv.wait(lock, [&]() { return quitting; }); + } + + /** + * @brief Not for use by client code. + */ + void notify() override { + cv.notify_all(); + } + +private: + std::optional name; + bool running; + bool quitting; + std::condition_variable cv; + std::thread thread; + + // If acquiring both mutexes, acquire mutex first then threadMutex. + std::mutex mutex; // protects everything except thread + std::mutex threadMutex; // protects only thread + + friend void util::impl::notifyScheduler<>(AsyncTask&); + + void run() { + if (name.has_value()) { + loguru::set_thread_name(name->c_str()); + } + std::unique_lock lock(mutex); + // clear flags when exiting + RAIIHelper r([&]() { running = quitting = false; }); + task(lock); + } +}; + +/** + * @brief Implements a task that executes a function periodically. + * + * Note that all PeriodicTask instances are run on the same thread, so the task function should + * not block. + * + * @tparam Clock The clock to use for timing. + */ +template +class PeriodicTask : public AsyncTask, private virtual impl::Notifiable { +public: + /** + * @brief Construct a new periodic task. + * + * @param period The period at which to run the task. + * @param f The function to execute every period. + */ + PeriodicTask(const std::chrono::milliseconds& period, const std::function& f) + : AsyncTask(), period(period), f(f) {} + +protected: + void task(std::unique_lock& lock) override { + auto event = scheduler.scheduleEvent(period, f); + AsyncTask::wait_until_done(lock); + scheduler.removeEvent(event); + } + + void notify() override { + impl::notifyScheduler(scheduler); + AsyncTask::notify(); + } + +private: + std::chrono::milliseconds period; + std::function f; + + inline static PeriodicScheduler scheduler = + PeriodicScheduler("PeriodicTask_Scheduler"); + + friend void util::impl::notifyScheduler<>(PeriodicTask&); +}; + namespace impl { /** diff --git a/tests/util/CoreTest.cpp b/tests/util/CoreTest.cpp new file mode 100644 index 000000000..2839ed99a --- /dev/null +++ b/tests/util/CoreTest.cpp @@ -0,0 +1,25 @@ +#include "../../src/utils/core.h" + +#include + +using namespace util; + +TEST_CASE("Test RAIIHelper", "[util][core]") { + SECTION("Test RAIIHelper executes") { + bool called = false; + { + RAIIHelper r([&]() { called = true; }); + } + REQUIRE(called); + } + + SECTION("Test RAIIHelper moves properly") { + int i = 0; + { + RAIIHelper r([&]() { i++; }); + { RAIIHelper r2(std::move(r)); } + REQUIRE(i == 1); + } + REQUIRE(i == 1); + } +} diff --git a/tests/util/SchedulerTest.cpp b/tests/util/SchedulerTest.cpp index 3a864b8f4..ab851fb7b 100644 --- a/tests/util/SchedulerTest.cpp +++ b/tests/util/SchedulerTest.cpp @@ -97,53 +97,144 @@ TEST_CASE("Test Watchdog", "[util][scheduler]") { auto l = std::make_shared(1); Watchdog wd(100ms, [&]() { l->count_down(); }); for (int i = 0; i < 5; i++) { - REQUIRE_FALSE(l->wait_for(10ms)); + REQUIRE_FALSE(l->wait_for(50ms)); advanceFakeClock(50ms, 5ms, wd); - REQUIRE_FALSE(l->wait_for(10ms)); + REQUIRE_FALSE(l->wait_for(50ms)); wd.feed(); } - REQUIRE_FALSE(l->wait_for(10ms)); + REQUIRE_FALSE(l->wait_for(50ms)); advanceFakeClock(200ms, 5ms, wd); - REQUIRE(l->wait_for(10ms)); + REQUIRE(l->wait_for(50ms)); // check that callback is not called again while starved l = std::make_shared(1); advanceFakeClock(200ms, 5ms, wd); - REQUIRE_FALSE(l->wait_for(10ms)); + REQUIRE_FALSE(l->wait_for(50ms)); // check that callback is called after feeding and starving again wd.feed(); - REQUIRE_FALSE(l->wait_for(10ms)); + REQUIRE_FALSE(l->wait_for(50ms)); advanceFakeClock(100ms, 5ms, wd); - REQUIRE(l->wait_for(10ms)); + REQUIRE(l->wait_for(50ms)); } SECTION("Test keep calling while starved") { auto l = std::make_shared(1); - Watchdog wd( - 100ms, [&]() { l->count_down(); }, true); + auto fn = [&]() { l->count_down(); }; + Watchdog wd(100ms, fn, true); for (int i = 0; i < 5; i++) { - REQUIRE_FALSE(l->wait_for(10ms)); + REQUIRE_FALSE(l->wait_for(50ms)); advanceFakeClock(50ms, 5ms, wd); - REQUIRE_FALSE(l->wait_for(10ms)); + REQUIRE_FALSE(l->wait_for(50ms)); wd.feed(); } - REQUIRE_FALSE(l->wait_for(10ms)); + REQUIRE_FALSE(l->wait_for(50ms)); advanceFakeClock(100ms, 5ms, wd); - REQUIRE(l->wait_for(10ms)); + REQUIRE(l->wait_for(50ms)); // check that callback is repeatedly called while starved for (int i = 0; i < 3; i++) { l = std::make_shared(1); advanceFakeClock(100ms, 5ms, wd); - REQUIRE(l->wait_for(10ms)); + REQUIRE(l->wait_for(50ms)); } // check that callback is called after feeding and starving again l = std::make_shared(1); wd.feed(); - REQUIRE_FALSE(l->wait_for(10ms)); + REQUIRE_FALSE(l->wait_for(50ms)); advanceFakeClock(100ms, 5ms, wd); - REQUIRE(l->wait_for(10ms)); + REQUIRE(l->wait_for(50ms)); + } +} + +TEST_CASE("Test PeriodicTask", "[util][scheduler]") { + // NOTE: after calling start() or stop() we must force the thread to suspend in order for + // the PeriodicTask thread to start up and begin executing. + // So it should be start() -> checkLatch() -> advanceTime() -> checkLatch(). + + auto l = std::make_unique(1); + PeriodicTask pt(100ms, [&]() { l->count_down(); }); + + // check that it doesn't start upon construction + advanceFakeClock(350ms, 50ms, pt); + REQUIRE_FALSE(l->wait_for(100ms)); + REQUIRE_FALSE(pt.isRunning()); + + l = std::make_unique(3); + pt.start(); + REQUIRE(pt.isRunning()); + REQUIRE_FALSE(l->wait_for(50ms)); + advanceFakeClock(350ms, 50ms, pt); + REQUIRE(l->wait_for(50ms)); + + pt.stop(); + REQUIRE_FALSE(pt.isRunning()); + l = std::make_unique(1); + advanceFakeClock(150ms, 50ms, pt); + REQUIRE_FALSE(l->wait_for(50ms)); + + l = std::make_unique(3); + pt.start(); + REQUIRE(pt.isRunning()); + REQUIRE_FALSE(l->wait_for(50ms)); + advanceFakeClock(350ms, 50ms, pt); + REQUIRE(l->wait_for(50ms)); +} + +TEST_CASE("Test AsyncTask", "[util][scheduler]") { + // since PeriodicTask uses AsyncTask, a lot of the functionality is already tested. + SECTION("Test AsyncTask with task that terminates") { + class ExitingAsyncTask : public AsyncTask { + public: + ExitingAsyncTask(latch& l) : l_(l) {} + + protected: + void task(std::unique_lock& lock) override { + wait_for(lock, 100ms); + l_.count_down(); + } + + private: + latch& l_; + }; + + latch l(1); + ExitingAsyncTask task(l); + REQUIRE_FALSE(task.isRunning()); + task.start(); + REQUIRE(task.isRunning()); + REQUIRE_FALSE(l.wait_for(50ms)); + AsyncTask& at = task; + advanceFakeClock(150ms, 50ms, at); + REQUIRE(l.wait_for(50ms)); + } + + SECTION("Test that AsyncTask stops when killed during wait") { + class WaitingAsyncTask : public AsyncTask { + public: + WaitingAsyncTask(latch& l) : l_(l) {} + + protected: + void task(std::unique_lock& lock) override { + wait_for(lock, 100ms); + l_.count_down(); + } + + private: + latch& l_; + }; + + latch l(1); + WaitingAsyncTask task(l); + REQUIRE_FALSE(task.isRunning()); + task.start(); + REQUIRE(task.isRunning()); + REQUIRE_FALSE(l.wait_for(50ms)); + AsyncTask& at = task; + advanceFakeClock(50ms, 50ms, at); + REQUIRE_FALSE(l.wait_for(50ms)); + task.stop(); + REQUIRE(l.wait_for(50ms)); } } diff --git a/tests/util/UtilTest.cpp b/tests/util/TimeTest.cpp similarity index 83% rename from tests/util/UtilTest.cpp rename to tests/util/TimeTest.cpp index bb8468c86..b9fdacaea 100644 --- a/tests/util/UtilTest.cpp +++ b/tests/util/TimeTest.cpp @@ -8,7 +8,7 @@ using namespace Catch::literals; using namespace util; using namespace std::chrono_literals; -TEST_CASE("Test Duration To Seconds", "[util]") { +TEST_CASE("Test Duration To Seconds", "[util][time]") { auto dur = 1500ms; REQUIRE(durationToSec(dur) == 1.5_a); }