Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tasks to utils #300

Merged
merged 7 commits into from
Jan 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,9 @@ if(WITH_TESTS)
# Command tests
../tests/commands/DriveToWaypointCommandTest.cpp
# Util tests
../tests/util/UtilTest.cpp
../tests/util/CoreTest.cpp
../tests/util/MathTest.cpp
../tests/util/TimeTest.cpp
../tests/util/SchedulerTest.cpp
# Protocol/teleop tests
../tests/kinematics/DiffWristKinematicsTest.cpp)
Expand Down
16 changes: 14 additions & 2 deletions src/utils/core.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "core.h"

namespace util{
namespace util {

bool almostEqual(double a, double b, double threshold) {
return std::abs(a - b) < threshold;
Expand All @@ -15,4 +15,16 @@ frozen::string freezeStr(const std::string& str) {
return frozen::string(str.c_str(), str.size());
}

}
RAIIHelper::RAIIHelper(const std::function<void()>& f) : f(f) {}

RAIIHelper::RAIIHelper(RAIIHelper&& other) : f(std::move(other.f)) {
other.f = {};
}

RAIIHelper::~RAIIHelper() {
if (f) {
f();
}
}

} // namespace util
36 changes: 36 additions & 0 deletions src/utils/core.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <functional>
#include <string>
#include <tuple>
#include <unordered_map>
Expand Down Expand Up @@ -78,4 +79,39 @@ std::tuple<T, U> pairToTuple(const std::pair<T, U>& pair) {
return std::tuple<T, U>(pair.first, pair.second);
}

/**
* @brief A helper class for executing a function when leaving a scope, in RAII-style.
*/
class RAIIHelper {
public:
/**
* @brief Construct a new RAIIHelper.
*
* @param f The function to execute when this object is destructed.
*/
RAIIHelper(const std::function<void()>& f);

RAIIHelper(const RAIIHelper&) = delete;

/**
* @brief Move an RAIIHelper into another.
*
* The RAIIHelper being moved is guaranteed to be empty after this, i.e. will not execute
* any code when being destructed.
*
* @param other The RAIIHelper to move.
*/
RAIIHelper(RAIIHelper&& other);

RAIIHelper& operator=(const RAIIHelper&) = delete;

/**
* @brief Destroy the RAIIHelper object, executing the given function, if not empty.
*/
~RAIIHelper();

private:
std::function<void()> f;
};

} // namespace util
215 changes: 215 additions & 0 deletions src/utils/scheduler.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include "core.h"

#include <chrono>
#include <condition_variable>
#include <functional>
Expand Down Expand Up @@ -299,6 +301,219 @@ class Watchdog : private impl::Notifiable {
}
};

/**
* @brief An abstract class that can be overridden to run long-running tasks and encapsulate
* task-related data.
*
* Client code should use this class by deriving from it and overriding AsyncTask::task.
*
* For simpler tasks that just require a function to be run periodically, consider
* PeriodicTask.
*
* @tparam Clock The clock to use for timing.
*/
template <typename Clock = std::chrono::steady_clock>
class AsyncTask : private virtual impl::Notifiable {
public:
/**
* @brief Construct a new task.
*
* @param name The name of this task, for logging purposes.
*/
AsyncTask(const std::optional<std::string>& name = std::nullopt)
: name(name), running(false), quitting(false) {}

AsyncTask(const AsyncTask&) = delete;

virtual ~AsyncTask() {
stop();
if (thread.joinable()) {
thread.join();
}
quinnmp marked this conversation as resolved.
Show resolved Hide resolved
}

AsyncTask& operator=(const AsyncTask&) = delete;

/**
* @brief Start the task.
*
* If the task is already running, do nothing.
*/
virtual void start() {
std::lock_guard lock(mutex);
if (!running) {
std::lock_guard threadLock(threadMutex);
if (thread.joinable()) {
thread.join();
}
running = true;
quitting = false;
thread = std::thread(&AsyncTask::run, this);
}
}

/**
* @brief Stop the task and wait for it to finish.
*
* If the task is not running, do nothing.
*/
virtual void stop() {
bool isRunning = false;
{
std::lock_guard lock(mutex);
if (isRunningInternal()) {
quitting = true;
isRunning = true;
}
}
if (isRunning) {
cv.notify_one();
std::lock_guard threadLock(threadMutex);
if (thread.joinable()) {
thread.join();
}
}
}

/**
* @brief Check if the task is running.
*
* @return If the task is currently running.
*/
bool isRunning() {
std::lock_guard lock(mutex);
return running;
}

protected:
/**
* @brief The long-running task, overridden by client code.
*
* If a task wants to stop itself, it can just return.
*
* @param lock The lock on the private internal state of the AsyncTask. Client code should
* generally not use this except for the wait_until_xxx methods.
*/
virtual void task(std::unique_lock<std::mutex>& lock) = 0;

/**
* @brief Version of AsyncTask::isRunning() that does no synchronization.
*
* This is useful if called from AsyncTask::task(), to prevent deadlocks.
*
* @return If the task is currently running.
*/
bool isRunningInternal() {
return running;
}

/**
* @brief Wait until the specified time point, or until the task has been stopped.
*
* @param lock The lock passed to AsyncTask::task.
* @param tp The time point to wait until.
* @return true iff the task was stopped while waiting.
*/
bool wait_until(std::unique_lock<std::mutex>& lock,
const std::chrono::time_point<Clock>& tp) {
return cv.wait_until(lock, tp, [&]() { return quitting; });
}

/**
* @brief Wait for a given duration, or until the task has been stopped.
*
* @param lock The lock passed to AsyncTask::task.
* @param tp The duration of time to wait for.
* @return true iff the task was stopped while waiting.
*/
template <typename Rep, typename Period>
bool wait_for(std::unique_lock<std::mutex>& lock,
const std::chrono::duration<Rep, Period>& dur) {
return wait_until(lock, Clock::now() + dur);
}

/**
* @brief Wait until the task has been stopped.
*
* @param lock The lock passed to AsyncTask::task.
*/
void wait_until_done(std::unique_lock<std::mutex>& lock) {
return cv.wait(lock, [&]() { return quitting; });
}

/**
* @brief Not for use by client code.
*/
void notify() override {
cv.notify_all();
}

private:
std::optional<std::string> name;
bool running;
bool quitting;
std::condition_variable cv;
std::thread thread;

// If acquiring both mutexes, acquire mutex first then threadMutex.
std::mutex mutex; // protects everything except thread
std::mutex threadMutex; // protects only thread

friend void util::impl::notifyScheduler<>(AsyncTask&);

void run() {
if (name.has_value()) {
loguru::set_thread_name(name->c_str());
}
std::unique_lock<std::mutex> lock(mutex);
// clear flags when exiting
RAIIHelper r([&]() { running = quitting = false; });
task(lock);
}
};

/**
* @brief Implements a task that executes a function periodically.
*
* Note that all PeriodicTask instances are run on the same thread, so the task function should
* not block.
*
* @tparam Clock The clock to use for timing.
*/
template <typename Clock = std::chrono::steady_clock>
class PeriodicTask : public AsyncTask<Clock>, private virtual impl::Notifiable {
public:
/**
* @brief Construct a new periodic task.
*
* @param period The period at which to run the task.
* @param f The function to execute every period.
*/
PeriodicTask(const std::chrono::milliseconds& period, const std::function<void()>& f)
: AsyncTask<Clock>(), period(period), f(f) {}

protected:
void task(std::unique_lock<std::mutex>& lock) override {
auto event = scheduler.scheduleEvent(period, f);
AsyncTask<Clock>::wait_until_done(lock);
scheduler.removeEvent(event);
}

void notify() override {
impl::notifyScheduler(scheduler);
AsyncTask<Clock>::notify();
}

private:
std::chrono::milliseconds period;
std::function<void()> f;

inline static PeriodicScheduler<Clock> scheduler =
PeriodicScheduler<Clock>("PeriodicTask_Scheduler");

friend void util::impl::notifyScheduler<>(PeriodicTask&);
};

namespace impl {

/**
Expand Down
25 changes: 25 additions & 0 deletions tests/util/CoreTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "../../src/utils/core.h"

#include <catch2/catch.hpp>

using namespace util;

TEST_CASE("Test RAIIHelper", "[util][core]") {
SECTION("Test RAIIHelper executes") {
bool called = false;
{
RAIIHelper r([&]() { called = true; });
}
REQUIRE(called);
}

SECTION("Test RAIIHelper moves properly") {
int i = 0;
{
RAIIHelper r([&]() { i++; });
{ RAIIHelper r2(std::move(r)); }
REQUIRE(i == 1);
}
REQUIRE(i == 1);
}
}
Loading