Skip to content

Commit

Permalink
feat: Finalise ParameterCollection API for usize and multi.
Browse files Browse the repository at this point in the history
Implements the same API for all 3 types. Adds checks to ensure
names are unique across all parameter variants.
  • Loading branch information
jetuk committed Jul 4, 2024
1 parent 7f53994 commit 07ad27a
Show file tree
Hide file tree
Showing 7 changed files with 327 additions and 83 deletions.
8 changes: 4 additions & 4 deletions pywr-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,12 @@ pub enum PywrError {
DerivedMetricIndexNotFound(DerivedMetricIndex),
#[error("node name `{0}` already exists")]
NodeNameAlreadyExists(String),
#[error("parameter name `{0}` already exists at index {1}")]
ParameterNameAlreadyExists(String, ParameterIndex<f64>),
#[error("parameter name `{0}` already exists")]
ParameterNameAlreadyExists(String),
#[error("index parameter name `{0}` already exists at index {1}")]
IndexParameterNameAlreadyExists(String, GeneralParameterIndex<usize>),
IndexParameterNameAlreadyExists(String, ParameterIndex<usize>),
#[error("multi-value parameter name `{0}` already exists at index {1}")]
MultiValueParameterNameAlreadyExists(String, GeneralParameterIndex<MultiValue>),
MultiValueParameterNameAlreadyExists(String, ParameterIndex<MultiValue>),
#[error("metric set name `{0}` already exists")]
MetricSetNameAlreadyExists(String),
#[error("recorder name `{0}` already exists at index {1}")]
Expand Down
12 changes: 12 additions & 0 deletions pywr-core/src/metric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,18 @@ impl From<ParameterIndex<usize>> for MetricF64 {
}
}

impl From<(ParameterIndex<MultiValue>, String)> for MetricF64 {
fn from((idx, key): (ParameterIndex<MultiValue>, String)) -> Self {
match idx {
ParameterIndex::General(idx) => Self::MultiParameterValue((idx, key)),
ParameterIndex::Simple(idx) => Self::Simple(SimpleMetricF64::MultiParameterValue((idx, key))),
ParameterIndex::Const(idx) => Self::Simple(SimpleMetricF64::Constant(
ConstantMetricF64::MultiParameterValue((idx, key)),
)),
}
}
}

impl TryFrom<ParameterIndex<f64>> for SimpleMetricF64 {
type Error = PywrError;
fn try_from(idx: ParameterIndex<f64>) -> Result<Self, Self::Error> {
Expand Down
32 changes: 13 additions & 19 deletions pywr-core/src/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::solvers::{MultiStateSolver, Solver, SolverFeatures, SolverTimings};
use crate::state::{MultiValue, State, StateBuilder};
use crate::timestep::Timestep;
use crate::virtual_storage::{VirtualStorage, VirtualStorageBuilder, VirtualStorageIndex, VirtualStorageVec};
use crate::{parameters, recorders, GeneralParameterIndex, NodeIndex, PywrError, RecorderIndex};
use crate::{parameters, recorders, NodeIndex, PywrError, RecorderIndex};
use rayon::prelude::*;
use std::any::Any;
use std::collections::HashSet;
Expand Down Expand Up @@ -623,7 +623,7 @@ impl Network {
GeneralParameterType::Index(idx) => {
let p = self
.parameters
.get_usize(*idx)
.get_general_usize(*idx)
.ok_or(PywrError::GeneralIndexParameterIndexNotFound(*idx))?;

// .. and its internal state
Expand All @@ -638,7 +638,7 @@ impl Network {
GeneralParameterType::Multi(idx) => {
let p = self
.parameters
.get_multi(*idx)
.get_general_multi(idx)
.ok_or(PywrError::GeneralMultiValueParameterIndexNotFound(*idx))?;

// .. and its internal state
Expand Down Expand Up @@ -714,7 +714,7 @@ impl Network {
GeneralParameterType::Index(idx) => {
let p = self
.parameters
.get_usize(*idx)
.get_general_usize(*idx)
.ok_or(PywrError::GeneralIndexParameterIndexNotFound(*idx))?;

// .. and its internal state
Expand All @@ -727,7 +727,7 @@ impl Network {
GeneralParameterType::Multi(idx) => {
let p = self
.parameters
.get_multi(*idx)
.get_general_multi(idx)
.ok_or(PywrError::GeneralMultiValueParameterIndexNotFound(*idx))?;

// .. and its internal state
Expand Down Expand Up @@ -1123,29 +1123,23 @@ impl Network {
}

/// Get a [`Parameter<usize>`] from its index.
pub fn get_index_parameter(
&self,
index: GeneralParameterIndex<usize>,
) -> Result<&dyn parameters::GeneralParameter<usize>, PywrError> {
pub fn get_index_parameter(&self, index: ParameterIndex<usize>) -> Result<&dyn parameters::Parameter, PywrError> {
match self.parameters.get_usize(index) {
Some(p) => Ok(p),
None => Err(PywrError::GeneralIndexParameterIndexNotFound(index)),
None => Err(PywrError::IndexParameterIndexNotFound(index)),
}
}

/// Get a `IndexParameter` from a parameter's name
pub fn get_index_parameter_by_name(
&self,
name: &str,
) -> Result<&dyn parameters::GeneralParameter<usize>, PywrError> {
pub fn get_index_parameter_by_name(&self, name: &str) -> Result<&dyn parameters::Parameter, PywrError> {
match self.parameters.get_usize_by_name(name) {
Some(parameter) => Ok(parameter),
None => Err(PywrError::ParameterNotFound(name.to_string())),
}
}

/// Get a `IndexParameterIndex` from a parameter's name
pub fn get_index_parameter_index_by_name(&self, name: &str) -> Result<GeneralParameterIndex<usize>, PywrError> {
pub fn get_index_parameter_index_by_name(&self, name: &str) -> Result<ParameterIndex<usize>, PywrError> {
match self.parameters.get_usize_index_by_name(name) {
Some(idx) => Ok(idx),
None => Err(PywrError::ParameterNotFound(name.to_string())),
Expand All @@ -1155,19 +1149,19 @@ impl Network {
/// Get a `MultiValueParameterIndex` from a parameter's name
pub fn get_multi_valued_parameter(
&self,
index: GeneralParameterIndex<MultiValue>,
) -> Result<&dyn parameters::GeneralParameter<MultiValue>, PywrError> {
index: &ParameterIndex<MultiValue>,
) -> Result<&dyn parameters::Parameter, PywrError> {
match self.parameters.get_multi(index) {
Some(p) => Ok(p),
None => Err(PywrError::GeneralMultiValueParameterIndexNotFound(index)),
None => Err(PywrError::MultiValueParameterIndexNotFound(index.clone())),
}
}

/// Get a `MultiValueParameterIndex` from a parameter's name
pub fn get_multi_valued_parameter_index_by_name(
&self,
name: &str,
) -> Result<GeneralParameterIndex<MultiValue>, PywrError> {
) -> Result<ParameterIndex<MultiValue>, PywrError> {
match self.parameters.get_multi_index_by_name(name) {
Some(idx) => Ok(idx),
None => Err(PywrError::ParameterNotFound(name.to_string())),
Expand Down
Loading

0 comments on commit 07ad27a

Please sign in to comment.