Skip to content

Commit

Permalink
feat: Implement TryFromV1Parameter for RbfProfileParameter
Browse files Browse the repository at this point in the history
- Also update to v1 schema v0.9.0
- Add code for estimating epsilon parameter.
- Add Quintic RBF
  • Loading branch information
jetuk committed Dec 9, 2023
1 parent ef1f84a commit e0538c6
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ tracing = "0.1"
csv = "1.1"
hdf5 = { version="0.8.1" }
hdf5-sys = { version="0.8.1", features=["static"] }
pywr-v1-schema = { git = "https://github.com/pywr/pywr-schema/", tag="v0.8.0", package = "pywr-schema" }
pywr-v1-schema = { git = "https://github.com/pywr/pywr-schema/", tag="v0.9.0", package = "pywr-schema" }
2 changes: 2 additions & 0 deletions pywr-core/src/parameters/profiles/rbf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ impl VariableParameter<u32> for RbfProfileParameter {
pub enum RadialBasisFunction {
Linear,
Cubic,
Quintic,
ThinPlateSpline,
Gaussian { epsilon: f64 },
MultiQuadric { epsilon: f64 },
Expand All @@ -224,6 +225,7 @@ impl RadialBasisFunction {
match self {
RadialBasisFunction::Linear => r,
RadialBasisFunction::Cubic => r.powi(3),
RadialBasisFunction::Quintic => r.powi(5),
RadialBasisFunction::ThinPlateSpline => r.powi(2) * r.ln(),
RadialBasisFunction::Gaussian { epsilon } => (-(epsilon * r).powi(2)).exp(),
RadialBasisFunction::MultiQuadric { epsilon } => (1.0 + (epsilon * r).powi(2)).sqrt(),
Expand Down
9 changes: 9 additions & 0 deletions pywr-schema/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ pub enum SchemaError {
UnexpectedParameterType(String),
#[error("mismatch in the length of data provided. expected: {expected}, found: {found}")]
DataLengthMismatch { expected: usize, found: usize },
#[error("Failed to estimate epsilon for use in the radial basis function.")]
RbfEpsilonEstimation,
}

impl From<SchemaError> for PyErr {
Expand Down Expand Up @@ -79,4 +81,11 @@ pub enum ConversionError {
UnsupportedFeature { feature: String, name: String },
#[error("Parameter {name:?} of type `{ty:?}` are not supported in Pywr v2. {instead:?}")]
DeprecatedParameter { ty: String, name: String, instead: String },
#[error("Unexpected type for attribute {attr} on parameter {name}. Expected `{expected}`, found `{actual}`")]
UnexpectedType {
attr: String,
name: String,
expected: String,
actual: String,
},
}
3 changes: 3 additions & 0 deletions pywr-schema/src/parameters/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,9 @@ impl TryFromV1Parameter<ParameterV1> for Parameter {
instead: "Use a derived metric instead.".to_string(),
})
}
CoreParameter::RbfProfile(p) => {
Parameter::RbfProfile(p.try_into_v2_parameter(parent_node, unnamed_count)?)
}
},
ParameterV1::Custom(p) => {
println!("Custom parameter: {:?} ({})", p.meta.name, p.ty);
Expand Down
171 changes: 155 additions & 16 deletions pywr-schema/src/parameters/profiles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::parameters::{
use pywr_core::parameters::ParameterIndex;
use pywr_v1_schema::parameters::{
DailyProfileParameter as DailyProfileParameterV1, MonthInterpDay as MonthInterpDayV1,
MonthlyProfileParameter as MonthlyProfileParameterV1,
MonthlyProfileParameter as MonthlyProfileParameterV1, RbfProfileParameter as RbfProfileParameterV1,
UniformDrawdownProfileParameter as UniformDrawdownProfileParameterV1,
};
use std::collections::HashMap;
Expand Down Expand Up @@ -224,27 +224,79 @@ impl TryFromV1Parameter<UniformDrawdownProfileParameterV1> for UniformDrawdownPr
pub enum RadialBasisFunction {
Linear,
Cubic,
Quintic,
ThinPlateSpline,
Gaussian { epsilon: f64 },
MultiQuadric { epsilon: f64 },
InverseMultiQuadric { epsilon: f64 },
Gaussian { epsilon: Option<f64> },
MultiQuadric { epsilon: Option<f64> },
InverseMultiQuadric { epsilon: Option<f64> },
}

impl Into<pywr_core::parameters::RadialBasisFunction> for RadialBasisFunction {
fn into(self) -> pywr_core::parameters::RadialBasisFunction {
match self {
impl RadialBasisFunction {
/// Convert the schema representation of the RBF into `pywr_core` type.
///
/// If required this will estimate values of from the provided points.
fn into_core_rbf(self, points: &[(u32, f64)]) -> Result<pywr_core::parameters::RadialBasisFunction, SchemaError> {
let rbf = match self {
Self::Linear => pywr_core::parameters::RadialBasisFunction::Linear,
Self::Cubic => pywr_core::parameters::RadialBasisFunction::Cubic,
Self::Quintic => pywr_core::parameters::RadialBasisFunction::Cubic,
Self::ThinPlateSpline => pywr_core::parameters::RadialBasisFunction::ThinPlateSpline,
Self::Gaussian { epsilon } => pywr_core::parameters::RadialBasisFunction::Gaussian { epsilon },
Self::MultiQuadric { epsilon } => pywr_core::parameters::RadialBasisFunction::MultiQuadric { epsilon },
Self::Gaussian { epsilon } => {
let epsilon = match epsilon {
Some(e) => e,
None => estimate_epsilon(points).ok_or(SchemaError::RbfEpsilonEstimation)?,
};

pywr_core::parameters::RadialBasisFunction::Gaussian { epsilon }
}
Self::MultiQuadric { epsilon } => {
let epsilon = match epsilon {
Some(e) => e,
None => estimate_epsilon(points).ok_or(SchemaError::RbfEpsilonEstimation)?,
};

pywr_core::parameters::RadialBasisFunction::MultiQuadric { epsilon }
}
Self::InverseMultiQuadric { epsilon } => {
let epsilon = match epsilon {
Some(e) => e,
None => estimate_epsilon(points).ok_or(SchemaError::RbfEpsilonEstimation)?,
};

pywr_core::parameters::RadialBasisFunction::InverseMultiQuadric { epsilon }
}
}
};

Ok(rbf)
}
}

/// Compute an estimate for epsilon.
///
/// If there `points` is empty then `None` is returned.
fn estimate_epsilon(points: &[(u32, f64)]) -> Option<f64> {
if points.is_empty() {
return None;
}

// SAFETY: Above check that points is non-empty should make these unwraps safe.
let x_min = points.iter().map(|(x, _)| *x).min().unwrap();
let x_max = points.iter().map(|(x, _)| *x).max().unwrap();
let y_min = points.iter().map(|(_, y)| *y).reduce(f64::min).unwrap();
let y_max = points.iter().map(|(_, y)| *y).reduce(f64::max).unwrap();

let mut x_range = x_max - x_min;
if x_range == 0 {
x_range = 1;
}
let mut y_range = y_max - y_min;
if y_range == 0.0 {
y_range = 1.0;
}

Some((x_range as f64 * y_range).powf(1.0 / points.len() as f64))
}

/// Settings for a variable RBF profile.
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, Copy)]
pub struct RbfProfileVariableSettings {
Expand Down Expand Up @@ -325,12 +377,99 @@ impl RbfProfileParameter {
}
};

let p = pywr_core::parameters::RbfProfileParameter::new(
&self.meta.name,
self.points.clone(),
self.function.into(),
variable,
);
let function = self.function.into_core_rbf(&self.points)?;

let p =
pywr_core::parameters::RbfProfileParameter::new(&self.meta.name, self.points.clone(), function, variable);
Ok(model.add_parameter(Box::new(p))?)
}
}

impl TryFromV1Parameter<RbfProfileParameterV1> for RbfProfileParameter {
type Error = ConversionError;

fn try_from_v1_parameter(
v1: RbfProfileParameterV1,
parent_node: Option<&str>,
unnamed_count: &mut usize,
) -> Result<Self, Self::Error> {
let meta: ParameterMeta = v1.meta.into_v2_parameter(parent_node, unnamed_count);

let points = v1
.days_of_year
.into_iter()
.zip(v1.values.into_iter())
.map(|(doy, v)| (doy, v))
.collect();

if v1.rbf_kwargs.contains_key("smooth") {
return Err(ConversionError::UnsupportedFeature {
feature: "The RBF `smooth` keyword argument is not supported.".to_string(),
name: meta.name,
});
}

if v1.rbf_kwargs.contains_key("norm") {
return Err(ConversionError::UnsupportedFeature {
feature: "The RBF `norm` keyword argument is not supported.".to_string(),
name: meta.name,
});
}

// Parse any epsilon value; we expect a float here.
let epsilon = if let Some(epsilon_value) = v1.rbf_kwargs.get("epsilon") {
if let Some(epsilon_f64) = epsilon_value.as_f64() {
Some(epsilon_f64)
} else {
return Err(ConversionError::UnexpectedType {
attr: "epsilon".to_string(),
name: meta.name,
expected: "float".to_string(),
actual: format!("{}", epsilon_value),
});
}
} else {
None
};

let function = if let Some(function_value) = v1.rbf_kwargs.get("function") {
if let Some(function_str) = function_value.as_str() {
// Function kwarg is a string!
let f = match function_str {
"multiquadric" => RadialBasisFunction::MultiQuadric { epsilon },
"inverse" => RadialBasisFunction::InverseMultiQuadric { epsilon },
"gaussian" => RadialBasisFunction::Gaussian { epsilon },
"linear" => RadialBasisFunction::Linear,
"cubic" => RadialBasisFunction::Cubic,
"thin_plate" => RadialBasisFunction::ThinPlateSpline,
_ => {
return Err(ConversionError::UnsupportedFeature {
feature: format!("Radial basis function `{}` not supported.", function_str),
name: meta.name.clone(),
})
}
};
f
} else {
return Err(ConversionError::UnexpectedType {
attr: "function".to_string(),
name: meta.name,
expected: "string".to_string(),
actual: format!("{}", function_value),
});
}
} else {
// Default to multi-quadratic
RadialBasisFunction::MultiQuadric { epsilon }
};

let p = Self {
meta,
points,
function,
variable: None,
};

Ok(p)
}
}

0 comments on commit e0538c6

Please sign in to comment.