Skip to content

Commit

Permalink
fix: Fix LossLink and WTW losses. (#223)
Browse files Browse the repository at this point in the history
- Add missing aggregated nodes to LossLinkNode and WaterTreatmentWorks. This means that actually apply losses.
- Adds a shared struct for defining Gross or Net losses.
- Add new tests for both nodes, including without losses defined.
- Update some test utils for easier comparison with expected outputs.
  • Loading branch information
jetuk authored Jul 25, 2024
1 parent 5fda5e8 commit 1b96db2
Show file tree
Hide file tree
Showing 23 changed files with 653 additions and 240 deletions.
2 changes: 1 addition & 1 deletion pywr-core/src/aggregated_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,6 @@ mod tests {

let model = Model::new(default_time_domain().into(), network);

run_all_solvers(&model, &["cbc", "highs"]);
run_all_solvers(&model, &["cbc", "highs"], &[]);
}
}
24 changes: 24 additions & 0 deletions pywr-core/src/metric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ impl ConstantMetricF64 {
ConstantMetricF64::Constant(v) => Ok(*v),
}
}

/// Returns true if the constant value is a [`ConstantMetricF64::Constant`] with a value of zero.
pub fn is_constant_zero(&self) -> bool {
match self {
ConstantMetricF64::Constant(v) => *v == 0.0,
_ => false,
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub enum SimpleMetricF64 {
Expand All @@ -45,6 +53,14 @@ impl SimpleMetricF64 {
SimpleMetricF64::Constant(m) => m.get_value(values.get_constant_values()),
}
}

/// Returns true if the constant value is a [`ConstantMetricF64::Constant`] with a value of zero.
pub fn is_constant_zero(&self) -> bool {
match self {
SimpleMetricF64::Constant(c) => c.is_constant_zero(),
_ => false,
}
}
}

#[derive(Clone, Debug, PartialEq)]
Expand Down Expand Up @@ -122,6 +138,14 @@ impl MetricF64 {
MetricF64::Simple(s) => s.get_value(&state.get_simple_parameter_values()),
}
}

/// Returns true if the constant value is a [`ConstantMetricF64::Constant`] with a value of zero.
pub fn is_constant_zero(&self) -> bool {
match self {
MetricF64::Simple(s) => s.is_constant_zero(),
_ => false,
}
}
}

impl TryFrom<MetricF64> for SimpleMetricF64 {
Expand Down
6 changes: 3 additions & 3 deletions pywr-core/src/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1785,7 +1785,7 @@ mod tests {
model.network_mut().add_recorder(Box::new(recorder)).unwrap();

// Test all solvers
run_all_solvers(&model, &[]);
run_all_solvers(&model, &[], &[]);
}

#[test]
Expand All @@ -1809,7 +1809,7 @@ mod tests {
network.add_recorder(Box::new(recorder)).unwrap();

// Test all solvers
run_all_solvers(&model, &[]);
run_all_solvers(&model, &[], &[]);
}

/// Test proportional storage derived metric.
Expand Down Expand Up @@ -1849,7 +1849,7 @@ mod tests {
network.add_recorder(Box::new(recorder)).unwrap();

// Test all solvers
run_all_solvers(&model, &[]);
run_all_solvers(&model, &[], &[]);
}

#[test]
Expand Down
19 changes: 17 additions & 2 deletions pywr-core/src/recorders/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use chrono::NaiveDateTime;
use serde::{Deserialize, Serialize};
use std::any::Any;
use std::fs::File;
use std::num::NonZeroU32;
use std::ops::Deref;
use std::path::PathBuf;

Expand Down Expand Up @@ -205,14 +206,21 @@ pub struct CsvLongFmtOutput {
meta: RecorderMeta,
filename: PathBuf,
metric_set_indices: Vec<MetricSetIndex>,
decimal_places: Option<NonZeroU32>,
}

impl CsvLongFmtOutput {
pub fn new<P: Into<PathBuf>>(name: &str, filename: P, metric_set_indices: &[MetricSetIndex]) -> Self {
pub fn new<P: Into<PathBuf>>(
name: &str,
filename: P,
metric_set_indices: &[MetricSetIndex],
decimal_places: Option<NonZeroU32>,
) -> Self {
Self {
meta: RecorderMeta::new(name),
filename: filename.into(),
metric_set_indices: metric_set_indices.to_vec(),
decimal_places,
}
}

Expand All @@ -236,14 +244,21 @@ impl CsvLongFmtOutput {
let name = metric.name().to_string();
let attribute = metric.attribute().to_string();

let value_scaled = if let Some(decimal_places) = self.decimal_places {
let scale = 10.0_f64.powi(decimal_places.get() as i32);
(value.value * scale).round() / scale
} else {
value.value
};

let record = CsvLongFmtRecord {
time_start: value.start,
time_end: value.end(),
scenario_index: scenario_idx,
metric_set: metric_set.name().to_string(),
name,
attribute,
value: value.value,
value: value_scaled,
};

internal
Expand Down
2 changes: 1 addition & 1 deletion pywr-core/src/recorders/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ mod tests {

let _idx = model.network_mut().add_recorder(Box::new(rec)).unwrap();
// Test all solvers
run_all_solvers(&model, &[]);
run_all_solvers(&model, &[], &[]);

// TODO fix this with respect to the trait.
// let array = rec.data_view2().unwrap();
Expand Down
39 changes: 33 additions & 6 deletions pywr-core/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use float_cmp::{approx_eq, F64Margin};
use ndarray::{Array, Array2};
use rand::Rng;
use rand_distr::{Distribution, Normal};
use std::path::PathBuf;

pub fn default_timestepper() -> Timestepper {
let start = NaiveDate::from_ymd_opt(2020, 1, 1)
Expand Down Expand Up @@ -163,21 +164,42 @@ pub fn run_and_assert_parameter(
let rec = AssertionRecorder::new("assert", p_idx.into(), expected_values, ulps, epsilon);

model.network_mut().add_recorder(Box::new(rec)).unwrap();
run_all_solvers(model, &[])
run_all_solvers(model, &[], &[])
}

/// A struct to hold the expected outputs for a test.
pub struct ExpectedOutputs {
actual_path: PathBuf,
expected_str: &'static str,
}

impl ExpectedOutputs {
pub fn new(actual_path: PathBuf, expected_str: &'static str) -> Self {
Self {
actual_path,
expected_str,
}
}

fn verify(&self) {
assert!(self.actual_path.exists());
let actual_str = std::fs::read_to_string(&self.actual_path).unwrap();
assert_eq!(actual_str, self.expected_str);
}
}

/// Run a model using each of the in-built solvers.
///
/// The model will only be run if the solver has the required solver features (and
/// is also enabled as a Cargo feature).
pub fn run_all_solvers(model: &Model, solvers_without_features: &[&str]) {
check_features_and_run::<ClpSolver>(model, !solvers_without_features.contains(&"clp"));
pub fn run_all_solvers(model: &Model, solvers_without_features: &[&str], expected_outputs: &[ExpectedOutputs]) {
check_features_and_run::<ClpSolver>(model, !solvers_without_features.contains(&"clp"), expected_outputs);

#[cfg(feature = "cbc")]
check_features_and_run::<CbcSolver>(model, !solvers_without_features.contains(&"cbc"));
check_features_and_run::<CbcSolver>(model, !solvers_without_features.contains(&"cbc"), expected_outputs);

#[cfg(feature = "highs")]
check_features_and_run::<HighsSolver>(model, !solvers_without_features.contains(&"highs"));
check_features_and_run::<HighsSolver>(model, !solvers_without_features.contains(&"highs"), expected_outputs);

#[cfg(feature = "ipm-simd")]
{
Expand All @@ -199,7 +221,7 @@ pub fn run_all_solvers(model: &Model, solvers_without_features: &[&str]) {
}

/// Check features and
fn check_features_and_run<S>(model: &Model, expect_features: bool)
fn check_features_and_run<S>(model: &Model, expect_features: bool, expected_outputs: &[ExpectedOutputs])
where
S: Solver,
<S as Solver>::Settings: SolverSettings + Default,
Expand All @@ -214,6 +236,11 @@ where
model
.run::<S>(&Default::default())
.unwrap_or_else(|_| panic!("Failed to solve with: {}", S::name()));

// Verify any expected outputs
for expected_output in expected_outputs {
expected_output.verify();
}
} else {
assert!(
!has_features,
Expand Down
6 changes: 3 additions & 3 deletions pywr-core/src/virtual_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ mod tests {
let domain = default_timestepper().try_into().unwrap();
let model = Model::new(domain, network);
// Test all solvers
run_all_solvers(&model, &["highs"]);
run_all_solvers(&model, &["highs"], &[]);
}

#[test]
Expand All @@ -449,7 +449,7 @@ mod tests {
network.add_recorder(Box::new(recorder)).unwrap();

// Test all solvers
run_all_solvers(&model, &["highs"]);
run_all_solvers(&model, &["highs"], &[]);
}

#[test]
Expand Down Expand Up @@ -489,6 +489,6 @@ mod tests {
network.add_recorder(Box::new(recorder)).unwrap();

// Test all solvers
run_all_solvers(&model, &["highs"]);
run_all_solvers(&model, &["highs"], &[]);
}
}
2 changes: 1 addition & 1 deletion pywr-schema/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,7 @@ mod core_tests {
network.add_recorder(Box::new(rec)).unwrap();

// Test all solvers
run_all_solvers(&model, &[]);
run_all_solvers(&model, &[], &[]);
}

/// Test that a cycle in parameter dependencies does not load.
Expand Down
2 changes: 1 addition & 1 deletion pywr-schema/src/nodes/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,6 @@ mod tests {
let schema = PywrModel::from_str(data).unwrap();
let model: pywr_core::models::Model = schema.build_model(None, None).unwrap();
// Test all solvers
run_all_solvers(&model, &[]);
run_all_solvers(&model, &[], &[]);
}
}
2 changes: 1 addition & 1 deletion pywr-schema/src/nodes/delay.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,6 @@ mod tests {
network.add_recorder(Box::new(recorder)).unwrap();

// Test all solvers
run_all_solvers(&model, &[]);
run_all_solvers(&model, &[], &[]);
}
}
Loading

0 comments on commit 1b96db2

Please sign in to comment.