From 07ad27a2f3c9afd499703a144a76363d876294d5 Mon Sep 17 00:00:00 2001 From: James Tomlinson Date: Thu, 4 Jul 2024 14:49:02 +0100 Subject: [PATCH] feat: Finalise ParameterCollection API for usize and multi. Implements the same API for all 3 types. Adds checks to ensure names are unique across all parameter variants. --- pywr-core/src/lib.rs | 8 +- pywr-core/src/metric.rs | 12 + pywr-core/src/network.rs | 32 +- pywr-core/src/parameters/mod.rs | 342 ++++++++++++++++++--- pywr-schema/src/metric.rs | 7 +- pywr-schema/src/model.rs | 1 - pywr-schema/src/nodes/piecewise_storage.rs | 8 +- 7 files changed, 327 insertions(+), 83 deletions(-) diff --git a/pywr-core/src/lib.rs b/pywr-core/src/lib.rs index 25df5d7d..2a6bcf01 100644 --- a/pywr-core/src/lib.rs +++ b/pywr-core/src/lib.rs @@ -93,12 +93,12 @@ pub enum PywrError { DerivedMetricIndexNotFound(DerivedMetricIndex), #[error("node name `{0}` already exists")] NodeNameAlreadyExists(String), - #[error("parameter name `{0}` already exists at index {1}")] - ParameterNameAlreadyExists(String, ParameterIndex), + #[error("parameter name `{0}` already exists")] + ParameterNameAlreadyExists(String), #[error("index parameter name `{0}` already exists at index {1}")] - IndexParameterNameAlreadyExists(String, GeneralParameterIndex), + IndexParameterNameAlreadyExists(String, ParameterIndex), #[error("multi-value parameter name `{0}` already exists at index {1}")] - MultiValueParameterNameAlreadyExists(String, GeneralParameterIndex), + MultiValueParameterNameAlreadyExists(String, ParameterIndex), #[error("metric set name `{0}` already exists")] MetricSetNameAlreadyExists(String), #[error("recorder name `{0}` already exists at index {1}")] diff --git a/pywr-core/src/metric.rs b/pywr-core/src/metric.rs index 620cbc63..523b6cbc 100644 --- a/pywr-core/src/metric.rs +++ b/pywr-core/src/metric.rs @@ -195,6 +195,18 @@ impl From> for MetricF64 { } } +impl From<(ParameterIndex, String)> for MetricF64 { + fn from((idx, key): (ParameterIndex, String)) -> Self { + match idx { + ParameterIndex::General(idx) => Self::MultiParameterValue((idx, key)), + ParameterIndex::Simple(idx) => Self::Simple(SimpleMetricF64::MultiParameterValue((idx, key))), + ParameterIndex::Const(idx) => Self::Simple(SimpleMetricF64::Constant( + ConstantMetricF64::MultiParameterValue((idx, key)), + )), + } + } +} + impl TryFrom> for SimpleMetricF64 { type Error = PywrError; fn try_from(idx: ParameterIndex) -> Result { diff --git a/pywr-core/src/network.rs b/pywr-core/src/network.rs index 88cfb5e6..f6709afc 100644 --- a/pywr-core/src/network.rs +++ b/pywr-core/src/network.rs @@ -12,7 +12,7 @@ use crate::solvers::{MultiStateSolver, Solver, SolverFeatures, SolverTimings}; use crate::state::{MultiValue, State, StateBuilder}; use crate::timestep::Timestep; use crate::virtual_storage::{VirtualStorage, VirtualStorageBuilder, VirtualStorageIndex, VirtualStorageVec}; -use crate::{parameters, recorders, GeneralParameterIndex, NodeIndex, PywrError, RecorderIndex}; +use crate::{parameters, recorders, NodeIndex, PywrError, RecorderIndex}; use rayon::prelude::*; use std::any::Any; use std::collections::HashSet; @@ -623,7 +623,7 @@ impl Network { GeneralParameterType::Index(idx) => { let p = self .parameters - .get_usize(*idx) + .get_general_usize(*idx) .ok_or(PywrError::GeneralIndexParameterIndexNotFound(*idx))?; // .. and its internal state @@ -638,7 +638,7 @@ impl Network { GeneralParameterType::Multi(idx) => { let p = self .parameters - .get_multi(*idx) + .get_general_multi(idx) .ok_or(PywrError::GeneralMultiValueParameterIndexNotFound(*idx))?; // .. and its internal state @@ -714,7 +714,7 @@ impl Network { GeneralParameterType::Index(idx) => { let p = self .parameters - .get_usize(*idx) + .get_general_usize(*idx) .ok_or(PywrError::GeneralIndexParameterIndexNotFound(*idx))?; // .. and its internal state @@ -727,7 +727,7 @@ impl Network { GeneralParameterType::Multi(idx) => { let p = self .parameters - .get_multi(*idx) + .get_general_multi(idx) .ok_or(PywrError::GeneralMultiValueParameterIndexNotFound(*idx))?; // .. and its internal state @@ -1123,21 +1123,15 @@ impl Network { } /// Get a [`Parameter`] from its index. - pub fn get_index_parameter( - &self, - index: GeneralParameterIndex, - ) -> Result<&dyn parameters::GeneralParameter, PywrError> { + pub fn get_index_parameter(&self, index: ParameterIndex) -> Result<&dyn parameters::Parameter, PywrError> { match self.parameters.get_usize(index) { Some(p) => Ok(p), - None => Err(PywrError::GeneralIndexParameterIndexNotFound(index)), + None => Err(PywrError::IndexParameterIndexNotFound(index)), } } /// Get a `IndexParameter` from a parameter's name - pub fn get_index_parameter_by_name( - &self, - name: &str, - ) -> Result<&dyn parameters::GeneralParameter, PywrError> { + pub fn get_index_parameter_by_name(&self, name: &str) -> Result<&dyn parameters::Parameter, PywrError> { match self.parameters.get_usize_by_name(name) { Some(parameter) => Ok(parameter), None => Err(PywrError::ParameterNotFound(name.to_string())), @@ -1145,7 +1139,7 @@ impl Network { } /// Get a `IndexParameterIndex` from a parameter's name - pub fn get_index_parameter_index_by_name(&self, name: &str) -> Result, PywrError> { + pub fn get_index_parameter_index_by_name(&self, name: &str) -> Result, PywrError> { match self.parameters.get_usize_index_by_name(name) { Some(idx) => Ok(idx), None => Err(PywrError::ParameterNotFound(name.to_string())), @@ -1155,11 +1149,11 @@ impl Network { /// Get a `MultiValueParameterIndex` from a parameter's name pub fn get_multi_valued_parameter( &self, - index: GeneralParameterIndex, - ) -> Result<&dyn parameters::GeneralParameter, PywrError> { + index: &ParameterIndex, + ) -> Result<&dyn parameters::Parameter, PywrError> { match self.parameters.get_multi(index) { Some(p) => Ok(p), - None => Err(PywrError::GeneralMultiValueParameterIndexNotFound(index)), + None => Err(PywrError::MultiValueParameterIndexNotFound(index.clone())), } } @@ -1167,7 +1161,7 @@ impl Network { pub fn get_multi_valued_parameter_index_by_name( &self, name: &str, - ) -> Result, PywrError> { + ) -> Result, PywrError> { match self.parameters.get_multi_index_by_name(name) { Some(idx) => Ok(idx), None => Err(PywrError::ParameterNotFound(name.to_string())), diff --git a/pywr-core/src/parameters/mod.rs b/pywr-core/src/parameters/mod.rs index 53470e56..9e67e914 100644 --- a/pywr-core/src/parameters/mod.rs +++ b/pywr-core/src/parameters/mod.rs @@ -863,6 +863,13 @@ impl ParameterCollection { }) } + /// Does a parameter with the given name exist in the collection. + pub fn has_name(&self, name: &str) -> bool { + self.get_f64_index_by_name(name).is_some() + || self.get_usize_index_by_name(name).is_some() + || self.get_multi_index_by_name(name).is_some() + } + /// Add a [`GeneralParameter`] parameter to the collection. /// /// This function will add attempt to simplify the parameter and add it to the simple or @@ -872,11 +879,8 @@ impl ParameterCollection { &mut self, parameter: Box>, ) -> Result, PywrError> { - if let Some(index) = self.get_f64_index_by_name(¶meter.meta().name) { - return Err(PywrError::ParameterNameAlreadyExists( - parameter.meta().name.to_string(), - index, - )); + if self.has_name(parameter.meta().name.as_str()) { + return Err(PywrError::ParameterNameAlreadyExists(parameter.meta().name.to_string())); } match parameter.try_into_simple() { @@ -893,11 +897,8 @@ impl ParameterCollection { &mut self, parameter: Box>, ) -> Result, PywrError> { - if let Some(index) = self.get_f64_index_by_name(¶meter.meta().name) { - return Err(PywrError::ParameterNameAlreadyExists( - parameter.meta().name.to_string(), - index, - )); + if self.has_name(parameter.meta().name.as_str()) { + return Err(PywrError::ParameterNameAlreadyExists(parameter.meta().name.to_string())); } match parameter.try_into_const() { @@ -914,11 +915,8 @@ impl ParameterCollection { } pub fn add_const_f64(&mut self, parameter: Box>) -> Result, PywrError> { - if let Some(index) = self.get_f64_index_by_name(¶meter.meta().name) { - return Err(PywrError::ParameterNameAlreadyExists( - parameter.meta().name.to_string(), - index, - )); + if self.has_name(parameter.meta().name.as_str()) { + return Err(PywrError::ParameterNameAlreadyExists(parameter.meta().name.to_string())); } let index = ConstParameterIndex::new(self.constant_f64.len()); @@ -979,11 +977,8 @@ impl ParameterCollection { &mut self, parameter: Box>, ) -> Result, PywrError> { - if let Some(index) = self.get_usize_index_by_name(¶meter.meta().name) { - return Err(PywrError::IndexParameterNameAlreadyExists( - parameter.meta().name.to_string(), - index, - )); + if self.has_name(parameter.meta().name.as_str()) { + return Err(PywrError::ParameterNameAlreadyExists(parameter.meta().name.to_string())); } match parameter.try_into_simple() { @@ -1000,13 +995,9 @@ impl ParameterCollection { &mut self, parameter: Box>, ) -> Result, PywrError> { - // TODO Fix this check - // if let Some(index) = self.get_f64_index_by_name(¶meter.meta().name) { - // return Err(PywrError::SimpleParameterNameAlreadyExists( - // parameter.meta().name.to_string(), - // index, - // )); - // } + if self.has_name(parameter.meta().name.as_str()) { + return Err(PywrError::ParameterNameAlreadyExists(parameter.meta().name.to_string())); + } let index = SimpleParameterIndex::new(self.simple_usize.len()); @@ -1015,33 +1006,75 @@ impl ParameterCollection { Ok(index) } - pub fn get_usize(&self, index: GeneralParameterIndex) -> Option<&dyn GeneralParameter> { + + pub fn add_const_usize( + &mut self, + parameter: Box>, + ) -> Result, PywrError> { + if self.has_name(parameter.meta().name.as_str()) { + return Err(PywrError::ParameterNameAlreadyExists(parameter.meta().name.to_string())); + } + + let index = ConstParameterIndex::new(self.constant_usize.len()); + + self.constant_usize.push(parameter); + self.constant_resolve_order.push(ConstParameterType::Index(index)); + + Ok(index.into()) + } + + pub fn get_usize(&self, index: ParameterIndex) -> Option<&dyn Parameter> { + match index { + ParameterIndex::Const(idx) => self.constant_usize.get(*idx.deref()).map(|p| p.as_parameter()), + ParameterIndex::Simple(idx) => self.simple_usize.get(*idx.deref()).map(|p| p.as_parameter()), + ParameterIndex::General(idx) => self.general_usize.get(*idx.deref()).map(|p| p.as_parameter()), + } + } + + pub fn get_general_usize(&self, index: GeneralParameterIndex) -> Option<&dyn GeneralParameter> { self.general_usize.get(*index.deref()).map(|p| p.as_ref()) } - pub fn get_usize_by_name(&self, name: &str) -> Option<&dyn GeneralParameter> { + pub fn get_usize_by_name(&self, name: &str) -> Option<&dyn Parameter> { self.general_usize .iter() .find(|p| p.meta().name == name) - .map(|p| p.as_ref()) + .map(|p| p.as_parameter()) } - pub fn get_usize_index_by_name(&self, name: &str) -> Option> { - self.general_usize + pub fn get_usize_index_by_name(&self, name: &str) -> Option> { + if let Some(idx) = self + .general_usize .iter() .position(|p| p.meta().name == name) .map(|idx| GeneralParameterIndex::new(idx)) + { + Some(idx.into()) + } else if let Some(idx) = self + .simple_usize + .iter() + .position(|p| p.meta().name == name) + .map(|idx| SimpleParameterIndex::new(idx)) + { + Some(idx.into()) + } else if let Some(idx) = self + .constant_usize + .iter() + .position(|p| p.meta().name == name) + .map(|idx| ConstParameterIndex::new(idx)) + { + Some(idx.into()) + } else { + None + } } pub fn add_general_multi( &mut self, parameter: Box>, ) -> Result, PywrError> { - if let Some(index) = self.get_multi_index_by_name(¶meter.meta().name) { - return Err(PywrError::MultiValueParameterNameAlreadyExists( - parameter.meta().name.to_string(), - index, - )); + if self.has_name(parameter.meta().name.as_str()) { + return Err(PywrError::ParameterNameAlreadyExists(parameter.meta().name.to_string())); } match parameter.try_into_simple() { @@ -1058,13 +1091,9 @@ impl ParameterCollection { &mut self, parameter: Box>, ) -> Result, PywrError> { - // TODO Fix this check - // if let Some(index) = self.get_f64_index_by_name(¶meter.meta().name) { - // return Err(PywrError::SimpleParameterNameAlreadyExists( - // parameter.meta().name.to_string(), - // index, - // )); - // } + if self.has_name(parameter.meta().name.as_str()) { + return Err(PywrError::ParameterNameAlreadyExists(parameter.meta().name.to_string())); + } let index = SimpleParameterIndex::new(self.simple_multi.len()); @@ -1073,22 +1102,69 @@ impl ParameterCollection { Ok(index) } - pub fn get_multi(&self, index: GeneralParameterIndex) -> Option<&dyn GeneralParameter> { + + pub fn add_const_multi( + &mut self, + parameter: Box>, + ) -> Result, PywrError> { + if self.has_name(parameter.meta().name.as_str()) { + return Err(PywrError::ParameterNameAlreadyExists(parameter.meta().name.to_string())); + } + + let index = ConstParameterIndex::new(self.constant_multi.len()); + + self.constant_multi.push(parameter); + self.constant_resolve_order.push(ConstParameterType::Multi(index)); + + Ok(index) + } + pub fn get_multi(&self, index: &ParameterIndex) -> Option<&dyn Parameter> { + match index { + ParameterIndex::Const(idx) => self.constant_multi.get(*idx.deref()).map(|p| p.as_parameter()), + ParameterIndex::Simple(idx) => self.simple_multi.get(*idx.deref()).map(|p| p.as_parameter()), + ParameterIndex::General(idx) => self.general_multi.get(*idx.deref()).map(|p| p.as_parameter()), + } + } + + pub fn get_general_multi( + &self, + index: &GeneralParameterIndex, + ) -> Option<&dyn GeneralParameter> { self.general_multi.get(*index.deref()).map(|p| p.as_ref()) } - pub fn get_multi_by_name(&self, name: &str) -> Option<&dyn GeneralParameter> { + pub fn get_multi_by_name(&self, name: &str) -> Option<&dyn Parameter> { self.general_multi .iter() .find(|p| p.meta().name == name) - .map(|p| p.as_ref()) + .map(|p| p.as_parameter()) } - pub fn get_multi_index_by_name(&self, name: &str) -> Option> { - self.general_multi + pub fn get_multi_index_by_name(&self, name: &str) -> Option> { + if let Some(idx) = self + .general_multi .iter() .position(|p| p.meta().name == name) .map(|idx| GeneralParameterIndex::new(idx)) + { + Some(idx.into()) + } else if let Some(idx) = self + .simple_multi + .iter() + .position(|p| p.meta().name == name) + .map(|idx| SimpleParameterIndex::new(idx)) + { + Some(idx.into()) + } else if let Some(idx) = self + .constant_multi + .iter() + .position(|p| p.meta().name == name) + .map(|idx| ConstParameterIndex::new(idx)) + { + Some(idx.into()) + } else { + None + } } pub fn compute_simple( @@ -1293,8 +1369,14 @@ impl ParameterCollection { #[cfg(test)] mod tests { - + use super::{ + ConstParameter, GeneralParameter, Parameter, ParameterCollection, ParameterMeta, ParameterState, + SimpleParameter, + }; + use crate::scenario::ScenarioIndex; + use crate::state::{ConstParameterValues, MultiValue}; use crate::timestep::{TimestepDuration, Timestepper}; + use crate::PywrError; use chrono::NaiveDateTime; // TODO tests need re-enabling @@ -1306,6 +1388,166 @@ mod tests { Timestepper::new(start, end, duration) } + /// Parameter for testing purposes + struct TestParameter { + meta: ParameterMeta, + } + + impl Default for TestParameter { + fn default() -> Self { + Self { + meta: ParameterMeta::new("test-parameter"), + } + } + } + impl Parameter for TestParameter { + fn meta(&self) -> &ParameterMeta { + &self.meta + } + } + + impl ConstParameter for TestParameter + where + T: From, + { + fn compute( + &self, + _scenario_index: &ScenarioIndex, + _values: &ConstParameterValues, + _internal_state: &mut Option>, + ) -> Result { + Ok(T::from(1)) + } + + fn as_parameter(&self) -> &dyn Parameter { + self + } + } + + impl ConstParameter for TestParameter { + fn compute( + &self, + _scenario_index: &ScenarioIndex, + _values: &ConstParameterValues, + _internal_state: &mut Option>, + ) -> Result { + Ok(MultiValue::default()) + } + + fn as_parameter(&self) -> &dyn Parameter { + self + } + } + impl SimpleParameter for TestParameter + where + T: From, + { + fn compute( + &self, + _timestep: &crate::timestep::Timestep, + _scenario_index: &ScenarioIndex, + _values: &crate::state::SimpleParameterValues, + _internal_state: &mut Option>, + ) -> Result { + Ok(T::from(1)) + } + + fn as_parameter(&self) -> &dyn Parameter { + self + } + } + + impl SimpleParameter for TestParameter { + fn compute( + &self, + _timestep: &crate::timestep::Timestep, + _scenario_index: &ScenarioIndex, + _values: &crate::state::SimpleParameterValues, + _internal_state: &mut Option>, + ) -> Result { + Ok(MultiValue::default()) + } + + fn as_parameter(&self) -> &dyn Parameter { + self + } + } + impl GeneralParameter for TestParameter + where + T: From, + { + fn compute( + &self, + _timestep: &crate::timestep::Timestep, + _scenario_index: &ScenarioIndex, + _model: &crate::network::Network, + _state: &crate::state::State, + _internal_state: &mut Option>, + ) -> Result { + Ok(T::from(1)) + } + + fn as_parameter(&self) -> &dyn Parameter { + self + } + } + + impl GeneralParameter for TestParameter { + fn compute( + &self, + _timestep: &crate::timestep::Timestep, + _scenario_index: &ScenarioIndex, + _model: &crate::network::Network, + _state: &crate::state::State, + _internal_state: &mut Option>, + ) -> Result { + Ok(MultiValue::default()) + } + + fn as_parameter(&self) -> &dyn Parameter { + self + } + } + + /// Test naming constraints on parameter collection. + #[test] + fn test_parameter_collection_name_constraints() { + let mut collection = ParameterCollection::default(); + + let ret = collection.add_const_f64(Box::new(TestParameter::default())); + assert!(ret.is_ok()); + + assert!(collection.has_name("test-parameter")); + + // Try to add a parameter with the same name + let ret = collection.add_const_f64(Box::new(TestParameter::default())); + assert!(ret.is_err()); + + let ret = collection.add_simple_f64(Box::new(TestParameter::default())); + assert!(ret.is_err()); + + let ret = collection.add_general_f64(Box::new(TestParameter::default())); + assert!(ret.is_err()); + + let ret = collection.add_const_usize(Box::new(TestParameter::default())); + assert!(ret.is_err()); + + let ret = collection.add_simple_usize(Box::new(TestParameter::default())); + assert!(ret.is_err()); + + let ret = collection.add_general_usize(Box::new(TestParameter::default())); + assert!(ret.is_err()); + + let ret = collection.add_const_multi(Box::new(TestParameter::default())); + assert!(ret.is_err()); + + let ret = collection.add_simple_multi(Box::new(TestParameter::default())); + assert!(ret.is_err()); + + let ret = collection.add_general_multi(Box::new(TestParameter::default())); + assert!(ret.is_err()); + } + // #[test] // /// Test `ConstantParameter` returns the correct value. // fn test_constant_parameter() { diff --git a/pywr-schema/src/metric.rs b/pywr-schema/src/metric.rs index 1493a3ec..37f1012d 100644 --- a/pywr-schema/src/metric.rs +++ b/pywr-schema/src/metric.rs @@ -321,16 +321,17 @@ impl ParameterReference { match &self.key { Some(key) => { // Key given; this should be a multi-valued parameter - Ok(MetricF64::MultiParameterValue(( + Ok(( network.get_multi_valued_parameter_index_by_name(&self.name)?, key.clone(), - ))) + ) + .into()) } None => { if let Ok(idx) = network.get_parameter_index_by_name(&self.name) { Ok(idx.into()) } else if let Ok(idx) = network.get_index_parameter_index_by_name(&self.name) { - Ok(MetricF64::IndexParameterValue(idx)) + Ok(idx.into()) } else { Err(SchemaError::ParameterNotFound(self.name.to_string())) } diff --git a/pywr-schema/src/model.rs b/pywr-schema/src/model.rs index 344f8841..ac30f3c3 100644 --- a/pywr-schema/src/model.rs +++ b/pywr-schema/src/model.rs @@ -1036,7 +1036,6 @@ mod core_tests { use ndarray::{Array1, Array2, Axis}; use pywr_core::{metric::MetricF64, recorders::AssertionRecorder, solvers::ClpSolver, test_utils::run_all_solvers}; use std::path::PathBuf; - use std::str::FromStr; fn model_str() -> &'static str { include_str!("./test_models/simple1.json") diff --git a/pywr-schema/src/nodes/piecewise_storage.rs b/pywr-schema/src/nodes/piecewise_storage.rs index 84796298..fa2b92f9 100644 --- a/pywr-schema/src/nodes/piecewise_storage.rs +++ b/pywr-schema/src/nodes/piecewise_storage.rs @@ -235,7 +235,7 @@ mod tests { use crate::model::PywrModel; use crate::nodes::PiecewiseStorageNode; use ndarray::{concatenate, Array, Array2, Axis}; - use pywr_core::metric::{MetricF64, MetricUsize}; + use pywr_core::metric::MetricF64; use pywr_core::recorders::{AssertionRecorder, IndexAssertionRecorder}; use pywr_core::test_utils::run_all_solvers; @@ -353,11 +353,7 @@ mod tests { .get_index_parameter_index_by_name("storage1-drought-index") .unwrap(); - let recorder = IndexAssertionRecorder::new( - "storage1-drought-index", - MetricUsize::IndexParameterValue(idx), - expected_drought_index, - ); + let recorder = IndexAssertionRecorder::new("storage1-drought-index", idx.into(), expected_drought_index); network.add_recorder(Box::new(recorder)).unwrap(); // Test all solvers