diff --git a/pywr-core/src/recorders/memory.rs b/pywr-core/src/recorders/memory.rs index 697e1ee3..fd7399e1 100644 --- a/pywr-core/src/recorders/memory.rs +++ b/pywr-core/src/recorders/memory.rs @@ -146,7 +146,7 @@ impl InternalState { /// Aggregate over the saved data to a single value using the provided aggregation functions. /// /// This method will first aggregation over the metrics, then over time, and finally over the scenarios. - fn aggregate_scenario_time_metric(&self, aggregation: &Aggregation) -> Result { + fn aggregate_metric_time_scenario(&self, aggregation: &Aggregation) -> Result { let scenario_data: Vec = self .data .iter() @@ -168,7 +168,7 @@ impl InternalState { /// Aggregate over the saved data to a single value using the provided aggregation functions. /// /// This method will first aggregation over time, then over the metrics, and finally over the scenarios. - fn aggregate_scenario_metric_time(&self, aggregation: &Aggregation) -> Result { + fn aggregate_time_metric_scenario(&self, aggregation: &Aggregation) -> Result { let scenario_data: Vec = self .data .iter() @@ -192,6 +192,13 @@ impl InternalState { } } +#[derive(Default, Copy, Clone)] +pub enum AggregationOrder { + #[default] + MetricTimeScenario, + TimeMetricScenario, +} + /// A recorder that saves the metric values to memory. /// /// This recorder saves data into memory and can be used to provide aggregated data for external @@ -204,14 +211,16 @@ pub struct MemoryRecorder { meta: RecorderMeta, metric_set_idx: MetricSetIndex, aggregation: Aggregation, + order: AggregationOrder, } impl MemoryRecorder { - pub fn new(name: &str, metric_set_idx: MetricSetIndex, aggregation: Aggregation) -> Self { + pub fn new(name: &str, metric_set_idx: MetricSetIndex, aggregation: Aggregation, order: AggregationOrder) -> Self { Self { meta: RecorderMeta::new(name), metric_set_idx, aggregation, + order, } } } @@ -298,8 +307,11 @@ impl Recorder for MemoryRecorder { None => panic!("No internal state defined when one was expected! :("), }; - // TODO allow the user to choose the order of aggregation - let agg_value = internal_state.aggregate_scenario_time_metric(&self.aggregation)?; + let agg_value = match self.order { + AggregationOrder::MetricTimeScenario => internal_state.aggregate_metric_time_scenario(&self.aggregation)?, + AggregationOrder::TimeMetricScenario => internal_state.aggregate_time_metric_scenario(&self.aggregation)?, + }; + Ok(agg_value) } } @@ -356,10 +368,10 @@ mod tests { Some(AggregationFunction::CountFunc { func: |v: f64| v > 0.0 }), Some(AggregationFunction::Sum), ); - let agg_value = state.aggregate_scenario_time_metric(&agg).expect("Aggregation failed"); + let agg_value = state.aggregate_metric_time_scenario(&agg).expect("Aggregation failed"); assert_approx_eq!(f64, agg_value, count_non_zero_max); - let agg_value = state.aggregate_scenario_metric_time(&agg).expect("Aggregation failed"); + let agg_value = state.aggregate_time_metric_scenario(&agg).expect("Aggregation failed"); assert_approx_eq!(f64, agg_value, count_non_zero_by_metric.iter().sum()); } } diff --git a/pywr-core/src/recorders/mod.rs b/pywr-core/src/recorders/mod.rs index c8dc508d..025dd4c2 100644 --- a/pywr-core/src/recorders/mod.rs +++ b/pywr-core/src/recorders/mod.rs @@ -16,7 +16,7 @@ pub use aggregator::{AggregationFrequency, AggregationFunction, Aggregator}; pub use csv::{CsvLongFmtOutput, CsvLongFmtRecord, CsvWideFmtOutput}; use float_cmp::{approx_eq, ApproxEq, F64Margin}; pub use hdf::HDF5Recorder; -pub use memory::{Aggregation, AggregationError, MemoryRecorder}; +pub use memory::{Aggregation, AggregationError, AggregationOrder, MemoryRecorder}; pub use metric_set::{MetricSet, MetricSetIndex, MetricSetState, OutputMetric}; use ndarray::prelude::*; use ndarray::Array2; diff --git a/pywr-schema/src/outputs/memory.rs b/pywr-schema/src/outputs/memory.rs index 7c4448c7..972e2d5a 100644 --- a/pywr-schema/src/outputs/memory.rs +++ b/pywr-schema/src/outputs/memory.rs @@ -24,18 +24,40 @@ impl From for pywr_core::recorders::Aggregation { } } +#[derive(serde::Deserialize, serde::Serialize, Debug, Copy, Clone, JsonSchema, PywrVisitPaths)] +pub enum MemoryAggregationOrder { + MetricTimeScenario, + TimeMetricScenario, +} + +#[cfg(feature = "core")] +impl From for pywr_core::recorders::AggregationOrder { + fn from(value: MemoryAggregationOrder) -> Self { + match value { + MemoryAggregationOrder::MetricTimeScenario => pywr_core::recorders::AggregationOrder::MetricTimeScenario, + MemoryAggregationOrder::TimeMetricScenario => pywr_core::recorders::AggregationOrder::TimeMetricScenario, + } + } +} + #[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema, PywrVisitPaths)] pub struct MemoryOutput { pub name: String, pub metric_set: String, pub aggregation: MemoryAggregation, + pub order: Option, } #[cfg(feature = "core")] impl MemoryOutput { pub fn add_to_model(&self, network: &mut pywr_core::network::Network) -> Result<(), SchemaError> { let metric_set_idx = network.get_metric_set_index_by_name(&self.metric_set)?; - let recorder = MemoryRecorder::new(&self.name, metric_set_idx, self.aggregation.clone().into()); + let recorder = MemoryRecorder::new( + &self.name, + metric_set_idx, + self.aggregation.clone().into(), + self.order.map(|o| o.into()).unwrap_or_default(), + ); network.add_recorder(Box::new(recorder))?;