From 0a2c85328dadd6ddfed92917fb559e8d9b57d8b8 Mon Sep 17 00:00:00 2001 From: Dmitry Ustalov Date: Sun, 1 Dec 2024 21:49:56 +0100 Subject: [PATCH] Update pyo3 and rust-numpy --- Cargo.toml | 6 +++--- src/bradley_terry.rs | 5 ++++- src/python.rs | 36 +++++++++++++----------------------- 3 files changed, 20 insertions(+), 27 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7ca6dfe..a103684 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,10 +14,10 @@ crate-type = ["cdylib"] [dependencies] approx = "^0.5.1" -ndarray = "^0.16.1" +ndarray = "^0.16.1" # numpy supports only >= 0.15, < 0.17 num-traits = "^0.2.19" -pyo3 = { version = "^0.22.3", features = ["extension-module", "abi3-py38"], optional = true } -numpy = { version = "^0.22.0", optional = true } +pyo3 = { version = "^0.23.2", features = ["extension-module", "abi3-py38"], optional = true } +numpy = { version = "^0.23.0", optional = true } [features] python = ["dep:pyo3", "dep:numpy"] diff --git a/src/bradley_terry.rs b/src/bradley_terry.rs index 7eb801f..bf8bcc4 100644 --- a/src/bradley_terry.rs +++ b/src/bradley_terry.rs @@ -85,7 +85,10 @@ pub fn newman( v = one_nan_to_num(v_new, tolerance); - let broadcast_scores_t = scores.clone().into_shape_with_order((1, scores.len())).unwrap(); + let broadcast_scores_t = scores + .clone() + .into_shape_with_order((1, scores.len())) + .unwrap(); let sqrt_scores_outer = (&broadcast_scores_t * &broadcast_scores_t.t()).mapv_into(f64::sqrt); let sum_scores = &broadcast_scores_t + &broadcast_scores_t.t(); diff --git a/src/python.rs b/src/python.rs index 08a462b..e171a3e 100644 --- a/src/python.rs +++ b/src/python.rs @@ -39,13 +39,12 @@ unsafe impl Element for Winner { Clone::clone(self) } - fn get_dtype_bound(py: Python<'_>) -> Bound<'_, PyArrayDescr> { - numpy::dtype_bound::(py) + fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> { + numpy::dtype::(py) } } create_exception!(evalica, LengthMismatchError, PyValueError); - #[pyfunction] fn matrices_pyo3<'py>( py: Python<'py>, @@ -63,8 +62,8 @@ fn matrices_pyo3<'py>( total, ) { Ok((wins, ties)) => Ok(( - wins.into_pyarray_bound(py).unbind(), - ties.into_pyarray_bound(py).unbind(), + wins.into_pyarray(py).unbind(), + ties.into_pyarray(py).unbind(), )), Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")), } @@ -77,7 +76,7 @@ fn pairwise_scores_pyo3<'py>( ) -> PyResult>> { let pairwise = pairwise_scores(&scores.as_array()); - Ok(pairwise.into_pyarray_bound(py).unbind()) + Ok(pairwise.into_pyarray(py).unbind()) } #[pyfunction] @@ -100,7 +99,7 @@ fn counting_pyo3<'py>( win_weight, tie_weight, ) { - Ok(scores) => Ok(scores.into_pyarray_bound(py).unbind()), + Ok(scores) => Ok(scores.into_pyarray(py).unbind()), Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")), } } @@ -125,7 +124,7 @@ fn average_win_rate_pyo3<'py>( win_weight, tie_weight, ) { - Ok(scores) => Ok(scores.into_pyarray_bound(py).unbind()), + Ok(scores) => Ok(scores.into_pyarray(py).unbind()), Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")), } } @@ -160,9 +159,7 @@ fn bradley_terry_pyo3<'py>( ); match bradley_terry(&matrix.view(), tolerance, limit) { - Ok((scores, iterations)) => { - Ok((scores.into_pyarray_bound(py).unbind(), iterations)) - } + Ok((scores, iterations)) => Ok((scores.into_pyarray(py).into(), iterations)), Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")), } } @@ -206,7 +203,7 @@ fn newman_pyo3<'py>( limit, ) { Ok((scores, v, iterations)) => { - Ok((scores.into_pyarray_bound(py).unbind(), v, iterations)) + Ok((scores.into_pyarray(py).unbind(), v, iterations)) } Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")), } @@ -243,7 +240,7 @@ fn elo_pyo3<'py>( win_weight, tie_weight, ) { - Ok(scores) => Ok(scores.into_pyarray_bound(py).unbind()), + Ok(scores) => Ok(scores.into_pyarray(py).unbind()), Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")), } } @@ -278,9 +275,7 @@ fn eigen_pyo3<'py>( ); match eigen(&matrix.view(), tolerance, limit) { - Ok((scores, iterations)) => { - Ok((scores.into_pyarray_bound(py).unbind(), iterations)) - } + Ok((scores, iterations)) => Ok((scores.into_pyarray(py).unbind(), iterations)), Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")), } } @@ -319,9 +314,7 @@ fn pagerank_pyo3<'py>( ); match pagerank(&matrix.view(), damping, tolerance, limit) { - Ok((scores, iterations)) => { - Ok((scores.into_pyarray_bound(py).unbind(), iterations)) - } + Ok((scores, iterations)) => Ok((scores.into_pyarray(py).unbind(), iterations)), Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")), } } @@ -332,10 +325,7 @@ fn pagerank_pyo3<'py>( #[pymodule] fn evalica(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add("__version__", env!("CARGO_PKG_VERSION"))?; - m.add( - "LengthMismatchError", - py.get_type_bound::(), - )?; + m.add("LengthMismatchError", py.get_type::())?; m.add_function(wrap_pyfunction!(matrices_pyo3, m)?)?; m.add_function(wrap_pyfunction!(pairwise_scores_pyo3, m)?)?; m.add_function(wrap_pyfunction!(counting_pyo3, m)?)?;