Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor of schema metrics #158

Merged
merged 3 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 0 additions & 66 deletions pywr-core/src/metric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,72 +86,6 @@ impl MetricF64 {
MetricF64::InterNetworkTransfer(idx) => state.get_inter_network_transfer_value(*idx),
}
}
pub fn name<'a>(&'a self, network: &'a Network) -> Result<&'a str, PywrError> {
match self {
Self::NodeInFlow(idx) | Self::NodeOutFlow(idx) | Self::NodeVolume(idx) => {
network.get_node(idx).map(|n| n.name())
}
Self::AggregatedNodeInFlow(idx) | Self::AggregatedNodeOutFlow(idx) => {
network.get_aggregated_node(idx).map(|n| n.name())
}
Self::AggregatedNodeVolume(idx) => network.get_aggregated_storage_node(idx).map(|n| n.name()),
Self::EdgeFlow(idx) => {
let edge = network.get_edge(idx)?;
network.get_node(&edge.from_node_index).map(|n| n.name())
}
Self::ParameterValue(idx) => network.get_parameter(idx).map(|p| p.name()),
Self::IndexParameterValue(idx) => network.get_index_parameter(idx).map(|p| p.name()),
Self::MultiParameterValue((idx, _)) => network.get_multi_valued_parameter(idx).map(|p| p.name()),
Self::VirtualStorageVolume(idx) => network.get_virtual_storage_node(idx).map(|v| v.name()),
Self::MultiNodeInFlow { name, .. } | Self::MultiNodeOutFlow { name, .. } => Ok(name),
Self::Constant(_) => Ok(""),
Self::DerivedMetric(idx) => network.get_derived_metric(idx)?.name(network),
Self::InterNetworkTransfer(_) => todo!("InterNetworkTransfer name is not implemented"),
}
}

pub fn sub_name<'a>(&'a self, network: &'a Network) -> Result<Option<&'a str>, PywrError> {
match self {
Self::NodeInFlow(idx) | Self::NodeOutFlow(idx) | Self::NodeVolume(idx) => {
network.get_node(idx).map(|n| n.sub_name())
}
Self::AggregatedNodeInFlow(idx) | Self::AggregatedNodeOutFlow(idx) => {
network.get_aggregated_node(idx).map(|n| n.sub_name())
}
Self::AggregatedNodeVolume(idx) => network.get_aggregated_storage_node(idx).map(|n| n.sub_name()),
Self::EdgeFlow(idx) => {
let edge = network.get_edge(idx)?;
network.get_node(&edge.to_node_index).map(|n| Some(n.name()))
}
Self::ParameterValue(_) | Self::IndexParameterValue(_) | Self::MultiParameterValue(_) => Ok(None),
Self::VirtualStorageVolume(idx) => network.get_virtual_storage_node(idx).map(|v| v.sub_name()),
Self::MultiNodeInFlow { .. } | Self::MultiNodeOutFlow { .. } => Ok(None),
Self::Constant(_) => Ok(None),
Self::DerivedMetric(idx) => network.get_derived_metric(idx)?.sub_name(network),
Self::InterNetworkTransfer(_) => todo!("InterNetworkTransfer sub_name is not implemented"),
}
}

pub fn attribute(&self) -> &str {
match self {
Self::NodeInFlow(_) => "inflow",
Self::NodeOutFlow(_) => "outflow",
Self::NodeVolume(_) => "volume",
Self::AggregatedNodeInFlow(_) => "inflow",
Self::AggregatedNodeOutFlow(_) => "outflow",
Self::AggregatedNodeVolume(_) => "volume",
Self::EdgeFlow(_) => "edge_flow",
Self::ParameterValue(_) => "value",
Self::IndexParameterValue(_) => "value",
Self::MultiParameterValue(_) => "value",
Self::VirtualStorageVolume(_) => "volume",
Self::MultiNodeInFlow { .. } => "inflow",
Self::MultiNodeOutFlow { .. } => "outflow",
Self::Constant(_) => "value",
Self::DerivedMetric(_) => "value",
Self::InterNetworkTransfer(_) => "value",
}
}
}

#[derive(Clone, Debug, PartialEq)]
Expand Down
20 changes: 3 additions & 17 deletions pywr-core/src/recorders/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,27 +80,21 @@ impl Recorder for CsvWideFmtOutput {
let mut writer = csv::Writer::from_path(&self.filename).map_err(|e| PywrError::CSVError(e.to_string()))?;

let mut names = vec![];
let mut sub_names = vec![];
let mut attributes = vec![];

let metric_set = network.get_metric_set(self.metric_set_idx)?;

for metric in metric_set.iter_metrics() {
let name = metric.name(network)?.to_string();
let sub_name = metric
.sub_name(network)?
.map_or_else(|| "".to_string(), |s| s.to_string());
let name = metric.name().to_string();
let attribute = metric.attribute().to_string();

// Add entries for each scenario
names.push(name);
sub_names.push(sub_name);
attributes.push(attribute);
}

// These are the header rows in the CSV file; we start each
let mut header_name = vec!["node".to_string()];
let mut header_sub_name = vec!["sub-node".to_string()];
let mut header_attribute = vec!["attribute".to_string()];
let mut header_scenario = vec!["global-scenario-index".to_string()];

Expand All @@ -113,7 +107,6 @@ impl Recorder for CsvWideFmtOutput {
for scenario_index in domain.scenarios().indices().iter() {
// Repeat the names, sub-names and attributes for every scenario
header_name.extend(names.clone());
header_sub_name.extend(sub_names.clone());
header_attribute.extend(attributes.clone());
header_scenario.extend(vec![format!("{}", scenario_index.index); names.len()]);

Expand All @@ -125,9 +118,7 @@ impl Recorder for CsvWideFmtOutput {
writer
.write_record(header_name)
.map_err(|e| PywrError::CSVError(e.to_string()))?;
writer
.write_record(header_sub_name)
.map_err(|e| PywrError::CSVError(e.to_string()))?;

writer
.write_record(header_attribute)
.map_err(|e| PywrError::CSVError(e.to_string()))?;
Expand Down Expand Up @@ -200,7 +191,6 @@ pub struct CsvLongFmtRecord {
scenario_index: usize,
metric_set: String,
name: String,
sub_name: String,
attribute: String,
value: f64,
}
Expand Down Expand Up @@ -243,10 +233,7 @@ impl CsvLongFmtOutput {
let metric_set = network.get_metric_set(*metric_set_idx)?;

for (metric, value) in metric_set.iter_metrics().zip(current_values.iter()) {
let name = metric.name(network)?.to_string();
let sub_name = metric
.sub_name(network)?
.map_or_else(|| "".to_string(), |s| s.to_string());
let name = metric.name().to_string();
let attribute = metric.attribute().to_string();

let record = CsvLongFmtRecord {
Expand All @@ -255,7 +242,6 @@ impl CsvLongFmtOutput {
scenario_index: scenario_idx,
metric_set: metric_set.name().to_string(),
name,
sub_name,
attribute,
value: value.value,
};
Expand Down
54 changes: 30 additions & 24 deletions pywr-core/src/recorders/hdf.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{MetricSetState, PywrError, Recorder, RecorderMeta, Timestep};
use super::{MetricSetState, OutputMetric, PywrError, Recorder, RecorderMeta, Timestep};
use crate::models::ModelDomain;
use crate::network::Network;
use crate::recorders::MetricSetIndex;
Expand Down Expand Up @@ -95,12 +95,7 @@ impl Recorder for HDF5Recorder {
let mut datasets = Vec::new();

for metric in metric_set.iter_metrics() {
let name = metric.name(network)?;
let sub_name = metric.sub_name(network)?;
let attribute = metric.attribute();

let ds = require_metric_dataset(root_grp, shape, name, sub_name, attribute)?;

let ds = require_metric_dataset(root_grp, shape, metric)?;
datasets.push(ds);
}

Expand Down Expand Up @@ -176,21 +171,35 @@ fn require_dataset<S: Into<Extents>>(parent: &Group, shape: S, name: &str) -> Re
fn require_metric_dataset<S: Into<Extents>>(
parent: &Group,
shape: S,
name: &str,
sub_name: Option<&str>,
attribute: &str,
metric: &OutputMetric,
) -> Result<hdf5::Dataset, PywrError> {
match sub_name {
None => {
let grp = require_group(parent, name)?;
require_dataset(&grp, shape, attribute)
}
Some(sn) => {
let grp = require_group(parent, name)?;
let grp = require_group(&grp, sn)?;
require_dataset(&grp, shape, attribute)
}
let grp = require_group(parent, metric.name())?;
let ds = require_dataset(&grp, shape, metric.attribute())?;

// Write the type and subtype as attributes
let ty = hdf5::types::VarLenUnicode::from_str(metric.ty()).map_err(|e| PywrError::HDF5Error(e.to_string()))?;
let attr = ds
.new_attr::<hdf5::types::VarLenUnicode>()
.shape(())
.create("pywr-type")
.map_err(|e| PywrError::HDF5Error(e.to_string()))?;
attr.as_writer()
.write_scalar(&ty)
.map_err(|e| PywrError::HDF5Error(e.to_string()))?;

if let Some(sub_type) = metric.sub_type() {
let sub_type =
hdf5::types::VarLenUnicode::from_str(sub_type).map_err(|e| PywrError::HDF5Error(e.to_string()))?;
let attr = ds
.new_attr::<hdf5::types::VarLenUnicode>()
.shape(())
.create("pywr-subtype")
.map_err(|e| PywrError::HDF5Error(e.to_string()))?;
attr.as_writer()
.write_scalar(&sub_type)
.map_err(|e| PywrError::HDF5Error(e.to_string()))?;
}
Ok(ds)
}

fn require_group(parent: &Group, name: &str) -> Result<Group, PywrError> {
Expand All @@ -208,13 +217,10 @@ fn require_group(parent: &Group, name: &str) -> Result<Group, PywrError> {
fn write_pywr_metadata(file: &hdf5::File) -> Result<(), PywrError> {
let root = file.deref();

let grp = require_group(root, "pywr")?;

// Write the Pywr version as an attribute
const VERSION: &str = env!("CARGO_PKG_VERSION");
let version = hdf5::types::VarLenUnicode::from_str(VERSION).map_err(|e| PywrError::HDF5Error(e.to_string()))?;

let attr = grp
let attr = root
.new_attr::<hdf5::types::VarLenUnicode>()
.shape(())
.create("pywr-version")
Expand Down
52 changes: 49 additions & 3 deletions pywr-core/src/recorders/metric_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,52 @@ use std::fmt;
use std::fmt::{Display, Formatter};
use std::ops::Deref;

/// A container for a [`MetricF64`] that retains additional information from the schema.
///
/// This is used to store the name and attribute of the metric so that it can be output in
/// a context that is relevant to the originating schema, and therefore more meaningful to the user.
#[derive(Clone, Debug)]
pub struct OutputMetric {
name: String,
attribute: String,
// The originating type of the metric (e.g. node, parameter, etc.)
ty: String,
// The originating subtype of the metric (e.g. node type, parameter type, etc.)
sub_type: Option<String>,
metric: MetricF64,
}

impl OutputMetric {
pub fn new(name: &str, attribute: &str, ty: &str, sub_type: Option<&str>, metric: MetricF64) -> Self {
Self {
name: name.to_string(),
attribute: attribute.to_string(),
ty: ty.to_string(),
sub_type: sub_type.map(|s| s.to_string()),
metric,
}
}

pub fn get_value(&self, model: &Network, state: &State) -> Result<f64, PywrError> {
self.metric.get_value(model, state)
}

pub fn name(&self) -> &str {
&self.name
}
pub fn attribute(&self) -> &str {
&self.attribute
}

pub fn ty(&self) -> &str {
&self.ty
}

pub fn sub_type(&self) -> Option<&str> {
self.sub_type.as_deref()
}
}

#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Debug)]
pub struct MetricSetIndex(usize);

Expand Down Expand Up @@ -51,11 +97,11 @@ impl MetricSetState {
pub struct MetricSet {
name: String,
aggregator: Option<Aggregator>,
metrics: Vec<MetricF64>,
metrics: Vec<OutputMetric>,
}

impl MetricSet {
pub fn new(name: &str, aggregator: Option<Aggregator>, metrics: Vec<MetricF64>) -> Self {
pub fn new(name: &str, aggregator: Option<Aggregator>, metrics: Vec<OutputMetric>) -> Self {
Self {
name: name.to_string(),
aggregator,
Expand All @@ -67,7 +113,7 @@ impl MetricSet {
pub fn name(&self) -> &str {
&self.name
}
pub fn iter_metrics(&self) -> impl Iterator<Item = &MetricF64> + '_ {
pub fn iter_metrics(&self) -> impl Iterator<Item = &OutputMetric> + '_ {
self.metrics.iter()
}

Expand Down
2 changes: 1 addition & 1 deletion pywr-core/src/recorders/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub use csv::{CsvLongFmtOutput, CsvLongFmtRecord, CsvWideFmtOutput};
use float_cmp::{approx_eq, ApproxEq, F64Margin};
pub use hdf::HDF5Recorder;
pub use memory::{Aggregation, AggregationError, MemoryRecorder};
pub use metric_set::{MetricSet, MetricSetIndex, MetricSetState};
pub use metric_set::{MetricSet, MetricSetIndex, MetricSetState, OutputMetric};
use ndarray::prelude::*;
use ndarray::Array2;
use std::any::Any;
Expand Down
2 changes: 1 addition & 1 deletion pywr-python/tests/models/aggregated-node1/expected.csv
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
date,input1,link1,link2,agg-node,output1
,outflow,outflow,outflow,outflow,inflow
,Outflow,Outflow,Outflow,Outflow,Inflow
2021-01-01,1,1,0,1,1
2021-01-02,2,2,0,2,2
2021-01-03,3,2,1,3,3
Expand Down
Loading