Skip to content

Commit

Permalink
refactor: Add LoadArgs to capture schema loading arguments. (#148)
Browse files Browse the repository at this point in the history
This reduces the number of arguments to schema loading functions,
and will make maintaining those arguments easier in future.
  • Loading branch information
jetuk authored Mar 27, 2024
1 parent ef16934 commit d9d6361
Show file tree
Hide file tree
Showing 28 changed files with 378 additions and 1,655 deletions.
151 changes: 103 additions & 48 deletions pywr-schema/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,16 @@ pub struct Scenario {
pub ensemble_names: Option<Vec<String>>,
}

#[derive(Clone)]
pub struct LoadArgs<'a> {
pub schema: &'a PywrNetwork,
pub domain: &'a ModelDomain,
pub tables: &'a LoadedTableCollection,
pub timeseries: &'a LoadedTimeseriesCollection,
pub data_path: Option<&'a Path>,
pub inter_network_transfers: &'a [PywrMultiNetworkTransfer],
}

#[derive(serde::Deserialize, serde::Serialize, Clone, Default)]
pub struct PywrNetwork {
pub nodes: Vec<Node>,
Expand Down Expand Up @@ -183,20 +193,41 @@ impl PywrNetwork {
}
}

pub fn load_tables(&self, data_path: Option<&Path>) -> Result<LoadedTableCollection, SchemaError> {
Ok(LoadedTableCollection::from_schema(self.tables.as_deref(), data_path)?)
}

pub fn load_timeseries(
&self,
domain: &ModelDomain,
data_path: Option<&Path>,
) -> Result<LoadedTimeseriesCollection, SchemaError> {
Ok(LoadedTimeseriesCollection::from_schema(
self.timeseries.as_deref(),
domain,
data_path,
)?)
}

pub fn build_network(
&self,
domain: &ModelDomain,
data_path: Option<&Path>,
output_path: Option<&Path>,
tables: &LoadedTableCollection,
timeseries: &LoadedTimeseriesCollection,
inter_network_transfers: &[PywrMultiNetworkTransfer],
) -> Result<pywr_core::network::Network, SchemaError> {
let mut network = pywr_core::network::Network::default();

// Load all the data tables
let tables = LoadedTableCollection::from_schema(self.tables.as_deref(), data_path)?;

// Load all timeseries data
let timeseries = LoadedTimeseriesCollection::from_schema(self.timeseries.as_deref(), domain, data_path)?;
let args = LoadArgs {
schema: self,
domain,
tables,
timeseries,
data_path,
inter_network_transfers,
};

// Create all the nodes
let mut remaining_nodes = self.nodes.clone();
Expand All @@ -205,15 +236,7 @@ impl PywrNetwork {
let mut failed_nodes: Vec<Node> = Vec::new();
let n = remaining_nodes.len();
for node in remaining_nodes.into_iter() {
if let Err(e) = node.add_to_model(
&mut network,
&self,
domain,
&tables,
data_path,
inter_network_transfers,
&timeseries,
) {
if let Err(e) = node.add_to_model(&mut network, &args) {
// Adding the node failed!
match e {
SchemaError::PywrCore(core_err) => match core_err {
Expand Down Expand Up @@ -264,15 +287,7 @@ impl PywrNetwork {
let mut failed_parameters: Vec<Parameter> = Vec::new();
let n = remaining_parameters.len();
for parameter in remaining_parameters.into_iter() {
if let Err(e) = parameter.add_to_model(
&mut network,
self,
domain,
&tables,
data_path,
inter_network_transfers,
&timeseries,
) {
if let Err(e) = parameter.add_to_model(&mut network, &args) {
// Adding the parameter failed!
match e {
SchemaError::PywrCore(core_err) => match core_err {
Expand All @@ -298,15 +313,7 @@ impl PywrNetwork {

// Apply the inline parameters & constraints to the nodes
for node in &self.nodes {
node.set_constraints(
&mut network,
self,
domain,
&tables,
data_path,
inter_network_transfers,
&timeseries,
)?;
node.set_constraints(&mut network, &args)?;
}

// Create all of the metric sets
Expand Down Expand Up @@ -411,7 +418,12 @@ impl PywrModel {

let domain = ModelDomain::from(timestepper, scenario_collection)?;

let network = self.network.build_network(&domain, data_path, output_path, &[])?;
let tables = self.network.load_tables(data_path)?;
let timeseries = self.network.load_timeseries(&domain, data_path)?;

let network = self
.network
.build_network(&domain, data_path, output_path, &tables, &timeseries, &[])?;

let model = pywr_core::models::Model::new(domain, network);

Expand Down Expand Up @@ -619,15 +631,17 @@ impl PywrMultiNetworkModel {
}

let domain = ModelDomain::from(timestepper, scenario_collection)?;
let mut model = pywr_core::models::MultiNetworkModel::new(domain);
let mut schemas = Vec::with_capacity(self.networks.len());
let mut networks = Vec::with_capacity(self.networks.len());
let mut inter_network_transfers = Vec::new();
let mut schemas: Vec<(PywrNetwork, LoadedTableCollection, LoadedTimeseriesCollection)> =
Vec::with_capacity(self.networks.len());

// First load all the networks
// These will contain any parameters that are referenced by the inter-model transfers
// Because of potential circular references, we need to load all the networks first.
for network_entry in &self.networks {
// Load the network itself
let network = match &network_entry.network {
let (network, schema, tables, timeseries) = match &network_entry.network {
PywrNetworkRef::Path(path) => {
let pth = if let Some(dp) = data_path {
if path.is_relative() {
Expand All @@ -640,44 +654,85 @@ impl PywrMultiNetworkModel {
};

let network_schema = PywrNetwork::from_path(pth)?;
let tables = network_schema.load_tables(data_path)?;
let timeseries = network_schema.load_timeseries(&domain, data_path)?;
let net = network_schema.build_network(
model.domain(),
&domain,
data_path,
output_path,
&tables,
&timeseries,
&network_entry.transfers,
)?;
schemas.push(network_schema);
net

(net, network_schema, tables, timeseries)
}
PywrNetworkRef::Inline(network_schema) => {
let tables = network_schema.load_tables(data_path)?;
let timeseries = network_schema.load_timeseries(&domain, data_path)?;
let net = network_schema.build_network(
model.domain(),
&domain,
data_path,
output_path,
&tables,
&timeseries,
&network_entry.transfers,
)?;
schemas.push(network_schema.clone());
net

(net, network_schema.clone(), tables, timeseries)
}
};

model.add_network(&network_entry.name, network);
schemas.push((schema, tables, timeseries));
networks.push((network_entry.name.clone(), network));
}

// Now load the inter-model transfers
for (to_network_idx, network_entry) in self.networks.iter().enumerate() {
for transfer in &network_entry.transfers {
let from_network_idx = model.get_network_index_by_name(&transfer.from_network)?;

// Load the metric from the "from" network
let from_network = model.network_mut(from_network_idx)?;

let (from_network_idx, from_network) = networks
.iter_mut()
.enumerate()
.find_map(|(idx, (name, net))| {
if name.as_str() == transfer.from_network.as_str() {
Some((idx, net))
} else {
None
}
})
.ok_or_else(|| SchemaError::NetworkNotFound(transfer.from_network.clone()))?;

// The transfer metric will fail to load if it is defined as an inter-model transfer itself.
let from_metric = transfer.metric.load(from_network, &schemas[from_network_idx], &[])?;
let (from_schema, from_tables, from_timeseries) = &schemas[from_network_idx];

model.add_inter_network_transfer(from_network_idx, from_metric, to_network_idx, transfer.initial_value);
let args = LoadArgs {
schema: from_schema,
domain: &domain,
tables: from_tables,
timeseries: from_timeseries,
data_path,
inter_network_transfers: &[],
};

let from_metric = transfer.metric.load(from_network, &args)?;

inter_network_transfers.push((from_network_idx, from_metric, to_network_idx, transfer.initial_value));
}
}

// Now construct the model from the loaded components
let mut model = pywr_core::models::MultiNetworkModel::new(domain);

for (name, network) in networks {
model.add_network(&name, network);
}

for (from_network_idx, from_metric, to_network_idx, initial_value) in inter_network_transfers {
model.add_inter_network_transfer(from_network_idx, from_metric, to_network_idx, initial_value);
}

Ok(model)
}
}
Expand Down
53 changes: 5 additions & 48 deletions pywr-schema/src/nodes/annual_virtual_storage.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
use crate::data_tables::LoadedTableCollection;
use crate::error::{ConversionError, SchemaError};
use crate::model::PywrMultiNetworkTransfer;
use crate::model::LoadArgs;
use crate::nodes::core::StorageInitialVolume;
use crate::nodes::{NodeAttribute, NodeMeta};
use crate::parameters::{DynamicFloatValue, TryIntoV2Parameter};
use crate::timeseries::LoadedTimeseriesCollection;
use pywr_core::derived_metric::DerivedMetric;
use pywr_core::metric::MetricF64;
use pywr_core::models::ModelDomain;
use pywr_core::node::ConstraintValue;
use pywr_core::virtual_storage::VirtualStorageReset;
use pywr_schema_macros::PywrNode;
use pywr_v1_schema::nodes::AnnualVirtualStorageNode as AnnualVirtualStorageNodeV1;
use std::collections::HashMap;
use std::path::Path;

#[derive(serde::Deserialize, serde::Serialize, Clone, Debug)]
pub struct AnnualReset {
Expand Down Expand Up @@ -48,58 +44,19 @@ pub struct AnnualVirtualStorageNode {
impl AnnualVirtualStorageNode {
pub const DEFAULT_ATTRIBUTE: NodeAttribute = NodeAttribute::Volume;

pub fn add_to_model(
&self,
network: &mut pywr_core::network::Network,
schema: &crate::model::PywrNetwork,
domain: &ModelDomain,
tables: &LoadedTableCollection,
data_path: Option<&Path>,
inter_network_transfers: &[PywrMultiNetworkTransfer],
timeseries: &LoadedTimeseriesCollection,
) -> Result<(), SchemaError> {
pub fn add_to_model(&self, network: &mut pywr_core::network::Network, args: &LoadArgs) -> Result<(), SchemaError> {
let cost = match &self.cost {
Some(v) => v
.load(
network,
schema,
domain,
tables,
data_path,
inter_network_transfers,
timeseries,
)?
.into(),
Some(v) => v.load(network, args)?.into(),
None => ConstraintValue::Scalar(0.0),
};

let min_volume = match &self.min_volume {
Some(v) => v
.load(
network,
schema,
domain,
tables,
data_path,
inter_network_transfers,
timeseries,
)?
.into(),
Some(v) => v.load(network, args)?.into(),
None => ConstraintValue::Scalar(0.0),
};

let max_volume = match &self.max_volume {
Some(v) => v
.load(
network,
schema,
domain,
tables,
data_path,
inter_network_transfers,
timeseries,
)?
.into(),
Some(v) => v.load(network, args)?.into(),
None => ConstraintValue::None,
};

Expand Down
Loading

0 comments on commit d9d6361

Please sign in to comment.