From fc6d176633fe54691f931f4d258c9f0be22fa29f Mon Sep 17 00:00:00 2001 From: ksimpson Date: Thu, 24 Oct 2024 17:16:25 -0700 Subject: [PATCH 01/13] update tests --- cuda_core/tests/test_context.py | 9 ++ cuda_core/tests/test_device.py | 66 ++++++++++ cuda_core/tests/test_event.py | 38 ++++++ cuda_core/tests/test_launcher.py | 77 ++++++++++++ cuda_core/tests/test_memory.py | 199 +++++++++++++++++++++++++++++++ cuda_core/tests/test_module.py | 35 ++++++ cuda_core/tests/test_program.py | 58 +++++++++ cuda_core/tests/test_stream.py | 82 +++++++++++++ 8 files changed, 564 insertions(+) create mode 100644 cuda_core/tests/test_context.py create mode 100644 cuda_core/tests/test_device.py create mode 100644 cuda_core/tests/test_event.py create mode 100644 cuda_core/tests/test_launcher.py create mode 100644 cuda_core/tests/test_memory.py create mode 100644 cuda_core/tests/test_module.py create mode 100644 cuda_core/tests/test_program.py create mode 100644 cuda_core/tests/test_stream.py diff --git a/cuda_core/tests/test_context.py b/cuda_core/tests/test_context.py new file mode 100644 index 0000000..823e0f2 --- /dev/null +++ b/cuda_core/tests/test_context.py @@ -0,0 +1,9 @@ +from cuda.core.experimental._context import Context + +def test_context_initialization(): + try: + context = Context() + except NotImplementedError as e: + assert True + else: + assert False, "Expected NotImplementedError was not raised" \ No newline at end of file diff --git a/cuda_core/tests/test_device.py b/cuda_core/tests/test_device.py new file mode 100644 index 0000000..4f4af0f --- /dev/null +++ b/cuda_core/tests/test_device.py @@ -0,0 +1,66 @@ +from cuda import cuda, cudart +from cuda.core.experimental._device import Device +from cuda.core.experimental._utils import handle_return, ComputeCapability, CUDAError, \ + precondition +import pytest + +@pytest.fixture(scope='module') +def init_cuda(): + Device().set_current() + +def test_device_initialization(): + device = Device() + assert device is not None + +def test_device_repr(): + device = Device() + assert str(device).startswith('= 11040: + uuid = handle_return(cuda.cuDeviceGetUuid_v2(device.device_id)) + else: + uuid = handle_return(cuda.cuDeviceGetUuid(device.device_id)) + uuid = uuid.bytes.hex() + expected_uuid = f"{uuid[:8]}-{uuid[8:12]}-{uuid[12:16]}-{uuid[16:20]}-{uuid[20:]}" + assert device.uuid == expected_uuid + +def test_name(): + device = Device() + name = handle_return(cuda.cuDeviceGetName(128, device.device_id)) + name = name.split(b'\0')[0] + assert device.name == name.decode() + +def test_compute_capability(): + device = Device() + major = handle_return(cudart.cudaDeviceGetAttribute( + cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor, device.device_id)) + minor = handle_return(cudart.cudaDeviceGetAttribute( + cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor, device.device_id)) + expected_cc = ComputeCapability(major, minor) + assert device.compute_capability == expected_cc \ No newline at end of file diff --git a/cuda_core/tests/test_event.py b/cuda_core/tests/test_event.py new file mode 100644 index 0000000..e725c88 --- /dev/null +++ b/cuda_core/tests/test_event.py @@ -0,0 +1,38 @@ +from cuda import cuda +from cuda.core.experimental._event import EventOptions, Event +from cuda.core.experimental._utils import handle_return + +def test_is_timing_disabled(): + options = EventOptions(enable_timing=False) + event = Event._init(options) + assert event.is_timing_disabled == True + +def test_is_sync_busy_waited(): + options = EventOptions(busy_waited_sync=True) + event = Event._init(options) + assert event.is_sync_busy_waited == True + +def test_is_ipc_supported(): + options = EventOptions(support_ipc=True) + try: + event = Event._init(options) + except NotImplementedError: + assert True + else: + assert False + +def test_sync(): + options = EventOptions() + event = Event._init(options) + event.sync() + assert event.is_done == True + +def test_is_done(): + options = EventOptions() + event = Event._init(options) + assert event.is_done == True + +def test_handle(): + options = EventOptions() + event = Event._init(options) + assert isinstance(event.handle, int) diff --git a/cuda_core/tests/test_launcher.py b/cuda_core/tests/test_launcher.py new file mode 100644 index 0000000..a54a8fb --- /dev/null +++ b/cuda_core/tests/test_launcher.py @@ -0,0 +1,77 @@ +from cuda.core.experimental._launcher import LaunchConfig +from cuda.core.experimental._stream import Stream +from cuda.core.experimental._device import Device +from cuda.core.experimental._utils import handle_return +from cuda import cuda + +def test_launch_config_init(): + config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), stream=None, shmem_size=0) + assert config.grid == (1, 1, 1) + assert config.block == (1, 1, 1) + assert config.stream is None + assert config.shmem_size == 0 + + config = LaunchConfig(grid=(2, 2, 2), block=(2, 2, 2), stream=Device().create_stream(), shmem_size=1024) + assert config.grid == (2, 2, 2) + assert config.block == (2, 2, 2) + assert isinstance(config.stream, Stream) + assert config.shmem_size == 1024 + +def test_launch_config_cast_to_3_tuple(): + config = LaunchConfig(grid=1, block=1) + assert config._cast_to_3_tuple(1) == (1, 1, 1) + assert config._cast_to_3_tuple((1, 2)) == (1, 2, 1) + assert config._cast_to_3_tuple((1, 2, 3)) == (1, 2, 3) + + # Edge cases + assert config._cast_to_3_tuple(999) == (999, 1, 1) + assert config._cast_to_3_tuple((999, 888)) == (999, 888, 1) + assert config._cast_to_3_tuple((999, 888, 777)) == (999, 888, 777) + +def test_launch_config_invalid_values(): + try: + LaunchConfig(grid=0, block=1) + except ValueError: + assert True + else: + assert False + + try: + LaunchConfig(grid=(0, 1), block=1) + except ValueError: + assert True + else: + assert False + + try: + LaunchConfig(grid=(1, 1, 1), block=0) + except ValueError: + assert True + else: + assert False + + try: + LaunchConfig(grid=(1, 1, 1), block=(0, 1)) + except ValueError: + assert True + else: + assert False + +def test_launch_config_stream(): + stream = Device().create_stream() + config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), stream=stream, shmem_size=0) + assert config.stream == stream + + try: + LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), stream="invalid_stream", shmem_size=0) + except ValueError: + assert True + else: + assert False + +def test_launch_config_shmem_size(): + config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), stream=None, shmem_size=2048) + assert config.shmem_size == 2048 + + config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), stream=None) + assert config.shmem_size == 0 \ No newline at end of file diff --git a/cuda_core/tests/test_memory.py b/cuda_core/tests/test_memory.py new file mode 100644 index 0000000..3fddc3d --- /dev/null +++ b/cuda_core/tests/test_memory.py @@ -0,0 +1,199 @@ +# FILE: test_memory.py + +from cuda.core.experimental._memory import Buffer, MemoryResource +from cuda.core.experimental._device import Device +from cuda import cuda +from cuda.core.experimental._utils import handle_return +import ctypes + +@pytest.fixture(scope='module') +def init_cuda(): + Device().set_current() + +class DummyDeviceMemoryResource(MemoryResource): + def __init__(self, device): + self.device = device + pass + + def allocate(self, size, stream=None) -> Buffer: + ptr = handle_return(cuda.cuMemAlloc(size)) + return Buffer(ptr=ptr, size=size, mr=self) + + def deallocate(self, ptr, size, stream=None): + cuda.cuMemFree(ptr) + + @property + def is_device_accessible(self) -> bool: + return True + + @property + def is_host_accessible(self) -> bool: + return False + + @property + def device_id(self) -> int: + return 0 + +class DummyHostMemoryResource(MemoryResource): + def __init__(self): + pass + + def allocate(self, size, stream=None) -> Buffer: + # Allocate a ctypes buffer of size `size` + ptr = (ctypes.c_byte * size)() + return Buffer(ptr=ptr, size=size, mr=self) + + def deallocate(self, ptr, size, stream=None): + #the memory is deallocated per the ctypes deallocation at garbage collection time + pass + + @property + def is_device_accessible(self) -> bool: + return False + + @property + def is_host_accessible(self) -> bool: + return True + + @property + def device_id(self) -> int: + raise RuntimeError("the pinned memory resource is not bound to any GPU") + +class DummyUnifiedMemoryResource(MemoryResource): + def __init__(self, device): + self.device = device + pass + + def allocate(self, size, stream=None) -> Buffer: + ptr = handle_return(cuda.cuMemAllocManaged(size, cuda.CUmemAttach_flags.CU_MEM_ATTACH_GLOBAL.value)) + return Buffer(ptr=ptr, size=size, mr=self) + + def deallocate(self, ptr, size, stream=None): + cuda.cuMemFree(ptr) + + @property + def is_device_accessible(self) -> bool: + return True + + @property + def is_host_accessible(self) -> bool: + return True + + @property + def device_id(self) -> int: + return 0 + +class DummyPinnedMemoryResource(MemoryResource): + def __init__(self, device): + self.device = device + pass + + def allocate(self, size, stream=None) -> Buffer: + ptr = handle_return(cuda.cuMemAllocHost(size)) + return Buffer(ptr=ptr, size=size, mr=self) + + def deallocate(self, ptr, size, stream=None): + cuda.cuMemFreeHost(ptr) + + @property + def is_device_accessible(self) -> bool: + return True + + @property + def is_host_accessible(self) -> bool: + return True + + @property + def device_id(self) -> int: + raise RuntimeError("the pinned memory resource is not bound to any GPU") + +def buffer_initialization(dummy_mr : MemoryResource): + buffer = dummy_mr.allocate(size=1024) + assert buffer.handle != 0 + assert buffer.size == 1024 + assert buffer.memory_resource == dummy_mr + assert buffer.is_device_accessible == dummy_mr.is_device_accessible + assert buffer.is_host_accessible == dummy_mr.is_host_accessible + dummy_mr.deallocate(buffer.handle, buffer.size) + +def test_buffer_initialization(): + device = Device() + device.set_current() + buffer_initialization(DummyDeviceMemoryResource(device)) + buffer_initialization(DummyHostMemoryResource()) + buffer_initialization(DummyUnifiedMemoryResource(device)) + buffer_initialization(DummyPinnedMemoryResource(device)) + +def buffer_copy_to(dummy_mr : MemoryResource, device : Device, check = False): + src_buffer = dummy_mr.allocate(size=1024) + dst_buffer = dummy_mr.allocate(size=1024) + stream = device.create_stream() + + if check: + src_ptr = ctypes.cast(src_buffer.handle, ctypes.POINTER(ctypes.c_byte)) + for i in range(1024): + src_ptr[i] = ctypes.c_byte(i) + + src_buffer.copy_to(dst_buffer, stream=stream) + device.sync() + + if check: + dst_ptr = ctypes.cast(dst_buffer.handle, ctypes.POINTER(ctypes.c_byte)) + + for i in range(10): + assert dst_ptr[i] == src_ptr[i] + + dummy_mr.deallocate(src_buffer.handle, src_buffer.size) + dummy_mr.deallocate(dst_buffer.handle, dst_buffer.size) + +def test_buffer_copy_to(): + device = Device() + device.set_current() + buffer_copy_to(DummyDeviceMemoryResource(device), device) + buffer_copy_to(DummyUnifiedMemoryResource(device), device) + buffer_copy_to(DummyPinnedMemoryResource(device), device, check = True) + +def buffer_copy_from(dummy_mr : MemoryResource, device, check = False): + src_buffer = dummy_mr.allocate(size=1024) + dst_buffer = dummy_mr.allocate(size=1024) + stream = device.create_stream() + + if check: + src_ptr = ctypes.cast(src_buffer.handle, ctypes.POINTER(ctypes.c_byte)) + for i in range(1024): + src_ptr[i] = ctypes.c_byte(i) + + dst_buffer.copy_from(src_buffer, stream=stream) + device.sync() + + if check: + dst_ptr = ctypes.cast(dst_buffer.handle, ctypes.POINTER(ctypes.c_byte)) + + for i in range(10): + assert dst_ptr[i] == src_ptr[i] + + dummy_mr.deallocate(src_buffer.handle, src_buffer.size) + dummy_mr.deallocate(dst_buffer.handle, dst_buffer.size) + +def test_buffer_copy_from(): + device = Device() + device.set_current() + buffer_copy_from(DummyDeviceMemoryResource(device), device) + buffer_copy_from(DummyUnifiedMemoryResource(device), device) + buffer_copy_from(DummyPinnedMemoryResource(device), device, check = True) + +def buffer_close(dummy_mr : MemoryResource): + buffer = dummy_mr.allocate(size=1024) + buffer.close() + assert buffer.handle == 0 + assert buffer.memory_resource == None + +def test_buffer_close(): + device = Device() + device.set_current() + buffer_close(DummyDeviceMemoryResource(device)) + buffer_close(DummyHostMemoryResource()) + buffer_close(DummyUnifiedMemoryResource(device)) + buffer_close(DummyPinnedMemoryResource(device)) + +test_buffer_copy_to() \ No newline at end of file diff --git a/cuda_core/tests/test_module.py b/cuda_core/tests/test_module.py new file mode 100644 index 0000000..60e4e5c --- /dev/null +++ b/cuda_core/tests/test_module.py @@ -0,0 +1,35 @@ +import pytest +from cuda import cuda +from cuda.core.experimental._device import Device +from cuda.core.experimental._module import Kernel, ObjectCode +from cuda.core.experimental._utils import handle_return + +@pytest.fixture(scope='module') +def init_cuda(): + Device().set_current() + +def test_object_code_initialization(): + # Test with supported code types + for code_type in ["cubin", "ptx", "fatbin"]: + module_data = b"dummy_data" + obj_code = ObjectCode(module_data, code_type) + assert obj_code._code_type == code_type + assert obj_code._module == module_data + assert obj_code._handle is not None + + # Test with unsupported code type + with pytest.raises(ValueError): + ObjectCode(b"dummy_data", "unsupported_code_type") + +#TODO add ObjectCode tests which provide the appropriate data for cuLibraryLoadFromFile +def test_object_code_initialization_with_str(): + assert True + +def test_object_code_initialization_with_jit_options(): + assert True + +def test_object_code_get_kernel(): + assert True + +def test_kernel_from_obj(): + assert True diff --git a/cuda_core/tests/test_program.py b/cuda_core/tests/test_program.py new file mode 100644 index 0000000..3c8e81d --- /dev/null +++ b/cuda_core/tests/test_program.py @@ -0,0 +1,58 @@ +import pytest +from cuda import nvrtc +from cuda.core.experimental._program import Program +from cuda.core.experimental._module import ObjectCode, Kernel +from cuda.core.experimental._device import Device + +@pytest.fixture(scope='module') +def init_cuda(): + Device().set_current() + +def test_program_init_valid_code_type(): + code = "extern \"C\" __global__ void my_kernel() {}" + program = Program(code, "c++") + assert program.backend == "nvrtc" + assert program.handle is not None + +def test_program_init_invalid_code_type(): + code = "extern \"C\" __global__ void my_kernel() {}" + with pytest.raises(NotImplementedError): + Program(code, "python") + +def test_program_init_invalid_code_format(): + code = 12345 + with pytest.raises(TypeError): + Program(code, "c++") + +def test_program_compile_valid_target_type(): + code = "extern \"C\" __global__ void my_kernel() {}" + program = Program(code, "c++") + object_code = program.compile("ptx") + kernel = object_code.get_kernel("my_kernel") + assert isinstance(object_code, ObjectCode) + assert isinstance(kernel, Kernel) + +def test_program_compile_invalid_target_type(): + code = "extern \"C\" __global__ void my_kernel() {}" + program = Program(code, "c++") + with pytest.raises(NotImplementedError): + program.compile("invalid_target") + +def test_program_backend_property(): + code = "extern \"C\" __global__ void my_kernel() {}" + program = Program(code, "c++") + assert program.backend == "nvrtc" + +def test_program_handle_property(): + code = "extern \"C\" __global__ void my_kernel() {}" + program = Program(code, "c++") + assert program.handle is not None + +def test_program_close(): + code = "extern \"C\" __global__ void my_kernel() {}" + program = Program(code, "c++") + program.close() + assert program.handle is None + +Device().set_current() +test_program_compile_valid_target_type() \ No newline at end of file diff --git a/cuda_core/tests/test_stream.py b/cuda_core/tests/test_stream.py new file mode 100644 index 0000000..95ec5e2 --- /dev/null +++ b/cuda_core/tests/test_stream.py @@ -0,0 +1,82 @@ +import pytest +from cuda.core.experimental._stream import Stream, StreamOptions, LEGACY_DEFAULT_STREAM, PER_THREAD_DEFAULT_STREAM, default_stream +from cuda.core.experimental._event import Event, EventOptions +from cuda.core.experimental._device import Device + +@pytest.fixture(scope='module') +def init_cuda(): + Device().set_current() + +def test_stream_init(): + with pytest.raises(NotImplementedError): + Stream() + +def test_stream_init_with_options(): + stream = Stream._init(options=StreamOptions(nonblocking=True, priority=0)) + assert stream.is_nonblocking is True + assert stream.priority == 0 + +def test_stream_handle(): + stream = Stream._init(options=StreamOptions()) + assert isinstance(stream.handle, int) + +def test_stream_is_nonblocking(): + stream = Stream._init(options=StreamOptions(nonblocking=True)) + assert stream.is_nonblocking is True + +def test_stream_priority(): + stream = Stream._init(options=StreamOptions(priority=0)) + assert stream.priority == 0 + stream = Stream._init(options=StreamOptions(priority=-1)) + assert stream.priority == -1 + with pytest.raises(ValueError): + stream = Stream._init(options=StreamOptions(priority=1)) + +def test_stream_sync(): + stream = Stream._init(options=StreamOptions()) + stream.sync() # Should not raise any exceptions + +def test_stream_record(): + stream = Stream._init(options=StreamOptions()) + event = stream.record() + assert isinstance(event, Event) + +def test_stream_record_invalid_event(): + stream = Stream._init(options=StreamOptions()) + with pytest.raises(TypeError): + stream.record(event="invalid_event") + +def test_stream_wait_event(): + stream = Stream._init(options=StreamOptions()) + event = Event._init() + stream.record(event) + stream.wait(event) # Should not raise any exceptions + +def test_stream_wait_invalid_event(): + stream = Stream._init(options=StreamOptions()) + with pytest.raises(ValueError): + stream.wait(event_or_stream="invalid_event") + +def test_stream_device(): + stream = Stream._init(options=StreamOptions()) + device = stream.device + assert isinstance(device, Device) + +def test_stream_context(): + stream = Stream._init(options=StreamOptions()) + context = stream.context + assert context is not None + +def test_stream_from_handle(): + stream = Stream.from_handle(0) + assert isinstance(stream, Stream) + +def test_legacy_default_stream(): + assert isinstance(LEGACY_DEFAULT_STREAM, Stream) + +def test_per_thread_default_stream(): + assert isinstance(PER_THREAD_DEFAULT_STREAM, Stream) + +def test_default_stream(): + stream = default_stream() + assert isinstance(stream, Stream) \ No newline at end of file From 197c6fb4ac12dfd8164361b4283cf316cbec5d63 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Fri, 25 Oct 2024 10:29:07 -0700 Subject: [PATCH 02/13] add license headers (copied from other tests, and test_examples --- cuda_core/tests/test_context.py | 14 +++ cuda_core/tests/test_device.py | 8 ++ cuda_core/tests/test_event.py | 14 +++ cuda_core/tests/test_examples.py | 163 +++++++++++++++++++++++++++++++ cuda_core/tests/test_launcher.py | 13 +++ cuda_core/tests/test_memory.py | 9 +- cuda_core/tests/test_module.py | 10 +- cuda_core/tests/test_program.py | 10 +- cuda_core/tests/test_stream.py | 10 +- 9 files changed, 247 insertions(+), 4 deletions(-) create mode 100644 cuda_core/tests/test_examples.py diff --git a/cuda_core/tests/test_context.py b/cuda_core/tests/test_context.py index 823e0f2..1127540 100644 --- a/cuda_core/tests/test_context.py +++ b/cuda_core/tests/test_context.py @@ -1,4 +1,18 @@ +# Copyright 2021-2024 NVIDIA Corporation. All rights reserved. +# +# Please refer to the NVIDIA end user license agreement (EULA) associated +# with this source code for terms and conditions that govern your use of +# this software. Any use, reproduction, disclosure, or distribution of +# this software and related documentation outside the terms of the EULA +# is strictly prohibited. + from cuda.core.experimental._context import Context +from cuda.core.experimental._device import Device +import pytest + +@pytest.fixture(scope='module') +def init_cuda(): + Device().set_current() def test_context_initialization(): try: diff --git a/cuda_core/tests/test_device.py b/cuda_core/tests/test_device.py index 4f4af0f..4deb9ef 100644 --- a/cuda_core/tests/test_device.py +++ b/cuda_core/tests/test_device.py @@ -1,3 +1,11 @@ +# Copyright 2021-2024 NVIDIA Corporation. All rights reserved. +# +# Please refer to the NVIDIA end user license agreement (EULA) associated +# with this source code for terms and conditions that govern your use of +# this software. Any use, reproduction, disclosure, or distribution of +# this software and related documentation outside the terms of the EULA +# is strictly prohibited. + from cuda import cuda, cudart from cuda.core.experimental._device import Device from cuda.core.experimental._utils import handle_return, ComputeCapability, CUDAError, \ diff --git a/cuda_core/tests/test_event.py b/cuda_core/tests/test_event.py index e725c88..717a251 100644 --- a/cuda_core/tests/test_event.py +++ b/cuda_core/tests/test_event.py @@ -1,6 +1,20 @@ +# Copyright 2021-2024 NVIDIA Corporation. All rights reserved. +# +# Please refer to the NVIDIA end user license agreement (EULA) associated +# with this source code for terms and conditions that govern your use of +# this software. Any use, reproduction, disclosure, or distribution of +# this software and related documentation outside the terms of the EULA +# is strictly prohibited. + from cuda import cuda from cuda.core.experimental._event import EventOptions, Event from cuda.core.experimental._utils import handle_return +from cuda.core.experimental._device import Device +import pytest + +@pytest.fixture(scope='module') +def init_cuda(): + Device().set_current() def test_is_timing_disabled(): options = EventOptions(enable_timing=False) diff --git a/cuda_core/tests/test_examples.py b/cuda_core/tests/test_examples.py new file mode 100644 index 0000000..ae8cf36 --- /dev/null +++ b/cuda_core/tests/test_examples.py @@ -0,0 +1,163 @@ +# Copyright 2021-2024 NVIDIA Corporation. All rights reserved. +# +# Please refer to the NVIDIA end user license agreement (EULA) associated +# with this source code for terms and conditions that govern your use of +# this software. Any use, reproduction, disclosure, or distribution of +# this software and related documentation outside the terms of the EULA +# is strictly prohibited. + +from cuda.core.experimental import Device +from cuda.core.experimental import LaunchConfig, launch +from cuda.core.experimental import Program +import sys +import pytest + +@pytest.fixture(scope='module') +def init_cuda(): + Device().set_current() + +#saxpy example +def test_saxpy_example(): + import cupy as cp + # compute out = a * x + y + code = """ + template + __global__ void saxpy(const T a, + const T* x, + const T* y, + T* out, + size_t N) { + const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; + for (size_t i=tid; i", "saxpy")) + + # run in single precision + ker = mod.get_kernel("saxpy") + dtype = cp.float32 + + # prepare input/output + size = cp.uint64(64) + a = dtype(10) + x = cp.random.random(size, dtype=dtype) + y = cp.random.random(size, dtype=dtype) + out = cp.empty_like(x) + dev.sync() # cupy runs on a different stream from s, so sync before accessing + + # prepare launch + block = 32 + grid = int((size + block - 1) // block) + config = LaunchConfig(grid=grid, block=block, stream=s) + ker_args = (a, x.data.ptr, y.data.ptr, out.data.ptr, size) + + # launch kernel on stream s + launch(ker, config, *ker_args) + s.sync() + + # check result + assert cp.allclose(out, a*x+y) + + # let's repeat again, this time allocates our own out buffer instead of cupy's + # run in double precision + ker = mod.get_kernel("saxpy") + dtype = cp.float64 + + # prepare input + size = cp.uint64(128) + a = dtype(42) + x = cp.random.random(size, dtype=dtype) + y = cp.random.random(size, dtype=dtype) + dev.sync() + + # prepare output + buf = dev.allocate(size * 8, # = dtype.itemsize + stream=s) + + # prepare launch + block = 64 + grid = int((size + block - 1) // block) + config = LaunchConfig(grid=grid, block=block, stream=s) + ker_args = (a, x.data.ptr, y.data.ptr, buf, size) + + # launch kernel on stream s + launch(ker, config, *ker_args) + s.sync() + + # check result + # we wrap output buffer as a cupy array for simplicity + out = cp.ndarray(size, dtype=dtype, + memptr=cp.cuda.MemoryPointer(cp.cuda.UnownedMemory(int(buf.handle), buf.size, buf), 0)) + assert cp.allclose(out, a*x+y) + + # clean up resources that we allocate + # cupy cleans up automatically the rest + buf.close(s) + s.close() + +def test_vector_add_example(): + import cupy as cp + # compute c = a + b + code = """ + template + __global__ void vector_add(const T* A, + const T* B, + T* C, + size_t N) { + const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; + for (size_t i=tid; i",)) + + # run in single precision + ker = mod.get_kernel("vector_add") + dtype = cp.float32 + + # prepare input/output + size = 50000 + a = cp.random.random(size, dtype=dtype) + b = cp.random.random(size, dtype=dtype) + c = cp.empty_like(a) + + # cupy runs on a different stream from s, so sync before accessing + dev.sync() + + # prepare launch + block = 256 + grid = (size + block - 1) // block + config = LaunchConfig(grid=grid, block=block, stream=s) + + # launch kernel on stream s + launch(ker, config, a.data.ptr, b.data.ptr, c.data.ptr, cp.uint64(size)) + s.sync() + + # check result + assert cp.allclose(c, a+b) diff --git a/cuda_core/tests/test_launcher.py b/cuda_core/tests/test_launcher.py index a54a8fb..a58fb93 100644 --- a/cuda_core/tests/test_launcher.py +++ b/cuda_core/tests/test_launcher.py @@ -1,8 +1,21 @@ +# Copyright 2021-2024 NVIDIA Corporation. All rights reserved. +# +# Please refer to the NVIDIA end user license agreement (EULA) associated +# with this source code for terms and conditions that govern your use of +# this software. Any use, reproduction, disclosure, or distribution of +# this software and related documentation outside the terms of the EULA +# is strictly prohibited. + from cuda.core.experimental._launcher import LaunchConfig from cuda.core.experimental._stream import Stream from cuda.core.experimental._device import Device from cuda.core.experimental._utils import handle_return from cuda import cuda +import pytest + +@pytest.fixture(scope='module') +def init_cuda(): + Device().set_current() def test_launch_config_init(): config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), stream=None, shmem_size=0) diff --git a/cuda_core/tests/test_memory.py b/cuda_core/tests/test_memory.py index 3fddc3d..eac610a 100644 --- a/cuda_core/tests/test_memory.py +++ b/cuda_core/tests/test_memory.py @@ -1,10 +1,17 @@ -# FILE: test_memory.py +# Copyright 2021-2024 NVIDIA Corporation. All rights reserved. +# +# Please refer to the NVIDIA end user license agreement (EULA) associated +# with this source code for terms and conditions that govern your use of +# this software. Any use, reproduction, disclosure, or distribution of +# this software and related documentation outside the terms of the EULA +# is strictly prohibited. from cuda.core.experimental._memory import Buffer, MemoryResource from cuda.core.experimental._device import Device from cuda import cuda from cuda.core.experimental._utils import handle_return import ctypes +import pytest @pytest.fixture(scope='module') def init_cuda(): diff --git a/cuda_core/tests/test_module.py b/cuda_core/tests/test_module.py index 60e4e5c..1ca1785 100644 --- a/cuda_core/tests/test_module.py +++ b/cuda_core/tests/test_module.py @@ -1,8 +1,16 @@ -import pytest +# Copyright 2021-2024 NVIDIA Corporation. All rights reserved. +# +# Please refer to the NVIDIA end user license agreement (EULA) associated +# with this source code for terms and conditions that govern your use of +# this software. Any use, reproduction, disclosure, or distribution of +# this software and related documentation outside the terms of the EULA +# is strictly prohibited. + from cuda import cuda from cuda.core.experimental._device import Device from cuda.core.experimental._module import Kernel, ObjectCode from cuda.core.experimental._utils import handle_return +import pytest @pytest.fixture(scope='module') def init_cuda(): diff --git a/cuda_core/tests/test_program.py b/cuda_core/tests/test_program.py index 3c8e81d..8cc9061 100644 --- a/cuda_core/tests/test_program.py +++ b/cuda_core/tests/test_program.py @@ -1,8 +1,16 @@ -import pytest +# Copyright 2021-2024 NVIDIA Corporation. All rights reserved. +# +# Please refer to the NVIDIA end user license agreement (EULA) associated +# with this source code for terms and conditions that govern your use of +# this software. Any use, reproduction, disclosure, or distribution of +# this software and related documentation outside the terms of the EULA +# is strictly prohibited. + from cuda import nvrtc from cuda.core.experimental._program import Program from cuda.core.experimental._module import ObjectCode, Kernel from cuda.core.experimental._device import Device +import pytest @pytest.fixture(scope='module') def init_cuda(): diff --git a/cuda_core/tests/test_stream.py b/cuda_core/tests/test_stream.py index 95ec5e2..31c4ba0 100644 --- a/cuda_core/tests/test_stream.py +++ b/cuda_core/tests/test_stream.py @@ -1,7 +1,15 @@ -import pytest +# Copyright 2021-2024 NVIDIA Corporation. All rights reserved. +# +# Please refer to the NVIDIA end user license agreement (EULA) associated +# with this source code for terms and conditions that govern your use of +# this software. Any use, reproduction, disclosure, or distribution of +# this software and related documentation outside the terms of the EULA +# is strictly prohibited. + from cuda.core.experimental._stream import Stream, StreamOptions, LEGACY_DEFAULT_STREAM, PER_THREAD_DEFAULT_STREAM, default_stream from cuda.core.experimental._event import Event, EventOptions from cuda.core.experimental._device import Device +import pytest @pytest.fixture(scope='module') def init_cuda(): From 2112636a2b849dbef00caa6ce9f0c14beb2c32a9 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Fri, 25 Oct 2024 10:32:21 -0700 Subject: [PATCH 03/13] remove context test --- cuda_core/tests/test_context.py | 23 ----------------------- 1 file changed, 23 deletions(-) delete mode 100644 cuda_core/tests/test_context.py diff --git a/cuda_core/tests/test_context.py b/cuda_core/tests/test_context.py deleted file mode 100644 index 1127540..0000000 --- a/cuda_core/tests/test_context.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright 2021-2024 NVIDIA Corporation. All rights reserved. -# -# Please refer to the NVIDIA end user license agreement (EULA) associated -# with this source code for terms and conditions that govern your use of -# this software. Any use, reproduction, disclosure, or distribution of -# this software and related documentation outside the terms of the EULA -# is strictly prohibited. - -from cuda.core.experimental._context import Context -from cuda.core.experimental._device import Device -import pytest - -@pytest.fixture(scope='module') -def init_cuda(): - Device().set_current() - -def test_context_initialization(): - try: - context = Context() - except NotImplementedError as e: - assert True - else: - assert False, "Expected NotImplementedError was not raised" \ No newline at end of file From 3f7b802d790a910f7ab20ede7b6fae4f7f7a6b59 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Fri, 25 Oct 2024 16:50:05 -0700 Subject: [PATCH 04/13] add global fixture, run examples in new proc. One issue with fixtures and device init status --- cuda_core/examples/saxpy.py | 186 +++++++++--------- cuda_core/tests/conftest.py | 16 ++ .../example_tests/test_basic_examples.py | 15 ++ cuda_core/tests/example_tests/utils.py | 55 ++++++ cuda_core/tests/test_context.py | 33 ++++ cuda_core/tests/test_device.py | 24 +-- cuda_core/tests/test_event.py | 3 - cuda_core/tests/test_examples.py | 163 --------------- cuda_core/tests/test_launcher.py | 34 +--- cuda_core/tests/test_memory.py | 4 - cuda_core/tests/test_module.py | 4 - cuda_core/tests/test_program.py | 22 +-- cuda_core/tests/test_stream.py | 4 - 13 files changed, 235 insertions(+), 328 deletions(-) create mode 100644 cuda_core/tests/conftest.py create mode 100644 cuda_core/tests/example_tests/test_basic_examples.py create mode 100644 cuda_core/tests/example_tests/utils.py create mode 100644 cuda_core/tests/test_context.py delete mode 100644 cuda_core/tests/test_examples.py diff --git a/cuda_core/examples/saxpy.py b/cuda_core/examples/saxpy.py index 37ad493..c11c91d 100644 --- a/cuda_core/examples/saxpy.py +++ b/cuda_core/examples/saxpy.py @@ -10,95 +10,99 @@ import cupy as cp - -# compute out = a * x + y -code = """ -template -__global__ void saxpy(const T a, - const T* x, - const T* y, - T* out, - size_t N) { - const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; - for (size_t i=tid; i + __global__ void saxpy(const T a, + const T* x, + const T* y, + T* out, + size_t N) { + const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; + for (size_t i=tid; i", "saxpy")) - -# run in single precision -ker = mod.get_kernel("saxpy") -dtype = cp.float32 - -# prepare input/output -size = cp.uint64(64) -a = dtype(10) -x = cp.random.random(size, dtype=dtype) -y = cp.random.random(size, dtype=dtype) -out = cp.empty_like(x) -dev.sync() # cupy runs on a different stream from s, so sync before accessing - -# prepare launch -block = 32 -grid = int((size + block - 1) // block) -config = LaunchConfig(grid=grid, block=block, stream=s) -ker_args = (a, x.data.ptr, y.data.ptr, out.data.ptr, size) - -# launch kernel on stream s -launch(ker, config, *ker_args) -s.sync() - -# check result -assert cp.allclose(out, a*x+y) - -# let's repeat again, this time allocates our own out buffer instead of cupy's -# run in double precision -ker = mod.get_kernel("saxpy") -dtype = cp.float64 - -# prepare input -size = cp.uint64(128) -a = dtype(42) -x = cp.random.random(size, dtype=dtype) -y = cp.random.random(size, dtype=dtype) -dev.sync() - -# prepare output -buf = dev.allocate(size * 8, # = dtype.itemsize - stream=s) - -# prepare launch -block = 64 -grid = int((size + block - 1) // block) -config = LaunchConfig(grid=grid, block=block, stream=s) -ker_args = (a, x.data.ptr, y.data.ptr, buf, size) - -# launch kernel on stream s -launch(ker, config, *ker_args) -s.sync() - -# check result -# we wrap output buffer as a cupy array for simplicity -out = cp.ndarray(size, dtype=dtype, - memptr=cp.cuda.MemoryPointer(cp.cuda.UnownedMemory(int(buf.handle), buf.size, buf), 0)) -assert cp.allclose(out, a*x+y) - -# clean up resources that we allocate -# cupy cleans up automatically the rest -buf.close(s) -s.close() - -print("done!") + """ + + + dev = Device() + dev.set_current() + s = dev.create_stream() + + # prepare program + prog = Program(code, code_type="c++") + mod = prog.compile( + "cubin", + options=("-std=c++11", "-arch=sm_" + "".join(f"{i}" for i in dev.compute_capability),), + logs=sys.stdout, + name_expressions=("saxpy", "saxpy")) + + # run in single precision + ker = mod.get_kernel("saxpy") + dtype = cp.float32 + + # prepare input/output + size = cp.uint64(64) + a = dtype(10) + x = cp.random.random(size, dtype=dtype) + y = cp.random.random(size, dtype=dtype) + out = cp.empty_like(x) + dev.sync() # cupy runs on a different stream from s, so sync before accessing + + # prepare launch + block = 32 + grid = int((size + block - 1) // block) + config = LaunchConfig(grid=grid, block=block, stream=s) + ker_args = (a, x.data.ptr, y.data.ptr, out.data.ptr, size) + + # launch kernel on stream s + launch(ker, config, *ker_args) + s.sync() + + # check result + assert cp.allclose(out, a*x+y) + + # let's repeat again, this time allocates our own out buffer instead of cupy's + # run in double precision + ker = mod.get_kernel("saxpy") + dtype = cp.float64 + + # prepare input + size = cp.uint64(128) + a = dtype(42) + x = cp.random.random(size, dtype=dtype) + y = cp.random.random(size, dtype=dtype) + dev.sync() + + # prepare output + buf = dev.allocate(size * 8, # = dtype.itemsize + stream=s) + + # prepare launch + block = 64 + grid = int((size + block - 1) // block) + config = LaunchConfig(grid=grid, block=block, stream=s) + ker_args = (a, x.data.ptr, y.data.ptr, buf, size) + + # launch kernel on stream s + launch(ker, config, *ker_args) + s.sync() + + # check result + # we wrap output buffer as a cupy array for simplicity + out = cp.ndarray(size, dtype=dtype, + memptr=cp.cuda.MemoryPointer(cp.cuda.UnownedMemory(int(buf.handle), buf.size, buf), 0)) + assert cp.allclose(out, a*x+y) + + # clean up resources that we allocate + # cupy cleans up automatically the rest + buf.close(s) + s.close() + + print("done!") + +if __name__ == "__main__": + saxpy() diff --git a/cuda_core/tests/conftest.py b/cuda_core/tests/conftest.py new file mode 100644 index 0000000..862a938 --- /dev/null +++ b/cuda_core/tests/conftest.py @@ -0,0 +1,16 @@ +import pytest +from cuda import cuda +from cuda.core.experimental._device import Device +from cuda.core.experimental._context import Context +from cuda.core.experimental._utils import handle_return + +@pytest.fixture(scope="module") +def init_cuda(): + device = Device() + device.set_current() + +@pytest.fixture(scope="function") +def reestablish_valid_context(): + yield + device = Device() + device.set_current() \ No newline at end of file diff --git a/cuda_core/tests/example_tests/test_basic_examples.py b/cuda_core/tests/example_tests/test_basic_examples.py new file mode 100644 index 0000000..9a976ec --- /dev/null +++ b/cuda_core/tests/example_tests/test_basic_examples.py @@ -0,0 +1,15 @@ +# Copyright 2021-2024 NVIDIA Corporation. All rights reserved. +# +# Please refer to the NVIDIA end user license agreement (EULA) associated +# with this source code for terms and conditions that govern your use of +# this software. Any use, reproduction, disclosure, or distribution of +# this software and related documentation outside the terms of the EULA +# is strictly prohibited. + +# If we have subcategories of examples in the future, this file can be split along those lines + +from utils import run_example + +def test_basic_examples(): + run_example("../examples", "saxpy.py") + run_example("../examples", "vector_add.py") diff --git a/cuda_core/tests/example_tests/utils.py b/cuda_core/tests/example_tests/utils.py new file mode 100644 index 0000000..3ba956d --- /dev/null +++ b/cuda_core/tests/example_tests/utils.py @@ -0,0 +1,55 @@ +# Copyright 2021-2024 NVIDIA Corporation. All rights reserved. +# +# Please refer to the NVIDIA end user license agreement (EULA) associated +# with this source code for terms and conditions that govern your use of +# this software. Any use, reproduction, disclosure, or distribution of +# this software and related documentation outside the terms of the EULA +# is strictly prohibited. + +import gc +import os +import sys +import pytest +from cuda import cuda +import cupy as cp + +class SampleTestError(Exception): + pass + +def parse_python_script(filepath): + if filepath.endswith('.py'): + with open(filepath, "r", encoding='utf-8') as f: + script = f.read() + else: + raise ValueError(f"{filepath} not supported") + return script + + +def run_example(samples_path, filename, env=None): + fullpath = os.path.join(samples_path, filename) + script = parse_python_script(fullpath) + try: + old_argv = sys.argv + sys.argv = [fullpath] + SYS_PATH_BACKUP = sys.path.copy() + sys.path.append(samples_path) + exec(script, env if env is not None else {}) + except ImportError as e: + # for samples requiring any of optional dependencies + for m in ('cupy',): + if f"No module named '{m}'" in str(e): + pytest.skip(f'{m} not installed, skipping related tests') + break + else: + raise + except Exception as e: + msg = "\n" + msg += f'Got error ({filename}):\n' + msg += str(e) + raise SampleTestError(msg) from e + finally: + sys.path = SYS_PATH_BACKUP + sys.argv = old_argv + # further reduce the memory watermark + gc.collect() + cp.get_default_memory_pool().free_all_blocks() \ No newline at end of file diff --git a/cuda_core/tests/test_context.py b/cuda_core/tests/test_context.py new file mode 100644 index 0000000..7c62165 --- /dev/null +++ b/cuda_core/tests/test_context.py @@ -0,0 +1,33 @@ +# Copyright 2021-2024 NVIDIA Corporation. All rights reserved. +# +# Please refer to the NVIDIA end user license agreement (EULA) associated +# with this source code for terms and conditions that govern your use of +# this software. Any use, reproduction, disclosure, or distribution of +# this software and related documentation outside the terms of the EULA +# is strictly prohibited. + +from cuda.core.experimental._context import Context +from cuda.core.experimental._device import Device +from cuda.core.experimental._utils import handle_return +from cuda import cuda +import pytest + + +def test_context_initialization(): + context = Context() + assert context is not None + +def test_context_from_ctx(reestablish_valid_context): + device = Device() + dev_id = 0 + + # push the primary context and set it as the current context for the device + ctx = handle_return(cuda.cuDevicePrimaryCtxRetain(dev_id)) + handle_return(cuda.cuCtxPushCurrent(ctx)) + device.set_current(Context._from_ctx(ctx, 0)) + + # pop the context + handle_return(cuda.cuCtxPopCurrent()) + + # the device's context *has* been initialized, but if this is the method used to guard active context dependant calls, it should return false + assert device._check_context_initialized() == False \ No newline at end of file diff --git a/cuda_core/tests/test_device.py b/cuda_core/tests/test_device.py index 4deb9ef..479764d 100644 --- a/cuda_core/tests/test_device.py +++ b/cuda_core/tests/test_device.py @@ -12,26 +12,18 @@ precondition import pytest -@pytest.fixture(scope='module') -def init_cuda(): - Device().set_current() - -def test_device_initialization(): - device = Device() - assert device is not None def test_device_repr(): - device = Device() + device = Device(0) assert str(device).startswith(' - __global__ void saxpy(const T a, - const T* x, - const T* y, - T* out, - size_t N) { - const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; - for (size_t i=tid; i", "saxpy")) - - # run in single precision - ker = mod.get_kernel("saxpy") - dtype = cp.float32 - - # prepare input/output - size = cp.uint64(64) - a = dtype(10) - x = cp.random.random(size, dtype=dtype) - y = cp.random.random(size, dtype=dtype) - out = cp.empty_like(x) - dev.sync() # cupy runs on a different stream from s, so sync before accessing - - # prepare launch - block = 32 - grid = int((size + block - 1) // block) - config = LaunchConfig(grid=grid, block=block, stream=s) - ker_args = (a, x.data.ptr, y.data.ptr, out.data.ptr, size) - - # launch kernel on stream s - launch(ker, config, *ker_args) - s.sync() - - # check result - assert cp.allclose(out, a*x+y) - - # let's repeat again, this time allocates our own out buffer instead of cupy's - # run in double precision - ker = mod.get_kernel("saxpy") - dtype = cp.float64 - - # prepare input - size = cp.uint64(128) - a = dtype(42) - x = cp.random.random(size, dtype=dtype) - y = cp.random.random(size, dtype=dtype) - dev.sync() - - # prepare output - buf = dev.allocate(size * 8, # = dtype.itemsize - stream=s) - - # prepare launch - block = 64 - grid = int((size + block - 1) // block) - config = LaunchConfig(grid=grid, block=block, stream=s) - ker_args = (a, x.data.ptr, y.data.ptr, buf, size) - - # launch kernel on stream s - launch(ker, config, *ker_args) - s.sync() - - # check result - # we wrap output buffer as a cupy array for simplicity - out = cp.ndarray(size, dtype=dtype, - memptr=cp.cuda.MemoryPointer(cp.cuda.UnownedMemory(int(buf.handle), buf.size, buf), 0)) - assert cp.allclose(out, a*x+y) - - # clean up resources that we allocate - # cupy cleans up automatically the rest - buf.close(s) - s.close() - -def test_vector_add_example(): - import cupy as cp - # compute c = a + b - code = """ - template - __global__ void vector_add(const T* A, - const T* B, - T* C, - size_t N) { - const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; - for (size_t i=tid; i",)) - - # run in single precision - ker = mod.get_kernel("vector_add") - dtype = cp.float32 - - # prepare input/output - size = 50000 - a = cp.random.random(size, dtype=dtype) - b = cp.random.random(size, dtype=dtype) - c = cp.empty_like(a) - - # cupy runs on a different stream from s, so sync before accessing - dev.sync() - - # prepare launch - block = 256 - grid = (size + block - 1) // block - config = LaunchConfig(grid=grid, block=block, stream=s) - - # launch kernel on stream s - launch(ker, config, a.data.ptr, b.data.ptr, c.data.ptr, cp.uint64(size)) - s.sync() - - # check result - assert cp.allclose(c, a+b) diff --git a/cuda_core/tests/test_launcher.py b/cuda_core/tests/test_launcher.py index a58fb93..d6c923e 100644 --- a/cuda_core/tests/test_launcher.py +++ b/cuda_core/tests/test_launcher.py @@ -13,10 +13,6 @@ from cuda import cuda import pytest -@pytest.fixture(scope='module') -def init_cuda(): - Device().set_current() - def test_launch_config_init(): config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), stream=None, shmem_size=0) assert config.grid == (1, 1, 1) @@ -42,45 +38,25 @@ def test_launch_config_cast_to_3_tuple(): assert config._cast_to_3_tuple((999, 888, 777)) == (999, 888, 777) def test_launch_config_invalid_values(): - try: + with pytest.raises(ValueError): LaunchConfig(grid=0, block=1) - except ValueError: - assert True - else: - assert False - try: + with pytest.raises(ValueError): LaunchConfig(grid=(0, 1), block=1) - except ValueError: - assert True - else: - assert False - try: + with pytest.raises(ValueError): LaunchConfig(grid=(1, 1, 1), block=0) - except ValueError: - assert True - else: - assert False - try: + with pytest.raises(ValueError): LaunchConfig(grid=(1, 1, 1), block=(0, 1)) - except ValueError: - assert True - else: - assert False def test_launch_config_stream(): stream = Device().create_stream() config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), stream=stream, shmem_size=0) assert config.stream == stream - try: + with pytest.raises(ValueError): LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), stream="invalid_stream", shmem_size=0) - except ValueError: - assert True - else: - assert False def test_launch_config_shmem_size(): config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), stream=None, shmem_size=2048) diff --git a/cuda_core/tests/test_memory.py b/cuda_core/tests/test_memory.py index eac610a..0ec0fb8 100644 --- a/cuda_core/tests/test_memory.py +++ b/cuda_core/tests/test_memory.py @@ -13,10 +13,6 @@ import ctypes import pytest -@pytest.fixture(scope='module') -def init_cuda(): - Device().set_current() - class DummyDeviceMemoryResource(MemoryResource): def __init__(self, device): self.device = device diff --git a/cuda_core/tests/test_module.py b/cuda_core/tests/test_module.py index 1ca1785..45b59ba 100644 --- a/cuda_core/tests/test_module.py +++ b/cuda_core/tests/test_module.py @@ -12,10 +12,6 @@ from cuda.core.experimental._utils import handle_return import pytest -@pytest.fixture(scope='module') -def init_cuda(): - Device().set_current() - def test_object_code_initialization(): # Test with supported code types for code_type in ["cubin", "ptx", "fatbin"]: diff --git a/cuda_core/tests/test_program.py b/cuda_core/tests/test_program.py index 8cc9061..4feb911 100644 --- a/cuda_core/tests/test_program.py +++ b/cuda_core/tests/test_program.py @@ -12,9 +12,6 @@ from cuda.core.experimental._device import Device import pytest -@pytest.fixture(scope='module') -def init_cuda(): - Device().set_current() def test_program_init_valid_code_type(): code = "extern \"C\" __global__ void my_kernel() {}" @@ -32,13 +29,13 @@ def test_program_init_invalid_code_format(): with pytest.raises(TypeError): Program(code, "c++") -def test_program_compile_valid_target_type(): - code = "extern \"C\" __global__ void my_kernel() {}" - program = Program(code, "c++") - object_code = program.compile("ptx") - kernel = object_code.get_kernel("my_kernel") - assert isinstance(object_code, ObjectCode) - assert isinstance(kernel, Kernel) +# def test_program_compile_valid_target_type(): +# code = "extern \"C\" __global__ void my_kernel() {}" +# program = Program(code, "c++") +# object_code = program.compile("ptx") +# kernel = object_code.get_kernel("my_kernel") +# assert isinstance(object_code, ObjectCode) +# assert isinstance(kernel, Kernel) def test_program_compile_invalid_target_type(): code = "extern \"C\" __global__ void my_kernel() {}" @@ -60,7 +57,4 @@ def test_program_close(): code = "extern \"C\" __global__ void my_kernel() {}" program = Program(code, "c++") program.close() - assert program.handle is None - -Device().set_current() -test_program_compile_valid_target_type() \ No newline at end of file + assert program.handle is None \ No newline at end of file diff --git a/cuda_core/tests/test_stream.py b/cuda_core/tests/test_stream.py index 31c4ba0..9ecbd86 100644 --- a/cuda_core/tests/test_stream.py +++ b/cuda_core/tests/test_stream.py @@ -11,10 +11,6 @@ from cuda.core.experimental._device import Device import pytest -@pytest.fixture(scope='module') -def init_cuda(): - Device().set_current() - def test_stream_init(): with pytest.raises(NotImplementedError): Stream() From d23f2b0ea4c949762f05187918d002bd6924e73e Mon Sep 17 00:00:00 2001 From: ksimpson Date: Fri, 25 Oct 2024 18:03:31 -0700 Subject: [PATCH 05/13] remove context test for now --- cuda_core/examples/saxpy.py | 186 ++++++++++++++++---------------- cuda_core/tests/test_context.py | 33 ------ 2 files changed, 91 insertions(+), 128 deletions(-) delete mode 100644 cuda_core/tests/test_context.py diff --git a/cuda_core/examples/saxpy.py b/cuda_core/examples/saxpy.py index c11c91d..37ad493 100644 --- a/cuda_core/examples/saxpy.py +++ b/cuda_core/examples/saxpy.py @@ -10,99 +10,95 @@ import cupy as cp -def saxpy(): - - # compute out = a * x + y - code = """ - template - __global__ void saxpy(const T a, - const T* x, - const T* y, - T* out, - size_t N) { - const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; - for (size_t i=tid; i +__global__ void saxpy(const T a, + const T* x, + const T* y, + T* out, + size_t N) { + const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; + for (size_t i=tid; i", "saxpy")) - - # run in single precision - ker = mod.get_kernel("saxpy") - dtype = cp.float32 - - # prepare input/output - size = cp.uint64(64) - a = dtype(10) - x = cp.random.random(size, dtype=dtype) - y = cp.random.random(size, dtype=dtype) - out = cp.empty_like(x) - dev.sync() # cupy runs on a different stream from s, so sync before accessing - - # prepare launch - block = 32 - grid = int((size + block - 1) // block) - config = LaunchConfig(grid=grid, block=block, stream=s) - ker_args = (a, x.data.ptr, y.data.ptr, out.data.ptr, size) - - # launch kernel on stream s - launch(ker, config, *ker_args) - s.sync() - - # check result - assert cp.allclose(out, a*x+y) - - # let's repeat again, this time allocates our own out buffer instead of cupy's - # run in double precision - ker = mod.get_kernel("saxpy") - dtype = cp.float64 - - # prepare input - size = cp.uint64(128) - a = dtype(42) - x = cp.random.random(size, dtype=dtype) - y = cp.random.random(size, dtype=dtype) - dev.sync() - - # prepare output - buf = dev.allocate(size * 8, # = dtype.itemsize - stream=s) - - # prepare launch - block = 64 - grid = int((size + block - 1) // block) - config = LaunchConfig(grid=grid, block=block, stream=s) - ker_args = (a, x.data.ptr, y.data.ptr, buf, size) - - # launch kernel on stream s - launch(ker, config, *ker_args) - s.sync() - - # check result - # we wrap output buffer as a cupy array for simplicity - out = cp.ndarray(size, dtype=dtype, - memptr=cp.cuda.MemoryPointer(cp.cuda.UnownedMemory(int(buf.handle), buf.size, buf), 0)) - assert cp.allclose(out, a*x+y) - - # clean up resources that we allocate - # cupy cleans up automatically the rest - buf.close(s) - s.close() - - print("done!") - -if __name__ == "__main__": - saxpy() +} +""" + + +dev = Device() +dev.set_current() +s = dev.create_stream() + +# prepare program +prog = Program(code, code_type="c++") +mod = prog.compile( + "cubin", + options=("-std=c++11", "-arch=sm_" + "".join(f"{i}" for i in dev.compute_capability),), + logs=sys.stdout, + name_expressions=("saxpy", "saxpy")) + +# run in single precision +ker = mod.get_kernel("saxpy") +dtype = cp.float32 + +# prepare input/output +size = cp.uint64(64) +a = dtype(10) +x = cp.random.random(size, dtype=dtype) +y = cp.random.random(size, dtype=dtype) +out = cp.empty_like(x) +dev.sync() # cupy runs on a different stream from s, so sync before accessing + +# prepare launch +block = 32 +grid = int((size + block - 1) // block) +config = LaunchConfig(grid=grid, block=block, stream=s) +ker_args = (a, x.data.ptr, y.data.ptr, out.data.ptr, size) + +# launch kernel on stream s +launch(ker, config, *ker_args) +s.sync() + +# check result +assert cp.allclose(out, a*x+y) + +# let's repeat again, this time allocates our own out buffer instead of cupy's +# run in double precision +ker = mod.get_kernel("saxpy") +dtype = cp.float64 + +# prepare input +size = cp.uint64(128) +a = dtype(42) +x = cp.random.random(size, dtype=dtype) +y = cp.random.random(size, dtype=dtype) +dev.sync() + +# prepare output +buf = dev.allocate(size * 8, # = dtype.itemsize + stream=s) + +# prepare launch +block = 64 +grid = int((size + block - 1) // block) +config = LaunchConfig(grid=grid, block=block, stream=s) +ker_args = (a, x.data.ptr, y.data.ptr, buf, size) + +# launch kernel on stream s +launch(ker, config, *ker_args) +s.sync() + +# check result +# we wrap output buffer as a cupy array for simplicity +out = cp.ndarray(size, dtype=dtype, + memptr=cp.cuda.MemoryPointer(cp.cuda.UnownedMemory(int(buf.handle), buf.size, buf), 0)) +assert cp.allclose(out, a*x+y) + +# clean up resources that we allocate +# cupy cleans up automatically the rest +buf.close(s) +s.close() + +print("done!") diff --git a/cuda_core/tests/test_context.py b/cuda_core/tests/test_context.py deleted file mode 100644 index 7c62165..0000000 --- a/cuda_core/tests/test_context.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2021-2024 NVIDIA Corporation. All rights reserved. -# -# Please refer to the NVIDIA end user license agreement (EULA) associated -# with this source code for terms and conditions that govern your use of -# this software. Any use, reproduction, disclosure, or distribution of -# this software and related documentation outside the terms of the EULA -# is strictly prohibited. - -from cuda.core.experimental._context import Context -from cuda.core.experimental._device import Device -from cuda.core.experimental._utils import handle_return -from cuda import cuda -import pytest - - -def test_context_initialization(): - context = Context() - assert context is not None - -def test_context_from_ctx(reestablish_valid_context): - device = Device() - dev_id = 0 - - # push the primary context and set it as the current context for the device - ctx = handle_return(cuda.cuDevicePrimaryCtxRetain(dev_id)) - handle_return(cuda.cuCtxPushCurrent(ctx)) - device.set_current(Context._from_ctx(ctx, 0)) - - # pop the context - handle_return(cuda.cuCtxPopCurrent()) - - # the device's context *has* been initialized, but if this is the method used to guard active context dependant calls, it should return false - assert device._check_context_initialized() == False \ No newline at end of file From 059b8f49fbfae47112c738223151aeb1363d9646 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Fri, 25 Oct 2024 18:04:26 -0700 Subject: [PATCH 06/13] remove fixture used by old context test --- cuda_core/tests/conftest.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/cuda_core/tests/conftest.py b/cuda_core/tests/conftest.py index 862a938..b88047b 100644 --- a/cuda_core/tests/conftest.py +++ b/cuda_core/tests/conftest.py @@ -6,11 +6,5 @@ @pytest.fixture(scope="module") def init_cuda(): - device = Device() - device.set_current() - -@pytest.fixture(scope="function") -def reestablish_valid_context(): - yield device = Device() device.set_current() \ No newline at end of file From b0a39229cdce3b31a49e180f9d44a76b111e8a47 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Mon, 28 Oct 2024 13:27:19 -0700 Subject: [PATCH 07/13] address Vlad's comments --- cuda_core/tests/conftest.py | 14 ++++++++++---- cuda_core/tests/test_device.py | 4 +++- cuda_core/tests/test_launcher.py | 3 ++- cuda_core/tests/test_memory.py | 2 +- cuda_core/tests/test_program.py | 3 ++- cuda_core/tests/test_stream.py | 3 ++- 6 files changed, 20 insertions(+), 9 deletions(-) diff --git a/cuda_core/tests/conftest.py b/cuda_core/tests/conftest.py index b88047b..e622fc8 100644 --- a/cuda_core/tests/conftest.py +++ b/cuda_core/tests/conftest.py @@ -1,10 +1,16 @@ +# Copyright 2021-2024 NVIDIA Corporation. All rights reserved. +# +# Please refer to the NVIDIA end user license agreement (EULA) associated +# with this source code for terms and conditions that govern your use of +# this software. Any use, reproduction, disclosure, or distribution of +# this software and related documentation outside the terms of the EULA +# is strictly prohibited. + import pytest -from cuda import cuda from cuda.core.experimental._device import Device -from cuda.core.experimental._context import Context -from cuda.core.experimental._utils import handle_return @pytest.fixture(scope="module") def init_cuda(): device = Device() - device.set_current() \ No newline at end of file + device.set_current() + \ No newline at end of file diff --git a/cuda_core/tests/test_device.py b/cuda_core/tests/test_device.py index 479764d..9686ced 100644 --- a/cuda_core/tests/test_device.py +++ b/cuda_core/tests/test_device.py @@ -33,6 +33,7 @@ def test_device_create_stream(): device = Device() stream = device.create_stream() assert stream is not None + assert stream.handle def test_pci_bus_id(): device = Device() @@ -63,4 +64,5 @@ def test_compute_capability(): minor = handle_return(cudart.cudaDeviceGetAttribute( cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor, device.device_id)) expected_cc = ComputeCapability(major, minor) - assert device.compute_capability == expected_cc \ No newline at end of file + assert device.compute_capability == expected_cc + \ No newline at end of file diff --git a/cuda_core/tests/test_launcher.py b/cuda_core/tests/test_launcher.py index d6c923e..c519f0e 100644 --- a/cuda_core/tests/test_launcher.py +++ b/cuda_core/tests/test_launcher.py @@ -63,4 +63,5 @@ def test_launch_config_shmem_size(): assert config.shmem_size == 2048 config = LaunchConfig(grid=(1, 1, 1), block=(1, 1, 1), stream=None) - assert config.shmem_size == 0 \ No newline at end of file + assert config.shmem_size == 0 + \ No newline at end of file diff --git a/cuda_core/tests/test_memory.py b/cuda_core/tests/test_memory.py index 0ec0fb8..9497e2f 100644 --- a/cuda_core/tests/test_memory.py +++ b/cuda_core/tests/test_memory.py @@ -199,4 +199,4 @@ def test_buffer_close(): buffer_close(DummyUnifiedMemoryResource(device)) buffer_close(DummyPinnedMemoryResource(device)) -test_buffer_copy_to() \ No newline at end of file +test_buffer_copy_to() diff --git a/cuda_core/tests/test_program.py b/cuda_core/tests/test_program.py index 4feb911..4c31070 100644 --- a/cuda_core/tests/test_program.py +++ b/cuda_core/tests/test_program.py @@ -57,4 +57,5 @@ def test_program_close(): code = "extern \"C\" __global__ void my_kernel() {}" program = Program(code, "c++") program.close() - assert program.handle is None \ No newline at end of file + assert program.handle is None + \ No newline at end of file diff --git a/cuda_core/tests/test_stream.py b/cuda_core/tests/test_stream.py index 9ecbd86..b85660d 100644 --- a/cuda_core/tests/test_stream.py +++ b/cuda_core/tests/test_stream.py @@ -83,4 +83,5 @@ def test_per_thread_default_stream(): def test_default_stream(): stream = default_stream() - assert isinstance(stream, Stream) \ No newline at end of file + assert isinstance(stream, Stream) + \ No newline at end of file From f76d6833390aa60767fa17e79df99ee51fc19598 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Mon, 28 Oct 2024 13:40:38 -0700 Subject: [PATCH 08/13] address utils nits --- cuda_core/tests/example_tests/utils.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/cuda_core/tests/example_tests/utils.py b/cuda_core/tests/example_tests/utils.py index 3ba956d..807ebea 100644 --- a/cuda_core/tests/example_tests/utils.py +++ b/cuda_core/tests/example_tests/utils.py @@ -6,22 +6,21 @@ # this software and related documentation outside the terms of the EULA # is strictly prohibited. +from cuda import cuda import gc import os import sys import pytest -from cuda import cuda import cupy as cp class SampleTestError(Exception): pass def parse_python_script(filepath): - if filepath.endswith('.py'): - with open(filepath, "r", encoding='utf-8') as f: - script = f.read() - else: + if not filepath.endswith('.py'): raise ValueError(f"{filepath} not supported") + with open(filepath, "r", encoding='utf-8') as f: + script = f.read() return script @@ -31,9 +30,9 @@ def run_example(samples_path, filename, env=None): try: old_argv = sys.argv sys.argv = [fullpath] - SYS_PATH_BACKUP = sys.path.copy() + old_sys_path = sys.path.copy() sys.path.append(samples_path) - exec(script, env if env is not None else {}) + exec(script, env if env else {}) except ImportError as e: # for samples requiring any of optional dependencies for m in ('cupy',): @@ -48,7 +47,7 @@ def run_example(samples_path, filename, env=None): msg += str(e) raise SampleTestError(msg) from e finally: - sys.path = SYS_PATH_BACKUP + sys.path = old_sys_path sys.argv = old_argv # further reduce the memory watermark gc.collect() From 1da4e7ed3003e416adf82ccc6e32a341840c8d46 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Mon, 28 Oct 2024 14:06:08 -0700 Subject: [PATCH 09/13] address example_tests comments --- cuda_core/tests/example_tests/__init__.py | 0 .../tests/example_tests/test_basic_examples.py | 18 ++++++++++++++---- 2 files changed, 14 insertions(+), 4 deletions(-) create mode 100644 cuda_core/tests/example_tests/__init__.py diff --git a/cuda_core/tests/example_tests/__init__.py b/cuda_core/tests/example_tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cuda_core/tests/example_tests/test_basic_examples.py b/cuda_core/tests/example_tests/test_basic_examples.py index 9a976ec..87c0c08 100644 --- a/cuda_core/tests/example_tests/test_basic_examples.py +++ b/cuda_core/tests/example_tests/test_basic_examples.py @@ -8,8 +8,18 @@ # If we have subcategories of examples in the future, this file can be split along those lines -from utils import run_example +from .utils import run_example +import os +import glob +import pytest -def test_basic_examples(): - run_example("../examples", "saxpy.py") - run_example("../examples", "vector_add.py") +samples_path = os.path.join( + os.path.dirname(__file__), '..', '..', 'examples') +sample_files = glob.glob(samples_path+'**/*.py', recursive=True) +@pytest.mark.parametrize( + 'example', sample_files +) +class TestExamples: + def test_example(self, example): + filename = os.path.basename(example) + run_example(samples_path, example) From 1a3a148f1d940a92357c6d03919e6cffd3ccc2a5 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Mon, 28 Oct 2024 15:38:09 -0700 Subject: [PATCH 10/13] addressing final formatting comments --- cuda_core/tests/conftest.py | 4 ++-- .../example_tests/test_basic_examples.py | 2 +- cuda_core/tests/example_tests/utils.py | 4 ++-- cuda_core/tests/test_device.py | 19 +++++++++---------- cuda_core/tests/test_event.py | 3 +-- cuda_core/tests/test_launcher.py | 5 ++--- cuda_core/tests/test_memory.py | 14 +++++--------- cuda_core/tests/test_module.py | 2 +- cuda_core/tests/test_program.py | 19 ++++++++----------- cuda_core/tests/test_stream.py | 3 +-- 10 files changed, 32 insertions(+), 43 deletions(-) diff --git a/cuda_core/tests/conftest.py b/cuda_core/tests/conftest.py index e622fc8..3ff6ce0 100644 --- a/cuda_core/tests/conftest.py +++ b/cuda_core/tests/conftest.py @@ -1,4 +1,4 @@ -# Copyright 2021-2024 NVIDIA Corporation. All rights reserved. +# Copyright 2024 NVIDIA Corporation. All rights reserved. # # Please refer to the NVIDIA end user license agreement (EULA) associated # with this source code for terms and conditions that govern your use of @@ -6,8 +6,8 @@ # this software and related documentation outside the terms of the EULA # is strictly prohibited. -import pytest from cuda.core.experimental._device import Device +import pytest @pytest.fixture(scope="module") def init_cuda(): diff --git a/cuda_core/tests/example_tests/test_basic_examples.py b/cuda_core/tests/example_tests/test_basic_examples.py index 87c0c08..e490892 100644 --- a/cuda_core/tests/example_tests/test_basic_examples.py +++ b/cuda_core/tests/example_tests/test_basic_examples.py @@ -1,4 +1,4 @@ -# Copyright 2021-2024 NVIDIA Corporation. All rights reserved. +# Copyright 2024 NVIDIA Corporation. All rights reserved. # # Please refer to the NVIDIA end user license agreement (EULA) associated # with this source code for terms and conditions that govern your use of diff --git a/cuda_core/tests/example_tests/utils.py b/cuda_core/tests/example_tests/utils.py index 807ebea..5f4e14b 100644 --- a/cuda_core/tests/example_tests/utils.py +++ b/cuda_core/tests/example_tests/utils.py @@ -1,4 +1,4 @@ -# Copyright 2021-2024 NVIDIA Corporation. All rights reserved. +# Copyright 2024 NVIDIA Corporation. All rights reserved. # # Please refer to the NVIDIA end user license agreement (EULA) associated # with this source code for terms and conditions that govern your use of @@ -51,4 +51,4 @@ def run_example(samples_path, filename, env=None): sys.argv = old_argv # further reduce the memory watermark gc.collect() - cp.get_default_memory_pool().free_all_blocks() \ No newline at end of file + cp.get_default_memory_pool().free_all_blocks() diff --git a/cuda_core/tests/test_device.py b/cuda_core/tests/test_device.py index 9686ced..653dac0 100644 --- a/cuda_core/tests/test_device.py +++ b/cuda_core/tests/test_device.py @@ -1,4 +1,4 @@ -# Copyright 2021-2024 NVIDIA Corporation. All rights reserved. +# Copyright 2024 NVIDIA Corporation. All rights reserved. # # Please refer to the NVIDIA end user license agreement (EULA) associated # with this source code for terms and conditions that govern your use of @@ -12,18 +12,17 @@ precondition import pytest - def test_device_repr(): device = Device(0) assert str(device).startswith(' Buffer: ptr = handle_return(cuda.cuMemAlloc(size)) return Buffer(ptr=ptr, size=size, mr=self) def deallocate(self, ptr, size, stream=None): - cuda.cuMemFree(ptr) + handle_return(cuda.cuMemFree(ptr)) @property def is_device_accessible(self) -> bool: @@ -65,14 +64,13 @@ def device_id(self) -> int: class DummyUnifiedMemoryResource(MemoryResource): def __init__(self, device): self.device = device - pass def allocate(self, size, stream=None) -> Buffer: ptr = handle_return(cuda.cuMemAllocManaged(size, cuda.CUmemAttach_flags.CU_MEM_ATTACH_GLOBAL.value)) return Buffer(ptr=ptr, size=size, mr=self) def deallocate(self, ptr, size, stream=None): - cuda.cuMemFree(ptr) + handle_return(cuda.cuMemFree(ptr)) @property def is_device_accessible(self) -> bool: @@ -89,14 +87,13 @@ def device_id(self) -> int: class DummyPinnedMemoryResource(MemoryResource): def __init__(self, device): self.device = device - pass def allocate(self, size, stream=None) -> Buffer: ptr = handle_return(cuda.cuMemAllocHost(size)) return Buffer(ptr=ptr, size=size, mr=self) def deallocate(self, ptr, size, stream=None): - cuda.cuMemFreeHost(ptr) + handle_return(cuda.cuMemFreeHost(ptr)) @property def is_device_accessible(self) -> bool: @@ -199,4 +196,3 @@ def test_buffer_close(): buffer_close(DummyUnifiedMemoryResource(device)) buffer_close(DummyPinnedMemoryResource(device)) -test_buffer_copy_to() diff --git a/cuda_core/tests/test_module.py b/cuda_core/tests/test_module.py index 45b59ba..cc5cf57 100644 --- a/cuda_core/tests/test_module.py +++ b/cuda_core/tests/test_module.py @@ -1,4 +1,4 @@ -# Copyright 2021-2024 NVIDIA Corporation. All rights reserved. +# Copyright 2024 NVIDIA Corporation. All rights reserved. # # Please refer to the NVIDIA end user license agreement (EULA) associated # with this source code for terms and conditions that govern your use of diff --git a/cuda_core/tests/test_program.py b/cuda_core/tests/test_program.py index 4c31070..39ce4dc 100644 --- a/cuda_core/tests/test_program.py +++ b/cuda_core/tests/test_program.py @@ -1,4 +1,4 @@ -# Copyright 2021-2024 NVIDIA Corporation. All rights reserved. +# Copyright 2024 NVIDIA Corporation. All rights reserved. # # Please refer to the NVIDIA end user license agreement (EULA) associated # with this source code for terms and conditions that govern your use of @@ -6,13 +6,11 @@ # this software and related documentation outside the terms of the EULA # is strictly prohibited. -from cuda import nvrtc from cuda.core.experimental._program import Program from cuda.core.experimental._module import ObjectCode, Kernel from cuda.core.experimental._device import Device import pytest - def test_program_init_valid_code_type(): code = "extern \"C\" __global__ void my_kernel() {}" program = Program(code, "c++") @@ -29,13 +27,13 @@ def test_program_init_invalid_code_format(): with pytest.raises(TypeError): Program(code, "c++") -# def test_program_compile_valid_target_type(): -# code = "extern \"C\" __global__ void my_kernel() {}" -# program = Program(code, "c++") -# object_code = program.compile("ptx") -# kernel = object_code.get_kernel("my_kernel") -# assert isinstance(object_code, ObjectCode) -# assert isinstance(kernel, Kernel) +def test_program_compile_valid_target_type(): + code = "extern \"C\" __global__ void my_kernel() {}" + program = Program(code, "c++") + object_code = program.compile("ptx") + kernel = object_code.get_kernel("my_kernel") + assert isinstance(object_code, ObjectCode) + assert isinstance(kernel, Kernel) def test_program_compile_invalid_target_type(): code = "extern \"C\" __global__ void my_kernel() {}" @@ -58,4 +56,3 @@ def test_program_close(): program = Program(code, "c++") program.close() assert program.handle is None - \ No newline at end of file diff --git a/cuda_core/tests/test_stream.py b/cuda_core/tests/test_stream.py index b85660d..e0a98c1 100644 --- a/cuda_core/tests/test_stream.py +++ b/cuda_core/tests/test_stream.py @@ -1,4 +1,4 @@ -# Copyright 2021-2024 NVIDIA Corporation. All rights reserved. +# Copyright 2024 NVIDIA Corporation. All rights reserved. # # Please refer to the NVIDIA end user license agreement (EULA) associated # with this source code for terms and conditions that govern your use of @@ -84,4 +84,3 @@ def test_per_thread_default_stream(): def test_default_stream(): stream = default_stream() assert isinstance(stream, Stream) - \ No newline at end of file From df514918ab823d40444b6843e4d4813f5946893e Mon Sep 17 00:00:00 2001 From: ksimpson Date: Mon, 28 Oct 2024 15:46:32 -0700 Subject: [PATCH 11/13] extra line missed --- cuda_core/tests/test_memory.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cuda_core/tests/test_memory.py b/cuda_core/tests/test_memory.py index 510d344..ce1b3f7 100644 --- a/cuda_core/tests/test_memory.py +++ b/cuda_core/tests/test_memory.py @@ -195,4 +195,3 @@ def test_buffer_close(): buffer_close(DummyHostMemoryResource()) buffer_close(DummyUnifiedMemoryResource(device)) buffer_close(DummyPinnedMemoryResource(device)) - From 976a8fabc37f25ccd12c2818838120bf4385f0f2 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Mon, 28 Oct 2024 15:48:07 -0700 Subject: [PATCH 12/13] remove ipc test from event --- cuda_core/tests/test_event.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/cuda_core/tests/test_event.py b/cuda_core/tests/test_event.py index 5f49d1c..b6cfe64 100644 --- a/cuda_core/tests/test_event.py +++ b/cuda_core/tests/test_event.py @@ -22,15 +22,6 @@ def test_is_sync_busy_waited(): event = Event._init(options) assert event.is_sync_busy_waited == True -def test_is_ipc_supported(): - options = EventOptions(support_ipc=True) - try: - event = Event._init(options) - except NotImplementedError: - assert True - else: - assert False - def test_sync(): options = EventOptions() event = Event._init(options) From 364e600cd71da1a9033dc4b533abc4f441a11451 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Mon, 28 Oct 2024 16:35:15 -0700 Subject: [PATCH 13/13] fix some warnings and some exceptions --- cuda_core/cuda/core/experimental/_program.py | 2 +- cuda_core/tests/test_memory.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_program.py b/cuda_core/cuda/core/experimental/_program.py index ae5928e..ec0778a 100644 --- a/cuda_core/cuda/core/experimental/_program.py +++ b/cuda_core/cuda/core/experimental/_program.py @@ -14,9 +14,9 @@ class Program: _supported_target_type = ("ptx", "cubin", "ltoir", ) def __init__(self, code, code_type): + self._handle = None if code_type not in self._supported_code_type: raise NotImplementedError - self._handle = None if code_type.lower() == "c++": if not isinstance(code, str): diff --git a/cuda_core/tests/test_memory.py b/cuda_core/tests/test_memory.py index ce1b3f7..4085526 100644 --- a/cuda_core/tests/test_memory.py +++ b/cuda_core/tests/test_memory.py @@ -114,7 +114,7 @@ def buffer_initialization(dummy_mr : MemoryResource): assert buffer.memory_resource == dummy_mr assert buffer.is_device_accessible == dummy_mr.is_device_accessible assert buffer.is_host_accessible == dummy_mr.is_host_accessible - dummy_mr.deallocate(buffer.handle, buffer.size) + buffer.close() def test_buffer_initialization(): device = Device() @@ -143,8 +143,8 @@ def buffer_copy_to(dummy_mr : MemoryResource, device : Device, check = False): for i in range(10): assert dst_ptr[i] == src_ptr[i] - dummy_mr.deallocate(src_buffer.handle, src_buffer.size) - dummy_mr.deallocate(dst_buffer.handle, dst_buffer.size) + dst_buffer.close() + src_buffer.close() def test_buffer_copy_to(): device = Device() @@ -172,8 +172,8 @@ def buffer_copy_from(dummy_mr : MemoryResource, device, check = False): for i in range(10): assert dst_ptr[i] == src_ptr[i] - dummy_mr.deallocate(src_buffer.handle, src_buffer.size) - dummy_mr.deallocate(dst_buffer.handle, dst_buffer.size) + dst_buffer.close() + src_buffer.close() def test_buffer_copy_from(): device = Device()