Skip to content

Commit

Permalink
refactor: Refactor OutputMetric to a general Metric.
Browse files Browse the repository at this point in the history
  • Loading branch information
jetuk committed Apr 15, 2024
1 parent 4c2ee90 commit 8344049
Show file tree
Hide file tree
Showing 69 changed files with 1,638 additions and 1,311 deletions.
66 changes: 0 additions & 66 deletions pywr-core/src/metric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,72 +86,6 @@ impl MetricF64 {
MetricF64::InterNetworkTransfer(idx) => state.get_inter_network_transfer_value(*idx),
}
}
pub fn name<'a>(&'a self, network: &'a Network) -> Result<&'a str, PywrError> {
match self {
Self::NodeInFlow(idx) | Self::NodeOutFlow(idx) | Self::NodeVolume(idx) => {
network.get_node(idx).map(|n| n.name())
}
Self::AggregatedNodeInFlow(idx) | Self::AggregatedNodeOutFlow(idx) => {
network.get_aggregated_node(idx).map(|n| n.name())
}
Self::AggregatedNodeVolume(idx) => network.get_aggregated_storage_node(idx).map(|n| n.name()),
Self::EdgeFlow(idx) => {
let edge = network.get_edge(idx)?;
network.get_node(&edge.from_node_index).map(|n| n.name())
}
Self::ParameterValue(idx) => network.get_parameter(idx).map(|p| p.name()),
Self::IndexParameterValue(idx) => network.get_index_parameter(idx).map(|p| p.name()),
Self::MultiParameterValue((idx, _)) => network.get_multi_valued_parameter(idx).map(|p| p.name()),
Self::VirtualStorageVolume(idx) => network.get_virtual_storage_node(idx).map(|v| v.name()),
Self::MultiNodeInFlow { name, .. } | Self::MultiNodeOutFlow { name, .. } => Ok(name),
Self::Constant(_) => Ok(""),
Self::DerivedMetric(idx) => network.get_derived_metric(idx)?.name(network),
Self::InterNetworkTransfer(_) => todo!("InterNetworkTransfer name is not implemented"),
}
}

pub fn sub_name<'a>(&'a self, network: &'a Network) -> Result<Option<&'a str>, PywrError> {
match self {
Self::NodeInFlow(idx) | Self::NodeOutFlow(idx) | Self::NodeVolume(idx) => {
network.get_node(idx).map(|n| n.sub_name())
}
Self::AggregatedNodeInFlow(idx) | Self::AggregatedNodeOutFlow(idx) => {
network.get_aggregated_node(idx).map(|n| n.sub_name())
}
Self::AggregatedNodeVolume(idx) => network.get_aggregated_storage_node(idx).map(|n| n.sub_name()),
Self::EdgeFlow(idx) => {
let edge = network.get_edge(idx)?;
network.get_node(&edge.to_node_index).map(|n| Some(n.name()))
}
Self::ParameterValue(_) | Self::IndexParameterValue(_) | Self::MultiParameterValue(_) => Ok(None),
Self::VirtualStorageVolume(idx) => network.get_virtual_storage_node(idx).map(|v| v.sub_name()),
Self::MultiNodeInFlow { .. } | Self::MultiNodeOutFlow { .. } => Ok(None),
Self::Constant(_) => Ok(None),
Self::DerivedMetric(idx) => network.get_derived_metric(idx)?.sub_name(network),
Self::InterNetworkTransfer(_) => todo!("InterNetworkTransfer sub_name is not implemented"),
}
}

pub fn attribute(&self) -> &str {
match self {
Self::NodeInFlow(_) => "inflow",
Self::NodeOutFlow(_) => "outflow",
Self::NodeVolume(_) => "volume",
Self::AggregatedNodeInFlow(_) => "inflow",
Self::AggregatedNodeOutFlow(_) => "outflow",
Self::AggregatedNodeVolume(_) => "volume",
Self::EdgeFlow(_) => "edge_flow",
Self::ParameterValue(_) => "value",
Self::IndexParameterValue(_) => "value",
Self::MultiParameterValue(_) => "value",
Self::VirtualStorageVolume(_) => "volume",
Self::MultiNodeInFlow { .. } => "inflow",
Self::MultiNodeOutFlow { .. } => "outflow",
Self::Constant(_) => "value",
Self::DerivedMetric(_) => "value",
Self::InterNetworkTransfer(_) => "value",
}
}
}

#[derive(Clone, Debug, PartialEq)]
Expand Down
20 changes: 3 additions & 17 deletions pywr-core/src/recorders/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,27 +80,21 @@ impl Recorder for CsvWideFmtOutput {
let mut writer = csv::Writer::from_path(&self.filename).map_err(|e| PywrError::CSVError(e.to_string()))?;

let mut names = vec![];
let mut sub_names = vec![];
let mut attributes = vec![];

let metric_set = network.get_metric_set(self.metric_set_idx)?;

for metric in metric_set.iter_metrics() {
let name = metric.name(network)?.to_string();
let sub_name = metric
.sub_name(network)?
.map_or_else(|| "".to_string(), |s| s.to_string());
let name = metric.name().to_string();
let attribute = metric.attribute().to_string();

// Add entries for each scenario
names.push(name);
sub_names.push(sub_name);
attributes.push(attribute);
}

// These are the header rows in the CSV file; we start each
let mut header_name = vec!["node".to_string()];
let mut header_sub_name = vec!["sub-node".to_string()];
let mut header_attribute = vec!["attribute".to_string()];
let mut header_scenario = vec!["global-scenario-index".to_string()];

Expand All @@ -113,7 +107,6 @@ impl Recorder for CsvWideFmtOutput {
for scenario_index in domain.scenarios().indices().iter() {
// Repeat the names, sub-names and attributes for every scenario
header_name.extend(names.clone());
header_sub_name.extend(sub_names.clone());
header_attribute.extend(attributes.clone());
header_scenario.extend(vec![format!("{}", scenario_index.index); names.len()]);

Expand All @@ -125,9 +118,7 @@ impl Recorder for CsvWideFmtOutput {
writer
.write_record(header_name)
.map_err(|e| PywrError::CSVError(e.to_string()))?;
writer
.write_record(header_sub_name)
.map_err(|e| PywrError::CSVError(e.to_string()))?;

writer
.write_record(header_attribute)
.map_err(|e| PywrError::CSVError(e.to_string()))?;
Expand Down Expand Up @@ -200,7 +191,6 @@ pub struct CsvLongFmtRecord {
scenario_index: usize,
metric_set: String,
name: String,
sub_name: String,
attribute: String,
value: f64,
}
Expand Down Expand Up @@ -243,10 +233,7 @@ impl CsvLongFmtOutput {
let metric_set = network.get_metric_set(*metric_set_idx)?;

for (metric, value) in metric_set.iter_metrics().zip(current_values.iter()) {
let name = metric.name(network)?.to_string();
let sub_name = metric
.sub_name(network)?
.map_or_else(|| "".to_string(), |s| s.to_string());
let name = metric.name().to_string();
let attribute = metric.attribute().to_string();

let record = CsvLongFmtRecord {
Expand All @@ -255,7 +242,6 @@ impl CsvLongFmtOutput {
scenario_index: scenario_idx,
metric_set: metric_set.name().to_string(),
name,
sub_name,
attribute,
value: value.value,
};
Expand Down
54 changes: 30 additions & 24 deletions pywr-core/src/recorders/hdf.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{MetricSetState, PywrError, Recorder, RecorderMeta, Timestep};
use super::{MetricSetState, OutputMetric, PywrError, Recorder, RecorderMeta, Timestep};
use crate::models::ModelDomain;
use crate::network::Network;
use crate::recorders::MetricSetIndex;
Expand Down Expand Up @@ -95,12 +95,7 @@ impl Recorder for HDF5Recorder {
let mut datasets = Vec::new();

for metric in metric_set.iter_metrics() {
let name = metric.name(network)?;
let sub_name = metric.sub_name(network)?;
let attribute = metric.attribute();

let ds = require_metric_dataset(root_grp, shape, name, sub_name, attribute)?;

let ds = require_metric_dataset(root_grp, shape, metric)?;
datasets.push(ds);
}

Expand Down Expand Up @@ -176,21 +171,35 @@ fn require_dataset<S: Into<Extents>>(parent: &Group, shape: S, name: &str) -> Re
fn require_metric_dataset<S: Into<Extents>>(
parent: &Group,
shape: S,
name: &str,
sub_name: Option<&str>,
attribute: &str,
metric: &OutputMetric,
) -> Result<hdf5::Dataset, PywrError> {
match sub_name {
None => {
let grp = require_group(parent, name)?;
require_dataset(&grp, shape, attribute)
}
Some(sn) => {
let grp = require_group(parent, name)?;
let grp = require_group(&grp, sn)?;
require_dataset(&grp, shape, attribute)
}
let grp = require_group(parent, metric.name())?;
let ds = require_dataset(&grp, shape, metric.attribute())?;

// Write the type and subtype as attributes
let ty = hdf5::types::VarLenUnicode::from_str(metric.ty()).map_err(|e| PywrError::HDF5Error(e.to_string()))?;
let attr = ds
.new_attr::<hdf5::types::VarLenUnicode>()
.shape(())
.create("pywr-type")
.map_err(|e| PywrError::HDF5Error(e.to_string()))?;
attr.as_writer()
.write_scalar(&ty)
.map_err(|e| PywrError::HDF5Error(e.to_string()))?;

if let Some(sub_type) = metric.sub_type() {
let sub_type =
hdf5::types::VarLenUnicode::from_str(sub_type).map_err(|e| PywrError::HDF5Error(e.to_string()))?;
let attr = ds
.new_attr::<hdf5::types::VarLenUnicode>()
.shape(())
.create("pywr-subtype")
.map_err(|e| PywrError::HDF5Error(e.to_string()))?;
attr.as_writer()
.write_scalar(&sub_type)
.map_err(|e| PywrError::HDF5Error(e.to_string()))?;
}
Ok(ds)
}

fn require_group(parent: &Group, name: &str) -> Result<Group, PywrError> {
Expand All @@ -208,13 +217,10 @@ fn require_group(parent: &Group, name: &str) -> Result<Group, PywrError> {
fn write_pywr_metadata(file: &hdf5::File) -> Result<(), PywrError> {
let root = file.deref();

let grp = require_group(root, "pywr")?;

// Write the Pywr version as an attribute
const VERSION: &str = env!("CARGO_PKG_VERSION");
let version = hdf5::types::VarLenUnicode::from_str(VERSION).map_err(|e| PywrError::HDF5Error(e.to_string()))?;

let attr = grp
let attr = root
.new_attr::<hdf5::types::VarLenUnicode>()
.shape(())
.create("pywr-version")
Expand Down
52 changes: 49 additions & 3 deletions pywr-core/src/recorders/metric_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,52 @@ use std::fmt;
use std::fmt::{Display, Formatter};
use std::ops::Deref;

/// A container for a [`MetricF64`] that retains additional information from the schema.
///
/// This is used to store the name and attribute of the metric so that it can be output in
/// a context that is relevant to the originating schema, and therefore more meaningful to the user.
#[derive(Clone, Debug)]
pub struct OutputMetric {
name: String,
attribute: String,
// The originating type of the metric (e.g. node, parameter, etc.)
ty: String,
// The originating subtype of the metric (e.g. node type, parameter type, etc.)
sub_type: Option<String>,
metric: MetricF64,
}

impl OutputMetric {
pub fn new(name: &str, attribute: &str, ty: &str, sub_type: Option<&str>, metric: MetricF64) -> Self {
Self {
name: name.to_string(),
attribute: attribute.to_string(),
ty: ty.to_string(),
sub_type: sub_type.map(|s| s.to_string()),
metric,
}
}

pub fn get_value(&self, model: &Network, state: &State) -> Result<f64, PywrError> {
self.metric.get_value(model, state)
}

pub fn name(&self) -> &str {
&self.name
}
pub fn attribute(&self) -> &str {
&self.attribute
}

pub fn ty(&self) -> &str {
&self.ty
}

pub fn sub_type(&self) -> Option<&str> {
self.sub_type.as_deref()
}
}

#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Debug)]
pub struct MetricSetIndex(usize);

Expand Down Expand Up @@ -51,11 +97,11 @@ impl MetricSetState {
pub struct MetricSet {
name: String,
aggregator: Option<Aggregator>,
metrics: Vec<MetricF64>,
metrics: Vec<OutputMetric>,
}

impl MetricSet {
pub fn new(name: &str, aggregator: Option<Aggregator>, metrics: Vec<MetricF64>) -> Self {
pub fn new(name: &str, aggregator: Option<Aggregator>, metrics: Vec<OutputMetric>) -> Self {
Self {
name: name.to_string(),
aggregator,
Expand All @@ -67,7 +113,7 @@ impl MetricSet {
pub fn name(&self) -> &str {
&self.name
}
pub fn iter_metrics(&self) -> impl Iterator<Item = &MetricF64> + '_ {
pub fn iter_metrics(&self) -> impl Iterator<Item = &OutputMetric> + '_ {
self.metrics.iter()
}

Expand Down
2 changes: 1 addition & 1 deletion pywr-core/src/recorders/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub use csv::{CsvLongFmtOutput, CsvLongFmtRecord, CsvWideFmtOutput};
use float_cmp::{approx_eq, ApproxEq, F64Margin};
pub use hdf::HDF5Recorder;
pub use memory::{Aggregation, AggregationError, MemoryRecorder};
pub use metric_set::{MetricSet, MetricSetIndex, MetricSetState};
pub use metric_set::{MetricSet, MetricSetIndex, MetricSetState, OutputMetric};
use ndarray::prelude::*;
use ndarray::Array2;
use std::any::Any;
Expand Down
2 changes: 1 addition & 1 deletion pywr-python/tests/models/aggregated-node1/expected.csv
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
date,input1,link1,link2,agg-node,output1
,outflow,outflow,outflow,outflow,inflow
,Outflow,Outflow,Outflow,Outflow,Inflow
2021-01-01,1,1,0,1,1
2021-01-02,2,2,0,2,2
2021-01-03,3,2,1,3,3
Expand Down
Loading

0 comments on commit 8344049

Please sign in to comment.