From fbfeb2ff034027201e24dc46cd70e50016970f7d Mon Sep 17 00:00:00 2001 From: Icxolu <10486322+Icxolu@users.noreply.github.com> Date: Tue, 13 Feb 2024 01:09:41 +0100 Subject: [PATCH] update `#[derive(FromPyObject)]` to use `extract_bound` (#3828) * update `#[derive(FromPyObject)]` to use `extract_bound` * type inference for `from_py_with` using function pointers --- pyo3-macros-backend/src/frompyobject.rs | 15 ++++---- src/impl_/frompyobject.rs | 48 +++++++++++++++++++------ tests/test_frompyobject.rs | 16 ++++++--- 3 files changed, 57 insertions(+), 22 deletions(-) diff --git a/pyo3-macros-backend/src/frompyobject.rs b/pyo3-macros-backend/src/frompyobject.rs index f4774d88c0b..5e193bf4a24 100644 --- a/pyo3-macros-backend/src/frompyobject.rs +++ b/pyo3-macros-backend/src/frompyobject.rs @@ -271,7 +271,7 @@ impl<'a> Container<'a> { value: expr_path, .. }) => quote! { Ok(#self_ty { - #ident: _pyo3::impl_::frompyobject::extract_struct_field_with(#expr_path, obj, #struct_name, #field_name)? + #ident: _pyo3::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, obj, #struct_name, #field_name)? }) }, } @@ -283,7 +283,7 @@ impl<'a> Container<'a> { Some(FromPyWithAttribute { value: expr_path, .. }) => quote! ( - _pyo3::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path, obj, #struct_name, 0).map(#self_ty) + _pyo3::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path as fn(_) -> _, obj, #struct_name, 0).map(#self_ty) ), } } @@ -298,12 +298,12 @@ impl<'a> Container<'a> { let fields = struct_fields.iter().zip(&field_idents).enumerate().map(|(index, (field, ident))| { match &field.from_py_with { None => quote!( - _pyo3::impl_::frompyobject::extract_tuple_struct_field(#ident, #struct_name, #index)? + _pyo3::impl_::frompyobject::extract_tuple_struct_field(&#ident, #struct_name, #index)? ), Some(FromPyWithAttribute { value: expr_path, .. }) => quote! ( - _pyo3::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path, #ident, #struct_name, #index)? + _pyo3::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path as fn(_) -> _, &#ident, #struct_name, #index)? ), } }); @@ -339,12 +339,12 @@ impl<'a> Container<'a> { }; let extractor = match &field.from_py_with { None => { - quote!(_pyo3::impl_::frompyobject::extract_struct_field(obj.#getter?, #struct_name, #field_name)?) + quote!(_pyo3::impl_::frompyobject::extract_struct_field(&obj.#getter?, #struct_name, #field_name)?) } Some(FromPyWithAttribute { value: expr_path, .. }) => { - quote! (_pyo3::impl_::frompyobject::extract_struct_field_with(#expr_path, obj.#getter?, #struct_name, #field_name)?) + quote! (_pyo3::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, &obj.#getter?, #struct_name, #field_name)?) } }; @@ -606,10 +606,11 @@ pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result { Ok(quote!( const _: () = { use #krate as _pyo3; + use _pyo3::prelude::PyAnyMethods; #[automatically_derived] impl #trait_generics _pyo3::FromPyObject<#lt_param> for #ident #generics #where_clause { - fn extract(obj: &#lt_param _pyo3::PyAny) -> _pyo3::PyResult { + fn extract_bound(obj: &_pyo3::Bound<#lt_param, _pyo3::PyAny>) -> _pyo3::PyResult { #derives } } diff --git a/src/impl_/frompyobject.rs b/src/impl_/frompyobject.rs index a0c7b13df7c..5ab595ca784 100644 --- a/src/impl_/frompyobject.rs +++ b/src/impl_/frompyobject.rs @@ -1,5 +1,33 @@ +use crate::types::any::PyAnyMethods; +use crate::Bound; use crate::{exceptions::PyTypeError, FromPyObject, PyAny, PyErr, PyResult, Python}; +pub enum Extractor<'a, 'py, T> { + Bound(fn(&'a Bound<'py, PyAny>) -> PyResult), + GilRef(fn(&'a PyAny) -> PyResult), +} + +impl<'a, 'py, T> From) -> PyResult> for Extractor<'a, 'py, T> { + fn from(value: fn(&'a Bound<'py, PyAny>) -> PyResult) -> Self { + Self::Bound(value) + } +} + +impl<'a, T> From PyResult> for Extractor<'a, '_, T> { + fn from(value: fn(&'a PyAny) -> PyResult) -> Self { + Self::GilRef(value) + } +} + +impl<'a, 'py, T> Extractor<'a, 'py, T> { + fn call(self, obj: &'a Bound<'py, PyAny>) -> PyResult { + match self { + Extractor::Bound(f) => f(obj), + Extractor::GilRef(f) => f(obj.as_gil_ref()), + } + } +} + #[cold] pub fn failed_to_extract_enum( py: Python<'_>, @@ -41,7 +69,7 @@ fn extract_traceback(py: Python<'_>, mut error: PyErr) -> String { } pub fn extract_struct_field<'py, T>( - obj: &'py PyAny, + obj: &Bound<'py, PyAny>, struct_name: &str, field_name: &str, ) -> PyResult @@ -59,13 +87,13 @@ where } } -pub fn extract_struct_field_with<'py, T>( - extractor: impl FnOnce(&'py PyAny) -> PyResult, - obj: &'py PyAny, +pub fn extract_struct_field_with<'a, 'py, T>( + extractor: impl Into>, + obj: &'a Bound<'py, PyAny>, struct_name: &str, field_name: &str, ) -> PyResult { - match extractor(obj) { + match extractor.into().call(obj) { Ok(value) => Ok(value), Err(err) => Err(failed_to_extract_struct_field( obj.py(), @@ -92,7 +120,7 @@ fn failed_to_extract_struct_field( } pub fn extract_tuple_struct_field<'py, T>( - obj: &'py PyAny, + obj: &Bound<'py, PyAny>, struct_name: &str, index: usize, ) -> PyResult @@ -110,13 +138,13 @@ where } } -pub fn extract_tuple_struct_field_with<'py, T>( - extractor: impl FnOnce(&'py PyAny) -> PyResult, - obj: &'py PyAny, +pub fn extract_tuple_struct_field_with<'a, 'py, T>( + extractor: impl Into>, + obj: &'a Bound<'py, PyAny>, struct_name: &str, index: usize, ) -> PyResult { - match extractor(obj) { + match extractor.into().call(obj) { Ok(value) => Ok(value), Err(err) => Err(failed_to_extract_tuple_struct_field( obj.py(), diff --git a/tests/test_frompyobject.rs b/tests/test_frompyobject.rs index c475d8ea81f..5c57a954023 100644 --- a/tests/test_frompyobject.rs +++ b/tests/test_frompyobject.rs @@ -502,7 +502,7 @@ pub struct Zap { #[pyo3(item)] name: String, - #[pyo3(from_py_with = "PyAny::len", item("my_object"))] + #[pyo3(from_py_with = "Bound::<'_, PyAny>::len", item("my_object"))] some_object_length: usize, } @@ -525,7 +525,10 @@ fn test_from_py_with() { } #[derive(Debug, FromPyObject)] -pub struct ZapTuple(String, #[pyo3(from_py_with = "PyAny::len")] usize); +pub struct ZapTuple( + String, + #[pyo3(from_py_with = "Bound::<'_, PyAny>::len")] usize, +); #[test] fn test_from_py_with_tuple_struct() { @@ -560,8 +563,11 @@ fn test_from_py_with_tuple_struct_error() { #[derive(Debug, FromPyObject, PartialEq, Eq)] pub enum ZapEnum { - Zip(#[pyo3(from_py_with = "PyAny::len")] usize), - Zap(String, #[pyo3(from_py_with = "PyAny::len")] usize), + Zip(#[pyo3(from_py_with = "Bound::<'_, PyAny>::len")] usize), + Zap( + String, + #[pyo3(from_py_with = "Bound::<'_, PyAny>::len")] usize, + ), } #[test] @@ -581,7 +587,7 @@ fn test_from_py_with_enum() { #[derive(Debug, FromPyObject, PartialEq, Eq)] #[pyo3(transparent)] pub struct TransparentFromPyWith { - #[pyo3(from_py_with = "PyAny::len")] + #[pyo3(from_py_with = "Bound::<'_, PyAny>::len")] len: usize, }