Skip to content

Commit

Permalink
Suggestion for implementing user callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
cvanaret committed Nov 20, 2024
1 parent 461693c commit 8899f7a
Show file tree
Hide file tree
Showing 17 changed files with 108 additions and 24 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ if(NOT AMPLSOLVER)
message(WARNING "Optional library amplsolver (ASL) was not found.")
else()
message(STATUS "Library amplsolver was found.")
add_executable(uno_ampl bindings/AMPL/AMPLModel.cpp bindings/AMPL/uno_ampl.cpp)
add_executable(uno_ampl bindings/AMPL/AMPLModel.cpp bindings/AMPL/AMPLUserCallbacks.cpp bindings/AMPL/uno_ampl.cpp)

target_link_libraries(uno_ampl PUBLIC uno ${AMPLSOLVER} ${CMAKE_DL_LIBS})
add_definitions("-D HAS_AMPLSOLVER")
Expand Down
20 changes: 20 additions & 0 deletions bindings/AMPL/AMPLUserCallbacks.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) 2024 Charlie Vanaret
// Licensed under the MIT license. See LICENSE file in the project directory for details.

#include "AMPLUserCallbacks.hpp"
#include "linear_algebra/Vector.hpp"
#include "optimization/Multipliers.hpp"

namespace uno {
AMPLUserCallbacks::AMPLUserCallbacks(): UserCallbacks() { }

void AMPLUserCallbacks::notify_acceptable_iterate(const Vector<double>& /*primals*/, const Multipliers& /*multipliers*/,
double /*objective_multiplier*/) {
}

void AMPLUserCallbacks::notify_new_primals(const Vector<double>& /*primals*/) {
}

void AMPLUserCallbacks::notify_new_multipliers(const Multipliers& /*multipliers*/) {
}
} // namespace
20 changes: 20 additions & 0 deletions bindings/AMPL/AMPLUserCallbacks.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) 2024 Charlie Vanaret
// Licensed under the MIT license. See LICENSE file in the project directory for details.

#ifndef UNO_AMPLUSERCALLBACKS_H
#define UNO_AMPLUSERCALLBACKS_H

#include "tools/UserCallbacks.hpp"

namespace uno {
class AMPLUserCallbacks: public UserCallbacks {
public:
AMPLUserCallbacks();

void notify_acceptable_iterate(const Vector<double>& primals, const Multipliers& multipliers, double objective_multiplier) override;
void notify_new_primals(const Vector<double>& primals) override;
void notify_new_multipliers(const Multipliers& multipliers) override;
};
} // namespace

#endif //UNO_AMPLUSERCALLBACKS_H
6 changes: 5 additions & 1 deletion bindings/AMPL/uno_ampl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "ingredients/constraint_relaxation_strategies/ConstraintRelaxationStrategy.hpp"
#include "ingredients/constraint_relaxation_strategies/ConstraintRelaxationStrategyFactory.hpp"
#include "AMPLModel.hpp"
#include "AMPLUserCallbacks.hpp"
#include "Uno.hpp"
#include "model/ModelFactory.hpp"
#include "options/DefaultOptions.hpp"
Expand Down Expand Up @@ -50,8 +51,11 @@ namespace uno {
auto globalization_mechanism = GlobalizationMechanismFactory::create(*constraint_relaxation_strategy, options);
Uno uno = Uno(*globalization_mechanism, options);

// create the user callbacks
AMPLUserCallbacks user_callbacks{};

// solve the instance
uno.solve(*model, initial_iterate, options);
uno.solve(*model, initial_iterate, options, user_callbacks);
// std::cout << "memory_allocation_amount = " << memory_allocation_amount << '\n';
}
catch (std::exception& exception) {
Expand Down
8 changes: 6 additions & 2 deletions uno/Uno.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "options/Options.hpp"
#include "tools/Statistics.hpp"
#include "tools/Timer.hpp"
#include "tools/UserCallbacks.hpp"

namespace uno {
Uno::Uno(GlobalizationMechanism& globalization_mechanism, const Options& options) :
Expand All @@ -30,7 +31,7 @@ namespace uno {

Level Logger::level = INFO;

void Uno::solve(const Model& model, Iterate& current_iterate, const Options& options) {
void Uno::solve(const Model& model, Iterate& current_iterate, const Options& options, UserCallbacks& user_callbacks) {
Timer timer{};
Statistics statistics = Uno::create_statistics(model, options);
WarmstartInformation warmstart_information{};
Expand All @@ -54,8 +55,11 @@ namespace uno {

// compute an acceptable iterate by solving a subproblem at the current point
warmstart_information.iterate_changed();
this->globalization_mechanism.compute_next_iterate(statistics, model, current_iterate, trial_iterate, warmstart_information);
this->globalization_mechanism.compute_next_iterate(statistics, model, current_iterate, trial_iterate, warmstart_information, user_callbacks);
termination = this->termination_criteria(trial_iterate.status, major_iterations, timer.get_duration());
user_callbacks.notify_new_primals(trial_iterate.primals);
user_callbacks.notify_new_multipliers(trial_iterate.multipliers);

// the trial iterate becomes the current iterate for the next iteration
std::swap(current_iterate, trial_iterate);
}
Expand Down
3 changes: 2 additions & 1 deletion uno/Uno.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ namespace uno {
class Options;
class Statistics;
class Timer;
class UserCallbacks;

class Uno {
public:
Uno(GlobalizationMechanism& globalization_mechanism, const Options& options);

void solve(const Model& model, Iterate& initial_iterate, const Options& options);
void solve(const Model& model, Iterate& initial_iterate, const Options& options, UserCallbacks& user_callbacks);

static std::string current_version();
static void print_available_strategies();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace uno {
class Subproblem;
template <typename IndexType, typename ElementType>
class SymmetricMatrix;
class UserCallbacks;
template <typename ElementType>
class Vector;
struct WarmstartInformation;
Expand All @@ -50,7 +51,7 @@ namespace uno {

// trial iterate acceptance
[[nodiscard]] virtual bool is_iterate_acceptable(Statistics& statistics, Iterate& current_iterate, Iterate& trial_iterate, const Direction& direction,
double step_length, WarmstartInformation& warmstart_information) = 0;
double step_length, WarmstartInformation& warmstart_information, UserCallbacks& user_callbacks) = 0;
[[nodiscard]] TerminationStatus check_termination(Iterate& iterate);

// primal-dual residuals
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
#include "model/Model.hpp"
#include "optimization/Iterate.hpp"
#include "optimization/WarmstartInformation.hpp"
#include "symbolic/VectorView.hpp"
#include "options/Options.hpp"
#include "symbolic/VectorView.hpp"
#include "tools/UserCallbacks.hpp"

namespace uno {
FeasibilityRestoration::FeasibilityRestoration(const Model& model, const Options& options) :
Expand Down Expand Up @@ -148,7 +149,7 @@ namespace uno {
}

bool FeasibilityRestoration::is_iterate_acceptable(Statistics& statistics, Iterate& current_iterate, Iterate& trial_iterate, const Direction& direction,
double step_length, WarmstartInformation& warmstart_information) {
double step_length, WarmstartInformation& warmstart_information, UserCallbacks& user_callbacks) {
// TODO pick right multipliers
this->subproblem->postprocess_iterate(this->current_problem(), trial_iterate);
this->compute_progress_measures(current_iterate, trial_iterate);
Expand Down Expand Up @@ -176,6 +177,11 @@ namespace uno {
predicted_reduction, this->current_problem().get_objective_multiplier());
}
ConstraintRelaxationStrategy::set_progress_statistics(statistics, trial_iterate);
if (accept_iterate) {
user_callbacks.notify_acceptable_iterate(trial_iterate.primals,
this->current_phase == Phase::OPTIMALITY ? trial_iterate.multipliers : trial_iterate.feasibility_multipliers,
this->current_problem().get_objective_multiplier());
}
return accept_iterate;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace uno {

// trial iterate acceptance
[[nodiscard]] bool is_iterate_acceptable(Statistics& statistics, Iterate& current_iterate, Iterate& trial_iterate, const Direction& direction,
double step_length, WarmstartInformation& warmstart_information) override;
double step_length, WarmstartInformation& warmstart_information, UserCallbacks& user_callbacks) override;

// primal-dual residuals
void compute_primal_dual_residuals(Iterate& iterate) override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "symbolic/VectorView.hpp"
#include "options/Options.hpp"
#include "tools/Statistics.hpp"
#include "tools/UserCallbacks.hpp"

/*
* Infeasibility detection and SQP methods for nonlinear optimization
Expand Down Expand Up @@ -233,7 +234,7 @@ namespace uno {
}

bool l1Relaxation::is_iterate_acceptable(Statistics& statistics, Iterate& current_iterate, Iterate& trial_iterate, const Direction& direction,
double step_length, WarmstartInformation& /*warmstart_information*/) {
double step_length, WarmstartInformation& /*warmstart_information*/, UserCallbacks& user_callbacks) {
this->subproblem->postprocess_iterate(this->l1_relaxed_problem, trial_iterate);
this->compute_progress_measures(current_iterate, trial_iterate);
trial_iterate.objective_multiplier = this->l1_relaxed_problem.get_objective_multiplier();
Expand All @@ -254,6 +255,7 @@ namespace uno {
if (accept_iterate) {
this->check_exact_relaxation(trial_iterate);
// this->set_dual_residuals_statistics(statistics, trial_iterate);
user_callbacks.notify_acceptable_iterate(trial_iterate.primals, trial_iterate.multipliers, this->penalty_parameter);
}
this->set_progress_statistics(statistics, trial_iterate);
return accept_iterate;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace uno {

// trial iterate acceptance
[[nodiscard]] bool is_iterate_acceptable(Statistics& statistics, Iterate& current_iterate, Iterate& trial_iterate, const Direction& direction,
double step_length, WarmstartInformation& warmstart_information) override;
double step_length, WarmstartInformation& warmstart_information, UserCallbacks& user_callbacks) override;

// primal-dual residuals
void compute_primal_dual_residuals(Iterate& iterate) override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,17 @@ namespace uno {
}

void BacktrackingLineSearch::compute_next_iterate(Statistics& statistics, const Model& model, Iterate& current_iterate, Iterate& trial_iterate,
WarmstartInformation& warmstart_information) {
WarmstartInformation& warmstart_information, UserCallbacks& user_callbacks) {
DEBUG2 << "Current iterate\n" << current_iterate << '\n';

this->constraint_relaxation_strategy.compute_feasible_direction(statistics, current_iterate, this->direction, warmstart_information);
BacktrackingLineSearch::check_unboundedness(this->direction);
this->backtrack_along_direction(statistics, model, current_iterate, trial_iterate, warmstart_information);
this->backtrack_along_direction(statistics, model, current_iterate, trial_iterate, warmstart_information, user_callbacks);
}

// go a fraction along the direction by finding an acceptable step length
void BacktrackingLineSearch::backtrack_along_direction(Statistics& statistics, const Model& model, Iterate& current_iterate,
Iterate& trial_iterate, WarmstartInformation& warmstart_information) {
Iterate& trial_iterate, WarmstartInformation& warmstart_information, UserCallbacks& user_callbacks) {
double step_length = 1.;
bool termination = false;
size_t number_iterations = 0;
Expand All @@ -59,7 +59,7 @@ namespace uno {
this->scale_duals_with_step_length ? step_length : 1.);

is_acceptable = this->constraint_relaxation_strategy.is_iterate_acceptable(statistics, current_iterate, trial_iterate, this->direction,
step_length, warmstart_information);
step_length, warmstart_information, user_callbacks);
this->set_statistics(statistics, trial_iterate, this->direction, step_length, number_iterations);
}
catch (const EvaluationError& e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ namespace uno {

void initialize(Statistics& statistics, Iterate& initial_iterate, const Options& options) override;
void compute_next_iterate(Statistics& statistics, const Model& model, Iterate& current_iterate, Iterate& trial_iterate,
WarmstartInformation& warmstart_information) override;
WarmstartInformation& warmstart_information, UserCallbacks& user_callbacks) override;

private:
const double backtracking_ratio;
const double minimum_step_length;
const bool scale_duals_with_step_length;

void backtrack_along_direction(Statistics& statistics, const Model& model, Iterate& current_iterate, Iterate& trial_iterate,
WarmstartInformation& warmstart_information);
WarmstartInformation& warmstart_information, UserCallbacks& user_callbacks);
[[nodiscard]] bool terminate_with_small_step_length(Statistics& statistics, Iterate& trial_iterate);
[[nodiscard]] double decrease_step_length(double step_length) const;
static void check_unboundedness(const Direction& direction);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace uno {
class Model;
class Options;
class Statistics;
class UserCallbacks;
struct WarmstartInformation;

class GlobalizationMechanism {
Expand All @@ -22,7 +23,7 @@ namespace uno {

virtual void initialize(Statistics& statistics, Iterate& initial_iterate, const Options& options) = 0;
virtual void compute_next_iterate(Statistics& statistics, const Model& model, Iterate& current_iterate, Iterate& trial_iterate,
WarmstartInformation& warmstart_information) = 0;
WarmstartInformation& warmstart_information, UserCallbacks& user_callbacks) = 0;

[[nodiscard]] size_t get_hessian_evaluation_count() const;
[[nodiscard]] size_t get_number_subproblems_solved() const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ namespace uno {
}

void TrustRegionStrategy::compute_next_iterate(Statistics& statistics, const Model& model, Iterate& current_iterate, Iterate& trial_iterate,
WarmstartInformation& warmstart_information) {
WarmstartInformation& warmstart_information, UserCallbacks& user_callbacks) {
DEBUG2 << "Current iterate\n" << current_iterate << '\n';

size_t number_iterations = 0;
Expand Down Expand Up @@ -77,7 +77,8 @@ namespace uno {
GlobalizationMechanism::assemble_trial_iterate(model, current_iterate, trial_iterate, this->direction, 1., 1.);
this->reset_active_trust_region_multipliers(model, this->direction, trial_iterate);

is_acceptable = this->is_iterate_acceptable(statistics, current_iterate, trial_iterate, this->direction, warmstart_information);
is_acceptable = this->is_iterate_acceptable(statistics, current_iterate, trial_iterate, this->direction, warmstart_information,
user_callbacks);
if (is_acceptable) {
this->constraint_relaxation_strategy.set_dual_residuals_statistics(statistics, trial_iterate);
this->reset_radius();
Expand Down Expand Up @@ -122,9 +123,9 @@ namespace uno {

// the trial iterate is accepted by the constraint relaxation strategy or if the step is small and we cannot switch to solving the feasibility problem
bool TrustRegionStrategy::is_iterate_acceptable(Statistics& statistics, Iterate& current_iterate, Iterate& trial_iterate,
const Direction& direction, WarmstartInformation& warmstart_information) {
const Direction& direction, WarmstartInformation& warmstart_information, UserCallbacks& user_callbacks) {
bool accept_iterate = this->constraint_relaxation_strategy.is_iterate_acceptable(statistics, current_iterate, trial_iterate, direction, 1.,
warmstart_information);
warmstart_information, user_callbacks);
this->set_statistics(statistics, trial_iterate, direction);
if (accept_iterate) {
trial_iterate.status = this->constraint_relaxation_strategy.check_termination(trial_iterate);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace uno {

void initialize(Statistics& statistics, Iterate& initial_iterate, const Options& options) override;
void compute_next_iterate(Statistics& statistics, const Model& model, Iterate& current_iterate, Iterate& trial_iterate,
WarmstartInformation& warmstart_information) override;
WarmstartInformation& warmstart_information, UserCallbacks& user_callbacks) override;

private:
double radius; /*!< Current trust region radius */
Expand All @@ -26,7 +26,7 @@ namespace uno {
const double tolerance;

bool is_iterate_acceptable(Statistics& statistics, Iterate& current_iterate, Iterate& trial_iterate, const Direction& direction,
WarmstartInformation& warmstart_information);
WarmstartInformation& warmstart_information, UserCallbacks& user_callbacks);
void possibly_increase_radius(double step_norm);
void decrease_radius(double step_norm);
void decrease_radius();
Expand Down
24 changes: 24 additions & 0 deletions uno/tools/UserCallbacks.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (c) 2024 Charlie Vanaret
// Licensed under the MIT license. See LICENSE file in the project directory for details.

#ifndef UNO_USERCALLBACKS_H
#define UNO_USERCALLBACKS_H

namespace uno {
// forward declarations
class Multipliers;
template <class ElementType>
class Vector;

class UserCallbacks {
public:
UserCallbacks() = default;
virtual ~UserCallbacks() = default;

virtual void notify_acceptable_iterate(const Vector<double>& primals, const Multipliers& multipliers, double objective_multiplier) = 0;
virtual void notify_new_primals(const Vector<double>& primals) = 0;
virtual void notify_new_multipliers(const Multipliers& multipliers) = 0;
};
} // namespace

#endif //UNO_USERCALLBACKS_H

0 comments on commit 8899f7a

Please sign in to comment.