Skip to content

Commit

Permalink
feat: Implement Parameter<usize> for PyParameter.
Browse files Browse the repository at this point in the history
This supports using PyParameter as an index parameter with updated
schema to define the return type of the Python method.
  • Loading branch information
jetuk committed Mar 12, 2024
1 parent fb89773 commit f051b86
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 83 deletions.
135 changes: 81 additions & 54 deletions pywr-core/src/parameters/py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,8 @@ impl PyParameter {

Ok(index_values.into_py_dict(py))
}
}

impl Parameter<f64> for PyParameter {
fn meta(&self) -> &ParameterMeta {
&self.meta
}

fn setup(
&self,
_timesteps: &[Timestep],
_scenario_index: &ScenarioIndex,
) -> Result<Option<Box<dyn ParameterState>>, PywrError> {
fn setup(&self) -> Result<Option<Box<dyn ParameterState>>, PywrError> {
pyo3::prepare_freethreaded_python();

let user_obj: PyObject = Python::with_gil(|py| -> PyResult<PyObject> {
Expand All @@ -92,26 +82,20 @@ impl Parameter<f64> for PyParameter {
Ok(Some(internal.into_boxed_any()))
}

// fn before(&self, internal_state: &mut Option<Box<dyn ParameterState>>) -> Result<(), PywrError> {
// let internal = downcast_internal_state::<Internal>(internal_state);
//
// Python::with_gil(|py| internal.user_obj.call_method0(py, "before"))
// .map_err(|e| PywrError::PythonError(e.to_string()))?;
//
// Ok(())
// }

fn compute(
fn compute<T>(
&self,
timestep: &Timestep,
scenario_index: &ScenarioIndex,
model: &Network,
state: &State,
internal_state: &mut Option<Box<dyn ParameterState>>,
) -> Result<f64, PywrError> {
) -> Result<T, PywrError>
where
T: for<'a> FromPyObject<'a>,
{
let internal = downcast_internal_state_mut::<Internal>(internal_state);

let value: f64 = Python::with_gil(|py| {
let value: T = Python::with_gil(|py| {
let date = timestep.date.into_py(py);

let si = scenario_index.index.into_py(py);
Expand Down Expand Up @@ -160,7 +144,8 @@ impl Parameter<f64> for PyParameter {
}
}

impl Parameter<MultiValue> for PyParameter {
impl Parameter<f64> for PyParameter {

fn meta(&self) -> &ParameterMeta {
&self.meta
}
Expand All @@ -170,18 +155,80 @@ impl Parameter<MultiValue> for PyParameter {
_timesteps: &[Timestep],
_scenario_index: &ScenarioIndex,
) -> Result<Option<Box<dyn ParameterState>>, PywrError> {
pyo3::prepare_freethreaded_python();
self.setup()
}

let user_obj: PyObject = Python::with_gil(|py| -> PyResult<PyObject> {
let args = self.args.as_ref(py);
let kwargs = self.kwargs.as_ref(py);
self.object.call(py, args, Some(kwargs))
})
.unwrap();
fn compute(
&self,
timestep: &Timestep,
scenario_index: &ScenarioIndex,
model: &Network,
state: &State,
internal_state: &mut Option<Box<dyn ParameterState>>,
) -> Result<f64, PywrError> {
self.compute(timestep, scenario_index, model, state, internal_state)
}

let internal = Internal { user_obj };
fn after(
&self,
timestep: &Timestep,
scenario_index: &ScenarioIndex,
model: &Network,
state: &State,
internal_state: &mut Option<Box<dyn ParameterState>>,
) -> Result<(), PywrError> {
self.after(timestep, scenario_index, model, state, internal_state)
}
}

Ok(Some(internal.into_boxed_any()))
impl Parameter<usize> for PyParameter {

fn meta(&self) -> &ParameterMeta {
&self.meta
}

fn setup(
&self,
_timesteps: &[Timestep],
_scenario_index: &ScenarioIndex,
) -> Result<Option<Box<dyn ParameterState>>, PywrError> {
self.setup()
}

fn compute(
&self,
timestep: &Timestep,
scenario_index: &ScenarioIndex,
model: &Network,
state: &State,
internal_state: &mut Option<Box<dyn ParameterState>>,
) -> Result<usize, PywrError> {
self.compute(timestep, scenario_index, model, state, internal_state)
}

fn after(
&self,
timestep: &Timestep,
scenario_index: &ScenarioIndex,
model: &Network,
state: &State,
internal_state: &mut Option<Box<dyn ParameterState>>,
) -> Result<(), PywrError> {
self.after(timestep, scenario_index, model, state, internal_state)
}
}

impl Parameter<MultiValue> for PyParameter {
fn meta(&self) -> &ParameterMeta {
&self.meta
}

fn setup(
&self,
_timesteps: &[Timestep],
_scenario_index: &ScenarioIndex,
) -> Result<Option<Box<dyn ParameterState>>, PywrError> {
self.setup()
}

// fn before(&self, internal_state: &mut Option<Box<dyn ParameterState>>) -> Result<(), PywrError> {
Expand Down Expand Up @@ -257,27 +304,7 @@ impl Parameter<MultiValue> for PyParameter {
state: &State,
internal_state: &mut Option<Box<dyn ParameterState>>,
) -> Result<(), PywrError> {
let internal = downcast_internal_state_mut::<Internal>(internal_state);

Python::with_gil(|py| {
// Only do this if the object has an "after" method defined.
if internal.user_obj.getattr(py, "after").is_ok() {
let date = timestep.date.into_py(py);

let si = scenario_index.index.into_py(py);

let metric_dict = self.get_metrics_dict(model, state, py)?;
let index_dict = self.get_indices_dict(state, py)?;

let args = PyTuple::new(py, [date.as_ref(py), si.as_ref(py), metric_dict, index_dict]);

internal.user_obj.call_method1(py, "after", args)?;
}
Ok(())
})
.map_err(|e: PyErr| PywrError::PythonError(e.to_string()))?;

Ok(())
self.after(timestep, scenario_index, model, state, internal_state)
}
}

Expand Down
88 changes: 59 additions & 29 deletions pywr-schema/src/parameters/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ pub enum PythonModule {
Path(PathBuf),
}

/// The expected return type of the Python parameter.
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, Default)]
#[serde(rename_all = "lowercase")]
pub enum PythonReturnType {
#[default]
Float,
Int,
Dict,
}

/// A Parameter that uses a Python object for its calculations.
///
/// This struct defines a schema for loading a [`crate::parameters::PyParameter`] from external
Expand Down Expand Up @@ -67,10 +77,10 @@ pub struct PythonParameter {
pub module: PythonModule,
/// The name of Python object from the module to use.
pub object: String,
/// Is this a multi-valued parameter or not. If true then the calculation method should
/// return a dictionary with string keys and either floats or ints as values.
/// The return type of the Python calculation. This is used to convert the Python return value
/// to the appropriate type for the Parameter.
#[serde(default)]
pub multi: bool,
pub return_type: PythonReturnType,
/// Position arguments to pass to the object during setup.
pub args: Vec<serde_json::Value>,
/// Keyword arguments to pass to the object during setup.
Expand Down Expand Up @@ -193,10 +203,11 @@ impl PythonParameter {
};

let p = PyParameter::new(&self.meta.name, object, args, kwargs, &metrics, &indices);
let pt = if self.multi {
ParameterType::Multi(network.add_multi_value_parameter(Box::new(p))?)
} else {
ParameterType::Parameter(network.add_parameter(Box::new(p))?)

let pt = match self.return_type {
PythonReturnType::Float => ParameterType::Parameter(network.add_parameter(Box::new(p))?),
PythonReturnType::Int => ParameterType::Index(network.add_index_parameter(Box::new(p))?),
PythonReturnType::Dict => ParameterType::Multi(network.add_multi_value_parameter(Box::new(p))?),
};

Ok(pt)
Expand All @@ -212,42 +223,59 @@ mod tests {
use pywr_core::network::Network;
use pywr_core::test_utils::default_time_domain;
use serde_json::json;
use std::fs::File;
use std::io::Write;
use tempfile::tempdir;
use std::path::PathBuf;

#[test]
fn test_python_parameter() {
let dir = tempdir().unwrap();

let file_path = dir.path().join("my_parameter.py");
fn test_python_float_parameter() {
let mut py_fn = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
py_fn.push("src/test_models/test_parameters.py");

let data = json!(
{
"name": "my-custom-calculation",
"name": "my-float-parameter",
"type": "Python",
"path": file_path,
"object": "MyParameter",
"path": py_fn,
"object": "FloatParameter",
"args": [0, ],
"kwargs": {},
}
)
.to_string();

let mut file = File::create(file_path).unwrap();
write!(
file,
r#"
class MyParameter:
def __init__(self, count, *args, **kwargs):
self.count = 0
// Init Python
pyo3::prepare_freethreaded_python();
// Load the schema ...
let param: PythonParameter = serde_json::from_str(data.as_str()).unwrap();
// ... add it to an empty network
// this should trigger loading the module and extracting the class
let domain: ModelDomain = default_time_domain().into();
let schema = PywrNetwork::default();
let mut network = Network::default();
let tables = LoadedTableCollection::from_schema(None, None).unwrap();
param
.add_to_model(&mut network, &schema, &domain, &tables, None, &[])
.unwrap();

assert!(network.get_parameter_by_name("my-float-parameter").is_ok());
}

#[test]
fn test_python_int_parameter() {
let mut py_fn = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
py_fn.push("src/test_models/test_parameters.py");

def calc(self, ts, si, p_values):
self.count += si
return float(self.count + ts.day)
"#
let data = json!(
{
"name": "my-int-parameter",
"type": "Python",
"path": py_fn,
"return_type": "int",
"object": "FloatParameter",
"args": [0, ],
"kwargs": {},
}
)
.unwrap();
.to_string();

// Init Python
pyo3::prepare_freethreaded_python();
Expand All @@ -262,5 +290,7 @@ class MyParameter:
param
.add_to_model(&mut network, &schema, &domain, &tables, None, &[])
.unwrap();

assert!(network.get_index_parameter_by_name("my-int-parameter").is_ok());
}
}
20 changes: 20 additions & 0 deletions pywr-schema/src/test_models/test_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
class FloatParameter:
"""A simple float parameter"""

def __init__(self, count, *args, **kwargs):
self.count = 0

def calc(self, ts, si, p_values) -> float:
self.count += si
return float(self.count + ts.day)


class IntParameter:
"""A simple int parameter"""

def __init__(self, count, *args, **kwargs):
self.count = 0

def calc(self, ts, si, p_values) -> int:
self.count += si
return self.count + ts.day

0 comments on commit f051b86

Please sign in to comment.