forked from Benny44/QPALM_vLADEL
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for asynchronous solver cancellation
- Loading branch information
Showing
15 changed files
with
317 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
#pragma once | ||
|
||
#include <pybind11/gil.h> | ||
namespace py = pybind11; | ||
|
||
#include <chrono> | ||
#include <exception> | ||
#include <future> | ||
#include <iostream> | ||
#include <tuple> | ||
using namespace std::chrono_literals; | ||
|
||
#include "thread-checker.hpp" | ||
|
||
namespace qpalm { | ||
|
||
template <class Solver, class Invoker, class... CheckedArgs> | ||
void async_solve(bool async, bool suppress_interrupt, Solver &solver, Invoker &invoke_solver, | ||
CheckedArgs &...checked_args) { | ||
if (!async) { | ||
// Invoke the solver synchronously | ||
invoke_solver(); | ||
} else { | ||
// Check that the user doesn't use the same solver/problem in multiple threads | ||
ThreadChecker solver_checker{solver}; | ||
std::tuple checkers{ThreadChecker{checked_args}...}; | ||
// Invoke the solver asynchronously | ||
auto done = std::async(std::launch::async, invoke_solver); | ||
{ | ||
py::gil_scoped_release gil; | ||
while (done.wait_for(50ms) != std::future_status::ready) { | ||
py::gil_scoped_acquire gil; | ||
// Check if Python received a signal (e.g. Ctrl+C) | ||
if (PyErr_CheckSignals() != 0) { | ||
// Nicely ask the solver to stop | ||
solver.cancel(); | ||
// It should return a result soon | ||
if (py::gil_scoped_release gil; | ||
done.wait_for(15s) != std::future_status::ready) { | ||
// If it doesn't, we terminate the entire program, | ||
// because the solver uses variables local to this | ||
// function, so we cannot safely return without | ||
// waiting for the solver to finish. | ||
std::cerr << "QPALM solver failed to respond to cancellation request. " | ||
"Terminating ..." | ||
<< std::endl; | ||
std::terminate(); | ||
} | ||
if (PyErr_Occurred()) { | ||
if (PyErr_ExceptionMatches(PyExc_KeyboardInterrupt) && suppress_interrupt) | ||
PyErr_Clear(); // Clear the KeyboardInterrupt exception | ||
else | ||
throw py::error_already_set(); | ||
} | ||
break; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
} // namespace qpalm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#include "demangled-typename.hpp" | ||
#include <cstdlib> | ||
#include <memory> | ||
#ifdef __GNUC__ | ||
#include <cxxabi.h> | ||
#endif | ||
|
||
namespace qpalm { | ||
|
||
std::string demangled_typename(const std::type_info &t) { | ||
#ifdef __GNUC__ | ||
return std::unique_ptr<char, decltype(&std::free)>{ | ||
abi::__cxa_demangle(t.name(), nullptr, nullptr, nullptr), std::free} | ||
.get(); | ||
#else | ||
return t.name(); | ||
#endif | ||
} | ||
|
||
} // namespace qpalm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#pragma once | ||
|
||
#include <string> | ||
#include <typeinfo> | ||
|
||
namespace qpalm { | ||
|
||
/// Get the pretty name of the given type as a string. | ||
std::string demangled_typename(const std::type_info &t); | ||
|
||
} // namespace qpalm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import qpalm | ||
import numpy as np | ||
import concurrent.futures | ||
from time import sleep | ||
import scipy.sparse as spa | ||
|
||
|
||
def test_qpalm_cancel(): | ||
settings = qpalm.Settings() | ||
settings.max_iter = 20000 | ||
settings.eps_abs = 1e-200 | ||
settings.eps_rel = 0 | ||
settings.eps_rel_in = 0 | ||
settings.verbose = 1 | ||
|
||
def create_solver(): | ||
m, n = 100, 120 | ||
data = qpalm.Data(n, m) | ||
rng = np.random.default_rng(seed=123) | ||
Q = rng.random((n, n)) | ||
A = rng.random((m, n)) | ||
Q = Q.T @ Q | ||
data.Q = spa.csc_array(Q) | ||
data.A = spa.csc_array(A) | ||
data.q = rng.random(n) | ||
data.bmax = rng.random(m) | ||
data.bmin = -np.inf * np.ones(m) | ||
return qpalm.Solver(data=data, settings=settings) | ||
|
||
solver = create_solver() | ||
|
||
def run_solver(): | ||
solver.solve(asynchronous=True, suppress_interrupt=True) | ||
return solver.info.status_val | ||
|
||
with concurrent.futures.ThreadPoolExecutor() as pool: | ||
future = pool.submit(run_solver) | ||
sleep(0.2) | ||
solver.cancel() | ||
assert future.result() == qpalm.Info.USER_CANCELLATION | ||
|
||
|
||
if __name__ == "__main__": | ||
test_qpalm_cancel() | ||
print("done.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from copy import deepcopy | ||
import qpalm | ||
import numpy as np | ||
import concurrent.futures | ||
import pytest | ||
import os | ||
import scipy.sparse as spa | ||
|
||
|
||
def test_qpalm_threaded(): | ||
valgrind = "valgrind" in os.getenv("LD_PRELOAD", "") | ||
|
||
settings = qpalm.Settings() | ||
settings.max_iter = 300 | ||
settings.eps_abs = 1e-200 | ||
settings.eps_rel = 0 | ||
settings.eps_rel_in = 0 | ||
settings.verbose = 1 | ||
|
||
def create_solver(): | ||
m, n = 100, 120 | ||
data = qpalm.Data(n, m) | ||
rng = np.random.default_rng(seed=123) | ||
Q = rng.random((n, n)) | ||
A = rng.random((m, n)) | ||
Q = Q.T @ Q | ||
data.Q = spa.csc_array(Q) | ||
data.A = spa.csc_array(A) | ||
data.q = rng.random(n) | ||
data.bmax = rng.random(m) | ||
data.bmin = -np.inf * np.ones(m) | ||
return qpalm.Solver(data=data, settings=settings) | ||
|
||
shared_solver = create_solver() | ||
|
||
def good_experiment(): | ||
solver = create_solver() | ||
solver.solve(asynchronous=True) | ||
return solver.info.status_val == qpalm.Info.MAX_ITER_REACHED | ||
|
||
def bad_experiment(): | ||
solver = shared_solver | ||
solver.solve(asynchronous=True) | ||
return solver.info.status_val == qpalm.Info.MAX_ITER_REACHED | ||
|
||
def run(experiment): | ||
N = 4 if valgrind else 200 | ||
with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as pool: | ||
futures = (pool.submit(experiment) for _ in range(N)) | ||
for future in concurrent.futures.as_completed(futures): | ||
success = future.result() | ||
assert success | ||
|
||
run(good_experiment) | ||
if not valgrind: | ||
with pytest.raises( | ||
RuntimeError, match=r"^Same instance of .* used in multiple threads" | ||
) as e: | ||
run(bad_experiment) | ||
print(e.value) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_qpalm_threaded() | ||
print("done.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
#pragma once | ||
|
||
#include "demangled-typename.hpp" | ||
#include <optional> | ||
#include <set> | ||
#include <stdexcept> | ||
|
||
namespace qpalm { | ||
|
||
template <class T> | ||
const T *get_identity(const T &t) { | ||
return std::addressof(t); | ||
} | ||
template <class T> | ||
void get_identity(const T *) = delete; | ||
|
||
template <class T> | ||
class ThreadChecker { | ||
using set_t = std::set<decltype(get_identity(std::declval<T>()))>; | ||
using iterator_t = typename set_t::iterator; | ||
static set_t set; | ||
std::optional<iterator_t> iterator; | ||
|
||
public: | ||
ThreadChecker(const T &t) { | ||
auto [iter, inserted] = set.insert(get_identity(t)); | ||
if (!inserted) { | ||
std::string name = "instance of type " + demangled_typename(typeid(T)); | ||
throw std::runtime_error("Same " + name + | ||
" used in multiple threads (consider making a copy or " | ||
"creating a separate instance for each thread)"); | ||
} | ||
iterator = iter; | ||
} | ||
~ThreadChecker() { | ||
if (iterator) | ||
set.erase(*iterator); | ||
} | ||
ThreadChecker(const ThreadChecker &) = delete; | ||
ThreadChecker &operator=(const ThreadChecker &) = delete; | ||
ThreadChecker(ThreadChecker &&o) noexcept { std::swap(this->iterator, o.iterator); } | ||
ThreadChecker &operator=(ThreadChecker &&o) noexcept { | ||
this->iterator = std::move(o.iterator); | ||
o.iterator.reset(); | ||
return *this; | ||
} | ||
}; | ||
|
||
template <class T> | ||
typename ThreadChecker<T>::set_t ThreadChecker<T>::set; | ||
|
||
} // namespace qpalm |
Oops, something went wrong.