From 48a305c98530792ebcf9b47550193dd36a911f1c Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Thu, 5 Sep 2024 13:15:04 +0000 Subject: [PATCH] fix dtype repr and stream pass-through --- cuda_py/cuda/py/_memoryview.pyx | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/cuda_py/cuda/py/_memoryview.pyx b/cuda_py/cuda/py/_memoryview.pyx index fd09c7a3..25f2fc89 100644 --- a/cuda_py/cuda/py/_memoryview.pyx +++ b/cuda_py/cuda/py/_memoryview.pyx @@ -29,7 +29,7 @@ cdef class GPUMemoryView: return (f"GPUMemoryView(ptr={self.ptr},\n" + f" shape={self.shape},\n" + f" strides={self.strides},\n" - + f" dtype={get_simple_repr(numpy.dtype(self.dtype))},\n" + + f" dtype={self.dtype.__name__},\n" + f" device_id={self.device_id},\n" + f" device_accessible={self.device_accessible},\n" + f" readonly={self.readonly},\n" @@ -39,7 +39,7 @@ cdef class GPUMemoryView: cdef str get_simple_repr(obj): cdef object obj_class = obj.__class__ cdef str obj_repr - if obj_class.__module__ in (None, "__builtin__"): + if obj_class.__module__ in (None, "builtins"): obj_repr = obj_class.__name__ else: obj_repr = f"{obj_class.__module__}.{obj_class.__name__}" @@ -78,17 +78,24 @@ cdef GPUMemoryView view_as_dlpack(obj, stream_ptr): if dldevice == _kDLCPU: device_accessible = False assert device_id == 0 - stream_ptr = None + if stream_ptr is None: + raise BufferError("stream=None is ambiguous with view()") + elif stream_ptr == -1: + stream_ptr = None elif dldevice == _kDLCUDA: device_accessible = True - stream_ptr = -1 + # no need to check other stream values, it's a pass-through + if stream_ptr is None: + raise BufferError("stream=None is ambiguous with view()") elif dldevice == _kDLCUDAHost: device_accessible = True assert device_id == 0 - stream_ptr = None + # just do a pass-through without any checks, as pinned memory can be + # accessed on both host and device elif dldevice == _kDLCUDAManaged: device_accessible = True - stream_ptr = -1 + # just do a pass-through without any checks, as managed memory can be + # accessed on both host and device else: raise BufferError("device not supported")