Skip to content

Commit

Permalink
Improve error messages for #[pyfunction] defined inside #[pymethods] (#…
Browse files Browse the repository at this point in the history
…4349)

* Improve error messages for #[pyfunction] defined inside #[pymethods]

Make error message more specific when `#[pyfunction]` is used in
`#[pymethods]`.

Effectively, this replaces the error message:

```
error: static method needs #[staticmethod] attribute
```

To:
```
functions inside #[pymethods] do not need to be annotated with #[pyfunction]
```

...and also removes the other misleading error messages to the function in
question.

Fixes #4340

Co-authored-by: László Vaskó <[email protected]>

* review fixes

---------

Co-authored-by: László Vaskó <[email protected]>
  • Loading branch information
2 people authored and davidhewitt committed Sep 3, 2024
1 parent be10c5e commit cd70cec
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 37 deletions.
1 change: 1 addition & 0 deletions newsfragments/4349.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve error messages for `#[pyfunction]` defined inside `#[pymethods]`
37 changes: 1 addition & 36 deletions pyo3-macros-backend/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
get_doc,
pyclass::PyClassPyO3Option,
pyfunction::{impl_wrap_pyfunction, PyFunctionOptions},
utils::{Ctx, LitCStr, PyO3CratePath},
utils::{has_attribute, has_attribute_with_namespace, Ctx, IdentOrStr, LitCStr},
};
use proc_macro2::{Span, TokenStream};
use quote::quote;
Expand Down Expand Up @@ -584,11 +584,6 @@ fn find_and_remove_attribute(attrs: &mut Vec<syn::Attribute>, ident: &str) -> bo
found
}

enum IdentOrStr<'a> {
Str(&'a str),
Ident(syn::Ident),
}

impl<'a> PartialEq<syn::Ident> for IdentOrStr<'a> {
fn eq(&self, other: &syn::Ident) -> bool {
match self {
Expand All @@ -597,36 +592,6 @@ impl<'a> PartialEq<syn::Ident> for IdentOrStr<'a> {
}
}
}
fn has_attribute(attrs: &[syn::Attribute], ident: &str) -> bool {
has_attribute_with_namespace(attrs, None, &[ident])
}

fn has_attribute_with_namespace(
attrs: &[syn::Attribute],
crate_path: Option<&PyO3CratePath>,
idents: &[&str],
) -> bool {
let mut segments = vec![];
if let Some(c) = crate_path {
match c {
PyO3CratePath::Given(paths) => {
for p in &paths.segments {
segments.push(IdentOrStr::Ident(p.ident.clone()));
}
}
PyO3CratePath::Default => segments.push(IdentOrStr::Str("pyo3")),
}
};
for i in idents {
segments.push(IdentOrStr::Str(i));
}

attrs.iter().any(|attr| {
segments
.iter()
.eq(attr.path().segments.iter().map(|v| &v.ident))
})
}

fn set_module_attribute(attrs: &mut Vec<syn::Attribute>, module_name: &str) {
attrs.push(parse_quote!(#[pyo3(module = #module_name)]));
Expand Down
25 changes: 24 additions & 1 deletion pyo3-macros-backend/src/pyimpl.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::HashSet;

use crate::utils::Ctx;
use crate::utils::{has_attribute, has_attribute_with_namespace, Ctx, PyO3CratePath};
use crate::{
attributes::{take_pyo3_options, CrateAttribute},
konst::{ConstAttributes, ConstSpec},
Expand All @@ -10,6 +10,7 @@ use crate::{
use proc_macro2::TokenStream;
use pymethod::GeneratedPyMethod;
use quote::{format_ident, quote};
use syn::ImplItemFn;
use syn::{
parse::{Parse, ParseStream},
spanned::Spanned,
Expand Down Expand Up @@ -84,6 +85,25 @@ pub fn build_py_methods(
}
}

fn check_pyfunction(pyo3_path: &PyO3CratePath, meth: &mut ImplItemFn) -> syn::Result<()> {
let mut error = None;

meth.attrs.retain(|attr| {
let attrs = [attr.clone()];

if has_attribute(&attrs, "pyfunction")
|| has_attribute_with_namespace(&attrs, Some(pyo3_path), &["pyfunction"])
|| has_attribute_with_namespace(&attrs, Some(pyo3_path), &["prelude", "pyfunction"]) {
error = Some(err_spanned!(meth.sig.span() => "functions inside #[pymethods] do not need to be annotated with #[pyfunction]"));
false
} else {
true
}
});

error.map_or(Ok(()), Err)
}

pub fn impl_methods(
ty: &syn::Type,
impls: &mut [syn::ImplItem],
Expand All @@ -103,6 +123,9 @@ pub fn impl_methods(
let ctx = &Ctx::new(&options.krate, Some(&meth.sig));
let mut fun_options = PyFunctionOptions::from_attrs(&mut meth.attrs)?;
fun_options.krate = fun_options.krate.or_else(|| options.krate.clone());

check_pyfunction(&ctx.pyo3_path, meth)?;

match pymethod::gen_py_method(ty, &mut meth.sig, &mut meth.attrs, fun_options, ctx)?
{
GeneratedPyMethod::Method(MethodAndMethodDef {
Expand Down
36 changes: 36 additions & 0 deletions pyo3-macros-backend/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,39 @@ pub fn apply_renaming_rule(rule: RenamingRule, name: &str) -> String {
pub(crate) fn is_abi3() -> bool {
pyo3_build_config::get().abi3
}

pub(crate) enum IdentOrStr<'a> {
Str(&'a str),
Ident(syn::Ident),
}

pub(crate) fn has_attribute(attrs: &[syn::Attribute], ident: &str) -> bool {
has_attribute_with_namespace(attrs, None, &[ident])
}

pub(crate) fn has_attribute_with_namespace(
attrs: &[syn::Attribute],
crate_path: Option<&PyO3CratePath>,
idents: &[&str],
) -> bool {
let mut segments = vec![];
if let Some(c) = crate_path {
match c {
PyO3CratePath::Given(paths) => {
for p in &paths.segments {
segments.push(IdentOrStr::Ident(p.ident.clone()));
}
}
PyO3CratePath::Default => segments.push(IdentOrStr::Str("pyo3")),
}
};
for i in idents {
segments.push(IdentOrStr::Str(i));
}

attrs.iter().any(|attr| {
segments
.iter()
.eq(attr.path().segments.iter().map(|v| &v.ident))
})
}
1 change: 1 addition & 0 deletions tests/test_compile_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ fn test_compile_errors() {
t.compile_fail("tests/ui/invalid_pyclass_enum.rs");
t.compile_fail("tests/ui/invalid_pyclass_item.rs");
t.compile_fail("tests/ui/invalid_pyfunction_signatures.rs");
t.compile_fail("tests/ui/invalid_pyfunction_definition.rs");
#[cfg(any(not(Py_LIMITED_API), Py_3_11))]
t.compile_fail("tests/ui/invalid_pymethods_buffer.rs");
// The output is not stable across abi3 / not abi3 and features
Expand Down
15 changes: 15 additions & 0 deletions tests/ui/invalid_pyfunction_definition.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#[pyo3::pymodule]
mod pyo3_scratch {
use pyo3::prelude::*;

#[pyclass]
struct Foo {}

#[pymethods]
impl Foo {
#[pyfunction]
fn bug() {}
}
}

fn main() {}
5 changes: 5 additions & 0 deletions tests/ui/invalid_pyfunction_definition.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
error: functions inside #[pymethods] do not need to be annotated with #[pyfunction]
--> tests/ui/invalid_pyfunction_definition.rs:11:9
|
11 | fn bug() {}
| ^^

0 comments on commit cd70cec

Please sign in to comment.