Skip to content

Commit

Permalink
WIP: Fixing tests with IPM solvers.
Browse files Browse the repository at this point in the history
  • Loading branch information
jetuk committed Jul 16, 2024
1 parent f8d05c5 commit 9b8261c
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 23 deletions.
3 changes: 3 additions & 0 deletions pywr-core/benches/random_models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,9 @@ fn bench_ocl_chunks(c: &mut Criterion) {
)
}

#[cfg(not(feature = "ipm-ocl"))]
fn bench_ocl_chunks(c: &mut Criterion) {}

/// Benchmark a large number of scenarios using various solvers
fn bench_hyper_scenarios(c: &mut Criterion) {
// Go from largest to smallest
Expand Down
2 changes: 1 addition & 1 deletion pywr-core/src/aggregated_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,6 @@ mod tests {

let model = Model::new(default_time_domain().into(), network);

run_all_solvers(&model, &["cbc", "highs"]);
run_all_solvers(&model, &["cbc", "highs", "ipm-ocl", "ipm-simd"]);
}
}
2 changes: 0 additions & 2 deletions pywr-core/src/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1641,8 +1641,6 @@ mod tests {
use crate::parameters::{ActivationFunction, ControlCurveInterpolatedParameter, Parameter};
use crate::recorders::AssertionRecorder;
use crate::scenario::{ScenarioDomain, ScenarioGroupCollection, ScenarioIndex};
#[cfg(feature = "clipm")]
use crate::solvers::{ClIpmF64Solver, SimdIpmF64Solver};
use crate::solvers::{ClpSolver, ClpSolverSettings};
use crate::test_utils::{run_all_solvers, simple_model, simple_storage_model};
use float_cmp::assert_approx_eq;
Expand Down
6 changes: 5 additions & 1 deletion pywr-core/src/solvers/ipm_ocl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ impl BuiltSolver {
.iter()
.map(|state| {
let (avail, missing) = node
.get_current_available_volume_bounds(network, state)
.get_current_available_volume_bounds(state)
.expect("Volumes bounds expected for Storage nodes.");
(avail / dt, missing / dt)
})
Expand Down Expand Up @@ -574,6 +574,10 @@ pub struct ClIpmF32Solver {
impl MultiStateSolver for ClIpmF32Solver {
type Settings = ClIpmSolverSettings;

fn name() -> &'static str {
"ipm-ocl"
}

fn features() -> &'static [SolverFeatures] {
&[]
}
Expand Down
6 changes: 5 additions & 1 deletion pywr-core/src/solvers/ipm_simd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ where
.iter()
.map(|state| {
let (avail, missing) = node
.get_current_available_volume_bounds(network, state)
.get_current_available_volume_bounds(state)
.expect("Volumes bounds expected for Storage nodes.");
(avail / dt, missing / dt)
})
Expand Down Expand Up @@ -611,6 +611,10 @@ where
{
type Settings = SimdIpmSolverSettings<f64, N>;

fn name() -> &'static str {
"ipm-simd"
}

fn features() -> &'static [SolverFeatures] {
&[]
}
Expand Down
2 changes: 2 additions & 0 deletions pywr-core/src/solvers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ pub trait Solver: Send {

pub trait MultiStateSolver: Send {
type Settings;

fn name() -> &'static str;
/// An array of features that this solver provides.
fn features() -> &'static [SolverFeatures];
fn setup(model: &Network, num_scenarios: usize, settings: &Self::Settings) -> Result<Box<Self>, PywrError>;
Expand Down
43 changes: 28 additions & 15 deletions pywr-core/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::solvers::ClIpmF64Solver;
use crate::solvers::HighsSolver;
#[cfg(feature = "ipm-simd")]
use crate::solvers::SimdIpmF64Solver;
use crate::solvers::{ClpSolver, Solver, SolverSettings};
use crate::solvers::{ClpSolver, MultiStateSolver, Solver, SolverSettings};
use crate::timestep::{TimeDomain, TimestepDuration, Timestepper};
use crate::PywrError;
use chrono::{Days, NaiveDate};
Expand Down Expand Up @@ -180,22 +180,10 @@ pub fn run_all_solvers(model: &Model, solvers_without_features: &[&str]) {
check_features_and_run::<HighsSolver>(model, !solvers_without_features.contains(&"highs"));

#[cfg(feature = "ipm-simd")]
{
if model.check_multi_scenario_solver_features::<SimdIpmF64Solver<4>>() {
model
.run_multi_scenario::<SimdIpmF64Solver<4>>(&Default::default())
.expect("Failed to solve with SIMD IPM");
}
}
check_features_and_run_multi::<SimdIpmF64Solver<4>>(model, !solvers_without_features.contains(&"ipm-simd"));

#[cfg(feature = "ipm-ocl")]
{
if model.check_multi_scenario_solver_features::<ClIpmF64Solver>() {
model
.run_multi_scenario::<ClIpmF64Solver>(&Default::default())
.expect("Failed to solve with OpenCl IPM");
}
}
check_features_and_run_multi::<ClIpmF64Solver>(&Default::default(), !solvers_without_features.contains(&"ipm-ocl"));
}

/// Check features and
Expand Down Expand Up @@ -223,6 +211,31 @@ where
}
}

/// Check features and run with a multi-scenario solver
fn check_features_and_run_multi<S>(model: &Model, expect_features: bool)
where
S: MultiStateSolver,
<S as MultiStateSolver>::Settings: SolverSettings + Default,
{
let has_features = model.check_multi_scenario_solver_features::<S>();
if expect_features {
assert!(
has_features,
"Solver `{}` was expected to have the required features",
S::name()
);
model
.run_multi_scenario::<S>(&Default::default())
.expect(&format!("Failed to solve with: {}", S::name()));
} else {
assert!(
!has_features,
"Solver `{}` was not expected to have the required features",
S::name()
);
}
}

/// Make a simple system with random inputs.
fn make_simple_system<R: Rng>(
network: &mut Network,
Expand Down
6 changes: 3 additions & 3 deletions pywr-core/src/virtual_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ mod tests {
let domain = default_timestepper().try_into().unwrap();
let model = Model::new(domain, network);
// Test all solvers
run_all_solvers(&model, &["highs"]);
run_all_solvers(&model, &["highs", "ipm-ocl", "ipm-simd"]);
}

#[test]
Expand All @@ -441,7 +441,7 @@ mod tests {
network.add_recorder(Box::new(recorder)).unwrap();

// Test all solvers
run_all_solvers(&model, &["highs"]);
run_all_solvers(&model, &["highs", "ipm-ocl", "ipm-simd"]);
}

#[test]
Expand Down Expand Up @@ -481,6 +481,6 @@ mod tests {
network.add_recorder(Box::new(recorder)).unwrap();

// Test all solvers
run_all_solvers(&model, &["highs"]);
run_all_solvers(&model, &["highs", "ipm-ocl", "ipm-simd"]);
}
}

0 comments on commit 9b8261c

Please sign in to comment.