Skip to content

Commit

Permalink
feat: Implement ConstParameter.
Browse files Browse the repository at this point in the history
Final parameter variant that is computed only at the beginning
of a model.
  • Loading branch information
jetuk committed Jun 17, 2024
1 parent 173ae4c commit 17e008f
Show file tree
Hide file tree
Showing 9 changed files with 550 additions and 50 deletions.
10 changes: 9 additions & 1 deletion pywr-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ extern crate core;
use crate::derived_metric::DerivedMetricIndex;
use crate::models::MultiNetworkTransferIndex;
use crate::node::NodeIndex;
use crate::parameters::{GeneralParameterIndex, InterpolationError, ParameterIndex, SimpleParameterIndex};
use crate::parameters::{
ConstParameterIndex, GeneralParameterIndex, InterpolationError, ParameterIndex, SimpleParameterIndex,
};
use crate::recorders::{AggregationError, MetricSetIndex, RecorderIndex};
use crate::state::MultiValue;
use crate::virtual_storage::VirtualStorageIndex;
Expand Down Expand Up @@ -63,6 +65,12 @@ pub enum PywrError {
SimpleIndexParameterIndexNotFound(SimpleParameterIndex<usize>),
#[error("multi-value parameter index {0} not found")]
SimpleMultiValueParameterIndexNotFound(SimpleParameterIndex<MultiValue>),
#[error("parameter index {0} not found")]
ConstParameterIndexNotFound(ConstParameterIndex<f64>),
#[error("index parameter index {0} not found")]
ConstIndexParameterIndexNotFound(ConstParameterIndex<usize>),
#[error("multi-value parameter index {0} not found")]
ConstMultiValueParameterIndexNotFound(ConstParameterIndex<MultiValue>),
#[error("multi-value parameter key {0} not found")]
MultiValueParameterKeyNotFound(String),
#[error("inter-network parameter state not initialised")]
Expand Down
65 changes: 59 additions & 6 deletions pywr-core/src/metric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,25 @@ use crate::edge::EdgeIndex;
use crate::models::MultiNetworkTransferIndex;
use crate::network::Network;
use crate::node::NodeIndex;
use crate::parameters::{GeneralParameterIndex, ParameterIndex, SimpleParameterIndex};
use crate::state::{MultiValue, SimpleParameterValues, State};
use crate::parameters::{ConstParameterIndex, GeneralParameterIndex, ParameterIndex, SimpleParameterIndex};
use crate::state::{ConstParameterValues, MultiValue, SimpleParameterValues, State};
use crate::virtual_storage::VirtualStorageIndex;
use crate::PywrError;

#[derive(Clone, Debug, PartialEq)]
pub enum ConstantMetricF64 {
ParameterValue(ConstParameterIndex<f64>),
IndexParameterValue(ConstParameterIndex<usize>),
MultiParameterValue((ConstParameterIndex<MultiValue>, String)),
Constant(f64),
}

impl ConstantMetricF64 {
pub fn get_value(&self) -> Result<f64, PywrError> {
pub fn get_value(&self, values: &ConstParameterValues) -> Result<f64, PywrError> {
match self {
ConstantMetricF64::ParameterValue(idx) => Ok(values.get_const_parameter_f64(*idx)?),
ConstantMetricF64::IndexParameterValue(idx) => Ok(values.get_const_parameter_usize(*idx)? as f64),
ConstantMetricF64::MultiParameterValue((idx, key)) => Ok(values.get_const_multi_parameter_f64(*idx, key)?),
ConstantMetricF64::Constant(v) => Ok(*v),
}
}
Expand All @@ -36,7 +42,7 @@ impl SimpleMetricF64 {
SimpleMetricF64::ParameterValue(idx) => Ok(values.get_simple_parameter_f64(*idx)?),
SimpleMetricF64::IndexParameterValue(idx) => Ok(values.get_simple_parameter_usize(*idx)? as f64),
SimpleMetricF64::MultiParameterValue((idx, key)) => Ok(values.get_simple_multi_parameter_f64(*idx, key)?),
SimpleMetricF64::Constant(m) => m.get_value(),
SimpleMetricF64::Constant(m) => m.get_value(&values.get_constant_values()),
}
}
}
Expand Down Expand Up @@ -170,6 +176,9 @@ impl From<ParameterIndex<f64>> for MetricF64 {
match idx {
ParameterIndex::General(idx) => Self::ParameterValue(idx),
ParameterIndex::Simple(idx) => Self::Simple(SimpleMetricF64::ParameterValue(idx)),
ParameterIndex::Const(idx) => {
Self::Simple(SimpleMetricF64::Constant(ConstantMetricF64::ParameterValue(idx)))
}
}
}
}
Expand All @@ -179,6 +188,9 @@ impl From<ParameterIndex<usize>> for MetricF64 {
match idx {
ParameterIndex::General(idx) => Self::IndexParameterValue(idx),
ParameterIndex::Simple(idx) => Self::Simple(SimpleMetricF64::IndexParameterValue(idx)),
ParameterIndex::Const(idx) => {
Self::Simple(SimpleMetricF64::Constant(ConstantMetricF64::IndexParameterValue(idx)))
}
}
}
}
Expand All @@ -203,15 +215,32 @@ impl TryFrom<ParameterIndex<usize>> for SimpleMetricUsize {
}
}

#[derive(Clone, Debug, PartialEq)]
pub enum ConstantMetricUsize {
IndexParameterValue(ConstParameterIndex<usize>),
Constant(usize),
}

impl ConstantMetricUsize {
pub fn get_value(&self, values: &ConstParameterValues) -> Result<usize, PywrError> {
match self {
ConstantMetricUsize::IndexParameterValue(idx) => values.get_const_parameter_usize(*idx),
ConstantMetricUsize::Constant(v) => Ok(*v),
}
}
}

#[derive(Clone, Debug, PartialEq)]
pub enum SimpleMetricUsize {
IndexParameterValue(SimpleParameterIndex<usize>),
Constant(ConstantMetricUsize),
}

impl SimpleMetricUsize {
pub fn get_value(&self, values: &SimpleParameterValues) -> Result<usize, PywrError> {
match self {
SimpleMetricUsize::IndexParameterValue(idx) => values.get_simple_parameter_usize(*idx),
SimpleMetricUsize::Constant(m) => m.get_value(values.get_constant_values()),
}
}
}
Expand All @@ -220,15 +249,13 @@ impl SimpleMetricUsize {
pub enum MetricUsize {
IndexParameterValue(GeneralParameterIndex<usize>),
Simple(SimpleMetricUsize),
Constant(usize),
}

impl MetricUsize {
pub fn get_value(&self, _network: &Network, state: &State) -> Result<usize, PywrError> {
match self {
Self::IndexParameterValue(idx) => state.get_parameter_index(*idx),
Self::Simple(s) => s.get_value(&state.get_simple_parameter_values()),
Self::Constant(i) => Ok(*i),
}
}
}
Expand All @@ -238,6 +265,32 @@ impl From<ParameterIndex<usize>> for MetricUsize {
match idx {
ParameterIndex::General(idx) => Self::IndexParameterValue(idx),
ParameterIndex::Simple(idx) => Self::Simple(SimpleMetricUsize::IndexParameterValue(idx)),
ParameterIndex::Const(idx) => Self::Simple(SimpleMetricUsize::Constant(
ConstantMetricUsize::IndexParameterValue(idx),
)),
}
}
}
impl From<usize> for ConstantMetricUsize {
fn from(v: usize) -> Self {
ConstantMetricUsize::Constant(v)
}
}

impl<T> From<T> for SimpleMetricUsize
where
T: Into<ConstantMetricUsize>,
{
fn from(v: T) -> Self {
SimpleMetricUsize::Constant(v.into())
}
}

impl<T> From<T> for MetricUsize
where
T: Into<SimpleMetricUsize>,
{
fn from(v: T) -> Self {
MetricUsize::Simple(v.into())
}
}
35 changes: 25 additions & 10 deletions pywr-core/src/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,17 +250,19 @@ impl Network {
.with_derived_metrics(self.derived_metrics.len())
.with_inter_network_transfers(num_inter_network_transfers);

let state = state_builder.build();
let mut state = state_builder.build();

states.push(state);

parameter_internal_states.push(ParameterStates::from_collection(
&self.parameters,
timesteps,
scenario_index,
)?);
let mut internal_states = ParameterStates::from_collection(&self.parameters, timesteps, scenario_index)?;

metric_set_internal_states.push(self.metric_sets.iter().map(|p| p.setup()).collect::<Vec<_>>());

// Calculate parameters that implement `ConstParameter`
// First we update the simple parameters
self.parameters
.compute_const(scenario_index, &mut state, &mut internal_states)?;

states.push(state);
parameter_internal_states.push(internal_states);
}

Ok(NetworkState {
Expand Down Expand Up @@ -683,6 +685,9 @@ impl Network {
) -> Result<(), PywrError> {
// TODO reset parameter state to zero

self.parameters
.after_simple(timestep, scenario_index, state, internal_states)?;

for c_type in &self.resolve_order {
match c_type {
ComponentType::Node(_) => {
Expand Down Expand Up @@ -1337,6 +1342,16 @@ impl Network {
Ok(parameter_index.into())
}

/// Add a [`parameters::ConstParameter`] to the network
pub fn add_const_parameter(
&mut self,
parameter: Box<dyn parameters::ConstParameter<f64>>,
) -> Result<ParameterIndex<f64>, PywrError> {
let parameter_index = self.parameters.add_const_f64(parameter)?;

Ok(parameter_index.into())
}

/// Add a `parameters::IndexParameter` to the network
pub fn add_index_parameter(
&mut self,
Expand Down Expand Up @@ -1716,7 +1731,7 @@ mod tests {
let _node_index = network.add_input_node("input", None).unwrap();

let input_max_flow = parameters::ConstantParameter::new("my-constant", 10.0);
let parameter = network.add_simple_parameter(Box::new(input_max_flow)).unwrap();
let parameter = network.add_const_parameter(Box::new(input_max_flow)).unwrap();

// assign the new parameter to one of the nodes.
let node = network.get_mut_node_by_name("input", None).unwrap();
Expand Down Expand Up @@ -1959,7 +1974,7 @@ mod tests {

let input_max_flow_idx = model
.network_mut()
.add_simple_parameter(Box::new(input_max_flow))
.add_const_parameter(Box::new(input_max_flow))
.unwrap();

// assign the new parameter to one of the nodes.
Expand Down
9 changes: 4 additions & 5 deletions pywr-core/src/parameters/constant.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use crate::parameters::{
downcast_internal_state_mut, downcast_internal_state_ref, downcast_variable_config_ref, ActivationFunction,
Parameter, ParameterMeta, ParameterState, SimpleParameter, VariableConfig, VariableParameter,
ConstParameter, Parameter, ParameterMeta, ParameterState, VariableConfig, VariableParameter,
};
use crate::scenario::ScenarioIndex;
use crate::state::SimpleParameterValues;
use crate::state::ConstParameterValues;
use crate::timestep::Timestep;
use crate::PywrError;

Expand Down Expand Up @@ -58,12 +58,11 @@ impl Parameter for ConstantParameter {
}

// TODO this should only need to implement `ConstantParameter` when that is implemented.
impl SimpleParameter<f64> for ConstantParameter {
impl ConstParameter<f64> for ConstantParameter {
fn compute(
&self,
_timestep: &Timestep,
_scenario_index: &ScenarioIndex,
_values: &SimpleParameterValues,
_values: &ConstParameterValues,
internal_state: &mut Option<Box<dyn ParameterState>>,
) -> Result<f64, PywrError> {
Ok(self.value(internal_state))
Expand Down
Loading

0 comments on commit 17e008f

Please sign in to comment.