Skip to content

Commit

Permalink
Highspy update:
Browse files Browse the repository at this point in the history
* added callback support
* added user/keyboard interrupt support
* fixed issues with deadlock on windows
* fixed MIP solution callback issue
* added resetGlobalScheduler
* released GIL for Presolve
* improved setCallback
* updated unit tests for 99% coverage
  • Loading branch information
mathgeekcoder committed Sep 16, 2024
1 parent 42f9ae6 commit aa4ae38
Show file tree
Hide file tree
Showing 3 changed files with 577 additions and 45 deletions.
74 changes: 49 additions & 25 deletions src/highs_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,28 @@ using namespace pybind11::literals;
template<typename T>
using dense_array_t = py::array_t<T, py::array::c_style | py::array::forcecast>;

// wrapper for raw pointers
template<class T>
class readonly_ptr_wrapper {
public:
readonly_ptr_wrapper() : ptr(nullptr) {}
readonly_ptr_wrapper(T* ptr) : ptr(ptr) {}
readonly_ptr_wrapper(const readonly_ptr_wrapper& other) : ptr(other.ptr) {}
T& operator*() const { return *ptr; }
T* operator->() const { return ptr; }
T* get() const { return ptr; }
T& operator[](std::size_t idx) const { return ptr[idx]; }
bool is_valid() { return ptr != nullptr; }

py::array_t<T, py::array::c_style> to_array(std::size_t size) {
return py::array_t<T, py::array::c_style>(
py::buffer_info(ptr, sizeof(T), py::format_descriptor<T>::format(), 1, {size}, {1}));
}

private:
T* ptr;
};

HighsStatus highs_passModel(Highs* h, HighsModel& model) {
return h->passModel(model);
}
Expand Down Expand Up @@ -582,24 +604,25 @@ std::tuple<HighsStatus, int> highs_getRowByName(Highs* h,
return std::make_tuple(status, row);
}

// Wrap the setCallback function. Pass a lambda wrapper around the python
// function that acquires the GIL and appropriately handle user data passed to
// the callback
// Wrap the setCallback function to appropriately handle user data.
// pybind11 automatically ensures GIL is re-acquired when the callback is called.
HighsStatus highs_setCallback(
Highs* h,
std::function<void(int, const std::string&, const HighsCallbackDataOut*,
HighsCallbackDataIn*, py::handle)>
fn,
std::function<void(int, const std::string&, const HighsCallbackDataOut*, HighsCallbackDataIn*, py::handle)> fn,
py::handle data) {
return h->setCallback(
[fn, data](int callbackType, const std::string& msg,
const HighsCallbackDataOut* dataOut,
HighsCallbackDataIn* dataIn, void* d) {
py::gil_scoped_acquire acquire;
return fn(callbackType, msg, dataOut, dataIn,
py::handle(reinterpret_cast<PyObject*>(d)));
},
data.ptr());

if (static_cast<bool>(fn) == false)
return h->setCallback((HighsCallbackFunctionType)nullptr, nullptr);
else
return h->setCallback(
[fn, data](int callbackType,
const std::string& msg,
const HighsCallbackDataOut* dataOut,
HighsCallbackDataIn* dataIn,
void* d) {
return fn(callbackType, msg, dataOut, dataIn, py::handle(reinterpret_cast<PyObject*>(d)));
},
data.ptr());
}

PYBIND11_MODULE(_core, m) {
Expand Down Expand Up @@ -900,6 +923,7 @@ PYBIND11_MODULE(_core, m) {
.def("postsolve", &highs_postsolve)
.def("postsolve", &highs_mipPostsolve)
.def("run", &Highs::run, py::call_guard<py::gil_scoped_release>())
.def("resetGlobalScheduler", &Highs::resetGlobalScheduler)
.def("feasibilityRelaxation",
[](Highs& self, double global_lower_penalty, double global_upper_penalty, double global_rhs_penalty,
py::object local_lower_penalty, py::object local_upper_penalty, py::object local_rhs_penalty) {
Expand Down Expand Up @@ -931,7 +955,7 @@ PYBIND11_MODULE(_core, m) {
py::arg("local_upper_penalty") = py::none(),
py::arg("local_rhs_penalty") = py::none())
.def("getIis", &Highs::getIis)
.def("presolve", &Highs::presolve)
.def("presolve", &Highs::presolve, py::call_guard<py::gil_scoped_release>())
.def("writeSolution", &highs_writeSolution)
.def("readSolution", &Highs::readSolution)
.def("setOptionValue",
Expand Down Expand Up @@ -1246,6 +1270,7 @@ PYBIND11_MODULE(_core, m) {
.value("kCallbackSimplexInterrupt",
HighsCallbackType::kCallbackSimplexInterrupt)
.value("kCallbackIpmInterrupt", HighsCallbackType::kCallbackIpmInterrupt)
.value("kCallbackMipSolution", HighsCallbackType::kCallbackMipSolution)
.value("kCallbackMipImprovingSolution",
HighsCallbackType::kCallbackMipImprovingSolution)
.value("kCallbackMipLogging", HighsCallbackType::kCallbackMipLogging)
Expand All @@ -1257,6 +1282,11 @@ PYBIND11_MODULE(_core, m) {
.value("kNumCallbackType", HighsCallbackType::kNumCallbackType)
.export_values();
// Classes
py::class_<readonly_ptr_wrapper<double>>(m, "readonly_ptr_wrapper_double")
.def(py::init<double*>())
.def("__getitem__", &readonly_ptr_wrapper<double>::operator[])
.def("__bool__", &readonly_ptr_wrapper<double>::is_valid)
.def("to_array", &readonly_ptr_wrapper<double>::to_array);
py::class_<HighsCallbackDataOut>(callbacks, "HighsCallbackDataOut")
.def(py::init<>())
.def_readwrite("log_type", &HighsCallbackDataOut::log_type)
Expand All @@ -1274,15 +1304,9 @@ PYBIND11_MODULE(_core, m) {
&HighsCallbackDataOut::mip_primal_bound)
.def_readwrite("mip_dual_bound", &HighsCallbackDataOut::mip_dual_bound)
.def_readwrite("mip_gap", &HighsCallbackDataOut::mip_gap)
.def_property(
"mip_solution",
[](const HighsCallbackDataOut& self) -> py::array {
// XXX: This is clearly wrong, most likely we need to have the
// length as an input data parameter
return py::array(3, self.mip_solution);
},
[](HighsCallbackDataOut& self, dense_array_t<double> new_mip_solution) {
self.mip_solution = new_mip_solution.mutable_data();
.def_property_readonly("mip_solution",
[](const HighsCallbackDataOut& self) -> readonly_ptr_wrapper<double> {
return readonly_ptr_wrapper<double>(self.mip_solution);
});
py::class_<HighsCallbackDataIn>(callbacks, "HighsCallbackDataIn")
.def(py::init<>())
Expand Down
Loading

0 comments on commit aa4ae38

Please sign in to comment.