Skip to content

Commit

Permalink
fix: Fix an issue accumulating aggregated metric values.
Browse files Browse the repository at this point in the history
The issue was caused by the lazy evaluation of `.map()` when
collecting into an `Option`. This is now done explicitly.
  • Loading branch information
jetuk committed Mar 25, 2024
1 parent 1284937 commit 821a914
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 35 deletions.
2 changes: 1 addition & 1 deletion pywr-core/src/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ impl Network {

// Setup recorders
for (recorder, internal_state) in self.recorders.iter().zip(recorder_internal_states) {
recorder.finalise(metric_set_states, internal_state)?;
recorder.finalise(self, metric_set_states, internal_state)?;
}

Ok(())
Expand Down
198 changes: 173 additions & 25 deletions pywr-core/src/recorders/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::path::PathBuf;

/// Output the values from a [`MetricSet`] to a CSV file.
#[derive(Clone, Debug)]
pub struct CSVRecorder {
pub struct CsvShortFmtOutput {
meta: RecorderMeta,
filename: PathBuf,
metric_set_idx: MetricSetIndex,
Expand All @@ -21,17 +21,51 @@ struct Internal {
writer: csv::Writer<File>,
}

impl CSVRecorder {
impl CsvShortFmtOutput {
pub fn new<P: Into<PathBuf>>(name: &str, filename: P, metric_set_idx: MetricSetIndex) -> Self {
Self {
meta: RecorderMeta::new(name),
filename: filename.into(),
metric_set_idx,
}
}

fn write_values(
&self,
timestep: String,
metric_set_states: &[Vec<MetricSetState>],
internal: &mut Internal,
) -> Result<(), PywrError> {
let mut row = vec![timestep];

// Iterate through all the scenario's state
for ms_scenario_states in metric_set_states.iter() {
let metric_set_state = ms_scenario_states
.get(*self.metric_set_idx.deref())
.ok_or(PywrError::MetricSetIndexNotFound(self.metric_set_idx))?;

if let Some(current_values) = metric_set_state.current_values() {
let values = current_values
.iter()
.map(|v| format!("{:.2}", v.value))
.collect::<Vec<_>>();

row.extend(values);
}
}

// Only write
if row.len() > 1 {
internal
.writer
.write_record(row)
.map_err(|e| PywrError::CSVError(e.to_string()))?;
}
Ok(())
}
}

impl Recorder for CSVRecorder {
impl Recorder for CsvShortFmtOutput {
fn meta(&self) -> &RecorderMeta {
&self.meta
}
Expand Down Expand Up @@ -125,44 +159,158 @@ impl Recorder for CSVRecorder {
None => panic!("No internal state defined when one was expected! :("),
};

let mut row = vec![timestep.date.to_string()];
self.write_values(timestep.date.to_string(), metric_set_states, internal)?;

// Iterate through all of the scenario's state
for ms_scenario_states in metric_set_states.iter() {
let metric_set_state = ms_scenario_states
.get(*self.metric_set_idx.deref())
.ok_or(PywrError::MetricSetIndexNotFound(self.metric_set_idx))?;

if let Some(current_values) = metric_set_state.current_values() {
let values = current_values
.iter()
.map(|v| format!("{:.2}", v.value))
.collect::<Vec<_>>();
Ok(())
}

row.extend(values);
fn finalise(
&self,
_network: &Network,
metric_set_states: &[Vec<MetricSetState>],
internal_state: &mut Option<Box<dyn Any>>,
) -> Result<(), PywrError> {
// This will leave the internal state with a `None` because we need to take
// ownership of the file handle in order to close it.
match internal_state.take() {
Some(mut internal) => {
if let Some(internal) = internal.downcast_mut::<Internal>() {
self.write_values("end".to_string(), metric_set_states, internal)?;
Ok(())
} else {
panic!("Internal state did not downcast to the correct type! :(");
}
}
None => panic!("No internal state defined when one was expected! :("),
}
}
}

// Only write
if row.len() > 1 {
internal
.writer
.write_record(row)
.map_err(|e| PywrError::CSVError(e.to_string()))?;
/// Output the values from a several [`MetricSet`]s to a CSV file in long format.
///
#[derive(Clone, Debug)]
pub struct CsvLongFmtOutput {
meta: RecorderMeta,
filename: PathBuf,
metric_set_indices: Vec<MetricSetIndex>,
}

impl CsvLongFmtOutput {
pub fn new<P: Into<PathBuf>>(name: &str, filename: P, metric_set_indices: Vec<MetricSetIndex>) -> Self {
Self {
meta: RecorderMeta::new(name),
filename: filename.into(),
metric_set_indices,
}
}

fn write_values(
&self,
timestep: &str,
network: &Network,
metric_set_states: &[Vec<MetricSetState>],
internal: &mut Internal,
) -> Result<(), PywrError> {
// Iterate through all the scenario's state
for (scenario_idx, ms_scenario_states) in metric_set_states.iter().enumerate() {
for metric_set_idx in self.metric_set_indices.iter() {
let metric_set_state = ms_scenario_states
.get(*metric_set_idx.deref())
.ok_or(PywrError::MetricSetIndexNotFound(*metric_set_idx))?;

if let Some(current_values) = metric_set_state.current_values() {
let metric_set = network.get_metric_set(*metric_set_idx)?;

for (metric, value) in metric_set.iter_metrics().zip(current_values.iter()) {
let name = metric.name(network)?.to_string();
let sub_name = metric
.sub_name(network)?
.map_or_else(|| "".to_string(), |s| s.to_string());
let attribute = metric.attribute().to_string();

let row = vec![
value.start.to_string(),
format!("{}", scenario_idx),
metric_set.name().to_string(),
name,
sub_name,
attribute,
format!("{:.2}", value.value),
];

internal
.writer
.write_record(row)
.map_err(|e| PywrError::CSVError(e.to_string()))?;
}
}
}
}

Ok(())
}
}

impl Recorder for CsvLongFmtOutput {
fn meta(&self) -> &RecorderMeta {
&self.meta
}
fn setup(&self, _domain: &ModelDomain, _network: &Network) -> Result<Option<Box<(dyn Any)>>, PywrError> {
let mut writer = csv::Writer::from_path(&self.filename).map_err(|e| PywrError::CSVError(e.to_string()))?;

let header = vec![
"timestep".to_string(),
"scenario_index".to_string(),
"metric_set".to_string(),
"node".to_string(),
"sub_node".to_string(),
"attribute".to_string(),
"value".to_string(),
];

writer
.write_record(header)
.map_err(|e| PywrError::CSVError(e.to_string()))?;

let internal = Internal { writer };

Ok(Some(Box::new(internal)))
}

fn save(
&self,
timestep: &Timestep,
_scenario_indices: &[ScenarioIndex],
network: &Network,
_state: &[State],
metric_set_states: &[Vec<MetricSetState>],
internal_state: &mut Option<Box<dyn Any>>,
) -> Result<(), PywrError> {
let internal = match internal_state {
Some(internal) => match internal.downcast_mut::<Internal>() {
Some(pa) => pa,
None => panic!("Internal state did not downcast to the correct type! :("),
},
None => panic!("No internal state defined when one was expected! :("),
};

self.write_values(&timestep.date.to_string(), network, metric_set_states, internal)?;

Ok(())
}

fn finalise(
&self,
_metric_set_states: &[Vec<MetricSetState>],
network: &Network,
metric_set_states: &[Vec<MetricSetState>],
internal_state: &mut Option<Box<dyn Any>>,
) -> Result<(), PywrError> {
// This will leave the internal state with a `None` because we need to take
// ownership of the file handle in order to close it.
match internal_state.take() {
Some(internal) => {
if let Ok(_internal) = internal.downcast::<Internal>() {
Some(mut internal) => {
if let Some(internal) = internal.downcast_mut::<Internal>() {
self.write_values("end", network, metric_set_states, internal)?;
Ok(())
} else {
panic!("Internal state did not downcast to the correct type! :(");
Expand Down
1 change: 1 addition & 0 deletions pywr-core/src/recorders/hdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ impl Recorder for HDF5Recorder {

fn finalise(
&self,
_network: &Network,
_metric_set_states: &[Vec<MetricSetState>],
internal_state: &mut Option<Box<dyn Any>>,
) -> Result<(), PywrError> {
Expand Down
1 change: 1 addition & 0 deletions pywr-core/src/recorders/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ impl Recorder for MemoryRecorder {

fn finalise(
&self,
_network: &Network,
metric_set_states: &[Vec<MetricSetState>],
internal_state: &mut Option<Box<dyn Any>>,
) -> Result<(), PywrError> {
Expand Down
26 changes: 21 additions & 5 deletions pywr-core/src/recorders/metric_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,27 @@ impl MetricSet {
.as_mut()
.expect("Aggregation state expected for metric set with aggregator!");

let agg_values = values
.into_iter()
.zip(aggregation_states.iter_mut())
.map(|(value, current_state)| aggregator.append_value(current_state, value))
.collect::<Option<Vec<_>>>();
// Collect any aggregated values. This will remain empty if the aggregator yields
// no values. However, if there are values we will expect the same number of aggregated
// values as the input values / metrics.
let mut agg_values = Vec::with_capacity(values.len());
// Use a for loop instead of using an iterator because we need to execute the
// `append_value` method on all aggregators.
for (value, current_state) in values.iter().zip(aggregation_states.iter_mut()) {
if let Some(agg_value) = aggregator.append_value(current_state, *value) {
agg_values.push(agg_value);
}
}

let agg_values = if agg_values.is_empty() {
None
} else if agg_values.len() == values.len() {
Some(agg_values)
} else {
// This should never happen because the aggregator should either yield no values
// or the same number of values as the input metrics.
unreachable!("Some values were aggregated and some were not!");
};

internal_state.current_values = agg_values;
} else {
Expand Down
5 changes: 3 additions & 2 deletions pywr-core/src/recorders/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ mod memory;
mod metric_set;
mod py;

pub use self::csv::CSVRecorder;
pub use self::csv::{CsvLongFmtOutput, CsvShortFmtOutput};
use crate::metric::{MetricF64, MetricUsize};
use crate::models::ModelDomain;
use crate::network::Network;
Expand Down Expand Up @@ -78,7 +78,7 @@ pub trait Recorder: Send + Sync {
&self,
_timestep: &Timestep,
_scenario_indices: &[ScenarioIndex],
_model: &Network,
_network: &Network,
_state: &[State],
_metric_set_states: &[Vec<MetricSetState>],
_internal_state: &mut Option<Box<dyn Any>>,
Expand All @@ -87,6 +87,7 @@ pub trait Recorder: Send + Sync {
}
fn finalise(
&self,
_network: &Network,
_metric_set_states: &[Vec<MetricSetState>],
_internal_state: &mut Option<Box<dyn Any>>,
) -> Result<(), PywrError> {
Expand Down
6 changes: 4 additions & 2 deletions pywr-schema/src/outputs/csv.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::error::SchemaError;
use pywr_core::recorders::CSVRecorder;
use pywr_core::recorders::{CsvLongFmtOutput, CsvShortFmtOutput};
use std::path::{Path, PathBuf};

#[derive(serde::Deserialize, serde::Serialize, Debug, Clone)]
Expand All @@ -21,7 +21,9 @@ impl CsvOutput {
};

let metric_set_idx = network.get_metric_set_index_by_name(&self.metric_set)?;
let recorder = CSVRecorder::new(&self.name, filename, metric_set_idx);
// let recorder = CsvShortFmtOutput::new(&self.name, filename, metric_set_idx);

let recorder = CsvLongFmtOutput::new(&self.name, filename, vec![metric_set_idx]);

network.add_recorder(Box::new(recorder))?;

Expand Down

0 comments on commit 821a914

Please sign in to comment.