Skip to content

Commit

Permalink
stream to use for syncing by madness tasks is no longer stored in TLS…
Browse files Browse the repository at this point in the history
…, but in task body so that streams are per-task, not per thread in case a task recursively executes other tasks by doing Future::get(dowork=true)
  • Loading branch information
evaleev committed Sep 28, 2023
1 parent fdaf8bc commit 37ee448
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 42 deletions.
20 changes: 15 additions & 5 deletions src/TiledArray/device/device_task_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,22 +99,31 @@ struct deviceTaskFn : public TaskInterface {

protected:
void run(const TaskThreadEnv& env) override {
TA_ASSERT(!stream_);
TA_ASSERT(
TiledArray::device::detail::madness_task_stream_opt_ptr_accessor() ==
nullptr);
// tell the task to report stream to be synced with to this->stream_
TiledArray::device::detail::madness_task_stream_opt_ptr_accessor() =
&this->stream_;

// run the async function, the function must call synchronize_stream() to
// set the stream it used!!
task_->run_async();

// get the stream used by async function
auto stream_opt = TiledArray::device::detail::tls_stream_accessor();
// clear ptr to stream_
TiledArray::device::detail::madness_task_stream_opt_ptr_accessor() =
nullptr;

// WARNING, need to handle NoOp
if (!stream_opt) {
if (!stream_) {
task_->notify();
} else {
// TODO should we use device callback or device events??
// insert device callback
TiledArray::device::launchHostFunc(*stream_opt, device_callback, task_);
TiledArray::device::launchHostFunc(*stream_, device_callback, task_);
// processed sync, clear state
TiledArray::device::detail::tls_stream_accessor() = {};
stream_ = {};
}
}

Expand All @@ -137,6 +146,7 @@ struct deviceTaskFn : public TaskInterface {
}

deviceTaskFn_* task_;
std::optional<TiledArray::device::Stream> stream_; // stream to sync with
};

public:
Expand Down
23 changes: 15 additions & 8 deletions src/TiledArray/external/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -798,9 +798,14 @@ class Env {
};

namespace detail {
inline std::optional<Stream>& tls_stream_accessor() {
static thread_local std::optional<Stream> tls_stream;
return tls_stream;
inline std::optional<Stream>*& madness_task_stream_opt_ptr_accessor() {
static thread_local std::optional<Stream>* stream_opt_ptr = nullptr;
return stream_opt_ptr;
}

inline std::optional<Stream>& madness_task_stream_opt_accessor() {
TA_ASSERT(madness_task_stream_opt_ptr_accessor() != nullptr);
return *madness_task_stream_opt_ptr_accessor();
}
} // namespace detail

Expand All @@ -810,10 +815,10 @@ inline std::optional<Stream>& tls_stream_accessor() {
/// before task completion
/// \param s the stream to synchronize this task with
inline void sync_madness_task_with(const Stream& s) {
if (!detail::tls_stream_accessor())
detail::tls_stream_accessor() = s;
if (!detail::madness_task_stream_opt_accessor())
detail::madness_task_stream_opt_accessor() = s;
else {
TA_ASSERT(*detail::tls_stream_accessor() == s);
TA_ASSERT(*detail::madness_task_stream_opt_accessor() == s);
}
}

Expand Down Expand Up @@ -841,15 +846,17 @@ inline void sync_madness_task_with(stream_t stream) {

/// @return the optional Stream with which this task will be synced
inline std::optional<Stream> madness_task_current_stream() {
return detail::tls_stream_accessor();
return detail::madness_task_stream_opt_accessor();
}

/// should call this within a task submitted to
/// the MADNESS runtime via madness::add_device_task
/// to cancel the previous calls to sync_madness_task_with()
/// if, e.g., it synchronized with any work performed
/// before exiting
inline void cancel_madness_task_sync() { detail::tls_stream_accessor() = {}; }
inline void cancel_madness_task_sync() {
detail::madness_task_stream_opt_accessor() = {};
}

/// maps a (tile) Range to device::Stream; if had already pushed work into a
/// device::Stream (as indicated by madness_task_current_stream() )
Expand Down
83 changes: 54 additions & 29 deletions src/TiledArray/reduce_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,28 +456,33 @@ class ReduceTask {
ready_object_ = nullptr;
lock_.unlock(); // <<< End critical section

#ifdef TILEDARRAY_HAS_DEVICE
TA_ASSERT(device::detail::madness_task_stream_opt_ptr_accessor() ==
nullptr);
device::detail::madness_task_stream_opt_ptr_accessor() = &stream_;
#endif

// Reduce the argument that was held by ready_object_
op_(*result, ready_object->arg());

// cleanup the argument
#ifdef TILEDARRAY_HAS_DEVICE
auto& stream_opt = device::detail::tls_stream_accessor();
device::detail::madness_task_stream_opt_ptr_accessor() = nullptr;

// need to sync with a device stream?
if (!stream_opt) { // no
if (!stream_) { // no
ReduceObject::destroy(ready_object);
this->dec();
} else {
auto callback_object = new std::vector<void*>(3);
(*callback_object)[0] = &world_;
(*callback_object)[1] = this;
(*callback_object)[2] = ready_object;
DeviceSafeCall(device::setDevice(stream_opt->device));
DeviceSafeCall(device::setDevice(stream_->device));
DeviceSafeCall(device::launchHostFunc(
stream_opt->stream,
stream_->stream,
device_dependency_dec_reduceobject_delete_callback,
callback_object));
device::cancel_madness_task_sync();
// std::cout << std::to_string(world().rank()) + "
// add 3\n";
}
Expand All @@ -491,26 +496,32 @@ class ReduceTask {
ready_result_.reset();
lock_.unlock(); // <<< End critical section

#ifdef TILEDARRAY_HAS_DEVICE
TA_ASSERT(device::detail::madness_task_stream_opt_ptr_accessor() ==
nullptr);
device::detail::madness_task_stream_opt_ptr_accessor() = &stream_;
#endif

// Reduce the result that was held by ready_result_
op_(*result, *ready_result);

// cleanup the result
#ifdef TILEDARRAY_HAS_DEVICE
auto queue_opt = device::detail::tls_stream_accessor();
device::detail::madness_task_stream_opt_ptr_accessor() = nullptr;

// need to sync with a stream?
if (!queue_opt) { // no
if (!stream_) { // no
ready_result.reset();
} else { // yes
auto ready_result_heap =
new std::shared_ptr<result_type>(ready_result);
auto callback_object = new std::vector<void*>(2);
(*callback_object)[0] = &world_;
(*callback_object)[1] = ready_result_heap;
auto& [device, stream] = *queue_opt;
auto& [device, stream] = *stream_;
DeviceSafeCall(device::setDevice(device));
DeviceSafeCall(device::launchHostFunc(
stream, device_readyresult_reset_callback, callback_object));
device::cancel_madness_task_sync();
// std::cout << std::to_string(world().rank()) + "
// add 4\n";
}
Expand All @@ -532,43 +543,49 @@ class ReduceTask {
/// \param object The reduction argument to be reduced
void reduce_result_object(std::shared_ptr<result_type> result,
const ReduceObject* object) {
#ifdef TILEDARRAY_HAS_DEVICE
TA_ASSERT(device::detail::madness_task_stream_opt_ptr_accessor() ==
nullptr);
device::detail::madness_task_stream_opt_ptr_accessor() = &stream_;
#endif

// Reduce the argument
op_(*result, object->arg());

// Cleanup the argument
#ifdef TILEDARRAY_HAS_DEVICE
auto& stream_opt = device::detail::tls_stream_accessor();
if (!stream_opt) {
device::detail::madness_task_stream_opt_ptr_accessor() = nullptr;

if (!stream_) {
ReduceObject::destroy(object);
} else {
auto callback_object = new std::vector<void*>(2);
(*callback_object)[0] = &world_;
(*callback_object)[1] = const_cast<ReduceObject*>(object);
DeviceSafeCall(device::setDevice(stream_opt->device));
DeviceSafeCall(device::setDevice(stream_->device));
DeviceSafeCall(device::launchHostFunc(
stream_opt->stream, device_reduceobject_delete_callback,
stream_->stream, device_reduceobject_delete_callback,
callback_object));
device::cancel_madness_task_sync();
// std::cout << std::to_string(world().rank()) + " add 1\n";
}
#else
ReduceObject::destroy(object);
#endif

// Check for more reductions
reduce(result);

// Decrement the dependency counter for the argument. This must
// be done after the reduce call to avoid a race condition.
#ifdef TILEDARRAY_HAS_DEVICE
if (!stream_opt) {
if (!stream_) {
this->dec();
} else {
auto callback_object2 = new std::vector<void*>(1);
(*callback_object2)[0] = this;
DeviceSafeCall(device::setDevice(stream_opt->device));
DeviceSafeCall(device::launchHostFunc(stream_opt->stream,
device_dependency_dec_callback,
callback_object2));
DeviceSafeCall(device::setDevice(stream_->device));
DeviceSafeCall(device::launchHostFunc(
stream_->stream, device_dependency_dec_callback, callback_object2));
// std::cout << std::to_string(world().rank()) + " add 2\n";
}
#else
Expand All @@ -582,26 +599,32 @@ class ReduceTask {
// Construct an empty result object
auto result = std::make_shared<result_type>(op_());

#ifdef TILEDARRAY_HAS_DEVICE
TA_ASSERT(device::detail::madness_task_stream_opt_ptr_accessor() ==
nullptr);
device::detail::madness_task_stream_opt_ptr_accessor() = &stream_;
#endif

// Reduce the two arguments
op_(*result, object1->arg());
op_(*result, object2->arg());

// Cleanup arguments
#ifdef TILEDARRAY_HAS_DEVICE
auto& stream_opt = device::detail::tls_stream_accessor();
if (!stream_opt) {
device::detail::madness_task_stream_opt_ptr_accessor() = nullptr;

if (!stream_) {
ReduceObject::destroy(object1);
ReduceObject::destroy(object2);
} else {
auto callback_object1 = new std::vector<void*>(3);
(*callback_object1)[0] = &world_;
(*callback_object1)[1] = const_cast<ReduceObject*>(object1);
(*callback_object1)[2] = const_cast<ReduceObject*>(object2);
DeviceSafeCall(device::setDevice(stream_opt->device));
DeviceSafeCall(device::setDevice(stream_->device));
DeviceSafeCall(device::launchHostFunc(
stream_opt->stream, device_reduceobject_delete_callback,
stream_->stream, device_reduceobject_delete_callback,
callback_object1));
device::cancel_madness_task_sync();
// std::cout << std::to_string(world().rank()) + " add 1\n";
}
#else
Expand All @@ -615,17 +638,16 @@ class ReduceTask {
// Decrement the dependency counter for the two arguments. This
// must be done after the reduce call to avoid a race condition.
#ifdef TILEDARRAY_HAS_DEVICE
if (!stream_opt) {
if (!stream_) {
this->dec();
this->dec();
} else {
auto callback_object2 = new std::vector<void*>(2);
(*callback_object2)[0] = this;
(*callback_object2)[1] = this;
DeviceSafeCall(device::setDevice(stream_opt->device));
DeviceSafeCall(device::launchHostFunc(stream_opt->stream,
device_dependency_dec_callback,
callback_object2));
DeviceSafeCall(device::setDevice(stream_->device));
DeviceSafeCall(device::launchHostFunc(
stream_->stream, device_dependency_dec_callback, callback_object2));
// std::cout << std::to_string(world().rank()) + " add 2\n";
}

Expand Down Expand Up @@ -671,6 +693,9 @@ class ReduceTask {
madness::Spinlock lock_; ///< Task lock
madness::CallbackInterface* callback_; ///< The completion callback
int task_id_; ///< Task id
#ifdef TILEDARRAY_HAS_DEVICE
std::optional<device::Stream> stream_;
#endif

public:
/// Implementation constructor
Expand Down

0 comments on commit 37ee448

Please sign in to comment.