Skip to content

Commit

Permalink
Implement Google-OR Tools backend.
Browse files Browse the repository at this point in the history
  • Loading branch information
sukritkalra committed Sep 26, 2023
1 parent da181d8 commit 2c67dfe
Show file tree
Hide file tree
Showing 6 changed files with 281 additions and 24 deletions.
21 changes: 19 additions & 2 deletions schedulers/tetrisched/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
FetchContent_MakeAvailable(googletest)
set(CPLEX_DIR "/opt/ibm/ILOG/CPLEX_Studio2211")
set(GUROBI_DIR "/opt/gurobi1001/linux64")
set(ORTOOLS_DIR "/usr/local/include/ortools")

if (EXISTS "${CPLEX_DIR}")
message("-- Adding CPLEX : ${CPLEX_DIR}")
Expand All @@ -45,6 +46,7 @@ else()
set(CPLEX_LINK_DIRS "")
set(CPLEX_LINK_LIBRARIES "")
endif()

if (EXISTS "${GUROBI_DIR}")
message("-- Adding GUROBI : ${GUROBI_DIR}")
add_compile_definitions(_TETRISCHED_WITH_GUROBI_)
Expand All @@ -56,15 +58,23 @@ if (EXISTS "${GUROBI_DIR}")
set(GUROBI_LINK_LIBRARIES
"gurobi_c++"
"gurobi100")

else()
message("-- Not Adding GUROBI")
set(GUROBI_INCLUDE_DIRS "")
set(GUROBI_LINK_DIRS "")
set(GUROBI_LINK_LIBRARIES "")

endif()

if (EXISTS "${ORTOOLS_DIR}")
message("-- Adding Google OR-Tools : ${ORTOOLS_DIR}")
add_compile_definitions(_TETRISCHED_WITH_ORTOOLS_)

set(ORTOOLS_LINK_LIBRARIES
"ortools")
else()
message("-- Not adding OR-Tools")
set(ORTOOLS_LINK_LIBRARIES "")
endif()

set(TETRISCHED_SOURCE
"src/Expression.cpp"
Expand All @@ -78,6 +88,9 @@ endif()
if (EXISTS "${GUROBI_DIR}")
LIST(APPEND TETRISCHED_SOURCE "src/GurobiSolver.cpp")
endif()
if (EXISTS "${ORTOOLS_DIR}")
LIST(APPEND TETRISCHED_SOURCE "src/GoogleCPSolver.cpp")
endif()

add_library(tetrisched SHARED ${TETRISCHED_SOURCE})
target_include_directories(tetrisched PRIVATE include)
Expand All @@ -93,6 +106,9 @@ target_include_directories(tetrisched PRIVATE ${GUROBI_INCLUDE_DIRS})
target_link_directories(tetrisched PRIVATE ${GUROBI_LINK_DIRS})
target_link_libraries(tetrisched PRIVATE ${GUROBI_LINK_LIBRARIES})

# Include and link Google-ORTools.
target_link_libraries(tetrisched PRIVATE ${ORTOOLS_LINK_LIBRARIES})


# add_executable(tetrisched_main ${TETRISCHED_SOURCE})
# target_include_directories(tetrisched_main PRIVATE ${CPLEX_INCLUDE_DIRS})
Expand Down Expand Up @@ -127,6 +143,7 @@ target_link_libraries(
tetrisched
${CPLEX_LINK_LIBRARIES}
${GUROBI_LINK_LIBRARIES}
${ORTOOLS_LINK_LIBRARIES}
)

include(GoogleTest)
Expand Down
50 changes: 50 additions & 0 deletions schedulers/tetrisched/include/tetrisched/GoogleCPSolver.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#ifndef _TETRISCHED_GOOGLE_CP_SOLVER_HPP_
#define _TETRISCHED_GOOGLE_CP_SOLVER_HPP_

#include <variant>

#include "ortools/sat/cp_model.h"
#include "tetrisched/Solver.hpp"

namespace tetrisched {
// Import the relevant names from the ORTools namespace.
using operations_research::sat::BoolVar;
using operations_research::sat::CpModelBuilder;
using operations_research::sat::IntVar;
using operations_research::sat::LinearExpr;

class GoogleCPSolver : public Solver {
using GoogleCPVarType = std::variant<IntVar, BoolVar>;

private:
/// The SolverModelPtr associated with this GoogleCPSolver.
SolverModelPtr solverModel;
/// The ORTools model associated with this GoogleCPSolver.
std::unique_ptr<CpModelBuilder> cpModel;
/// A map from the Variable ID to the ORTools variable.
std::unordered_map<uint32_t, GoogleCPVarType> cpVariables;

/// Translates the VariablePtr into an IntVar / BoolVar.
GoogleCPVarType translateVariable(const VariablePtr& variable) const;
/// Translates the ConstraintPtr into a Constraint and adds it to the model.
operations_research::sat::Constraint translateConstraint(
const ConstraintPtr& constraint);
/// Translates the ObjectiveFunctionPtr into an Expression in Google OR-Tools.
LinearExpr translateObjectiveFunction(
const ObjectiveFunctionPtr& objectiveFunction) const;

public:
/// Create a new CP-SAT solver.
GoogleCPSolver();

/// Retrieve a pointer to the SolverModel.
SolverModelPtr getModel() override;

/// Translates the SolverModel into a CP-SAT model.
void translateModel() override;

/// Export the constructed model to the given file.
void exportModel(const std::string& fileName) override;
};
} // namespace tetrisched
#endif // _TETRISCHED_GOOGLE_CP_SOLVER_HPP_
4 changes: 4 additions & 0 deletions schedulers/tetrisched/include/tetrisched/SolverModel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class VariableT {
/// Annotate friend classes for Solvers so that they have access to internals.
friend tetrisched::CPLEXSolver;
friend tetrisched::GurobiSolver;
friend tetrisched::GoogleCPSolver;
};

// Specialize the VariableT class for Integer type.
Expand Down Expand Up @@ -147,6 +148,7 @@ class ConstraintT {
/// Annotate friend classes for Solvers so that they have access to internals.
friend tetrisched::CPLEXSolver;
friend tetrisched::GurobiSolver;
friend tetrisched::GoogleCPSolver;
};

// Specialize the Constraint class for Integer.
Expand Down Expand Up @@ -191,6 +193,7 @@ class ObjectiveFunctionT {
/// Annotate friend classes for Solvers so that they have access to internals.
friend tetrisched::CPLEXSolver;
friend tetrisched::GurobiSolver;
friend tetrisched::GoogleCPSolver;
};

// Specialize the ObjectiveFunction class for Integer.
Expand Down Expand Up @@ -242,6 +245,7 @@ class SolverModelT {
/// back to the user.
friend tetrisched::CPLEXSolver;
friend tetrisched::GurobiSolver;
friend tetrisched::GoogleCPSolver;
};

// Specialize the SolverModel class for Integer.
Expand Down
1 change: 1 addition & 0 deletions schedulers/tetrisched/include/tetrisched/Types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ using ExpressionPtr = std::unique_ptr<Expression>;
/// them as friend classes in the model.
class CPLEXSolver;
class GurobiSolver;
class GoogleCPSolver;
} // namespace tetrisched

#endif
165 changes: 165 additions & 0 deletions schedulers/tetrisched/src/GoogleCPSolver.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
#include "tetrisched/GoogleCPSolver.hpp"

namespace tetrisched {
GoogleCPSolver::GoogleCPSolver() : cpModel(new CpModelBuilder()) {}

SolverModelPtr GoogleCPSolver::getModel() {
if (!solverModel) {
solverModel = std::shared_ptr<SolverModel>(new SolverModel());
}
return solverModel;
}

GoogleCPSolver::GoogleCPVarType GoogleCPSolver::translateVariable(
const VariablePtr& variable) const {
// Check that a continuous variable is not passed in.
if (variable->variableType == VariableType::VAR_CONTINUOUS) {
throw exceptions::SolverException(
"Cannot construct a continuous variable in ORTools.");
} else if (variable->variableType == VariableType::VAR_INTEGER) {
// Check that the Variable has been given both a lower and an upper bound.
// This is required to set a Domain in ORTools.
if (!(variable->lowerBound.has_value() &&
variable->upperBound.has_value())) {
throw exceptions::SolverException(
"Cannot construct a variable without a lower and upper bound in "
"ORTools: " +
variable->toString());
}

// Construct the domain for the variable.
const operations_research::Domain domain(variable->lowerBound.value(),
variable->upperBound.value());
return cpModel->NewIntVar(domain).WithName(variable->variableName);
} else if (variable->variableType == VariableType::VAR_INDICATOR) {
// Construct the Indicator variable.
return cpModel->NewBoolVar().WithName(variable->variableName);
} else {
throw exceptions::SolverException("Cannot construct a variable of type " +
std::to_string(variable->variableType) +
" in ORTools.");
}
}

operations_research::sat::Constraint GoogleCPSolver::translateConstraint(
const ConstraintPtr& constraint) {
// TODO (Sukrit): We are currently assuming that all constraints and
// objectives are linear. We may need to support quadratic constraints.
operations_research::sat::LinearExpr constraintExpr;

// Construct all the terms.
for (const auto& [coefficient, variable] : constraint->terms) {
if (variable) {
switch (variable->variableType) {
case VariableType::VAR_INTEGER:
constraintExpr +=
coefficient * std::get<IntVar>(cpVariables.at(variable->getId()));
break;
case VariableType::VAR_INDICATOR:
constraintExpr +=
coefficient *
std::get<BoolVar>(cpVariables.at(variable->getId()));
break;
default:
throw exceptions::SolverException(
"Cannot construct a constraint with a variable of type " +
std::to_string(variable->variableType) + " in ORTools.");
}
} else {
constraintExpr += coefficient;
}
}

// Translate the constraint.
switch (constraint->constraintType) {
case ConstraintType::CONSTR_EQ:
return cpModel->AddEquality(constraintExpr, constraint->rightHandSide)
.WithName(constraint->getName());
case ConstraintType::CONSTR_GE:
return cpModel
->AddGreaterOrEqual(constraintExpr, constraint->rightHandSide)
.WithName(constraint->getName());
case ConstraintType::CONSTR_LE:
return cpModel->AddLessOrEqual(constraintExpr, constraint->rightHandSide)
.WithName(constraint->getName());
default:
throw exceptions::SolverException(
"Invalid constraint type: " +
std::to_string(constraint->constraintType));
}
}

LinearExpr GoogleCPSolver::translateObjectiveFunction(
const ObjectiveFunctionPtr& objectiveFunction) const {
LinearExpr objectiveExpr;

// Construct all the terms.
for (const auto& [coefficient, variable] : objectiveFunction->terms) {
if (variable) {
switch (variable->variableType) {
case VariableType::VAR_INTEGER:
objectiveExpr +=
coefficient * std::get<IntVar>(cpVariables.at(variable->getId()));
break;
case VariableType::VAR_INDICATOR:
objectiveExpr += coefficient *
std::get<BoolVar>(cpVariables.at(variable->getId()));
break;
default:
throw exceptions::SolverException(
"Cannot construct an objective function with a variable of "
"type " +
std::to_string(variable->variableType) + " in ORTools.");
}
} else {
objectiveExpr += coefficient;
}
}
return objectiveExpr;
}

void GoogleCPSolver::translateModel() {
if (!solverModel) {
throw tetrisched::exceptions::SolverException(
"Empty SolverModel for GurobiSolver. Nothing to translate!");
}

// Generate all the variables and keep a cache of the variable indices
// to the ORTools variables.
for (const auto& [variableId, variable] : solverModel->variables) {
TETRISCHED_DEBUG("Adding variable " << variable->getName() << "("
<< variable->getId()
<< ") to ORTools model.");
cpVariables[variableId] = translateVariable(variable);
}

// Generate all the constraints.
for (const auto& [constraintId, constraint] : solverModel->constraints) {
TETRISCHED_DEBUG("Adding constraint " << constraint->getName() << "("
<< constraint->getId()
<< ") to ORTools model.");
auto _ = translateConstraint(constraint);
}

// Translate the objective function.
auto objectiveExpr =
translateObjectiveFunction(solverModel->objectiveFunction);
switch (solverModel->objectiveFunction->objectiveType) {
case ObjectiveType::OBJ_MINIMIZE:
cpModel->Minimize(objectiveExpr);
break;
case ObjectiveType::OBJ_MAXIMIZE:
cpModel->Maximize(objectiveExpr);
break;
default:
throw exceptions::SolverException(
"Invalid objective type: " +
std::to_string(solverModel->objectiveFunction->objectiveType));
}
}

void GoogleCPSolver::exportModel(const std::string& fileName) {
cpModel->ExportToFile(fileName);
}

} // namespace tetrisched
Loading

0 comments on commit 2c67dfe

Please sign in to comment.