Skip to content

Commit

Permalink
Update dict.get_item binding to use PyDict_GetItemRef (#4355)
Browse files Browse the repository at this point in the history
* Update dict.get_item binding to use PyDict_GetItemRef

Refs #4265

* test: add test for dict.get_item error path

* test: add test for dict.get_item error path

* test: add test for dict.get_item error path

* fix: fix logic error in dict.get_item bindings

* update: apply david's review suggestions for dict.get_item bindings

* update: create ffi::compat to store compatibility shims

* update: move PyDict_GetItemRef bindings to spot in order from dictobject.h

* build: fix build warning with --no-default-features

* doc: expand release note fragments

* fix: fix clippy warnings

* respond to review comments

* Apply suggestion from @mejrs

* refactor so cfg is applied to functions

* properly set cfgs

* fix clippy lints

* Apply @davidhewitt's suggestion

* deal with upstream deprecation of new_bound
  • Loading branch information
ngoldbaum authored and davidhewitt committed Sep 3, 2024
1 parent cd70cec commit 6710dcd
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 5 deletions.
10 changes: 10 additions & 0 deletions newsfragments/4355.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
* Added an `ffi::compat` namespace to store compatibility shims for C API
functions added in recent versions of Python.

* Added bindings for `PyDict_GetItemRef` on Python 3.13 and newer. Also added
`ffi::compat::PyDict_GetItemRef` which re-exports the FFI binding on Python
3.13 or newer and defines a compatibility version on older versions of
Python. This function is inherently safer to use than `PyDict_GetItem` and has
an API that is easier to use than `PyDict_GetItemWithError`. It returns a
strong reference to value, as opposed to the two older functions which return
a possibly unsafe borrowed reference.
2 changes: 2 additions & 0 deletions newsfragments/4355.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Avoid creating temporary borrowed reference in dict.get_item bindings. Borrowed
references like this are unsafe in the free-threading build.
44 changes: 44 additions & 0 deletions pyo3-ffi/src/compat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//! C API Compatibility Shims
//!
//! Some CPython C API functions added in recent versions of Python are
//! inherently safer to use than older C API constructs. This module
//! exposes functions available on all Python versions that wrap the
//! old C API on old Python versions and wrap the function directly
//! on newer Python versions.

// Unless otherwise noted, the compatibility shims are adapted from
// the pythoncapi-compat project: https://github.com/python/pythoncapi-compat

#[cfg(not(Py_3_13))]
use crate::object::PyObject;
#[cfg(not(Py_3_13))]
use std::os::raw::c_int;

#[cfg_attr(docsrs, doc(cfg(all)))]
#[cfg(Py_3_13)]
pub use crate::dictobject::PyDict_GetItemRef;

#[cfg_attr(docsrs, doc(cfg(all)))]
#[cfg(not(Py_3_13))]
pub unsafe fn PyDict_GetItemRef(
dp: *mut PyObject,
key: *mut PyObject,
result: *mut *mut PyObject,
) -> c_int {
{
use crate::dictobject::PyDict_GetItemWithError;
use crate::object::_Py_NewRef;
use crate::pyerrors::PyErr_Occurred;

let item: *mut PyObject = PyDict_GetItemWithError(dp, key);
if !item.is_null() {
*result = _Py_NewRef(item);
return 1; // found
}
*result = std::ptr::null_mut();
if PyErr_Occurred().is_null() {
return 0; // not found
}
-1
}
}
6 changes: 6 additions & 0 deletions pyo3-ffi/src/dictobject.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ extern "C" {
) -> c_int;
#[cfg_attr(PyPy, link_name = "PyPyDict_DelItemString")]
pub fn PyDict_DelItemString(dp: *mut PyObject, key: *const c_char) -> c_int;
#[cfg(Py_3_13)]
pub fn PyDict_GetItemRef(
dp: *mut PyObject,
key: *mut PyObject,
result: *mut *mut PyObject,
) -> c_int;
// skipped 3.10 / ex-non-limited PyObject_GenericGetDict
}

Expand Down
3 changes: 3 additions & 0 deletions pyo3-ffi/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
//! Raw FFI declarations for Python's C API.
//!
//! PyO3 can be used to write native Python modules or run Python code and modules from Rust.
Expand Down Expand Up @@ -290,6 +291,8 @@ pub const fn _cstr_from_utf8_with_nul_checked(s: &str) -> &CStr {

use std::ffi::CStr;

pub mod compat;

pub use self::abstract_::*;
pub use self::bltinmodule::*;
pub use self::boolobject::*;
Expand Down
46 changes: 41 additions & 5 deletions src/types/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -436,13 +436,13 @@ impl<'py> PyDictMethods<'py> for Bound<'py, PyDict> {
key: Bound<'_, PyAny>,
) -> PyResult<Option<Bound<'py, PyAny>>> {
let py = dict.py();
let mut result: *mut ffi::PyObject = std::ptr::null_mut();
match unsafe {
ffi::PyDict_GetItemWithError(dict.as_ptr(), key.as_ptr())
.assume_borrowed_or_opt(py)
.map(Borrowed::to_owned)
ffi::compat::PyDict_GetItemRef(dict.as_ptr(), key.as_ptr(), &mut result)
} {
some @ Some(_) => Ok(some),
None => PyErr::take(py).map(Err).transpose(),
std::os::raw::c_int::MIN..=-1 => Err(PyErr::fetch(py)),
0 => Ok(None),
1..=std::os::raw::c_int::MAX => Ok(Some(unsafe { result.assume_owned(py) })),
}
}

Expand Down Expand Up @@ -957,6 +957,42 @@ mod tests {
});
}

#[cfg(feature = "macros")]
#[test]
fn test_get_item_error_path() {
use crate::exceptions::PyTypeError;

#[crate::pyclass(crate = "crate")]
struct HashErrors;

#[crate::pymethods(crate = "crate")]
impl HashErrors {
#[new]
fn new() -> Self {
HashErrors {}
}

fn __hash__(&self) -> PyResult<isize> {
Err(PyTypeError::new_err("Error from __hash__"))
}
}

Python::with_gil(|py| {
let class = py.get_type_bound::<HashErrors>();
let instance = class.call0().unwrap();
let d = PyDict::new_bound(py);
match d.get_item(instance) {
Ok(_) => {
panic!("this get_item call should always error")
}
Err(err) => {
assert!(err.is_instance_of::<PyTypeError>(py));
assert_eq!(err.value_bound(py).to_string(), "Error from __hash__")
}
}
})
}

#[test]
#[allow(deprecated)]
#[cfg(all(not(any(PyPy, GraalPy)), feature = "gil-refs"))]
Expand Down

0 comments on commit 6710dcd

Please sign in to comment.