diff --git a/python/pylibraft/pylibraft/common/mdspan.pyx b/python/pylibraft/pylibraft/common/mdspan.pyx index c1a9188585..434eb59752 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pyx +++ b/python/pylibraft/pylibraft/common/mdspan.pyx @@ -22,6 +22,7 @@ import io import numpy as np +from cpython.buffer cimport PyBUF_FULL_RO, PyBuffer_Release, PyObject_GetBuffer from cpython.object cimport PyObject from cython.operator cimport dereference as deref from libc.stddef cimport size_t @@ -47,10 +48,6 @@ from pylibraft.common.optional cimport make_optional, optional from pylibraft.common import DeviceResources -cdef extern from "Python.h": - Py_buffer* PyMemoryView_GET_BUFFER(PyObject* mview) - - def run_roundtrip_test_for_mdspan(X, fortran_order=False): if not isinstance(X, np.ndarray) or len(X.shape) != 2: raise ValueError("Please call this function with a NumPy array with" @@ -59,6 +56,9 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False): cdef device_resources * handle_ = \ handle.getHandle() cdef ostringstream oss + cdef Py_buffer buf + PyObject_GetBuffer(X, &buf, PyBUF_FULL_RO) + cdef uintptr_t buf_ptr = buf.buf if X.dtype == np.float32: if fortran_order: serialize_mdspan[float, matrix_extent[size_t], col_major]( @@ -67,8 +67,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False): make_host_matrix_view[float, size_t, col_major]( - PyMemoryView_GET_BUFFER( - X.data).buf, + buf_ptr, X.shape[0], X.shape[1])) else: serialize_mdspan[float, matrix_extent[size_t], row_major]( @@ -77,8 +76,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False): make_host_matrix_view[float, size_t, row_major]( - PyMemoryView_GET_BUFFER( - X.data).buf, + buf_ptr, X.shape[0], X.shape[1])) elif X.dtype == np.float64: if fortran_order: @@ -88,8 +86,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False): make_host_matrix_view[double, size_t, col_major]( - PyMemoryView_GET_BUFFER( - X.data).buf, + buf_ptr, X.shape[0], X.shape[1])) else: serialize_mdspan[double, matrix_extent[size_t], row_major]( @@ -98,8 +95,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False): make_host_matrix_view[double, size_t, row_major]( - PyMemoryView_GET_BUFFER( - X.data).buf, + buf_ptr, X.shape[0], X.shape[1])) elif X.dtype == np.int32: if fortran_order: @@ -109,8 +105,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False): make_host_matrix_view[int32_t, size_t, col_major]( - PyMemoryView_GET_BUFFER( - X.data).buf, + buf_ptr, X.shape[0], X.shape[1])) else: serialize_mdspan[int32_t, matrix_extent[size_t], row_major]( @@ -119,8 +114,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False): make_host_matrix_view[int32_t, size_t, row_major]( - PyMemoryView_GET_BUFFER( - X.data).buf, + buf_ptr, X.shape[0], X.shape[1])) elif X.dtype == np.uint32: if fortran_order: @@ -130,8 +124,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False): make_host_matrix_view[uint32_t, size_t, col_major]( - PyMemoryView_GET_BUFFER( - X.data).buf, + buf_ptr, X.shape[0], X.shape[1])) else: serialize_mdspan[uint32_t, matrix_extent[size_t], row_major]( @@ -140,11 +133,12 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False): make_host_matrix_view[uint32_t, size_t, row_major]( - PyMemoryView_GET_BUFFER( - X.data).buf, + buf_ptr, X.shape[0], X.shape[1])) else: + PyBuffer_Release(&buf) raise NotImplementedError() + PyBuffer_Release(&buf) f = io.BytesIO(oss.str()) X2 = np.load(f) assert np.all(X.shape == X2.shape) diff --git a/python/pylibraft/pylibraft/neighbors/cpp/hnsw.pxd b/python/pylibraft/pylibraft/neighbors/cpp/hnsw.pxd index 75c0c14aad..7b2cf59c81 100644 --- a/python/pylibraft/pylibraft/neighbors/cpp/hnsw.pxd +++ b/python/pylibraft/pylibraft/neighbors/cpp/hnsw.pxd @@ -75,19 +75,7 @@ cdef extern from "raft_runtime/neighbors/hnsw.hpp" \ host_matrix_view[uint64_t, int64_t, row_major] neighbors, host_matrix_view[float, int64_t, row_major] distances) except + - cdef unique_ptr[index[float]] deserialize_file[float]( - const device_resources& handle, - const string& filename, - int dim, - DistanceType metric) except + - - cdef unique_ptr[index[int8_t]] deserialize_file[int8_t]( - const device_resources& handle, - const string& filename, - int dim, - DistanceType metric) except + - - cdef unique_ptr[index[uint8_t]] deserialize_file[uint8_t]( + cdef unique_ptr[index[T]] deserialize_file[T]( const device_resources& handle, const string& filename, int dim,