Skip to content

Commit

Permalink
fix: Remove Box from variable API.
Browse files Browse the repository at this point in the history
  • Loading branch information
jetuk committed Feb 13, 2024
1 parent da6babb commit 6577373
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 29 deletions.
10 changes: 5 additions & 5 deletions pywr-core/src/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1423,7 +1423,7 @@ impl Network {
&self,
parameter_index: ParameterIndex,
values: &[f64],
variable_config: &Box<dyn VariableConfig>,
variable_config: &dyn VariableConfig,
state: &mut NetworkState,
) -> Result<(), PywrError> {
match self.parameters.get(*parameter_index.deref()) {
Expand Down Expand Up @@ -1454,7 +1454,7 @@ impl Network {
parameter_index: ParameterIndex,
scenario_index: ScenarioIndex,
values: &[f64],
variable_config: &Box<dyn VariableConfig>,
variable_config: &dyn VariableConfig,
state: &mut NetworkState,
) -> Result<(), PywrError> {
match self.parameters.get(*parameter_index.deref()) {
Expand Down Expand Up @@ -1529,7 +1529,7 @@ impl Network {
&self,
parameter_index: ParameterIndex,
values: &[u32],
variable_config: &Box<dyn VariableConfig>,
variable_config: &dyn VariableConfig,
state: &mut NetworkState,
) -> Result<(), PywrError> {
match self.parameters.get(*parameter_index.deref()) {
Expand Down Expand Up @@ -1560,7 +1560,7 @@ impl Network {
parameter_index: ParameterIndex,
scenario_index: ScenarioIndex,
values: &[u32],
variable_config: &Box<dyn VariableConfig>,
variable_config: &dyn VariableConfig,
state: &mut NetworkState,
) -> Result<(), PywrError> {
match self.parameters.get(*parameter_index.deref()) {
Expand Down Expand Up @@ -1934,7 +1934,7 @@ mod tests {
fn test_variable_api() {
let mut model = simple_model(1);

let variable: Box<dyn VariableConfig> = Box::new(ActivationFunction::Unit { min: 0.0, max: 10.0 });
let variable = ActivationFunction::Unit { min: 0.0, max: 10.0 };
let input_max_flow = parameters::ConstantParameter::new("my-constant", 10.0);

assert!(input_max_flow.can_be_f64_variable());
Expand Down
12 changes: 6 additions & 6 deletions pywr-core/src/parameters/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ impl VariableParameter<f64> for ConstantParameter {
&self.meta
}

fn size(&self, _variable_config: &Box<dyn VariableConfig>) -> usize {
fn size(&self, _variable_config: &dyn VariableConfig) -> usize {
1
}

fn set_variables(
&self,
values: &[f64],
variable_config: &Box<dyn VariableConfig>,
variable_config: &dyn VariableConfig,
internal_state: &mut Option<Box<dyn ParameterState>>,
) -> Result<(), PywrError> {
let activation_function = downcast_variable_config_ref::<ActivationFunction>(variable_config);
Expand All @@ -108,28 +108,28 @@ impl VariableParameter<f64> for ConstantParameter {
}
}

fn get_lower_bounds(&self, variable_config: &Box<dyn VariableConfig>) -> Result<Vec<f64>, PywrError> {
fn get_lower_bounds(&self, variable_config: &dyn VariableConfig) -> Result<Vec<f64>, PywrError> {
let activation_function = downcast_variable_config_ref::<ActivationFunction>(variable_config);
Ok(vec![activation_function.lower_bound()])
}

fn get_upper_bounds(&self, variable_config: &Box<dyn VariableConfig>) -> Result<Vec<f64>, PywrError> {
fn get_upper_bounds(&self, variable_config: &dyn VariableConfig) -> Result<Vec<f64>, PywrError> {
let activation_function = downcast_variable_config_ref::<ActivationFunction>(variable_config);
Ok(vec![activation_function.upper_bound()])
}
}

#[cfg(test)]
mod tests {
use crate::parameters::{ActivationFunction, ConstantParameter, Parameter, VariableConfig, VariableParameter};
use crate::parameters::{ActivationFunction, ConstantParameter, Parameter, VariableParameter};
use crate::test_utils::default_domain;
use float_cmp::assert_approx_eq;

#[test]
fn test_variable_api() {
let domain = default_domain();

let var: Box<dyn VariableConfig> = Box::new(ActivationFunction::Unit { min: 0.0, max: 2.0 });
let var = ActivationFunction::Unit { min: 0.0, max: 2.0 };
let p = ConstantParameter::new("test", 1.0);
let mut state = p
.setup(
Expand Down
12 changes: 6 additions & 6 deletions pywr-core/src/parameters/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,9 @@ where
}

/// Helper function to downcast to variable config and print a helpful panic message if this fails.
pub fn downcast_variable_config_ref<T: 'static>(variable_config: &Box<dyn VariableConfig>) -> &T {
pub fn downcast_variable_config_ref<T: 'static>(variable_config: &dyn VariableConfig) -> &T {
// Downcast the internal state to the correct type
match variable_config.as_ref().as_any().downcast_ref::<T>() {
match variable_config.as_any().downcast_ref::<T>() {
Some(pa) => pa,
None => panic!("Variable config did not downcast to the correct type! :("),
}
Expand Down Expand Up @@ -369,20 +369,20 @@ pub trait VariableParameter<T> {
}

/// Return the number of variables required
fn size(&self, variable_config: &Box<dyn VariableConfig>) -> usize;
fn size(&self, variable_config: &dyn VariableConfig) -> usize;
/// Apply new variable values to the parameter's state
fn set_variables(
&self,
values: &[T],
variable_config: &Box<dyn VariableConfig>,
variable_config: &dyn VariableConfig,
internal_state: &mut Option<Box<dyn ParameterState>>,
) -> Result<(), PywrError>;
/// Get the current variable values
fn get_variables(&self, internal_state: &Option<Box<dyn ParameterState>>) -> Option<Vec<T>>;
/// Get variable lower bounds
fn get_lower_bounds(&self, variable_config: &Box<dyn VariableConfig>) -> Result<Vec<T>, PywrError>;
fn get_lower_bounds(&self, variable_config: &dyn VariableConfig) -> Result<Vec<T>, PywrError>;
/// Get variable upper bounds
fn get_upper_bounds(&self, variable_config: &Box<dyn VariableConfig>) -> Result<Vec<T>, PywrError>;
fn get_upper_bounds(&self, variable_config: &dyn VariableConfig) -> Result<Vec<T>, PywrError>;
}

#[cfg(test)]
Expand Down
8 changes: 4 additions & 4 deletions pywr-core/src/parameters/offset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,14 @@ impl VariableParameter<f64> for OffsetParameter {
&self.meta
}

fn size(&self, _variable_config: &Box<dyn VariableConfig>) -> usize {
fn size(&self, _variable_config: &dyn VariableConfig) -> usize {
1
}

fn set_variables(
&self,
values: &[f64],
variable_config: &Box<dyn VariableConfig>,
variable_config: &dyn VariableConfig,
internal_state: &mut Option<Box<dyn ParameterState>>,
) -> Result<(), PywrError> {
let activation_function = downcast_variable_config_ref::<ActivationFunction>(variable_config);
Expand All @@ -103,12 +103,12 @@ impl VariableParameter<f64> for OffsetParameter {
}
}

fn get_lower_bounds(&self, variable_config: &Box<dyn VariableConfig>) -> Result<Vec<f64>, PywrError> {
fn get_lower_bounds(&self, variable_config: &dyn VariableConfig) -> Result<Vec<f64>, PywrError> {
let activation_function = downcast_variable_config_ref::<ActivationFunction>(variable_config);
Ok(vec![activation_function.lower_bound()])
}

fn get_upper_bounds(&self, variable_config: &Box<dyn VariableConfig>) -> Result<Vec<f64>, PywrError> {
fn get_upper_bounds(&self, variable_config: &dyn VariableConfig) -> Result<Vec<f64>, PywrError> {
let activation_function = downcast_variable_config_ref::<ActivationFunction>(variable_config);
Ok(vec![activation_function.upper_bound()])
}
Expand Down
16 changes: 8 additions & 8 deletions pywr-core/src/parameters/profiles/rbf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,15 +155,15 @@ impl VariableParameter<f64> for RbfProfileParameter {
}

/// The size is the number of points that define the profile.
fn size(&self, _variable_config: &Box<dyn VariableConfig>) -> usize {
fn size(&self, _variable_config: &dyn VariableConfig) -> usize {
self.points.len()
}

/// The f64 values update the profile value of each point.
fn set_variables(
&self,
values: &[f64],
_variable_config: &Box<dyn VariableConfig>,
_variable_config: &dyn VariableConfig,
internal_state: &mut Option<Box<dyn ParameterState>>,
) -> Result<(), PywrError> {
if values.len() == self.points.len() {
Expand All @@ -184,13 +184,13 @@ impl VariableParameter<f64> for RbfProfileParameter {
value.points_y.clone()
}

fn get_lower_bounds(&self, variable_config: &Box<dyn VariableConfig>) -> Result<Vec<f64>, PywrError> {
fn get_lower_bounds(&self, variable_config: &dyn VariableConfig) -> Result<Vec<f64>, PywrError> {
let config = downcast_variable_config_ref::<RbfProfileVariableConfig>(variable_config);
let lb = (0..self.points.len()).map(|_| config.value_lower_bounds).collect();
Ok(lb)
}

fn get_upper_bounds(&self, variable_config: &Box<dyn VariableConfig>) -> Result<Vec<f64>, PywrError> {
fn get_upper_bounds(&self, variable_config: &dyn VariableConfig) -> Result<Vec<f64>, PywrError> {
let config = downcast_variable_config_ref::<RbfProfileVariableConfig>(variable_config);
let lb = (0..self.points.len()).map(|_| config.value_upper_bounds).collect();
Ok(lb)
Expand All @@ -202,7 +202,7 @@ impl VariableParameter<u32> for RbfProfileParameter {
&self.meta
}
/// The size is the number of points that define the profile.
fn size(&self, variable_config: &Box<dyn VariableConfig>) -> usize {
fn size(&self, variable_config: &dyn VariableConfig) -> usize {
let config = downcast_variable_config_ref::<RbfProfileVariableConfig>(variable_config);
match config.days_of_year_range {
Some(_) => self.points.len(),
Expand All @@ -214,7 +214,7 @@ impl VariableParameter<u32> for RbfProfileParameter {
fn set_variables(
&self,
values: &[u32],
_variable_config: &Box<dyn VariableConfig>,
_variable_config: &dyn VariableConfig,
internal_state: &mut Option<Box<dyn ParameterState>>,
) -> Result<(), PywrError> {
if values.len() == self.points.len() {
Expand All @@ -235,7 +235,7 @@ impl VariableParameter<u32> for RbfProfileParameter {
value.points_x.clone()
}

fn get_lower_bounds(&self, variable_config: &Box<dyn VariableConfig>) -> Result<Vec<u32>, PywrError> {
fn get_lower_bounds(&self, variable_config: &dyn VariableConfig) -> Result<Vec<u32>, PywrError> {
let config = downcast_variable_config_ref::<RbfProfileVariableConfig>(variable_config);

if let Some(days_of_year_range) = &config.days_of_year_range {
Expand All @@ -252,7 +252,7 @@ impl VariableParameter<u32> for RbfProfileParameter {
}
}

fn get_upper_bounds(&self, variable_config: &Box<dyn VariableConfig>) -> Result<Vec<u32>, PywrError> {
fn get_upper_bounds(&self, variable_config: &dyn VariableConfig) -> Result<Vec<u32>, PywrError> {
let config = downcast_variable_config_ref::<RbfProfileVariableConfig>(variable_config);

if let Some(days_of_year_range) = &config.days_of_year_range {
Expand Down

0 comments on commit 6577373

Please sign in to comment.