diff --git a/pywr-core/benches/random_models.rs b/pywr-core/benches/random_models.rs index 718226db..7106cf9a 100644 --- a/pywr-core/benches/random_models.rs +++ b/pywr-core/benches/random_models.rs @@ -387,6 +387,9 @@ fn bench_ocl_chunks(c: &mut Criterion) { ) } +#[cfg(not(feature = "ipm-ocl"))] +fn bench_ocl_chunks(c: &mut Criterion) {} + /// Benchmark a large number of scenarios using various solvers fn bench_hyper_scenarios(c: &mut Criterion) { // Go from largest to smallest diff --git a/pywr-core/src/aggregated_node.rs b/pywr-core/src/aggregated_node.rs index ddcc3c1e..eec4f9c4 100644 --- a/pywr-core/src/aggregated_node.rs +++ b/pywr-core/src/aggregated_node.rs @@ -344,6 +344,6 @@ mod tests { let model = Model::new(default_time_domain().into(), network); - run_all_solvers(&model, &["cbc", "highs"]); + run_all_solvers(&model, &["cbc", "highs", "ipm-ocl", "ipm-simd"]); } } diff --git a/pywr-core/src/network.rs b/pywr-core/src/network.rs index f0f55ca7..9aad189d 100644 --- a/pywr-core/src/network.rs +++ b/pywr-core/src/network.rs @@ -1641,8 +1641,6 @@ mod tests { use crate::parameters::{ActivationFunction, ControlCurveInterpolatedParameter, Parameter}; use crate::recorders::AssertionRecorder; use crate::scenario::{ScenarioDomain, ScenarioGroupCollection, ScenarioIndex}; - #[cfg(feature = "clipm")] - use crate::solvers::{ClIpmF64Solver, SimdIpmF64Solver}; use crate::solvers::{ClpSolver, ClpSolverSettings}; use crate::test_utils::{run_all_solvers, simple_model, simple_storage_model}; use float_cmp::assert_approx_eq; diff --git a/pywr-core/src/solvers/ipm_ocl/mod.rs b/pywr-core/src/solvers/ipm_ocl/mod.rs index 231ecc7f..898acde7 100644 --- a/pywr-core/src/solvers/ipm_ocl/mod.rs +++ b/pywr-core/src/solvers/ipm_ocl/mod.rs @@ -354,7 +354,7 @@ impl BuiltSolver { .iter() .map(|state| { let (avail, missing) = node - .get_current_available_volume_bounds(network, state) + .get_current_available_volume_bounds(state) .expect("Volumes bounds expected for Storage nodes."); (avail / dt, missing / dt) }) @@ -574,6 +574,10 @@ pub struct ClIpmF32Solver { impl MultiStateSolver for ClIpmF32Solver { type Settings = ClIpmSolverSettings; + fn name() -> &'static str { + "ipm-ocl" + } + fn features() -> &'static [SolverFeatures] { &[] } diff --git a/pywr-core/src/solvers/ipm_simd/mod.rs b/pywr-core/src/solvers/ipm_simd/mod.rs index 75a38da7..1d9200d6 100644 --- a/pywr-core/src/solvers/ipm_simd/mod.rs +++ b/pywr-core/src/solvers/ipm_simd/mod.rs @@ -383,7 +383,7 @@ where .iter() .map(|state| { let (avail, missing) = node - .get_current_available_volume_bounds(network, state) + .get_current_available_volume_bounds(state) .expect("Volumes bounds expected for Storage nodes."); (avail / dt, missing / dt) }) @@ -611,6 +611,10 @@ where { type Settings = SimdIpmSolverSettings; + fn name() -> &'static str { + "ipm-simd" + } + fn features() -> &'static [SolverFeatures] { &[] } diff --git a/pywr-core/src/solvers/mod.rs b/pywr-core/src/solvers/mod.rs index 11d6ccaa..b93d3bb8 100644 --- a/pywr-core/src/solvers/mod.rs +++ b/pywr-core/src/solvers/mod.rs @@ -93,6 +93,8 @@ pub trait Solver: Send { pub trait MultiStateSolver: Send { type Settings; + + fn name() -> &'static str; /// An array of features that this solver provides. fn features() -> &'static [SolverFeatures]; fn setup(model: &Network, num_scenarios: usize, settings: &Self::Settings) -> Result, PywrError>; diff --git a/pywr-core/src/test_utils.rs b/pywr-core/src/test_utils.rs index 0a9f9666..8f629a0a 100644 --- a/pywr-core/src/test_utils.rs +++ b/pywr-core/src/test_utils.rs @@ -15,7 +15,7 @@ use crate::solvers::ClIpmF64Solver; use crate::solvers::HighsSolver; #[cfg(feature = "ipm-simd")] use crate::solvers::SimdIpmF64Solver; -use crate::solvers::{ClpSolver, Solver, SolverSettings}; +use crate::solvers::{ClpSolver, MultiStateSolver, Solver, SolverSettings}; use crate::timestep::{TimeDomain, TimestepDuration, Timestepper}; use crate::PywrError; use chrono::{Days, NaiveDate}; @@ -180,22 +180,10 @@ pub fn run_all_solvers(model: &Model, solvers_without_features: &[&str]) { check_features_and_run::(model, !solvers_without_features.contains(&"highs")); #[cfg(feature = "ipm-simd")] - { - if model.check_multi_scenario_solver_features::>() { - model - .run_multi_scenario::>(&Default::default()) - .expect("Failed to solve with SIMD IPM"); - } - } + check_features_and_run_multi::>(model, !solvers_without_features.contains(&"ipm-simd")); #[cfg(feature = "ipm-ocl")] - { - if model.check_multi_scenario_solver_features::() { - model - .run_multi_scenario::(&Default::default()) - .expect("Failed to solve with OpenCl IPM"); - } - } + check_features_and_run_multi::(&Default::default(), !solvers_without_features.contains(&"ipm-ocl")); } /// Check features and @@ -223,6 +211,31 @@ where } } +/// Check features and run with a multi-scenario solver +fn check_features_and_run_multi(model: &Model, expect_features: bool) +where + S: MultiStateSolver, + ::Settings: SolverSettings + Default, +{ + let has_features = model.check_multi_scenario_solver_features::(); + if expect_features { + assert!( + has_features, + "Solver `{}` was expected to have the required features", + S::name() + ); + model + .run_multi_scenario::(&Default::default()) + .expect(&format!("Failed to solve with: {}", S::name())); + } else { + assert!( + !has_features, + "Solver `{}` was not expected to have the required features", + S::name() + ); + } +} + /// Make a simple system with random inputs. fn make_simple_system( network: &mut Network, diff --git a/pywr-core/src/virtual_storage.rs b/pywr-core/src/virtual_storage.rs index f906cece..0f4e80e9 100644 --- a/pywr-core/src/virtual_storage.rs +++ b/pywr-core/src/virtual_storage.rs @@ -414,7 +414,7 @@ mod tests { let domain = default_timestepper().try_into().unwrap(); let model = Model::new(domain, network); // Test all solvers - run_all_solvers(&model, &["highs"]); + run_all_solvers(&model, &["highs", "ipm-ocl", "ipm-simd"]); } #[test] @@ -441,7 +441,7 @@ mod tests { network.add_recorder(Box::new(recorder)).unwrap(); // Test all solvers - run_all_solvers(&model, &["highs"]); + run_all_solvers(&model, &["highs", "ipm-ocl", "ipm-simd"]); } #[test] @@ -481,6 +481,6 @@ mod tests { network.add_recorder(Box::new(recorder)).unwrap(); // Test all solvers - run_all_solvers(&model, &["highs"]); + run_all_solvers(&model, &["highs", "ipm-ocl", "ipm-simd"]); } }