Skip to content

Commit

Permalink
refactor: Add LoadArgs to capture schema loading arguments.
Browse files Browse the repository at this point in the history
This reduces the number of arguments to schema loading functions,
and will make maintaining those arguments easier in future.
  • Loading branch information
jetuk committed Mar 25, 2024
1 parent 1284937 commit 26cfbcd
Show file tree
Hide file tree
Showing 29 changed files with 355 additions and 810 deletions.
105 changes: 80 additions & 25 deletions pywr-schema/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,15 @@ pub struct Scenario {
pub ensemble_names: Option<Vec<String>>,
}

#[derive(Clone)]
pub struct LoadArgs<'a> {
pub schema: &'a PywrNetwork,
pub domain: &'a ModelDomain,
pub tables: &'a LoadedTableCollection,
pub data_path: Option<&'a Path>,
pub inter_network_transfers: &'a [PywrMultiNetworkTransfer],
}

#[derive(serde::Deserialize, serde::Serialize, Clone, Default)]
pub struct PywrNetwork {
pub nodes: Vec<Node>,
Expand Down Expand Up @@ -180,17 +189,27 @@ impl PywrNetwork {
}
}

pub fn load_tables(&self, data_path: Option<&Path>) -> Result<LoadedTableCollection, SchemaError> {
Ok(LoadedTableCollection::from_schema(self.tables.as_deref(), data_path)?)
}

pub fn build_network(
&self,
domain: &ModelDomain,
data_path: Option<&Path>,
output_path: Option<&Path>,
tables: &LoadedTableCollection,
inter_network_transfers: &[PywrMultiNetworkTransfer],
) -> Result<pywr_core::network::Network, SchemaError> {
let mut network = pywr_core::network::Network::default();

// Load all the data tables
let tables = LoadedTableCollection::from_schema(self.tables.as_deref(), data_path)?;
let args = LoadArgs {
schema: self,
domain,
tables,
data_path,
inter_network_transfers,
};

// Create all the nodes
let mut remaining_nodes = self.nodes.clone();
Expand All @@ -199,9 +218,7 @@ impl PywrNetwork {
let mut failed_nodes: Vec<Node> = Vec::new();
let n = remaining_nodes.len();
for node in remaining_nodes.into_iter() {
if let Err(e) =
node.add_to_model(&mut network, self, domain, &tables, data_path, inter_network_transfers)
{
if let Err(e) = node.add_to_model(&mut network, &args) {
// Adding the node failed!
match e {
SchemaError::PywrCore(core_err) => match core_err {
Expand Down Expand Up @@ -252,9 +269,7 @@ impl PywrNetwork {
let mut failed_parameters: Vec<Parameter> = Vec::new();
let n = remaining_parameters.len();
for parameter in remaining_parameters.into_iter() {
if let Err(e) =
parameter.add_to_model(&mut network, self, domain, &tables, data_path, inter_network_transfers)
{
if let Err(e) = parameter.add_to_model(&mut network, &args) {
// Adding the parameter failed!
match e {
SchemaError::PywrCore(core_err) => match core_err {
Expand All @@ -280,7 +295,7 @@ impl PywrNetwork {

// Apply the inline parameters & constraints to the nodes
for node in &self.nodes {
node.set_constraints(&mut network, self, domain, &tables, data_path, inter_network_transfers)?;
node.set_constraints(&mut network, &args)?;
}

// Create all of the metric sets
Expand Down Expand Up @@ -385,7 +400,10 @@ impl PywrModel {

let domain = ModelDomain::from(timestepper, scenario_collection)?;

let network = self.network.build_network(&domain, data_path, output_path, &[])?;
let tables = self.network.load_tables(data_path)?;
let network = self
.network
.build_network(&domain, data_path, output_path, &tables, &[])?;

let model = pywr_core::models::Model::new(domain, network);

Expand Down Expand Up @@ -583,15 +601,16 @@ impl PywrMultiNetworkModel {
}

let domain = ModelDomain::from(timestepper, scenario_collection)?;
let mut model = pywr_core::models::MultiNetworkModel::new(domain);
let mut schemas = Vec::with_capacity(self.networks.len());
let mut networks = Vec::with_capacity(self.networks.len());
let mut inter_network_transfers = Vec::new();
let mut schemas: Vec<(PywrNetwork, LoadedTableCollection)> = Vec::with_capacity(self.networks.len());

// First load all the networks
// These will contain any parameters that are referenced by the inter-model transfers
// Because of potential circular references, we need to load all the networks first.
for network_entry in &self.networks {
// Load the network itself
let network = match &network_entry.network {
let (network, schema, tables) = match &network_entry.network {
PywrNetworkRef::Path(path) => {
let pth = if let Some(dp) = data_path {
if path.is_relative() {
Expand All @@ -604,44 +623,80 @@ impl PywrMultiNetworkModel {
};

let network_schema = PywrNetwork::from_path(pth)?;
let tables = network_schema.load_tables(data_path)?;
let net = network_schema.build_network(
model.domain(),
&domain,
data_path,
output_path,
&tables,
&network_entry.transfers,
)?;
schemas.push(network_schema);
net

(net, network_schema, tables)
}
PywrNetworkRef::Inline(network_schema) => {
let tables = network_schema.load_tables(data_path)?;
let net = network_schema.build_network(
model.domain(),
&domain,
data_path,
output_path,
&tables,
&network_entry.transfers,
)?;
schemas.push(network_schema.clone());
net

(net, network_schema.clone(), tables)
}
};

model.add_network(&network_entry.name, network);
schemas.push((schema, tables));
networks.push((network_entry.name.clone(), network));
}

// Now load the inter-model transfers
for (to_network_idx, network_entry) in self.networks.iter().enumerate() {
for transfer in &network_entry.transfers {
let from_network_idx = model.get_network_index_by_name(&transfer.from_network)?;

// Load the metric from the "from" network
let from_network = model.network_mut(from_network_idx)?;

let (from_network_idx, from_network) = networks
.iter_mut()
.enumerate()
.find_map(|(idx, (name, net))| {
if name.as_str() == transfer.from_network.as_str() {
Some((idx, net))
} else {
None
}
})
.ok_or_else(|| SchemaError::NetworkNotFound(transfer.from_network.clone()))?;

// The transfer metric will fail to load if it is defined as an inter-model transfer itself.
let from_metric = transfer.metric.load(from_network, &schemas[from_network_idx], &[])?;
let (from_schema, from_tables) = &schemas[from_network_idx];

let args = LoadArgs {
schema: from_schema,
domain: &domain,
tables: from_tables,
data_path,
inter_network_transfers: &[],
};

let from_metric = transfer.metric.load(from_network, &args)?;

model.add_inter_network_transfer(from_network_idx, from_metric, to_network_idx, transfer.initial_value);
inter_network_transfers.push((from_network_idx, from_metric, to_network_idx, transfer.initial_value));
}
}

// Now construct the model from the loaded components
let mut model = pywr_core::models::MultiNetworkModel::new(domain);

for (name, network) in networks {
model.add_network(&name, network);
}

for (from_network_idx, from_metric, to_network_idx, initial_value) in inter_network_transfers {
model.add_inter_network_transfer(from_network_idx, from_metric, to_network_idx, initial_value);
}

Ok(model)
}
}
Expand Down
27 changes: 5 additions & 22 deletions pywr-schema/src/nodes/annual_virtual_storage.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
use crate::data_tables::LoadedTableCollection;
use crate::error::{ConversionError, SchemaError};
use crate::model::PywrMultiNetworkTransfer;
use crate::model::LoadArgs;
use crate::nodes::core::StorageInitialVolume;
use crate::nodes::{NodeAttribute, NodeMeta};
use crate::parameters::{DynamicFloatValue, TryIntoV2Parameter};
use pywr_core::derived_metric::DerivedMetric;
use pywr_core::metric::MetricF64;
use pywr_core::models::ModelDomain;
use pywr_core::node::ConstraintValue;
use pywr_core::virtual_storage::VirtualStorageReset;
use pywr_schema_macros::PywrNode;
use pywr_v1_schema::nodes::AnnualVirtualStorageNode as AnnualVirtualStorageNodeV1;
use std::collections::HashMap;
use std::path::Path;

#[derive(serde::Deserialize, serde::Serialize, Clone)]
pub struct AnnualReset {
Expand Down Expand Up @@ -47,33 +44,19 @@ pub struct AnnualVirtualStorageNode {
impl AnnualVirtualStorageNode {
pub const DEFAULT_ATTRIBUTE: NodeAttribute = NodeAttribute::Volume;

pub fn add_to_model(
&self,
network: &mut pywr_core::network::Network,
schema: &crate::model::PywrNetwork,
domain: &ModelDomain,
tables: &LoadedTableCollection,
data_path: Option<&Path>,
inter_network_transfers: &[PywrMultiNetworkTransfer],
) -> Result<(), SchemaError> {
pub fn add_to_model(&self, network: &mut pywr_core::network::Network, args: &LoadArgs) -> Result<(), SchemaError> {
let cost = match &self.cost {
Some(v) => v
.load(network, schema, domain, tables, data_path, inter_network_transfers)?
.into(),
Some(v) => v.load(network, args)?.into(),
None => ConstraintValue::Scalar(0.0),
};

let min_volume = match &self.min_volume {
Some(v) => v
.load(network, schema, domain, tables, data_path, inter_network_transfers)?
.into(),
Some(v) => v.load(network, args)?.into(),
None => ConstraintValue::Scalar(0.0),
};

let max_volume = match &self.max_volume {
Some(v) => v
.load(network, schema, domain, tables, data_path, inter_network_transfers)?
.into(),
Some(v) => v.load(network, args)?.into(),
None => ConstraintValue::None,
};

Expand Down
Loading

0 comments on commit 26cfbcd

Please sign in to comment.