Skip to content

Commit

Permalink
DeviceBuffer: accept memory resource when taking ownership
Browse files Browse the repository at this point in the history
In c_from_unique_ptr we should not just rely on
get_current_device_resource, but rather allow the user to pass in the
memory resource they _know_ was used to allocate the buffer we are
taking ownership of.

So that we are backwards-compatible we default, as before, to the
current device resource.
  • Loading branch information
wence- committed May 3, 2024
1 parent aaddfb1 commit e0a4d05
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
3 changes: 2 additions & 1 deletion python/rmm/rmm/_lib/device_buffer.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ cdef class DeviceBuffer:
@staticmethod
cdef DeviceBuffer c_from_unique_ptr(
unique_ptr[device_buffer] ptr,
Stream stream=*
Stream stream=*,
DeviceMemoryResource mr=*,
)

@staticmethod
Expand Down
14 changes: 10 additions & 4 deletions python/rmm/rmm/_lib/device_buffer.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ from cuda.ccudart cimport (
)

from rmm._lib.memory_resource cimport (
DeviceMemoryResource,
device_memory_resource,
get_current_device_resource,
)
Expand All @@ -48,7 +49,8 @@ cdef class DeviceBuffer:
def __cinit__(self, *,
uintptr_t ptr=0,
size_t size=0,
Stream stream=DEFAULT_STREAM):
Stream stream=DEFAULT_STREAM,
DeviceMemoryResource mr=None):
"""Construct a ``DeviceBuffer`` with optional size and data pointer
Parameters
Expand All @@ -65,6 +67,9 @@ cdef class DeviceBuffer:
scope while the DeviceBuffer is in use. Destroying the
underlying stream while the DeviceBuffer is in use will
result in undefined behavior.
mr : optional
DeviceMemoryResource for the allocation, if not provided
defaults to the current device resource.
Note
----
Expand All @@ -80,7 +85,7 @@ cdef class DeviceBuffer:
cdef const void* c_ptr
cdef device_memory_resource * mr_ptr
# Save a reference to the MR and stream used for allocation
self.mr = get_current_device_resource()
self.mr = get_current_device_resource() if mr is None else mr
self.stream = stream

mr_ptr = self.mr.get_mr()
Expand Down Expand Up @@ -162,13 +167,14 @@ cdef class DeviceBuffer:
@staticmethod
cdef DeviceBuffer c_from_unique_ptr(
unique_ptr[device_buffer] ptr,
Stream stream=DEFAULT_STREAM
Stream stream=DEFAULT_STREAM,
DeviceMemoryResource mr=None,
):
cdef DeviceBuffer buf = DeviceBuffer.__new__(DeviceBuffer)
if stream.c_is_default():
stream.c_synchronize()
buf.c_obj = move(ptr)
buf.mr = get_current_device_resource()
buf.mr = get_current_device_resource() if mr is None else mr
buf.stream = stream
return buf

Expand Down

0 comments on commit e0a4d05

Please sign in to comment.