Skip to content

Commit

Permalink
refactor: Tidy-up nodes and parameters modules. (#252)
Browse files Browse the repository at this point in the history
Some small changes to schemas to make things more consistent.
Added additional types to the top-level module to allow them to
be used outside of the crate.
  • Loading branch information
jetuk authored Sep 27, 2024
1 parent 890d6fb commit 2385f7d
Show file tree
Hide file tree
Showing 17 changed files with 193 additions and 112 deletions.
8 changes: 6 additions & 2 deletions pywr-python/tests/models/aggregated-node1/model.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,12 @@
},
"type": "Aggregated",
"nodes": [
"link1",
"link2"
{
"name": "link1"
},
{
"name": "link2"
}
],
"max_flow": {
"type": "Constant",
Expand Down
55 changes: 50 additions & 5 deletions pywr-schema/src/metric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ impl TimeseriesReference {
}
}

/// A reference to a node with an optional attribute.
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema, PywrVisitAll)]
#[serde(deny_unknown_fields)]
pub struct NodeReference {
Expand Down Expand Up @@ -304,12 +305,56 @@ impl NodeReference {
}
}

impl From<String> for NodeReference {
/// A reference to a node without an attribute.
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema, PywrVisitAll)]
pub struct SimpleNodeReference {
/// The name of the node
pub name: String,
}

impl SimpleNodeReference {
pub fn new(name: String) -> Self {
Self { name }
}

#[cfg(feature = "core")]
pub fn load(&self, network: &mut pywr_core::network::Network, args: &LoadArgs) -> Result<MetricF64, SchemaError> {
// This is the associated node in the schema
let node = args
.schema
.get_node_by_name(&self.name)
.ok_or_else(|| SchemaError::NodeNotFound(self.name.clone()))?;

node.create_metric(network, None, args)
}

/// Return the default attribute of the node.
#[cfg(feature = "core")]
pub fn attribute(&self, args: &LoadArgs) -> Result<NodeAttribute, SchemaError> {
// This is the associated node in the schema
let node = args
.schema
.get_node_by_name(&self.name)
.ok_or_else(|| SchemaError::NodeNotFound(self.name.clone()))?;

Ok(node.default_metric())
}

#[cfg(feature = "core")]
pub fn node_type(&self, args: &LoadArgs) -> Result<NodeType, SchemaError> {
// This is the associated node in the schema
let node = args
.schema
.get_node_by_name(&self.name)
.ok_or_else(|| SchemaError::NodeNotFound(self.name.clone()))?;

Ok(node.node_type())
}
}

impl From<String> for SimpleNodeReference {
fn from(v: String) -> Self {
NodeReference {
name: v,
attribute: None,
}
SimpleNodeReference { name: v }
}
}

Expand Down
14 changes: 8 additions & 6 deletions pywr-schema/src/nodes/annual_virtual_storage.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::error::ConversionError;
#[cfg(feature = "core")]
use crate::error::SchemaError;
use crate::metric::Metric;
use crate::metric::{Metric, SimpleNodeReference};
#[cfg(feature = "core")]
use crate::model::LoadArgs;
use crate::nodes::core::StorageInitialVolume;
Expand Down Expand Up @@ -39,7 +39,7 @@ impl Default for AnnualReset {
#[serde(deny_unknown_fields)]
pub struct AnnualVirtualStorageNode {
pub meta: NodeMeta,
pub nodes: Vec<String>,
pub nodes: Vec<SimpleNodeReference>,
pub factors: Option<Vec<f64>>,
pub max_volume: Option<Metric>,
pub min_volume: Option<Metric>,
Expand Down Expand Up @@ -74,10 +74,10 @@ impl AnnualVirtualStorageNode {
let indices = self
.nodes
.iter()
.map(|name| {
.map(|node_ref| {
args.schema
.get_node_by_name(name)
.ok_or_else(|| SchemaError::NodeNotFound(name.to_string()))?
.get_node_by_name(&node_ref.name)
.ok_or_else(|| SchemaError::NodeNotFound(node_ref.name.to_string()))?
.node_indices_for_constraints(network, args)
})
.collect::<Result<Vec<_>, _>>()?
Expand Down Expand Up @@ -188,9 +188,11 @@ impl TryFrom<AnnualVirtualStorageNodeV1> for AnnualVirtualStorageNode {
});
};

let nodes = v1.nodes.into_iter().map(|n| n.into()).collect();

let n = Self {
meta,
nodes: v1.nodes,
nodes,
factors: v1.factors,
max_volume,
min_volume,
Expand Down
34 changes: 19 additions & 15 deletions pywr-schema/src/nodes/core.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::error::ConversionError;
#[cfg(feature = "core")]
use crate::error::SchemaError;
use crate::metric::Metric;
use crate::metric::{Metric, SimpleNodeReference};
#[cfg(feature = "core")]
use crate::model::LoadArgs;
use crate::nodes::{NodeAttribute, NodeMeta};
Expand Down Expand Up @@ -725,7 +725,7 @@ pub enum Relationship {
#[serde(deny_unknown_fields)]
pub struct AggregatedNode {
pub meta: NodeMeta,
pub nodes: Vec<String>,
pub nodes: Vec<SimpleNodeReference>,
pub max_flow: Option<Metric>,
pub min_flow: Option<Metric>,
pub factors: Option<Relationship>,
Expand Down Expand Up @@ -760,10 +760,10 @@ impl AggregatedNode {
let indices = self
.nodes
.iter()
.map(|name| {
.map(|node_ref| {
args.schema
.get_node_by_name(name)
.ok_or_else(|| SchemaError::NodeNotFound(name.to_string()))?
.get_node_by_name(&node_ref.name)
.ok_or_else(|| SchemaError::NodeNotFound(node_ref.name.to_string()))?
.node_indices_for_constraints(network, args)
})
.collect::<Result<Vec<_>, _>>()?
Expand All @@ -776,11 +776,11 @@ impl AggregatedNode {
let nodes: Vec<Vec<_>> = self
.nodes
.iter()
.map(|name| {
.map(|node_ref| {
let node = args
.schema
.get_node_by_name(name)
.ok_or_else(|| SchemaError::NodeNotFound(name.to_string()))?;
.get_node_by_name(&node_ref.name)
.ok_or_else(|| SchemaError::NodeNotFound(node_ref.name.to_string()))?;
node.node_indices_for_constraints(network, args)
})
.collect::<Result<Vec<_>, _>>()?;
Expand Down Expand Up @@ -889,9 +889,11 @@ impl TryFrom<AggregatedNodeV1> for AggregatedNode {
.map(|v| v.try_into_v2_parameter(Some(&meta.name), &mut unnamed_count))
.transpose()?;

let nodes = v1.nodes.into_iter().map(|n| n.into()).collect();

let n = Self {
meta,
nodes: v1.nodes,
nodes,
max_flow,
min_flow,
factors,
Expand All @@ -904,7 +906,7 @@ impl TryFrom<AggregatedNodeV1> for AggregatedNode {
#[serde(deny_unknown_fields)]
pub struct AggregatedStorageNode {
pub meta: NodeMeta,
pub storage_nodes: Vec<String>,
pub storage_nodes: Vec<SimpleNodeReference>,
}

impl AggregatedStorageNode {
Expand Down Expand Up @@ -936,10 +938,10 @@ impl AggregatedStorageNode {
let indices = self
.storage_nodes
.iter()
.map(|name| {
.map(|node_ref| {
args.schema
.get_node_by_name(name)
.ok_or_else(|| SchemaError::NodeNotFound(name.to_string()))?
.get_node_by_name(&node_ref.name)
.ok_or_else(|| SchemaError::NodeNotFound(node_ref.name.to_string()))?
.node_indices_for_constraints(network, args)
})
.collect::<Result<Vec<_>, _>>()?
Expand All @@ -952,7 +954,7 @@ impl AggregatedStorageNode {
let nodes = self
.storage_nodes
.iter()
.map(|name| network.get_node_index_by_name(name, None))
.map(|node_ref| network.get_node_index_by_name(&node_ref.name, None))
.collect::<Result<_, _>>()?;

network.add_aggregated_storage_node(self.meta.name.as_str(), None, nodes)?;
Expand Down Expand Up @@ -993,9 +995,11 @@ impl TryFrom<AggregatedStorageNodeV1> for AggregatedStorageNode {
type Error = ConversionError;

fn try_from(v1: AggregatedStorageNodeV1) -> Result<Self, Self::Error> {
let storage_nodes = v1.storage_nodes.into_iter().map(|n| n.into()).collect();

let n = Self {
meta: v1.meta.into(),
storage_nodes: v1.storage_nodes,
storage_nodes,
};
Ok(n)
}
Expand Down
23 changes: 12 additions & 11 deletions pywr-schema/src/nodes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,18 @@ use crate::metric::Metric;
#[cfg(feature = "core")]
use crate::model::LoadArgs;
use crate::model::PywrNetwork;
pub use crate::nodes::core::{
AggregatedNode, AggregatedStorageNode, CatchmentNode, InputNode, LinkNode, OutputNode, StorageNode,
};
pub use crate::nodes::delay::DelayNode;
pub use crate::nodes::river::RiverNode;
use crate::nodes::rolling_virtual_storage::RollingVirtualStorageNode;
use crate::nodes::turbine::TurbineNode;
use crate::parameters::TimeseriesV1Data;
use crate::visit::{VisitMetrics, VisitPaths};
pub use annual_virtual_storage::AnnualVirtualStorageNode;
pub use loss_link::LossLinkNode;
pub use annual_virtual_storage::{AnnualReset, AnnualVirtualStorageNode};
pub use core::{
AggregatedNode, AggregatedStorageNode, CatchmentNode, InputNode, LinkNode, OutputNode, Relationship,
StorageInitialVolume, StorageNode,
};
pub use delay::DelayNode;
pub use loss_link::{LossFactor, LossLinkNode};
pub use monthly_virtual_storage::MonthlyVirtualStorageNode;
pub use piecewise_link::{PiecewiseLinkNode, PiecewiseLinkStep};
pub use piecewise_storage::PiecewiseStorageNode;
pub use piecewise_storage::{PiecewiseStorageNode, PiecewiseStore};
#[cfg(feature = "core")]
use pywr_core::metric::MetricF64;
use pywr_schema_macros::PywrVisitAll;
Expand All @@ -43,11 +41,14 @@ use pywr_v1_schema::nodes::{
use pywr_v1_schema::parameters::{
CoreParameter as CoreParameterV1, Parameter as ParameterV1, ParameterValue as ParameterValueV1, ParameterValueType,
};
pub use river::RiverNode;
pub use river_gauge::RiverGaugeNode;
pub use river_split_with_gauge::RiverSplitWithGaugeNode;
pub use river_split_with_gauge::{RiverSplit, RiverSplitWithGaugeNode};
pub use rolling_virtual_storage::{RollingVirtualStorageNode, RollingWindow};
use schemars::JsonSchema;
use std::path::{Path, PathBuf};
use strum_macros::{Display, EnumDiscriminants, EnumString, IntoStaticStr, VariantNames};
pub use turbine::{TargetType, TurbineNode};
pub use virtual_storage::VirtualStorageNode;
pub use water_treatment_works::WaterTreatmentWorks;

Expand Down
14 changes: 8 additions & 6 deletions pywr-schema/src/nodes/monthly_virtual_storage.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::error::ConversionError;
#[cfg(feature = "core")]
use crate::error::SchemaError;
use crate::metric::Metric;
use crate::metric::{Metric, SimpleNodeReference};
#[cfg(feature = "core")]
use crate::model::LoadArgs;
use crate::nodes::core::StorageInitialVolume;
Expand Down Expand Up @@ -33,7 +33,7 @@ impl Default for NumberOfMonthsReset {
#[serde(deny_unknown_fields)]
pub struct MonthlyVirtualStorageNode {
pub meta: NodeMeta,
pub nodes: Vec<String>,
pub nodes: Vec<SimpleNodeReference>,
pub factors: Option<Vec<f64>>,
pub max_volume: Option<Metric>,
pub min_volume: Option<Metric>,
Expand Down Expand Up @@ -68,10 +68,10 @@ impl MonthlyVirtualStorageNode {
let indices = self
.nodes
.iter()
.map(|name| {
.map(|node_ref| {
args.schema
.get_node_by_name(name)
.ok_or_else(|| SchemaError::NodeNotFound(name.to_string()))?
.get_node_by_name(&node_ref.name)
.ok_or_else(|| SchemaError::NodeNotFound(node_ref.name.to_string()))?
.node_indices_for_constraints(network, args)
})
.collect::<Result<Vec<_>, _>>()?
Expand Down Expand Up @@ -179,9 +179,11 @@ impl TryFrom<MonthlyVirtualStorageNodeV1> for MonthlyVirtualStorageNode {
});
};

let nodes = v1.nodes.into_iter().map(|n| n.into()).collect();

let n = Self {
meta,
nodes: v1.nodes,
nodes,
factors: v1.factors,
max_volume,
min_volume,
Expand Down
22 changes: 14 additions & 8 deletions pywr-schema/src/nodes/river_split_with_gauge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ use pywr_schema_macros::PywrVisitAll;
use pywr_v1_schema::nodes::RiverSplitWithGaugeNode as RiverSplitWithGaugeNodeV1;
use schemars::JsonSchema;

#[derive(serde::Deserialize, serde::Serialize, Clone, Debug, JsonSchema, PywrVisitAll)]
pub struct RiverSplit {
pub factor: Metric,
pub slot_name: String,
}

#[doc = svgbobdoc::transform!(
/// This is used to represent a proportional split above a minimum residual flow (MRF) at a gauging station.
///
Expand Down Expand Up @@ -40,7 +46,7 @@ pub struct RiverSplitWithGaugeNode {
pub meta: NodeMeta,
pub mrf: Option<Metric>,
pub mrf_cost: Option<Metric>,
pub splits: Vec<(Metric, String)>,
pub splits: Vec<RiverSplit>,
}

impl RiverSplitWithGaugeNode {
Expand Down Expand Up @@ -85,7 +91,7 @@ impl RiverSplitWithGaugeNode {
let i = self
.splits
.iter()
.position(|(_, s)| s == slot)
.position(|split| split.slot_name == slot)
.expect("Invalid slot name!");

vec![(self.meta.name.as_str(), Self::split_sub_name(i))]
Expand Down Expand Up @@ -164,9 +170,9 @@ impl RiverSplitWithGaugeNode {
network.set_node_max_flow(self.meta.name.as_str(), Self::mrf_sub_name(), value.into())?;
}

for (i, (factor, _)) in self.splits.iter().enumerate() {
for (i, split) in self.splits.iter().enumerate() {
// Set the factors for each split
let r = Relationship::new_proportion_factors(&[factor.load(network, args)?]);
let r = Relationship::new_proportion_factors(&[split.factor.load(network, args)?]);
network.set_aggregated_node_relationship(
self.meta.name.as_str(),
Self::split_agg_sub_name(i).as_deref(),
Expand Down Expand Up @@ -246,12 +252,12 @@ impl TryFrom<RiverSplitWithGaugeNodeV1> for RiverSplitWithGaugeNode {
.skip(1)
.zip(v1.slot_names.into_iter().skip(1))
.map(|(f, slot_name)| {
Ok((
f.try_into_v2_parameter(Some(&meta.name), &mut unnamed_count)?,
Ok(RiverSplit {
factor: f.try_into_v2_parameter(Some(&meta.name), &mut unnamed_count)?,
slot_name,
))
})
})
.collect::<Result<Vec<(Metric, String)>, Self::Error>>()?;
.collect::<Result<Vec<_>, Self::Error>>()?;

let n = Self {
meta,
Expand Down
Loading

0 comments on commit 2385f7d

Please sign in to comment.