Skip to content

Commit

Permalink
feat: Support aggregation order in MemoryOutput.
Browse files Browse the repository at this point in the history
This addresses a todo regarding allowing the user to specify
the aggregation order for the "aggregated value" of the memory
output.

It also fixes two Clippy warnings about unused functions.
  • Loading branch information
jetuk committed Jul 16, 2024
1 parent f954435 commit c0d04d8
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 9 deletions.
26 changes: 19 additions & 7 deletions pywr-core/src/recorders/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ impl InternalState {
/// Aggregate over the saved data to a single value using the provided aggregation functions.
///
/// This method will first aggregation over the metrics, then over time, and finally over the scenarios.
fn aggregate_scenario_time_metric(&self, aggregation: &Aggregation) -> Result<f64, AggregationError> {
fn aggregate_metric_time_scenario(&self, aggregation: &Aggregation) -> Result<f64, AggregationError> {
let scenario_data: Vec<f64> = self
.data
.iter()
Expand All @@ -168,7 +168,7 @@ impl InternalState {
/// Aggregate over the saved data to a single value using the provided aggregation functions.
///
/// This method will first aggregation over time, then over the metrics, and finally over the scenarios.
fn aggregate_scenario_metric_time(&self, aggregation: &Aggregation) -> Result<f64, AggregationError> {
fn aggregate_time_metric_scenario(&self, aggregation: &Aggregation) -> Result<f64, AggregationError> {
let scenario_data: Vec<f64> = self
.data
.iter()
Expand All @@ -192,6 +192,13 @@ impl InternalState {
}
}

#[derive(Default, Copy, Clone)]
pub enum AggregationOrder {
#[default]
MetricTimeScenario,
TimeMetricScenario,
}

/// A recorder that saves the metric values to memory.
///
/// This recorder saves data into memory and can be used to provide aggregated data for external
Expand All @@ -204,14 +211,16 @@ pub struct MemoryRecorder {
meta: RecorderMeta,
metric_set_idx: MetricSetIndex,
aggregation: Aggregation,
order: AggregationOrder,
}

impl MemoryRecorder {
pub fn new(name: &str, metric_set_idx: MetricSetIndex, aggregation: Aggregation) -> Self {
pub fn new(name: &str, metric_set_idx: MetricSetIndex, aggregation: Aggregation, order: AggregationOrder) -> Self {
Self {
meta: RecorderMeta::new(name),
metric_set_idx,
aggregation,
order,
}
}
}
Expand Down Expand Up @@ -298,8 +307,11 @@ impl Recorder for MemoryRecorder {
None => panic!("No internal state defined when one was expected! :("),
};

// TODO allow the user to choose the order of aggregation
let agg_value = internal_state.aggregate_scenario_time_metric(&self.aggregation)?;
let agg_value = match self.order {
AggregationOrder::MetricTimeScenario => internal_state.aggregate_metric_time_scenario(&self.aggregation)?,
AggregationOrder::TimeMetricScenario => internal_state.aggregate_time_metric_scenario(&self.aggregation)?,
};

Ok(agg_value)
}
}
Expand Down Expand Up @@ -356,10 +368,10 @@ mod tests {
Some(AggregationFunction::CountFunc { func: |v: f64| v > 0.0 }),
Some(AggregationFunction::Sum),
);
let agg_value = state.aggregate_scenario_time_metric(&agg).expect("Aggregation failed");
let agg_value = state.aggregate_metric_time_scenario(&agg).expect("Aggregation failed");
assert_approx_eq!(f64, agg_value, count_non_zero_max);

let agg_value = state.aggregate_scenario_metric_time(&agg).expect("Aggregation failed");
let agg_value = state.aggregate_time_metric_scenario(&agg).expect("Aggregation failed");
assert_approx_eq!(f64, agg_value, count_non_zero_by_metric.iter().sum());
}
}
2 changes: 1 addition & 1 deletion pywr-core/src/recorders/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub use aggregator::{AggregationFrequency, AggregationFunction, Aggregator};
pub use csv::{CsvLongFmtOutput, CsvLongFmtRecord, CsvWideFmtOutput};
use float_cmp::{approx_eq, ApproxEq, F64Margin};
pub use hdf::HDF5Recorder;
pub use memory::{Aggregation, AggregationError, MemoryRecorder};
pub use memory::{Aggregation, AggregationError, AggregationOrder, MemoryRecorder};
pub use metric_set::{MetricSet, MetricSetIndex, MetricSetState, OutputMetric};
use ndarray::prelude::*;
use ndarray::Array2;
Expand Down
24 changes: 23 additions & 1 deletion pywr-schema/src/outputs/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,40 @@ impl From<MemoryAggregation> for pywr_core::recorders::Aggregation {
}
}

#[derive(serde::Deserialize, serde::Serialize, Debug, Copy, Clone, JsonSchema, PywrVisitPaths)]
pub enum MemoryAggregationOrder {
MetricTimeScenario,
TimeMetricScenario,
}

#[cfg(feature = "core")]
impl From<MemoryAggregationOrder> for pywr_core::recorders::AggregationOrder {
fn from(value: MemoryAggregationOrder) -> Self {
match value {
MemoryAggregationOrder::MetricTimeScenario => pywr_core::recorders::AggregationOrder::MetricTimeScenario,
MemoryAggregationOrder::TimeMetricScenario => pywr_core::recorders::AggregationOrder::TimeMetricScenario,
}
}
}

#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema, PywrVisitPaths)]
pub struct MemoryOutput {
pub name: String,
pub metric_set: String,
pub aggregation: MemoryAggregation,
pub order: Option<MemoryAggregationOrder>,
}

#[cfg(feature = "core")]
impl MemoryOutput {
pub fn add_to_model(&self, network: &mut pywr_core::network::Network) -> Result<(), SchemaError> {
let metric_set_idx = network.get_metric_set_index_by_name(&self.metric_set)?;
let recorder = MemoryRecorder::new(&self.name, metric_set_idx, self.aggregation.clone().into());
let recorder = MemoryRecorder::new(
&self.name,
metric_set_idx,
self.aggregation.clone().into(),
self.order.map(|o| o.into()).unwrap_or_default(),
);

network.add_recorder(Box::new(recorder))?;

Expand Down

0 comments on commit c0d04d8

Please sign in to comment.