diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index 1d5bc7d9..963748d1 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -14,21 +14,21 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - with: - submodules: true + - uses: actions/checkout@v4 + with: + submodules: true - - name: Install latest mdbook - run: | - tag=$(curl 'https://api.github.com/repos/rust-lang/mdbook/releases/latest' | jq -r '.tag_name') - url="https://github.com/rust-lang/mdbook/releases/download/${tag}/mdbook-${tag}-x86_64-unknown-linux-gnu.tar.gz" - mkdir bin - curl -sSL $url | tar -xz --directory=bin - echo "$(pwd)/bin" >> $GITHUB_PATH + - name: Install latest mdbook + run: | + tag=$(curl 'https://api.github.com/repos/rust-lang/mdbook/releases/latest' | jq -r '.tag_name') + url="https://github.com/rust-lang/mdbook/releases/download/${tag}/mdbook-${tag}-x86_64-unknown-linux-gnu.tar.gz" + mkdir bin + curl -sSL $url | tar -xz --directory=bin + echo "$(pwd)/bin" >> $GITHUB_PATH - - name: Build - run: cargo build --verbose --no-default-features - - name: Run tests - run: cargo test --no-default-features - - name: Run mdbook tests - run: mdbook test ./pywr-book + - name: Build + run: cargo build --verbose --no-default-features --features highs + - name: Run tests + run: cargo test --no-default-features --features highs + - name: Run mdbook tests + run: mdbook test ./pywr-book diff --git a/Cargo.toml b/Cargo.toml index 3b42716f..da6b96ff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ members = [ "pywr-schema", "pywr-cli", "pywr-python", + "pywr-schema-macros", ] exclude = [ "tests/models/simple-wasm/simple-wasm-parameter" @@ -38,8 +39,8 @@ thiserror = "1.0.25" num = "0.4.0" float-cmp = "0.9.0" ndarray = "0.15.3" -polars = { version = "0.37.0", features = ["lazy", "rows", "ndarray"] } -pyo3-polars = "0.11.1" +polars = { version = "0.38.1", features = ["lazy", "rows", "ndarray"] } +pyo3-polars = "0.12.0" pyo3 = { version = "0.20.2", default-features = false } pyo3-log = "0.9.0" tracing = { version = "0.1", features = ["log"] } diff --git a/pywr-cli/src/main.rs b/pywr-cli/src/main.rs index f49c3124..4b4f3642 100644 --- a/pywr-cli/src/main.rs +++ b/pywr-cli/src/main.rs @@ -1,7 +1,8 @@ mod tracing; use crate::tracing::setup_tracing; -use anyhow::{Context, Result}; +use ::tracing::info; +use anyhow::Result; use clap::{Parser, Subcommand, ValueEnum}; #[cfg(feature = "ipm-ocl")] use pywr_core::solvers::{ClIpmF32Solver, ClIpmF64Solver, ClIpmSolverSettings}; @@ -12,7 +13,6 @@ use pywr_core::solvers::{HighsSolver, HighsSolverSettings}; use pywr_core::solvers::{SimdIpmF64Solver, SimdIpmSolverSettings}; use pywr_core::test_utils::make_random_model; use pywr_schema::model::{PywrModel, PywrMultiNetworkModel}; -use pywr_schema::ConversionError; use rand::SeedableRng; use rand_chacha::ChaCha8Rng; use std::fmt::{Display, Formatter}; @@ -50,16 +50,9 @@ impl Display for Solver { #[derive(Parser)] #[command(author, version, about, long_about = None)] struct Cli { - // /// Optional name to operate on - // name: Option, - // - // /// Sets a custom config file - // #[arg(short, long, value_name = "FILE")] - // config: Option, - // - // /// Turn debugging information on - // #[arg(short, long, action = clap::ArgAction::Count)] - // debug: u8, + /// Turn debugging information on + #[arg(long, default_value_t = false)] + debug: bool, #[command(subcommand)] command: Option, } @@ -69,6 +62,9 @@ enum Commands { Convert { /// Path to Pywr v1.x JSON. model: PathBuf, + /// Stop if there is an error converting the model. + #[arg(short, long, default_value_t = false)] + stop_on_error: bool, }, Run { @@ -87,8 +83,6 @@ enum Commands { /// The number of threads to use in parallel simulation. #[arg(short, long, default_value_t = 1)] threads: usize, - #[arg(long, default_value_t = false)] - debug: bool, }, RunMulti { /// Path to Pywr model JSON. @@ -106,8 +100,6 @@ enum Commands { /// The number of threads to use in parallel simulation. #[arg(short, long, default_value_t = 1)] threads: usize, - #[arg(long, default_value_t = false)] - debug: bool, }, RunRandom { num_systems: usize, @@ -121,10 +113,11 @@ enum Commands { fn main() -> Result<()> { let cli = Cli::parse(); + setup_tracing(cli.debug).unwrap(); match &cli.command { Some(command) => match command { - Commands::Convert { model } => convert(model)?, + Commands::Convert { model, stop_on_error } => convert(model, *stop_on_error), Commands::Run { model, solver, @@ -132,8 +125,7 @@ fn main() -> Result<()> { output_path, parallel: _, threads: _, - debug, - } => run(model, solver, data_path.as_deref(), output_path.as_deref(), *debug), + } => run(model, solver, data_path.as_deref(), output_path.as_deref()), Commands::RunMulti { model, solver, @@ -141,8 +133,7 @@ fn main() -> Result<()> { output_path, parallel: _, threads: _, - debug, - } => run_multi(model, solver, data_path.as_deref(), output_path.as_deref(), *debug), + } => run_multi(model, solver, data_path.as_deref(), output_path.as_deref()), Commands::RunRandom { num_systems, density, @@ -156,7 +147,7 @@ fn main() -> Result<()> { Ok(()) } -fn convert(path: &Path) -> Result<()> { +fn convert(path: &Path, stop_on_error: bool) { if path.is_dir() { for entry in path.read_dir().expect("read_dir call failed").flatten() { let path = entry.path(); @@ -164,22 +155,34 @@ fn convert(path: &Path) -> Result<()> { && (path.extension().unwrap() == "json") && (!path.file_stem().unwrap().to_str().unwrap().contains("_v2")) { - v1_to_v2(&path).with_context(|| format!("Could not convert model: `{:?}`", &path))?; + v1_to_v2(&path, stop_on_error); } } } else { - v1_to_v2(path).with_context(|| format!("Could not convert model: `{:?}`", path))?; + v1_to_v2(path, stop_on_error); } - - Ok(()) } -fn v1_to_v2(path: &Path) -> std::result::Result<(), ConversionError> { - println!("Model: {}", path.display()); +fn v1_to_v2(path: &Path, stop_on_error: bool) { + info!("Model: {}", path.display()); let data = std::fs::read_to_string(path).unwrap(); + // Load the v1 schema let schema: pywr_v1_schema::PywrModel = serde_json::from_str(data.as_str()).unwrap(); - let schema_v2: PywrModel = schema.try_into()?; + // Convert to v2 schema and collect any errors + let (schema_v2, errors) = PywrModel::from_v1(schema); + + if !errors.is_empty() { + info!("Model converted with {} errors:", errors.len()); + for error in errors { + info!(" {}", error); + } + if stop_on_error { + return; + } + } else { + info!("Model converted with zero errors!"); + } // There must be a better way to do this!! let mut new_file_name = path.file_stem().unwrap().to_os_string(); @@ -189,13 +192,9 @@ fn v1_to_v2(path: &Path) -> std::result::Result<(), ConversionError> { let new_file_pth = path.parent().unwrap().join(new_file_name); std::fs::write(new_file_pth, serde_json::to_string_pretty(&schema_v2).unwrap()).unwrap(); - - Ok(()) } -fn run(path: &Path, solver: &Solver, data_path: Option<&Path>, output_path: Option<&Path>, debug: bool) { - setup_tracing(debug).unwrap(); - +fn run(path: &Path, solver: &Solver, data_path: Option<&Path>, output_path: Option<&Path>) { let data = std::fs::read_to_string(path).unwrap(); let data_path = data_path.or_else(|| path.parent()); let schema_v2: PywrModel = serde_json::from_str(data.as_str()).unwrap(); @@ -216,9 +215,7 @@ fn run(path: &Path, solver: &Solver, data_path: Option<&Path>, output_path: Opti .unwrap(); } -fn run_multi(path: &Path, solver: &Solver, data_path: Option<&Path>, output_path: Option<&Path>, debug: bool) { - setup_tracing(debug).unwrap(); - +fn run_multi(path: &Path, solver: &Solver, data_path: Option<&Path>, output_path: Option<&Path>) { let data = std::fs::read_to_string(path).unwrap(); let data_path = data_path.or_else(|| path.parent()); diff --git a/pywr-core/Cargo.toml b/pywr-core/Cargo.toml index 5bd95f1c..0fdc225f 100644 --- a/pywr-core/Cargo.toml +++ b/pywr-core/Cargo.toml @@ -18,20 +18,19 @@ libc = "0.2.97" thiserror = { workspace = true } ndarray = { workspace = true } num = { workspace = true } -float-cmp = { workspace = true } +float-cmp = { workspace = true } hdf5 = { workspace = true } csv = { workspace = true } clp-sys = { path = "../clp-sys", version = "0.1.0" } ipm-ocl = { path = "../ipm-ocl", optional = true } ipm-simd = { path = "../ipm-simd", optional = true } -tracing = { workspace = true } -highs-sys = { git = "https://github.com/jetuk/highs-sys", branch="fix-build-libz-linking", optional = true } -# highs-sys = { path = "../../highs-sys" } +tracing = { workspace = true } +highs-sys = { version = "1.6.2", optional = true } nalgebra = "0.32.3" chrono = { workspace = true } polars = { workspace = true } -pyo3 = { workspace = true, features = ["chrono"] } +pyo3 = { workspace = true, features = ["chrono"] } rayon = "1.6.1" diff --git a/pywr-core/src/lib.rs b/pywr-core/src/lib.rs index 301521b5..5cb2e140 100644 --- a/pywr-core/src/lib.rs +++ b/pywr-core/src/lib.rs @@ -5,8 +5,9 @@ extern crate core; use crate::derived_metric::DerivedMetricIndex; use crate::models::MultiNetworkTransferIndex; use crate::node::NodeIndex; -use crate::parameters::{IndexParameterIndex, InterpolationError, MultiValueParameterIndex, ParameterIndex}; +use crate::parameters::{InterpolationError, ParameterIndex}; use crate::recorders::{AggregationError, MetricSetIndex, RecorderIndex}; +use crate::state::MultiValue; use crate::virtual_storage::VirtualStorageIndex; use pyo3::exceptions::{PyException, PyRuntimeError}; use pyo3::{create_exception, PyErr}; @@ -45,11 +46,11 @@ pub enum PywrError { #[error("virtual storage index {0} not found")] VirtualStorageIndexNotFound(VirtualStorageIndex), #[error("parameter index {0} not found")] - ParameterIndexNotFound(ParameterIndex), + ParameterIndexNotFound(ParameterIndex), #[error("index parameter index {0} not found")] - IndexParameterIndexNotFound(IndexParameterIndex), + IndexParameterIndexNotFound(ParameterIndex), #[error("multi1 value parameter index {0} not found")] - MultiValueParameterIndexNotFound(MultiValueParameterIndex), + MultiValueParameterIndexNotFound(ParameterIndex), #[error("multi1 value parameter key {0} not found")] MultiValueParameterKeyNotFound(String), #[error("inter-network parameter state not initialised")] @@ -73,9 +74,9 @@ pub enum PywrError { #[error("node name `{0}` already exists")] NodeNameAlreadyExists(String), #[error("parameter name `{0}` already exists at index {1}")] - ParameterNameAlreadyExists(String, ParameterIndex), + ParameterNameAlreadyExists(String, ParameterIndex), #[error("index parameter name `{0}` already exists at index {1}")] - IndexParameterNameAlreadyExists(String, IndexParameterIndex), + IndexParameterNameAlreadyExists(String, ParameterIndex), #[error("metric set name `{0}` already exists")] MetricSetNameAlreadyExists(String), #[error("recorder name `{0}` already exists at index {1}")] @@ -159,7 +160,7 @@ pub enum PywrError { #[error("parameters do not provide an initial value")] ParameterNoInitialValue, #[error("parameter state not found for parameter index {0}")] - ParameterStateNotFound(ParameterIndex), + ParameterStateNotFound(ParameterIndex), #[error("Could not create timestep range due to following error: {0}")] TimestepRangeGenerationError(String), #[error("Could not create timesteps for frequency '{0}'")] diff --git a/pywr-core/src/metric.rs b/pywr-core/src/metric.rs index 82e88ea6..daa7ad0a 100644 --- a/pywr-core/src/metric.rs +++ b/pywr-core/src/metric.rs @@ -5,8 +5,8 @@ use crate::edge::EdgeIndex; use crate::models::MultiNetworkTransferIndex; use crate::network::Network; use crate::node::NodeIndex; -use crate::parameters::{IndexParameterIndex, MultiValueParameterIndex, ParameterIndex}; -use crate::state::State; +use crate::parameters::ParameterIndex; +use crate::state::{MultiValue, State}; use crate::virtual_storage::VirtualStorageIndex; use crate::PywrError; #[derive(Clone, Debug, PartialEq)] @@ -18,8 +18,8 @@ pub enum Metric { AggregatedNodeOutFlow(AggregatedNodeIndex), AggregatedNodeVolume(AggregatedStorageNodeIndex), EdgeFlow(EdgeIndex), - ParameterValue(ParameterIndex), - MultiParameterValue((MultiValueParameterIndex, String)), + ParameterValue(ParameterIndex), + MultiParameterValue((ParameterIndex, String)), VirtualStorageVolume(VirtualStorageIndex), MultiNodeInFlow { indices: Vec, name: String }, MultiNodeOutFlow { indices: Vec, name: String }, @@ -87,7 +87,7 @@ impl Metric { #[derive(Clone, Debug, PartialEq)] pub enum IndexMetric { - IndexParameterValue(IndexParameterIndex), + IndexParameterValue(ParameterIndex), Constant(usize), } diff --git a/pywr-core/src/network.rs b/pywr-core/src/network.rs index 13a44a2b..0d6a3782 100644 --- a/pywr-core/src/network.rs +++ b/pywr-core/src/network.rs @@ -5,14 +5,14 @@ use crate::edge::{EdgeIndex, EdgeVec}; use crate::metric::Metric; use crate::models::ModelDomain; use crate::node::{ConstraintValue, Node, NodeVec, StorageInitialVolume}; -use crate::parameters::{MultiValueParameterIndex, ParameterType, VariableConfig}; +use crate::parameters::{ParameterType, VariableConfig}; use crate::recorders::{MetricSet, MetricSetIndex, MetricSetState}; use crate::scenario::ScenarioIndex; use crate::solvers::{MultiStateSolver, Solver, SolverFeatures, SolverTimings}; -use crate::state::{ParameterStates, State, StateBuilder}; +use crate::state::{MultiValue, ParameterStates, State, StateBuilder}; use crate::timestep::Timestep; use crate::virtual_storage::{VirtualStorage, VirtualStorageIndex, VirtualStorageReset, VirtualStorageVec}; -use crate::{parameters, recorders, IndexParameterIndex, NodeIndex, ParameterIndex, PywrError, RecorderIndex}; +use crate::{parameters, recorders, NodeIndex, ParameterIndex, PywrError, RecorderIndex}; use rayon::prelude::*; use std::any::Any; use std::collections::HashSet; @@ -201,9 +201,9 @@ pub struct Network { aggregated_nodes: AggregatedNodeVec, aggregated_storage_nodes: AggregatedStorageNodeVec, virtual_storage_nodes: VirtualStorageVec, - parameters: Vec>, - index_parameters: Vec>, - multi_parameters: Vec>, + parameters: Vec>>, + index_parameters: Vec>>, + multi_parameters: Vec>>, derived_metrics: Vec, metric_sets: Vec, resolve_order: Vec, @@ -1089,7 +1089,7 @@ impl Network { } /// Get a `Parameter` from a parameter's name - pub fn get_parameter(&self, index: &ParameterIndex) -> Result<&dyn parameters::Parameter, PywrError> { + pub fn get_parameter(&self, index: &ParameterIndex) -> Result<&dyn parameters::Parameter, PywrError> { match self.parameters.get(*index.deref()) { Some(p) => Ok(p.as_ref()), None => Err(PywrError::ParameterIndexNotFound(*index)), @@ -1097,7 +1097,10 @@ impl Network { } /// Get a `Parameter` from a parameter's name - pub fn get_mut_parameter(&mut self, index: &ParameterIndex) -> Result<&mut dyn parameters::Parameter, PywrError> { + pub fn get_mut_parameter( + &mut self, + index: &ParameterIndex, + ) -> Result<&mut dyn parameters::Parameter, PywrError> { match self.parameters.get_mut(*index.deref()) { Some(p) => Ok(p.as_mut()), None => Err(PywrError::ParameterIndexNotFound(*index)), @@ -1105,7 +1108,7 @@ impl Network { } /// Get a `Parameter` from a parameter's name - pub fn get_parameter_by_name(&self, name: &str) -> Result<&dyn parameters::Parameter, PywrError> { + pub fn get_parameter_by_name(&self, name: &str) -> Result<&dyn parameters::Parameter, PywrError> { match self.parameters.iter().find(|p| p.name() == name) { Some(parameter) => Ok(parameter.as_ref()), None => Err(PywrError::ParameterNotFound(name.to_string())), @@ -1113,7 +1116,7 @@ impl Network { } /// Get a `ParameterIndex` from a parameter's name - pub fn get_parameter_index_by_name(&self, name: &str) -> Result { + pub fn get_parameter_index_by_name(&self, name: &str) -> Result, PywrError> { match self.parameters.iter().position(|p| p.name() == name) { Some(idx) => Ok(ParameterIndex::new(idx)), None => Err(PywrError::ParameterNotFound(name.to_string())), @@ -1121,7 +1124,7 @@ impl Network { } /// Get a `IndexParameter` from a parameter's name - pub fn get_index_parameter_by_name(&self, name: &str) -> Result<&dyn parameters::IndexParameter, PywrError> { + pub fn get_index_parameter_by_name(&self, name: &str) -> Result<&dyn parameters::Parameter, PywrError> { match self.index_parameters.iter().find(|p| p.name() == name) { Some(parameter) => Ok(parameter.as_ref()), None => Err(PywrError::ParameterNotFound(name.to_string())), @@ -1129,17 +1132,20 @@ impl Network { } /// Get a `IndexParameterIndex` from a parameter's name - pub fn get_index_parameter_index_by_name(&self, name: &str) -> Result { + pub fn get_index_parameter_index_by_name(&self, name: &str) -> Result, PywrError> { match self.index_parameters.iter().position(|p| p.name() == name) { - Some(idx) => Ok(IndexParameterIndex::new(idx)), + Some(idx) => Ok(ParameterIndex::new(idx)), None => Err(PywrError::ParameterNotFound(name.to_string())), } } /// Get a `MultiValueParameterIndex` from a parameter's name - pub fn get_multi_valued_parameter_index_by_name(&self, name: &str) -> Result { + pub fn get_multi_valued_parameter_index_by_name( + &self, + name: &str, + ) -> Result, PywrError> { match self.multi_parameters.iter().position(|p| p.name() == name) { - Some(idx) => Ok(MultiValueParameterIndex::new(idx)), + Some(idx) => Ok(ParameterIndex::new(idx)), None => Err(PywrError::ParameterNotFound(name.to_string())), } } @@ -1311,7 +1317,10 @@ impl Network { } /// Add a `parameters::Parameter` to the network - pub fn add_parameter(&mut self, parameter: Box) -> Result { + pub fn add_parameter( + &mut self, + parameter: Box>, + ) -> Result, PywrError> { if let Ok(idx) = self.get_parameter_index_by_name(¶meter.meta().name) { return Err(PywrError::ParameterNameAlreadyExists( parameter.meta().name.to_string(), @@ -1332,8 +1341,8 @@ impl Network { /// Add a `parameters::IndexParameter` to the network pub fn add_index_parameter( &mut self, - index_parameter: Box, - ) -> Result { + index_parameter: Box>, + ) -> Result, PywrError> { if let Ok(idx) = self.get_index_parameter_index_by_name(&index_parameter.meta().name) { return Err(PywrError::IndexParameterNameAlreadyExists( index_parameter.meta().name.to_string(), @@ -1341,7 +1350,7 @@ impl Network { )); } - let parameter_index = IndexParameterIndex::new(self.index_parameters.len()); + let parameter_index = ParameterIndex::new(self.index_parameters.len()); self.index_parameters.push(index_parameter); // .. and add it to the resolve order @@ -1353,8 +1362,8 @@ impl Network { /// Add a `parameters::MultiValueParameter` to the network pub fn add_multi_value_parameter( &mut self, - parameter: Box, - ) -> Result { + parameter: Box>, + ) -> Result, PywrError> { if let Ok(idx) = self.get_parameter_index_by_name(¶meter.meta().name) { return Err(PywrError::ParameterNameAlreadyExists( parameter.meta().name.to_string(), @@ -1362,7 +1371,7 @@ impl Network { )); } - let parameter_index = MultiValueParameterIndex::new(self.multi_parameters.len()); + let parameter_index = ParameterIndex::new(self.multi_parameters.len()); // Add the parameter ... self.multi_parameters.push(parameter); @@ -1451,7 +1460,7 @@ impl Network { /// This will update the internal state of the parameter with the new values for all scenarios. pub fn set_f64_parameter_variable_values( &self, - parameter_index: ParameterIndex, + parameter_index: ParameterIndex, values: &[f64], variable_config: &dyn VariableConfig, state: &mut NetworkState, @@ -1481,7 +1490,7 @@ impl Network { /// Only the internal state of the parameter for the given scenario will be updated. pub fn set_f64_parameter_variable_values_for_scenario( &self, - parameter_index: ParameterIndex, + parameter_index: ParameterIndex, scenario_index: ScenarioIndex, values: &[f64], variable_config: &dyn VariableConfig, @@ -1505,7 +1514,7 @@ impl Network { /// Return a vector of the current values of active variable parameters. pub fn get_f64_parameter_variable_values_for_scenario( &self, - parameter_index: ParameterIndex, + parameter_index: ParameterIndex, scenario_index: ScenarioIndex, state: &NetworkState, ) -> Result>, PywrError> { @@ -1527,7 +1536,7 @@ impl Network { pub fn get_f64_parameter_variable_values( &self, - parameter_index: ParameterIndex, + parameter_index: ParameterIndex, state: &NetworkState, ) -> Result>>, PywrError> { match self.parameters.get(*parameter_index.deref()) { @@ -1557,7 +1566,7 @@ impl Network { /// This will update the internal state of the parameter with the new values for scenarios. pub fn set_u32_parameter_variable_values( &self, - parameter_index: ParameterIndex, + parameter_index: ParameterIndex, values: &[u32], variable_config: &dyn VariableConfig, state: &mut NetworkState, @@ -1587,7 +1596,7 @@ impl Network { /// Only the internal state of the parameter for the given scenario will be updated. pub fn set_u32_parameter_variable_values_for_scenario( &self, - parameter_index: ParameterIndex, + parameter_index: ParameterIndex, scenario_index: ScenarioIndex, values: &[u32], variable_config: &dyn VariableConfig, @@ -1611,7 +1620,7 @@ impl Network { /// Return a vector of the current values of active variable parameters. pub fn get_u32_parameter_variable_values_for_scenario( &self, - parameter_index: ParameterIndex, + parameter_index: ParameterIndex, scenario_index: ScenarioIndex, state: &NetworkState, ) -> Result>, PywrError> { diff --git a/pywr-core/src/parameters/aggregated.rs b/pywr-core/src/parameters/aggregated.rs index 50e7349c..21bf4630 100644 --- a/pywr-core/src/parameters/aggregated.rs +++ b/pywr-core/src/parameters/aggregated.rs @@ -5,7 +5,6 @@ use crate::parameters::{Parameter, ParameterMeta}; use crate::scenario::ScenarioIndex; use crate::state::{ParameterState, State}; use crate::timestep::Timestep; -use std::any::Any; use std::str::FromStr; pub enum AggFunc { @@ -47,10 +46,7 @@ impl AggregatedParameter { } } -impl Parameter for AggregatedParameter { - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } +impl Parameter for AggregatedParameter { fn meta(&self) -> &ParameterMeta { &self.meta } diff --git a/pywr-core/src/parameters/aggregated_index.rs b/pywr-core/src/parameters/aggregated_index.rs index e4ec3bf7..f3d0e2bf 100644 --- a/pywr-core/src/parameters/aggregated_index.rs +++ b/pywr-core/src/parameters/aggregated_index.rs @@ -2,7 +2,7 @@ /// use super::PywrError; use crate::network::Network; -use crate::parameters::{IndexParameter, IndexValue, ParameterMeta}; +use crate::parameters::{IndexValue, Parameter, ParameterMeta}; use crate::scenario::ScenarioIndex; use crate::state::{ParameterState, State}; use crate::timestep::Timestep; @@ -49,7 +49,7 @@ impl AggregatedIndexParameter { } } -impl IndexParameter for AggregatedIndexParameter { +impl Parameter for AggregatedIndexParameter { fn meta(&self) -> &ParameterMeta { &self.meta } diff --git a/pywr-core/src/parameters/array.rs b/pywr-core/src/parameters/array.rs index 0cb69f38..afde690a 100644 --- a/pywr-core/src/parameters/array.rs +++ b/pywr-core/src/parameters/array.rs @@ -5,7 +5,6 @@ use crate::state::{ParameterState, State}; use crate::timestep::Timestep; use crate::PywrError; use ndarray::{Array1, Array2, Axis}; -use std::any::Any; pub struct Array1Parameter { meta: ParameterMeta, @@ -23,10 +22,7 @@ impl Array1Parameter { } } -impl Parameter for Array1Parameter { - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } +impl Parameter for Array1Parameter { fn meta(&self) -> &ParameterMeta { &self.meta } @@ -66,10 +62,7 @@ impl Array2Parameter { } } -impl Parameter for Array2Parameter { - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } +impl Parameter for Array2Parameter { fn meta(&self) -> &ParameterMeta { &self.meta } diff --git a/pywr-core/src/parameters/asymmetric.rs b/pywr-core/src/parameters/asymmetric.rs index 775fa535..b35bc7c3 100644 --- a/pywr-core/src/parameters/asymmetric.rs +++ b/pywr-core/src/parameters/asymmetric.rs @@ -1,5 +1,5 @@ use crate::network::Network; -use crate::parameters::{downcast_internal_state_mut, IndexParameter, IndexValue, ParameterMeta}; +use crate::parameters::{downcast_internal_state_mut, IndexValue, Parameter, ParameterMeta}; use crate::scenario::ScenarioIndex; use crate::state::{ParameterState, State}; use crate::timestep::Timestep; @@ -21,7 +21,7 @@ impl AsymmetricSwitchIndexParameter { } } -impl IndexParameter for AsymmetricSwitchIndexParameter { +impl Parameter for AsymmetricSwitchIndexParameter { fn meta(&self) -> &ParameterMeta { &self.meta } diff --git a/pywr-core/src/parameters/constant.rs b/pywr-core/src/parameters/constant.rs index 658ee86f..74348db7 100644 --- a/pywr-core/src/parameters/constant.rs +++ b/pywr-core/src/parameters/constant.rs @@ -7,7 +7,6 @@ use crate::scenario::ScenarioIndex; use crate::state::{ParameterState, State}; use crate::timestep::Timestep; use crate::PywrError; -use std::any::Any; pub struct ConstantParameter { meta: ParameterMeta, @@ -37,11 +36,7 @@ impl ConstantParameter { } } -impl Parameter for ConstantParameter { - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } - +impl Parameter for ConstantParameter { fn meta(&self) -> &ParameterMeta { &self.meta } diff --git a/pywr-core/src/parameters/control_curves/apportion.rs b/pywr-core/src/parameters/control_curves/apportion.rs index 5aaaffe1..58938194 100644 --- a/pywr-core/src/parameters/control_curves/apportion.rs +++ b/pywr-core/src/parameters/control_curves/apportion.rs @@ -1,6 +1,6 @@ use crate::metric::Metric; use crate::network::Network; -use crate::parameters::{MultiValueParameter, ParameterMeta}; +use crate::parameters::{Parameter, ParameterMeta}; use crate::scenario::ScenarioIndex; use crate::state::{MultiValue, ParameterState, State}; use crate::timestep::Timestep; @@ -31,7 +31,7 @@ impl ApportionParameter { } } -impl MultiValueParameter for ApportionParameter { +impl Parameter for ApportionParameter { fn meta(&self) -> &ParameterMeta { &self.meta } diff --git a/pywr-core/src/parameters/control_curves/index.rs b/pywr-core/src/parameters/control_curves/index.rs index 34d157a3..7616883e 100644 --- a/pywr-core/src/parameters/control_curves/index.rs +++ b/pywr-core/src/parameters/control_curves/index.rs @@ -1,6 +1,6 @@ use crate::metric::Metric; use crate::network::Network; -use crate::parameters::{IndexParameter, ParameterMeta}; +use crate::parameters::{Parameter, ParameterMeta}; use crate::scenario::ScenarioIndex; use crate::state::{ParameterState, State}; use crate::timestep::Timestep; @@ -22,7 +22,7 @@ impl ControlCurveIndexParameter { } } -impl IndexParameter for ControlCurveIndexParameter { +impl Parameter for ControlCurveIndexParameter { fn meta(&self) -> &ParameterMeta { &self.meta } diff --git a/pywr-core/src/parameters/control_curves/interpolated.rs b/pywr-core/src/parameters/control_curves/interpolated.rs index 15f48e31..94e02407 100644 --- a/pywr-core/src/parameters/control_curves/interpolated.rs +++ b/pywr-core/src/parameters/control_curves/interpolated.rs @@ -6,7 +6,6 @@ use crate::scenario::ScenarioIndex; use crate::state::{ParameterState, State}; use crate::timestep::Timestep; use crate::PywrError; -use std::any::Any; pub struct ControlCurveInterpolatedParameter { meta: ParameterMeta, @@ -26,10 +25,7 @@ impl ControlCurveInterpolatedParameter { } } -impl Parameter for ControlCurveInterpolatedParameter { - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } +impl Parameter for ControlCurveInterpolatedParameter { fn meta(&self) -> &ParameterMeta { &self.meta } diff --git a/pywr-core/src/parameters/control_curves/piecewise.rs b/pywr-core/src/parameters/control_curves/piecewise.rs index a2e3522e..eb0c3f25 100644 --- a/pywr-core/src/parameters/control_curves/piecewise.rs +++ b/pywr-core/src/parameters/control_curves/piecewise.rs @@ -6,7 +6,6 @@ use crate::scenario::ScenarioIndex; use crate::state::{ParameterState, State}; use crate::timestep::Timestep; use crate::PywrError; -use std::any::Any; pub struct PiecewiseInterpolatedParameter { meta: ParameterMeta, @@ -37,10 +36,7 @@ impl PiecewiseInterpolatedParameter { } } -impl Parameter for PiecewiseInterpolatedParameter { - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } +impl Parameter for PiecewiseInterpolatedParameter { fn meta(&self) -> &ParameterMeta { &self.meta } diff --git a/pywr-core/src/parameters/control_curves/simple.rs b/pywr-core/src/parameters/control_curves/simple.rs index a59553d2..1c76b180 100644 --- a/pywr-core/src/parameters/control_curves/simple.rs +++ b/pywr-core/src/parameters/control_curves/simple.rs @@ -5,7 +5,6 @@ use crate::scenario::ScenarioIndex; use crate::state::{ParameterState, State}; use crate::timestep::Timestep; use crate::PywrError; -use std::any::Any; pub struct ControlCurveParameter { meta: ParameterMeta, @@ -25,10 +24,7 @@ impl ControlCurveParameter { } } -impl Parameter for ControlCurveParameter { - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } +impl Parameter for ControlCurveParameter { fn meta(&self) -> &ParameterMeta { &self.meta } diff --git a/pywr-core/src/parameters/control_curves/volume_between.rs b/pywr-core/src/parameters/control_curves/volume_between.rs index f1f11880..67eae4cd 100644 --- a/pywr-core/src/parameters/control_curves/volume_between.rs +++ b/pywr-core/src/parameters/control_curves/volume_between.rs @@ -5,7 +5,6 @@ use crate::scenario::ScenarioIndex; use crate::state::{ParameterState, State}; use crate::timestep::Timestep; use crate::PywrError; -use std::any::Any; /// A parameter that returns the volume that is the proportion between two control curves pub struct VolumeBetweenControlCurvesParameter { @@ -26,11 +25,7 @@ impl VolumeBetweenControlCurvesParameter { } } -impl Parameter for VolumeBetweenControlCurvesParameter { - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } - +impl Parameter for VolumeBetweenControlCurvesParameter { fn meta(&self) -> &ParameterMeta { &self.meta } diff --git a/pywr-core/src/parameters/delay.rs b/pywr-core/src/parameters/delay.rs index 6ca3dc85..edf0e5be 100644 --- a/pywr-core/src/parameters/delay.rs +++ b/pywr-core/src/parameters/delay.rs @@ -5,7 +5,6 @@ use crate::scenario::ScenarioIndex; use crate::state::{ParameterState, State}; use crate::timestep::Timestep; use crate::PywrError; -use std::any::Any; use std::collections::VecDeque; pub struct DelayParameter { @@ -26,10 +25,7 @@ impl DelayParameter { } } -impl Parameter for DelayParameter { - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } +impl Parameter for DelayParameter { fn meta(&self) -> &ParameterMeta { &self.meta } diff --git a/pywr-core/src/parameters/discount_factor.rs b/pywr-core/src/parameters/discount_factor.rs index 4994ea91..d1231efe 100644 --- a/pywr-core/src/parameters/discount_factor.rs +++ b/pywr-core/src/parameters/discount_factor.rs @@ -6,7 +6,6 @@ use crate::state::{ParameterState, State}; use crate::timestep::Timestep; use crate::PywrError; use chrono::Datelike; -use std::any::Any; pub struct DiscountFactorParameter { meta: ParameterMeta, @@ -24,10 +23,7 @@ impl DiscountFactorParameter { } } -impl Parameter for DiscountFactorParameter { - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } +impl Parameter for DiscountFactorParameter { fn meta(&self) -> &ParameterMeta { &self.meta } diff --git a/pywr-core/src/parameters/division.rs b/pywr-core/src/parameters/division.rs index 4a3a749b..f9630ae8 100644 --- a/pywr-core/src/parameters/division.rs +++ b/pywr-core/src/parameters/division.rs @@ -6,7 +6,6 @@ use crate::scenario::ScenarioIndex; use crate::state::{ParameterState, State}; use crate::timestep::Timestep; use crate::PywrError::InvalidMetricValue; -use std::any::Any; pub struct DivisionParameter { meta: ParameterMeta, @@ -24,10 +23,7 @@ impl DivisionParameter { } } -impl Parameter for DivisionParameter { - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } +impl Parameter for DivisionParameter { fn meta(&self) -> &ParameterMeta { &self.meta } diff --git a/pywr-core/src/parameters/indexed_array.rs b/pywr-core/src/parameters/indexed_array.rs index d83a0361..f33e2ede 100644 --- a/pywr-core/src/parameters/indexed_array.rs +++ b/pywr-core/src/parameters/indexed_array.rs @@ -5,7 +5,6 @@ use crate::scenario::ScenarioIndex; use crate::state::{ParameterState, State}; use crate::timestep::Timestep; use crate::PywrError; -use std::any::Any; pub struct IndexedArrayParameter { meta: ParameterMeta, @@ -23,10 +22,7 @@ impl IndexedArrayParameter { } } -impl Parameter for IndexedArrayParameter { - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } +impl Parameter for IndexedArrayParameter { fn meta(&self) -> &ParameterMeta { &self.meta } diff --git a/pywr-core/src/parameters/interpolated.rs b/pywr-core/src/parameters/interpolated.rs index 25111aeb..059722c5 100644 --- a/pywr-core/src/parameters/interpolated.rs +++ b/pywr-core/src/parameters/interpolated.rs @@ -6,7 +6,6 @@ use crate::scenario::ScenarioIndex; use crate::state::{ParameterState, State}; use crate::timestep::Timestep; use crate::PywrError; -use std::any::Any; /// A parameter that interpolates a value to a function with given discrete data points. pub struct InterpolatedParameter { @@ -27,10 +26,7 @@ impl InterpolatedParameter { } } -impl Parameter for InterpolatedParameter { - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } +impl Parameter for InterpolatedParameter { fn meta(&self) -> &ParameterMeta { &self.meta } diff --git a/pywr-core/src/parameters/max.rs b/pywr-core/src/parameters/max.rs index 2074d4b9..b424303b 100644 --- a/pywr-core/src/parameters/max.rs +++ b/pywr-core/src/parameters/max.rs @@ -2,8 +2,6 @@ use crate::metric::Metric; use crate::network::Network; use crate::parameters::{Parameter, ParameterMeta}; use crate::scenario::ScenarioIndex; -use std::any::Any; - use crate::state::{ParameterState, State}; use crate::timestep::Timestep; use crate::PywrError; @@ -24,10 +22,7 @@ impl MaxParameter { } } -impl Parameter for MaxParameter { - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } +impl Parameter for MaxParameter { fn meta(&self) -> &ParameterMeta { &self.meta } diff --git a/pywr-core/src/parameters/min.rs b/pywr-core/src/parameters/min.rs index 62c3e1ca..136ec6f8 100644 --- a/pywr-core/src/parameters/min.rs +++ b/pywr-core/src/parameters/min.rs @@ -2,8 +2,6 @@ use crate::metric::Metric; use crate::network::Network; use crate::parameters::{Parameter, ParameterMeta}; use crate::scenario::ScenarioIndex; -use std::any::Any; - use crate::state::{ParameterState, State}; use crate::timestep::Timestep; use crate::PywrError; @@ -24,10 +22,7 @@ impl MinParameter { } } -impl Parameter for MinParameter { - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } +impl Parameter for MinParameter { fn meta(&self) -> &ParameterMeta { &self.meta } diff --git a/pywr-core/src/parameters/mod.rs b/pywr-core/src/parameters/mod.rs index 03059294..7759a838 100644 --- a/pywr-core/src/parameters/mod.rs +++ b/pywr-core/src/parameters/mod.rs @@ -61,76 +61,59 @@ pub use profiles::{ pub use py::PyParameter; use std::fmt; use std::fmt::{Display, Formatter}; +use std::marker::PhantomData; use std::ops::Deref; pub use threshold::{Predicate, ThresholdParameter}; pub use vector::VectorParameter; -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct ParameterIndex(usize); - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct IndexParameterIndex(usize); - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct MultiValueParameterIndex(usize); - -impl ParameterIndex { - pub fn new(idx: usize) -> Self { - Self(idx) - } +/// Generic parameter index. +/// +/// This is a wrapper around usize that is used to index parameters in the state. It is +/// generic over the type of the value that the parameter returns. +#[derive(Debug)] +pub struct ParameterIndex { + idx: usize, + phantom: PhantomData, } -impl IndexParameterIndex { - pub fn new(idx: usize) -> Self { - Self(idx) +// These implementations are required because the derive macro does not work well with PhantomData. +// See issue: https://github.com/rust-lang/rust/issues/26925 +impl Clone for ParameterIndex { + fn clone(&self) -> Self { + *self } } -impl MultiValueParameterIndex { - pub fn new(idx: usize) -> Self { - Self(idx) - } -} +impl Copy for ParameterIndex {} -impl Deref for ParameterIndex { - type Target = usize; - - fn deref(&self) -> &Self::Target { - &self.0 +impl PartialEq for ParameterIndex { + fn eq(&self, other: &Self) -> bool { + self.idx == other.idx } } -impl Deref for IndexParameterIndex { - type Target = usize; +impl Eq for ParameterIndex {} - fn deref(&self) -> &Self::Target { - &self.0 +impl ParameterIndex { + pub fn new(idx: usize) -> Self { + Self { + idx, + phantom: PhantomData, + } } } -impl Deref for MultiValueParameterIndex { +impl Deref for ParameterIndex { type Target = usize; fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl Display for ParameterIndex { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) + &self.idx } } -impl Display for IndexParameterIndex { +impl Display for ParameterIndex { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } -} - -impl Display for MultiValueParameterIndex { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) + write!(f, "{}", self.idx) } } @@ -202,9 +185,10 @@ pub fn downcast_variable_config_ref(variable_config: &dyn VariableCo } } -// TODO It might be possible to make these three traits into a single generic trait -pub trait Parameter: Send + Sync { - fn as_any_mut(&mut self) -> &mut dyn Any; +/// A trait that defines a component that produces a value each time-step. +/// +/// The trait is generic over the type of the value produced. +pub trait Parameter: Send + Sync { fn meta(&self) -> &ParameterMeta; fn name(&self) -> &str { self.meta().name.as_str() @@ -225,7 +209,7 @@ pub trait Parameter: Send + Sync { model: &Network, state: &State, internal_state: &mut Option>, - ) -> Result; + ) -> Result; fn after( &self, @@ -269,79 +253,10 @@ pub trait Parameter: Send + Sync { } } -pub trait IndexParameter: Send + Sync { - fn meta(&self) -> &ParameterMeta; - fn name(&self) -> &str { - self.meta().name.as_str() - } - - fn setup( - &self, - _timesteps: &[Timestep], - _scenario_index: &ScenarioIndex, - ) -> Result>, PywrError> { - Ok(None) - } - - fn compute( - &self, - timestep: &Timestep, - scenario_index: &ScenarioIndex, - model: &Network, - state: &State, - internal_state: &mut Option>, - ) -> Result; - - fn after( - &self, - #[allow(unused_variables)] timestep: &Timestep, - #[allow(unused_variables)] scenario_index: &ScenarioIndex, - #[allow(unused_variables)] model: &Network, - #[allow(unused_variables)] state: &State, - #[allow(unused_variables)] internal_state: &mut Option>, - ) -> Result<(), PywrError> { - Ok(()) - } -} - -pub trait MultiValueParameter: Send + Sync { - fn meta(&self) -> &ParameterMeta; - fn name(&self) -> &str { - self.meta().name.as_str() - } - fn setup( - &self, - #[allow(unused_variables)] timesteps: &[Timestep], - #[allow(unused_variables)] scenario_index: &ScenarioIndex, - ) -> Result>, PywrError> { - Ok(None) - } - - fn compute( - &self, - timestep: &Timestep, - scenario_index: &ScenarioIndex, - model: &Network, - state: &State, - internal_state: &mut Option>, - ) -> Result; - - fn after( - &self, - #[allow(unused_variables)] timestep: &Timestep, - #[allow(unused_variables)] scenario_index: &ScenarioIndex, - #[allow(unused_variables)] model: &Network, - #[allow(unused_variables)] state: &State, - #[allow(unused_variables)] internal_state: &mut Option>, - ) -> Result<(), PywrError> { - Ok(()) - } -} - #[derive(Copy, Clone)] pub enum IndexValue { Constant(usize), - Dynamic(IndexParameterIndex), + Dynamic(ParameterIndex), } impl IndexValue { @@ -354,9 +269,9 @@ impl IndexValue { } pub enum ParameterType { - Parameter(ParameterIndex), - Index(IndexParameterIndex), - Multi(MultiValueParameterIndex), + Parameter(ParameterIndex), + Index(ParameterIndex), + Multi(ParameterIndex), } /// A parameter that can be optimised. diff --git a/pywr-core/src/parameters/negative.rs b/pywr-core/src/parameters/negative.rs index f44621a7..fef811c9 100644 --- a/pywr-core/src/parameters/negative.rs +++ b/pywr-core/src/parameters/negative.rs @@ -5,7 +5,6 @@ use crate::scenario::ScenarioIndex; use crate::state::{ParameterState, State}; use crate::timestep::Timestep; use crate::PywrError; -use std::any::Any; pub struct NegativeParameter { meta: ParameterMeta, @@ -21,10 +20,7 @@ impl NegativeParameter { } } -impl Parameter for NegativeParameter { - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } +impl Parameter for NegativeParameter { fn meta(&self) -> &ParameterMeta { &self.meta } diff --git a/pywr-core/src/parameters/offset.rs b/pywr-core/src/parameters/offset.rs index 65c8c1af..36a7d31f 100644 --- a/pywr-core/src/parameters/offset.rs +++ b/pywr-core/src/parameters/offset.rs @@ -5,8 +5,6 @@ use crate::parameters::{ Parameter, ParameterMeta, VariableConfig, VariableParameter, }; use crate::scenario::ScenarioIndex; -use std::any::Any; - use crate::state::{ParameterState, State}; use crate::timestep::Timestep; use crate::PywrError; @@ -41,10 +39,7 @@ impl OffsetParameter { } } -impl Parameter for OffsetParameter { - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } +impl Parameter for OffsetParameter { fn meta(&self) -> &ParameterMeta { &self.meta } diff --git a/pywr-core/src/parameters/polynomial.rs b/pywr-core/src/parameters/polynomial.rs index 16d82126..46320111 100644 --- a/pywr-core/src/parameters/polynomial.rs +++ b/pywr-core/src/parameters/polynomial.rs @@ -5,7 +5,6 @@ use crate::scenario::ScenarioIndex; use crate::state::{ParameterState, State}; use crate::timestep::Timestep; use crate::PywrError; -use std::any::Any; pub struct Polynomial1DParameter { meta: ParameterMeta, @@ -27,10 +26,7 @@ impl Polynomial1DParameter { } } -impl Parameter for Polynomial1DParameter { - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } +impl Parameter for Polynomial1DParameter { fn meta(&self) -> &ParameterMeta { &self.meta } diff --git a/pywr-core/src/parameters/profiles/daily.rs b/pywr-core/src/parameters/profiles/daily.rs index 6f2411d2..c13d14c8 100644 --- a/pywr-core/src/parameters/profiles/daily.rs +++ b/pywr-core/src/parameters/profiles/daily.rs @@ -5,7 +5,6 @@ use crate::state::{ParameterState, State}; use crate::timestep::Timestep; use crate::PywrError; use chrono::Datelike; -use std::any::Any; pub struct DailyProfileParameter { meta: ParameterMeta, @@ -21,10 +20,7 @@ impl DailyProfileParameter { } } -impl Parameter for DailyProfileParameter { - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } +impl Parameter for DailyProfileParameter { fn meta(&self) -> &ParameterMeta { &self.meta } diff --git a/pywr-core/src/parameters/profiles/monthly.rs b/pywr-core/src/parameters/profiles/monthly.rs index 60b49679..3bade16b 100644 --- a/pywr-core/src/parameters/profiles/monthly.rs +++ b/pywr-core/src/parameters/profiles/monthly.rs @@ -1,11 +1,10 @@ use crate::network::Network; -use crate::parameters::{Parameter, ParameterIndex, ParameterMeta}; +use crate::parameters::{Parameter, ParameterMeta}; use crate::scenario::ScenarioIndex; use crate::state::{ParameterState, State}; use crate::timestep::Timestep; use crate::PywrError; -use chrono::{Datelike, NaiveDateTime}; -use std::any::Any; +use chrono::{Datelike, NaiveDateTime, Timelike}; #[derive(Copy, Clone)] pub enum MonthlyInterpDay { @@ -49,7 +48,9 @@ fn interpolate_first(date: &NaiveDateTime, first_value: f64, last_value: f64) -> } else if date.day() > days_in_month { last_value } else { - first_value + (last_value - first_value) * (date.day() - 1) as f64 / days_in_month as f64 + first_value + + (last_value - first_value) * (date.day() as f64 + date.num_seconds_from_midnight() as f64 / 86400.0 - 1.0) + / days_in_month as f64 } } @@ -63,14 +64,13 @@ fn interpolate_last(date: &NaiveDateTime, first_value: f64, last_value: f64) -> } else if date.day() >= days_in_month { last_value } else { - first_value + (last_value - first_value) * date.day() as f64 / days_in_month as f64 + first_value + + (last_value - first_value) * (date.day() as f64 + date.num_seconds_from_midnight() as f64 / 86400.0) + / days_in_month as f64 } } -impl Parameter for MonthlyProfileParameter { - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } +impl Parameter for MonthlyProfileParameter { fn meta(&self) -> &ParameterMeta { &self.meta } @@ -105,21 +105,3 @@ impl Parameter for MonthlyProfileParameter { Ok(v) } } - -// TODO this is a proof-of-concept of a external "variable" -#[allow(dead_code)] -pub struct MonthlyProfileVariable { - index: ParameterIndex, -} - -#[allow(dead_code)] -impl MonthlyProfileVariable { - fn update(&self, model: &mut Network, new_values: &[f64]) { - let p = model.get_mut_parameter(&self.index).unwrap(); - - let profile = p.as_any_mut().downcast_mut::().unwrap(); - - // This panics if the slices are different lengths! - profile.values.copy_from_slice(new_values); - } -} diff --git a/pywr-core/src/parameters/profiles/rbf.rs b/pywr-core/src/parameters/profiles/rbf.rs index e92d935f..a3d65876 100644 --- a/pywr-core/src/parameters/profiles/rbf.rs +++ b/pywr-core/src/parameters/profiles/rbf.rs @@ -9,7 +9,6 @@ use crate::timestep::Timestep; use crate::PywrError; use chrono::Datelike; use nalgebra::DMatrix; -use std::any::Any; pub struct RbfProfileVariableConfig { days_of_year_range: Option, @@ -101,11 +100,7 @@ impl RbfProfileParameter { } } -impl Parameter for RbfProfileParameter { - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } - +impl Parameter for RbfProfileParameter { fn meta(&self) -> &ParameterMeta { &self.meta } diff --git a/pywr-core/src/parameters/profiles/uniform_drawdown.rs b/pywr-core/src/parameters/profiles/uniform_drawdown.rs index a34d4439..12f619b8 100644 --- a/pywr-core/src/parameters/profiles/uniform_drawdown.rs +++ b/pywr-core/src/parameters/profiles/uniform_drawdown.rs @@ -5,7 +5,6 @@ use crate::state::{ParameterState, State}; use crate::timestep::Timestep; use crate::PywrError; use chrono::{Datelike, NaiveDate}; -use std::any::Any; fn is_leap_year(year: i32) -> bool { (year % 4 == 0) & ((year % 100 != 0) | (year % 400 == 0)) @@ -32,10 +31,7 @@ impl UniformDrawdownProfileParameter { } } -impl Parameter for UniformDrawdownProfileParameter { - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } +impl Parameter for UniformDrawdownProfileParameter { fn meta(&self) -> &ParameterMeta { &self.meta } diff --git a/pywr-core/src/parameters/profiles/weekly.rs b/pywr-core/src/parameters/profiles/weekly.rs index 4eadb5b4..7bb6eb6f 100644 --- a/pywr-core/src/parameters/profiles/weekly.rs +++ b/pywr-core/src/parameters/profiles/weekly.rs @@ -4,8 +4,7 @@ use crate::scenario::ScenarioIndex; use crate::state::{ParameterState, State}; use crate::timestep::Timestep; use crate::PywrError; -use chrono::{Datelike, NaiveDate}; -use std::any::Any; +use chrono::{Datelike, NaiveDate, NaiveDateTime, Timelike}; use thiserror::Error; pub enum WeeklyInterpDay { @@ -20,17 +19,22 @@ pub enum WeeklyProfileValues { } impl WeeklyProfileValues { - /// Get the week position in a calendar year from date. The position starts from 0 on the - /// first week day and ends with 1 on the last week day. - fn current_pos(&self, date: &NaiveDate) -> f64 { - let current_day = date.ordinal(); - (current_day - 1) as f64 / 7.0 + /// Get the week position in a calendar year from date. In the first year week, the position + /// starts from 0 on the first week day and ends with 1 on the last day. Seconds may be + /// included in the position by setting with_seconds to true. + fn current_pos(&self, date: &NaiveDateTime, with_seconds: bool) -> f64 { + let mut current_day = date.ordinal() as f64; + if with_seconds { + let seconds_in_day = date.num_seconds_from_midnight() as f64 / 86400.0; + current_day += seconds_in_day; + } + (current_day - 1.0) / 7.0 } /// Get the week index from the provided date - fn current_index(&self, date: &NaiveDate) -> usize { + fn current_index(&self, date: &NaiveDateTime) -> usize { let current_day = date.ordinal(); - let current_pos = self.current_pos(date) as usize; + let current_pos = self.current_pos(date, false) as usize; // if year is leap the last week starts on the 365th day let is_leap_year = NaiveDate::from_ymd_opt(date.year(), 1, 1).unwrap().leap_year(); @@ -55,7 +59,7 @@ impl WeeklyProfileValues { } /// Get the value corresponding to the week index for the provided date - fn current(&self, date: &NaiveDate) -> f64 { + fn current(&self, date: &NaiveDateTime) -> f64 { // The current_index function always returns and index between 0 and // 52 (for Self::FiftyTwo) or 53 (Self::FiftyThree). This ensures // that the index is always in range in the value array below @@ -70,7 +74,7 @@ impl WeeklyProfileValues { /// Get the next week's value based on the week index of the provided date. If the current /// week is larger than the array length, the value corresponding to the first week is /// returned. - fn next(&self, date: &NaiveDate) -> f64 { + fn next(&self, date: &NaiveDateTime) -> f64 { let current_week_index = self.current_index(date); match self { @@ -93,7 +97,7 @@ impl WeeklyProfileValues { /// Get the previous week's value based on the week index of the provided date. If the /// current week index is 0 than the last array value is returned. - fn prev(&self, date: &NaiveDate) -> f64 { + fn prev(&self, date: &NaiveDateTime) -> f64 { let current_week_index = self.current_index(date); match self { @@ -116,8 +120,8 @@ impl WeeklyProfileValues { /// Find the value corresponding to the given date by linearly interpolating between two /// consecutive week's values. - fn interpolate(&self, date: &NaiveDate, first_value: f64, last_value: f64) -> f64 { - let current_pos = self.current_pos(date); + fn interpolate(&self, date: &NaiveDateTime, first_value: f64, last_value: f64) -> f64 { + let current_pos = self.current_pos(date, true); let week_delta = current_pos - current_pos.floor(); first_value + (last_value - first_value) * week_delta } @@ -126,7 +130,7 @@ impl WeeklyProfileValues { /// interpolated profile, the upper boundary in the 52nd and 53rd week is the same when /// WeeklyInterpDay is First (i.e. the value on 1st January). When WeeklyInterpDay is Last the /// 1st and last week will share the same lower bound (i.e. the value on the last week). - fn value(&self, date: &NaiveDate, interp_day: &Option) -> f64 { + fn value(&self, date: &NaiveDateTime, interp_day: &Option) -> f64 { match interp_day { None => self.current(date), Some(interp_day) => match interp_day { @@ -180,10 +184,7 @@ impl WeeklyProfileParameter { } } -impl Parameter for WeeklyProfileParameter { - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } +impl Parameter for WeeklyProfileParameter { fn meta(&self) -> &ParameterMeta { &self.meta } @@ -195,7 +196,7 @@ impl Parameter for WeeklyProfileParameter { _state: &State, _internal_state: &mut Option>, ) -> Result { - Ok(self.values.value(×tep.date.date(), &self.interp_day)) + Ok(self.values.value(×tep.date, &self.interp_day)) } } @@ -204,16 +205,26 @@ mod tests { use crate::parameters::profiles::weekly::{WeeklyInterpDay, WeeklyProfileValues}; use crate::test_utils::assert_approx_array_eq; use chrono::{Datelike, NaiveDate, TimeDelta}; + use float_cmp::{assert_approx_eq, F64Margin}; /// Build a time-series from the weekly profile fn collect(week_size: &WeeklyProfileValues, interp_day: Option) -> Vec { - let dt0 = NaiveDate::from_ymd_opt(2020, 1, 1).unwrap(); - let dt1 = NaiveDate::from_ymd_opt(2020, 12, 31).unwrap(); + let dt0 = NaiveDate::from_ymd_opt(2020, 1, 1) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap(); + let dt1 = NaiveDate::from_ymd_opt(2020, 12, 31) + .unwrap() + .and_hms_opt(23, 59, 59) + .unwrap(); let mut dt = dt0; let mut data: Vec = Vec::new(); while dt <= dt1 { - let date = NaiveDate::from_ymd_opt(dt.year(), dt.month(), dt.day()).unwrap(); + let date = NaiveDate::from_ymd_opt(dt.year(), dt.month(), dt.day()) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap(); let value = week_size.value(&date, &interp_day); data.push(value); @@ -451,4 +462,31 @@ mod tests { let values_interp_none = collect(&week_size, Some(WeeklyInterpDay::Last)); assert_approx_array_eq(&values_interp_none, &expected_values_interp_last); } + + /// Test the interpolation with the time + #[test] + fn test_time_interpolation() { + let profile = [ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, + 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, + 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, + ]; + let week_size = WeeklyProfileValues::FiftyTwo(profile); + + let t0 = NaiveDate::from_ymd_opt(2016, 1, 1) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap(); + assert_eq!(week_size.interpolate(&t0, 0.0, 1.0), 0.0); + + let t0 = NaiveDate::from_ymd_opt(2016, 1, 7) + .unwrap() + .and_hms_opt(12, 00, 00) + .unwrap(); + let margins = F64Margin { + epsilon: 2.0, + ulps: (f64::EPSILON * 2.0) as i64, + }; + assert_approx_eq!(f64, week_size.interpolate(&t0, 0.0, 1.0), 1.928571429, margins); + } } diff --git a/pywr-core/src/parameters/py.rs b/pywr-core/src/parameters/py.rs index 28714beb..5c9b97a1 100644 --- a/pywr-core/src/parameters/py.rs +++ b/pywr-core/src/parameters/py.rs @@ -1,12 +1,11 @@ use super::{IndexValue, Parameter, ParameterMeta, PywrError, Timestep}; use crate::metric::Metric; use crate::network::Network; -use crate::parameters::{downcast_internal_state_mut, MultiValueParameter}; +use crate::parameters::downcast_internal_state_mut; use crate::scenario::ScenarioIndex; use crate::state::{MultiValue, ParameterState, State}; use pyo3::prelude::*; use pyo3::types::{IntoPyDict, PyDict, PyFloat, PyLong, PyTuple}; -use std::any::Any; use std::collections::HashMap; pub struct PyParameter { @@ -67,21 +66,8 @@ impl PyParameter { Ok(index_values.into_py_dict(py)) } -} -impl Parameter for PyParameter { - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } - fn meta(&self) -> &ParameterMeta { - &self.meta - } - - fn setup( - &self, - _timesteps: &[Timestep], - _scenario_index: &ScenarioIndex, - ) -> Result>, PywrError> { + fn setup(&self) -> Result>, PywrError> { pyo3::prepare_freethreaded_python(); let user_obj: PyObject = Python::with_gil(|py| -> PyResult { @@ -96,26 +82,20 @@ impl Parameter for PyParameter { Ok(Some(internal.into_boxed_any())) } - // fn before(&self, internal_state: &mut Option>) -> Result<(), PywrError> { - // let internal = downcast_internal_state::(internal_state); - // - // Python::with_gil(|py| internal.user_obj.call_method0(py, "before")) - // .map_err(|e| PywrError::PythonError(e.to_string()))?; - // - // Ok(()) - // } - - fn compute( + fn compute( &self, timestep: &Timestep, scenario_index: &ScenarioIndex, model: &Network, state: &State, internal_state: &mut Option>, - ) -> Result { + ) -> Result + where + T: for<'a> FromPyObject<'a>, + { let internal = downcast_internal_state_mut::(internal_state); - let value: f64 = Python::with_gil(|py| { + let value: T = Python::with_gil(|py| { let date = timestep.date.into_py(py); let si = scenario_index.index.into_py(py); @@ -164,7 +144,7 @@ impl Parameter for PyParameter { } } -impl MultiValueParameter for PyParameter { +impl Parameter for PyParameter { fn meta(&self) -> &ParameterMeta { &self.meta } @@ -174,18 +154,79 @@ impl MultiValueParameter for PyParameter { _timesteps: &[Timestep], _scenario_index: &ScenarioIndex, ) -> Result>, PywrError> { - pyo3::prepare_freethreaded_python(); + self.setup() + } - let user_obj: PyObject = Python::with_gil(|py| -> PyResult { - let args = self.args.as_ref(py); - let kwargs = self.kwargs.as_ref(py); - self.object.call(py, args, Some(kwargs)) - }) - .unwrap(); + fn compute( + &self, + timestep: &Timestep, + scenario_index: &ScenarioIndex, + model: &Network, + state: &State, + internal_state: &mut Option>, + ) -> Result { + self.compute(timestep, scenario_index, model, state, internal_state) + } - let internal = Internal { user_obj }; + fn after( + &self, + timestep: &Timestep, + scenario_index: &ScenarioIndex, + model: &Network, + state: &State, + internal_state: &mut Option>, + ) -> Result<(), PywrError> { + self.after(timestep, scenario_index, model, state, internal_state) + } +} - Ok(Some(internal.into_boxed_any())) +impl Parameter for PyParameter { + fn meta(&self) -> &ParameterMeta { + &self.meta + } + + fn setup( + &self, + _timesteps: &[Timestep], + _scenario_index: &ScenarioIndex, + ) -> Result>, PywrError> { + self.setup() + } + + fn compute( + &self, + timestep: &Timestep, + scenario_index: &ScenarioIndex, + model: &Network, + state: &State, + internal_state: &mut Option>, + ) -> Result { + self.compute(timestep, scenario_index, model, state, internal_state) + } + + fn after( + &self, + timestep: &Timestep, + scenario_index: &ScenarioIndex, + model: &Network, + state: &State, + internal_state: &mut Option>, + ) -> Result<(), PywrError> { + self.after(timestep, scenario_index, model, state, internal_state) + } +} + +impl Parameter for PyParameter { + fn meta(&self) -> &ParameterMeta { + &self.meta + } + + fn setup( + &self, + _timesteps: &[Timestep], + _scenario_index: &ScenarioIndex, + ) -> Result>, PywrError> { + self.setup() } // fn before(&self, internal_state: &mut Option>) -> Result<(), PywrError> { @@ -261,27 +302,7 @@ impl MultiValueParameter for PyParameter { state: &State, internal_state: &mut Option>, ) -> Result<(), PywrError> { - let internal = downcast_internal_state_mut::(internal_state); - - Python::with_gil(|py| { - // Only do this if the object has an "after" method defined. - if internal.user_obj.getattr(py, "after").is_ok() { - let date = timestep.date.into_py(py); - - let si = scenario_index.index.into_py(py); - - let metric_dict = self.get_metrics_dict(model, state, py)?; - let index_dict = self.get_indices_dict(state, py)?; - - let args = PyTuple::new(py, [date.as_ref(py), si.as_ref(py), metric_dict, index_dict]); - - internal.user_obj.call_method1(py, "after", args)?; - } - Ok(()) - }) - .map_err(|e: PyErr| PywrError::PythonError(e.to_string()))?; - - Ok(()) + self.after(timestep, scenario_index, model, state, internal_state) } } @@ -343,7 +364,7 @@ class MyParameter: let mut internal_p_states: Vec<_> = scenario_indices .iter() - .map(|si| Parameter::setup(¶m, ×teps, si).expect("Could not setup the PyParameter")) + .map(|si| Parameter::::setup(¶m, ×teps, si).expect("Could not setup the PyParameter")) .collect(); let model = Network::default(); @@ -412,14 +433,14 @@ class MyParameter: let mut internal_p_states: Vec<_> = scenario_indices .iter() - .map(|si| MultiValueParameter::setup(¶m, ×teps, si).expect("Could not setup the PyParameter")) + .map(|si| Parameter::::setup(¶m, ×teps, si).expect("Could not setup the PyParameter")) .collect(); let model = Network::default(); for ts in timesteps { for (si, internal) in scenario_indices.iter().zip(internal_p_states.iter_mut()) { - let value = MultiValueParameter::compute(¶m, ts, si, &model, &state, internal).unwrap(); + let value = Parameter::::compute(¶m, ts, si, &model, &state, internal).unwrap(); assert_approx_eq!(f64, *value.get_value("a-float").unwrap(), std::f64::consts::PI); diff --git a/pywr-core/src/parameters/rhai.rs b/pywr-core/src/parameters/rhai.rs index 8fcfbb5f..636e0a28 100644 --- a/pywr-core/src/parameters/rhai.rs +++ b/pywr-core/src/parameters/rhai.rs @@ -6,7 +6,6 @@ use crate::scenario::ScenarioIndex; use crate::state::{ParameterState, State}; use chrono::Datelike; use rhai::{Dynamic, Engine, Map, Scope, AST}; -use std::any::Any; use std::collections::HashMap; pub struct RhaiParameter { @@ -55,10 +54,7 @@ impl RhaiParameter { } } -impl Parameter for RhaiParameter { - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } +impl Parameter for RhaiParameter { fn meta(&self) -> &ParameterMeta { &self.meta } diff --git a/pywr-core/src/parameters/threshold.rs b/pywr-core/src/parameters/threshold.rs index 6d02fc05..078186b8 100644 --- a/pywr-core/src/parameters/threshold.rs +++ b/pywr-core/src/parameters/threshold.rs @@ -1,6 +1,6 @@ use crate::metric::Metric; use crate::network::Network; -use crate::parameters::{downcast_internal_state_mut, IndexParameter, ParameterMeta}; +use crate::parameters::{downcast_internal_state_mut, Parameter, ParameterMeta}; use crate::scenario::ScenarioIndex; use crate::state::{ParameterState, State}; use crate::timestep::Timestep; @@ -50,7 +50,7 @@ impl ThresholdParameter { } } -impl IndexParameter for ThresholdParameter { +impl Parameter for ThresholdParameter { fn meta(&self) -> &ParameterMeta { &self.meta } diff --git a/pywr-core/src/parameters/vector.rs b/pywr-core/src/parameters/vector.rs index 7a9263a0..d1c0e1fc 100644 --- a/pywr-core/src/parameters/vector.rs +++ b/pywr-core/src/parameters/vector.rs @@ -4,7 +4,6 @@ use crate::scenario::ScenarioIndex; use crate::state::{ParameterState, State}; use crate::timestep::Timestep; use crate::PywrError; -use std::any::Any; pub struct VectorParameter { meta: ParameterMeta, @@ -20,10 +19,7 @@ impl VectorParameter { } } -impl Parameter for VectorParameter { - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } +impl Parameter for VectorParameter { fn meta(&self) -> &ParameterMeta { &self.meta } diff --git a/pywr-core/src/solvers/builder.rs b/pywr-core/src/solvers/builder.rs index e644076e..2bdc2715 100644 --- a/pywr-core/src/solvers/builder.rs +++ b/pywr-core/src/solvers/builder.rs @@ -293,6 +293,14 @@ where I::from(self.builder.col_upper.len()).unwrap() } + pub fn num_rows(&self) -> I { + I::from(self.builder.row_upper.len()).unwrap() + } + + pub fn num_non_zero(&self) -> I { + I::from(self.builder.elements.len()).unwrap() + } + pub fn col_lower(&self) -> &[f64] { &self.builder.col_lower } @@ -313,6 +321,10 @@ where &self.builder.row_upper } + pub fn row_mask(&self) -> &[I] { + &self.builder.row_mask + } + pub fn row_starts(&self) -> &[I] { &self.builder.row_starts } diff --git a/pywr-core/src/solvers/highs/settings.rs b/pywr-core/src/solvers/highs/settings.rs index 939bc26f..6f3e01aa 100644 --- a/pywr-core/src/solvers/highs/settings.rs +++ b/pywr-core/src/solvers/highs/settings.rs @@ -39,14 +39,11 @@ impl HighsSolverSettings { /// /// ``` /// use std::num::NonZeroUsize; -/// use pywr::solvers::ClpSolverSettingsBuilder; +/// use pywr_core::solvers::HighsSolverSettingsBuilder; /// // Settings with parallel enabled and 4 threads. -/// let settings = ClpSolverSettingsBuilder::default().parallel().threads(4).build(); -/// -/// let mut builder = ClpSolverSettingsBuilder::default(); -/// builder.chunk_size(NonZeroUsize::new(1024).unwrap()); -/// let settings = builder.build(); +/// let settings = HighsSolverSettingsBuilder::default().parallel().threads(4).build(); /// +/// let mut builder = HighsSolverSettingsBuilder::default(); /// builder.parallel(); /// let settings = builder.build(); /// @@ -97,6 +94,6 @@ mod tests { }; let settings_from_builder = HighsSolverSettingsBuilder::default().parallel().build(); - assert_eq!(settings_from_builder, settings_from_builder); + assert_eq!(settings_from_builder, settings); } } diff --git a/pywr-core/src/state.rs b/pywr-core/src/state.rs index 6f896b6e..194beeae 100644 --- a/pywr-core/src/state.rs +++ b/pywr-core/src/state.rs @@ -3,7 +3,7 @@ use crate::edge::{Edge, EdgeIndex}; use crate::models::MultiNetworkTransferIndex; use crate::network::Network; use crate::node::{Node, NodeIndex}; -use crate::parameters::{IndexParameterIndex, MultiValueParameterIndex, ParameterIndex}; +use crate::parameters::ParameterIndex; use crate::timestep::Timestep; use crate::virtual_storage::VirtualStorageIndex; use crate::PywrError; @@ -271,27 +271,30 @@ impl ParameterStates { } } - pub fn get_value_state(&self, index: ParameterIndex) -> Option<&Option>> { + pub fn get_value_state(&self, index: ParameterIndex) -> Option<&Option>> { self.values.get(*index.deref()) } - pub fn get_mut_value_state(&mut self, index: ParameterIndex) -> Option<&mut Option>> { + pub fn get_mut_value_state(&mut self, index: ParameterIndex) -> Option<&mut Option>> { self.values.get_mut(*index.deref()) } - pub fn get_mut_index_state(&mut self, index: IndexParameterIndex) -> Option<&mut Option>> { + pub fn get_mut_index_state( + &mut self, + index: ParameterIndex, + ) -> Option<&mut Option>> { self.indices.get_mut(*index.deref()) } pub fn get_mut_multi_state( &mut self, - index: MultiValueParameterIndex, + index: ParameterIndex, ) -> Option<&mut Option>> { self.multi.get_mut(*index.deref()) } } -#[derive(Debug, Default, Clone)] +#[derive(Debug, Default, Clone, PartialEq)] pub struct MultiValue { values: HashMap, indices: HashMap, @@ -328,14 +331,14 @@ impl ParameterValues { } } - fn get_value(&self, idx: ParameterIndex) -> Result { + fn get_value(&self, idx: ParameterIndex) -> Result { match self.values.get(*idx.deref()) { Some(s) => Ok(*s), None => Err(PywrError::ParameterIndexNotFound(idx)), } } - fn set_value(&mut self, idx: ParameterIndex, value: f64) -> Result<(), PywrError> { + fn set_value(&mut self, idx: ParameterIndex, value: f64) -> Result<(), PywrError> { match self.values.get_mut(*idx.deref()) { Some(s) => { *s = value; @@ -345,14 +348,14 @@ impl ParameterValues { } } - fn get_index(&self, idx: IndexParameterIndex) -> Result { + fn get_index(&self, idx: ParameterIndex) -> Result { match self.indices.get(*idx.deref()) { Some(s) => Ok(*s), None => Err(PywrError::IndexParameterIndexNotFound(idx)), } } - fn set_index(&mut self, idx: IndexParameterIndex, value: usize) -> Result<(), PywrError> { + fn set_index(&mut self, idx: ParameterIndex, value: usize) -> Result<(), PywrError> { match self.indices.get_mut(*idx.deref()) { Some(s) => { *s = value; @@ -362,7 +365,7 @@ impl ParameterValues { } } - fn get_multi_value(&self, idx: MultiValueParameterIndex, key: &str) -> Result { + fn get_multi_value(&self, idx: ParameterIndex, key: &str) -> Result { match self.multi_values.get(*idx.deref()) { Some(s) => match s.get_value(key) { Some(v) => Ok(*v), @@ -372,7 +375,7 @@ impl ParameterValues { } } - fn set_multi_value(&mut self, idx: MultiValueParameterIndex, value: MultiValue) -> Result<(), PywrError> { + fn set_multi_value(&mut self, idx: ParameterIndex, value: MultiValue) -> Result<(), PywrError> { match self.multi_values.get_mut(*idx.deref()) { Some(s) => { *s = value; @@ -382,7 +385,7 @@ impl ParameterValues { } } - fn get_multi_index(&self, idx: MultiValueParameterIndex, key: &str) -> Result { + fn get_multi_index(&self, idx: ParameterIndex, key: &str) -> Result { match self.multi_values.get(*idx.deref()) { Some(s) => match s.get_index(key) { Some(v) => Ok(*v), @@ -637,35 +640,35 @@ impl State { &mut self.network } - pub fn get_parameter_value(&self, idx: ParameterIndex) -> Result { + pub fn get_parameter_value(&self, idx: ParameterIndex) -> Result { self.parameters.get_value(idx) } - pub fn set_parameter_value(&mut self, idx: ParameterIndex, value: f64) -> Result<(), PywrError> { + pub fn set_parameter_value(&mut self, idx: ParameterIndex, value: f64) -> Result<(), PywrError> { self.parameters.set_value(idx, value) } - pub fn get_parameter_index(&self, idx: IndexParameterIndex) -> Result { + pub fn get_parameter_index(&self, idx: ParameterIndex) -> Result { self.parameters.get_index(idx) } - pub fn set_parameter_index(&mut self, idx: IndexParameterIndex, value: usize) -> Result<(), PywrError> { + pub fn set_parameter_index(&mut self, idx: ParameterIndex, value: usize) -> Result<(), PywrError> { self.parameters.set_index(idx, value) } - pub fn get_multi_parameter_value(&self, idx: MultiValueParameterIndex, key: &str) -> Result { + pub fn get_multi_parameter_value(&self, idx: ParameterIndex, key: &str) -> Result { self.parameters.get_multi_value(idx, key) } pub fn set_multi_parameter_value( &mut self, - idx: MultiValueParameterIndex, + idx: ParameterIndex, value: MultiValue, ) -> Result<(), PywrError> { self.parameters.set_multi_value(idx, value) } - pub fn get_multi_parameter_index(&self, idx: MultiValueParameterIndex, key: &str) -> Result { + pub fn get_multi_parameter_index(&self, idx: ParameterIndex, key: &str) -> Result { self.parameters.get_multi_index(idx, key) } diff --git a/pywr-core/src/test_utils.rs b/pywr-core/src/test_utils.rs index 284825e4..941e4d3d 100644 --- a/pywr-core/src/test_utils.rs +++ b/pywr-core/src/test_utils.rs @@ -158,7 +158,7 @@ pub fn simple_storage_model() -> Model { /// See [`AssertionRecorder`] for more information. pub fn run_and_assert_parameter( model: &mut Model, - parameter: Box, + parameter: Box>, expected_values: Array2, ulps: Option, epsilon: Option, diff --git a/pywr-schema-macros/Cargo.toml b/pywr-schema-macros/Cargo.toml new file mode 100644 index 00000000..f9046ed6 --- /dev/null +++ b/pywr-schema-macros/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "pywr-schema-macros" +version = "2.0.0-dev" +edition = "2021" +rust-version = "1.60" +description = "A generalised water resource allocation model." +readme = "../README.md" +repository = "https://github.com/pywr/pywr-next/" +license = "MIT OR Apache-2.0" +keywords = ["water", "modelling"] +categories = ["science", "simulation"] + +[lib] +name = "pywr_schema_macros" +path = "src/lib.rs" +proc-macro = true + +[dependencies] +syn = "2.0.52" +quote = "1.0.35" + + diff --git a/pywr-schema-macros/src/lib.rs b/pywr-schema-macros/src/lib.rs new file mode 100644 index 00000000..827560d8 --- /dev/null +++ b/pywr-schema-macros/src/lib.rs @@ -0,0 +1,196 @@ +use proc_macro::TokenStream; +use quote::quote; + +/// A derive macro for Pywr nodes that implements `parameters` and `parameters_mut` methods. +#[proc_macro_derive(PywrNode)] +pub fn pywr_node_macro(input: TokenStream) -> TokenStream { + // Parse the input tokens into a syntax tree + let input = syn::parse_macro_input!(input as syn::DeriveInput); + impl_parameter_references_derive(&input) +} + +enum PywrField { + Optional(syn::Ident), + Required(syn::Ident), +} + +/// Generates a [`TokenStream`] containing the implementation of two methods, `parameters` +/// and `parameters_mut`, for the given struct. +/// +/// Both method returns a [`HashMap`] of parameter names to [`DynamicFloatValue`]. This +/// is intended to be used for nodes and parameter structs in the Pywr schema. +fn impl_parameter_references_derive(ast: &syn::DeriveInput) -> TokenStream { + // Name of the node type + let name = &ast.ident; + + if let syn::Data::Struct(data) = &ast.data { + // Only apply this to structs + + // Help struct for capturing parameter fields and whether they are optional. + struct ParamField { + field_name: syn::Ident, + optional: bool, + } + + // Iterate through all fields of the struct. Try to find fields that reference + // parameters (e.g. `Option` or `ParameterValue`). + let parameter_fields: Vec = data + .fields + .iter() + .filter_map(|field| { + let field_ident = field.ident.as_ref()?; + // Identify optional fields + match type_to_ident(&field.ty) { + Some(PywrField::Optional(ident)) => { + // If optional and a parameter identifier then add to the list + is_parameter_ident(&ident).then_some(ParamField { + field_name: field_ident.clone(), + optional: true, + }) + } + Some(PywrField::Required(ident)) => { + // Otherwise, if a parameter identifier then add to the list + is_parameter_ident(&ident).then_some(ParamField { + field_name: field_ident.clone(), + optional: false, + }) + } + None => None, // All other fields are ignored. + } + }) + .collect(); + + // Insert statements for non-mutable version + let inserts = parameter_fields + .iter() + .map(|param_field| { + let ident = ¶m_field.field_name; + let key = ident.to_string(); + if param_field.optional { + quote! { + if let Some(p) = &self.#ident { + attributes.insert(#key, p.into()); + } + } + } else { + quote! { + let #ident = &self.#ident; + attributes.insert(#key, #ident.into()); + } + } + }) + .collect::>(); + + // Insert statements for mutable version + let inserts_mut = parameter_fields + .iter() + .map(|param_field| { + let ident = ¶m_field.field_name; + let key = ident.to_string(); + if param_field.optional { + quote! { + if let Some(p) = &mut self.#ident { + attributes.insert(#key, p.into()); + } + } + } else { + quote! { + let #ident = &mut self.#ident; + attributes.insert(#key, #ident.into()); + } + } + }) + .collect::>(); + + // Create the two parameter methods using the insert statements + let expanded = quote! { + impl #name { + pub fn parameters(&self) -> HashMap<&str, &DynamicFloatValue> { + let mut attributes = HashMap::new(); + #( + #inserts + )* + attributes + } + + pub fn parameters_mut(&mut self) -> HashMap<&str, &mut DynamicFloatValue> { + let mut attributes = HashMap::new(); + #( + #inserts_mut + )* + attributes + } + } + }; + + // Hand the output tokens back to the compiler. + TokenStream::from(expanded) + } else { + panic!("Only structs are supported for #[derive(PywrNode)]") + } +} + +/// Returns the last segment of a type path as an identifier +fn type_to_ident(ty: &syn::Type) -> Option { + match ty { + // Match type's that are a path and not a self type. + syn::Type::Path(type_path) if type_path.qself.is_none() => { + // Match on the last segment + match type_path.path.segments.last() { + Some(last_segment) => { + let ident = &last_segment.ident; + + if ident == "Option" { + // The last segment is an Option, now we need to parse the argument + // I.e. the bit in inside the angle brackets. + let first_arg = match &last_segment.arguments { + syn::PathArguments::AngleBracketed(params) => params.args.first(), + _ => None, + }; + + // Find type arguments; ignore others + let arg_ty = match first_arg { + Some(syn::GenericArgument::Type(ty)) => Some(ty), + _ => None, + }; + + // Match on path types that are no self types. + let arg_type_path = match arg_ty { + Some(ty) => match ty { + syn::Type::Path(type_path) if type_path.qself.is_none() => Some(type_path), + _ => None, + }, + None => None, + }; + + // Get the last segment of the path + let last_segment = match arg_type_path { + Some(type_path) => type_path.path.segments.last(), + None => None, + }; + + // Finally, if there's a last segment return this as an optional `PywrField` + match last_segment { + Some(last_segment) => { + let ident = &last_segment.ident; + Some(PywrField::Optional(ident.clone())) + } + None => None, + } + } else { + // Otherwise, assume this a simple required field + Some(PywrField::Required(ident.clone())) + } + } + None => None, + } + } + _ => None, + } +} + +fn is_parameter_ident(ident: &syn::Ident) -> bool { + // TODO this currenty omits more complex attributes, such as `factors` for AggregatedNode + // and steps for PiecewiseLinks, that can internally contain `DynamicFloatValue` fields + ident == "DynamicFloatValue" +} diff --git a/pywr-schema/Cargo.toml b/pywr-schema/Cargo.toml index 40428a82..aa6a529b 100644 --- a/pywr-schema/Cargo.toml +++ b/pywr-schema/Cargo.toml @@ -33,6 +33,7 @@ thiserror = { workspace = true } pywr-v1-schema = { workspace = true } pywr-core = { path="../pywr-core" } chrono = { workspace = true, features = ["serde"] } +pywr-schema-macros = { path = "../pywr-schema-macros" } [dev-dependencies] tempfile = "3.3.0" diff --git a/pywr-schema/src/model.rs b/pywr-schema/src/model.rs index f33b7de9..3f0850d6 100644 --- a/pywr-schema/src/model.rs +++ b/pywr-schema/src/model.rs @@ -20,6 +20,16 @@ pub struct Metadata { pub minimum_version: Option, } +impl Default for Metadata { + fn default() -> Self { + Self { + title: "Untitled model".to_string(), + description: None, + minimum_version: None, + } + } +} + impl TryFrom for Metadata { type Error = ConversionError; @@ -64,6 +74,16 @@ pub struct Timestepper { pub timestep: Timestep, } +impl Default for Timestepper { + fn default() -> Self { + Self { + start: DateType::Date(NaiveDate::from_ymd_opt(2000, 1, 1).expect("Invalid date")), + end: DateType::Date(NaiveDate::from_ymd_opt(2000, 12, 31).expect("Invalid date")), + timestep: Timestep::Days(1), + } + } +} + impl TryFrom for Timestepper { type Error = ConversionError; @@ -371,20 +391,38 @@ impl PywrModel { Ok(model) } -} - -impl TryFrom for PywrModel { - type Error = ConversionError; - fn try_from(v1: pywr_v1_schema::PywrModel) -> Result { - let metadata = v1.metadata.try_into()?; - let timestepper = v1.timestepper.try_into()?; + /// Convert a v1 model to a v2 model. + /// + /// This function is used to convert a v1 model to a v2 model. The conversion is not always + /// possible and may result in errors. The errors are returned as a vector of [`ConversionError`]s. + /// alongside the (partially) converted model. This may result in a model that will not + /// function as expected. The user should check the errors and the converted model to ensure + /// that the conversion has been successful. + pub fn from_v1(v1: pywr_v1_schema::PywrModel) -> (Self, Vec) { + let mut errors = Vec::new(); + + let metadata = v1.metadata.try_into().unwrap_or_else(|e| { + errors.push(e); + Metadata::default() + }); + + let timestepper = v1.timestepper.try_into().unwrap_or_else(|e| { + errors.push(e); + Timestepper::default() + }); let nodes = v1 .nodes .into_iter() - .map(|n| n.try_into()) - .collect::, _>>()?; + .filter_map(|n| match n.try_into() { + Ok(n) => Some(n), + Err(e) => { + errors.push(e); + None + } + }) + .collect::>(); let edges = v1.edges.into_iter().map(|e| e.into()).collect(); @@ -393,8 +431,14 @@ impl TryFrom for PywrModel { Some( v1_parameters .into_iter() - .map(|p| p.try_into_v2_parameter(None, &mut unnamed_count)) - .collect::, _>>()?, + .filter_map(|p| match p.try_into_v2_parameter(None, &mut unnamed_count) { + Ok(p) => Some(p), + Err(e) => { + errors.push(e); + None + } + }) + .collect::>(), ) } else { None @@ -413,12 +457,15 @@ impl TryFrom for PywrModel { outputs, }; - Ok(Self { - metadata, - timestepper, - scenarios: None, - network, - }) + ( + Self { + metadata, + timestepper, + scenarios: None, + network, + }, + errors, + ) } } diff --git a/pywr-schema/src/nodes/annual_virtual_storage.rs b/pywr-schema/src/nodes/annual_virtual_storage.rs index 00205506..7965af51 100644 --- a/pywr-schema/src/nodes/annual_virtual_storage.rs +++ b/pywr-schema/src/nodes/annual_virtual_storage.rs @@ -9,7 +9,9 @@ use pywr_core::metric::Metric; use pywr_core::models::ModelDomain; use pywr_core::node::ConstraintValue; use pywr_core::virtual_storage::VirtualStorageReset; +use pywr_schema_macros::PywrNode; use pywr_v1_schema::nodes::AnnualVirtualStorageNode as AnnualVirtualStorageNodeV1; +use std::collections::HashMap; use std::path::Path; #[derive(serde::Deserialize, serde::Serialize, Clone)] @@ -29,7 +31,7 @@ impl Default for AnnualReset { } } -#[derive(serde::Deserialize, serde::Serialize, Clone, Default)] +#[derive(serde::Deserialize, serde::Serialize, Clone, Default, PywrNode)] pub struct AnnualVirtualStorageNode { #[serde(flatten)] pub meta: NodeMeta, diff --git a/pywr-schema/src/nodes/core.rs b/pywr-schema/src/nodes/core.rs index 3f2d3a3f..d58ecf73 100644 --- a/pywr-schema/src/nodes/core.rs +++ b/pywr-schema/src/nodes/core.rs @@ -7,6 +7,7 @@ use pywr_core::derived_metric::DerivedMetric; use pywr_core::metric::Metric; use pywr_core::models::ModelDomain; use pywr_core::node::{ConstraintValue, StorageInitialVolume as CoreStorageInitialVolume}; +use pywr_schema_macros::PywrNode; use pywr_v1_schema::nodes::{ AggregatedNode as AggregatedNodeV1, AggregatedStorageNode as AggregatedStorageNodeV1, CatchmentNode as CatchmentNodeV1, InputNode as InputNodeV1, LinkNode as LinkNodeV1, OutputNode as OutputNodeV1, @@ -15,7 +16,7 @@ use pywr_v1_schema::nodes::{ use std::collections::HashMap; use std::path::Path; -#[derive(serde::Deserialize, serde::Serialize, Clone, Default)] +#[derive(serde::Deserialize, serde::Serialize, Clone, Default, PywrNode)] pub struct InputNode { #[serde(flatten)] pub meta: NodeMeta, @@ -27,21 +28,6 @@ pub struct InputNode { impl InputNode { pub const DEFAULT_ATTRIBUTE: NodeAttribute = NodeAttribute::Outflow; - pub fn parameters(&self) -> HashMap<&str, &DynamicFloatValue> { - let mut attributes = HashMap::new(); - if let Some(p) = &self.max_flow { - attributes.insert("max_flow", p); - } - if let Some(p) = &self.min_flow { - attributes.insert("min_flow", p); - } - if let Some(p) = &self.cost { - attributes.insert("cost", p); - } - - attributes - } - pub fn add_to_model(&self, network: &mut pywr_core::network::Network) -> Result<(), SchemaError> { network.add_input_node(self.meta.name.as_str(), None)?; Ok(()) @@ -137,7 +123,7 @@ impl TryFrom for InputNode { } } -#[derive(serde::Deserialize, serde::Serialize, Clone, Default)] +#[derive(serde::Deserialize, serde::Serialize, Clone, Default, PywrNode)] pub struct LinkNode { #[serde(flatten)] pub meta: NodeMeta, @@ -149,21 +135,6 @@ pub struct LinkNode { impl LinkNode { const DEFAULT_ATTRIBUTE: NodeAttribute = NodeAttribute::Outflow; - pub fn parameters(&self) -> HashMap<&str, &DynamicFloatValue> { - let mut attributes = HashMap::new(); - if let Some(p) = &self.max_flow { - attributes.insert("max_flow", p); - } - if let Some(p) = &self.min_flow { - attributes.insert("min_flow", p); - } - if let Some(p) = &self.cost { - attributes.insert("cost", p); - } - - attributes - } - pub fn add_to_model(&self, network: &mut pywr_core::network::Network) -> Result<(), SchemaError> { network.add_link_node(self.meta.name.as_str(), None)?; Ok(()) @@ -259,7 +230,7 @@ impl TryFrom for LinkNode { } } -#[derive(serde::Deserialize, serde::Serialize, Clone, Default)] +#[derive(serde::Deserialize, serde::Serialize, Clone, Default, PywrNode)] pub struct OutputNode { #[serde(flatten)] pub meta: NodeMeta, @@ -271,21 +242,6 @@ pub struct OutputNode { impl OutputNode { const DEFAULT_ATTRIBUTE: NodeAttribute = NodeAttribute::Inflow; - pub fn parameters(&self) -> HashMap<&str, &DynamicFloatValue> { - let mut attributes = HashMap::new(); - if let Some(p) = &self.max_flow { - attributes.insert("max_flow", p); - } - if let Some(p) = &self.min_flow { - attributes.insert("min_flow", p); - } - if let Some(p) = &self.cost { - attributes.insert("cost", p); - } - - attributes - } - pub fn add_to_model(&self, network: &mut pywr_core::network::Network) -> Result<(), SchemaError> { network.add_output_node(self.meta.name.as_str(), None)?; Ok(()) @@ -407,7 +363,7 @@ impl From for CoreStorageInitialVolume { } } -#[derive(serde::Deserialize, serde::Serialize, Clone, Default, Debug)] +#[derive(serde::Deserialize, serde::Serialize, Clone, Default, Debug, PywrNode)] pub struct StorageNode { #[serde(flatten)] pub meta: NodeMeta, @@ -420,21 +376,6 @@ pub struct StorageNode { impl StorageNode { const DEFAULT_ATTRIBUTE: NodeAttribute = NodeAttribute::Volume; - pub fn parameters(&self) -> HashMap<&str, &DynamicFloatValue> { - let mut attributes = HashMap::new(); - // if let Some(p) = &self.max_volume { - // attributes.insert("max_volume", p); - // } - // if let Some(p) = &self.min_volume { - // attributes.insert("min_volume", p); - // } - if let Some(p) = &self.cost { - attributes.insert("cost", p); - } - - attributes - } - pub fn add_to_model( &self, network: &mut pywr_core::network::Network, @@ -635,7 +576,7 @@ impl TryFrom for StorageNode { /// ``` /// )] -#[derive(serde::Deserialize, serde::Serialize, Clone, Default)] +#[derive(serde::Deserialize, serde::Serialize, Clone, Default, PywrNode)] pub struct CatchmentNode { #[serde(flatten)] pub meta: NodeMeta, @@ -735,7 +676,7 @@ pub enum Factors { Ratio { factors: Vec }, } -#[derive(serde::Deserialize, serde::Serialize, Clone, Default)] +#[derive(serde::Deserialize, serde::Serialize, Clone, Default, PywrNode)] pub struct AggregatedNode { #[serde(flatten)] pub meta: NodeMeta, @@ -877,7 +818,7 @@ impl TryFrom for AggregatedNode { } } -#[derive(serde::Deserialize, serde::Serialize, Clone, Default)] +#[derive(serde::Deserialize, serde::Serialize, Clone, Default, PywrNode)] pub struct AggregatedStorageNode { #[serde(flatten)] pub meta: NodeMeta, diff --git a/pywr-schema/src/nodes/delay.rs b/pywr-schema/src/nodes/delay.rs index 78335ed3..e2d578e7 100644 --- a/pywr-schema/src/nodes/delay.rs +++ b/pywr-schema/src/nodes/delay.rs @@ -1,9 +1,11 @@ use crate::data_tables::LoadedTableCollection; use crate::error::{ConversionError, SchemaError}; use crate::nodes::{NodeAttribute, NodeMeta}; -use crate::parameters::ConstantValue; +use crate::parameters::{ConstantValue, DynamicFloatValue}; use pywr_core::metric::Metric; +use pywr_schema_macros::PywrNode; use pywr_v1_schema::nodes::DelayNode as DelayNodeV1; +use std::collections::HashMap; #[doc = svgbobdoc::transform!( /// This node is used to introduce a delay between flows entering and leaving the node. @@ -24,7 +26,7 @@ use pywr_v1_schema::nodes::DelayNode as DelayNodeV1; /// ``` /// )] -#[derive(serde::Deserialize, serde::Serialize, Clone, Default)] +#[derive(serde::Deserialize, serde::Serialize, Clone, Default, PywrNode)] pub struct DelayNode { #[serde(flatten)] pub meta: NodeMeta, diff --git a/pywr-schema/src/nodes/loss_link.rs b/pywr-schema/src/nodes/loss_link.rs index f9280f87..095e9920 100644 --- a/pywr-schema/src/nodes/loss_link.rs +++ b/pywr-schema/src/nodes/loss_link.rs @@ -5,7 +5,9 @@ use crate::nodes::{NodeAttribute, NodeMeta}; use crate::parameters::{DynamicFloatValue, TryIntoV2Parameter}; use pywr_core::metric::Metric; use pywr_core::models::ModelDomain; +use pywr_schema_macros::PywrNode; use pywr_v1_schema::nodes::LossLinkNode as LossLinkNodeV1; +use std::collections::HashMap; use std::path::Path; #[doc = svgbobdoc::transform!( @@ -25,7 +27,7 @@ use std::path::Path; /// ``` /// )] -#[derive(serde::Deserialize, serde::Serialize, Clone, Default)] +#[derive(serde::Deserialize, serde::Serialize, Clone, Default, PywrNode)] pub struct LossLinkNode { #[serde(flatten)] pub meta: NodeMeta, diff --git a/pywr-schema/src/nodes/mod.rs b/pywr-schema/src/nodes/mod.rs index 1b0823c6..ccf574b1 100644 --- a/pywr-schema/src/nodes/mod.rs +++ b/pywr-schema/src/nodes/mod.rs @@ -300,7 +300,45 @@ impl Node { Node::Link(n) => n.parameters(), Node::Output(n) => n.parameters(), Node::Storage(n) => n.parameters(), - _ => HashMap::new(), // TODO complete + Node::Catchment(n) => n.parameters(), + Node::RiverGauge(n) => n.parameters(), + Node::LossLink(n) => n.parameters(), + Node::River(n) => n.parameters(), + Node::RiverSplitWithGauge(n) => n.parameters(), + Node::WaterTreatmentWorks(n) => n.parameters(), + Node::Aggregated(n) => n.parameters(), + Node::AggregatedStorage(n) => n.parameters(), + Node::VirtualStorage(n) => n.parameters(), + Node::AnnualVirtualStorage(n) => n.parameters(), + Node::PiecewiseLink(n) => n.parameters(), + Node::PiecewiseStorage(n) => n.parameters(), + Node::Delay(n) => n.parameters(), + Node::MonthlyVirtualStorage(n) => n.parameters(), + Node::RollingVirtualStorage(n) => n.parameters(), + } + } + + pub fn parameters_mut(&mut self) -> HashMap<&str, &mut DynamicFloatValue> { + match self { + Node::Input(n) => n.parameters_mut(), + Node::Link(n) => n.parameters_mut(), + Node::Output(n) => n.parameters_mut(), + Node::Storage(n) => n.parameters_mut(), + Node::Catchment(n) => n.parameters_mut(), + Node::RiverGauge(n) => n.parameters_mut(), + Node::LossLink(n) => n.parameters_mut(), + Node::River(n) => n.parameters_mut(), + Node::RiverSplitWithGauge(n) => n.parameters_mut(), + Node::WaterTreatmentWorks(n) => n.parameters_mut(), + Node::Aggregated(n) => n.parameters_mut(), + Node::AggregatedStorage(n) => n.parameters_mut(), + Node::VirtualStorage(n) => n.parameters_mut(), + Node::AnnualVirtualStorage(n) => n.parameters_mut(), + Node::PiecewiseLink(n) => n.parameters_mut(), + Node::PiecewiseStorage(n) => n.parameters_mut(), + Node::Delay(n) => n.parameters_mut(), + Node::MonthlyVirtualStorage(n) => n.parameters_mut(), + Node::RollingVirtualStorage(n) => n.parameters_mut(), } } diff --git a/pywr-schema/src/nodes/monthly_virtual_storage.rs b/pywr-schema/src/nodes/monthly_virtual_storage.rs index 8673c003..a58a34df 100644 --- a/pywr-schema/src/nodes/monthly_virtual_storage.rs +++ b/pywr-schema/src/nodes/monthly_virtual_storage.rs @@ -9,7 +9,9 @@ use pywr_core::metric::Metric; use pywr_core::models::ModelDomain; use pywr_core::node::ConstraintValue; use pywr_core::virtual_storage::VirtualStorageReset; +use pywr_schema_macros::PywrNode; use pywr_v1_schema::nodes::MonthlyVirtualStorageNode as MonthlyVirtualStorageNodeV1; +use std::collections::HashMap; use std::path::Path; #[derive(serde::Deserialize, serde::Serialize, Clone)] @@ -23,7 +25,7 @@ impl Default for NumberOfMonthsReset { } } -#[derive(serde::Deserialize, serde::Serialize, Clone, Default)] +#[derive(serde::Deserialize, serde::Serialize, Clone, Default, PywrNode)] pub struct MonthlyVirtualStorageNode { #[serde(flatten)] pub meta: NodeMeta, diff --git a/pywr-schema/src/nodes/piecewise_link.rs b/pywr-schema/src/nodes/piecewise_link.rs index 02587224..13e42b89 100644 --- a/pywr-schema/src/nodes/piecewise_link.rs +++ b/pywr-schema/src/nodes/piecewise_link.rs @@ -5,7 +5,9 @@ use crate::nodes::{NodeAttribute, NodeMeta}; use crate::parameters::{DynamicFloatValue, TryIntoV2Parameter}; use pywr_core::metric::Metric; use pywr_core::models::ModelDomain; +use pywr_schema_macros::PywrNode; use pywr_v1_schema::nodes::PiecewiseLinkNode as PiecewiseLinkNodeV1; +use std::collections::HashMap; use std::path::Path; #[derive(serde::Deserialize, serde::Serialize, Clone)] @@ -37,7 +39,7 @@ pub struct PiecewiseLinkStep { /// ``` /// )] -#[derive(serde::Deserialize, serde::Serialize, Clone, Default)] +#[derive(serde::Deserialize, serde::Serialize, Clone, Default, PywrNode)] pub struct PiecewiseLinkNode { #[serde(flatten)] pub meta: NodeMeta, diff --git a/pywr-schema/src/nodes/piecewise_storage.rs b/pywr-schema/src/nodes/piecewise_storage.rs index f57b3d8c..2b059d98 100644 --- a/pywr-schema/src/nodes/piecewise_storage.rs +++ b/pywr-schema/src/nodes/piecewise_storage.rs @@ -8,6 +8,8 @@ use pywr_core::metric::Metric; use pywr_core::models::ModelDomain; use pywr_core::node::{ConstraintValue, StorageInitialVolume}; use pywr_core::parameters::VolumeBetweenControlCurvesParameter; +use pywr_schema_macros::PywrNode; +use std::collections::HashMap; use std::path::Path; #[derive(serde::Deserialize, serde::Serialize, Clone)] @@ -42,7 +44,7 @@ pub struct PiecewiseStore { /// ``` /// )] -#[derive(serde::Deserialize, serde::Serialize, Clone, Default)] +#[derive(serde::Deserialize, serde::Serialize, Clone, Default, PywrNode)] pub struct PiecewiseStorageNode { #[serde(flatten)] pub meta: NodeMeta, diff --git a/pywr-schema/src/nodes/river.rs b/pywr-schema/src/nodes/river.rs index 55587875..db6f4a28 100644 --- a/pywr-schema/src/nodes/river.rs +++ b/pywr-schema/src/nodes/river.rs @@ -2,10 +2,11 @@ use crate::error::{ConversionError, SchemaError}; use crate::nodes::{NodeAttribute, NodeMeta}; use crate::parameters::DynamicFloatValue; use pywr_core::metric::Metric; +use pywr_schema_macros::PywrNode; use pywr_v1_schema::nodes::LinkNode as LinkNodeV1; use std::collections::HashMap; -#[derive(serde::Deserialize, serde::Serialize, Clone, Default)] +#[derive(serde::Deserialize, serde::Serialize, Clone, Default, PywrNode)] pub struct RiverNode { #[serde(flatten)] pub meta: NodeMeta, @@ -14,10 +15,6 @@ pub struct RiverNode { impl RiverNode { const DEFAULT_ATTRIBUTE: NodeAttribute = NodeAttribute::Outflow; - pub fn parameters(&self) -> HashMap<&str, &DynamicFloatValue> { - HashMap::new() - } - pub fn add_to_model(&self, network: &mut pywr_core::network::Network) -> Result<(), SchemaError> { network.add_link_node(self.meta.name.as_str(), None)?; Ok(()) diff --git a/pywr-schema/src/nodes/river_gauge.rs b/pywr-schema/src/nodes/river_gauge.rs index 58921b52..d820aa39 100644 --- a/pywr-schema/src/nodes/river_gauge.rs +++ b/pywr-schema/src/nodes/river_gauge.rs @@ -5,7 +5,9 @@ use crate::nodes::{NodeAttribute, NodeMeta}; use crate::parameters::{DynamicFloatValue, TryIntoV2Parameter}; use pywr_core::metric::Metric; use pywr_core::models::ModelDomain; +use pywr_schema_macros::PywrNode; use pywr_v1_schema::nodes::RiverGaugeNode as RiverGaugeNodeV1; +use std::collections::HashMap; use std::path::Path; #[doc = svgbobdoc::transform!( @@ -23,7 +25,7 @@ use std::path::Path; /// ``` /// )] -#[derive(serde::Deserialize, serde::Serialize, Clone, Default)] +#[derive(serde::Deserialize, serde::Serialize, Clone, Default, PywrNode)] pub struct RiverGaugeNode { #[serde(flatten)] pub meta: NodeMeta, diff --git a/pywr-schema/src/nodes/river_split_with_gauge.rs b/pywr-schema/src/nodes/river_split_with_gauge.rs index 8195f8a6..22656123 100644 --- a/pywr-schema/src/nodes/river_split_with_gauge.rs +++ b/pywr-schema/src/nodes/river_split_with_gauge.rs @@ -7,7 +7,9 @@ use pywr_core::aggregated_node::Factors; use pywr_core::metric::Metric; use pywr_core::models::ModelDomain; use pywr_core::node::NodeIndex; +use pywr_schema_macros::PywrNode; use pywr_v1_schema::nodes::RiverSplitWithGaugeNode as RiverSplitWithGaugeNodeV1; +use std::collections::HashMap; use std::path::Path; #[doc = svgbobdoc::transform!( @@ -32,7 +34,7 @@ use std::path::Path; /// ``` /// )] -#[derive(serde::Deserialize, serde::Serialize, Clone, Default)] +#[derive(serde::Deserialize, serde::Serialize, Clone, Default, PywrNode)] pub struct RiverSplitWithGaugeNode { #[serde(flatten)] pub meta: NodeMeta, diff --git a/pywr-schema/src/nodes/rolling_virtual_storage.rs b/pywr-schema/src/nodes/rolling_virtual_storage.rs index 7665f847..6a17b8dd 100644 --- a/pywr-schema/src/nodes/rolling_virtual_storage.rs +++ b/pywr-schema/src/nodes/rolling_virtual_storage.rs @@ -9,7 +9,9 @@ use pywr_core::models::ModelDomain; use pywr_core::node::{ConstraintValue, StorageInitialVolume}; use pywr_core::timestep::TimeDomain; use pywr_core::virtual_storage::VirtualStorageReset; +use pywr_schema_macros::PywrNode; use pywr_v1_schema::nodes::RollingVirtualStorageNode as RollingVirtualStorageNodeV1; +use std::collections::HashMap; use std::num::NonZeroUsize; use std::path::Path; @@ -61,7 +63,7 @@ impl RollingWindow { /// The rolling virtual storage node is useful for representing rolling licences. For example, a 30-day or 90-day /// licence on a water abstraction. /// -#[derive(serde::Deserialize, serde::Serialize, Clone, Default)] +#[derive(serde::Deserialize, serde::Serialize, Clone, Default, PywrNode)] pub struct RollingVirtualStorageNode { #[serde(flatten)] pub meta: NodeMeta, diff --git a/pywr-schema/src/nodes/virtual_storage.rs b/pywr-schema/src/nodes/virtual_storage.rs index 2c247117..75f7dbb3 100644 --- a/pywr-schema/src/nodes/virtual_storage.rs +++ b/pywr-schema/src/nodes/virtual_storage.rs @@ -9,10 +9,12 @@ use pywr_core::metric::Metric; use pywr_core::models::ModelDomain; use pywr_core::node::ConstraintValue; use pywr_core::virtual_storage::VirtualStorageReset; +use pywr_schema_macros::PywrNode; use pywr_v1_schema::nodes::VirtualStorageNode as VirtualStorageNodeV1; +use std::collections::HashMap; use std::path::Path; -#[derive(serde::Deserialize, serde::Serialize, Clone, Default)] +#[derive(serde::Deserialize, serde::Serialize, Clone, Default, PywrNode)] pub struct VirtualStorageNode { #[serde(flatten)] pub meta: NodeMeta, diff --git a/pywr-schema/src/nodes/water_treatment_works.rs b/pywr-schema/src/nodes/water_treatment_works.rs index 3d2b65eb..bc9920fc 100644 --- a/pywr-schema/src/nodes/water_treatment_works.rs +++ b/pywr-schema/src/nodes/water_treatment_works.rs @@ -7,6 +7,8 @@ use num::Zero; use pywr_core::aggregated_node::Factors; use pywr_core::metric::Metric; use pywr_core::models::ModelDomain; +use pywr_schema_macros::PywrNode; +use std::collections::HashMap; use std::path::Path; #[doc = svgbobdoc::transform!( @@ -36,7 +38,7 @@ use std::path::Path; /// ``` /// )] -#[derive(serde::Deserialize, serde::Serialize, Clone, Default)] +#[derive(serde::Deserialize, serde::Serialize, Clone, Default, PywrNode)] pub struct WaterTreatmentWorks { /// Node metadata #[serde(flatten)] diff --git a/pywr-schema/src/parameters/aggregated.rs b/pywr-schema/src/parameters/aggregated.rs index 945a1543..022449c6 100644 --- a/pywr-schema/src/parameters/aggregated.rs +++ b/pywr-schema/src/parameters/aggregated.rs @@ -6,7 +6,7 @@ use crate::parameters::{ TryIntoV2Parameter, }; use pywr_core::models::ModelDomain; -use pywr_core::parameters::{IndexParameterIndex, ParameterIndex}; +use pywr_core::parameters::ParameterIndex; use pywr_v1_schema::parameters::{ AggFunc as AggFuncV1, AggregatedIndexParameter as AggregatedIndexParameterV1, AggregatedParameter as AggregatedParameterV1, IndexAggFunc as IndexAggFuncV1, @@ -99,7 +99,7 @@ impl AggregatedParameter { tables: &LoadedTableCollection, data_path: Option<&Path>, inter_network_transfers: &[PywrMultiNetworkTransfer], - ) -> Result { + ) -> Result, SchemaError> { let metrics = self .metrics .iter() @@ -206,7 +206,7 @@ impl AggregatedIndexParameter { tables: &LoadedTableCollection, data_path: Option<&Path>, inter_network_transfers: &[PywrMultiNetworkTransfer], - ) -> Result { + ) -> Result, SchemaError> { let parameters = self .parameters .iter() diff --git a/pywr-schema/src/parameters/asymmetric_switch.rs b/pywr-schema/src/parameters/asymmetric_switch.rs index 554b4506..3943e2c3 100644 --- a/pywr-schema/src/parameters/asymmetric_switch.rs +++ b/pywr-schema/src/parameters/asymmetric_switch.rs @@ -5,7 +5,7 @@ use crate::parameters::{ DynamicFloatValueType, DynamicIndexValue, IntoV2Parameter, ParameterMeta, TryFromV1Parameter, TryIntoV2Parameter, }; use pywr_core::models::ModelDomain; -use pywr_core::parameters::IndexParameterIndex; +use pywr_core::parameters::ParameterIndex; use pywr_v1_schema::parameters::AsymmetricSwitchIndexParameter as AsymmetricSwitchIndexParameterV1; use std::collections::HashMap; use std::path::Path; @@ -34,7 +34,7 @@ impl AsymmetricSwitchIndexParameter { tables: &LoadedTableCollection, data_path: Option<&Path>, inter_network_transfers: &[PywrMultiNetworkTransfer], - ) -> Result { + ) -> Result, SchemaError> { let on_index_parameter = self.on_index_parameter .load(network, schema, domain, tables, data_path, inter_network_transfers)?; diff --git a/pywr-schema/src/parameters/control_curves.rs b/pywr-schema/src/parameters/control_curves.rs index 9e3477b2..a7f722d0 100644 --- a/pywr-schema/src/parameters/control_curves.rs +++ b/pywr-schema/src/parameters/control_curves.rs @@ -6,7 +6,7 @@ use crate::parameters::{ DynamicFloatValue, IntoV2Parameter, NodeReference, ParameterMeta, TryFromV1Parameter, TryIntoV2Parameter, }; use pywr_core::models::ModelDomain; -use pywr_core::parameters::{IndexParameterIndex, ParameterIndex}; +use pywr_core::parameters::ParameterIndex; use pywr_v1_schema::parameters::{ ControlCurveIndexParameter as ControlCurveIndexParameterV1, ControlCurveInterpolatedParameter as ControlCurveInterpolatedParameterV1, @@ -33,7 +33,7 @@ impl ControlCurveInterpolatedParameter { tables: &LoadedTableCollection, data_path: Option<&Path>, inter_network_transfers: &[PywrMultiNetworkTransfer], - ) -> Result { + ) -> Result, SchemaError> { let metric = self.storage_node.load(network, schema)?; let control_curves = self @@ -136,7 +136,7 @@ impl ControlCurveIndexParameter { tables: &LoadedTableCollection, data_path: Option<&Path>, inter_network_transfers: &[PywrMultiNetworkTransfer], - ) -> Result { + ) -> Result, SchemaError> { let metric = self.storage_node.load(network, schema)?; let control_curves = self @@ -247,7 +247,7 @@ impl ControlCurveParameter { tables: &LoadedTableCollection, data_path: Option<&Path>, inter_network_transfers: &[PywrMultiNetworkTransfer], - ) -> Result { + ) -> Result, SchemaError> { let metric = self.storage_node.load(network, schema)?; let control_curves = self @@ -341,7 +341,7 @@ impl ControlCurvePiecewiseInterpolatedParameter { tables: &LoadedTableCollection, data_path: Option<&Path>, inter_network_transfers: &[PywrMultiNetworkTransfer], - ) -> Result { + ) -> Result, SchemaError> { let metric = self.storage_node.load(network, schema)?; let control_curves = self diff --git a/pywr-schema/src/parameters/core.rs b/pywr-schema/src/parameters/core.rs index a2f0c456..7eac46e0 100644 --- a/pywr-schema/src/parameters/core.rs +++ b/pywr-schema/src/parameters/core.rs @@ -170,7 +170,7 @@ impl ConstantParameter { &self, network: &mut pywr_core::network::Network, tables: &LoadedTableCollection, - ) -> Result { + ) -> Result, SchemaError> { let p = pywr_core::parameters::ConstantParameter::new(&self.meta.name, self.value.load(tables)?); Ok(network.add_parameter(Box::new(p))?) } @@ -226,7 +226,7 @@ impl MaxParameter { tables: &LoadedTableCollection, data_path: Option<&Path>, inter_network_transfers: &[PywrMultiNetworkTransfer], - ) -> Result { + ) -> Result, SchemaError> { let idx = self .parameter .load(network, schema, domain, tables, data_path, inter_network_transfers)?; @@ -301,7 +301,7 @@ impl DivisionParameter { tables: &LoadedTableCollection, data_path: Option<&Path>, inter_network_transfers: &[PywrMultiNetworkTransfer], - ) -> Result { + ) -> Result, SchemaError> { let n = self .numerator .load(network, schema, domain, tables, data_path, inter_network_transfers)?; @@ -376,7 +376,7 @@ impl MinParameter { tables: &LoadedTableCollection, data_path: Option<&Path>, inter_network_transfers: &[PywrMultiNetworkTransfer], - ) -> Result { + ) -> Result, SchemaError> { let idx = self .parameter .load(network, schema, domain, tables, data_path, inter_network_transfers)?; @@ -433,7 +433,7 @@ impl NegativeParameter { tables: &LoadedTableCollection, data_path: Option<&Path>, inter_network_transfers: &[PywrMultiNetworkTransfer], - ) -> Result { + ) -> Result, SchemaError> { let idx = self .parameter .load(network, schema, domain, tables, data_path, inter_network_transfers)?; diff --git a/pywr-schema/src/parameters/data_frame.rs b/pywr-schema/src/parameters/data_frame.rs index 57be6a7d..d0ce6a9e 100644 --- a/pywr-schema/src/parameters/data_frame.rs +++ b/pywr-schema/src/parameters/data_frame.rs @@ -75,7 +75,7 @@ impl DataFrameParameter { network: &mut pywr_core::network::Network, domain: &ModelDomain, data_path: Option<&Path>, - ) -> Result { + ) -> Result, SchemaError> { // Handle the case of an optional data path with a relative url. let pth = if let Some(dp) = data_path { if self.url.is_relative() { diff --git a/pywr-schema/src/parameters/delay.rs b/pywr-schema/src/parameters/delay.rs index 79021b85..de891c36 100644 --- a/pywr-schema/src/parameters/delay.rs +++ b/pywr-schema/src/parameters/delay.rs @@ -39,7 +39,7 @@ impl DelayParameter { tables: &LoadedTableCollection, data_path: Option<&Path>, inter_network_transfers: &[PywrMultiNetworkTransfer], - ) -> Result { + ) -> Result, SchemaError> { let metric = self .metric .load(network, schema, domain, tables, data_path, inter_network_transfers)?; diff --git a/pywr-schema/src/parameters/discount_factor.rs b/pywr-schema/src/parameters/discount_factor.rs index bef78112..0ab9a321 100644 --- a/pywr-schema/src/parameters/discount_factor.rs +++ b/pywr-schema/src/parameters/discount_factor.rs @@ -40,7 +40,7 @@ impl DiscountFactorParameter { tables: &LoadedTableCollection, data_path: Option<&Path>, inter_network_transfers: &[PywrMultiNetworkTransfer], - ) -> Result { + ) -> Result, SchemaError> { let discount_rate = self.discount_rate .load(network, schema, domain, tables, data_path, inter_network_transfers)?; diff --git a/pywr-schema/src/parameters/indexed_array.rs b/pywr-schema/src/parameters/indexed_array.rs index 3534db62..af1fb583 100644 --- a/pywr-schema/src/parameters/indexed_array.rs +++ b/pywr-schema/src/parameters/indexed_array.rs @@ -42,7 +42,7 @@ impl IndexedArrayParameter { tables: &LoadedTableCollection, data_path: Option<&Path>, inter_network_transfers: &[PywrMultiNetworkTransfer], - ) -> Result { + ) -> Result, SchemaError> { let index_parameter = self.index_parameter .load(network, schema, domain, tables, data_path, inter_network_transfers)?; diff --git a/pywr-schema/src/parameters/interpolated.rs b/pywr-schema/src/parameters/interpolated.rs index 68bf67d2..b8de7960 100644 --- a/pywr-schema/src/parameters/interpolated.rs +++ b/pywr-schema/src/parameters/interpolated.rs @@ -58,7 +58,7 @@ impl InterpolatedParameter { tables: &LoadedTableCollection, data_path: Option<&Path>, inter_network_transfers: &[PywrMultiNetworkTransfer], - ) -> Result { + ) -> Result, SchemaError> { let x = self .x .load(network, schema, domain, tables, data_path, inter_network_transfers)?; diff --git a/pywr-schema/src/parameters/mod.rs b/pywr-schema/src/parameters/mod.rs index b65e3afd..803d4222 100644 --- a/pywr-schema/src/parameters/mod.rs +++ b/pywr-schema/src/parameters/mod.rs @@ -42,7 +42,7 @@ pub use super::parameters::profiles::{ DailyProfileParameter, MonthlyProfileParameter, RadialBasisFunction, RbfProfileParameter, RbfProfileVariableSettings, UniformDrawdownProfileParameter, WeeklyProfileParameter, }; -pub use super::parameters::python::PythonParameter; +pub use super::parameters::python::{PythonModule, PythonParameter, PythonReturnType}; pub use super::parameters::tables::TablesArrayParameter; pub use super::parameters::thresholds::ParameterThresholdParameter; use crate::error::{ConversionError, SchemaError}; @@ -55,7 +55,7 @@ use crate::parameters::interpolated::InterpolatedParameter; pub use offset::OffsetParameter; use pywr_core::metric::Metric; use pywr_core::models::{ModelDomain, MultiNetworkTransferIndex}; -use pywr_core::parameters::{IndexParameterIndex, IndexValue, ParameterType}; +use pywr_core::parameters::{IndexValue, ParameterIndex, ParameterType}; use pywr_v1_schema::parameters::{ CoreParameter, ExternalDataRef as ExternalDataRefV1, Parameter as ParameterV1, ParameterMeta as ParameterMetaV1, ParameterValue as ParameterValueV1, TableIndex as TableIndexV1, TableIndexEntry as TableIndexEntryV1, @@ -733,7 +733,7 @@ impl ParameterIndexValue { tables: &LoadedTableCollection, data_path: Option<&Path>, inter_network_transfers: &[PywrMultiNetworkTransfer], - ) -> Result { + ) -> Result, SchemaError> { match self { Self::Reference(name) => { // This should be an existing parameter diff --git a/pywr-schema/src/parameters/offset.rs b/pywr-schema/src/parameters/offset.rs index 43d58440..3e35df93 100644 --- a/pywr-schema/src/parameters/offset.rs +++ b/pywr-schema/src/parameters/offset.rs @@ -55,7 +55,7 @@ impl OffsetParameter { tables: &LoadedTableCollection, data_path: Option<&Path>, inter_network_transfers: &[PywrMultiNetworkTransfer], - ) -> Result { + ) -> Result, SchemaError> { let idx = self .metric .load(network, schema, domain, tables, data_path, inter_network_transfers)?; diff --git a/pywr-schema/src/parameters/polynomial.rs b/pywr-schema/src/parameters/polynomial.rs index 377058b3..674c35aa 100644 --- a/pywr-schema/src/parameters/polynomial.rs +++ b/pywr-schema/src/parameters/polynomial.rs @@ -23,7 +23,7 @@ impl Polynomial1DParameter { HashMap::new() } - pub fn add_to_model(&self, network: &mut pywr_core::network::Network) -> Result { + pub fn add_to_model(&self, network: &mut pywr_core::network::Network) -> Result, SchemaError> { let metric = network.get_storage_node_metric(&self.storage_node, None, self.use_proportional_volume.unwrap_or(true))?; diff --git a/pywr-schema/src/parameters/profiles.rs b/pywr-schema/src/parameters/profiles.rs index 0a24b895..78bc1c89 100644 --- a/pywr-schema/src/parameters/profiles.rs +++ b/pywr-schema/src/parameters/profiles.rs @@ -31,7 +31,7 @@ impl DailyProfileParameter { &self, network: &mut pywr_core::network::Network, tables: &LoadedTableCollection, - ) -> Result { + ) -> Result, SchemaError> { let values = &self.values.load(tables)?[..366]; let p = pywr_core::parameters::DailyProfileParameter::new(&self.meta.name, values.try_into().expect("")); Ok(network.add_parameter(Box::new(p))?) @@ -101,7 +101,7 @@ impl MonthlyProfileParameter { &self, network: &mut pywr_core::network::Network, tables: &LoadedTableCollection, - ) -> Result { + ) -> Result, SchemaError> { let values = &self.values.load(tables)?[..12]; let p = pywr_core::parameters::MonthlyProfileParameter::new( &self.meta.name, @@ -175,7 +175,7 @@ impl UniformDrawdownProfileParameter { &self, network: &mut pywr_core::network::Network, tables: &LoadedTableCollection, - ) -> Result { + ) -> Result, SchemaError> { let reset_day = match &self.reset_day { Some(v) => v.load(tables)? as u32, None => 1, @@ -363,7 +363,7 @@ impl RbfProfileParameter { HashMap::new() } - pub fn add_to_model(&self, network: &mut pywr_core::network::Network) -> Result { + pub fn add_to_model(&self, network: &mut pywr_core::network::Network) -> Result, SchemaError> { let function = self.function.into_core_rbf(&self.points)?; let p = pywr_core::parameters::RbfProfileParameter::new(&self.meta.name, self.points.clone(), function); @@ -545,7 +545,7 @@ impl WeeklyProfileParameter { &self, network: &mut pywr_core::network::Network, tables: &LoadedTableCollection, - ) -> Result { + ) -> Result, SchemaError> { let p = pywr_core::parameters::WeeklyProfileParameter::new( &self.meta.name, WeeklyProfileValues::try_from(self.values.load(tables)?.as_slice()).map_err( diff --git a/pywr-schema/src/parameters/python.rs b/pywr-schema/src/parameters/python.rs index 52b62f61..ae42d25c 100644 --- a/pywr-schema/src/parameters/python.rs +++ b/pywr-schema/src/parameters/python.rs @@ -18,12 +18,22 @@ pub enum PythonModule { Path(PathBuf), } +/// The expected return type of the Python parameter. +#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, Default)] +#[serde(rename_all = "lowercase")] +pub enum PythonReturnType { + #[default] + Float, + Int, + Dict, +} + /// A Parameter that uses a Python object for its calculations. /// -/// This struct defines a schema for loading a [`crate::parameters::PyParameter`] from external +/// This struct defines a schema for loading a [`PyParameter`] from external /// sources. The user provides the name of an object in the given module. Typically, this object will be /// a class the user has written. For more information on the expected format and signature of -/// this object please refer to the [`crate::parameters::PyParameter`] documentation. The object +/// this object please refer to the [`PyParameter`] documentation. The object /// is initialised with user provided positional and/or keyword arguments that can be provided /// here. /// @@ -67,10 +77,10 @@ pub struct PythonParameter { pub module: PythonModule, /// The name of Python object from the module to use. pub object: String, - /// Is this a multi-valued parameter or not. If true then the calculation method should - /// return a dictionary with string keys and either floats or ints as values. + /// The return type of the Python calculation. This is used to convert the Python return value + /// to the appropriate type for the Parameter. #[serde(default)] - pub multi: bool, + pub return_type: PythonReturnType, /// Position arguments to pass to the object during setup. pub args: Vec, /// Keyword arguments to pass to the object during setup. @@ -193,10 +203,11 @@ impl PythonParameter { }; let p = PyParameter::new(&self.meta.name, object, args, kwargs, &metrics, &indices); - let pt = if self.multi { - ParameterType::Multi(network.add_multi_value_parameter(Box::new(p))?) - } else { - ParameterType::Parameter(network.add_parameter(Box::new(p))?) + + let pt = match self.return_type { + PythonReturnType::Float => ParameterType::Parameter(network.add_parameter(Box::new(p))?), + PythonReturnType::Int => ParameterType::Index(network.add_index_parameter(Box::new(p))?), + PythonReturnType::Dict => ParameterType::Multi(network.add_multi_value_parameter(Box::new(p))?), }; Ok(pt) @@ -212,42 +223,59 @@ mod tests { use pywr_core::network::Network; use pywr_core::test_utils::default_time_domain; use serde_json::json; - use std::fs::File; - use std::io::Write; - use tempfile::tempdir; + use std::path::PathBuf; #[test] - fn test_python_parameter() { - let dir = tempdir().unwrap(); - - let file_path = dir.path().join("my_parameter.py"); + fn test_python_float_parameter() { + let mut py_fn = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + py_fn.push("src/test_models/test_parameters.py"); let data = json!( { - "name": "my-custom-calculation", + "name": "my-float-parameter", "type": "Python", - "path": file_path, - "object": "MyParameter", + "path": py_fn, + "object": "FloatParameter", "args": [0, ], "kwargs": {}, } ) .to_string(); - let mut file = File::create(file_path).unwrap(); - write!( - file, - r#" -class MyParameter: - def __init__(self, count, *args, **kwargs): - self.count = 0 + // Init Python + pyo3::prepare_freethreaded_python(); + // Load the schema ... + let param: PythonParameter = serde_json::from_str(data.as_str()).unwrap(); + // ... add it to an empty network + // this should trigger loading the module and extracting the class + let domain: ModelDomain = default_time_domain().into(); + let schema = PywrNetwork::default(); + let mut network = Network::default(); + let tables = LoadedTableCollection::from_schema(None, None).unwrap(); + param + .add_to_model(&mut network, &schema, &domain, &tables, None, &[]) + .unwrap(); + + assert!(network.get_parameter_by_name("my-float-parameter").is_ok()); + } + + #[test] + fn test_python_int_parameter() { + let mut py_fn = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + py_fn.push("src/test_models/test_parameters.py"); - def calc(self, ts, si, p_values): - self.count += si - return float(self.count + ts.day) -"# + let data = json!( + { + "name": "my-int-parameter", + "type": "Python", + "path": py_fn, + "return_type": "int", + "object": "FloatParameter", + "args": [0, ], + "kwargs": {}, + } ) - .unwrap(); + .to_string(); // Init Python pyo3::prepare_freethreaded_python(); @@ -262,5 +290,7 @@ class MyParameter: param .add_to_model(&mut network, &schema, &domain, &tables, None, &[]) .unwrap(); + + assert!(network.get_index_parameter_by_name("my-int-parameter").is_ok()); } } diff --git a/pywr-schema/src/parameters/tables.rs b/pywr-schema/src/parameters/tables.rs index 6fe16f8b..5e9da5d8 100644 --- a/pywr-schema/src/parameters/tables.rs +++ b/pywr-schema/src/parameters/tables.rs @@ -33,7 +33,7 @@ impl TablesArrayParameter { network: &mut pywr_core::network::Network, domain: &ModelDomain, data_path: Option<&Path>, - ) -> Result { + ) -> Result, SchemaError> { // 1. Load the file from the HDF5 file (NB this is not Pandas format). // Handle the case of an optional data path with a relative url. diff --git a/pywr-schema/src/parameters/thresholds.rs b/pywr-schema/src/parameters/thresholds.rs index e24ba041..3770c6d0 100644 --- a/pywr-schema/src/parameters/thresholds.rs +++ b/pywr-schema/src/parameters/thresholds.rs @@ -5,7 +5,7 @@ use crate::parameters::{ DynamicFloatValue, DynamicFloatValueType, IntoV2Parameter, ParameterMeta, TryFromV1Parameter, TryIntoV2Parameter, }; use pywr_core::models::ModelDomain; -use pywr_core::parameters::IndexParameterIndex; +use pywr_core::parameters::ParameterIndex; use pywr_v1_schema::parameters::{ ParameterThresholdParameter as ParameterThresholdParameterV1, Predicate as PredicateV1, }; @@ -77,7 +77,7 @@ impl ParameterThresholdParameter { tables: &LoadedTableCollection, data_path: Option<&Path>, inter_network_transfers: &[PywrMultiNetworkTransfer], - ) -> Result { + ) -> Result, SchemaError> { let metric = self .parameter .load(network, schema, domain, tables, data_path, inter_network_transfers)?; diff --git a/pywr-schema/src/test_models/test_parameters.py b/pywr-schema/src/test_models/test_parameters.py new file mode 100644 index 00000000..47794588 --- /dev/null +++ b/pywr-schema/src/test_models/test_parameters.py @@ -0,0 +1,20 @@ +class FloatParameter: + """A simple float parameter""" + + def __init__(self, count, *args, **kwargs): + self.count = 0 + + def calc(self, ts, si, p_values) -> float: + self.count += si + return float(self.count + ts.day) + + +class IntParameter: + """A simple int parameter""" + + def __init__(self, count, *args, **kwargs): + self.count = 0 + + def calc(self, ts, si, p_values) -> int: + self.count += si + return self.count + ts.day