diff --git a/pywr-core/src/parameters/py.rs b/pywr-core/src/parameters/py.rs index 0181edbe..9289ce39 100644 --- a/pywr-core/src/parameters/py.rs +++ b/pywr-core/src/parameters/py.rs @@ -66,18 +66,8 @@ impl PyParameter { Ok(index_values.into_py_dict(py)) } -} - -impl Parameter for PyParameter { - fn meta(&self) -> &ParameterMeta { - &self.meta - } - fn setup( - &self, - _timesteps: &[Timestep], - _scenario_index: &ScenarioIndex, - ) -> Result>, PywrError> { + fn setup(&self) -> Result>, PywrError> { pyo3::prepare_freethreaded_python(); let user_obj: PyObject = Python::with_gil(|py| -> PyResult { @@ -92,26 +82,20 @@ impl Parameter for PyParameter { Ok(Some(internal.into_boxed_any())) } - // fn before(&self, internal_state: &mut Option>) -> Result<(), PywrError> { - // let internal = downcast_internal_state::(internal_state); - // - // Python::with_gil(|py| internal.user_obj.call_method0(py, "before")) - // .map_err(|e| PywrError::PythonError(e.to_string()))?; - // - // Ok(()) - // } - - fn compute( + fn compute( &self, timestep: &Timestep, scenario_index: &ScenarioIndex, model: &Network, state: &State, internal_state: &mut Option>, - ) -> Result { + ) -> Result + where + T: for<'a> FromPyObject<'a>, + { let internal = downcast_internal_state_mut::(internal_state); - let value: f64 = Python::with_gil(|py| { + let value: T = Python::with_gil(|py| { let date = timestep.date.into_py(py); let si = scenario_index.index.into_py(py); @@ -160,7 +144,8 @@ impl Parameter for PyParameter { } } -impl Parameter for PyParameter { +impl Parameter for PyParameter { + fn meta(&self) -> &ParameterMeta { &self.meta } @@ -170,18 +155,80 @@ impl Parameter for PyParameter { _timesteps: &[Timestep], _scenario_index: &ScenarioIndex, ) -> Result>, PywrError> { - pyo3::prepare_freethreaded_python(); + self.setup() + } - let user_obj: PyObject = Python::with_gil(|py| -> PyResult { - let args = self.args.as_ref(py); - let kwargs = self.kwargs.as_ref(py); - self.object.call(py, args, Some(kwargs)) - }) - .unwrap(); + fn compute( + &self, + timestep: &Timestep, + scenario_index: &ScenarioIndex, + model: &Network, + state: &State, + internal_state: &mut Option>, + ) -> Result { + self.compute(timestep, scenario_index, model, state, internal_state) + } - let internal = Internal { user_obj }; + fn after( + &self, + timestep: &Timestep, + scenario_index: &ScenarioIndex, + model: &Network, + state: &State, + internal_state: &mut Option>, + ) -> Result<(), PywrError> { + self.after(timestep, scenario_index, model, state, internal_state) + } +} - Ok(Some(internal.into_boxed_any())) +impl Parameter for PyParameter { + + fn meta(&self) -> &ParameterMeta { + &self.meta + } + + fn setup( + &self, + _timesteps: &[Timestep], + _scenario_index: &ScenarioIndex, + ) -> Result>, PywrError> { + self.setup() + } + + fn compute( + &self, + timestep: &Timestep, + scenario_index: &ScenarioIndex, + model: &Network, + state: &State, + internal_state: &mut Option>, + ) -> Result { + self.compute(timestep, scenario_index, model, state, internal_state) + } + + fn after( + &self, + timestep: &Timestep, + scenario_index: &ScenarioIndex, + model: &Network, + state: &State, + internal_state: &mut Option>, + ) -> Result<(), PywrError> { + self.after(timestep, scenario_index, model, state, internal_state) + } +} + +impl Parameter for PyParameter { + fn meta(&self) -> &ParameterMeta { + &self.meta + } + + fn setup( + &self, + _timesteps: &[Timestep], + _scenario_index: &ScenarioIndex, + ) -> Result>, PywrError> { + self.setup() } // fn before(&self, internal_state: &mut Option>) -> Result<(), PywrError> { @@ -257,27 +304,7 @@ impl Parameter for PyParameter { state: &State, internal_state: &mut Option>, ) -> Result<(), PywrError> { - let internal = downcast_internal_state_mut::(internal_state); - - Python::with_gil(|py| { - // Only do this if the object has an "after" method defined. - if internal.user_obj.getattr(py, "after").is_ok() { - let date = timestep.date.into_py(py); - - let si = scenario_index.index.into_py(py); - - let metric_dict = self.get_metrics_dict(model, state, py)?; - let index_dict = self.get_indices_dict(state, py)?; - - let args = PyTuple::new(py, [date.as_ref(py), si.as_ref(py), metric_dict, index_dict]); - - internal.user_obj.call_method1(py, "after", args)?; - } - Ok(()) - }) - .map_err(|e: PyErr| PywrError::PythonError(e.to_string()))?; - - Ok(()) + self.after(timestep, scenario_index, model, state, internal_state) } } diff --git a/pywr-schema/src/parameters/python.rs b/pywr-schema/src/parameters/python.rs index 52b62f61..0a76ac97 100644 --- a/pywr-schema/src/parameters/python.rs +++ b/pywr-schema/src/parameters/python.rs @@ -18,6 +18,16 @@ pub enum PythonModule { Path(PathBuf), } +/// The expected return type of the Python parameter. +#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, Default)] +#[serde(rename_all = "lowercase")] +pub enum PythonReturnType { + #[default] + Float, + Int, + Dict, +} + /// A Parameter that uses a Python object for its calculations. /// /// This struct defines a schema for loading a [`crate::parameters::PyParameter`] from external @@ -67,10 +77,10 @@ pub struct PythonParameter { pub module: PythonModule, /// The name of Python object from the module to use. pub object: String, - /// Is this a multi-valued parameter or not. If true then the calculation method should - /// return a dictionary with string keys and either floats or ints as values. + /// The return type of the Python calculation. This is used to convert the Python return value + /// to the appropriate type for the Parameter. #[serde(default)] - pub multi: bool, + pub return_type: PythonReturnType, /// Position arguments to pass to the object during setup. pub args: Vec, /// Keyword arguments to pass to the object during setup. @@ -193,10 +203,11 @@ impl PythonParameter { }; let p = PyParameter::new(&self.meta.name, object, args, kwargs, &metrics, &indices); - let pt = if self.multi { - ParameterType::Multi(network.add_multi_value_parameter(Box::new(p))?) - } else { - ParameterType::Parameter(network.add_parameter(Box::new(p))?) + + let pt = match self.return_type { + PythonReturnType::Float => ParameterType::Parameter(network.add_parameter(Box::new(p))?), + PythonReturnType::Int => ParameterType::Index(network.add_index_parameter(Box::new(p))?), + PythonReturnType::Dict => ParameterType::Multi(network.add_multi_value_parameter(Box::new(p))?), }; Ok(pt) @@ -212,42 +223,59 @@ mod tests { use pywr_core::network::Network; use pywr_core::test_utils::default_time_domain; use serde_json::json; - use std::fs::File; - use std::io::Write; - use tempfile::tempdir; + use std::path::PathBuf; #[test] - fn test_python_parameter() { - let dir = tempdir().unwrap(); - - let file_path = dir.path().join("my_parameter.py"); + fn test_python_float_parameter() { + let mut py_fn = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + py_fn.push("src/test_models/test_parameters.py"); let data = json!( { - "name": "my-custom-calculation", + "name": "my-float-parameter", "type": "Python", - "path": file_path, - "object": "MyParameter", + "path": py_fn, + "object": "FloatParameter", "args": [0, ], "kwargs": {}, } ) .to_string(); - let mut file = File::create(file_path).unwrap(); - write!( - file, - r#" -class MyParameter: - def __init__(self, count, *args, **kwargs): - self.count = 0 + // Init Python + pyo3::prepare_freethreaded_python(); + // Load the schema ... + let param: PythonParameter = serde_json::from_str(data.as_str()).unwrap(); + // ... add it to an empty network + // this should trigger loading the module and extracting the class + let domain: ModelDomain = default_time_domain().into(); + let schema = PywrNetwork::default(); + let mut network = Network::default(); + let tables = LoadedTableCollection::from_schema(None, None).unwrap(); + param + .add_to_model(&mut network, &schema, &domain, &tables, None, &[]) + .unwrap(); + + assert!(network.get_parameter_by_name("my-float-parameter").is_ok()); + } + + #[test] + fn test_python_int_parameter() { + let mut py_fn = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + py_fn.push("src/test_models/test_parameters.py"); - def calc(self, ts, si, p_values): - self.count += si - return float(self.count + ts.day) -"# + let data = json!( + { + "name": "my-int-parameter", + "type": "Python", + "path": py_fn, + "return_type": "int", + "object": "FloatParameter", + "args": [0, ], + "kwargs": {}, + } ) - .unwrap(); + .to_string(); // Init Python pyo3::prepare_freethreaded_python(); @@ -262,5 +290,7 @@ class MyParameter: param .add_to_model(&mut network, &schema, &domain, &tables, None, &[]) .unwrap(); + + assert!(network.get_index_parameter_by_name("my-int-parameter").is_ok()); } } diff --git a/pywr-schema/src/test_models/test_parameters.py b/pywr-schema/src/test_models/test_parameters.py new file mode 100644 index 00000000..47794588 --- /dev/null +++ b/pywr-schema/src/test_models/test_parameters.py @@ -0,0 +1,20 @@ +class FloatParameter: + """A simple float parameter""" + + def __init__(self, count, *args, **kwargs): + self.count = 0 + + def calc(self, ts, si, p_values) -> float: + self.count += si + return float(self.count + ts.day) + + +class IntParameter: + """A simple int parameter""" + + def __init__(self, count, *args, **kwargs): + self.count = 0 + + def calc(self, ts, si, p_values) -> int: + self.count += si + return self.count + ts.day