Skip to content

Commit

Permalink
support CAI too
Browse files Browse the repository at this point in the history
  • Loading branch information
leofang committed Sep 6, 2024
1 parent 94ec937 commit 60682de
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 4 deletions.
2 changes: 1 addition & 1 deletion cuda_py/cuda/py/_dlpack.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ cdef void versioned_deleter(DLManagedTensorVersioned* tensor) noexcept with gil:
stdlib.free(tensor)


cpdef object make_py_capsule(object buf, bint versioned) except +:
cpdef object make_py_capsule(object buf, bint versioned):
cdef DLManagedTensor* dlm_tensor
cdef DLManagedTensorVersioned* dlm_tensor_ver
cdef DLTensor* dl_tensor
Expand Down
54 changes: 51 additions & 3 deletions cuda_py/cuda/py/_memoryview.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,14 @@ from ._dlpack cimport *
import functools
from typing import Any, Optional

from cuda import cuda
import numpy

from cuda.py._utils import handle_return


# TODO(leofang): support NumPy structured dtypes


@cython.dataclasses.dataclass
cdef class GPUMemoryView:
Expand All @@ -37,6 +43,7 @@ cdef class GPUMemoryView:


cdef str get_simple_repr(obj):
# TODO: better handling in np.dtype objects
cdef object obj_class
cdef str obj_repr
if isinstance(obj, type):
Expand Down Expand Up @@ -71,8 +78,7 @@ cdef class _GPUMemoryViewProxy:
if self.has_dlpack:
return view_as_dlpack(self.obj, stream_ptr)
else:
# TODO: Support CAI
raise NotImplementedError("TODO")
return view_as_cai(self.obj, stream_ptr)


cdef GPUMemoryView view_as_dlpack(obj, stream_ptr):
Expand Down Expand Up @@ -216,7 +222,49 @@ cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
else:
raise TypeError('Unsupported dtype. dtype code: {}'.format(dtype.code))

return np_dtype
# We want the dtype object not just the type object
return numpy.dtype(np_dtype)


cdef GPUMemoryView view_as_cai(obj, stream_ptr):
cdef dict cai_data = obj.__cuda_array_interface__
if cai_data["version"] < 3:
raise BufferError("only CUDA Array Interface v3 or above is supported")
if cai_data.get("mask") is not None:
raise BufferError("mask is not supported")
if stream_ptr is None:
raise BufferError("stream=None is ambiguous with view()")

cdef GPUMemoryView buf = GPUMemoryView()
buf.obj = obj
buf.ptr, buf.readonly = cai_data["data"]
buf.shape = cai_data["shape"]
# TODO: this only works for built-in numeric types
buf.dtype = numpy.dtype(cai_data["typestr"])
buf.strides = cai_data.get("strides")
if buf.strides is not None:
# convert to counts
buf.strides = tuple(s // buf.dtype.itemsize for s in buf.strides)
buf.device_accessible = True
buf.device_id = handle_return(
cuda.cuPointerGetAttribute(
cuda.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
buf.ptr))

cdef intptr_t producer_s, consumer_s
stream = cai_data.get("stream")
if stream is not None:
producer_s = <intptr_t>(stream)
consumer_s = <intptr_t>(stream_ptr)
assert producer_s > 0
# establish stream order
if producer_s != consumer_s:
e = handle_return(cuda.cuEventCreate(
cuda.CUevent_flags.CU_EVENT_DISABLE_TIMING))
handle_return(cuda.cuEventRecord(e, producer_s))
handle_return(cuda.cuStreamWaitEvent(consumer_s, e, 0))

return buf


def viewable(tuple arg_indices):
Expand Down

0 comments on commit 60682de

Please sign in to comment.