Skip to content

Commit

Permalink
removed ArgByRef
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed Sep 13, 2023
1 parent 5154a20 commit efa4ca8
Showing 1 changed file with 15 additions and 29 deletions.
44 changes: 15 additions & 29 deletions cpp/include/kvikio/stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,32 +41,24 @@ namespace kvikio {
* like most other asynchronous CUDA functions that take by-value arguments.
*
* To support by-value arguments, we allocate the arguments on the heap (malloc `ArgByVal`) and have
* the by-reference arguments (`ArgByRef`) points into `ArgByVal`. This way, the `read_async` and
* `write_async` can call `.get_args()` to get the by-reference arguments required by cuFile's
* stream API. However, this also means that the caller of `read_async` and `write_async` MUST keep
* the returned `StreamFuture` alive until the operations is done otherwise `StreamFuture` will free
* `ArgByVal` before cuFile had a change to use them!
* the by-reference argumentspoints into `ArgByVal`. This way, the `read_async` and `write_async`
* can call `.get_args()` to get the by-reference arguments required by cuFile's stream API.
* However, this also means that the caller of `read_async` and `write_async` MUST keep the returned
* `StreamFuture` alive until the operations is done otherwise `StreamFuture` will free `ArgByVal`
* before cuFile had a change to use them!
*/
class StreamFuture {
public:
private:
struct ArgByVal {
std::size_t size;
off_t file_offset;
off_t devPtr_offset;
ssize_t bytes_done;
};
struct ArgByRef {
std::size_t* size_p;
off_t* file_offset_p;
off_t* devPtr_offset_p;
ssize_t* bytes_done_p;
};

private:
void* _devPtr_base{nullptr};
CUstream _stream{nullptr};
ArgByVal* _val{nullptr};
ArgByRef _ref{nullptr};

public:
StreamFuture() noexcept = default;
Expand All @@ -82,10 +74,6 @@ class StreamFuture {
}
*_val = {
.size = size, .file_offset = file_offset, .devPtr_offset = devPtr_offset, .bytes_done = 0};
_ref = {.size_p = &_val->size,
.file_offset_p = &_val->file_offset,
.devPtr_offset_p = &_val->devPtr_offset,
.bytes_done_p = &_val->bytes_done};
}

/**
Expand All @@ -96,16 +84,14 @@ class StreamFuture {
StreamFuture(StreamFuture&& o) noexcept
: _devPtr_base{std::exchange(o._devPtr_base, nullptr)},
_stream{std::exchange(o._stream, nullptr)},
_val{std::exchange(o._val, nullptr)},
_ref{o._ref}
_val{std::exchange(o._val, nullptr)}
{
}
StreamFuture& operator=(StreamFuture&& o) noexcept
{
_devPtr_base = std::exchange(o._devPtr_base, nullptr);
_stream = std::exchange(o._stream, nullptr);
_val = std::exchange(o._val, nullptr);
_ref = o._ref;
return *this;
}

Expand All @@ -122,12 +108,12 @@ class StreamFuture {
*/
std::tuple<void*, std::size_t*, off_t*, off_t*, ssize_t*, CUstream> get_args() const noexcept
{
return std::make_tuple(_devPtr_base,
_ref.size_p,
_ref.file_offset_p,
_ref.devPtr_offset_p,
_ref.bytes_done_p,
_stream);
return {_devPtr_base,
&_val->size,
&_val->file_offset,
&_val->devPtr_offset,
&_val->bytes_done,
_stream};
}

/**
Expand All @@ -140,10 +126,10 @@ class StreamFuture {
std::size_t check_bytes_done() const
{
CUDA_DRIVER_TRY(cudaAPI::instance().StreamSynchronize(_stream));
CUFILE_CHECK_STREAM_IO(_ref.bytes_done_p);
CUFILE_CHECK_STREAM_IO(&_val->bytes_done);
// At this point, we know `*_bytes_done_p` is positive otherwise
// CUFILE_CHECK_STREAM_IO() would have raised an exception.
return static_cast<std::size_t>(*_ref.bytes_done_p);
return static_cast<std::size_t>(_val->bytes_done);
}
};

Expand Down

0 comments on commit efa4ca8

Please sign in to comment.