Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add pywr-schema-macro crate #136

Merged
merged 1 commit into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ members = [
"pywr-schema",
"pywr-cli",
"pywr-python",
"pywr-schema-macros",
]
exclude = [
"tests/models/simple-wasm/simple-wasm-parameter"
Expand Down
22 changes: 22 additions & 0 deletions pywr-schema-macros/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[package]
name = "pywr-schema-macros"
version = "2.0.0-dev"
edition = "2021"
rust-version = "1.60"
description = "A generalised water resource allocation model."
readme = "../README.md"
repository = "https://github.com/pywr/pywr-next/"
license = "MIT OR Apache-2.0"
keywords = ["water", "modelling"]
categories = ["science", "simulation"]

[lib]
name = "pywr_schema_macros"
path = "src/lib.rs"
proc-macro = true

[dependencies]
syn = "2.0.52"
quote = "1.0.35"


198 changes: 198 additions & 0 deletions pywr-schema-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
use proc_macro::TokenStream;
use quote::quote;

/// A derive macro for Pywr nodes that implements `parameters` and `parameters_mut` methods.
#[proc_macro_derive(PywrNode)]
pub fn pywr_node_macro(input: TokenStream) -> TokenStream {
// Parse the input tokens into a syntax tree
let input = syn::parse_macro_input!(input as syn::DeriveInput);
impl_parameter_references_derive(&input)
}

enum PywrField {
Optional(syn::Ident),
Required(syn::Ident),
}

/// Generates a [`TokenStream`] containing the implementation of two methods, `parameters`
/// and `parameters_mut`, for the given struct.
///
/// Both method returns a [`HashMap`] of parameter names to [`DynamicFloatValue`]. This
/// is intended to be used for nodes and parameter structs in the Pywr schema.
fn impl_parameter_references_derive(ast: &syn::DeriveInput) -> TokenStream {
// Name of the node type
let name = &ast.ident;

if let syn::Data::Struct(data) = &ast.data {
// Only apply this to structs

// Help struct for capturing parameter fields and whether they are optional.
struct ParamField {
field_name: syn::Ident,
optional: bool,
}

// Iterate through all fields of the struct. Try to find fields that reference
// parameters (e.g. `Option<DynamicFloatValue>` or `ParameterValue`).
let parameter_fields: Vec<ParamField> = data
.fields
.iter()
.filter_map(|field| {
let field_ident = field.ident.as_ref()?;
// Identify optional fields
match type_to_ident(&field.ty) {
Some(PywrField::Optional(ident)) => {
// If optional and a parameter identifier then add to the list
is_parameter_ident(&ident).then_some(ParamField {
field_name: field_ident.clone(),
optional: true,
})
}
Some(PywrField::Required(ident)) => {
// Otherwise, if a parameter identifier then add to the list
is_parameter_ident(&ident).then_some(ParamField {
field_name: field_ident.clone(),
optional: false,
})
}
None => None, // All other fields are ignored.
}
})
.collect();

// Insert statements for non-mutable version
let inserts = parameter_fields
.iter()
.map(|param_field| {
let ident = &param_field.field_name;
let key = ident.to_string();
if param_field.optional {
quote! {
if let Some(p) = &self.#ident {
attributes.insert(#key, p.into());
}
}
} else {
quote! {
let #ident = &self.#ident;
attributes.insert(#key, #ident.into());
}
}
})
.collect::<Vec<_>>();

// Insert statements for mutable version
let inserts_mut = parameter_fields
.iter()
.map(|param_field| {
let ident = &param_field.field_name;
let key = ident.to_string();
if param_field.optional {
quote! {
if let Some(p) = &mut self.#ident {
attributes.insert(#key, p.into());
}
}
} else {
quote! {
let #ident = &mut self.#ident;
attributes.insert(#key, #ident.into());
}
}
})
.collect::<Vec<_>>();

// Create the two parameter methods using the insert statements
let expanded = quote! {
impl #name {
pub fn parameters(&self) -> HashMap<&str, &DynamicFloatValue> {
let mut attributes = HashMap::new();
#(
#inserts
)*
attributes
}

pub fn parameters_mut(&mut self) -> HashMap<&str, &mut DynamicFloatValue> {
let mut attributes = HashMap::new();
#(
#inserts_mut
)*
attributes
}
}
};

// Hand the output tokens back to the compiler.
TokenStream::from(expanded)
} else {
panic!("Only structs are supported for #[derive(PywrNode)]")
}
}

/// Returns the last segment of a type path as an identifier
fn type_to_ident(ty: &syn::Type) -> Option<PywrField> {
match ty {
// Match type's that are a path and not a self type.
syn::Type::Path(type_path) if type_path.qself.is_none() => {
// Match on the last segment
match type_path.path.segments.last() {
Some(last_segment) => {
let ident = &last_segment.ident;

if ident == "Option" {
// The last segment is an Option, now we need to parse the argument
// I.e. the bit in inside the angle brackets.
let first_arg = match &last_segment.arguments {
syn::PathArguments::AngleBracketed(params) => params.args.first(),
_ => None,
};

// Find type arguments; ignore others
let arg_ty = match first_arg {
Some(syn::GenericArgument::Type(ty)) => Some(ty),
_ => None,
};

// Match on path types that are no self types.
let arg_type_path = match arg_ty {
Some(ty) => match ty {
syn::Type::Path(type_path) if type_path.qself.is_none() => {
Some(type_path)
}
_ => None,
},
None => None,
};

// Get the last segment of the path
let last_segment = match arg_type_path {
Some(type_path) => type_path.path.segments.last(),
None => None,
};

// Finally, if there's a last segment return this as an optional `PywrField`
match last_segment {
Some(last_segment) => {
let ident = &last_segment.ident;
Some(PywrField::Optional(ident.clone()))
}
None => None,
}
} else {
// Otherwise, assume this a simple required field
Some(PywrField::Required(ident.clone()))
}
}
None => None,
}
}
_ => None,
}
}

fn is_parameter_ident(ident: &syn::Ident) -> bool {
// TODO this currenty omits more complex attributes, such as `factors` for AggregatedNode
// and steps for PiecewiseLinks, that can internally contain `DynamicFloatValue` fields
ident == "DynamicFloatValue"
}
1 change: 1 addition & 0 deletions pywr-schema/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ thiserror = { workspace = true }
pywr-v1-schema = { workspace = true }
pywr-core = { path="../pywr-core" }
chrono = { workspace = true, features = ["serde"] }
pywr-schema-macros = { path = "../pywr-schema-macros" }

[dev-dependencies]
tempfile = "3.3.0"
4 changes: 3 additions & 1 deletion pywr-schema/src/nodes/annual_virtual_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ use pywr_core::metric::Metric;
use pywr_core::models::ModelDomain;
use pywr_core::node::ConstraintValue;
use pywr_core::virtual_storage::VirtualStorageReset;
use pywr_schema_macros::PywrNode;
use pywr_v1_schema::nodes::AnnualVirtualStorageNode as AnnualVirtualStorageNodeV1;
use std::collections::HashMap;
use std::path::Path;

#[derive(serde::Deserialize, serde::Serialize, Clone)]
Expand All @@ -29,7 +31,7 @@ impl Default for AnnualReset {
}
}

#[derive(serde::Deserialize, serde::Serialize, Clone, Default)]
#[derive(serde::Deserialize, serde::Serialize, Clone, Default, PywrNode)]
pub struct AnnualVirtualStorageNode {
#[serde(flatten)]
pub meta: NodeMeta,
Expand Down
Loading