Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor parameters into General, Simple and Constant #194

Merged
merged 8 commits into from
Jul 5, 2024
Merged
19 changes: 8 additions & 11 deletions pywr-core/src/aggregated_node.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::metric::MetricF64;
use crate::network::Network;
use crate::node::{Constraint, ConstraintValue, FlowConstraints, NodeMeta};
use crate::node::{Constraint, FlowConstraints, NodeMeta};
use crate::state::State;
use crate::{NodeIndex, PywrError};
use std::ops::{Deref, DerefMut};
Expand Down Expand Up @@ -112,7 +112,7 @@ impl AggregatedNode {
) -> Self {
Self {
meta: NodeMeta::new(index, name, sub_name),
flow_constraints: FlowConstraints::new(),
flow_constraints: FlowConstraints::default(),
nodes: nodes.to_vec(),
factors,
}
Expand Down Expand Up @@ -174,21 +174,21 @@ impl AggregatedNode {
}
}

pub fn set_min_flow_constraint(&mut self, value: ConstraintValue) {
pub fn set_min_flow_constraint(&mut self, value: Option<MetricF64>) {
self.flow_constraints.min_flow = value;
}
pub fn get_min_flow_constraint(&self, model: &Network, state: &State) -> Result<f64, PywrError> {
self.flow_constraints.get_min_flow(model, state)
}
pub fn set_max_flow_constraint(&mut self, value: ConstraintValue) {
pub fn set_max_flow_constraint(&mut self, value: Option<MetricF64>) {
self.flow_constraints.max_flow = value;
}
pub fn get_max_flow_constraint(&self, model: &Network, state: &State) -> Result<f64, PywrError> {
self.flow_constraints.get_max_flow(model, state)
}

/// Set a constraint on a node.
pub fn set_constraint(&mut self, value: ConstraintValue, constraint: Constraint) -> Result<(), PywrError> {
pub fn set_constraint(&mut self, value: Option<MetricF64>, constraint: Constraint) -> Result<(), PywrError> {
match constraint {
Constraint::MinFlow => self.set_min_flow_constraint(value),
Constraint::MaxFlow => self.set_max_flow_constraint(value),
Expand Down Expand Up @@ -296,7 +296,6 @@ mod tests {
use crate::metric::MetricF64;
use crate::models::Model;
use crate::network::Network;
use crate::node::ConstraintValue;
use crate::recorders::AssertionRecorder;
use crate::test_utils::{default_time_domain, run_all_solvers};
use ndarray::Array2;
Expand All @@ -321,17 +320,15 @@ mod tests {
network.connect_nodes(input_node, link_node1).unwrap();
network.connect_nodes(link_node1, output_node1).unwrap();

let factors = Some(Factors::Ratio(vec![MetricF64::Constant(2.0), MetricF64::Constant(1.0)]));
let factors = Some(Factors::Ratio(vec![2.0.into(), 1.0.into()]));

let _agg_node = network.add_aggregated_node("agg-node", None, &[link_node0, link_node1], factors);

// Setup a demand on output-0
let output_node = network.get_mut_node_by_name("output", Some("0")).unwrap();
output_node
.set_max_flow_constraint(ConstraintValue::Scalar(100.0))
.unwrap();
output_node.set_max_flow_constraint(Some(100.0.into())).unwrap();

output_node.set_cost(ConstraintValue::Scalar(-10.0));
output_node.set_cost(Some((-10.0).into()));

// Set-up assertion for "input" node
let idx = network.get_node_by_name("link", Some("0")).unwrap().index();
Expand Down
6 changes: 3 additions & 3 deletions pywr-core/src/derived_metric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,13 @@ impl DerivedMetric {
pub fn compute(&self, network: &Network, state: &State) -> Result<f64, PywrError> {
match self {
Self::NodeProportionalVolume(idx) => {
let max_volume = network.get_node(idx)?.get_current_max_volume(network, state)?;
let max_volume = network.get_node(idx)?.get_current_max_volume(state)?;
Ok(state
.get_network_state()
.get_node_proportional_volume(idx, max_volume)?)
}
Self::VirtualStorageProportionalVolume(idx) => {
let max_volume = network.get_virtual_storage_node(idx)?.get_max_volume(network, state)?;
let max_volume = network.get_virtual_storage_node(idx)?.get_max_volume(state)?;
Ok(state
.get_network_state()
.get_virtual_storage_proportional_volume(*idx, max_volume)?)
Expand All @@ -100,7 +100,7 @@ impl DerivedMetric {
let max_volume: f64 = node
.nodes
.iter()
.map(|idx| network.get_node(idx)?.get_current_max_volume(network, state))
.map(|idx| network.get_node(idx)?.get_current_max_volume(state))
.sum::<Result<_, _>>()?;
// TODO handle divide by zero
Ok(volume / max_volume)
Expand Down
36 changes: 31 additions & 5 deletions pywr-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ extern crate core;
use crate::derived_metric::DerivedMetricIndex;
use crate::models::MultiNetworkTransferIndex;
use crate::node::NodeIndex;
use crate::parameters::{InterpolationError, ParameterIndex};
use crate::parameters::{
ConstParameterIndex, GeneralParameterIndex, InterpolationError, ParameterIndex, SimpleParameterIndex,
};
use crate::recorders::{AggregationError, MetricSetIndex, RecorderIndex};
use crate::state::MultiValue;
use crate::virtual_storage::VirtualStorageIndex;
Expand Down Expand Up @@ -49,9 +51,27 @@ pub enum PywrError {
ParameterIndexNotFound(ParameterIndex<f64>),
#[error("index parameter index {0} not found")]
IndexParameterIndexNotFound(ParameterIndex<usize>),
#[error("multi1 value parameter index {0} not found")]
#[error("multi-value parameter index {0} not found")]
MultiValueParameterIndexNotFound(ParameterIndex<MultiValue>),
#[error("multi1 value parameter key {0} not found")]
#[error("parameter index {0} not found")]
GeneralParameterIndexNotFound(GeneralParameterIndex<f64>),
#[error("index parameter index {0} not found")]
GeneralIndexParameterIndexNotFound(GeneralParameterIndex<usize>),
#[error("multi-value parameter index {0} not found")]
GeneralMultiValueParameterIndexNotFound(GeneralParameterIndex<MultiValue>),
#[error("parameter index {0} not found")]
SimpleParameterIndexNotFound(SimpleParameterIndex<f64>),
#[error("index parameter index {0} not found")]
SimpleIndexParameterIndexNotFound(SimpleParameterIndex<usize>),
#[error("multi-value parameter index {0} not found")]
SimpleMultiValueParameterIndexNotFound(SimpleParameterIndex<MultiValue>),
#[error("parameter index {0} not found")]
ConstParameterIndexNotFound(ConstParameterIndex<f64>),
#[error("index parameter index {0} not found")]
ConstIndexParameterIndexNotFound(ConstParameterIndex<usize>),
#[error("multi-value parameter index {0} not found")]
ConstMultiValueParameterIndexNotFound(ConstParameterIndex<MultiValue>),
#[error("multi-value parameter key {0} not found")]
MultiValueParameterKeyNotFound(String),
#[error("inter-network parameter state not initialised")]
InterNetworkParameterStateNotInitialised,
Expand All @@ -73,10 +93,12 @@ pub enum PywrError {
DerivedMetricIndexNotFound(DerivedMetricIndex),
#[error("node name `{0}` already exists")]
NodeNameAlreadyExists(String),
#[error("parameter name `{0}` already exists at index {1}")]
ParameterNameAlreadyExists(String, ParameterIndex<f64>),
#[error("parameter name `{0}` already exists")]
ParameterNameAlreadyExists(String),
#[error("index parameter name `{0}` already exists at index {1}")]
IndexParameterNameAlreadyExists(String, ParameterIndex<usize>),
#[error("multi-value parameter name `{0}` already exists at index {1}")]
MultiValueParameterNameAlreadyExists(String, ParameterIndex<MultiValue>),
#[error("metric set name `{0}` already exists")]
MetricSetNameAlreadyExists(String),
#[error("recorder name `{0}` already exists at index {1}")]
Expand Down Expand Up @@ -161,6 +183,8 @@ pub enum PywrError {
ParameterNoInitialValue,
#[error("parameter state not found for parameter index {0}")]
ParameterStateNotFound(ParameterIndex<f64>),
#[error("parameter state not found for parameter index {0}")]
GeneralParameterStateNotFound(GeneralParameterIndex<f64>),
#[error("Could not create timestep range due to following error: {0}")]
TimestepRangeGenerationError(String),
#[error("Could not create timesteps for frequency '{0}'")]
Expand All @@ -169,6 +193,8 @@ pub enum PywrError {
TimestepDurationMismatch,
#[error("aggregation error: {0}")]
Aggregation(#[from] AggregationError),
#[error("cannot simplify metric")]
CannotSimplifyMetric,
}

// Python errors
Expand Down
Loading