Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor parameters into General, Simple and Constant #194

Merged
merged 8 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,5 @@ tracing = { version = "0.1", features = ["log"] }
csv = "1.1"
hdf5 = { git = "https://github.com/aldanor/hdf5-rust.git", package = "hdf5", features = ["static", "zlib"] }
pywr-v1-schema = { git = "https://github.com/pywr/pywr-schema/", tag = "v0.13.0", package = "pywr-schema" }
chrono = { version = "0.4.34" }
chrono = { version = "0.4.34", features = ["serde"] }
schemars = { version = "0.8.16", features = ["chrono"] }
2 changes: 1 addition & 1 deletion pywr-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ nalgebra = "0.32.3"
chrono = { workspace = true }
polars = { workspace = true }

pyo3 = { workspace = true, features = ["chrono"] }
pyo3 = { workspace = true, features = ["chrono", "macros"] }


rayon = "1.6.1"
Expand Down
19 changes: 8 additions & 11 deletions pywr-core/src/aggregated_node.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::metric::MetricF64;
use crate::network::Network;
use crate::node::{Constraint, ConstraintValue, FlowConstraints, NodeMeta};
use crate::node::{Constraint, FlowConstraints, NodeMeta};
use crate::state::State;
use crate::{NodeIndex, PywrError};
use std::ops::{Deref, DerefMut};
Expand Down Expand Up @@ -112,7 +112,7 @@ impl AggregatedNode {
) -> Self {
Self {
meta: NodeMeta::new(index, name, sub_name),
flow_constraints: FlowConstraints::new(),
flow_constraints: FlowConstraints::default(),
nodes: nodes.to_vec(),
factors,
}
Expand Down Expand Up @@ -174,21 +174,21 @@ impl AggregatedNode {
}
}

pub fn set_min_flow_constraint(&mut self, value: ConstraintValue) {
pub fn set_min_flow_constraint(&mut self, value: Option<MetricF64>) {
self.flow_constraints.min_flow = value;
}
pub fn get_min_flow_constraint(&self, model: &Network, state: &State) -> Result<f64, PywrError> {
self.flow_constraints.get_min_flow(model, state)
}
pub fn set_max_flow_constraint(&mut self, value: ConstraintValue) {
pub fn set_max_flow_constraint(&mut self, value: Option<MetricF64>) {
self.flow_constraints.max_flow = value;
}
pub fn get_max_flow_constraint(&self, model: &Network, state: &State) -> Result<f64, PywrError> {
self.flow_constraints.get_max_flow(model, state)
}

/// Set a constraint on a node.
pub fn set_constraint(&mut self, value: ConstraintValue, constraint: Constraint) -> Result<(), PywrError> {
pub fn set_constraint(&mut self, value: Option<MetricF64>, constraint: Constraint) -> Result<(), PywrError> {
match constraint {
Constraint::MinFlow => self.set_min_flow_constraint(value),
Constraint::MaxFlow => self.set_max_flow_constraint(value),
Expand Down Expand Up @@ -296,7 +296,6 @@ mod tests {
use crate::metric::MetricF64;
use crate::models::Model;
use crate::network::Network;
use crate::node::ConstraintValue;
use crate::recorders::AssertionRecorder;
use crate::test_utils::{default_time_domain, run_all_solvers};
use ndarray::Array2;
Expand All @@ -321,17 +320,15 @@ mod tests {
network.connect_nodes(input_node, link_node1).unwrap();
network.connect_nodes(link_node1, output_node1).unwrap();

let factors = Some(Factors::Ratio(vec![MetricF64::Constant(2.0), MetricF64::Constant(1.0)]));
let factors = Some(Factors::Ratio(vec![2.0.into(), 1.0.into()]));

let _agg_node = network.add_aggregated_node("agg-node", None, &[link_node0, link_node1], factors);

// Setup a demand on output-0
let output_node = network.get_mut_node_by_name("output", Some("0")).unwrap();
output_node
.set_max_flow_constraint(ConstraintValue::Scalar(100.0))
.unwrap();
output_node.set_max_flow_constraint(Some(100.0.into())).unwrap();

output_node.set_cost(ConstraintValue::Scalar(-10.0));
output_node.set_cost(Some((-10.0).into()));

// Set-up assertion for "input" node
let idx = network.get_node_by_name("link", Some("0")).unwrap().index();
Expand Down
6 changes: 3 additions & 3 deletions pywr-core/src/derived_metric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,13 @@ impl DerivedMetric {
pub fn compute(&self, network: &Network, state: &State) -> Result<f64, PywrError> {
match self {
Self::NodeProportionalVolume(idx) => {
let max_volume = network.get_node(idx)?.get_current_max_volume(network, state)?;
let max_volume = network.get_node(idx)?.get_current_max_volume(state)?;
Ok(state
.get_network_state()
.get_node_proportional_volume(idx, max_volume)?)
}
Self::VirtualStorageProportionalVolume(idx) => {
let max_volume = network.get_virtual_storage_node(idx)?.get_max_volume(network, state)?;
let max_volume = network.get_virtual_storage_node(idx)?.get_max_volume(state)?;
Ok(state
.get_network_state()
.get_virtual_storage_proportional_volume(*idx, max_volume)?)
Expand All @@ -100,7 +100,7 @@ impl DerivedMetric {
let max_volume: f64 = node
.nodes
.iter()
.map(|idx| network.get_node(idx)?.get_current_max_volume(network, state))
.map(|idx| network.get_node(idx)?.get_current_max_volume(state))
.sum::<Result<_, _>>()?;
// TODO handle divide by zero
Ok(volume / max_volume)
Expand Down
26 changes: 22 additions & 4 deletions pywr-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ extern crate core;
use crate::derived_metric::DerivedMetricIndex;
use crate::models::MultiNetworkTransferIndex;
use crate::node::NodeIndex;
use crate::parameters::{InterpolationError, ParameterIndex};
use crate::parameters::{GeneralParameterIndex, InterpolationError, ParameterIndex, SimpleParameterIndex};
use crate::recorders::{AggregationError, MetricSetIndex, RecorderIndex};
use crate::state::MultiValue;
use crate::virtual_storage::VirtualStorageIndex;
Expand Down Expand Up @@ -49,9 +49,21 @@ pub enum PywrError {
ParameterIndexNotFound(ParameterIndex<f64>),
#[error("index parameter index {0} not found")]
IndexParameterIndexNotFound(ParameterIndex<usize>),
#[error("multi1 value parameter index {0} not found")]
#[error("multi-value parameter index {0} not found")]
MultiValueParameterIndexNotFound(ParameterIndex<MultiValue>),
#[error("multi1 value parameter key {0} not found")]
#[error("parameter index {0} not found")]
GeneralParameterIndexNotFound(GeneralParameterIndex<f64>),
#[error("index parameter index {0} not found")]
GeneralIndexParameterIndexNotFound(GeneralParameterIndex<usize>),
#[error("multi-value parameter index {0} not found")]
GeneralMultiValueParameterIndexNotFound(GeneralParameterIndex<MultiValue>),
#[error("parameter index {0} not found")]
SimpleParameterIndexNotFound(SimpleParameterIndex<f64>),
#[error("index parameter index {0} not found")]
SimpleIndexParameterIndexNotFound(SimpleParameterIndex<usize>),
#[error("multi-value parameter index {0} not found")]
SimpleMultiValueParameterIndexNotFound(SimpleParameterIndex<MultiValue>),
#[error("multi-value parameter key {0} not found")]
MultiValueParameterKeyNotFound(String),
#[error("inter-network parameter state not initialised")]
InterNetworkParameterStateNotInitialised,
Expand All @@ -76,7 +88,9 @@ pub enum PywrError {
#[error("parameter name `{0}` already exists at index {1}")]
ParameterNameAlreadyExists(String, ParameterIndex<f64>),
#[error("index parameter name `{0}` already exists at index {1}")]
IndexParameterNameAlreadyExists(String, ParameterIndex<usize>),
IndexParameterNameAlreadyExists(String, GeneralParameterIndex<usize>),
#[error("multi-value parameter name `{0}` already exists at index {1}")]
MultiValueParameterNameAlreadyExists(String, GeneralParameterIndex<MultiValue>),
#[error("metric set name `{0}` already exists")]
MetricSetNameAlreadyExists(String),
#[error("recorder name `{0}` already exists at index {1}")]
Expand Down Expand Up @@ -161,12 +175,16 @@ pub enum PywrError {
ParameterNoInitialValue,
#[error("parameter state not found for parameter index {0}")]
ParameterStateNotFound(ParameterIndex<f64>),
#[error("parameter state not found for parameter index {0}")]
GeneralParameterStateNotFound(GeneralParameterIndex<f64>),
#[error("Could not create timestep range due to following error: {0}")]
TimestepRangeGenerationError(String),
#[error("Could not create timesteps for frequency '{0}'")]
TimestepGenerationError(String),
#[error("aggregation error: {0}")]
Aggregation(#[from] AggregationError),
#[error("cannot simplify metric")]
CannotSimplifyMetric,
}

// Python errors
Expand Down
170 changes: 144 additions & 26 deletions pywr-core/src/metric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,42 @@ use crate::edge::EdgeIndex;
use crate::models::MultiNetworkTransferIndex;
use crate::network::Network;
use crate::node::NodeIndex;
use crate::parameters::ParameterIndex;
use crate::state::{MultiValue, State};
use crate::parameters::{GeneralParameterIndex, ParameterIndex, SimpleParameterIndex};
use crate::state::{MultiValue, SimpleParameterValues, State};
use crate::virtual_storage::VirtualStorageIndex;
use crate::PywrError;

#[derive(Clone, Debug, PartialEq)]
pub enum ConstantMetricF64 {
Constant(f64),
}

impl ConstantMetricF64 {
pub fn get_value(&self) -> Result<f64, PywrError> {
match self {
ConstantMetricF64::Constant(v) => Ok(*v),
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub enum SimpleMetricF64 {
ParameterValue(SimpleParameterIndex<f64>),
IndexParameterValue(SimpleParameterIndex<usize>),
MultiParameterValue((SimpleParameterIndex<MultiValue>, String)),
Constant(ConstantMetricF64),
}

impl SimpleMetricF64 {
pub fn get_value(&self, values: &SimpleParameterValues) -> Result<f64, PywrError> {
match self {
SimpleMetricF64::ParameterValue(idx) => Ok(values.get_simple_parameter_f64(*idx)?),
SimpleMetricF64::IndexParameterValue(idx) => Ok(values.get_simple_parameter_usize(*idx)? as f64),
SimpleMetricF64::MultiParameterValue((idx, key)) => Ok(values.get_simple_multi_parameter_f64(*idx, key)?),
SimpleMetricF64::Constant(m) => m.get_value(),
}
}
}

#[derive(Clone, Debug, PartialEq)]
pub enum MetricF64 {
NodeInFlow(NodeIndex),
Expand All @@ -19,16 +50,16 @@ pub enum MetricF64 {
AggregatedNodeOutFlow(AggregatedNodeIndex),
AggregatedNodeVolume(AggregatedStorageNodeIndex),
EdgeFlow(EdgeIndex),
ParameterValue(ParameterIndex<f64>),
IndexParameterValue(ParameterIndex<usize>),
MultiParameterValue((ParameterIndex<MultiValue>, String)),
ParameterValue(GeneralParameterIndex<f64>),
IndexParameterValue(GeneralParameterIndex<usize>),
MultiParameterValue((GeneralParameterIndex<MultiValue>, String)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor point but should these 3 variants be put into a nest enum to be consistent with the simple params?

VirtualStorageVolume(VirtualStorageIndex),
MultiNodeInFlow { indices: Vec<NodeIndex>, name: String },
MultiNodeOutFlow { indices: Vec<NodeIndex>, name: String },
// TODO implement other MultiNodeXXX variants
Constant(f64),
DerivedMetric(DerivedMetricIndex),
InterNetworkTransfer(MultiNetworkTransferIndex),
Simple(SimpleMetricF64),
}

impl MetricF64 {
Expand Down Expand Up @@ -60,7 +91,7 @@ impl MetricF64 {
MetricF64::MultiParameterValue((idx, key)) => Ok(state.get_multi_parameter_value(*idx, key)?),
MetricF64::VirtualStorageVolume(idx) => Ok(state.get_network_state().get_virtual_storage_volume(idx)?),
MetricF64::DerivedMetric(idx) => state.get_derived_metric_value(*idx),
MetricF64::Constant(v) => Ok(*v),

MetricF64::AggregatedNodeVolume(idx) => {
let node = model.get_aggregated_storage_node(idx)?;
node.nodes
Expand All @@ -84,42 +115,129 @@ impl MetricF64 {
Ok(flow)
}
MetricF64::InterNetworkTransfer(idx) => state.get_inter_network_transfer_value(*idx),
MetricF64::Simple(s) => s.get_value(&state.get_simple_parameter_values()),
}
}
}

#[derive(Clone, Debug, PartialEq)]
pub enum MetricUsize {
IndexParameterValue(ParameterIndex<usize>),
Constant(usize),
impl TryFrom<MetricF64> for SimpleMetricF64 {
type Error = PywrError;

fn try_from(value: MetricF64) -> Result<Self, Self::Error> {
match value {
MetricF64::Simple(s) => Ok(s),
_ => Err(PywrError::CannotSimplifyMetric),
}
}
}

impl MetricUsize {
pub fn get_value(&self, _network: &Network, state: &State) -> Result<usize, PywrError> {
match self {
Self::IndexParameterValue(idx) => state.get_parameter_index(*idx),
Self::Constant(i) => Ok(*i),
impl TryFrom<SimpleMetricF64> for ConstantMetricF64 {
type Error = PywrError;

fn try_from(value: SimpleMetricF64) -> Result<Self, Self::Error> {
match value {
SimpleMetricF64::Constant(c) => Ok(c),
_ => Err(PywrError::CannotSimplifyMetric),
}
}
}

impl From<f64> for ConstantMetricF64 {
fn from(v: f64) -> Self {
ConstantMetricF64::Constant(v)
}
}

pub fn name<'a>(&'a self, network: &'a Network) -> Result<&'a str, PywrError> {
match self {
Self::IndexParameterValue(idx) => network.get_index_parameter(idx).map(|p| p.name()),
Self::Constant(_) => Ok(""),
impl<T> From<T> for SimpleMetricF64
where
T: Into<ConstantMetricF64>,
{
fn from(v: T) -> Self {
SimpleMetricF64::Constant(v.into())
}
}
impl<T> From<T> for MetricF64
where
T: Into<SimpleMetricF64>,
{
fn from(v: T) -> Self {
MetricF64::Simple(v.into())
}
}

impl From<ParameterIndex<f64>> for MetricF64 {
fn from(idx: ParameterIndex<f64>) -> Self {
match idx {
ParameterIndex::General(idx) => Self::ParameterValue(idx),
ParameterIndex::Simple(idx) => Self::Simple(SimpleMetricF64::ParameterValue(idx)),
}
}
}

impl From<ParameterIndex<usize>> for MetricF64 {
fn from(idx: ParameterIndex<usize>) -> Self {
match idx {
ParameterIndex::General(idx) => Self::IndexParameterValue(idx),
ParameterIndex::Simple(idx) => Self::Simple(SimpleMetricF64::IndexParameterValue(idx)),
}
}
}

impl TryFrom<ParameterIndex<f64>> for SimpleMetricF64 {
type Error = PywrError;
fn try_from(idx: ParameterIndex<f64>) -> Result<Self, Self::Error> {
match idx {
ParameterIndex::Simple(idx) => Ok(Self::ParameterValue(idx)),
_ => Err(PywrError::CannotSimplifyMetric),
}
}
}

impl TryFrom<ParameterIndex<usize>> for SimpleMetricUsize {
type Error = PywrError;
fn try_from(idx: ParameterIndex<usize>) -> Result<Self, Self::Error> {
match idx {
ParameterIndex::Simple(idx) => Ok(Self::IndexParameterValue(idx)),
_ => Err(PywrError::CannotSimplifyMetric),
}
}
}

pub fn sub_name<'a>(&'a self, _network: &'a Network) -> Result<Option<&'a str>, PywrError> {
#[derive(Clone, Debug, PartialEq)]
pub enum SimpleMetricUsize {
IndexParameterValue(SimpleParameterIndex<usize>),
}

impl SimpleMetricUsize {
pub fn get_value(&self, values: &SimpleParameterValues) -> Result<usize, PywrError> {
match self {
Self::IndexParameterValue(_) => Ok(None),
Self::Constant(_) => Ok(None),
SimpleMetricUsize::IndexParameterValue(idx) => values.get_simple_parameter_usize(*idx),
}
}
}

#[derive(Clone, Debug, PartialEq)]
pub enum MetricUsize {
IndexParameterValue(GeneralParameterIndex<usize>),
Simple(SimpleMetricUsize),
Constant(usize),
}

pub fn attribute(&self) -> &str {
impl MetricUsize {
pub fn get_value(&self, _network: &Network, state: &State) -> Result<usize, PywrError> {
match self {
Self::IndexParameterValue(_) => "value",
Self::Constant(_) => "value",
Self::IndexParameterValue(idx) => state.get_parameter_index(*idx),
Self::Simple(s) => s.get_value(&state.get_simple_parameter_values()),
Self::Constant(i) => Ok(*i),
}
}
}

impl From<ParameterIndex<usize>> for MetricUsize {
fn from(idx: ParameterIndex<usize>) -> Self {
match idx {
ParameterIndex::General(idx) => Self::IndexParameterValue(idx),
ParameterIndex::Simple(idx) => Self::Simple(SimpleMetricUsize::IndexParameterValue(idx)),
}
}
}
Loading