From 37ee4488f658cad4d5dfdfedb323c36a9b20df62 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Thu, 28 Sep 2023 01:06:21 -0400 Subject: [PATCH] stream to use for syncing by madness tasks is no longer stored in TLS, but in task body so that streams are per-task, not per thread in case a task recursively executes other tasks by doing Future::get(dowork=true) --- src/TiledArray/device/device_task_fn.h | 20 +++++-- src/TiledArray/external/device.h | 23 ++++--- src/TiledArray/reduce_task.h | 83 +++++++++++++++++--------- 3 files changed, 84 insertions(+), 42 deletions(-) diff --git a/src/TiledArray/device/device_task_fn.h b/src/TiledArray/device/device_task_fn.h index 6b7105c550..fada332c63 100644 --- a/src/TiledArray/device/device_task_fn.h +++ b/src/TiledArray/device/device_task_fn.h @@ -99,22 +99,31 @@ struct deviceTaskFn : public TaskInterface { protected: void run(const TaskThreadEnv& env) override { + TA_ASSERT(!stream_); + TA_ASSERT( + TiledArray::device::detail::madness_task_stream_opt_ptr_accessor() == + nullptr); + // tell the task to report stream to be synced with to this->stream_ + TiledArray::device::detail::madness_task_stream_opt_ptr_accessor() = + &this->stream_; + // run the async function, the function must call synchronize_stream() to // set the stream it used!! task_->run_async(); - // get the stream used by async function - auto stream_opt = TiledArray::device::detail::tls_stream_accessor(); + // clear ptr to stream_ + TiledArray::device::detail::madness_task_stream_opt_ptr_accessor() = + nullptr; // WARNING, need to handle NoOp - if (!stream_opt) { + if (!stream_) { task_->notify(); } else { // TODO should we use device callback or device events?? // insert device callback - TiledArray::device::launchHostFunc(*stream_opt, device_callback, task_); + TiledArray::device::launchHostFunc(*stream_, device_callback, task_); // processed sync, clear state - TiledArray::device::detail::tls_stream_accessor() = {}; + stream_ = {}; } } @@ -137,6 +146,7 @@ struct deviceTaskFn : public TaskInterface { } deviceTaskFn_* task_; + std::optional stream_; // stream to sync with }; public: diff --git a/src/TiledArray/external/device.h b/src/TiledArray/external/device.h index 685c804dff..133bb11c56 100644 --- a/src/TiledArray/external/device.h +++ b/src/TiledArray/external/device.h @@ -798,9 +798,14 @@ class Env { }; namespace detail { -inline std::optional& tls_stream_accessor() { - static thread_local std::optional tls_stream; - return tls_stream; +inline std::optional*& madness_task_stream_opt_ptr_accessor() { + static thread_local std::optional* stream_opt_ptr = nullptr; + return stream_opt_ptr; +} + +inline std::optional& madness_task_stream_opt_accessor() { + TA_ASSERT(madness_task_stream_opt_ptr_accessor() != nullptr); + return *madness_task_stream_opt_ptr_accessor(); } } // namespace detail @@ -810,10 +815,10 @@ inline std::optional& tls_stream_accessor() { /// before task completion /// \param s the stream to synchronize this task with inline void sync_madness_task_with(const Stream& s) { - if (!detail::tls_stream_accessor()) - detail::tls_stream_accessor() = s; + if (!detail::madness_task_stream_opt_accessor()) + detail::madness_task_stream_opt_accessor() = s; else { - TA_ASSERT(*detail::tls_stream_accessor() == s); + TA_ASSERT(*detail::madness_task_stream_opt_accessor() == s); } } @@ -841,7 +846,7 @@ inline void sync_madness_task_with(stream_t stream) { /// @return the optional Stream with which this task will be synced inline std::optional madness_task_current_stream() { - return detail::tls_stream_accessor(); + return detail::madness_task_stream_opt_accessor(); } /// should call this within a task submitted to @@ -849,7 +854,9 @@ inline std::optional madness_task_current_stream() { /// to cancel the previous calls to sync_madness_task_with() /// if, e.g., it synchronized with any work performed /// before exiting -inline void cancel_madness_task_sync() { detail::tls_stream_accessor() = {}; } +inline void cancel_madness_task_sync() { + detail::madness_task_stream_opt_accessor() = {}; +} /// maps a (tile) Range to device::Stream; if had already pushed work into a /// device::Stream (as indicated by madness_task_current_stream() ) diff --git a/src/TiledArray/reduce_task.h b/src/TiledArray/reduce_task.h index 34a2fef9ea..2a5813ff10 100644 --- a/src/TiledArray/reduce_task.h +++ b/src/TiledArray/reduce_task.h @@ -456,15 +456,21 @@ class ReduceTask { ready_object_ = nullptr; lock_.unlock(); // <<< End critical section +#ifdef TILEDARRAY_HAS_DEVICE + TA_ASSERT(device::detail::madness_task_stream_opt_ptr_accessor() == + nullptr); + device::detail::madness_task_stream_opt_ptr_accessor() = &stream_; +#endif + // Reduce the argument that was held by ready_object_ op_(*result, ready_object->arg()); // cleanup the argument #ifdef TILEDARRAY_HAS_DEVICE - auto& stream_opt = device::detail::tls_stream_accessor(); + device::detail::madness_task_stream_opt_ptr_accessor() = nullptr; // need to sync with a device stream? - if (!stream_opt) { // no + if (!stream_) { // no ReduceObject::destroy(ready_object); this->dec(); } else { @@ -472,12 +478,11 @@ class ReduceTask { (*callback_object)[0] = &world_; (*callback_object)[1] = this; (*callback_object)[2] = ready_object; - DeviceSafeCall(device::setDevice(stream_opt->device)); + DeviceSafeCall(device::setDevice(stream_->device)); DeviceSafeCall(device::launchHostFunc( - stream_opt->stream, + stream_->stream, device_dependency_dec_reduceobject_delete_callback, callback_object)); - device::cancel_madness_task_sync(); // std::cout << std::to_string(world().rank()) + " // add 3\n"; } @@ -491,14 +496,21 @@ class ReduceTask { ready_result_.reset(); lock_.unlock(); // <<< End critical section +#ifdef TILEDARRAY_HAS_DEVICE + TA_ASSERT(device::detail::madness_task_stream_opt_ptr_accessor() == + nullptr); + device::detail::madness_task_stream_opt_ptr_accessor() = &stream_; +#endif + // Reduce the result that was held by ready_result_ op_(*result, *ready_result); // cleanup the result #ifdef TILEDARRAY_HAS_DEVICE - auto queue_opt = device::detail::tls_stream_accessor(); + device::detail::madness_task_stream_opt_ptr_accessor() = nullptr; + // need to sync with a stream? - if (!queue_opt) { // no + if (!stream_) { // no ready_result.reset(); } else { // yes auto ready_result_heap = @@ -506,11 +518,10 @@ class ReduceTask { auto callback_object = new std::vector(2); (*callback_object)[0] = &world_; (*callback_object)[1] = ready_result_heap; - auto& [device, stream] = *queue_opt; + auto& [device, stream] = *stream_; DeviceSafeCall(device::setDevice(device)); DeviceSafeCall(device::launchHostFunc( stream, device_readyresult_reset_callback, callback_object)); - device::cancel_madness_task_sync(); // std::cout << std::to_string(world().rank()) + " // add 4\n"; } @@ -532,43 +543,49 @@ class ReduceTask { /// \param object The reduction argument to be reduced void reduce_result_object(std::shared_ptr result, const ReduceObject* object) { +#ifdef TILEDARRAY_HAS_DEVICE + TA_ASSERT(device::detail::madness_task_stream_opt_ptr_accessor() == + nullptr); + device::detail::madness_task_stream_opt_ptr_accessor() = &stream_; +#endif + // Reduce the argument op_(*result, object->arg()); // Cleanup the argument #ifdef TILEDARRAY_HAS_DEVICE - auto& stream_opt = device::detail::tls_stream_accessor(); - if (!stream_opt) { + device::detail::madness_task_stream_opt_ptr_accessor() = nullptr; + + if (!stream_) { ReduceObject::destroy(object); } else { auto callback_object = new std::vector(2); (*callback_object)[0] = &world_; (*callback_object)[1] = const_cast(object); - DeviceSafeCall(device::setDevice(stream_opt->device)); + DeviceSafeCall(device::setDevice(stream_->device)); DeviceSafeCall(device::launchHostFunc( - stream_opt->stream, device_reduceobject_delete_callback, + stream_->stream, device_reduceobject_delete_callback, callback_object)); - device::cancel_madness_task_sync(); // std::cout << std::to_string(world().rank()) + " add 1\n"; } #else ReduceObject::destroy(object); #endif + // Check for more reductions reduce(result); // Decrement the dependency counter for the argument. This must // be done after the reduce call to avoid a race condition. #ifdef TILEDARRAY_HAS_DEVICE - if (!stream_opt) { + if (!stream_) { this->dec(); } else { auto callback_object2 = new std::vector(1); (*callback_object2)[0] = this; - DeviceSafeCall(device::setDevice(stream_opt->device)); - DeviceSafeCall(device::launchHostFunc(stream_opt->stream, - device_dependency_dec_callback, - callback_object2)); + DeviceSafeCall(device::setDevice(stream_->device)); + DeviceSafeCall(device::launchHostFunc( + stream_->stream, device_dependency_dec_callback, callback_object2)); // std::cout << std::to_string(world().rank()) + " add 2\n"; } #else @@ -582,14 +599,21 @@ class ReduceTask { // Construct an empty result object auto result = std::make_shared(op_()); +#ifdef TILEDARRAY_HAS_DEVICE + TA_ASSERT(device::detail::madness_task_stream_opt_ptr_accessor() == + nullptr); + device::detail::madness_task_stream_opt_ptr_accessor() = &stream_; +#endif + // Reduce the two arguments op_(*result, object1->arg()); op_(*result, object2->arg()); // Cleanup arguments #ifdef TILEDARRAY_HAS_DEVICE - auto& stream_opt = device::detail::tls_stream_accessor(); - if (!stream_opt) { + device::detail::madness_task_stream_opt_ptr_accessor() = nullptr; + + if (!stream_) { ReduceObject::destroy(object1); ReduceObject::destroy(object2); } else { @@ -597,11 +621,10 @@ class ReduceTask { (*callback_object1)[0] = &world_; (*callback_object1)[1] = const_cast(object1); (*callback_object1)[2] = const_cast(object2); - DeviceSafeCall(device::setDevice(stream_opt->device)); + DeviceSafeCall(device::setDevice(stream_->device)); DeviceSafeCall(device::launchHostFunc( - stream_opt->stream, device_reduceobject_delete_callback, + stream_->stream, device_reduceobject_delete_callback, callback_object1)); - device::cancel_madness_task_sync(); // std::cout << std::to_string(world().rank()) + " add 1\n"; } #else @@ -615,17 +638,16 @@ class ReduceTask { // Decrement the dependency counter for the two arguments. This // must be done after the reduce call to avoid a race condition. #ifdef TILEDARRAY_HAS_DEVICE - if (!stream_opt) { + if (!stream_) { this->dec(); this->dec(); } else { auto callback_object2 = new std::vector(2); (*callback_object2)[0] = this; (*callback_object2)[1] = this; - DeviceSafeCall(device::setDevice(stream_opt->device)); - DeviceSafeCall(device::launchHostFunc(stream_opt->stream, - device_dependency_dec_callback, - callback_object2)); + DeviceSafeCall(device::setDevice(stream_->device)); + DeviceSafeCall(device::launchHostFunc( + stream_->stream, device_dependency_dec_callback, callback_object2)); // std::cout << std::to_string(world().rank()) + " add 2\n"; } @@ -671,6 +693,9 @@ class ReduceTask { madness::Spinlock lock_; ///< Task lock madness::CallbackInterface* callback_; ///< The completion callback int task_id_; ///< Task id +#ifdef TILEDARRAY_HAS_DEVICE + std::optional stream_; +#endif public: /// Implementation constructor