Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Add LoadArgs to capture schema loading arguments. #148

Merged
merged 2 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 103 additions & 48 deletions pywr-schema/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,16 @@ pub struct Scenario {
pub ensemble_names: Option<Vec<String>>,
}

#[derive(Clone)]
pub struct LoadArgs<'a> {
pub schema: &'a PywrNetwork,
pub domain: &'a ModelDomain,
pub tables: &'a LoadedTableCollection,
pub timeseries: &'a LoadedTimeseriesCollection,
pub data_path: Option<&'a Path>,
pub inter_network_transfers: &'a [PywrMultiNetworkTransfer],
}

#[derive(serde::Deserialize, serde::Serialize, Clone, Default)]
pub struct PywrNetwork {
pub nodes: Vec<Node>,
Expand Down Expand Up @@ -183,20 +193,41 @@ impl PywrNetwork {
}
}

pub fn load_tables(&self, data_path: Option<&Path>) -> Result<LoadedTableCollection, SchemaError> {
Ok(LoadedTableCollection::from_schema(self.tables.as_deref(), data_path)?)
}

pub fn load_timeseries(
&self,
domain: &ModelDomain,
data_path: Option<&Path>,
) -> Result<LoadedTimeseriesCollection, SchemaError> {
Ok(LoadedTimeseriesCollection::from_schema(
self.timeseries.as_deref(),
domain,
data_path,
)?)
}

pub fn build_network(
&self,
domain: &ModelDomain,
data_path: Option<&Path>,
output_path: Option<&Path>,
tables: &LoadedTableCollection,
timeseries: &LoadedTimeseriesCollection,
inter_network_transfers: &[PywrMultiNetworkTransfer],
) -> Result<pywr_core::network::Network, SchemaError> {
let mut network = pywr_core::network::Network::default();

// Load all the data tables
let tables = LoadedTableCollection::from_schema(self.tables.as_deref(), data_path)?;

// Load all timeseries data
let timeseries = LoadedTimeseriesCollection::from_schema(self.timeseries.as_deref(), domain, data_path)?;
let args = LoadArgs {
schema: self,
domain,
tables,
timeseries,
data_path,
inter_network_transfers,
};

// Create all the nodes
let mut remaining_nodes = self.nodes.clone();
Expand All @@ -205,15 +236,7 @@ impl PywrNetwork {
let mut failed_nodes: Vec<Node> = Vec::new();
let n = remaining_nodes.len();
for node in remaining_nodes.into_iter() {
if let Err(e) = node.add_to_model(
&mut network,
&self,
domain,
&tables,
data_path,
inter_network_transfers,
&timeseries,
) {
if let Err(e) = node.add_to_model(&mut network, &args) {
// Adding the node failed!
match e {
SchemaError::PywrCore(core_err) => match core_err {
Expand Down Expand Up @@ -264,15 +287,7 @@ impl PywrNetwork {
let mut failed_parameters: Vec<Parameter> = Vec::new();
let n = remaining_parameters.len();
for parameter in remaining_parameters.into_iter() {
if let Err(e) = parameter.add_to_model(
&mut network,
self,
domain,
&tables,
data_path,
inter_network_transfers,
&timeseries,
) {
if let Err(e) = parameter.add_to_model(&mut network, &args) {
// Adding the parameter failed!
match e {
SchemaError::PywrCore(core_err) => match core_err {
Expand All @@ -298,15 +313,7 @@ impl PywrNetwork {

// Apply the inline parameters & constraints to the nodes
for node in &self.nodes {
node.set_constraints(
&mut network,
self,
domain,
&tables,
data_path,
inter_network_transfers,
&timeseries,
)?;
node.set_constraints(&mut network, &args)?;
}

// Create all of the metric sets
Expand Down Expand Up @@ -411,7 +418,12 @@ impl PywrModel {

let domain = ModelDomain::from(timestepper, scenario_collection)?;

let network = self.network.build_network(&domain, data_path, output_path, &[])?;
let tables = self.network.load_tables(data_path)?;
let timeseries = self.network.load_timeseries(&domain, data_path)?;

let network = self
.network
.build_network(&domain, data_path, output_path, &tables, &timeseries, &[])?;

let model = pywr_core::models::Model::new(domain, network);

Expand Down Expand Up @@ -619,15 +631,17 @@ impl PywrMultiNetworkModel {
}

let domain = ModelDomain::from(timestepper, scenario_collection)?;
let mut model = pywr_core::models::MultiNetworkModel::new(domain);
let mut schemas = Vec::with_capacity(self.networks.len());
let mut networks = Vec::with_capacity(self.networks.len());
let mut inter_network_transfers = Vec::new();
let mut schemas: Vec<(PywrNetwork, LoadedTableCollection, LoadedTimeseriesCollection)> =
Vec::with_capacity(self.networks.len());

// First load all the networks
// These will contain any parameters that are referenced by the inter-model transfers
// Because of potential circular references, we need to load all the networks first.
for network_entry in &self.networks {
// Load the network itself
let network = match &network_entry.network {
let (network, schema, tables, timeseries) = match &network_entry.network {
PywrNetworkRef::Path(path) => {
let pth = if let Some(dp) = data_path {
if path.is_relative() {
Expand All @@ -640,44 +654,85 @@ impl PywrMultiNetworkModel {
};

let network_schema = PywrNetwork::from_path(pth)?;
let tables = network_schema.load_tables(data_path)?;
let timeseries = network_schema.load_timeseries(&domain, data_path)?;
let net = network_schema.build_network(
model.domain(),
&domain,
data_path,
output_path,
&tables,
&timeseries,
&network_entry.transfers,
)?;
schemas.push(network_schema);
net

(net, network_schema, tables, timeseries)
}
PywrNetworkRef::Inline(network_schema) => {
let tables = network_schema.load_tables(data_path)?;
let timeseries = network_schema.load_timeseries(&domain, data_path)?;
let net = network_schema.build_network(
model.domain(),
&domain,
data_path,
output_path,
&tables,
&timeseries,
&network_entry.transfers,
)?;
schemas.push(network_schema.clone());
net

(net, network_schema.clone(), tables, timeseries)
}
};

model.add_network(&network_entry.name, network);
schemas.push((schema, tables, timeseries));
networks.push((network_entry.name.clone(), network));
}

// Now load the inter-model transfers
for (to_network_idx, network_entry) in self.networks.iter().enumerate() {
for transfer in &network_entry.transfers {
let from_network_idx = model.get_network_index_by_name(&transfer.from_network)?;

// Load the metric from the "from" network
let from_network = model.network_mut(from_network_idx)?;

let (from_network_idx, from_network) = networks
.iter_mut()
.enumerate()
.find_map(|(idx, (name, net))| {
if name.as_str() == transfer.from_network.as_str() {
Some((idx, net))
} else {
None
}
})
.ok_or_else(|| SchemaError::NetworkNotFound(transfer.from_network.clone()))?;

// The transfer metric will fail to load if it is defined as an inter-model transfer itself.
let from_metric = transfer.metric.load(from_network, &schemas[from_network_idx], &[])?;
let (from_schema, from_tables, from_timeseries) = &schemas[from_network_idx];

model.add_inter_network_transfer(from_network_idx, from_metric, to_network_idx, transfer.initial_value);
let args = LoadArgs {
schema: from_schema,
domain: &domain,
tables: from_tables,
timeseries: from_timeseries,
data_path,
inter_network_transfers: &[],
};

let from_metric = transfer.metric.load(from_network, &args)?;

inter_network_transfers.push((from_network_idx, from_metric, to_network_idx, transfer.initial_value));
}
}

// Now construct the model from the loaded components
let mut model = pywr_core::models::MultiNetworkModel::new(domain);

for (name, network) in networks {
model.add_network(&name, network);
}

for (from_network_idx, from_metric, to_network_idx, initial_value) in inter_network_transfers {
model.add_inter_network_transfer(from_network_idx, from_metric, to_network_idx, initial_value);
}

Ok(model)
}
}
Expand Down
53 changes: 5 additions & 48 deletions pywr-schema/src/nodes/annual_virtual_storage.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
use crate::data_tables::LoadedTableCollection;
use crate::error::{ConversionError, SchemaError};
use crate::model::PywrMultiNetworkTransfer;
use crate::model::LoadArgs;
use crate::nodes::core::StorageInitialVolume;
use crate::nodes::{NodeAttribute, NodeMeta};
use crate::parameters::{DynamicFloatValue, TryIntoV2Parameter};
use crate::timeseries::LoadedTimeseriesCollection;
use pywr_core::derived_metric::DerivedMetric;
use pywr_core::metric::MetricF64;
use pywr_core::models::ModelDomain;
use pywr_core::node::ConstraintValue;
use pywr_core::virtual_storage::VirtualStorageReset;
use pywr_schema_macros::PywrNode;
use pywr_v1_schema::nodes::AnnualVirtualStorageNode as AnnualVirtualStorageNodeV1;
use std::collections::HashMap;
use std::path::Path;

#[derive(serde::Deserialize, serde::Serialize, Clone, Debug)]
pub struct AnnualReset {
Expand Down Expand Up @@ -48,58 +44,19 @@ pub struct AnnualVirtualStorageNode {
impl AnnualVirtualStorageNode {
pub const DEFAULT_ATTRIBUTE: NodeAttribute = NodeAttribute::Volume;

pub fn add_to_model(
&self,
network: &mut pywr_core::network::Network,
schema: &crate::model::PywrNetwork,
domain: &ModelDomain,
tables: &LoadedTableCollection,
data_path: Option<&Path>,
inter_network_transfers: &[PywrMultiNetworkTransfer],
timeseries: &LoadedTimeseriesCollection,
) -> Result<(), SchemaError> {
pub fn add_to_model(&self, network: &mut pywr_core::network::Network, args: &LoadArgs) -> Result<(), SchemaError> {
let cost = match &self.cost {
Some(v) => v
.load(
network,
schema,
domain,
tables,
data_path,
inter_network_transfers,
timeseries,
)?
.into(),
Some(v) => v.load(network, args)?.into(),
None => ConstraintValue::Scalar(0.0),
};

let min_volume = match &self.min_volume {
Some(v) => v
.load(
network,
schema,
domain,
tables,
data_path,
inter_network_transfers,
timeseries,
)?
.into(),
Some(v) => v.load(network, args)?.into(),
None => ConstraintValue::Scalar(0.0),
};

let max_volume = match &self.max_volume {
Some(v) => v
.load(
network,
schema,
domain,
tables,
data_path,
inter_network_transfers,
timeseries,
)?
.into(),
Some(v) => v.load(network, args)?.into(),
None => ConstraintValue::None,
};

Expand Down
Loading