From 4a592250837d78321ff74e1b3e49972ceb3f5b2b Mon Sep 17 00:00:00 2001 From: Pieter Pas Date: Thu, 28 Mar 2024 12:21:52 +0100 Subject: [PATCH] Add support for asynchronous solver cancellation --- QPALM/include/qpalm.h | 6 ++ QPALM/include/qpalm/constants.h | 1 + QPALM/include/qpalm/types.h | 9 +++ QPALM/interfaces/cxx/include/qpalm.hpp | 3 + QPALM/interfaces/cxx/src/qpalm.cpp | 2 + QPALM/interfaces/python/CMakeLists.txt | 2 +- QPALM/interfaces/python/async.hpp | 62 ++++++++++++++++++ .../interfaces/python/demangled-typename.cpp | 20 ++++++ .../interfaces/python/demangled-typename.hpp | 11 ++++ QPALM/interfaces/python/qpalm.py.cpp | 17 ++++- .../python/test/test_qpalm_cancel.py | 45 +++++++++++++ .../python/test/test_qpalm_thread.py | 65 +++++++++++++++++++ QPALM/interfaces/python/thread-checker.hpp | 52 +++++++++++++++ QPALM/src/qpalm.c | 12 ++++ QPALM/src/util.c | 12 ++++ 15 files changed, 317 insertions(+), 2 deletions(-) create mode 100644 QPALM/interfaces/python/async.hpp create mode 100644 QPALM/interfaces/python/demangled-typename.cpp create mode 100644 QPALM/interfaces/python/demangled-typename.hpp create mode 100644 QPALM/interfaces/python/test/test_qpalm_cancel.py create mode 100644 QPALM/interfaces/python/test/test_qpalm_thread.py create mode 100644 QPALM/interfaces/python/thread-checker.hpp diff --git a/QPALM/include/qpalm.h b/QPALM/include/qpalm.h index 7fdd7d97..f88cd07b 100644 --- a/QPALM/include/qpalm.h +++ b/QPALM/include/qpalm.h @@ -97,6 +97,12 @@ QPALM_EXPORT void qpalm_warm_start(QPALMWorkspace *work, */ QPALM_EXPORT void qpalm_solve(QPALMWorkspace *work); +/** + * Cancel the ongoing call to @ref qpalm_solve. + * + * Thread- and signal handler-safe. + */ +QPALM_EXPORT void qpalm_cancel(QPALMWorkspace *work); /** * Update the settings to the new settings. diff --git a/QPALM/include/qpalm/constants.h b/QPALM/include/qpalm/constants.h index 66a243e9..72e44b19 100644 --- a/QPALM/include/qpalm/constants.h +++ b/QPALM/include/qpalm/constants.h @@ -33,6 +33,7 @@ extern "C" { #define QPALM_PRIMAL_INFEASIBLE (-3) /**< status to indicate the problem is primal infeasible */ #define QPALM_DUAL_INFEASIBLE (-4) /**< status to indicate the problem is dual infeasible */ #define QPALM_TIME_LIMIT_REACHED (-5) /**< status to indicate the problem's runtime has exceeded the specified time limit */ +#define QPALM_USER_CANCELLATION (-6) /**< status to indicate the user has cancelled the solve */ #define QPALM_UNSOLVED (-10) /**< status to indicate the problem is unsolved. Only setup function has been called */ #define QPALM_ERROR (0) /**< status to indicate an error has occured (this error should automatically be printed) */ diff --git a/QPALM/include/qpalm/types.h b/QPALM/include/qpalm/types.h index 98047006..916ea0aa 100644 --- a/QPALM/include/qpalm/types.h +++ b/QPALM/include/qpalm/types.h @@ -8,6 +8,13 @@ #ifndef QPALM_TYPES_H # define QPALM_TYPES_H +#ifdef __cplusplus +#include +using std::atomic_bool; +#else +#include +#endif + # ifdef __cplusplus extern "C" { # endif @@ -311,6 +318,8 @@ typedef struct { QPALMTimer *timer; ///< timer object # endif // ifdef QPALM_TIMING + atomic_bool cancel; ///< Cancel the solver (from other thread or signal) + } QPALMWorkspace; diff --git a/QPALM/interfaces/cxx/include/qpalm.hpp b/QPALM/interfaces/cxx/include/qpalm.hpp index 937e6228..bc24a57a 100644 --- a/QPALM/interfaces/cxx/include/qpalm.hpp +++ b/QPALM/interfaces/cxx/include/qpalm.hpp @@ -173,6 +173,9 @@ class Solver { /// @ref get_info(). /// @see @ref ::qpalm_solve QPALM_CXX_EXPORT void solve(); + /// Cancel the ongoing call to @ref solve. + /// Thread- and signal handler-safe. + QPALM_CXX_EXPORT void cancel(); /// Get the solution computed by @ref solve(). /// @note Returns a view that is only valid as long as the solver is not diff --git a/QPALM/interfaces/cxx/src/qpalm.cpp b/QPALM/interfaces/cxx/src/qpalm.cpp index 96c99662..9cd6555b 100644 --- a/QPALM/interfaces/cxx/src/qpalm.cpp +++ b/QPALM/interfaces/cxx/src/qpalm.cpp @@ -62,6 +62,8 @@ void Solver::warm_start(std::optional x, void Solver::solve() { ::qpalm_solve(work.get()); } +void Solver::cancel() { ::qpalm_cancel(work.get()); } + SolutionView Solver::get_solution() const { assert(work->solution); assert(work->solution->x); diff --git a/QPALM/interfaces/python/CMakeLists.txt b/QPALM/interfaces/python/CMakeLists.txt index a188de7f..46795373 100644 --- a/QPALM/interfaces/python/CMakeLists.txt +++ b/QPALM/interfaces/python/CMakeLists.txt @@ -22,7 +22,7 @@ include(cmake/QueryPythonForPybind11.cmake) find_pybind11_python_first() # Compile the example Python module -pybind11_add_module(_qpalm MODULE "qpalm.py.cpp") +pybind11_add_module(_qpalm MODULE "qpalm.py.cpp" "demangled-typename.cpp") target_compile_features(_qpalm PRIVATE cxx_std_17) target_link_libraries(_qpalm PRIVATE qpalm_warnings) target_link_libraries(_qpalm PRIVATE pybind11::pybind11 qpalm qpalm_cxx) diff --git a/QPALM/interfaces/python/async.hpp b/QPALM/interfaces/python/async.hpp new file mode 100644 index 00000000..f6d979dd --- /dev/null +++ b/QPALM/interfaces/python/async.hpp @@ -0,0 +1,62 @@ +#pragma once + +#include +namespace py = pybind11; + +#include +#include +#include +#include +#include +using namespace std::chrono_literals; + +#include "thread-checker.hpp" + +namespace qpalm { + +template +void async_solve(bool async, bool suppress_interrupt, Solver &solver, Invoker &invoke_solver, + CheckedArgs &...checked_args) { + if (!async) { + // Invoke the solver synchronously + invoke_solver(); + } else { + // Check that the user doesn't use the same solver/problem in multiple threads + ThreadChecker solver_checker{solver}; + std::tuple checkers{ThreadChecker{checked_args}...}; + // Invoke the solver asynchronously + auto done = std::async(std::launch::async, invoke_solver); + { + py::gil_scoped_release gil; + while (done.wait_for(50ms) != std::future_status::ready) { + py::gil_scoped_acquire gil; + // Check if Python received a signal (e.g. Ctrl+C) + if (PyErr_CheckSignals() != 0) { + // Nicely ask the solver to stop + solver.cancel(); + // It should return a result soon + if (py::gil_scoped_release gil; + done.wait_for(15s) != std::future_status::ready) { + // If it doesn't, we terminate the entire program, + // because the solver uses variables local to this + // function, so we cannot safely return without + // waiting for the solver to finish. + std::cerr << "QPALM solver failed to respond to cancellation request. " + "Terminating ..." + << std::endl; + std::terminate(); + } + if (PyErr_Occurred()) { + if (PyErr_ExceptionMatches(PyExc_KeyboardInterrupt) && suppress_interrupt) + PyErr_Clear(); // Clear the KeyboardInterrupt exception + else + throw py::error_already_set(); + } + break; + } + } + } + } +} + +} // namespace qpalm diff --git a/QPALM/interfaces/python/demangled-typename.cpp b/QPALM/interfaces/python/demangled-typename.cpp new file mode 100644 index 00000000..69a50295 --- /dev/null +++ b/QPALM/interfaces/python/demangled-typename.cpp @@ -0,0 +1,20 @@ +#include "demangled-typename.hpp" +#include +#include +#ifdef __GNUC__ +#include +#endif + +namespace qpalm { + +std::string demangled_typename(const std::type_info &t) { +#ifdef __GNUC__ + return std::unique_ptr{ + abi::__cxa_demangle(t.name(), nullptr, nullptr, nullptr), std::free} + .get(); +#else + return t.name(); +#endif +} + +} // namespace qpalm diff --git a/QPALM/interfaces/python/demangled-typename.hpp b/QPALM/interfaces/python/demangled-typename.hpp new file mode 100644 index 00000000..2f95f67d --- /dev/null +++ b/QPALM/interfaces/python/demangled-typename.hpp @@ -0,0 +1,11 @@ +#pragma once + +#include +#include + +namespace qpalm { + +/// Get the pretty name of the given type as a string. +std::string demangled_typename(const std::type_info &t); + +} // namespace qpalm diff --git a/QPALM/interfaces/python/qpalm.py.cpp b/QPALM/interfaces/python/qpalm.py.cpp index eea75fba..56f71a48 100644 --- a/QPALM/interfaces/python/qpalm.py.cpp +++ b/QPALM/interfaces/python/qpalm.py.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include namespace py = pybind11; @@ -19,6 +20,8 @@ using py::operator""_a; #include #include +#include "async.hpp" + /// Throw an exception if the dimensions of the matrix don't match the expected /// dimensions @p r and @p c. static void check_dim(const qpalm::sparse_mat_t &M, std::string_view name, qpalm::index_t r, @@ -57,10 +60,12 @@ PYBIND11_MODULE(MODULE_NAME, m) { m.attr("debug") = true; #endif +#if 0 // not thread-safe ladel_set_alloc_config_calloc(&PyMem_Calloc); ladel_set_alloc_config_malloc(&PyMem_Malloc); ladel_set_alloc_config_realloc(&PyMem_Realloc); ladel_set_alloc_config_free(&PyMem_Free); +#endif ladel_set_print_config_printf(&print_wrap); py::class_(m, "Data") @@ -152,6 +157,7 @@ PYBIND11_MODULE(MODULE_NAME, m) { info.attr("PRIMAL_INFEASIBLE") = QPALM_PRIMAL_INFEASIBLE; info.attr("DUAL_INFEASIBLE") = QPALM_DUAL_INFEASIBLE; info.attr("TIME_LIMIT_REACHED") = QPALM_TIME_LIMIT_REACHED; + info.attr("USER_CANCELLATION") = QPALM_USER_CANCELLATION; info.attr("UNSOLVED") = QPALM_UNSOLVED; info.attr("ERROR") = QPALM_ERROR; @@ -233,7 +239,15 @@ PYBIND11_MODULE(MODULE_NAME, m) { self.warm_start(x, y); }, "x"_a = py::none(), "y"_a = py::none()) - .def("solve", &qpalm::Solver::solve) + .def( + "solve", + [](qpalm::Solver &self, bool async, bool suppress_interrupt) { + auto invoke_solver = [&] { self.solve(); }; + qpalm::async_solve(async, suppress_interrupt, self, invoke_solver, + *self.get_c_work_ptr()); + }, + "asynchronous"_a = true, "suppress_interrupt"_a = false) + .def("cancel", &qpalm::Solver::cancel) .def_property_readonly("solution", py::cpp_function( // https://github.com/pybind/pybind11/issues/2618 &qpalm::Solver::get_solution, py::return_value_policy::reference, @@ -258,6 +272,7 @@ PYBIND11_MODULE(MODULE_NAME, m) { } static int print_wrap(const char *fmt, ...) { + py::gil_scoped_acquire gil{}; static std::vector buffer(1024); py::object write = py::module_::import("sys").attr("stdout").attr("write"); std::va_list args, args2; diff --git a/QPALM/interfaces/python/test/test_qpalm_cancel.py b/QPALM/interfaces/python/test/test_qpalm_cancel.py new file mode 100644 index 00000000..eda05c07 --- /dev/null +++ b/QPALM/interfaces/python/test/test_qpalm_cancel.py @@ -0,0 +1,45 @@ +import qpalm +import numpy as np +import concurrent.futures +from time import sleep +import scipy.sparse as spa + + +def test_qpalm_cancel(): + settings = qpalm.Settings() + settings.max_iter = 20000 + settings.eps_abs = 1e-200 + settings.eps_rel = 0 + settings.eps_rel_in = 0 + settings.verbose = 1 + + def create_solver(): + m, n = 100, 120 + data = qpalm.Data(n, m) + rng = np.random.default_rng(seed=123) + Q = rng.random((n, n)) + A = rng.random((m, n)) + Q = Q.T @ Q + data.Q = spa.csc_array(Q) + data.A = spa.csc_array(A) + data.q = rng.random(n) + data.bmax = rng.random(m) + data.bmin = -np.inf * np.ones(m) + return qpalm.Solver(data=data, settings=settings) + + solver = create_solver() + + def run_solver(): + solver.solve(asynchronous=True, suppress_interrupt=True) + return solver.info.status_val + + with concurrent.futures.ThreadPoolExecutor() as pool: + future = pool.submit(run_solver) + sleep(0.2) + solver.cancel() + assert future.result() == qpalm.Info.USER_CANCELLATION + + +if __name__ == "__main__": + test_qpalm_cancel() + print("done.") diff --git a/QPALM/interfaces/python/test/test_qpalm_thread.py b/QPALM/interfaces/python/test/test_qpalm_thread.py new file mode 100644 index 00000000..29488c2e --- /dev/null +++ b/QPALM/interfaces/python/test/test_qpalm_thread.py @@ -0,0 +1,65 @@ +from copy import deepcopy +import qpalm +import numpy as np +import concurrent.futures +import pytest +import os +import scipy.sparse as spa + + +def test_qpalm_threaded(): + valgrind = "valgrind" in os.getenv("LD_PRELOAD", "") + + settings = qpalm.Settings() + settings.max_iter = 300 + settings.eps_abs = 1e-200 + settings.eps_rel = 0 + settings.eps_rel_in = 0 + settings.verbose = 1 + + def create_solver(): + m, n = 100, 120 + data = qpalm.Data(n, m) + rng = np.random.default_rng(seed=123) + Q = rng.random((n, n)) + A = rng.random((m, n)) + Q = Q.T @ Q + data.Q = spa.csc_array(Q) + data.A = spa.csc_array(A) + data.q = rng.random(n) + data.bmax = rng.random(m) + data.bmin = -np.inf * np.ones(m) + return qpalm.Solver(data=data, settings=settings) + + shared_solver = create_solver() + + def good_experiment(): + solver = create_solver() + solver.solve(asynchronous=True) + return solver.info.status_val == qpalm.Info.MAX_ITER_REACHED + + def bad_experiment(): + solver = shared_solver + solver.solve(asynchronous=True) + return solver.info.status_val == qpalm.Info.MAX_ITER_REACHED + + def run(experiment): + N = 4 if valgrind else 200 + with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as pool: + futures = (pool.submit(experiment) for _ in range(N)) + for future in concurrent.futures.as_completed(futures): + success = future.result() + assert success + + run(good_experiment) + if not valgrind: + with pytest.raises( + RuntimeError, match=r"^Same instance of .* used in multiple threads" + ) as e: + run(bad_experiment) + print(e.value) + + +if __name__ == "__main__": + test_qpalm_threaded() + print("done.") diff --git a/QPALM/interfaces/python/thread-checker.hpp b/QPALM/interfaces/python/thread-checker.hpp new file mode 100644 index 00000000..3afb07c4 --- /dev/null +++ b/QPALM/interfaces/python/thread-checker.hpp @@ -0,0 +1,52 @@ +#pragma once + +#include "demangled-typename.hpp" +#include +#include +#include + +namespace qpalm { + +template +const T *get_identity(const T &t) { + return std::addressof(t); +} +template +void get_identity(const T *) = delete; + +template +class ThreadChecker { + using set_t = std::set()))>; + using iterator_t = typename set_t::iterator; + static set_t set; + std::optional iterator; + + public: + ThreadChecker(const T &t) { + auto [iter, inserted] = set.insert(get_identity(t)); + if (!inserted) { + std::string name = "instance of type " + demangled_typename(typeid(T)); + throw std::runtime_error("Same " + name + + " used in multiple threads (consider making a copy or " + "creating a separate instance for each thread)"); + } + iterator = iter; + } + ~ThreadChecker() { + if (iterator) + set.erase(*iterator); + } + ThreadChecker(const ThreadChecker &) = delete; + ThreadChecker &operator=(const ThreadChecker &) = delete; + ThreadChecker(ThreadChecker &&o) noexcept { std::swap(this->iterator, o.iterator); } + ThreadChecker &operator=(ThreadChecker &&o) noexcept { + this->iterator = std::move(o.iterator); + o.iterator.reset(); + return *this; + } +}; + +template +typename ThreadChecker::set_t ThreadChecker::set; + +} // namespace qpalm diff --git a/QPALM/src/qpalm.c b/QPALM/src/qpalm.c index 8a7f118a..05290767 100644 --- a/QPALM/src/qpalm.c +++ b/QPALM/src/qpalm.c @@ -251,6 +251,7 @@ QPALMWorkspace* qpalm_setup(const QPALMData *data, const QPALMSettings *settings work->info->run_time = 0.0; // Total run time to zero work->info->setup_time = qpalm_toc(work->timer); // Update timer information # endif /* ifdef QPALM_TIMING */ + atomic_init(&work->cancel, 0); return work; } @@ -258,6 +259,8 @@ QPALMWorkspace* qpalm_setup(const QPALMData *data, const QPALMSettings *settings void qpalm_warm_start(QPALMWorkspace *work, const c_float *x_warm_start, const c_float *y_warm_start) { + atomic_store(&work->cancel, 0); + // If we have previously solved the problem, then reset the setup time if (work->info->status_val != QPALM_UNSOLVED) { @@ -473,6 +476,10 @@ static void qpalm_terminate_on_status(QPALMWorkspace *work, solver_common *c, so qpalm_termination(work, c, c2, iter, iter_out); } +void qpalm_cancel(QPALMWorkspace *work) { + atomic_store(&work->cancel, 1); +} + void qpalm_solve(QPALMWorkspace *work) { #if defined(QPALM_PRINTING) && defined(_WIN32) && defined(_MSC_VER) && _MSC_VER < 1900 @@ -508,6 +515,11 @@ void qpalm_solve(QPALMWorkspace *work) return; } #endif /* ifdef QPALM_TIMING */ + if (atomic_load(&work->cancel)) + { + qpalm_terminate_on_status(work, c, c2, iter, iter_out, QPALM_USER_CANCELLATION); + return; + } /*Perform the iteration */ compute_residuals(work, c); diff --git a/QPALM/src/util.c b/QPALM/src/util.c index 28318fc7..edaa08de 100644 --- a/QPALM/src/util.c +++ b/QPALM/src/util.c @@ -81,6 +81,9 @@ void update_status(QPALMInfo *info, c_int status_val) { case QPALM_TIME_LIMIT_REACHED: c_strcpy(info->status, "time limit exceeded"); break; + case QPALM_USER_CANCELLATION: + c_strcpy(info->status, "cancelled by user"); + break; case QPALM_MAX_ITER_REACHED: c_strcpy(info->status, "maximum iterations reached"); break; @@ -178,6 +181,15 @@ void print_final_message(QPALMWorkspace *work) { qpalm_print("| dual residual : %5.4e, dual tolerance : %5.4e |\n", work->info->dua_res_norm, work->eps_dua); qpalm_print("| objective value: %+-5.4e |\n", work->info->objective); break; + case QPALM_USER_CANCELLATION: + snprintf(buf, 80,"| QPALM was cancelled. |\n"); + characters_box = strlen(buf); + qpalm_print("%s", buf); + // characters_box = qpalm_print("| QPALM was cancelled. |\n"); + qpalm_print("| primal residual: %5.4e, primal tolerance: %5.4e |\n", work->info->pri_res_norm, work->eps_pri); + qpalm_print("| dual residual : %5.4e, dual tolerance : %5.4e |\n", work->info->dua_res_norm, work->eps_dua); + qpalm_print("| objective value: %+-5.4e |\n", work->info->objective); + break; default: c_strcpy(work->info->status, "unrecognised status value"); qpalm_eprint("Unrecognised final status value %" LADEL_PRIi, work->info->status_val);