diff --git a/cuda_py/cuda/py/_dlpack.pyx b/cuda_py/cuda/py/_dlpack.pyx index b3037878..dda08eab 100644 --- a/cuda_py/cuda/py/_dlpack.pyx +++ b/cuda_py/cuda/py/_dlpack.pyx @@ -41,7 +41,7 @@ cdef void versioned_deleter(DLManagedTensorVersioned* tensor) noexcept with gil: stdlib.free(tensor) -cpdef object make_py_capsule(object buf, bint versioned) except +: +cpdef object make_py_capsule(object buf, bint versioned): cdef DLManagedTensor* dlm_tensor cdef DLManagedTensorVersioned* dlm_tensor_ver cdef DLTensor* dl_tensor diff --git a/cuda_py/cuda/py/_memoryview.pyx b/cuda_py/cuda/py/_memoryview.pyx index c746df5e..b45c9f02 100644 --- a/cuda_py/cuda/py/_memoryview.pyx +++ b/cuda_py/cuda/py/_memoryview.pyx @@ -9,8 +9,14 @@ from ._dlpack cimport * import functools from typing import Any, Optional +from cuda import cuda import numpy +from cuda.py._utils import handle_return + + +# TODO(leofang): support NumPy structured dtypes + @cython.dataclasses.dataclass cdef class GPUMemoryView: @@ -37,6 +43,7 @@ cdef class GPUMemoryView: cdef str get_simple_repr(obj): + # TODO: better handling in np.dtype objects cdef object obj_class cdef str obj_repr if isinstance(obj, type): @@ -71,8 +78,7 @@ cdef class _GPUMemoryViewProxy: if self.has_dlpack: return view_as_dlpack(self.obj, stream_ptr) else: - # TODO: Support CAI - raise NotImplementedError("TODO") + return view_as_cai(self.obj, stream_ptr) cdef GPUMemoryView view_as_dlpack(obj, stream_ptr): @@ -216,7 +222,49 @@ cdef object dtype_dlpack_to_numpy(DLDataType* dtype): else: raise TypeError('Unsupported dtype. dtype code: {}'.format(dtype.code)) - return np_dtype + # We want the dtype object not just the type object + return numpy.dtype(np_dtype) + + +cdef GPUMemoryView view_as_cai(obj, stream_ptr): + cdef dict cai_data = obj.__cuda_array_interface__ + if cai_data["version"] < 3: + raise BufferError("only CUDA Array Interface v3 or above is supported") + if cai_data.get("mask") is not None: + raise BufferError("mask is not supported") + if stream_ptr is None: + raise BufferError("stream=None is ambiguous with view()") + + cdef GPUMemoryView buf = GPUMemoryView() + buf.obj = obj + buf.ptr, buf.readonly = cai_data["data"] + buf.shape = cai_data["shape"] + # TODO: this only works for built-in numeric types + buf.dtype = numpy.dtype(cai_data["typestr"]) + buf.strides = cai_data.get("strides") + if buf.strides is not None: + # convert to counts + buf.strides = tuple(s // buf.dtype.itemsize for s in buf.strides) + buf.device_accessible = True + buf.device_id = handle_return( + cuda.cuPointerGetAttribute( + cuda.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL, + buf.ptr)) + + cdef intptr_t producer_s, consumer_s + stream = cai_data.get("stream") + if stream is not None: + producer_s = (stream) + consumer_s = (stream_ptr) + assert producer_s > 0 + # establish stream order + if producer_s != consumer_s: + e = handle_return(cuda.cuEventCreate( + cuda.CUevent_flags.CU_EVENT_DISABLE_TIMING)) + handle_return(cuda.cuEventRecord(e, producer_s)) + handle_return(cuda.cuStreamWaitEvent(consumer_s, e, 0)) + + return buf def viewable(tuple arg_indices):