From cd70cecf40d20cb6e8db3bcfc9db91ce2d984aba Mon Sep 17 00:00:00 2001 From: Zsolt Cserna Date: Wed, 24 Jul 2024 10:02:19 +0200 Subject: [PATCH] Improve error messages for #[pyfunction] defined inside #[pymethods] (#4349) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Improve error messages for #[pyfunction] defined inside #[pymethods] Make error message more specific when `#[pyfunction]` is used in `#[pymethods]`. Effectively, this replaces the error message: ``` error: static method needs #[staticmethod] attribute ``` To: ``` functions inside #[pymethods] do not need to be annotated with #[pyfunction] ``` ...and also removes the other misleading error messages to the function in question. Fixes #4340 Co-authored-by: László Vaskó <1771332+vlaci@users.noreply.github.com> * review fixes --------- Co-authored-by: László Vaskó <1771332+vlaci@users.noreply.github.com> --- newsfragments/4349.fixed.md | 1 + pyo3-macros-backend/src/module.rs | 37 +------------------ pyo3-macros-backend/src/pyimpl.rs | 25 ++++++++++++- pyo3-macros-backend/src/utils.rs | 36 ++++++++++++++++++ tests/test_compile_error.rs | 1 + tests/ui/invalid_pyfunction_definition.rs | 15 ++++++++ tests/ui/invalid_pyfunction_definition.stderr | 5 +++ 7 files changed, 83 insertions(+), 37 deletions(-) create mode 100644 newsfragments/4349.fixed.md create mode 100644 tests/ui/invalid_pyfunction_definition.rs create mode 100644 tests/ui/invalid_pyfunction_definition.stderr diff --git a/newsfragments/4349.fixed.md b/newsfragments/4349.fixed.md new file mode 100644 index 00000000000..0895ffa1ae1 --- /dev/null +++ b/newsfragments/4349.fixed.md @@ -0,0 +1 @@ +Improve error messages for `#[pyfunction]` defined inside `#[pymethods]` diff --git a/pyo3-macros-backend/src/module.rs b/pyo3-macros-backend/src/module.rs index b729640b43e..785d3ac78d2 100644 --- a/pyo3-macros-backend/src/module.rs +++ b/pyo3-macros-backend/src/module.rs @@ -8,7 +8,7 @@ use crate::{ get_doc, pyclass::PyClassPyO3Option, pyfunction::{impl_wrap_pyfunction, PyFunctionOptions}, - utils::{Ctx, LitCStr, PyO3CratePath}, + utils::{has_attribute, has_attribute_with_namespace, Ctx, IdentOrStr, LitCStr}, }; use proc_macro2::{Span, TokenStream}; use quote::quote; @@ -584,11 +584,6 @@ fn find_and_remove_attribute(attrs: &mut Vec, ident: &str) -> bo found } -enum IdentOrStr<'a> { - Str(&'a str), - Ident(syn::Ident), -} - impl<'a> PartialEq for IdentOrStr<'a> { fn eq(&self, other: &syn::Ident) -> bool { match self { @@ -597,36 +592,6 @@ impl<'a> PartialEq for IdentOrStr<'a> { } } } -fn has_attribute(attrs: &[syn::Attribute], ident: &str) -> bool { - has_attribute_with_namespace(attrs, None, &[ident]) -} - -fn has_attribute_with_namespace( - attrs: &[syn::Attribute], - crate_path: Option<&PyO3CratePath>, - idents: &[&str], -) -> bool { - let mut segments = vec![]; - if let Some(c) = crate_path { - match c { - PyO3CratePath::Given(paths) => { - for p in &paths.segments { - segments.push(IdentOrStr::Ident(p.ident.clone())); - } - } - PyO3CratePath::Default => segments.push(IdentOrStr::Str("pyo3")), - } - }; - for i in idents { - segments.push(IdentOrStr::Str(i)); - } - - attrs.iter().any(|attr| { - segments - .iter() - .eq(attr.path().segments.iter().map(|v| &v.ident)) - }) -} fn set_module_attribute(attrs: &mut Vec, module_name: &str) { attrs.push(parse_quote!(#[pyo3(module = #module_name)])); diff --git a/pyo3-macros-backend/src/pyimpl.rs b/pyo3-macros-backend/src/pyimpl.rs index 6807f90831e..cca14ad1eff 100644 --- a/pyo3-macros-backend/src/pyimpl.rs +++ b/pyo3-macros-backend/src/pyimpl.rs @@ -1,6 +1,6 @@ use std::collections::HashSet; -use crate::utils::Ctx; +use crate::utils::{has_attribute, has_attribute_with_namespace, Ctx, PyO3CratePath}; use crate::{ attributes::{take_pyo3_options, CrateAttribute}, konst::{ConstAttributes, ConstSpec}, @@ -10,6 +10,7 @@ use crate::{ use proc_macro2::TokenStream; use pymethod::GeneratedPyMethod; use quote::{format_ident, quote}; +use syn::ImplItemFn; use syn::{ parse::{Parse, ParseStream}, spanned::Spanned, @@ -84,6 +85,25 @@ pub fn build_py_methods( } } +fn check_pyfunction(pyo3_path: &PyO3CratePath, meth: &mut ImplItemFn) -> syn::Result<()> { + let mut error = None; + + meth.attrs.retain(|attr| { + let attrs = [attr.clone()]; + + if has_attribute(&attrs, "pyfunction") + || has_attribute_with_namespace(&attrs, Some(pyo3_path), &["pyfunction"]) + || has_attribute_with_namespace(&attrs, Some(pyo3_path), &["prelude", "pyfunction"]) { + error = Some(err_spanned!(meth.sig.span() => "functions inside #[pymethods] do not need to be annotated with #[pyfunction]")); + false + } else { + true + } + }); + + error.map_or(Ok(()), Err) +} + pub fn impl_methods( ty: &syn::Type, impls: &mut [syn::ImplItem], @@ -103,6 +123,9 @@ pub fn impl_methods( let ctx = &Ctx::new(&options.krate, Some(&meth.sig)); let mut fun_options = PyFunctionOptions::from_attrs(&mut meth.attrs)?; fun_options.krate = fun_options.krate.or_else(|| options.krate.clone()); + + check_pyfunction(&ctx.pyo3_path, meth)?; + match pymethod::gen_py_method(ty, &mut meth.sig, &mut meth.attrs, fun_options, ctx)? { GeneratedPyMethod::Method(MethodAndMethodDef { diff --git a/pyo3-macros-backend/src/utils.rs b/pyo3-macros-backend/src/utils.rs index 005884a557c..350abb6bbf6 100644 --- a/pyo3-macros-backend/src/utils.rs +++ b/pyo3-macros-backend/src/utils.rs @@ -291,3 +291,39 @@ pub fn apply_renaming_rule(rule: RenamingRule, name: &str) -> String { pub(crate) fn is_abi3() -> bool { pyo3_build_config::get().abi3 } + +pub(crate) enum IdentOrStr<'a> { + Str(&'a str), + Ident(syn::Ident), +} + +pub(crate) fn has_attribute(attrs: &[syn::Attribute], ident: &str) -> bool { + has_attribute_with_namespace(attrs, None, &[ident]) +} + +pub(crate) fn has_attribute_with_namespace( + attrs: &[syn::Attribute], + crate_path: Option<&PyO3CratePath>, + idents: &[&str], +) -> bool { + let mut segments = vec![]; + if let Some(c) = crate_path { + match c { + PyO3CratePath::Given(paths) => { + for p in &paths.segments { + segments.push(IdentOrStr::Ident(p.ident.clone())); + } + } + PyO3CratePath::Default => segments.push(IdentOrStr::Str("pyo3")), + } + }; + for i in idents { + segments.push(IdentOrStr::Str(i)); + } + + attrs.iter().any(|attr| { + segments + .iter() + .eq(attr.path().segments.iter().map(|v| &v.ident)) + }) +} diff --git a/tests/test_compile_error.rs b/tests/test_compile_error.rs index 9e8b3b1a593..20ef0b799a5 100644 --- a/tests/test_compile_error.rs +++ b/tests/test_compile_error.rs @@ -11,6 +11,7 @@ fn test_compile_errors() { t.compile_fail("tests/ui/invalid_pyclass_enum.rs"); t.compile_fail("tests/ui/invalid_pyclass_item.rs"); t.compile_fail("tests/ui/invalid_pyfunction_signatures.rs"); + t.compile_fail("tests/ui/invalid_pyfunction_definition.rs"); #[cfg(any(not(Py_LIMITED_API), Py_3_11))] t.compile_fail("tests/ui/invalid_pymethods_buffer.rs"); // The output is not stable across abi3 / not abi3 and features diff --git a/tests/ui/invalid_pyfunction_definition.rs b/tests/ui/invalid_pyfunction_definition.rs new file mode 100644 index 00000000000..2f08ff421b9 --- /dev/null +++ b/tests/ui/invalid_pyfunction_definition.rs @@ -0,0 +1,15 @@ +#[pyo3::pymodule] +mod pyo3_scratch { + use pyo3::prelude::*; + + #[pyclass] + struct Foo {} + + #[pymethods] + impl Foo { + #[pyfunction] + fn bug() {} + } +} + +fn main() {} diff --git a/tests/ui/invalid_pyfunction_definition.stderr b/tests/ui/invalid_pyfunction_definition.stderr new file mode 100644 index 00000000000..9c7cac1f0f3 --- /dev/null +++ b/tests/ui/invalid_pyfunction_definition.stderr @@ -0,0 +1,5 @@ +error: functions inside #[pymethods] do not need to be annotated with #[pyfunction] + --> tests/ui/invalid_pyfunction_definition.rs:11:9 + | +11 | fn bug() {} + | ^^