Skip to content

Commit

Permalink
chore: Merge branch 'refs/heads/main' into deps/polars-v0_39
Browse files Browse the repository at this point in the history
  • Loading branch information
jetuk committed Apr 26, 2024
2 parents aade2ec + 90e42b3 commit d005aff
Show file tree
Hide file tree
Showing 47 changed files with 220 additions and 149 deletions.
7 changes: 4 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@ num = "0.4.0"
float-cmp = "0.9.0"
ndarray = "0.15.3"
polars = { version = "0.39", features = ["lazy", "rows", "ndarray"] }
pyo3-polars = "0.12.0"
pyo3 = { version = "0.20.2", default-features = false }
pyo3-log = "0.9.0"
pyo3-polars = "0.13"
pyo3 = { version = "0.21", default-features = false }
pyo3-log = "0.10"
tracing = { version = "0.1", features = ["log"] }
csv = "1.1"
hdf5 = { git = "https://github.com/aldanor/hdf5-rust.git", package = "hdf5", features = ["static", "zlib"] }
pywr-v1-schema = { git = "https://github.com/pywr/pywr-schema/", tag = "v0.12.0", package = "pywr-schema" }
chrono = { version = "0.4.34" }
schemars = { version = "0.8.16", features = ["chrono"] }
5 changes: 3 additions & 2 deletions pywr-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@ categories = ["science", "simulation"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
clap = { version="4.0", features=["derive"] }
clap = { version = "4.0", features = ["derive"] }
anyhow = "1.0.69"
tracing = { workspace = true }
tracing-subscriber = { version ="0.3.17", features=["env-filter"] }
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
rand = "0.8.5"
rand_chacha = "0.3.1"
serde = { workspace = true }
serde_json = { workspace = true }
pywr-v1-schema = { workspace = true }
schemars = { workspace = true }

pywr-core = { path = "../pywr-core" }
pywr-schema = { path = "../pywr-schema" }
19 changes: 18 additions & 1 deletion pywr-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ mod tracing;

use crate::tracing::setup_tracing;
use ::tracing::info;
use anyhow::Result;
use anyhow::{Context, Result};
use clap::{Parser, Subcommand, ValueEnum};
#[cfg(feature = "ipm-ocl")]
use pywr_core::solvers::{ClIpmF32Solver, ClIpmF64Solver, ClIpmSolverSettings};
Expand All @@ -15,6 +15,7 @@ use pywr_core::test_utils::make_random_model;
use pywr_schema::model::{PywrModel, PywrMultiNetworkModel};
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use schemars::schema_for;
use std::fmt::{Display, Formatter};
use std::path::{Path, PathBuf};

Expand Down Expand Up @@ -109,6 +110,10 @@ enum Commands {
#[arg(short, long, default_value_t=Solver::Clp)]
solver: Solver,
},
ExportSchema {
/// Path to save the JSON schema.
out: PathBuf,
},
}

fn main() -> Result<()> {
Expand Down Expand Up @@ -140,6 +145,7 @@ fn main() -> Result<()> {
num_scenarios,
solver,
} => run_random(*num_systems, *density, *num_scenarios, solver),
Commands::ExportSchema { out } => export_schema(out)?,
},
None => {}
}
Expand Down Expand Up @@ -254,3 +260,14 @@ fn run_random(num_systems: usize, density: usize, num_scenarios: usize, solver:
}
.unwrap();
}

fn export_schema(out_path: &Path) -> Result<()> {
let schema = schema_for!(PywrModel);
std::fs::write(
out_path,
serde_json::to_string_pretty(&schema).with_context(|| "Failed serialise Pywr schema".to_string())?,
)
.with_context(|| format!("Failed to write file: {:?}", out_path))?;

Ok(())
}
45 changes: 27 additions & 18 deletions pywr-core/src/parameters/py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,38 +52,38 @@ impl PyParameter {
network: &Network,
state: &State,
py: Python<'py>,
) -> Result<&'py PyDict, PywrError> {
) -> Result<Bound<'py, PyDict>, PywrError> {
let metric_values: Vec<(&str, f64)> = self
.metrics
.iter()
.map(|(k, value)| Ok((k.as_str(), value.get_value(network, state)?)))
.collect::<Result<Vec<_>, PywrError>>()?;

Ok(metric_values.into_py_dict(py))
Ok(metric_values.into_py_dict_bound(py))
}

fn get_indices_dict<'py>(
&self,
network: &Network,
state: &State,
py: Python<'py>,
) -> Result<&'py PyDict, PywrError> {
) -> Result<Bound<'py, PyDict>, PywrError> {
let index_values: Vec<(&str, usize)> = self
.indices
.iter()
.map(|(k, value)| Ok((k.as_str(), value.get_value(network, state)?)))
.collect::<Result<Vec<_>, PywrError>>()?;

Ok(index_values.into_py_dict(py))
Ok(index_values.into_py_dict_bound(py))
}

fn setup(&self) -> Result<Option<Box<dyn ParameterState>>, PywrError> {
pyo3::prepare_freethreaded_python();

let user_obj: PyObject = Python::with_gil(|py| -> PyResult<PyObject> {
let args = self.args.as_ref(py);
let kwargs = self.kwargs.as_ref(py);
self.object.call(py, args, Some(kwargs))
let args = self.args.bind(py);
let kwargs = self.kwargs.bind(py);
self.object.call_bound(py, args, Some(kwargs))
})
.unwrap();

Expand Down Expand Up @@ -113,7 +113,10 @@ impl PyParameter {
let metric_dict = self.get_metrics_dict(network, state, py)?;
let index_dict = self.get_indices_dict(network, state, py)?;

let args = PyTuple::new(py, [date.as_ref(py), si.as_ref(py), metric_dict, index_dict]);
let args = PyTuple::new_bound(
py,
[date.bind(py), si.bind(py), metric_dict.as_any(), index_dict.as_any()],
);

internal.user_obj.call_method1(py, "calc", args)?.extract(py)
})
Expand Down Expand Up @@ -142,7 +145,10 @@ impl PyParameter {
let metric_dict = self.get_metrics_dict(network, state, py)?;
let index_dict = self.get_indices_dict(network, state, py)?;

let args = PyTuple::new(py, [date.as_ref(py), si.as_ref(py), metric_dict, index_dict]);
let args = PyTuple::new_bound(
py,
[date.bind(py), si.bind(py), metric_dict.as_any(), index_dict.as_any()],
);

internal.user_obj.call_method1(py, "after", args)?;
}
Expand Down Expand Up @@ -266,7 +272,10 @@ impl Parameter<MultiValue> for PyParameter {
let metric_dict = self.get_metrics_dict(network, state, py)?;
let index_dict = self.get_indices_dict(network, state, py)?;

let args = PyTuple::new(py, [date.as_ref(py), si.as_ref(py), metric_dict, index_dict]);
let args = PyTuple::new_bound(
py,
[date.bind(py), si.bind(py), metric_dict.as_any(), index_dict.as_any()],
);

let py_values: HashMap<String, PyObject> = internal
.user_obj
Expand All @@ -278,15 +287,15 @@ impl Parameter<MultiValue> for PyParameter {
// Try to convert the floats
let values: HashMap<String, f64> = py_values
.iter()
.filter_map(|(k, v)| match v.downcast::<PyFloat>(py) {
.filter_map(|(k, v)| match v.downcast_bound::<PyFloat>(py) {
Ok(v) => Some((k.clone(), v.extract().unwrap())),
Err(_) => None,
})
.collect();

let indices: HashMap<String, usize> = py_values
.iter()
.filter_map(|(k, v)| match v.downcast::<PyLong>(py) {
.filter_map(|(k, v)| match v.downcast_bound::<PyLong>(py) {
Ok(v) => Some((k.clone(), v.extract().unwrap())),
Err(_) => None,
})
Expand Down Expand Up @@ -332,7 +341,7 @@ mod tests {
pyo3::prepare_freethreaded_python();

let class = Python::with_gil(|py| {
let test_module = PyModule::from_code(
let test_module = PyModule::from_code_bound(
py,
r#"
class MyParameter:
Expand All @@ -351,8 +360,8 @@ class MyParameter:
test_module.getattr("MyParameter").unwrap().into()
});

let args = Python::with_gil(|py| PyTuple::new(py, [0]).into());
let kwargs = Python::with_gil(|py| PyDict::new(py).into());
let args = Python::with_gil(|py| PyTuple::new_bound(py, [0]).into());
let kwargs = Python::with_gil(|py| PyDict::new_bound(py).into());

let param = PyParameter::new("my-parameter", class, args, kwargs, &HashMap::new(), &HashMap::new());
let timestepper = default_timestepper();
Expand Down Expand Up @@ -395,7 +404,7 @@ class MyParameter:
pyo3::prepare_freethreaded_python();

let class = Python::with_gil(|py| {
let test_module = PyModule::from_code(
let test_module = PyModule::from_code_bound(
py,
r#"
import math
Expand All @@ -420,8 +429,8 @@ class MyParameter:
test_module.getattr("MyParameter").unwrap().into()
});

let args = Python::with_gil(|py| PyTuple::new(py, [0]).into());
let kwargs = Python::with_gil(|py| PyDict::new(py).into());
let args = Python::with_gil(|py| PyTuple::new_bound(py, [0]).into());
let kwargs = Python::with_gil(|py| PyDict::new_bound(py).into());

let param = PyParameter::new("my-parameter", class, args, kwargs, &HashMap::new(), &HashMap::new());
let timestepper = default_timestepper();
Expand Down
12 changes: 6 additions & 6 deletions pywr-python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub struct Schema {
#[pymethods]
impl Schema {
#[new]
fn new(title: &str, start: &PyDateTime, end: &PyDateTime) -> Self {
fn new(title: &str, start: &Bound<'_, PyDateTime>, end: &Bound<'_, PyDateTime>) -> Self {
// SAFETY: We know that the date & month are valid because it is a Python date.
let start = DateType::DateTime(
NaiveDate::from_ymd_opt(start.get_year(), start.get_month() as u32, start.get_day() as u32)
Expand All @@ -90,15 +90,15 @@ impl Schema {

/// Create a new schema object from a file path.
#[classmethod]
fn from_path(_cls: &PyType, path: PathBuf) -> PyResult<Self> {
fn from_path(_cls: &Bound<'_, PyType>, path: PathBuf) -> PyResult<Self> {
Ok(Self {
schema: pywr_schema::PywrModel::from_path(path)?,
})
}

/// Create a new schema object from a JSON string.
#[classmethod]
fn from_json_string(_cls: &PyType, data: &str) -> PyResult<Self> {
fn from_json_string(_cls: &Bound<'_, PyType>, data: &str) -> PyResult<Self> {
Ok(Self {
schema: pywr_schema::PywrModel::from_str(data)?,
})
Expand All @@ -124,7 +124,7 @@ pub struct Model {

#[pymethods]
impl Model {
fn run(&self, solver_name: &str, solver_kwargs: Option<&PyDict>) -> PyResult<()> {
fn run(&self, solver_name: &str, solver_kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<()> {
match solver_name {
"clp" => {
let settings = build_clp_settings(solver_kwargs)?;
Expand All @@ -146,7 +146,7 @@ impl Model {
}
}

fn build_clp_settings(kwargs: Option<&PyDict>) -> PyResult<ClpSolverSettings> {
fn build_clp_settings(kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<ClpSolverSettings> {
let mut builder = ClpSolverSettingsBuilder::default();

if let Some(kwargs) = kwargs {
Expand Down Expand Up @@ -211,7 +211,7 @@ fn build_highs_settings(kwargs: Option<&PyDict>) -> PyResult<HighsSolverSettings

/// A Python module implemented in Rust.
#[pymodule]
fn pywr(_py: Python, m: &PyModule) -> PyResult<()> {
fn pywr(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
pyo3_log::init();

m.add_class::<Schema>()?;
Expand Down
2 changes: 1 addition & 1 deletion pywr-schema/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pyo3 = { workspace = true, optional = true }
pyo3-polars = { workspace = true, optional = true }
strum = "0.26"
strum_macros = "0.26"

schemars = { workspace = true }
hdf5 = { workspace = true, optional = true }
csv = { workspace = true, optional = true }
tracing = { workspace = true, optional = true }
Expand Down
11 changes: 6 additions & 5 deletions pywr-schema/src/data_tables/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use pywr_v1_schema::parameters::TableDataRef as TableDataRefV1;
use scalar::{
load_csv_row2_scalar_table_one, load_csv_row_col_scalar_table_one, load_csv_row_scalar_table_one, LoadedScalarTable,
};
use schemars::JsonSchema;
#[cfg(feature = "core")]
use std::collections::HashMap;
use std::path::{Path, PathBuf};
Expand All @@ -19,7 +20,7 @@ use tracing::{debug, info};
#[cfg(feature = "core")]
use vec::{load_csv_row2_vec_table_one, load_csv_row_vec_table_one, LoadedVecTable};

#[derive(serde::Deserialize, serde::Serialize, Debug, Clone)]
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum DataTableType {
Scalar,
Expand All @@ -31,7 +32,7 @@ pub enum DataTableFormat {
CSV,
}

#[derive(serde::Deserialize, serde::Serialize, Debug, Clone)]
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema)]
#[serde(tag = "format", rename_all = "lowercase")]
pub enum DataTable {
CSV(CsvDataTable),
Expand All @@ -52,7 +53,7 @@ impl DataTable {
}
}

#[derive(serde::Deserialize, serde::Serialize, Debug, Clone)]
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum CsvDataTableLookup {
Row(usize),
Expand All @@ -61,7 +62,7 @@ pub enum CsvDataTableLookup {
}

/// An external table of data that can be referenced
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone)]
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema)]
pub struct CsvDataTable {
pub name: String,
#[serde(rename = "type")]
Expand Down Expand Up @@ -234,7 +235,7 @@ impl LoadedTableCollection {
}
}

#[derive(serde::Deserialize, serde::Serialize, Debug, Clone)]
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema)]
pub struct TableDataRef {
pub table: String,
pub column: Option<TableIndex>,
Expand Down
4 changes: 3 additions & 1 deletion pywr-schema/src/edge.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#[derive(serde::Deserialize, serde::Serialize, Clone)]
use schemars::JsonSchema;

#[derive(serde::Deserialize, serde::Serialize, Clone, JsonSchema)]
pub struct Edge {
pub from_node: String,
pub to_node: String,
Expand Down
2 changes: 2 additions & 0 deletions pywr-schema/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ pub enum SchemaError {
Timeseries(#[from] TimeseriesError),
#[error("The output of literal constant values is not supported. This is because they do not have a unique identifier such as a name. If you would like to output a constant value please use a `Constant` parameter.")]
LiteralConstantOutputNotSupported,
#[error("Chrono out of range error: {0}")]
OutOfRange(#[from] chrono::OutOfRange),
}

#[cfg(feature = "core")]
Expand Down
Loading

0 comments on commit d005aff

Please sign in to comment.