Skip to content

Commit

Permalink
~StreamFuture(): sync stream
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed Sep 13, 2023
1 parent efa4ca8 commit 3d7d50a
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 20 deletions.
6 changes: 3 additions & 3 deletions cpp/examples/basic_io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,10 @@ int main()
kvikio::FileHandle f_handle("/tmp/test-file", "w+");
check(cudaMemcpyAsync(a_dev, a, SIZE, cudaMemcpyHostToDevice, stream) == cudaSuccess);

// Notice, we MUST keep `res` alive until the data has been written to disk
// Notice, we get a handle `res`, which will synchronize the CUDA stream on destruction
kvikio::StreamFuture res = f_handle.write_async(a_dev, SIZE, 0, 0, stream);
// We can use `check_bytes_done()` to sync the associated stream and return the number
// of bytes written.
// But we can also trigger the synchronization and get the bytes written by calling
// `check_bytes_done()`.
check(res.check_bytes_done() == SIZE);
cout << "File async write: " << res.check_bytes_done() << endl;

Expand Down
56 changes: 39 additions & 17 deletions cpp/include/kvikio/stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,8 @@ namespace kvikio {
* like most other asynchronous CUDA functions that take by-value arguments.
*
* To support by-value arguments, we allocate the arguments on the heap (malloc `ArgByVal`) and have
* the by-reference argumentspoints into `ArgByVal`. This way, the `read_async` and `write_async`
* the by-reference arguments points into `ArgByVal`. This way, the `read_async` and `write_async`
* can call `.get_args()` to get the by-reference arguments required by cuFile's stream API.
* However, this also means that the caller of `read_async` and `write_async` MUST keep the returned
* `StreamFuture` alive until the operations is done otherwise `StreamFuture` will free `ArgByVal`
* before cuFile had a change to use them!
*/
class StreamFuture {
private:
Expand All @@ -59,6 +56,7 @@ class StreamFuture {
void* _devPtr_base{nullptr};
CUstream _stream{nullptr};
ArgByVal* _val{nullptr};
bool _stream_synchronized{false};

public:
StreamFuture() noexcept = default;
Expand All @@ -84,30 +82,30 @@ class StreamFuture {
StreamFuture(StreamFuture&& o) noexcept
: _devPtr_base{std::exchange(o._devPtr_base, nullptr)},
_stream{std::exchange(o._stream, nullptr)},
_val{std::exchange(o._val, nullptr)}
_val{std::exchange(o._val, nullptr)},
_stream_synchronized{o._stream_synchronized}
{
}
StreamFuture& operator=(StreamFuture&& o) noexcept
{
_devPtr_base = std::exchange(o._devPtr_base, nullptr);
_stream = std::exchange(o._stream, nullptr);
_val = std::exchange(o._val, nullptr);
_devPtr_base = std::exchange(o._devPtr_base, nullptr);
_stream = std::exchange(o._stream, nullptr);
_val = std::exchange(o._val, nullptr);
_stream_synchronized = o._stream_synchronized;
return *this;
}

~StreamFuture() noexcept
{
if (_val != nullptr) { free(_val); }
}

/**
* @brief Return the arguments of the future call
*
* @return Tuple of the arguments in the order matching `FileHandle.read()` and
* `FileHandle.write()`
*/
std::tuple<void*, std::size_t*, off_t*, off_t*, ssize_t*, CUstream> get_args() const noexcept
std::tuple<void*, std::size_t*, off_t*, off_t*, ssize_t*, CUstream> get_args() const
{
if (_val == nullptr) {
throw kvikio::CUfileException("cannot get arguments from an uninitialized StreamFuture");
}
return {_devPtr_base,
&_val->size,
&_val->file_offset,
Expand All @@ -123,14 +121,38 @@ class StreamFuture {
*
* @return Number of bytes read or written by the future operation.
*/
std::size_t check_bytes_done() const
std::size_t check_bytes_done()
{
CUDA_DRIVER_TRY(cudaAPI::instance().StreamSynchronize(_stream));
if (_val == nullptr) {
throw kvikio::CUfileException("cannot check bytes done on an uninitialized StreamFuture");
}

if (!_stream_synchronized) {
_stream_synchronized = true;
CUDA_DRIVER_TRY(cudaAPI::instance().StreamSynchronize(_stream));
}

CUFILE_CHECK_STREAM_IO(&_val->bytes_done);
// At this point, we know `*_bytes_done_p` is positive otherwise
// At this point, we know `*_val->bytes_done` is a positive value otherwise
// CUFILE_CHECK_STREAM_IO() would have raised an exception.
return static_cast<std::size_t>(_val->bytes_done);
}

/**
* @brief Free the by-value arguments and make sure the associated CUDA stream has been
* synchronized.
*/
~StreamFuture() noexcept
{
if (_val != nullptr) {
try {
check_bytes_done();
} catch (const kvikio::CUfileException& e) {
std::cerr << e.what() << std::endl;
}
free(_val);
}
}
};

} // namespace kvikio

0 comments on commit 3d7d50a

Please sign in to comment.