Skip to content

Commit

Permalink
support dlpack 1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
leofang committed Sep 4, 2024
1 parent 16f541d commit 8c49acc
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 35 deletions.
98 changes: 77 additions & 21 deletions cuda_py/cuda/py/_dlpack.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,20 @@ cimport cpython
from libc cimport stdlib
from libc.stdint cimport uint8_t
from libc.stdint cimport uint16_t
from libc.stdint cimport uint32_t
from libc.stdint cimport int32_t
from libc.stdint cimport int64_t
from libc.stdint cimport uint64_t
from libc.stdint cimport intptr_t
from libcpp.vector cimport vector

from enum import IntEnum


cdef extern from "dlpack.h" nogil:

"""
#define DLPACK_TENSOR_UNUSED_NAME "dltensor"
#define DLPACK_VERSIONED_TENSOR_UNUSED_NAME "dltensor_versioned"
"""
ctypedef enum _DLDeviceType "DLDeviceType":
_kDLCPU "kDLCPU"
_kDLCUDA "kDLCUDA"
Expand Down Expand Up @@ -52,33 +55,89 @@ cdef extern from "dlpack.h" nogil:
void* manager_ctx
void (*deleter)(DLManagedTensor*)

ctypedef struct DLPackVersion:
uint32_t major
uint32_t minor

ctypedef struct DLManagedTensorVersioned:
DLPackVersion version
void* manager_ctx
void (*deleter)(DLManagedTensorVersioned*)
uint64_t flags
DLTensor dl_tensor

int DLPACK_MAJOR_VERSION
int DLPACK_MINOR_VERSION

const char* DLPACK_TENSOR_UNUSED_NAME
const char* DLPACK_VERSIONED_TENSOR_UNUSED_NAME


cdef void pycapsule_deleter(object dltensor):
cdef void pycapsule_deleter(object capsule):
cdef DLManagedTensor* dlm_tensor
# Do not invoke the deleter on a used capsule
if cpython.PyCapsule_IsValid(dltensor, 'dltensor'):
dlm_tensor = <DLManagedTensor*>cpython.PyCapsule_GetPointer(
dltensor, 'dltensor')
dlm_tensor.deleter(dlm_tensor)
cdef DLManagedTensorVersioned* dlm_tensor_ver
# Do not invoke the deleter on a used capsule.
if cpython.PyCapsule_IsValid(
capsule, DLPACK_TENSOR_UNUSED_NAME):
dlm_tensor = <DLManagedTensor*>(
cpython.PyCapsule_GetPointer(
capsule, DLPACK_TENSOR_UNUSED_NAME))
if dlm_tensor.deleter:
dlm_tensor.deleter(dlm_tensor)
elif cpython.PyCapsule_IsValid(
capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME):
dlm_tensor_ver = <DLManagedTensorVersioned*>(
cpython.PyCapsule_GetPointer(
capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME))
if dlm_tensor_ver.deleter:
dlm_tensor_ver.deleter(dlm_tensor_ver)


cdef void deleter(DLManagedTensor* tensor) with gil:
if tensor.manager_ctx is NULL:
return
stdlib.free(tensor.dl_tensor.shape)
cpython.Py_DECREF(<object>tensor.manager_ctx)
tensor.manager_ctx = NULL
if tensor.manager_ctx:
cpython.Py_DECREF(<object>tensor.manager_ctx)
tensor.manager_ctx = NULL
stdlib.free(tensor)


cpdef object make_py_capsule(object buf) except +:
cdef DLManagedTensor* dlm_tensor = \
<DLManagedTensor*>stdlib.malloc(sizeof(DLManagedTensor))
cdef void versioned_deleter(DLManagedTensorVersioned* tensor) with gil:
stdlib.free(tensor.dl_tensor.shape)
if tensor.manager_ctx:
cpython.Py_DECREF(<object>tensor.manager_ctx)
tensor.manager_ctx = NULL
stdlib.free(tensor)


cpdef object make_py_capsule(object buf, bint versioned) except +:
cdef DLManagedTensor* dlm_tensor
cdef DLManagedTensorVersioned* dlm_tensor_ver
cdef DLTensor* dl_tensor
cdef void* tensor_ptr
cdef const char* capsule_name

if versioned:
dlm_tensor_ver = <DLManagedTensorVersioned*>(
stdlib.malloc(sizeof(DLManagedTensorVersioned)))
dlm_tensor_ver.version.major = DLPACK_MAJOR_VERSION
dlm_tensor_ver.version.minor = DLPACK_MINOR_VERSION
dlm_tensor_ver.manager_ctx = <void*>buf
dlm_tensor_ver.deleter = versioned_deleter
dlm_tensor_ver.flags = 0
dl_tensor = &dlm_tensor_ver.dl_tensor
tensor_ptr = dlm_tensor_ver
capsule_name = DLPACK_VERSIONED_TENSOR_UNUSED_NAME
else:
dlm_tensor = <DLManagedTensor*>(
stdlib.malloc(sizeof(DLManagedTensor)))
dl_tensor = &dlm_tensor.dl_tensor
dlm_tensor.manager_ctx = <void*>buf
dlm_tensor.deleter = deleter
tensor_ptr = dlm_tensor
capsule_name = DLPACK_TENSOR_UNUSED_NAME

cdef DLTensor* dl_tensor = &dlm_tensor.dl_tensor
dl_tensor.data = <void*><intptr_t>(int(buf.handle))
dl_tensor.ndim = 1

cdef int64_t* shape_strides = \
<int64_t*>stdlib.malloc(sizeof(int64_t) * 2)
shape_strides[0] = <int64_t>buf.size
Expand Down Expand Up @@ -106,11 +165,8 @@ cpdef object make_py_capsule(object buf) except +:
dtype.lanes = <uint16_t>1
dtype.bits = <uint8_t>8

dlm_tensor.manager_ctx = <void*>buf
cpython.Py_INCREF(buf)
dlm_tensor.deleter = deleter

return cpython.PyCapsule_New(dlm_tensor, 'dltensor', pycapsule_deleter)
return cpython.PyCapsule_New(tensor_ptr, capsule_name, pycapsule_deleter)


class DLDeviceType(IntEnum):
Expand Down
27 changes: 13 additions & 14 deletions cuda_py/cuda/py/_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,19 @@ def __dlpack__(self, *,
max_version: Optional[Tuple[int, int]] = None,
dl_device: Optional[Tuple[int, int]] = None,
copy: Optional[bool] = None) -> PyCapsule:
# Support for Python-level DLPack protocol.
if stream is not None:
warnings.warn("stream != None is ignored")
# TODO: add checks for dl_device and copy
# FIXME: fix v1.0 support
#if max_version is None:
# versioned = False
#else:
# assert len(max_version) == 2
# if max_version >= (1, 0):
# versioned = True
# else:
# versioned = False
capsule = make_py_capsule(self)#, versioned)
# Note: we ignore the stream argument entirely (as if it is -1).
# It is the user's responsibility to maintain stream order.
if dl_device is not None or copy is True:
raise BufferError
if max_version is None:
versioned = False
else:
assert len(max_version) == 2
if max_version >= (1, 0):
versioned = True
else:
versioned = False
capsule = make_py_capsule(self, versioned)
return capsule

def __dlpack_device__(self) -> Tuple[int, int]:
Expand Down

0 comments on commit 8c49acc

Please sign in to comment.