diff --git a/cuda_core/MANIFEST.in b/cuda_core/MANIFEST.in new file mode 100644 index 0000000..f0b3354 --- /dev/null +++ b/cuda_core/MANIFEST.in @@ -0,0 +1 @@ +recursive-include cuda/core *.pyx *.pxd diff --git a/cuda_core/README.md b/cuda_core/README.md new file mode 100644 index 0000000..e979fb7 --- /dev/null +++ b/cuda_core/README.md @@ -0,0 +1,9 @@ +# `cuda.core`: (experimental) pythonic CUDA module + +Currently under active development. To build from source, just do: +```shell +$ git clone https://github.com/NVIDIA/cuda-python +$ cd cuda-python/cuda_core # move to the directory where this README locates +$ pip install . +``` +For now `cuda-python` is a required dependency. diff --git a/cuda_core/cuda/core/__init__.pxd b/cuda_core/cuda/core/__init__.pxd new file mode 100644 index 0000000..e69de29 diff --git a/cuda_core/cuda/core/__init__.py b/cuda_core/cuda/core/__init__.py new file mode 100644 index 0000000..cec6e8d --- /dev/null +++ b/cuda_core/cuda/core/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +from cuda.core._device import Device +from cuda.core._event import EventOptions +from cuda.core._launcher import LaunchConfig, launch +from cuda.core._program import Program +from cuda.core._stream import Stream, StreamOptions +from cuda.core._version import __version__ diff --git a/cuda_core/cuda/core/_context.py b/cuda_core/cuda/core/_context.py new file mode 100644 index 0000000..5d0f5ad --- /dev/null +++ b/cuda_core/cuda/core/_context.py @@ -0,0 +1,29 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +from dataclasses import dataclass + +from cuda import cuda, cudart +from cuda.core._utils import handle_return + + +@dataclass +class ContextOptions: + pass # TODO + + +class Context: + + __slots__ = ("_handle", "_id") + + def __init__(self): + raise NotImplementedError("TODO") + + @staticmethod + def _from_ctx(obj, dev_id): + assert isinstance(obj, cuda.CUcontext) + ctx = Context.__new__(Context) + ctx._handle = obj + ctx._id = dev_id + return ctx diff --git a/cuda_core/cuda/core/_device.py b/cuda_core/cuda/core/_device.py new file mode 100644 index 0000000..1268da3 --- /dev/null +++ b/cuda_core/cuda/core/_device.py @@ -0,0 +1,187 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +import threading +from typing import Optional, Union +import warnings + +from cuda import cuda, cudart +from cuda.core._utils import handle_return, ComputeCapability, CUDAError, \ + precondition +from cuda.core._context import Context, ContextOptions +from cuda.core._memory import _DefaultAsyncMempool, Buffer, MemoryResource +from cuda.core._stream import default_stream, Stream, StreamOptions + + +_tls = threading.local() +_tls_lock = threading.Lock() + + +class Device: + + __slots__ = ("_id", "_mr", "_has_inited") + + def __new__(cls, device_id=None): + # important: creating a Device instance does not initialize the GPU! + if device_id is None: + device_id = handle_return(cudart.cudaGetDevice()) + assert isinstance(device_id, int), f"{device_id=}" + else: + total = handle_return(cudart.cudaGetDeviceCount()) + if not isinstance(device_id, int) or not (0 <= device_id < total): + raise ValueError( + f"device_id must be within [0, {total}), got {device_id}") + + # ensure Device is singleton + with _tls_lock: + if not hasattr(_tls, "devices"): + total = handle_return(cudart.cudaGetDeviceCount()) + _tls.devices = [] + for dev_id in range(total): + dev = super().__new__(cls) + dev._id = dev_id + dev._mr = _DefaultAsyncMempool(dev_id) + dev._has_inited = False + _tls.devices.append(dev) + + return _tls.devices[device_id] + + def _check_context_initialized(self, *args, **kwargs): + if not self._has_inited: + raise CUDAError("the device is not yet initialized, " + "perhaps you forgot to call .set_current() first?") + + @property + def device_id(self) -> int: + return self._id + + @property + def pci_bus_id(self) -> str: + bus_id = handle_return(cudart.cudaDeviceGetPCIBusId(13, self._id)) + return bus_id[:12].decode() + + @property + def uuid(self) -> str: + driver_ver = handle_return(cuda.cuDriverGetVersion()) + if driver_ver >= 11040: + uuid = handle_return(cuda.cuDeviceGetUuid_v2(self._id)) + else: + uuid = handle_return(cuda.cuDeviceGetUuid(self._id)) + uuid = uuid.bytes.hex() + # 8-4-4-4-12 + return f"{uuid[:8]}-{uuid[8:12]}-{uuid[12:16]}-{uuid[16:20]}-{uuid[20:]}" + + @property + def name(self) -> str: + # assuming a GPU name is less than 128 characters... + name = handle_return(cuda.cuDeviceGetName(128, self._id)) + name = name.split(b'\0')[0] + return name.decode() + + @property + def properties(self) -> dict: + # TODO: pythonize the key names + return handle_return(cudart.cudaGetDeviceProperties(self._id)) + + @property + def compute_capability(self) -> ComputeCapability: + """Returns a named tuple with 2 fields: major and minor. """ + major = handle_return(cudart.cudaDeviceGetAttribute( + cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor, self._id)) + minor = handle_return(cudart.cudaDeviceGetAttribute( + cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor, self._id)) + return ComputeCapability(major, minor) + + @property + @precondition(_check_context_initialized) + def context(self) -> Context: + ctx = handle_return(cuda.cuCtxGetCurrent()) + assert int(ctx) != 0 + return Context._from_ctx(ctx, self._id) + + @property + def memory_resource(self) -> MemoryResource: + return self._mr + + @memory_resource.setter + def memory_resource(self, mr): + if not isinstance(mr, MemoryResource): + raise TypeError + self._mr = mr + + @property + def default_stream(self) -> Stream: + return default_stream() + + def __int__(self): + return self._id + + def __repr__(self): + return f"" + + def set_current(self, ctx: Context=None) -> Union[Context, None]: + """ + Entry point of this object. Users always start a code by + calling this method, e.g. + + >>> from cuda.core import Device + >>> dev0 = Device(0) + >>> dev0.set_current() + >>> # ... do work on device 0 ... + + The optional ctx argument is for advanced users to bind a + CUDA context with the device. In this case, the previously + set context is popped and returned to the user. + """ + if ctx is not None: + if not isinstance(ctx, Context): + raise TypeError("a Context object is required") + if ctx._id != self._id: + raise RuntimeError("the provided context was created on a different " + f"device {ctx._id} other than the target {self._id}") + prev_ctx = handle_return(cuda.cuCtxPopCurrent()) + handle_return(cuda.cuCtxPushCurrent(ctx._handle)) + self._has_inited = True + if int(prev_ctx) != 0: + return Context._from_ctx(prev_ctx, self._id) + else: + ctx = handle_return(cuda.cuCtxGetCurrent()) + if int(ctx) == 0: + # use primary ctx + ctx = handle_return(cuda.cuDevicePrimaryCtxRetain(self._id)) + handle_return(cuda.cuCtxPushCurrent(ctx)) + else: + ctx_id = handle_return(cuda.cuCtxGetDevice()) + if ctx_id != self._id: + # use primary ctx + ctx = handle_return(cuda.cuDevicePrimaryCtxRetain(self._id)) + handle_return(cuda.cuCtxPushCurrent(ctx)) + else: + # no-op, a valid context already exists and is set current + pass + self._has_inited = True + + def create_context(self, options: ContextOptions = None) -> Context: + # Create a Context object (but do NOT set it current yet!). + # ContextOptions is a dataclass for setting e.g. affinity or CIG + # options. + raise NotImplementedError("TODO") + + @precondition(_check_context_initialized) + def create_stream(self, obj=None, options: StreamOptions=None) -> Stream: + # Create a Stream object by either holding a newly created + # CUDA stream or wrapping an existing foreign object supporting + # the __cuda_stream__ protocol. In the latter case, a reference + # to obj is held internally so that its lifetime is managed. + return Stream._init(obj=obj, options=options) + + @precondition(_check_context_initialized) + def allocate(self, size, stream=None) -> Buffer: + if stream is None: + stream = default_stream() + return self._mr.allocate(size, stream) + + @precondition(_check_context_initialized) + def sync(self): + handle_return(cudart.cudaDeviceSynchronize()) diff --git a/cuda_core/cuda/core/_dlpack.pxd b/cuda_core/cuda/core/_dlpack.pxd new file mode 100644 index 0000000..1868287 --- /dev/null +++ b/cuda_core/cuda/core/_dlpack.pxd @@ -0,0 +1,79 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +cimport cpython + +from libc cimport stdlib +from libc.stdint cimport uint8_t +from libc.stdint cimport uint16_t +from libc.stdint cimport uint32_t +from libc.stdint cimport int32_t +from libc.stdint cimport int64_t +from libc.stdint cimport uint64_t +from libc.stdint cimport intptr_t + + +cdef extern from "dlpack.h" nogil: + """ + #define DLPACK_TENSOR_UNUSED_NAME "dltensor" + #define DLPACK_VERSIONED_TENSOR_UNUSED_NAME "dltensor_versioned" + #define DLPACK_TENSOR_USED_NAME "used_dltensor" + #define DLPACK_VERSIONED_TENSOR_USED_NAME "used_dltensor_versioned" + """ + ctypedef enum _DLDeviceType "DLDeviceType": + _kDLCPU "kDLCPU" + _kDLCUDA "kDLCUDA" + _kDLCUDAHost "kDLCUDAHost" + _kDLCUDAManaged "kDLCUDAManaged" + + ctypedef struct DLDevice: + _DLDeviceType device_type + int32_t device_id + + cdef enum DLDataTypeCode: + kDLInt + kDLUInt + kDLFloat + kDLBfloat + kDLComplex + kDLBool + + ctypedef struct DLDataType: + uint8_t code + uint8_t bits + uint16_t lanes + + ctypedef struct DLTensor: + void* data + DLDevice device + int32_t ndim + DLDataType dtype + int64_t* shape + int64_t* strides + uint64_t byte_offset + + ctypedef struct DLManagedTensor: + DLTensor dl_tensor + void* manager_ctx + void (*deleter)(DLManagedTensor*) + + ctypedef struct DLPackVersion: + uint32_t major + uint32_t minor + + ctypedef struct DLManagedTensorVersioned: + DLPackVersion version + void* manager_ctx + void (*deleter)(DLManagedTensorVersioned*) + uint64_t flags + DLTensor dl_tensor + + int DLPACK_MAJOR_VERSION + int DLPACK_MINOR_VERSION + int DLPACK_FLAG_BITMASK_READ_ONLY + + const char* DLPACK_TENSOR_UNUSED_NAME + const char* DLPACK_VERSIONED_TENSOR_UNUSED_NAME + const char* DLPACK_TENSOR_USED_NAME + const char* DLPACK_VERSIONED_TENSOR_USED_NAME diff --git a/cuda_core/cuda/core/_dlpack.pyx b/cuda_core/cuda/core/_dlpack.pyx new file mode 100644 index 0000000..dda08ea --- /dev/null +++ b/cuda_core/cuda/core/_dlpack.pyx @@ -0,0 +1,108 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +from enum import IntEnum + + +cdef void pycapsule_deleter(object capsule) noexcept: + cdef DLManagedTensor* dlm_tensor + cdef DLManagedTensorVersioned* dlm_tensor_ver + # Do not invoke the deleter on a used capsule. + if cpython.PyCapsule_IsValid( + capsule, DLPACK_TENSOR_UNUSED_NAME): + dlm_tensor = ( + cpython.PyCapsule_GetPointer( + capsule, DLPACK_TENSOR_UNUSED_NAME)) + if dlm_tensor.deleter: + dlm_tensor.deleter(dlm_tensor) + elif cpython.PyCapsule_IsValid( + capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME): + dlm_tensor_ver = ( + cpython.PyCapsule_GetPointer( + capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME)) + if dlm_tensor_ver.deleter: + dlm_tensor_ver.deleter(dlm_tensor_ver) + + +cdef void deleter(DLManagedTensor* tensor) noexcept with gil: + stdlib.free(tensor.dl_tensor.shape) + if tensor.manager_ctx: + cpython.Py_DECREF(tensor.manager_ctx) + tensor.manager_ctx = NULL + stdlib.free(tensor) + + +cdef void versioned_deleter(DLManagedTensorVersioned* tensor) noexcept with gil: + stdlib.free(tensor.dl_tensor.shape) + if tensor.manager_ctx: + cpython.Py_DECREF(tensor.manager_ctx) + tensor.manager_ctx = NULL + stdlib.free(tensor) + + +cpdef object make_py_capsule(object buf, bint versioned): + cdef DLManagedTensor* dlm_tensor + cdef DLManagedTensorVersioned* dlm_tensor_ver + cdef DLTensor* dl_tensor + cdef void* tensor_ptr + cdef const char* capsule_name + + if versioned: + dlm_tensor_ver = ( + stdlib.malloc(sizeof(DLManagedTensorVersioned))) + dlm_tensor_ver.version.major = DLPACK_MAJOR_VERSION + dlm_tensor_ver.version.minor = DLPACK_MINOR_VERSION + dlm_tensor_ver.manager_ctx = buf + dlm_tensor_ver.deleter = versioned_deleter + dlm_tensor_ver.flags = 0 + dl_tensor = &dlm_tensor_ver.dl_tensor + tensor_ptr = dlm_tensor_ver + capsule_name = DLPACK_VERSIONED_TENSOR_UNUSED_NAME + else: + dlm_tensor = ( + stdlib.malloc(sizeof(DLManagedTensor))) + dl_tensor = &dlm_tensor.dl_tensor + dlm_tensor.manager_ctx = buf + dlm_tensor.deleter = deleter + tensor_ptr = dlm_tensor + capsule_name = DLPACK_TENSOR_UNUSED_NAME + + dl_tensor.data = (int(buf.handle)) + dl_tensor.ndim = 1 + cdef int64_t* shape_strides = \ + stdlib.malloc(sizeof(int64_t) * 2) + shape_strides[0] = buf.size + shape_strides[1] = 1 # redundant + dl_tensor.shape = shape_strides + dl_tensor.strides = NULL + dl_tensor.byte_offset = 0 + + cdef DLDevice* device = &dl_tensor.device + # buf should be a Buffer instance + if buf.is_device_accessible and not buf.is_host_accessible: + device.device_type = _kDLCUDA + device.device_id = buf.device_id + elif buf.is_device_accessible and buf.is_host_accessible: + device.device_type = _kDLCUDAHost + device.device_id = 0 + elif not buf.is_device_accessible and buf.is_host_accessible: + device.device_type = _kDLCPU + device.device_id = 0 + else: # not buf.is_device_accessible and not buf.is_host_accessible + raise BufferError("invalid buffer") + + cdef DLDataType* dtype = &dl_tensor.dtype + dtype.code = kDLInt + dtype.lanes = 1 + dtype.bits = 8 + + cpython.Py_INCREF(buf) + return cpython.PyCapsule_New(tensor_ptr, capsule_name, pycapsule_deleter) + + +class DLDeviceType(IntEnum): + kDLCPU = _kDLCPU + kDLCUDA = _kDLCUDA + kDLCUDAHost = _kDLCUDAHost + kDLCUDAManaged = _kDLCUDAManaged diff --git a/cuda_core/cuda/core/_event.py b/cuda_core/cuda/core/_event.py new file mode 100644 index 0000000..5fbacae --- /dev/null +++ b/cuda_core/cuda/core/_event.py @@ -0,0 +1,95 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +from dataclasses import dataclass +from typing import Optional + +from cuda import cuda +from cuda.core._utils import check_or_create_options +from cuda.core._utils import CUDAError +from cuda.core._utils import handle_return + + +@dataclass +class EventOptions: + enable_timing: Optional[bool] = False + busy_waited_sync: Optional[bool] = False + support_ipc: Optional[bool] = False + + +class Event: + + __slots__ = ("_handle", "_timing_disabled", "_busy_waited") + + def __init__(self): + # minimal requirements for the destructor + self._handle = None + raise NotImplementedError( + "directly creating an Event object can be ambiguous. Please call " + "call Stream.record().") + + @staticmethod + def _init(options: Optional[EventOptions]=None): + self = Event.__new__(Event) + # minimal requirements for the destructor + self._handle = None + + options = check_or_create_options(EventOptions, options, "Event options") + flags = 0x0 + self._timing_disabled = False + self._busy_waited = False + if not options.enable_timing: + flags |= cuda.CUevent_flags.CU_EVENT_DISABLE_TIMING + self._timing_disabled = True + if options.busy_waited_sync: + flags |= cuda.CUevent_flags.CU_EVENT_BLOCKING_SYNC + self._busy_waited = True + if options.support_ipc: + raise NotImplementedError("TODO") + self._handle = handle_return(cuda.cuEventCreate(flags)) + return self + + def __del__(self): + self.close() + + def close(self): + # Destroy the event. + if self._handle: + handle_return(cuda.cuEventDestroy(self._handle)) + self._handle = None + + @property + def is_timing_disabled(self) -> bool: + # Check if this instance can be used for the timing purpose. + return self._timing_disabled + + @property + def is_sync_busy_waited(self) -> bool: + # Check if the event synchronization would keep the CPU busy-waiting. + return self._busy_waited + + @property + def is_ipc_supported(self) -> bool: + # Check if this instance can be used for IPC. + raise NotImplementedError("TODO") + + def sync(self): + # Sync over the event. + handle_return(cuda.cuEventSynchronize(self._handle)) + + @property + def is_done(self) -> bool: + # Return True if all captured works have been completed, + # otherwise False. + result, = cuda.cuEventQuery(self._handle) + if result == cuda.CUresult.CUDA_SUCCESS: + return True + elif result == cuda.CUresult.CUDA_ERROR_NOT_READY: + return False + else: + raise CUDAError(f"unexpected error: {result}") + + @property + def handle(self) -> int: + return int(self._handle) diff --git a/cuda_core/cuda/core/_kernel_arg_handler.pyx b/cuda_core/cuda/core/_kernel_arg_handler.pyx new file mode 100644 index 0000000..f2d392a --- /dev/null +++ b/cuda_core/cuda/core/_kernel_arg_handler.pyx @@ -0,0 +1,218 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +from cpython.mem cimport PyMem_Malloc, PyMem_Free +from libc.stdint cimport (intptr_t, + int8_t, int16_t, int32_t, int64_t, + uint8_t, uint16_t, uint32_t, uint64_t,) +from libcpp cimport bool as cpp_bool +from libcpp.complex cimport complex as cpp_complex +from libcpp cimport nullptr +from libcpp cimport vector + +import ctypes + +import numpy + +from cuda.core._memory import Buffer + + +ctypedef cpp_complex.complex[float] cpp_single_complex +ctypedef cpp_complex.complex[double] cpp_double_complex + + +ctypedef fused supported_type: + cpp_bool + int8_t + int16_t + int32_t + int64_t + uint8_t + uint16_t + uint32_t + uint64_t + float + double + intptr_t + cpp_single_complex + cpp_double_complex + + +# cache ctypes/numpy type objects to avoid attribute access +cdef object ctypes_bool = ctypes.c_bool +cdef object ctypes_int8 = ctypes.c_int8 +cdef object ctypes_int16 = ctypes.c_int16 +cdef object ctypes_int32 = ctypes.c_int32 +cdef object ctypes_int64 = ctypes.c_int64 +cdef object ctypes_uint8 = ctypes.c_uint8 +cdef object ctypes_uint16 = ctypes.c_uint16 +cdef object ctypes_uint32 = ctypes.c_uint32 +cdef object ctypes_uint64 = ctypes.c_uint64 +cdef object ctypes_float = ctypes.c_float +cdef object ctypes_double = ctypes.c_double +cdef object numpy_bool = numpy.bool_ +cdef object numpy_int8 = numpy.int8 +cdef object numpy_int16 = numpy.int16 +cdef object numpy_int32 = numpy.int32 +cdef object numpy_int64 = numpy.int64 +cdef object numpy_uint8 = numpy.uint8 +cdef object numpy_uint16 = numpy.uint16 +cdef object numpy_uint32 = numpy.uint32 +cdef object numpy_uint64 = numpy.uint64 +cdef object numpy_float16 = numpy.float16 +cdef object numpy_float32 = numpy.float32 +cdef object numpy_float64 = numpy.float64 +cdef object numpy_complex64 = numpy.complex64 +cdef object numpy_complex128 = numpy.complex128 + + +# limitation due to cython/cython#534 +ctypedef void* voidptr + + +# Cython can't infer the overload without at least one input argument with fused type +cdef inline int prepare_arg( + vector.vector[void*]& data, + vector.vector[void*]& data_addresses, + arg, # important: keep it a Python object and don't cast + const size_t idx, + const supported_type* __unused=NULL) except -1: + cdef void* ptr = PyMem_Malloc(sizeof(supported_type)) + # note: this should also work once ctypes has complex support: + # python/cpython#121248 + if supported_type is cpp_single_complex: + (ptr)[0] = cpp_complex.complex[float](arg.real, arg.imag) + elif supported_type is cpp_double_complex: + (ptr)[0] = cpp_complex.complex[double](arg.real, arg.imag) + else: + (ptr)[0] = (arg) + data_addresses[idx] = ptr # take the address to the scalar + data[idx] = ptr # for later dealloc + return 0 + + +cdef inline int prepare_ctypes_arg( + vector.vector[void*]& data, + vector.vector[void*]& data_addresses, + arg, + const size_t idx) except -1: + if isinstance(arg, ctypes_bool): + return prepare_arg[cpp_bool](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_int8): + return prepare_arg[int8_t](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_int16): + return prepare_arg[int16_t](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_int32): + return prepare_arg[int32_t](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_int64): + return prepare_arg[int64_t](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_uint8): + return prepare_arg[uint8_t](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_uint16): + return prepare_arg[uint16_t](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_uint32): + return prepare_arg[uint32_t](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_uint64): + return prepare_arg[uint64_t](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_float): + return prepare_arg[float](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_double): + return prepare_arg[double](data, data_addresses, arg.value, idx) + else: + return 1 + + +cdef inline int prepare_numpy_arg( + vector.vector[void*]& data, + vector.vector[void*]& data_addresses, + arg, + const size_t idx) except -1: + if isinstance(arg, numpy_bool): + return prepare_arg[cpp_bool](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_int8): + return prepare_arg[int8_t](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_int16): + return prepare_arg[int16_t](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_int32): + return prepare_arg[int32_t](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_int64): + return prepare_arg[int64_t](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_uint8): + return prepare_arg[uint8_t](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_uint16): + return prepare_arg[uint16_t](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_uint32): + return prepare_arg[uint32_t](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_uint64): + return prepare_arg[uint64_t](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_float16): + # use int16 as a proxy + return prepare_arg[int16_t](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_float32): + return prepare_arg[float](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_float64): + return prepare_arg[double](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_complex64): + return prepare_arg[cpp_single_complex](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_complex128): + return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx) + else: + return 1 + + +cdef class ParamHolder: + + cdef: + vector.vector[void*] data + vector.vector[void*] data_addresses + object kernel_args + readonly intptr_t ptr + + def __init__(self, kernel_args): + if len(kernel_args) == 0: + self.ptr = 0 + return + + cdef size_t n_args = len(kernel_args) + cdef size_t i + cdef int not_prepared + self.data = vector.vector[voidptr](n_args, nullptr) + self.data_addresses = vector.vector[voidptr](n_args) + for i, arg in enumerate(kernel_args): + if isinstance(arg, Buffer): + # we need the address of where the actual buffer address is stored + self.data_addresses[i] = (arg._ptr.getPtr()) + continue + elif isinstance(arg, int): + # Here's the dilemma: We want to have a fast path to pass in Python + # integers as pointer addresses, but one could also (mistakenly) pass + # it with the intention of passing a scalar integer. It's a mistake + # bacause a Python int is ambiguous (arbitrary width). Our judgement + # call here is to treat it as a pointer address, without any warning! + prepare_arg[intptr_t](self.data, self.data_addresses, arg, i) + continue + elif isinstance(arg, float): + prepare_arg[double](self.data, self.data_addresses, arg, i) + continue + elif isinstance(arg, complex): + prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i) + continue + elif isinstance(arg, bool): + prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i) + continue + + not_prepared = prepare_numpy_arg(self.data, self.data_addresses, arg, i) + if not_prepared: + not_prepared = prepare_ctypes_arg(self.data, self.data_addresses, arg, i) + if not_prepared: + # TODO: support ctypes/numpy struct + raise TypeError + + self.kernel_args = kernel_args + self.ptr = self.data_addresses.data() + + def __dealloc__(self): + for data in self.data: + if data: + PyMem_Free(data) diff --git a/cuda_core/cuda/core/_launcher.py b/cuda_core/cuda/core/_launcher.py new file mode 100644 index 0000000..03d7fc0 --- /dev/null +++ b/cuda_core/cuda/core/_launcher.py @@ -0,0 +1,90 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +from dataclasses import dataclass +from typing import Optional, Union + +import numpy as np + +from cuda import cuda, cudart +from cuda.core._kernel_arg_handler import ParamHolder +from cuda.core._memory import Buffer +from cuda.core._module import Kernel +from cuda.core._stream import Stream +from cuda.core._utils import CUDAError, check_or_create_options, handle_return + + +@dataclass +class LaunchConfig: + """ + """ + grid: Union[tuple, int] = None + block: Union[tuple, int] = None + stream: Stream = None + shmem_size: Optional[int] = None + + def __post_init__(self): + self.grid = self._cast_to_3_tuple(self.grid) + self.block = self._cast_to_3_tuple(self.block) + # we handle "stream=None" in the launch API + if self.stream is not None: + if not isinstance(self.stream, Stream): + try: + self.stream = Stream._init(self.stream) + except Exception as e: + raise ValueError( + "stream must either be a Stream object " + "or support __cuda_stream__") from e + if self.shmem_size is None: + self.shmem_size = 0 + + def _cast_to_3_tuple(self, cfg): + if isinstance(cfg, int): + if cfg < 1: + raise ValueError + return (cfg, 1, 1) + elif isinstance(cfg, tuple): + size = len(cfg) + if size == 1: + cfg = cfg[0] + if cfg < 1: + raise ValueError + return (cfg, 1, 1) + elif size == 2: + if cfg[0] < 1 or cfg[1] < 1: + raise ValueError + return (*cfg, 1) + elif size == 3: + if cfg[0] < 1 or cfg[1] < 1 or cfg[2] < 1: + raise ValueError + return cfg + else: + raise ValueError + + +def launch(kernel, config, *kernel_args): + if not isinstance(kernel, Kernel): + raise ValueError + config = check_or_create_options(LaunchConfig, config, "launch config") + # TODO: can we ensure kernel_args is valid/safe to use here? + + driver_ver = handle_return(cuda.cuDriverGetVersion()) + if driver_ver >= 12000: + drv_cfg = cuda.CUlaunchConfig() + drv_cfg.gridDimX, drv_cfg.gridDimY, drv_cfg.gridDimZ = config.grid + drv_cfg.blockDimX, drv_cfg.blockDimY, drv_cfg.blockDimZ = config.block + if config.stream is None: + raise CUDAError("stream cannot be None") + drv_cfg.hStream = config.stream._handle + drv_cfg.sharedMemBytes = config.shmem_size + drv_cfg.numAttrs = 0 # FIXME + + # TODO: merge with HelperKernelParams? + kernel_args = ParamHolder(kernel_args) + args_ptr = kernel_args.ptr + + handle_return(cuda.cuLaunchKernelEx( + drv_cfg, int(kernel._handle), args_ptr, 0)) + else: + raise NotImplementedError("TODO") diff --git a/cuda_core/cuda/core/_memory.py b/cuda_core/cuda/core/_memory.py new file mode 100644 index 0000000..0d5dd0d --- /dev/null +++ b/cuda_core/cuda/core/_memory.py @@ -0,0 +1,241 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +from __future__ import annotations + +import abc +from typing import Optional, Tuple, TypeVar +import warnings + +from cuda import cuda +from cuda.core._dlpack import DLDeviceType, make_py_capsule +from cuda.core._stream import default_stream +from cuda.core._utils import handle_return + + +PyCapsule = TypeVar("PyCapsule") + + +# TODO: define a memory property mixin class and make Buffer and +# MemoryResource both inherit from it + + +class Buffer: + + # TODO: handle ownership? (_mr could be None) + __slots__ = ("_ptr", "_size", "_mr",) + + def __init__(self, ptr, size, mr: MemoryResource=None): + self._ptr = ptr + self._size = size + self._mr = mr + + def __del__(self): + self.close(default_stream()) + + def close(self, stream=None): + if self._ptr and self._mr is not None: + if stream is None: + stream = default_stream() + self._mr.deallocate(self._ptr, self._size, stream) + self._ptr = 0 + self._mr = None + + @property + def handle(self): + return self._ptr + + @property + def size(self): + return self._size + + @property + def memory_resource(self) -> MemoryResource: + # Return the memory resource from which this buffer was allocated. + return self._mr + + @property + def is_device_accessible(self) -> bool: + # Check if this buffer can be accessed from GPUs. + if self._mr is not None: + return self._mr.is_device_accessible + raise NotImplementedError + + @property + def is_host_accessible(self) -> bool: + # Check if this buffer can be accessed from CPUs. + if self._mr is not None: + return self._mr.is_host_accessible + raise NotImplementedError + + @property + def device_id(self) -> int: + if self._mr is not None: + return self._mr.device_id + raise NotImplementedError + + def copy_to(self, dst: Buffer=None, *, stream) -> Buffer: + # Copy from this buffer to the dst buffer asynchronously on the + # given stream. The dst buffer is returned. If the dst is not provided, + # allocate one from self.memory_resource. Raise an exception if the + # stream is not provided. + if stream is None: + raise ValueError("stream must be provided") + if dst is None: + if self._mr is None: + raise ValueError("a destination buffer must be provided") + dst = self._mr.allocate(self._size, stream) + if dst._size != self._size: + raise ValueError("buffer sizes mismatch between src and dst") + handle_return( + cuda.cuMemcpyAsync(dst._ptr, self._ptr, self._size, stream._handle)) + return dst + + def copy_from(self, src: Buffer, *, stream): + # Copy from the src buffer to this buffer asynchronously on the + # given stream. Raise an exception if the stream is not provided. + if stream is None: + raise ValueError("stream must be provided") + if src._size != self._size: + raise ValueError("buffer sizes mismatch between src and dst") + handle_return( + cuda.cuMemcpyAsync(self._ptr, src._ptr, self._size, stream._handle)) + + def __dlpack__(self, *, + stream: Optional[int] = None, + max_version: Optional[Tuple[int, int]] = None, + dl_device: Optional[Tuple[int, int]] = None, + copy: Optional[bool] = None) -> PyCapsule: + # Note: we ignore the stream argument entirely (as if it is -1). + # It is the user's responsibility to maintain stream order. + if dl_device is not None or copy is True: + raise BufferError + if max_version is None: + versioned = False + else: + assert len(max_version) == 2 + if max_version >= (1, 0): + versioned = True + else: + versioned = False + capsule = make_py_capsule(self, versioned) + return capsule + + def __dlpack_device__(self) -> Tuple[int, int]: + if self.is_device_accessible and not self.is_host_accessible: + return (DLDeviceType.kDLCUDA, self.device_id) + elif self.is_device_accessible and self.is_host_accessible: + # TODO: this can also be kDLCUDAManaged, we need more fine-grained checks + return (DLDeviceType.kDLCUDAHost, 0) + elif not self.is_device_accessible and self.is_host_accessible: + return (DLDeviceType.kDLCPU, 0) + else: # not self.is_device_accessible and not self.is_host_accessible + raise BufferError("invalid buffer") + + def __buffer__(self, flags: int, /) -> memoryview: + # Support for Python-level buffer protocol as per PEP 688. + # This raises a BufferError unless: + # 1. Python is 3.12+ + # 2. This Buffer object is host accessible + raise NotImplementedError("TODO") + + def __release_buffer__(self, buffer: memoryview, /): + # Supporting methond paired with __buffer__. + raise NotImplementedError("TODO") + + +class MemoryResource(abc.ABC): + + __slots__ = ("_handle",) + + @abc.abstractmethod + def __init__(self, *args, **kwargs): + ... + + @abc.abstractmethod + def allocate(self, size, stream=None) -> Buffer: + ... + + @abc.abstractmethod + def deallocate(self, ptr, size, stream=None): + ... + + @property + @abc.abstractmethod + def is_device_accessible(self) -> bool: + # Check if the buffers allocated from this MR can be accessed from + # GPUs. + ... + + @property + @abc.abstractmethod + def is_host_accessible(self) -> bool: + # Check if the buffers allocated from this MR can be accessed from + # CPUs. + ... + + @property + @abc.abstractmethod + def device_id(self) -> int: + # Return the device ID if this MR is for single devices. Raise an + # exception if it is not. + ... + + +class _DefaultAsyncMempool(MemoryResource): + + __slots__ = ("_dev_id",) + + def __init__(self, dev_id): + self._handle = handle_return(cuda.cuDeviceGetMemPool(dev_id)) + self._dev_id = dev_id + + def allocate(self, size, stream=None) -> Buffer: + if stream is None: + stream = default_stream() + ptr = handle_return(cuda.cuMemAllocFromPoolAsync(size, self._handle, stream._handle)) + return Buffer(ptr, size, self) + + def deallocate(self, ptr, size, stream=None): + if stream is None: + stream = default_stream() + handle_return(cuda.cuMemFreeAsync(ptr, stream._handle)) + + @property + def is_device_accessible(self) -> bool: + return True + + @property + def is_host_accessible(self) -> bool: + return False + + @property + def device_id(self) -> int: + return self._dev_id + + +class _DefaultPinnedMemorySource(MemoryResource): + + def __init__(self): + # TODO: support flags from cuMemHostAlloc? + self._handle = None + + def allocate(self, size, stream=None) -> Buffer: + ptr = handle_return(cuda.cuMemAllocHost(size)) + return Buffer(ptr, size, self) + + def deallocate(self, ptr, size, stream=None): + handle_return(cuda.cuMemFreeHost(ptr)) + + @property + def is_device_accessible(self) -> bool: + return True + + @property + def is_host_accessible(self) -> bool: + return True + + @property + def device_id(self) -> int: + raise RuntimeError("the pinned memory resource is not bound to any GPU") diff --git a/cuda_core/cuda/core/_memoryview.pyx b/cuda_core/cuda/core/_memoryview.pyx new file mode 100644 index 0000000..8f7cc94 --- /dev/null +++ b/cuda_core/cuda/core/_memoryview.pyx @@ -0,0 +1,297 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +cimport cython + +from ._dlpack cimport * + +import functools +from typing import Any, Optional + +from cuda import cuda +import numpy + +from cuda.core._utils import handle_return + + +# TODO(leofang): support NumPy structured dtypes + + +@cython.dataclasses.dataclass +cdef class StridedMemoryView: + + # TODO: switch to use Cython's cdef typing? + ptr: int = None + shape: tuple = None + strides: tuple = None # in counts, not bytes + dtype: numpy.dtype = None + device_id: int = None # -1 for CPU + device_accessible: bool = None + readonly: bool = None + obj: Any = None + + def __init__(self, obj=None, stream_ptr=None): + if obj is not None: + # populate self's attributes + if check_has_dlpack(obj): + view_as_dlpack(obj, stream_ptr, self) + else: + view_as_cai(obj, stream_ptr, self) + else: + # default construct + pass + + def __repr__(self): + return (f"StridedMemoryView(ptr={self.ptr},\n" + + f" shape={self.shape},\n" + + f" strides={self.strides},\n" + + f" dtype={get_simple_repr(self.dtype)},\n" + + f" device_id={self.device_id},\n" + + f" device_accessible={self.device_accessible},\n" + + f" readonly={self.readonly},\n" + + f" obj={get_simple_repr(self.obj)})") + + +cdef str get_simple_repr(obj): + # TODO: better handling in np.dtype objects + cdef object obj_class + cdef str obj_repr + if isinstance(obj, type): + obj_class = obj + else: + obj_class = obj.__class__ + if obj_class.__module__ in (None, "builtins"): + obj_repr = obj_class.__name__ + else: + obj_repr = f"{obj_class.__module__}.{obj_class.__name__}" + return obj_repr + + +cdef bint check_has_dlpack(obj) except*: + cdef bint has_dlpack + if hasattr(obj, "__dlpack__") and hasattr(obj, "__dlpack_device__"): + has_dlpack = True + elif hasattr(obj, "__cuda_array_interface__"): + has_dlpack = False + else: + raise RuntimeError( + "the input object does not support any data exchange protocol") + return has_dlpack + + +cdef class _StridedMemoryViewProxy: + + cdef: + object obj + bint has_dlpack + + def __init__(self, obj): + self.obj = obj + self.has_dlpack = check_has_dlpack(obj) + + cpdef StridedMemoryView view(self, stream_ptr=None): + if self.has_dlpack: + return view_as_dlpack(self.obj, stream_ptr) + else: + return view_as_cai(self.obj, stream_ptr) + + +cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view=None): + cdef int dldevice, device_id, i + cdef bint device_accessible, versioned, is_readonly + dldevice, device_id = obj.__dlpack_device__() + if dldevice == _kDLCPU: + device_accessible = False + assert device_id == 0 + if stream_ptr is None: + raise BufferError("stream=None is ambiguous with view()") + elif stream_ptr == -1: + stream_ptr = None + elif dldevice == _kDLCUDA: + device_accessible = True + # no need to check other stream values, it's a pass-through + if stream_ptr is None: + raise BufferError("stream=None is ambiguous with view()") + elif dldevice == _kDLCUDAHost: + device_accessible = True + assert device_id == 0 + # just do a pass-through without any checks, as pinned memory can be + # accessed on both host and device + elif dldevice == _kDLCUDAManaged: + device_accessible = True + # just do a pass-through without any checks, as managed memory can be + # accessed on both host and device + else: + raise BufferError("device not supported") + + cdef object capsule + try: + capsule = obj.__dlpack__( + stream=stream_ptr, + max_version=(DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION)) + versioned = True + except TypeError: + capsule = obj.__dlpack__( + stream=stream_ptr) + versioned = False + + cdef void* data = NULL + if versioned and cpython.PyCapsule_IsValid( + capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME): + data = cpython.PyCapsule_GetPointer( + capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME) + elif not versioned and cpython.PyCapsule_IsValid( + capsule, DLPACK_TENSOR_UNUSED_NAME): + data = cpython.PyCapsule_GetPointer( + capsule, DLPACK_TENSOR_UNUSED_NAME) + else: + assert False + + cdef DLManagedTensor* dlm_tensor + cdef DLManagedTensorVersioned* dlm_tensor_ver + cdef DLTensor* dl_tensor + if versioned: + dlm_tensor_ver = data + dl_tensor = &dlm_tensor_ver.dl_tensor + is_readonly = bool((dlm_tensor_ver.flags & DLPACK_FLAG_BITMASK_READ_ONLY) != 0) + else: + dlm_tensor = data + dl_tensor = &dlm_tensor.dl_tensor + is_readonly = False + + cdef StridedMemoryView buf = StridedMemoryView() if view is None else view + buf.ptr = (dl_tensor.data) + buf.shape = tuple(int(dl_tensor.shape[i]) for i in range(dl_tensor.ndim)) + if dl_tensor.strides: + buf.strides = tuple( + int(dl_tensor.strides[i]) for i in range(dl_tensor.ndim)) + else: + # C-order + buf.strides = None + buf.dtype = dtype_dlpack_to_numpy(&dl_tensor.dtype) + buf.device_id = device_id + buf.device_accessible = device_accessible + buf.readonly = is_readonly + buf.obj = obj + + cdef const char* used_name = ( + DLPACK_VERSIONED_TENSOR_USED_NAME if versioned else DLPACK_TENSOR_USED_NAME) + cpython.PyCapsule_SetName(capsule, used_name) + + return buf + + +cdef object dtype_dlpack_to_numpy(DLDataType* dtype): + cdef int bits = dtype.bits + if dtype.lanes != 1: + # TODO: return a NumPy structured dtype? + raise NotImplementedError( + f'vector dtypes (lanes={dtype.lanes}) is not supported') + if dtype.code == kDLUInt: + if bits == 8: + np_dtype = numpy.uint8 + elif bits == 16: + np_dtype = numpy.uint16 + elif bits == 32: + np_dtype = numpy.uint32 + elif bits == 64: + np_dtype = numpy.uint64 + else: + raise TypeError('uint{} is not supported.'.format(bits)) + elif dtype.code == kDLInt: + if bits == 8: + np_dtype = numpy.int8 + elif bits == 16: + np_dtype = numpy.int16 + elif bits == 32: + np_dtype = numpy.int32 + elif bits == 64: + np_dtype = numpy.int64 + else: + raise TypeError('int{} is not supported.'.format(bits)) + elif dtype.code == kDLFloat: + if bits == 16: + np_dtype = numpy.float16 + elif bits == 32: + np_dtype = numpy.float32 + elif bits == 64: + np_dtype = numpy.float64 + else: + raise TypeError('float{} is not supported.'.format(bits)) + elif dtype.code == kDLComplex: + # TODO(leofang): support complex32 + if bits == 64: + np_dtype = numpy.complex64 + elif bits == 128: + np_dtype = numpy.complex128 + else: + raise TypeError('complex{} is not supported.'.format(bits)) + elif dtype.code == kDLBool: + if bits == 8: + np_dtype = numpy.bool_ + else: + raise TypeError(f'{bits}-bit bool is not supported') + elif dtype.code == kDLBfloat: + # TODO(leofang): use ml_dtype.bfloat16? + raise NotImplementedError('bfloat is not supported yet') + else: + raise TypeError('Unsupported dtype. dtype code: {}'.format(dtype.code)) + + # We want the dtype object not just the type object + return numpy.dtype(np_dtype) + + +cdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None): + cdef dict cai_data = obj.__cuda_array_interface__ + if cai_data["version"] < 3: + raise BufferError("only CUDA Array Interface v3 or above is supported") + if cai_data.get("mask") is not None: + raise BufferError("mask is not supported") + if stream_ptr is None: + raise BufferError("stream=None is ambiguous with view()") + + cdef StridedMemoryView buf = StridedMemoryView() if view is None else view + buf.obj = obj + buf.ptr, buf.readonly = cai_data["data"] + buf.shape = cai_data["shape"] + # TODO: this only works for built-in numeric types + buf.dtype = numpy.dtype(cai_data["typestr"]) + buf.strides = cai_data.get("strides") + if buf.strides is not None: + # convert to counts + buf.strides = tuple(s // buf.dtype.itemsize for s in buf.strides) + buf.device_accessible = True + buf.device_id = handle_return( + cuda.cuPointerGetAttribute( + cuda.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL, + buf.ptr)) + + cdef intptr_t producer_s, consumer_s + stream = cai_data.get("stream") + if stream is not None: + producer_s = (stream) + consumer_s = (stream_ptr) + assert producer_s > 0 + # establish stream order + if producer_s != consumer_s: + e = handle_return(cuda.cuEventCreate( + cuda.CUevent_flags.CU_EVENT_DISABLE_TIMING)) + handle_return(cuda.cuEventRecord(e, producer_s)) + handle_return(cuda.cuStreamWaitEvent(consumer_s, e, 0)) + handle_return(cuda.cuEventDestroy(e)) + + return buf + + +def viewable(tuple arg_indices): + def wrapped_func_with_indices(func): + @functools.wraps(func) + def wrapped_func(*args, **kwargs): + args = list(args) + cdef int idx + for idx in arg_indices: + args[idx] = _StridedMemoryViewProxy(args[idx]) + return func(*args, **kwargs) + return wrapped_func + return wrapped_func_with_indices diff --git a/cuda_core/cuda/core/_module.py b/cuda_core/cuda/core/_module.py new file mode 100644 index 0000000..9892636 --- /dev/null +++ b/cuda_core/cuda/core/_module.py @@ -0,0 +1,85 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +from cuda import cuda, cudart +from cuda.core._utils import handle_return + + +_backend = { + "new": { + "file": cuda.cuLibraryLoadFromFile, + "data": cuda.cuLibraryLoadData, + "kernel": cuda.cuLibraryGetKernel, + }, + "old": { + "file": cuda.cuModuleLoad, + "data": cuda.cuModuleLoadDataEx, + "kernel": cuda.cuModuleGetFunction, + }, +} + + +class Kernel: + + __slots__ = ("_handle", "_module",) + + def __init__(self): + raise NotImplementedError("directly constructing a Kernel instance is not supported") + + @staticmethod + def _from_obj(obj, mod): + assert isinstance(obj, (cuda.CUkernel, cuda.CUfunction)) + assert isinstance(mod, ObjectCode) + ker = Kernel.__new__(Kernel) + ker._handle = obj + ker._module = mod + return ker + + +class ObjectCode: + + __slots__ = ("_handle", "_code_type", "_module", "_loader", "_sym_map") + _supported_code_type = ("cubin", "ptx", "fatbin") + + def __init__(self, module, code_type, jit_options=None, *, + symbol_mapping=None): + if code_type not in self._supported_code_type: + raise ValueError + self._handle = None + + driver_ver = handle_return(cuda.cuDriverGetVersion()) + self._loader = _backend["new"] if driver_ver >= 12000 else _backend["old"] + + if isinstance(module, str): + if driver_ver < 12000 and jit_options is not None: + raise ValueError + module = module.encode() + self._handle = handle_return(self._loader["file"](module)) + else: + assert isinstance(module, bytes) + if jit_options is None: + jit_options = {} + if driver_ver >= 12000: + args = (module, list(jit_options.keys()), list(jit_options.values()), len(jit_options), + # TODO: support library options + [], [], 0) + else: + args = (module, len(jit_options), jit_options.keys(), jit_options.values()) + self._handle = handle_return(self._loader["data"](*args)) + + self._code_type = code_type + self._module = module + self._sym_map = {} if symbol_mapping is None else symbol_mapping + + def __del__(self): + # TODO: do we want to unload? Probably not.. + pass + + def get_kernel(self, name): + try: + name = self._sym_map[name] + except KeyError: + name = name.encode() + data = handle_return(self._loader["kernel"](self._handle, name)) + return Kernel._from_obj(data, self) diff --git a/cuda_core/cuda/core/_program.py b/cuda_core/cuda/core/_program.py new file mode 100644 index 0000000..0c0f02d --- /dev/null +++ b/cuda_core/cuda/core/_program.py @@ -0,0 +1,85 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +from cuda import nvrtc +from cuda.core._utils import handle_return +from cuda.core._module import ObjectCode + + +class Program: + + __slots__ = ("_handle", "_backend", ) + _supported_code_type = ("c++", ) + _supported_target_type = ("ptx", "cubin", "ltoir", ) + + def __init__(self, code, code_type): + if code_type not in self._supported_code_type: + raise NotImplementedError + self._handle = None + + if code_type.lower() == "c++": + if not isinstance(code, str): + raise TypeError + # TODO: support pre-loaded headers & include names + # TODO: allow tuples once NVIDIA/cuda-python#72 is resolved + self._handle = handle_return( + nvrtc.nvrtcCreateProgram(code.encode(), b"", 0, [], [])) + self._backend = "nvrtc" + else: + raise NotImplementedError + + def __del__(self): + self.close() + + def close(self): + if self._handle is not None: + handle_return(nvrtc.nvrtcDestroyProgram(self._handle)) + self._handle = None + + def compile(self, target_type, options=(), name_expressions=(), logs=None): + if target_type not in self._supported_target_type: + raise NotImplementedError + + if self._backend == "nvrtc": + if name_expressions: + for n in name_expressions: + handle_return( + nvrtc.nvrtcAddNameExpression(self._handle, n.encode()), + handle=self._handle) + # TODO: allow tuples once NVIDIA/cuda-python#72 is resolved + options = list(o.encode() for o in options) + handle_return( + nvrtc.nvrtcCompileProgram(self._handle, len(options), options), + handle=self._handle) + + size_func = getattr(nvrtc, f"nvrtcGet{target_type.upper()}Size") + comp_func = getattr(nvrtc, f"nvrtcGet{target_type.upper()}") + size = handle_return(size_func(self._handle), handle=self._handle) + data = b" " * size + handle_return(comp_func(self._handle, data), handle=self._handle) + + symbol_mapping = {} + if name_expressions: + for n in name_expressions: + symbol_mapping[n] = handle_return(nvrtc.nvrtcGetLoweredName( + self._handle, n.encode())) + + if logs is not None: + logsize = handle_return(nvrtc.nvrtcGetProgramLogSize(self._handle)) + if logsize > 1: + log = b" " * logsize + handle_return(nvrtc.nvrtcGetProgramLog(self._handle, log)) + logs.write(log.decode()) + + # TODO: handle jit_options for ptx? + + return ObjectCode(data, target_type, symbol_mapping=symbol_mapping) + + @property + def backend(self): + return self._backend + + @property + def handle(self): + return self._handle diff --git a/cuda_core/cuda/core/_stream.py b/cuda_core/cuda/core/_stream.py new file mode 100644 index 0000000..e815f9a --- /dev/null +++ b/cuda_core/cuda/core/_stream.py @@ -0,0 +1,243 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +from __future__ import annotations + +from dataclasses import dataclass +import os +from typing import Optional, Tuple, TYPE_CHECKING, Union + +if TYPE_CHECKING: + from cuda.core._device import Device +from cuda import cuda, cudart +from cuda.core._context import Context +from cuda.core._event import Event, EventOptions +from cuda.core._utils import check_or_create_options +from cuda.core._utils import get_device_from_ctx +from cuda.core._utils import handle_return + + +@dataclass +class StreamOptions: + + nonblocking: bool = True + priority: Optional[int] = None + + +class Stream: + + __slots__ = ("_handle", "_nonblocking", "_priority", "_owner", "_builtin", + "_device_id", "_ctx_handle") + + def __init__(self): + # minimal requirements for the destructor + self._handle = None + self._owner = None + self._builtin = False + raise NotImplementedError( + "directly creating a Stream object can be ambiguous. Please either " + "call Device.create_stream() or, if a stream pointer is already " + "available from somewhere else, Stream.from_handle()") + + @staticmethod + def _init(obj=None, *, options: Optional[StreamOptions]=None): + self = Stream.__new__(Stream) + + # minimal requirements for the destructor + self._handle = None + self._owner = None + self._builtin = False + + if obj is not None and options is not None: + raise ValueError("obj and options cannot be both specified") + if obj is not None: + if not hasattr(obj, "__cuda_stream__"): + raise ValueError + info = obj.__cuda_stream__ + assert info[0] == 0 + self._handle = cuda.CUstream(info[1]) + # TODO: check if obj is created under the current context/device + self._owner = obj + self._nonblocking = None # delayed + self._priority = None # delayed + self._device_id = None # delayed + self._ctx_handle = None # delayed + return self + + options = check_or_create_options(StreamOptions, options, "Stream options") + nonblocking = options.nonblocking + priority = options.priority + + if nonblocking: + flags = cuda.CUstream_flags.CU_STREAM_NON_BLOCKING + else: + flags = cuda.CUstream_flags.CU_STREAM_DEFAULT + + if priority is not None: + high, low = handle_return( + cudart.cudaDeviceGetStreamPriorityRange()) + if not (low <= priority <= high): + raise ValueError(f"{priority=} is out of range {[low, high]}") + else: + priority = 0 + + self._handle = handle_return( + cuda.cuStreamCreateWithPriority(flags, priority)) + self._owner = None + self._nonblocking = nonblocking + self._priority = priority + # don't defer this because we will have to pay a cost for context + # switch later + self._device_id = int(handle_return(cuda.cuCtxGetDevice())) + self._ctx_handle = None # delayed + return self + + def __del__(self): + self.close() + + def close(self): + if self._owner is None: + if self._handle and not self._builtin: + handle_return(cuda.cuStreamDestroy(self._handle)) + else: + self._owner = None + self._handle = None + + @property + def __cuda_stream__(self) -> Tuple[int, int]: + return (0, int(self._handle)) + + @property + def handle(self) -> int: + # Return the underlying cudaStream_t pointer address as Python int. + return int(self._handle) + + @property + def is_nonblocking(self) -> bool: + if self._nonblocking is None: + flag = handle_return(cuda.cuStreamGetFlags(self._handle)) + if flag == cuda.CUstream_flags.CU_STREAM_NON_BLOCKING: + self._nonblocking = True + else: + self._nonblocking = False + return self._nonblocking + + @property + def priority(self) -> int: + if self._priority is None: + prio = handle_return(cuda.cuStreamGetPriority(self._handle)) + self._priority = prio + return self._priority + + def sync(self): + handle_return(cuda.cuStreamSynchronize(self._handle)) + + def record(self, event: Event=None, options: EventOptions=None) -> Event: + # Create an Event object (or reusing the given one) by recording + # on the stream. Event flags such as disabling timing, nonblocking, + # and CU_EVENT_RECORD_EXTERNAL, can be set in EventOptions. + if event is None: + event = Event._init(options) + elif not isinstance(event, Event): + raise TypeError("record only takes an Event object") + handle_return(cuda.cuEventRecord(event.handle, self._handle)) + return event + + def wait(self, event_or_stream: Union[Event, Stream]): + # Wait for a CUDA event or a CUDA stream to establish a stream order. + # + # If a Stream instance is provided, the effect is as if an event is + # recorded on the given stream, and then self waits on the recorded + # event. + if isinstance(event_or_stream, Event): + event = event_or_stream.handle + discard_event = False + else: + if not isinstance(event_or_stream, Stream): + try: + stream = Stream._init(event_or_stream) + except Exception as e: + raise ValueError( + "only an Event, Stream, or object supporting " + "__cuda_stream__ can be waited") from e + else: + stream = event_or_stream + event = handle_return( + cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DISABLE_TIMING)) + handle_return(cuda.cuEventRecord(event, stream.handle)) + discard_event = True + + # TODO: support flags other than 0? + handle_return(cuda.cuStreamWaitEvent(self._handle, event, 0)) + if discard_event: + handle_return(cuda.cuEventDestroy(event)) + + @property + def device(self) -> Device: + # Inverse look-up to find on which device this stream instance was + # created. + # + # Note that Stream.device.context might not necessarily agree with + # Stream.context, in cases where a different CUDA context is set + # current after a stream was created. + from cuda.core._device import Device # avoid circular import + if self._device_id is None: + # Get the stream context first + if self._ctx_handle is None: + self._ctx_handle = handle_return( + cuda.cuStreamGetCtx(self._handle)) + self._device_id = get_device_from_ctx(self._ctx_handle) + return Device(self._device_id) + + @property + def context(self) -> Context: + # Inverse look-up to find in which CUDA context this stream instance + # was created + if self._ctx_handle is None: + self._ctx_handle = handle_return( + cuda.cuStreamGetCtx(self._handle)) + if self._device_id is None: + self._device_id = get_device_from_ctx(self._ctx_handle) + return Context._from_ctx(self._ctx_handle, self._device_id) + + @staticmethod + def from_handle(handle: int) -> Stream: + class _stream_holder: + @property + def __cuda_stream__(self): + return (0, handle) + return Stream._init(obj=_stream_holder()) + + +class _LegacyDefaultStream(Stream): + + def __init__(self): + self._handle = cuda.CUstream(cuda.CU_STREAM_LEGACY) + self._owner = None + self._nonblocking = None # delayed + self._priority = None # delayed + self._builtin = True + + +class _PerThreadDefaultStream(Stream): + + def __init__(self): + self._handle = cuda.CUstream(cuda.CU_STREAM_PER_THREAD) + self._owner = None + self._nonblocking = None # delayed + self._priority = None # delayed + self._builtin = True + + +LEGACY_DEFAULT_STREAM = _LegacyDefaultStream() +PER_THREAD_DEFAULT_STREAM = _PerThreadDefaultStream() + + +def default_stream(): + # TODO: flip the default + use_ptds = int(os.environ.get('CUDA_PYTHON_CUDA_PER_THREAD_DEFAULT_STREAM', 0)) + if use_ptds: + return PER_THREAD_DEFAULT_STREAM + else: + return LEGACY_DEFAULT_STREAM diff --git a/cuda_core/cuda/core/_utils.py b/cuda_core/cuda/core/_utils.py new file mode 100644 index 0000000..bd3c5cd --- /dev/null +++ b/cuda_core/cuda/core/_utils.py @@ -0,0 +1,131 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +from collections import namedtuple +import functools +from typing import Callable, Dict + +from cuda import cuda, cudart, nvrtc + + +class CUDAError(Exception): pass + + +class NVRTCError(CUDAError): pass + + +ComputeCapability = namedtuple("ComputeCapability", ("major", "minor")) + + +def _check_error(error, handle=None): + if isinstance(error, cuda.CUresult): + if error == cuda.CUresult.CUDA_SUCCESS: + return + err, name = cuda.cuGetErrorName(error) + if err == cuda.CUresult.CUDA_SUCCESS: + err, desc = cuda.cuGetErrorString(error) + if err == cuda.CUresult.CUDA_SUCCESS: + raise CUDAError(f"{name.decode()}: {desc.decode()}") + else: + raise CUDAError(f"unknown error: {error}") + elif isinstance(error, cudart.cudaError_t): + if error == cudart.cudaError_t.cudaSuccess: + return + err, name = cudart.cudaGetErrorName(error) + if err == cudart.cudaError_t.cudaSuccess: + err, desc = cudart.cudaGetErrorString(error) + if err == cudart.cudaError_t.cudaSuccess: + raise CUDAError(f"{name.decode()}: {desc.decode()}") + else: + raise CUDAError(f"unknown error: {error}") + elif isinstance(error, nvrtc.nvrtcResult): + if error == nvrtc.nvrtcResult.NVRTC_SUCCESS: + return + assert handle is not None + _, logsize = nvrtc.nvrtcGetProgramLogSize(handle) + log = b" " * logsize + _ = nvrtc.nvrtcGetProgramLog(handle, log) + err = f"{error}: {nvrtc.nvrtcGetErrorString(error)[1].decode()}, " \ + f"compilation log:\n\n{log.decode()}" + raise NVRTCError(err) + else: + raise RuntimeError('Unknown error type: {}'.format(error)) + + +def handle_return(result, handle=None): + _check_error(result[0], handle=handle) + if len(result) == 1: + return + elif len(result) == 2: + return result[1] + else: + return result[1:] + + +def check_or_create_options(cls, options, options_description, *, keep_none=False): + """ + Create the specified options dataclass from a dictionary of options or None. + """ + + if options is None: + if keep_none: + return options + options = cls() + elif isinstance(options, Dict): + options = cls(**options) + + if not isinstance(options, cls): + raise TypeError(f"The {options_description} must be provided as an object " + f"of type {cls.__name__} or as a dict with valid {options_description}. " + f"The provided object is '{options}'.") + + return options + + +def precondition(checker: Callable[..., None], what: str = "") -> Callable: + """ + A decorator that adds checks to ensure any preconditions are met. + + Args: + checker: The function to call to check whether the preconditions are met. It has the same signature as the wrapped + function with the addition of the keyword argument `what`. + what: A string that is passed in to `checker` to provide context information. + + Returns: + Callable: A decorator that creates the wrapping. + """ + def outer(wrapped_function): + """ + A decorator that actually wraps the function for checking preconditions. + """ + @functools.wraps(wrapped_function) + def inner(*args, **kwargs): + """ + Check preconditions and if they are met, call the wrapped function. + """ + checker(*args, **kwargs, what=what) + result = wrapped_function(*args, **kwargs) + + return result + + return inner + + return outer + + +def get_device_from_ctx(ctx_handle) -> int: + """Get device ID from the given ctx.""" + prev_ctx = Device().context.handle + if ctx_handle != prev_ctx: + switch_context = True + else: + switch_context = False + if switch_context: + assert prev_ctx == handle_return(cuda.cuCtxPopCurrent()) + handle_return(cuda.cuCtxPushCurrent(ctx_handle)) + device_id = int(handle_return(cuda.cuCtxGetDevice())) + if switch_context: + assert ctx_handle == handle_return(cuda.cuCtxPopCurrent()) + handle_return(cuda.cuCtxPushCurrent(prev_ctx)) + return device_id diff --git a/cuda_core/cuda/core/_version.py b/cuda_core/cuda/core/_version.py new file mode 100644 index 0000000..cc83b46 --- /dev/null +++ b/cuda_core/cuda/core/_version.py @@ -0,0 +1,5 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +__version__ = "0.0.1" diff --git a/cuda_core/cuda/core/dlpack.h b/cuda_core/cuda/core/dlpack.h new file mode 100644 index 0000000..bcb7794 --- /dev/null +++ b/cuda_core/cuda/core/dlpack.h @@ -0,0 +1,332 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file dlpack.h + * \brief The common header of DLPack. + */ +#ifndef DLPACK_DLPACK_H_ +#define DLPACK_DLPACK_H_ + +/** + * \brief Compatibility with C++ + */ +#ifdef __cplusplus +#define DLPACK_EXTERN_C extern "C" +#else +#define DLPACK_EXTERN_C +#endif + +/*! \brief The current major version of dlpack */ +#define DLPACK_MAJOR_VERSION 1 + +/*! \brief The current minor version of dlpack */ +#define DLPACK_MINOR_VERSION 0 + +/*! \brief DLPACK_DLL prefix for windows */ +#ifdef _WIN32 +#ifdef DLPACK_EXPORTS +#define DLPACK_DLL __declspec(dllexport) +#else +#define DLPACK_DLL __declspec(dllimport) +#endif +#else +#define DLPACK_DLL +#endif + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/*! + * \brief The DLPack version. + * + * A change in major version indicates that we have changed the + * data layout of the ABI - DLManagedTensorVersioned. + * + * A change in minor version indicates that we have added new + * code, such as a new device type, but the ABI is kept the same. + * + * If an obtained DLPack tensor has a major version that disagrees + * with the version number specified in this header file + * (i.e. major != DLPACK_MAJOR_VERSION), the consumer must call the deleter + * (and it is safe to do so). It is not safe to access any other fields + * as the memory layout will have changed. + * + * In the case of a minor version mismatch, the tensor can be safely used as + * long as the consumer knows how to interpret all fields. Minor version + * updates indicate the addition of enumeration values. + */ +typedef struct { + /*! \brief DLPack major version. */ + uint32_t major; + /*! \brief DLPack minor version. */ + uint32_t minor; +} DLPackVersion; + +/*! + * \brief The device type in DLDevice. + */ +#ifdef __cplusplus +typedef enum : int32_t { +#else +typedef enum { +#endif + /*! \brief CPU device */ + kDLCPU = 1, + /*! \brief CUDA GPU device */ + kDLCUDA = 2, + /*! + * \brief Pinned CUDA CPU memory by cudaMallocHost + */ + kDLCUDAHost = 3, + /*! \brief OpenCL devices. */ + kDLOpenCL = 4, + /*! \brief Vulkan buffer for next generation graphics. */ + kDLVulkan = 7, + /*! \brief Metal for Apple GPU. */ + kDLMetal = 8, + /*! \brief Verilog simulator buffer */ + kDLVPI = 9, + /*! \brief ROCm GPUs for AMD GPUs */ + kDLROCM = 10, + /*! + * \brief Pinned ROCm CPU memory allocated by hipMallocHost + */ + kDLROCMHost = 11, + /*! + * \brief Reserved extension device type, + * used for quickly test extension device + * The semantics can differ depending on the implementation. + */ + kDLExtDev = 12, + /*! + * \brief CUDA managed/unified memory allocated by cudaMallocManaged + */ + kDLCUDAManaged = 13, + /*! + * \brief Unified shared memory allocated on a oneAPI non-partititioned + * device. Call to oneAPI runtime is required to determine the device + * type, the USM allocation type and the sycl context it is bound to. + * + */ + kDLOneAPI = 14, + /*! \brief GPU support for next generation WebGPU standard. */ + kDLWebGPU = 15, + /*! \brief Qualcomm Hexagon DSP */ + kDLHexagon = 16, + /*! \brief Microsoft MAIA devices */ + kDLMAIA = 17, +} DLDeviceType; + +/*! + * \brief A Device for Tensor and operator. + */ +typedef struct { + /*! \brief The device type used in the device. */ + DLDeviceType device_type; + /*! + * \brief The device index. + * For vanilla CPU memory, pinned memory, or managed memory, this is set to 0. + */ + int32_t device_id; +} DLDevice; + +/*! + * \brief The type code options DLDataType. + */ +typedef enum { + /*! \brief signed integer */ + kDLInt = 0U, + /*! \brief unsigned integer */ + kDLUInt = 1U, + /*! \brief IEEE floating point */ + kDLFloat = 2U, + /*! + * \brief Opaque handle type, reserved for testing purposes. + * Frameworks need to agree on the handle data type for the exchange to be well-defined. + */ + kDLOpaqueHandle = 3U, + /*! \brief bfloat16 */ + kDLBfloat = 4U, + /*! + * \brief complex number + * (C/C++/Python layout: compact struct per complex number) + */ + kDLComplex = 5U, + /*! \brief boolean */ + kDLBool = 6U, +} DLDataTypeCode; + +/*! + * \brief The data type the tensor can hold. The data type is assumed to follow the + * native endian-ness. An explicit error message should be raised when attempting to + * export an array with non-native endianness + * + * Examples + * - float: type_code = 2, bits = 32, lanes = 1 + * - float4(vectorized 4 float): type_code = 2, bits = 32, lanes = 4 + * - int8: type_code = 0, bits = 8, lanes = 1 + * - std::complex: type_code = 5, bits = 64, lanes = 1 + * - bool: type_code = 6, bits = 8, lanes = 1 (as per common array library convention, the underlying storage size of bool is 8 bits) + */ +typedef struct { + /*! + * \brief Type code of base types. + * We keep it uint8_t instead of DLDataTypeCode for minimal memory + * footprint, but the value should be one of DLDataTypeCode enum values. + * */ + uint8_t code; + /*! + * \brief Number of bits, common choices are 8, 16, 32. + */ + uint8_t bits; + /*! \brief Number of lanes in the type, used for vector types. */ + uint16_t lanes; +} DLDataType; + +/*! + * \brief Plain C Tensor object, does not manage memory. + */ +typedef struct { + /*! + * \brief The data pointer points to the allocated data. This will be CUDA + * device pointer or cl_mem handle in OpenCL. It may be opaque on some device + * types. This pointer is always aligned to 256 bytes as in CUDA. The + * `byte_offset` field should be used to point to the beginning of the data. + * + * Note that as of Nov 2021, multiply libraries (CuPy, PyTorch, TensorFlow, + * TVM, perhaps others) do not adhere to this 256 byte aligment requirement + * on CPU/CUDA/ROCm, and always use `byte_offset=0`. This must be fixed + * (after which this note will be updated); at the moment it is recommended + * to not rely on the data pointer being correctly aligned. + * + * For given DLTensor, the size of memory required to store the contents of + * data is calculated as follows: + * + * \code{.c} + * static inline size_t GetDataSize(const DLTensor* t) { + * size_t size = 1; + * for (tvm_index_t i = 0; i < t->ndim; ++i) { + * size *= t->shape[i]; + * } + * size *= (t->dtype.bits * t->dtype.lanes + 7) / 8; + * return size; + * } + * \endcode + * + * Note that if the tensor is of size zero, then the data pointer should be + * set to `NULL`. + */ + void* data; + /*! \brief The device of the tensor */ + DLDevice device; + /*! \brief Number of dimensions */ + int32_t ndim; + /*! \brief The data type of the pointer*/ + DLDataType dtype; + /*! \brief The shape of the tensor */ + int64_t* shape; + /*! + * \brief strides of the tensor (in number of elements, not bytes) + * can be NULL, indicating tensor is compact and row-majored. + */ + int64_t* strides; + /*! \brief The offset in bytes to the beginning pointer to data */ + uint64_t byte_offset; +} DLTensor; + +/*! + * \brief C Tensor object, manage memory of DLTensor. This data structure is + * intended to facilitate the borrowing of DLTensor by another framework. It is + * not meant to transfer the tensor. When the borrowing framework doesn't need + * the tensor, it should call the deleter to notify the host that the resource + * is no longer needed. + * + * \note This data structure is used as Legacy DLManagedTensor + * in DLPack exchange and is deprecated after DLPack v0.8 + * Use DLManagedTensorVersioned instead. + * This data structure may get renamed or deleted in future versions. + * + * \sa DLManagedTensorVersioned + */ +typedef struct DLManagedTensor { + /*! \brief DLTensor which is being memory managed */ + DLTensor dl_tensor; + /*! \brief the context of the original host framework of DLManagedTensor in + * which DLManagedTensor is used in the framework. It can also be NULL. + */ + void * manager_ctx; + /*! + * \brief Destructor - this should be called + * to destruct the manager_ctx which backs the DLManagedTensor. It can be + * NULL if there is no way for the caller to provide a reasonable destructor. + * The destructor deletes the argument self as well. + */ + void (*deleter)(struct DLManagedTensor * self); +} DLManagedTensor; + +// bit masks used in in the DLManagedTensorVersioned + +/*! \brief bit mask to indicate that the tensor is read only. */ +#define DLPACK_FLAG_BITMASK_READ_ONLY (1UL << 0UL) + +/*! + * \brief bit mask to indicate that the tensor is a copy made by the producer. + * + * If set, the tensor is considered solely owned throughout its lifetime by the + * consumer, until the producer-provided deleter is invoked. + */ +#define DLPACK_FLAG_BITMASK_IS_COPIED (1UL << 1UL) + +/*! + * \brief A versioned and managed C Tensor object, manage memory of DLTensor. + * + * This data structure is intended to facilitate the borrowing of DLTensor by + * another framework. It is not meant to transfer the tensor. When the borrowing + * framework doesn't need the tensor, it should call the deleter to notify the + * host that the resource is no longer needed. + * + * \note This is the current standard DLPack exchange data structure. + */ +struct DLManagedTensorVersioned { + /*! + * \brief The API and ABI version of the current managed Tensor + */ + DLPackVersion version; + /*! + * \brief the context of the original host framework. + * + * Stores DLManagedTensorVersioned is used in the + * framework. It can also be NULL. + */ + void *manager_ctx; + /*! + * \brief Destructor. + * + * This should be called to destruct manager_ctx which holds the DLManagedTensorVersioned. + * It can be NULL if there is no way for the caller to provide a reasonable + * destructor. The destructor deletes the argument self as well. + */ + void (*deleter)(struct DLManagedTensorVersioned *self); + /*! + * \brief Additional bitmask flags information about the tensor. + * + * By default the flags should be set to 0. + * + * \note Future ABI changes should keep everything until this field + * stable, to ensure that deleter can be correctly called. + * + * \sa DLPACK_FLAG_BITMASK_READ_ONLY + * \sa DLPACK_FLAG_BITMASK_IS_COPIED + */ + uint64_t flags; + /*! \brief DLTensor which is being memory managed */ + DLTensor dl_tensor; +}; + +#ifdef __cplusplus +} // DLPACK_EXTERN_C +#endif +#endif // DLPACK_DLPACK_H_ diff --git a/cuda_core/cuda/core/utils.py b/cuda_core/cuda/core/utils.py new file mode 100644 index 0000000..3debe1d --- /dev/null +++ b/cuda_core/cuda/core/utils.py @@ -0,0 +1,5 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +from cuda.core._memoryview import StridedMemoryView, viewable diff --git a/cuda_core/examples/saxpy.py b/cuda_core/examples/saxpy.py new file mode 100644 index 0000000..7d296de --- /dev/null +++ b/cuda_core/examples/saxpy.py @@ -0,0 +1,104 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +import sys + +from cuda.core import Device +from cuda.core import LaunchConfig, launch +from cuda.core import Program + +import cupy as cp + + +# compute out = a * x + y +code = """ +template +__global__ void saxpy(const T a, + const T* x, + const T* y, + T* out, + size_t N) { + const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; + for (size_t i=tid; i", "saxpy")) + +# run in single precision +ker = mod.get_kernel("saxpy") +dtype = cp.float32 + +# prepare input/output +size = cp.uint64(64) +a = dtype(10) +x = cp.random.random(size, dtype=dtype) +y = cp.random.random(size, dtype=dtype) +out = cp.empty_like(x) +dev.sync() # cupy runs on a different stream from s, so sync before accessing + +# prepare launch +block = 32 +grid = int((size + block - 1) // block) +config = LaunchConfig(grid=grid, block=block, stream=s) +ker_args = (a, x.data.ptr, y.data.ptr, out.data.ptr, size) + +# launch kernel on stream s +launch(ker, config, *ker_args) +s.sync() + +# check result +assert cp.allclose(out, a*x+y) + +# let's repeat again, this time allocates our own out buffer instead of cupy's +# run in double precision +ker = mod.get_kernel("saxpy") +dtype = cp.float64 + +# prepare input +size = cp.uint64(128) +a = dtype(42) +x = cp.random.random(size, dtype=dtype) +y = cp.random.random(size, dtype=dtype) +dev.sync() + +# prepare output +buf = dev.allocate(size * 8, # = dtype.itemsize + stream=s) + +# prepare launch +block = 64 +grid = int((size + block - 1) // block) +config = LaunchConfig(grid=grid, block=block, stream=s) +ker_args = (a, x.data.ptr, y.data.ptr, buf, size) + +# launch kernel on stream s +launch(ker, config, *ker_args) +s.sync() + +# check result +# we wrap output buffer as a cupy array for simplicity +out = cp.ndarray(size, dtype=dtype, + memptr=cp.cuda.MemoryPointer(cp.cuda.UnownedMemory(int(buf.handle), buf.size, buf), 0)) +assert cp.allclose(out, a*x+y) + +# clean up resources that we allocate +# cupy cleans up automatically the rest +buf.close(s) +s.close() + +print("done!") diff --git a/cuda_core/examples/vector_add.py b/cuda_core/examples/vector_add.py new file mode 100644 index 0000000..8248ad3 --- /dev/null +++ b/cuda_core/examples/vector_add.py @@ -0,0 +1,62 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +from cuda.core import Device +from cuda.core import LaunchConfig, launch +from cuda.core import Program + +import cupy as cp + + +# compute c = a + b +code = """ +template +__global__ void vector_add(const T* A, + const T* B, + T* C, + size_t N) { + const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; + for (size_t i=tid; i",)) + +# run in single precision +ker = mod.get_kernel("vector_add") +dtype = cp.float32 + +# prepare input/output +size = 50000 +a = cp.random.random(size, dtype=dtype) +b = cp.random.random(size, dtype=dtype) +c = cp.empty_like(a) + +# cupy runs on a different stream from s, so sync before accessing +dev.sync() + +# prepare launch +block = 256 +grid = (size + block - 1) // block +config = LaunchConfig(grid=grid, block=block, stream=s) + +# launch kernel on stream s +launch(ker, config, a.data.ptr, b.data.ptr, c.data.ptr, cp.uint64(size)) +s.sync() + +# check result +assert cp.allclose(c, a+b) +print("done!") diff --git a/cuda_core/pyproject.toml b/cuda_core/pyproject.toml new file mode 100644 index 0000000..cf1e5b4 --- /dev/null +++ b/cuda_core/pyproject.toml @@ -0,0 +1,55 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +[build-system] +requires = ["setuptools", "Cython>=3.0"] +build-backend = "setuptools.build_meta" + + +[project] +name = "cuda-core" +dynamic = [ + "version", + "readme", +] +requires-python = '>=3.9' +description = "cuda.core: (experimental) pythonic CUDA module" +authors = [ + { name = "NVIDIA Corporation" } +] +license = {text = "NVIDIA Software License"} +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Intended Audience :: End Users/Desktop", + "Natural Language :: English", + "License :: Other/Proprietary License", + "Operating System :: POSIX :: Linux", + "Operating System :: Microsoft :: Windows", + "Topic :: Education", + "Topic :: Scientific/Engineering", + "Topic :: Software Development :: Libraries", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", + "Environment :: GPU :: NVIDIA CUDA", + "Environment :: GPU :: NVIDIA CUDA :: 11", + "Environment :: GPU :: NVIDIA CUDA :: 12", +] +dependencies = [ + "numpy", +] + + +[tool.setuptools] +packages = ["cuda", "cuda.core"] + + +[tool.setuptools.dynamic] +version = { attr = "cuda.core._version.__version__" } +readme = { file = ["README.md"], content-type = "text/markdown" } diff --git a/cuda_core/setup.py b/cuda_core/setup.py new file mode 100644 index 0000000..862d38d --- /dev/null +++ b/cuda_core/setup.py @@ -0,0 +1,49 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +import os + +from Cython.Build import cythonize +from setuptools import setup, Extension, find_packages +from setuptools.command.build_ext import build_ext as _build_ext + + +ext_modules = ( + Extension( + "cuda.core._dlpack", + sources=["cuda/core/_dlpack.pyx"], + language="c++", + ), + Extension( + "cuda.core._memoryview", + sources=["cuda/core/_memoryview.pyx"], + language="c++", + ), + Extension( + "cuda.core._kernel_arg_handler", + sources=["cuda/core/_kernel_arg_handler.pyx"], + language="c++", + ), +) + + +class build_ext(_build_ext): + + def build_extensions(self): + self.parallel = os.cpu_count() // 2 + super().build_extensions() + + +setup( + ext_modules=cythonize(ext_modules, + verbose=True, language_level=3, + compiler_directives={'embedsignature': True}), + packages=find_packages(include=['cuda.core', 'cuda.core.*']), + package_data=dict.fromkeys( + find_packages(include=["cuda.core.*"]), + ["*.pxd", "*.pyx", "*.py"], + ), + cmdclass = {'build_ext': build_ext,}, + zip_safe=False, +)