Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Make ParameterIndex generic #138

Merged
merged 1 commit into from
Mar 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions pywr-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ extern crate core;
use crate::derived_metric::DerivedMetricIndex;
use crate::models::MultiNetworkTransferIndex;
use crate::node::NodeIndex;
use crate::parameters::{IndexParameterIndex, InterpolationError, MultiValueParameterIndex, ParameterIndex};
use crate::parameters::{InterpolationError, ParameterIndex};
use crate::recorders::{AggregationError, MetricSetIndex, RecorderIndex};
use crate::state::MultiValue;
use crate::virtual_storage::VirtualStorageIndex;
use pyo3::exceptions::{PyException, PyRuntimeError};
use pyo3::{create_exception, PyErr};
Expand Down Expand Up @@ -44,11 +45,11 @@ pub enum PywrError {
#[error("virtual storage index {0} not found")]
VirtualStorageIndexNotFound(VirtualStorageIndex),
#[error("parameter index {0} not found")]
ParameterIndexNotFound(ParameterIndex),
ParameterIndexNotFound(ParameterIndex<f64>),
#[error("index parameter index {0} not found")]
IndexParameterIndexNotFound(IndexParameterIndex),
IndexParameterIndexNotFound(ParameterIndex<usize>),
#[error("multi1 value parameter index {0} not found")]
MultiValueParameterIndexNotFound(MultiValueParameterIndex),
MultiValueParameterIndexNotFound(ParameterIndex<MultiValue>),
#[error("multi1 value parameter key {0} not found")]
MultiValueParameterKeyNotFound(String),
#[error("inter-network parameter state not initialised")]
Expand All @@ -72,9 +73,9 @@ pub enum PywrError {
#[error("node name `{0}` already exists")]
NodeNameAlreadyExists(String),
#[error("parameter name `{0}` already exists at index {1}")]
ParameterNameAlreadyExists(String, ParameterIndex),
ParameterNameAlreadyExists(String, ParameterIndex<f64>),
#[error("index parameter name `{0}` already exists at index {1}")]
IndexParameterNameAlreadyExists(String, IndexParameterIndex),
IndexParameterNameAlreadyExists(String, ParameterIndex<usize>),
#[error("metric set name `{0}` already exists")]
MetricSetNameAlreadyExists(String),
#[error("recorder name `{0}` already exists at index {1}")]
Expand Down Expand Up @@ -158,7 +159,7 @@ pub enum PywrError {
#[error("parameters do not provide an initial value")]
ParameterNoInitialValue,
#[error("parameter state not found for parameter index {0}")]
ParameterStateNotFound(ParameterIndex),
ParameterStateNotFound(ParameterIndex<f64>),
#[error("Could not create timestep range due to following error: {0}")]
TimestepRangeGenerationError(String),
#[error("Could not create timesteps for frequency '{0}'")]
Expand Down
10 changes: 5 additions & 5 deletions pywr-core/src/metric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use crate::edge::EdgeIndex;
use crate::models::MultiNetworkTransferIndex;
use crate::network::Network;
use crate::node::NodeIndex;
use crate::parameters::{IndexParameterIndex, MultiValueParameterIndex, ParameterIndex};
use crate::state::State;
use crate::parameters::ParameterIndex;
use crate::state::{MultiValue, State};
use crate::virtual_storage::VirtualStorageIndex;
use crate::PywrError;
#[derive(Clone, Debug, PartialEq)]
Expand All @@ -18,8 +18,8 @@ pub enum Metric {
AggregatedNodeOutFlow(AggregatedNodeIndex),
AggregatedNodeVolume(AggregatedStorageNodeIndex),
EdgeFlow(EdgeIndex),
ParameterValue(ParameterIndex),
MultiParameterValue((MultiValueParameterIndex, String)),
ParameterValue(ParameterIndex<f64>),
MultiParameterValue((ParameterIndex<MultiValue>, String)),
VirtualStorageVolume(VirtualStorageIndex),
MultiNodeInFlow { indices: Vec<NodeIndex>, name: String },
MultiNodeOutFlow { indices: Vec<NodeIndex>, name: String },
Expand Down Expand Up @@ -87,7 +87,7 @@ impl Metric {

#[derive(Clone, Debug, PartialEq)]
pub enum IndexMetric {
IndexParameterValue(IndexParameterIndex),
IndexParameterValue(ParameterIndex<usize>),
Constant(usize),
}

Expand Down
45 changes: 24 additions & 21 deletions pywr-core/src/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ use crate::edge::{EdgeIndex, EdgeVec};
use crate::metric::Metric;
use crate::models::ModelDomain;
use crate::node::{ConstraintValue, Node, NodeVec, StorageInitialVolume};
use crate::parameters::{MultiValueParameterIndex, ParameterType, VariableConfig};
use crate::parameters::{ParameterType, VariableConfig};
use crate::recorders::{MetricSet, MetricSetIndex, MetricSetState};
use crate::scenario::ScenarioIndex;
use crate::solvers::{MultiStateSolver, Solver, SolverFeatures, SolverTimings};
use crate::state::{MultiValue, ParameterStates, State, StateBuilder};
use crate::timestep::Timestep;
use crate::virtual_storage::{VirtualStorage, VirtualStorageIndex, VirtualStorageReset, VirtualStorageVec};
use crate::{parameters, recorders, IndexParameterIndex, NodeIndex, ParameterIndex, PywrError, RecorderIndex};
use crate::{parameters, recorders, NodeIndex, ParameterIndex, PywrError, RecorderIndex};
use rayon::prelude::*;
use std::any::Any;
use std::collections::HashSet;
Expand Down Expand Up @@ -1089,7 +1089,7 @@ impl Network {
}

/// Get a `Parameter` from a parameter's name
pub fn get_parameter(&self, index: &ParameterIndex) -> Result<&dyn parameters::Parameter<f64>, PywrError> {
pub fn get_parameter(&self, index: &ParameterIndex<f64>) -> Result<&dyn parameters::Parameter<f64>, PywrError> {
match self.parameters.get(*index.deref()) {
Some(p) => Ok(p.as_ref()),
None => Err(PywrError::ParameterIndexNotFound(*index)),
Expand All @@ -1099,7 +1099,7 @@ impl Network {
/// Get a `Parameter` from a parameter's name
pub fn get_mut_parameter(
&mut self,
index: &ParameterIndex,
index: &ParameterIndex<f64>,
) -> Result<&mut dyn parameters::Parameter<f64>, PywrError> {
match self.parameters.get_mut(*index.deref()) {
Some(p) => Ok(p.as_mut()),
Expand All @@ -1116,7 +1116,7 @@ impl Network {
}

/// Get a `ParameterIndex` from a parameter's name
pub fn get_parameter_index_by_name(&self, name: &str) -> Result<ParameterIndex, PywrError> {
pub fn get_parameter_index_by_name(&self, name: &str) -> Result<ParameterIndex<f64>, PywrError> {
match self.parameters.iter().position(|p| p.name() == name) {
Some(idx) => Ok(ParameterIndex::new(idx)),
None => Err(PywrError::ParameterNotFound(name.to_string())),
Expand All @@ -1132,17 +1132,20 @@ impl Network {
}

/// Get a `IndexParameterIndex` from a parameter's name
pub fn get_index_parameter_index_by_name(&self, name: &str) -> Result<IndexParameterIndex, PywrError> {
pub fn get_index_parameter_index_by_name(&self, name: &str) -> Result<ParameterIndex<usize>, PywrError> {
match self.index_parameters.iter().position(|p| p.name() == name) {
Some(idx) => Ok(IndexParameterIndex::new(idx)),
Some(idx) => Ok(ParameterIndex::new(idx)),
None => Err(PywrError::ParameterNotFound(name.to_string())),
}
}

/// Get a `MultiValueParameterIndex` from a parameter's name
pub fn get_multi_valued_parameter_index_by_name(&self, name: &str) -> Result<MultiValueParameterIndex, PywrError> {
pub fn get_multi_valued_parameter_index_by_name(
&self,
name: &str,
) -> Result<ParameterIndex<MultiValue>, PywrError> {
match self.multi_parameters.iter().position(|p| p.name() == name) {
Some(idx) => Ok(MultiValueParameterIndex::new(idx)),
Some(idx) => Ok(ParameterIndex::new(idx)),
None => Err(PywrError::ParameterNotFound(name.to_string())),
}
}
Expand Down Expand Up @@ -1317,7 +1320,7 @@ impl Network {
pub fn add_parameter(
&mut self,
parameter: Box<dyn parameters::Parameter<f64>>,
) -> Result<ParameterIndex, PywrError> {
) -> Result<ParameterIndex<f64>, PywrError> {
if let Ok(idx) = self.get_parameter_index_by_name(&parameter.meta().name) {
return Err(PywrError::ParameterNameAlreadyExists(
parameter.meta().name.to_string(),
Expand All @@ -1339,15 +1342,15 @@ impl Network {
pub fn add_index_parameter(
&mut self,
index_parameter: Box<dyn parameters::Parameter<usize>>,
) -> Result<IndexParameterIndex, PywrError> {
) -> Result<ParameterIndex<usize>, PywrError> {
if let Ok(idx) = self.get_index_parameter_index_by_name(&index_parameter.meta().name) {
return Err(PywrError::IndexParameterNameAlreadyExists(
index_parameter.meta().name.to_string(),
idx,
));
}

let parameter_index = IndexParameterIndex::new(self.index_parameters.len());
let parameter_index = ParameterIndex::new(self.index_parameters.len());

self.index_parameters.push(index_parameter);
// .. and add it to the resolve order
Expand All @@ -1360,15 +1363,15 @@ impl Network {
pub fn add_multi_value_parameter(
&mut self,
parameter: Box<dyn parameters::Parameter<MultiValue>>,
) -> Result<MultiValueParameterIndex, PywrError> {
) -> Result<ParameterIndex<MultiValue>, PywrError> {
if let Ok(idx) = self.get_parameter_index_by_name(&parameter.meta().name) {
return Err(PywrError::ParameterNameAlreadyExists(
parameter.meta().name.to_string(),
idx,
));
}

let parameter_index = MultiValueParameterIndex::new(self.multi_parameters.len());
let parameter_index = ParameterIndex::new(self.multi_parameters.len());

// Add the parameter ...
self.multi_parameters.push(parameter);
Expand Down Expand Up @@ -1457,7 +1460,7 @@ impl Network {
/// This will update the internal state of the parameter with the new values for all scenarios.
pub fn set_f64_parameter_variable_values(
&self,
parameter_index: ParameterIndex,
parameter_index: ParameterIndex<f64>,
values: &[f64],
variable_config: &dyn VariableConfig,
state: &mut NetworkState,
Expand Down Expand Up @@ -1487,7 +1490,7 @@ impl Network {
/// Only the internal state of the parameter for the given scenario will be updated.
pub fn set_f64_parameter_variable_values_for_scenario(
&self,
parameter_index: ParameterIndex,
parameter_index: ParameterIndex<f64>,
scenario_index: ScenarioIndex,
values: &[f64],
variable_config: &dyn VariableConfig,
Expand All @@ -1511,7 +1514,7 @@ impl Network {
/// Return a vector of the current values of active variable parameters.
pub fn get_f64_parameter_variable_values_for_scenario(
&self,
parameter_index: ParameterIndex,
parameter_index: ParameterIndex<f64>,
scenario_index: ScenarioIndex,
state: &NetworkState,
) -> Result<Option<Vec<f64>>, PywrError> {
Expand All @@ -1533,7 +1536,7 @@ impl Network {

pub fn get_f64_parameter_variable_values(
&self,
parameter_index: ParameterIndex,
parameter_index: ParameterIndex<f64>,
state: &NetworkState,
) -> Result<Vec<Option<Vec<f64>>>, PywrError> {
match self.parameters.get(*parameter_index.deref()) {
Expand Down Expand Up @@ -1563,7 +1566,7 @@ impl Network {
/// This will update the internal state of the parameter with the new values for scenarios.
pub fn set_u32_parameter_variable_values(
&self,
parameter_index: ParameterIndex,
parameter_index: ParameterIndex<f64>,
values: &[u32],
variable_config: &dyn VariableConfig,
state: &mut NetworkState,
Expand Down Expand Up @@ -1593,7 +1596,7 @@ impl Network {
/// Only the internal state of the parameter for the given scenario will be updated.
pub fn set_u32_parameter_variable_values_for_scenario(
&self,
parameter_index: ParameterIndex,
parameter_index: ParameterIndex<f64>,
scenario_index: ScenarioIndex,
values: &[u32],
variable_config: &dyn VariableConfig,
Expand All @@ -1617,7 +1620,7 @@ impl Network {
/// Return a vector of the current values of active variable parameters.
pub fn get_u32_parameter_variable_values_for_scenario(
&self,
parameter_index: ParameterIndex,
parameter_index: ParameterIndex<f64>,
scenario_index: ScenarioIndex,
state: &NetworkState,
) -> Result<Option<Vec<u32>>, PywrError> {
Expand Down
85 changes: 34 additions & 51 deletions pywr-core/src/parameters/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub use self::rhai::RhaiParameter;
use super::PywrError;
use crate::network::Network;
use crate::scenario::ScenarioIndex;
use crate::state::{ParameterState, State};
use crate::state::{MultiValue, ParameterState, State};
use crate::timestep::Timestep;
pub use activation_function::ActivationFunction;
pub use aggregated::{AggFunc, AggregatedParameter};
Expand Down Expand Up @@ -59,76 +59,59 @@ pub use profiles::{
pub use py::PyParameter;
use std::fmt;
use std::fmt::{Display, Formatter};
use std::marker::PhantomData;
use std::ops::Deref;
pub use threshold::{Predicate, ThresholdParameter};
pub use vector::VectorParameter;

#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct ParameterIndex(usize);

#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct IndexParameterIndex(usize);

#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct MultiValueParameterIndex(usize);

impl ParameterIndex {
pub fn new(idx: usize) -> Self {
Self(idx)
}
}

impl IndexParameterIndex {
pub fn new(idx: usize) -> Self {
Self(idx)
}
/// Generic parameter index.
///
/// This is a wrapper around usize that is used to index parameters in the state. It is
/// generic over the type of the value that the parameter returns.
#[derive(Debug)]
pub struct ParameterIndex<T> {
idx: usize,
phantom: PhantomData<T>,
}

impl MultiValueParameterIndex {
pub fn new(idx: usize) -> Self {
Self(idx)
// These implementations are required because the derive macro does not work well with PhantomData.
// See issue: https://github.com/rust-lang/rust/issues/26925
impl<T> Clone for ParameterIndex<T> {
fn clone(&self) -> Self {
*self
}
}

impl Deref for ParameterIndex {
type Target = usize;
impl<T> Copy for ParameterIndex<T> {}

fn deref(&self) -> &Self::Target {
&self.0
impl<T> PartialEq<Self> for ParameterIndex<T> {
fn eq(&self, other: &Self) -> bool {
self.idx == other.idx
}
}

impl Deref for IndexParameterIndex {
type Target = usize;
impl<T> Eq for ParameterIndex<T> {}

fn deref(&self) -> &Self::Target {
&self.0
impl<T> ParameterIndex<T> {
pub fn new(idx: usize) -> Self {
Self {
idx,
phantom: PhantomData,
}
}
}

impl Deref for MultiValueParameterIndex {
impl<T> Deref for ParameterIndex<T> {
type Target = usize;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl Display for ParameterIndex {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}

impl Display for IndexParameterIndex {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
&self.idx
}
}

impl Display for MultiValueParameterIndex {
impl<T> Display for ParameterIndex<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
write!(f, "{}", self.idx)
}
}

Expand Down Expand Up @@ -271,7 +254,7 @@ pub trait Parameter<T>: Send + Sync {
#[derive(Copy, Clone)]
pub enum IndexValue {
Constant(usize),
Dynamic(IndexParameterIndex),
Dynamic(ParameterIndex<usize>),
}

impl IndexValue {
Expand All @@ -284,9 +267,9 @@ impl IndexValue {
}

pub enum ParameterType {
Parameter(ParameterIndex),
Index(IndexParameterIndex),
Multi(MultiValueParameterIndex),
Parameter(ParameterIndex<f64>),
Index(ParameterIndex<usize>),
Multi(ParameterIndex<MultiValue>),
}

/// A parameter that can be optimised.
Expand Down
Loading