Skip to content

Commit

Permalink
Add support for asynchronous solver cancellation
Browse files Browse the repository at this point in the history
  • Loading branch information
tttapa committed Mar 28, 2024
1 parent d1b58c3 commit 4a59225
Show file tree
Hide file tree
Showing 15 changed files with 317 additions and 2 deletions.
6 changes: 6 additions & 0 deletions QPALM/include/qpalm.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ QPALM_EXPORT void qpalm_warm_start(QPALMWorkspace *work,
*/
QPALM_EXPORT void qpalm_solve(QPALMWorkspace *work);

/**
* Cancel the ongoing call to @ref qpalm_solve.
*
* Thread- and signal handler-safe.
*/
QPALM_EXPORT void qpalm_cancel(QPALMWorkspace *work);

/**
* Update the settings to the new settings.
Expand Down
1 change: 1 addition & 0 deletions QPALM/include/qpalm/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ extern "C" {
#define QPALM_PRIMAL_INFEASIBLE (-3) /**< status to indicate the problem is primal infeasible */
#define QPALM_DUAL_INFEASIBLE (-4) /**< status to indicate the problem is dual infeasible */
#define QPALM_TIME_LIMIT_REACHED (-5) /**< status to indicate the problem's runtime has exceeded the specified time limit */
#define QPALM_USER_CANCELLATION (-6) /**< status to indicate the user has cancelled the solve */
#define QPALM_UNSOLVED (-10) /**< status to indicate the problem is unsolved. Only setup function has been called */
#define QPALM_ERROR (0) /**< status to indicate an error has occured (this error should automatically be printed) */

Expand Down
9 changes: 9 additions & 0 deletions QPALM/include/qpalm/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
#ifndef QPALM_TYPES_H
# define QPALM_TYPES_H

#ifdef __cplusplus
#include <atomic>
using std::atomic_bool;
#else
#include <stdatomic.h>
#endif

# ifdef __cplusplus
extern "C" {
# endif
Expand Down Expand Up @@ -311,6 +318,8 @@ typedef struct {
QPALMTimer *timer; ///< timer object
# endif // ifdef QPALM_TIMING

atomic_bool cancel; ///< Cancel the solver (from other thread or signal)

} QPALMWorkspace;


Expand Down
3 changes: 3 additions & 0 deletions QPALM/interfaces/cxx/include/qpalm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ class Solver {
/// @ref get_info().
/// @see @ref ::qpalm_solve
QPALM_CXX_EXPORT void solve();
/// Cancel the ongoing call to @ref solve.
/// Thread- and signal handler-safe.
QPALM_CXX_EXPORT void cancel();

/// Get the solution computed by @ref solve().
/// @note Returns a view that is only valid as long as the solver is not
Expand Down
2 changes: 2 additions & 0 deletions QPALM/interfaces/cxx/src/qpalm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ void Solver::warm_start(std::optional<const_ref_vec_t> x,

void Solver::solve() { ::qpalm_solve(work.get()); }

void Solver::cancel() { ::qpalm_cancel(work.get()); }

SolutionView Solver::get_solution() const {
assert(work->solution);
assert(work->solution->x);
Expand Down
2 changes: 1 addition & 1 deletion QPALM/interfaces/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ include(cmake/QueryPythonForPybind11.cmake)
find_pybind11_python_first()

# Compile the example Python module
pybind11_add_module(_qpalm MODULE "qpalm.py.cpp")
pybind11_add_module(_qpalm MODULE "qpalm.py.cpp" "demangled-typename.cpp")
target_compile_features(_qpalm PRIVATE cxx_std_17)
target_link_libraries(_qpalm PRIVATE qpalm_warnings)
target_link_libraries(_qpalm PRIVATE pybind11::pybind11 qpalm qpalm_cxx)
Expand Down
62 changes: 62 additions & 0 deletions QPALM/interfaces/python/async.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#pragma once

#include <pybind11/gil.h>
namespace py = pybind11;

#include <chrono>
#include <exception>
#include <future>
#include <iostream>
#include <tuple>
using namespace std::chrono_literals;

#include "thread-checker.hpp"

namespace qpalm {

template <class Solver, class Invoker, class... CheckedArgs>
void async_solve(bool async, bool suppress_interrupt, Solver &solver, Invoker &invoke_solver,
CheckedArgs &...checked_args) {
if (!async) {
// Invoke the solver synchronously
invoke_solver();
} else {
// Check that the user doesn't use the same solver/problem in multiple threads
ThreadChecker solver_checker{solver};
std::tuple checkers{ThreadChecker{checked_args}...};
// Invoke the solver asynchronously
auto done = std::async(std::launch::async, invoke_solver);
{
py::gil_scoped_release gil;
while (done.wait_for(50ms) != std::future_status::ready) {
py::gil_scoped_acquire gil;
// Check if Python received a signal (e.g. Ctrl+C)
if (PyErr_CheckSignals() != 0) {
// Nicely ask the solver to stop
solver.cancel();
// It should return a result soon
if (py::gil_scoped_release gil;
done.wait_for(15s) != std::future_status::ready) {
// If it doesn't, we terminate the entire program,
// because the solver uses variables local to this
// function, so we cannot safely return without
// waiting for the solver to finish.
std::cerr << "QPALM solver failed to respond to cancellation request. "
"Terminating ..."
<< std::endl;
std::terminate();
}
if (PyErr_Occurred()) {
if (PyErr_ExceptionMatches(PyExc_KeyboardInterrupt) && suppress_interrupt)
PyErr_Clear(); // Clear the KeyboardInterrupt exception
else
throw py::error_already_set();
}
break;
}
}
}
}
}

} // namespace qpalm
20 changes: 20 additions & 0 deletions QPALM/interfaces/python/demangled-typename.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include "demangled-typename.hpp"
#include <cstdlib>
#include <memory>
#ifdef __GNUC__
#include <cxxabi.h>
#endif

namespace qpalm {

std::string demangled_typename(const std::type_info &t) {
#ifdef __GNUC__
return std::unique_ptr<char, decltype(&std::free)>{
abi::__cxa_demangle(t.name(), nullptr, nullptr, nullptr), std::free}
.get();
#else
return t.name();
#endif
}

} // namespace qpalm
11 changes: 11 additions & 0 deletions QPALM/interfaces/python/demangled-typename.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#pragma once

#include <string>
#include <typeinfo>

namespace qpalm {

/// Get the pretty name of the given type as a string.
std::string demangled_typename(const std::type_info &t);

} // namespace qpalm
17 changes: 16 additions & 1 deletion QPALM/interfaces/python/qpalm.py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <Python.h>
#include <pybind11/eigen.h>
#include <pybind11/gil.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;
Expand All @@ -19,6 +20,8 @@ using py::operator""_a;
#include <string>
#include <string_view>

#include "async.hpp"

/// Throw an exception if the dimensions of the matrix don't match the expected
/// dimensions @p r and @p c.
static void check_dim(const qpalm::sparse_mat_t &M, std::string_view name, qpalm::index_t r,
Expand Down Expand Up @@ -57,10 +60,12 @@ PYBIND11_MODULE(MODULE_NAME, m) {
m.attr("debug") = true;
#endif

#if 0 // not thread-safe
ladel_set_alloc_config_calloc(&PyMem_Calloc);
ladel_set_alloc_config_malloc(&PyMem_Malloc);
ladel_set_alloc_config_realloc(&PyMem_Realloc);
ladel_set_alloc_config_free(&PyMem_Free);
#endif
ladel_set_print_config_printf(&print_wrap);

py::class_<qpalm::Data>(m, "Data")
Expand Down Expand Up @@ -152,6 +157,7 @@ PYBIND11_MODULE(MODULE_NAME, m) {
info.attr("PRIMAL_INFEASIBLE") = QPALM_PRIMAL_INFEASIBLE;
info.attr("DUAL_INFEASIBLE") = QPALM_DUAL_INFEASIBLE;
info.attr("TIME_LIMIT_REACHED") = QPALM_TIME_LIMIT_REACHED;
info.attr("USER_CANCELLATION") = QPALM_USER_CANCELLATION;
info.attr("UNSOLVED") = QPALM_UNSOLVED;
info.attr("ERROR") = QPALM_ERROR;

Expand Down Expand Up @@ -233,7 +239,15 @@ PYBIND11_MODULE(MODULE_NAME, m) {
self.warm_start(x, y);
},
"x"_a = py::none(), "y"_a = py::none())
.def("solve", &qpalm::Solver::solve)
.def(
"solve",
[](qpalm::Solver &self, bool async, bool suppress_interrupt) {
auto invoke_solver = [&] { self.solve(); };
qpalm::async_solve(async, suppress_interrupt, self, invoke_solver,
*self.get_c_work_ptr());
},
"asynchronous"_a = true, "suppress_interrupt"_a = false)
.def("cancel", &qpalm::Solver::cancel)
.def_property_readonly("solution",
py::cpp_function( // https://github.com/pybind/pybind11/issues/2618
&qpalm::Solver::get_solution, py::return_value_policy::reference,
Expand All @@ -258,6 +272,7 @@ PYBIND11_MODULE(MODULE_NAME, m) {
}

static int print_wrap(const char *fmt, ...) {
py::gil_scoped_acquire gil{};
static std::vector<char> buffer(1024);
py::object write = py::module_::import("sys").attr("stdout").attr("write");
std::va_list args, args2;
Expand Down
45 changes: 45 additions & 0 deletions QPALM/interfaces/python/test/test_qpalm_cancel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import qpalm
import numpy as np
import concurrent.futures
from time import sleep
import scipy.sparse as spa


def test_qpalm_cancel():
settings = qpalm.Settings()
settings.max_iter = 20000
settings.eps_abs = 1e-200
settings.eps_rel = 0
settings.eps_rel_in = 0
settings.verbose = 1

def create_solver():
m, n = 100, 120
data = qpalm.Data(n, m)
rng = np.random.default_rng(seed=123)
Q = rng.random((n, n))
A = rng.random((m, n))
Q = Q.T @ Q
data.Q = spa.csc_array(Q)
data.A = spa.csc_array(A)
data.q = rng.random(n)
data.bmax = rng.random(m)
data.bmin = -np.inf * np.ones(m)
return qpalm.Solver(data=data, settings=settings)

solver = create_solver()

def run_solver():
solver.solve(asynchronous=True, suppress_interrupt=True)
return solver.info.status_val

with concurrent.futures.ThreadPoolExecutor() as pool:
future = pool.submit(run_solver)
sleep(0.2)
solver.cancel()
assert future.result() == qpalm.Info.USER_CANCELLATION


if __name__ == "__main__":
test_qpalm_cancel()
print("done.")
65 changes: 65 additions & 0 deletions QPALM/interfaces/python/test/test_qpalm_thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from copy import deepcopy
import qpalm
import numpy as np
import concurrent.futures
import pytest
import os
import scipy.sparse as spa


def test_qpalm_threaded():
valgrind = "valgrind" in os.getenv("LD_PRELOAD", "")

settings = qpalm.Settings()
settings.max_iter = 300
settings.eps_abs = 1e-200
settings.eps_rel = 0
settings.eps_rel_in = 0
settings.verbose = 1

def create_solver():
m, n = 100, 120
data = qpalm.Data(n, m)
rng = np.random.default_rng(seed=123)
Q = rng.random((n, n))
A = rng.random((m, n))
Q = Q.T @ Q
data.Q = spa.csc_array(Q)
data.A = spa.csc_array(A)
data.q = rng.random(n)
data.bmax = rng.random(m)
data.bmin = -np.inf * np.ones(m)
return qpalm.Solver(data=data, settings=settings)

shared_solver = create_solver()

def good_experiment():
solver = create_solver()
solver.solve(asynchronous=True)
return solver.info.status_val == qpalm.Info.MAX_ITER_REACHED

def bad_experiment():
solver = shared_solver
solver.solve(asynchronous=True)
return solver.info.status_val == qpalm.Info.MAX_ITER_REACHED

def run(experiment):
N = 4 if valgrind else 200
with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as pool:
futures = (pool.submit(experiment) for _ in range(N))
for future in concurrent.futures.as_completed(futures):
success = future.result()
assert success

run(good_experiment)
if not valgrind:
with pytest.raises(
RuntimeError, match=r"^Same instance of .* used in multiple threads"
) as e:
run(bad_experiment)
print(e.value)


if __name__ == "__main__":
test_qpalm_threaded()
print("done.")
52 changes: 52 additions & 0 deletions QPALM/interfaces/python/thread-checker.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#pragma once

#include "demangled-typename.hpp"
#include <optional>
#include <set>
#include <stdexcept>

namespace qpalm {

template <class T>
const T *get_identity(const T &t) {
return std::addressof(t);
}
template <class T>
void get_identity(const T *) = delete;

template <class T>
class ThreadChecker {
using set_t = std::set<decltype(get_identity(std::declval<T>()))>;
using iterator_t = typename set_t::iterator;
static set_t set;
std::optional<iterator_t> iterator;

public:
ThreadChecker(const T &t) {
auto [iter, inserted] = set.insert(get_identity(t));
if (!inserted) {
std::string name = "instance of type " + demangled_typename(typeid(T));
throw std::runtime_error("Same " + name +
" used in multiple threads (consider making a copy or "
"creating a separate instance for each thread)");
}
iterator = iter;
}
~ThreadChecker() {
if (iterator)
set.erase(*iterator);
}
ThreadChecker(const ThreadChecker &) = delete;
ThreadChecker &operator=(const ThreadChecker &) = delete;
ThreadChecker(ThreadChecker &&o) noexcept { std::swap(this->iterator, o.iterator); }
ThreadChecker &operator=(ThreadChecker &&o) noexcept {
this->iterator = std::move(o.iterator);
o.iterator.reset();
return *this;
}
};

template <class T>
typename ThreadChecker<T>::set_t ThreadChecker<T>::set;

} // namespace qpalm
Loading

0 comments on commit 4a59225

Please sign in to comment.