diff --git a/pywr-core/src/lib.rs b/pywr-core/src/lib.rs index bad59d2b..67046719 100644 --- a/pywr-core/src/lib.rs +++ b/pywr-core/src/lib.rs @@ -5,7 +5,9 @@ extern crate core; use crate::derived_metric::DerivedMetricIndex; use crate::models::MultiNetworkTransferIndex; use crate::node::NodeIndex; -use crate::parameters::{GeneralParameterIndex, InterpolationError, ParameterIndex, SimpleParameterIndex}; +use crate::parameters::{ + ConstParameterIndex, GeneralParameterIndex, InterpolationError, ParameterIndex, SimpleParameterIndex, +}; use crate::recorders::{AggregationError, MetricSetIndex, RecorderIndex}; use crate::state::MultiValue; use crate::virtual_storage::VirtualStorageIndex; @@ -63,6 +65,12 @@ pub enum PywrError { SimpleIndexParameterIndexNotFound(SimpleParameterIndex), #[error("multi-value parameter index {0} not found")] SimpleMultiValueParameterIndexNotFound(SimpleParameterIndex), + #[error("parameter index {0} not found")] + ConstParameterIndexNotFound(ConstParameterIndex), + #[error("index parameter index {0} not found")] + ConstIndexParameterIndexNotFound(ConstParameterIndex), + #[error("multi-value parameter index {0} not found")] + ConstMultiValueParameterIndexNotFound(ConstParameterIndex), #[error("multi-value parameter key {0} not found")] MultiValueParameterKeyNotFound(String), #[error("inter-network parameter state not initialised")] diff --git a/pywr-core/src/metric.rs b/pywr-core/src/metric.rs index f01f7c5a..620cbc63 100644 --- a/pywr-core/src/metric.rs +++ b/pywr-core/src/metric.rs @@ -5,19 +5,25 @@ use crate::edge::EdgeIndex; use crate::models::MultiNetworkTransferIndex; use crate::network::Network; use crate::node::NodeIndex; -use crate::parameters::{GeneralParameterIndex, ParameterIndex, SimpleParameterIndex}; -use crate::state::{MultiValue, SimpleParameterValues, State}; +use crate::parameters::{ConstParameterIndex, GeneralParameterIndex, ParameterIndex, SimpleParameterIndex}; +use crate::state::{ConstParameterValues, MultiValue, SimpleParameterValues, State}; use crate::virtual_storage::VirtualStorageIndex; use crate::PywrError; #[derive(Clone, Debug, PartialEq)] pub enum ConstantMetricF64 { + ParameterValue(ConstParameterIndex), + IndexParameterValue(ConstParameterIndex), + MultiParameterValue((ConstParameterIndex, String)), Constant(f64), } impl ConstantMetricF64 { - pub fn get_value(&self) -> Result { + pub fn get_value(&self, values: &ConstParameterValues) -> Result { match self { + ConstantMetricF64::ParameterValue(idx) => Ok(values.get_const_parameter_f64(*idx)?), + ConstantMetricF64::IndexParameterValue(idx) => Ok(values.get_const_parameter_usize(*idx)? as f64), + ConstantMetricF64::MultiParameterValue((idx, key)) => Ok(values.get_const_multi_parameter_f64(*idx, key)?), ConstantMetricF64::Constant(v) => Ok(*v), } } @@ -36,7 +42,7 @@ impl SimpleMetricF64 { SimpleMetricF64::ParameterValue(idx) => Ok(values.get_simple_parameter_f64(*idx)?), SimpleMetricF64::IndexParameterValue(idx) => Ok(values.get_simple_parameter_usize(*idx)? as f64), SimpleMetricF64::MultiParameterValue((idx, key)) => Ok(values.get_simple_multi_parameter_f64(*idx, key)?), - SimpleMetricF64::Constant(m) => m.get_value(), + SimpleMetricF64::Constant(m) => m.get_value(&values.get_constant_values()), } } } @@ -170,6 +176,9 @@ impl From> for MetricF64 { match idx { ParameterIndex::General(idx) => Self::ParameterValue(idx), ParameterIndex::Simple(idx) => Self::Simple(SimpleMetricF64::ParameterValue(idx)), + ParameterIndex::Const(idx) => { + Self::Simple(SimpleMetricF64::Constant(ConstantMetricF64::ParameterValue(idx))) + } } } } @@ -179,6 +188,9 @@ impl From> for MetricF64 { match idx { ParameterIndex::General(idx) => Self::IndexParameterValue(idx), ParameterIndex::Simple(idx) => Self::Simple(SimpleMetricF64::IndexParameterValue(idx)), + ParameterIndex::Const(idx) => { + Self::Simple(SimpleMetricF64::Constant(ConstantMetricF64::IndexParameterValue(idx))) + } } } } @@ -203,15 +215,32 @@ impl TryFrom> for SimpleMetricUsize { } } +#[derive(Clone, Debug, PartialEq)] +pub enum ConstantMetricUsize { + IndexParameterValue(ConstParameterIndex), + Constant(usize), +} + +impl ConstantMetricUsize { + pub fn get_value(&self, values: &ConstParameterValues) -> Result { + match self { + ConstantMetricUsize::IndexParameterValue(idx) => values.get_const_parameter_usize(*idx), + ConstantMetricUsize::Constant(v) => Ok(*v), + } + } +} + #[derive(Clone, Debug, PartialEq)] pub enum SimpleMetricUsize { IndexParameterValue(SimpleParameterIndex), + Constant(ConstantMetricUsize), } impl SimpleMetricUsize { pub fn get_value(&self, values: &SimpleParameterValues) -> Result { match self { SimpleMetricUsize::IndexParameterValue(idx) => values.get_simple_parameter_usize(*idx), + SimpleMetricUsize::Constant(m) => m.get_value(values.get_constant_values()), } } } @@ -220,7 +249,6 @@ impl SimpleMetricUsize { pub enum MetricUsize { IndexParameterValue(GeneralParameterIndex), Simple(SimpleMetricUsize), - Constant(usize), } impl MetricUsize { @@ -228,7 +256,6 @@ impl MetricUsize { match self { Self::IndexParameterValue(idx) => state.get_parameter_index(*idx), Self::Simple(s) => s.get_value(&state.get_simple_parameter_values()), - Self::Constant(i) => Ok(*i), } } } @@ -238,6 +265,32 @@ impl From> for MetricUsize { match idx { ParameterIndex::General(idx) => Self::IndexParameterValue(idx), ParameterIndex::Simple(idx) => Self::Simple(SimpleMetricUsize::IndexParameterValue(idx)), + ParameterIndex::Const(idx) => Self::Simple(SimpleMetricUsize::Constant( + ConstantMetricUsize::IndexParameterValue(idx), + )), } } } +impl From for ConstantMetricUsize { + fn from(v: usize) -> Self { + ConstantMetricUsize::Constant(v) + } +} + +impl From for SimpleMetricUsize +where + T: Into, +{ + fn from(v: T) -> Self { + SimpleMetricUsize::Constant(v.into()) + } +} + +impl From for MetricUsize +where + T: Into, +{ + fn from(v: T) -> Self { + MetricUsize::Simple(v.into()) + } +} diff --git a/pywr-core/src/network.rs b/pywr-core/src/network.rs index e2a42b9d..88cfb5e6 100644 --- a/pywr-core/src/network.rs +++ b/pywr-core/src/network.rs @@ -250,17 +250,19 @@ impl Network { .with_derived_metrics(self.derived_metrics.len()) .with_inter_network_transfers(num_inter_network_transfers); - let state = state_builder.build(); + let mut state = state_builder.build(); - states.push(state); - - parameter_internal_states.push(ParameterStates::from_collection( - &self.parameters, - timesteps, - scenario_index, - )?); + let mut internal_states = ParameterStates::from_collection(&self.parameters, timesteps, scenario_index)?; metric_set_internal_states.push(self.metric_sets.iter().map(|p| p.setup()).collect::>()); + + // Calculate parameters that implement `ConstParameter` + // First we update the simple parameters + self.parameters + .compute_const(scenario_index, &mut state, &mut internal_states)?; + + states.push(state); + parameter_internal_states.push(internal_states); } Ok(NetworkState { @@ -683,6 +685,9 @@ impl Network { ) -> Result<(), PywrError> { // TODO reset parameter state to zero + self.parameters + .after_simple(timestep, scenario_index, state, internal_states)?; + for c_type in &self.resolve_order { match c_type { ComponentType::Node(_) => { @@ -1337,6 +1342,16 @@ impl Network { Ok(parameter_index.into()) } + /// Add a [`parameters::ConstParameter`] to the network + pub fn add_const_parameter( + &mut self, + parameter: Box>, + ) -> Result, PywrError> { + let parameter_index = self.parameters.add_const_f64(parameter)?; + + Ok(parameter_index.into()) + } + /// Add a `parameters::IndexParameter` to the network pub fn add_index_parameter( &mut self, @@ -1716,7 +1731,7 @@ mod tests { let _node_index = network.add_input_node("input", None).unwrap(); let input_max_flow = parameters::ConstantParameter::new("my-constant", 10.0); - let parameter = network.add_simple_parameter(Box::new(input_max_flow)).unwrap(); + let parameter = network.add_const_parameter(Box::new(input_max_flow)).unwrap(); // assign the new parameter to one of the nodes. let node = network.get_mut_node_by_name("input", None).unwrap(); @@ -1959,7 +1974,7 @@ mod tests { let input_max_flow_idx = model .network_mut() - .add_simple_parameter(Box::new(input_max_flow)) + .add_const_parameter(Box::new(input_max_flow)) .unwrap(); // assign the new parameter to one of the nodes. diff --git a/pywr-core/src/parameters/constant.rs b/pywr-core/src/parameters/constant.rs index 8ff1420f..75f530e3 100644 --- a/pywr-core/src/parameters/constant.rs +++ b/pywr-core/src/parameters/constant.rs @@ -1,9 +1,9 @@ use crate::parameters::{ downcast_internal_state_mut, downcast_internal_state_ref, downcast_variable_config_ref, ActivationFunction, - Parameter, ParameterMeta, ParameterState, SimpleParameter, VariableConfig, VariableParameter, + ConstParameter, Parameter, ParameterMeta, ParameterState, VariableConfig, VariableParameter, }; use crate::scenario::ScenarioIndex; -use crate::state::SimpleParameterValues; +use crate::state::ConstParameterValues; use crate::timestep::Timestep; use crate::PywrError; @@ -58,12 +58,11 @@ impl Parameter for ConstantParameter { } // TODO this should only need to implement `ConstantParameter` when that is implemented. -impl SimpleParameter for ConstantParameter { +impl ConstParameter for ConstantParameter { fn compute( &self, - _timestep: &Timestep, _scenario_index: &ScenarioIndex, - _values: &SimpleParameterValues, + _values: &ConstParameterValues, internal_state: &mut Option>, ) -> Result { Ok(self.value(internal_state)) diff --git a/pywr-core/src/parameters/mod.rs b/pywr-core/src/parameters/mod.rs index 19a0e0b1..53470e56 100644 --- a/pywr-core/src/parameters/mod.rs +++ b/pywr-core/src/parameters/mod.rs @@ -31,7 +31,7 @@ pub use self::rhai::RhaiParameter; use super::PywrError; use crate::network::Network; use crate::scenario::ScenarioIndex; -use crate::state::{MultiValue, SimpleParameterValues, State}; +use crate::state::{ConstParameterValues, MultiValue, SimpleParameterValues, State}; use crate::timestep::Timestep; pub use activation_function::ActivationFunction; pub use aggregated::{AggFunc, AggregatedParameter}; @@ -71,6 +71,56 @@ use std::ops::Deref; pub use threshold::{Predicate, ThresholdParameter}; pub use vector::VectorParameter; +/// Simple parameter index. +/// +/// This is a wrapper around usize that is used to index parameters in the state. It is +/// generic over the type of the value that the parameter returns. +#[derive(Debug)] +pub struct ConstParameterIndex { + idx: usize, + phantom: PhantomData, +} + +// These implementations are required because the derive macro does not work well with PhantomData. +// See issue: https://github.com/rust-lang/rust/issues/26925 +impl Clone for ConstParameterIndex { + fn clone(&self) -> Self { + *self + } +} + +impl Copy for ConstParameterIndex {} +impl PartialEq for ConstParameterIndex { + fn eq(&self, other: &Self) -> bool { + self.idx == other.idx + } +} + +impl Eq for ConstParameterIndex {} + +impl ConstParameterIndex { + pub fn new(idx: usize) -> Self { + Self { + idx, + phantom: PhantomData, + } + } +} + +impl Deref for ConstParameterIndex { + type Target = usize; + + fn deref(&self) -> &Self::Target { + &self.idx + } +} + +impl Display for ConstParameterIndex { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.idx) + } +} + /// Simple parameter index. /// /// This is a wrapper around usize that is used to index parameters in the state. It is @@ -174,6 +224,7 @@ impl Display for GeneralParameterIndex { #[derive(Debug, Copy, Clone)] pub enum ParameterIndex { + Const(ConstParameterIndex), Simple(SimpleParameterIndex), General(GeneralParameterIndex), } @@ -181,6 +232,7 @@ pub enum ParameterIndex { impl PartialEq for ParameterIndex { fn eq(&self, other: &Self) -> bool { match (self, other) { + (Self::Const(idx1), Self::Const(idx2)) => idx1 == idx2, (Self::Simple(idx1), Self::Simple(idx2)) => idx1 == idx2, (Self::General(idx1), Self::General(idx2)) => idx1 == idx2, _ => false, @@ -193,6 +245,7 @@ impl Eq for ParameterIndex {} impl Display for ParameterIndex { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self { + Self::Const(idx) => write!(f, "{}", idx), Self::Simple(idx) => write!(f, "{}", idx), Self::General(idx) => write!(f, "{}", idx), } @@ -210,6 +263,12 @@ impl From> for ParameterIndex { } } +impl From> for ParameterIndex { + fn from(idx: ConstParameterIndex) -> Self { + Self::Const(idx) + } +} + /// Meta data common to all parameters. #[derive(Debug, Clone)] pub struct ParameterMeta { @@ -255,6 +314,7 @@ struct ParameterStatesByType { #[derive(Clone)] pub struct ParameterStates { + constant: ParameterStatesByType, simple: ParameterStatesByType, general: ParameterStatesByType, } @@ -266,14 +326,20 @@ impl ParameterStates { timesteps: &[Timestep], scenario_index: &ScenarioIndex, ) -> Result { + let constant = collection.const_initial_states(timesteps, scenario_index)?; let simple = collection.simple_initial_states(timesteps, scenario_index)?; let general = collection.general_initial_states(timesteps, scenario_index)?; - Ok(Self { simple, general }) + Ok(Self { + constant, + simple, + general, + }) } pub fn get_f64_state(&self, index: ParameterIndex) -> Option<&Option>> { match index { + ParameterIndex::Const(idx) => self.constant.f64.get(*idx.deref()), ParameterIndex::Simple(idx) => self.simple.f64.get(*idx.deref()), ParameterIndex::General(idx) => self.general.f64.get(*idx.deref()), } @@ -286,8 +352,13 @@ impl ParameterStates { self.simple.f64.get(*index.deref()) } + pub fn get_const_f64_state(&self, index: SimpleParameterIndex) -> Option<&Option>> { + self.constant.f64.get(*index.deref()) + } + pub fn get_mut_f64_state(&mut self, index: ParameterIndex) -> Option<&mut Option>> { match index { + ParameterIndex::Const(idx) => self.constant.f64.get_mut(*idx.deref()), ParameterIndex::Simple(idx) => self.simple.f64.get_mut(*idx.deref()), ParameterIndex::General(idx) => self.general.f64.get_mut(*idx.deref()), } @@ -305,6 +376,12 @@ impl ParameterStates { ) -> Option<&mut Option>> { self.simple.f64.get_mut(*index.deref()) } + pub fn get_const_mut_f64_state( + &mut self, + index: ConstParameterIndex, + ) -> Option<&mut Option>> { + self.constant.f64.get_mut(*index.deref()) + } pub fn get_general_mut_usize_state( &mut self, index: GeneralParameterIndex, @@ -318,6 +395,12 @@ impl ParameterStates { ) -> Option<&mut Option>> { self.simple.usize.get_mut(*index.deref()) } + pub fn get_const_mut_usize_state( + &mut self, + index: ConstParameterIndex, + ) -> Option<&mut Option>> { + self.constant.usize.get_mut(*index.deref()) + } pub fn get_general_mut_multi_state( &mut self, @@ -332,6 +415,13 @@ impl ParameterStates { ) -> Option<&mut Option>> { self.simple.multi.get_mut(*index.deref()) } + + pub fn get_const_mut_multi_state( + &mut self, + index: ConstParameterIndex, + ) -> Option<&mut Option>> { + self.constant.multi.get_mut(*index.deref()) + } } /// Helper function to downcast to internal parameter state and print a helpful panic @@ -488,6 +578,24 @@ pub trait SimpleParameter: Parameter { } fn as_parameter(&self) -> &dyn Parameter; + + fn try_into_const(&self) -> Option>> { + None + } +} + +/// A trait that defines a component that produces a value each time-step. +/// +/// The trait is generic over the type of the value produced. +pub trait ConstParameter: Parameter { + fn compute( + &self, + scenario_index: &ScenarioIndex, + values: &ConstParameterValues, + internal_state: &mut Option>, + ) -> Result; + + fn as_parameter(&self) -> &dyn Parameter; } pub enum GeneralParameterType { @@ -538,6 +646,30 @@ impl From> for SimpleParameterType { } } +pub enum ConstParameterType { + Parameter(ConstParameterIndex), + Index(ConstParameterIndex), + Multi(ConstParameterIndex), +} + +impl From> for ConstParameterType { + fn from(idx: ConstParameterIndex) -> Self { + Self::Parameter(idx) + } +} + +impl From> for ConstParameterType { + fn from(idx: ConstParameterIndex) -> Self { + Self::Index(idx) + } +} + +impl From> for ConstParameterType { + fn from(idx: ConstParameterIndex) -> Self { + Self::Multi(idx) + } +} + pub enum ParameterType { Parameter(ParameterIndex), Index(ParameterIndex), @@ -593,6 +725,9 @@ pub trait VariableParameter { #[derive(Debug, Clone, Copy)] pub struct ParameterCollectionSize { + pub const_f64: usize, + pub const_usize: usize, + pub const_multi: usize, pub simple_f64: usize, pub simple_usize: usize, pub simple_multi: usize, @@ -604,11 +739,18 @@ pub struct ParameterCollectionSize { /// A collection of parameters that return different types. #[derive(Default)] pub struct ParameterCollection { + constant_f64: Vec>>, + constant_usize: Vec>>, + constant_multi: Vec>>, + constant_resolve_order: Vec, + simple_f64: Vec>>, simple_usize: Vec>>, simple_multi: Vec>>, simple_resolve_order: Vec, + // There is no resolve order for general parameters as they are resolved at a model + // level with other component types (e.g. nodes). general_f64: Vec>>, general_usize: Vec>>, general_multi: Vec>>, @@ -617,6 +759,9 @@ pub struct ParameterCollection { impl ParameterCollection { pub fn size(&self) -> ParameterCollectionSize { ParameterCollectionSize { + const_f64: self.constant_f64.len(), + const_usize: self.constant_usize.len(), + const_multi: self.constant_multi.len(), simple_f64: self.simple_f64.len(), simple_usize: self.simple_usize.len(), simple_multi: self.simple_multi.len(), @@ -687,7 +832,42 @@ impl ParameterCollection { }) } - /// Add a new parameter to the collection. + fn const_initial_states( + &self, + timesteps: &[Timestep], + scenario_index: &ScenarioIndex, + ) -> Result { + // Get the initial internal state + let f64_states = self + .constant_f64 + .iter() + .map(|p| p.setup(timesteps, scenario_index)) + .collect::, _>>()?; + + let usize_states = self + .constant_usize + .iter() + .map(|p| p.setup(timesteps, scenario_index)) + .collect::, _>>()?; + + let multi_states = self + .constant_multi + .iter() + .map(|p| p.setup(timesteps, scenario_index)) + .collect::, _>>()?; + + Ok(ParameterStatesByType { + f64: f64_states, + usize: usize_states, + multi: multi_states, + }) + } + + /// Add a [`GeneralParameter`] parameter to the collection. + /// + /// This function will add attempt to simplify the parameter and add it to the simple or + /// constant parameter list. If the parameter cannot be simplified it will be added to the + /// general parameter list. pub fn add_general_f64( &mut self, parameter: Box>, @@ -712,25 +892,46 @@ impl ParameterCollection { pub fn add_simple_f64( &mut self, parameter: Box>, - ) -> Result, PywrError> { - // TODO Fix this check - // if let Some(index) = self.get_f64_index_by_name(¶meter.meta().name) { - // return Err(PywrError::SimpleParameterNameAlreadyExists( - // parameter.meta().name.to_string(), - // index, - // )); - // } + ) -> Result, PywrError> { + if let Some(index) = self.get_f64_index_by_name(¶meter.meta().name) { + return Err(PywrError::ParameterNameAlreadyExists( + parameter.meta().name.to_string(), + index, + )); + } - let index = SimpleParameterIndex::new(self.simple_f64.len()); + match parameter.try_into_const() { + Some(constant) => self.add_const_f64(constant).map(|idx| idx.into()), + None => { + let index = SimpleParameterIndex::new(self.simple_f64.len()); - self.simple_f64.push(parameter); - self.simple_resolve_order.push(SimpleParameterType::Parameter(index)); + self.simple_f64.push(parameter); + self.simple_resolve_order.push(SimpleParameterType::Parameter(index)); - Ok(index) + Ok(index.into()) + } + } + } + + pub fn add_const_f64(&mut self, parameter: Box>) -> Result, PywrError> { + if let Some(index) = self.get_f64_index_by_name(¶meter.meta().name) { + return Err(PywrError::ParameterNameAlreadyExists( + parameter.meta().name.to_string(), + index, + )); + } + + let index = ConstParameterIndex::new(self.constant_f64.len()); + + self.constant_f64.push(parameter); + self.constant_resolve_order.push(ConstParameterType::Parameter(index)); + + Ok(index.into()) } pub fn get_f64(&self, index: ParameterIndex) -> Option<&dyn Parameter> { match index { + ParameterIndex::Const(idx) => self.constant_f64.get(*idx.deref()).map(|p| p.as_parameter()), ParameterIndex::Simple(idx) => self.simple_f64.get(*idx.deref()).map(|p| p.as_parameter()), ParameterIndex::General(idx) => self.general_f64.get(*idx.deref()).map(|p| p.as_parameter()), } @@ -762,6 +963,13 @@ impl ParameterCollection { .map(|idx| SimpleParameterIndex::new(idx)) { Some(idx.into()) + } else if let Some(idx) = self + .constant_f64 + .iter() + .position(|p| p.meta().name == name) + .map(|idx| ConstParameterIndex::new(idx)) + { + Some(idx.into()) } else { None } @@ -954,6 +1162,133 @@ impl ParameterCollection { Ok(()) } + + /// Perform the after step for simple parameters. + pub fn after_simple( + &self, + timestep: &Timestep, + scenario_index: &ScenarioIndex, + state: &mut State, + internal_states: &mut ParameterStates, + ) -> Result<(), PywrError> { + for p in &self.simple_resolve_order { + match p { + SimpleParameterType::Parameter(idx) => { + // Find the parameter itself + let p = self + .simple_f64 + .get(*idx.deref()) + .ok_or(PywrError::SimpleParameterIndexNotFound(*idx))?; + // .. and its internal state + let internal_state = internal_states + .get_simple_mut_f64_state(*idx) + .ok_or(PywrError::SimpleParameterIndexNotFound(*idx))?; + + p.after( + timestep, + scenario_index, + &state.get_simple_parameter_values(), + internal_state, + )?; + } + SimpleParameterType::Index(idx) => { + // Find the parameter itself + let p = self + .simple_usize + .get(*idx.deref()) + .ok_or(PywrError::SimpleIndexParameterIndexNotFound(*idx))?; + // .. and its internal state + let internal_state = internal_states + .get_simple_mut_usize_state(*idx) + .ok_or(PywrError::SimpleIndexParameterIndexNotFound(*idx))?; + + p.after( + timestep, + scenario_index, + &state.get_simple_parameter_values(), + internal_state, + )?; + } + SimpleParameterType::Multi(idx) => { + // Find the parameter itself + let p = self + .simple_multi + .get(*idx.deref()) + .ok_or(PywrError::SimpleMultiValueParameterIndexNotFound(*idx))?; + // .. and its internal state + let internal_state = internal_states + .get_simple_mut_multi_state(*idx) + .ok_or(PywrError::SimpleMultiValueParameterIndexNotFound(*idx))?; + + p.compute( + timestep, + scenario_index, + &state.get_simple_parameter_values(), + internal_state, + )?; + } + } + } + + Ok(()) + } + + /// Compute the constant parameters. + pub fn compute_const( + &self, + scenario_index: &ScenarioIndex, + state: &mut State, + internal_states: &mut ParameterStates, + ) -> Result<(), PywrError> { + for p in &self.constant_resolve_order { + match p { + ConstParameterType::Parameter(idx) => { + // Find the parameter itself + let p = self + .constant_f64 + .get(*idx.deref()) + .ok_or(PywrError::ConstParameterIndexNotFound(*idx))?; + // .. and its internal state + let internal_state = internal_states + .get_const_mut_f64_state(*idx) + .ok_or(PywrError::ConstParameterIndexNotFound(*idx))?; + + let value = p.compute(scenario_index, &state.get_const_parameter_values(), internal_state)?; + state.set_const_parameter_value(*idx, value)?; + } + ConstParameterType::Index(idx) => { + // Find the parameter itself + let p = self + .constant_usize + .get(*idx.deref()) + .ok_or(PywrError::ConstIndexParameterIndexNotFound(*idx))?; + // .. and its internal state + let internal_state = internal_states + .get_const_mut_usize_state(*idx) + .ok_or(PywrError::ConstIndexParameterIndexNotFound(*idx))?; + + let value = p.compute(scenario_index, &state.get_const_parameter_values(), internal_state)?; + state.set_const_parameter_index(*idx, value)?; + } + ConstParameterType::Multi(idx) => { + // Find the parameter itself + let p = self + .constant_multi + .get(*idx.deref()) + .ok_or(PywrError::ConstMultiValueParameterIndexNotFound(*idx))?; + // .. and its internal state + let internal_state = internal_states + .get_const_mut_multi_state(*idx) + .ok_or(PywrError::ConstMultiValueParameterIndexNotFound(*idx))?; + + let value = p.compute(scenario_index, &state.get_const_parameter_values(), internal_state)?; + state.set_const_multi_parameter_value(*idx, value)?; + } + } + } + + Ok(()) + } } #[cfg(test)] diff --git a/pywr-core/src/state.rs b/pywr-core/src/state.rs index 7d5a0498..f3b2eff9 100644 --- a/pywr-core/src/state.rs +++ b/pywr-core/src/state.rs @@ -3,7 +3,9 @@ use crate::edge::{Edge, EdgeIndex}; use crate::models::MultiNetworkTransferIndex; use crate::network::Network; use crate::node::{Node, NodeIndex}; -use crate::parameters::{GeneralParameterIndex, ParameterCollection, ParameterCollectionSize, SimpleParameterIndex}; +use crate::parameters::{ + ConstParameterIndex, GeneralParameterIndex, ParameterCollection, ParameterCollectionSize, SimpleParameterIndex, +}; use crate::timestep::Timestep; use crate::virtual_storage::VirtualStorageIndex; use crate::PywrError; @@ -353,10 +355,12 @@ pub struct ParameterValuesCollection { impl ParameterValuesCollection { fn get_simple_parameter_values(&self) -> SimpleParameterValues { SimpleParameterValues { - constant: ParameterValuesRef { - values: &self.constant.values, - indices: &self.constant.indices, - multi_values: &self.constant.multi_values, + constant: ConstParameterValues { + constant: ParameterValuesRef { + values: &self.constant.values, + indices: &self.constant.indices, + multi_values: &self.constant.multi_values, + }, }, simple: ParameterValuesRef { values: &self.simple.values, @@ -365,6 +369,16 @@ impl ParameterValuesCollection { }, } } + + fn get_const_parameter_values(&self) -> ConstParameterValues { + ConstParameterValues { + constant: ParameterValuesRef { + values: &self.constant.values, + indices: &self.constant.indices, + multi_values: &self.constant.multi_values, + }, + } + } } pub struct ParameterValuesRef<'a> { @@ -388,7 +402,7 @@ impl<'a> ParameterValuesRef<'a> { } pub struct SimpleParameterValues<'a> { - constant: ParameterValuesRef<'a>, + constant: ConstParameterValues<'a>, simple: ParameterValuesRef<'a>, } @@ -417,6 +431,41 @@ impl<'a> SimpleParameterValues<'a> { .ok_or(PywrError::SimpleMultiValueParameterIndexNotFound(idx)) .copied() } + + pub fn get_constant_values(&self) -> &ConstParameterValues { + &self.constant + } +} + +pub struct ConstParameterValues<'a> { + constant: ParameterValuesRef<'a>, +} + +impl<'a> ConstParameterValues<'a> { + pub fn get_const_parameter_f64(&self, idx: ConstParameterIndex) -> Result { + self.constant + .get_value(*idx.deref()) + .ok_or(PywrError::ConstParameterIndexNotFound(idx)) + .copied() + } + + pub fn get_const_parameter_usize(&self, idx: ConstParameterIndex) -> Result { + self.constant + .get_index(*idx.deref()) + .ok_or(PywrError::ConstIndexParameterIndexNotFound(idx)) + .copied() + } + + pub fn get_const_multi_parameter_f64( + &self, + idx: ConstParameterIndex, + key: &str, + ) -> Result { + self.constant + .get_multi_value(*idx.deref(), key) + .ok_or(PywrError::ConstMultiValueParameterIndexNotFound(idx)) + .copied() + } } // State of the nodes and edges @@ -684,6 +733,13 @@ impl State { }) } + pub fn set_const_parameter_value(&mut self, idx: ConstParameterIndex, value: f64) -> Result<(), PywrError> { + self.parameters.constant.set_value(*idx, value).map_err(|e| match e { + ParameterValuesError::IndexNotFound(_) => PywrError::ConstParameterIndexNotFound(idx), + ParameterValuesError::KeyNotFound(key) => PywrError::MultiValueParameterKeyNotFound(key), + }) + } + pub fn get_parameter_index(&self, idx: GeneralParameterIndex) -> Result { self.parameters.general.get_index(*idx).map_err(|e| match e { ParameterValuesError::IndexNotFound(_) => PywrError::GeneralIndexParameterIndexNotFound(idx), @@ -708,6 +764,17 @@ impl State { ParameterValuesError::KeyNotFound(key) => PywrError::MultiValueParameterKeyNotFound(key), }) } + + pub fn set_const_parameter_index( + &mut self, + idx: ConstParameterIndex, + value: usize, + ) -> Result<(), PywrError> { + self.parameters.constant.set_index(*idx, value).map_err(|e| match e { + ParameterValuesError::IndexNotFound(_) => PywrError::ConstIndexParameterIndexNotFound(idx), + ParameterValuesError::KeyNotFound(key) => PywrError::MultiValueParameterKeyNotFound(key), + }) + } pub fn get_multi_parameter_value( &self, idx: GeneralParameterIndex, @@ -747,6 +814,20 @@ impl State { }) } + pub fn set_const_multi_parameter_value( + &mut self, + idx: ConstParameterIndex, + value: MultiValue, + ) -> Result<(), PywrError> { + self.parameters + .constant + .set_multi_value(*idx, value) + .map_err(|e| match e { + ParameterValuesError::IndexNotFound(_) => PywrError::ConstMultiValueParameterIndexNotFound(idx), + ParameterValuesError::KeyNotFound(key) => PywrError::MultiValueParameterKeyNotFound(key), + }) + } + pub fn get_multi_parameter_index( &self, idx: GeneralParameterIndex, @@ -762,6 +843,10 @@ impl State { self.parameters.get_simple_parameter_values() } + pub fn get_const_parameter_values(&self) -> ConstParameterValues { + self.parameters.get_const_parameter_values() + } + pub fn set_node_volume(&mut self, idx: NodeIndex, volume: f64) -> Result<(), PywrError> { self.network.set_volume(idx, volume) } @@ -888,7 +973,12 @@ impl StateBuilder { /// Build the [`State`] from the builder. pub fn build(self) -> State { - let constant = ParameterValues::new(0, 0, 0); + let constant = ParameterValues::new( + self.num_parameters.map(|s| s.const_f64).unwrap_or(0), + self.num_parameters.map(|s| s.const_usize).unwrap_or(0), + self.num_parameters.map(|s| s.const_multi).unwrap_or(0), + ); + let simple = ParameterValues::new( self.num_parameters.map(|s| s.simple_f64).unwrap_or(0), self.num_parameters.map(|s| s.simple_usize).unwrap_or(0), diff --git a/pywr-core/src/test_utils.rs b/pywr-core/src/test_utils.rs index a393cc22..f8d7e2ac 100644 --- a/pywr-core/src/test_utils.rs +++ b/pywr-core/src/test_utils.rs @@ -69,7 +69,7 @@ pub fn simple_network(network: &mut Network, inflow_scenario_index: usize, num_i let base_demand = 10.0; let demand_factor = ConstantParameter::new("demand-factor", 1.2); - let demand_factor = network.add_simple_parameter(Box::new(demand_factor)).unwrap(); + let demand_factor = network.add_const_parameter(Box::new(demand_factor)).unwrap(); let total_demand: AggregatedParameter = AggregatedParameter::new( "total-demand", @@ -79,7 +79,7 @@ pub fn simple_network(network: &mut Network, inflow_scenario_index: usize, num_i let total_demand = network.add_parameter(Box::new(total_demand)).unwrap(); let demand_cost = ConstantParameter::new("demand-cost", -10.0); - let demand_cost = network.add_simple_parameter(Box::new(demand_cost)).unwrap(); + let demand_cost = network.add_const_parameter(Box::new(demand_cost)).unwrap(); let output_node = network.get_mut_node_by_name("output", None).unwrap(); output_node.set_max_flow_constraint(Some(total_demand.into())).unwrap(); @@ -122,10 +122,10 @@ pub fn simple_storage_model() -> Model { // Apply demand to the model // TODO convenience function for adding a constant constraint. let demand = ConstantParameter::new("demand", 10.0); - let demand = network.add_simple_parameter(Box::new(demand)).unwrap(); + let demand = network.add_const_parameter(Box::new(demand)).unwrap(); let demand_cost = ConstantParameter::new("demand-cost", -10.0); - let demand_cost = network.add_simple_parameter(Box::new(demand_cost)).unwrap(); + let demand_cost = network.add_const_parameter(Box::new(demand_cost)).unwrap(); let output_node = network.get_mut_node_by_name("output", None).unwrap(); output_node.set_max_flow_constraint(Some(demand.into())).unwrap(); diff --git a/pywr-schema/src/parameters/core.rs b/pywr-schema/src/parameters/core.rs index 3c1161a2..826808b1 100644 --- a/pywr-schema/src/parameters/core.rs +++ b/pywr-schema/src/parameters/core.rs @@ -167,7 +167,7 @@ impl ConstantParameter { args: &LoadArgs, ) -> Result, SchemaError> { let p = pywr_core::parameters::ConstantParameter::new(&self.meta.name, self.value.load(args.tables)?); - Ok(network.add_simple_parameter(Box::new(p))?) + Ok(network.add_const_parameter(Box::new(p))?) } } diff --git a/pywr-schema/src/parameters/mod.rs b/pywr-schema/src/parameters/mod.rs index a4cb7bfe..1f50a98a 100644 --- a/pywr-schema/src/parameters/mod.rs +++ b/pywr-schema/src/parameters/mod.rs @@ -815,7 +815,7 @@ impl DynamicIndexValue { impl DynamicIndexValue { pub fn load(&self, network: &mut pywr_core::network::Network, args: &LoadArgs) -> Result { let parameter_ref = match self { - DynamicIndexValue::Constant(v) => MetricUsize::Constant(v.load(args.tables)?), + DynamicIndexValue::Constant(v) => v.load(args.tables)?.into(), DynamicIndexValue::Dynamic(v) => v.load(network, args)?.into(), }; Ok(parameter_ref)