Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Safe handling of control plane promises #380

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions ci/conda/environments/clang_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ name: mrc
channels:
- conda-forge
dependencies:
- clang=16
- clang-tools=16
- clangdev=16
- clangxx=16
- libclang=16
- libclang-cpp=16
- llvmdev=16
- include-what-you-use
- clang=15
- clang-tools=15
- clangdev=15
- clangxx=15
- libclang=15
- libclang-cpp=15
- llvmdev=15
- include-what-you-use=0.19
4 changes: 2 additions & 2 deletions ci/conda/environments/dev_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies:
- autoconf>=2.69
- bash-completion
- benchmark=1.6.0
- boost-cpp=1.74
- boost-cpp=1.82
- ccache
- cmake=3.24
- cuda-toolkit # Version comes from the channel above
Expand All @@ -46,7 +46,7 @@ dependencies:
- isort
- jinja2=3.0
- lcov=1.15
- libhwloc=2.5
- libhwloc=2.9.2
- libprotobuf=3.21
- librmm=23.06
- libtool
Expand Down
74 changes: 61 additions & 13 deletions cpp/mrc/src/internal/control_plane/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@

#include "internal/control_plane/client/connections_manager.hpp"
#include "internal/grpc/progress_engine.hpp"
#include "internal/grpc/promise_handler.hpp"
#include "internal/grpc/promise_handler.hpp" // for PromiseHandler
#include "internal/grpc/stream_writer.hpp" // for StreamWriter
#include "internal/runnable/runnable_resources.hpp"
#include "internal/service.hpp"
#include "internal/system/system.hpp"

#include "mrc/channel/status.hpp"
Expand All @@ -33,15 +35,20 @@
#include "mrc/runnable/launch_control.hpp"
#include "mrc/runnable/launcher.hpp"
#include "mrc/runnable/runner.hpp"
#include "mrc/types.hpp"

#include <boost/fiber/future/promise.hpp> // for promise
#include <google/protobuf/any.pb.h>
#include <grpcpp/grpcpp.h>
#include <rxcpp/rx.hpp>

#include <mutex>
#include <ostream>

namespace mrc::control_plane {

std::atomic_uint64_t AsyncEventStatus::s_request_id_counter;

Client::Client(resources::PartitionResourceBase& base, std::shared_ptr<grpc::CompletionQueue> cq) :
resources::PartitionResourceBase(base),
m_cq(std::move(cq)),
Expand Down Expand Up @@ -73,13 +80,11 @@
if (m_owns_progress_engine)
{
CHECK(m_cq);
auto progress_engine = std::make_unique<rpc::ProgressEngine>(m_cq);
auto progress_handler = std::make_unique<rpc::PromiseHandler>();
auto progress_engine = std::make_unique<rpc::ProgressEngine>(m_cq);
m_progress_handler = std::make_unique<rpc::PromiseHandler>();

mrc::make_edge(*progress_engine, *progress_handler);
mrc::make_edge(*progress_engine, *m_progress_handler);

m_progress_handler =
runnable().launch_control().prepare_launcher(launch_options(), std::move(progress_handler))->ignition();
m_progress_engine =
runnable().launch_control().prepare_launcher(launch_options(), std::move(progress_engine))->ignition();
}
Expand Down Expand Up @@ -135,7 +140,6 @@
if (m_owns_progress_engine)
{
m_progress_engine->await_live();
m_progress_handler->await_live();
}
m_event_handler->await_live();
}
Expand All @@ -150,7 +154,6 @@
{
m_cq->Shutdown();
m_progress_engine->await_join();
m_progress_handler->await_join();
}
}

Expand All @@ -161,10 +164,21 @@
// handle a subset of events directly on the event handler

case protos::EventType::Response: {
auto* promise = reinterpret_cast<Promise<protos::Event>*>(event.msg.tag());
if (promise != nullptr)
auto event_tag = event.msg.tag();

if (event_tag != 0)
{
promise->set_value(std::move(event.msg));
// Lock to prevent multiple threads
std::unique_lock<decltype(m_mutex)> lock(m_mutex);

// Find the promise associated with the event tag
auto promise = m_pending_events.extract(event_tag);

// Unlock to allow other threads to continue as soon as possible
lock.unlock();

// Finally, set the value
promise.mapped().set_value(std::move(event.msg));
}
}
break;
Expand Down Expand Up @@ -242,11 +256,12 @@
return m_launch_options;
}

void Client::issue_event(const protos::EventType& event_type)
AsyncEventStatus Client::issue_event(const protos::EventType& event_type)
{
protos::Event event;
event.set_event(event_type);
m_writer->await_write(std::move(event));
// m_writer->await_write(std::move(event));
return this->write_event(std::move(event), false);
}

void Client::request_update()
Expand All @@ -260,4 +275,37 @@
// }
}

AsyncEventStatus Client::write_event(protos::Event event, bool await_response)
{
if (event.tag() != 0)
{
LOG(WARNING) << "event tag is set but this field should exclusively be used by the control plane client. "
"Clearing to avoid confusion";
event.clear_tag();

Check warning on line 284 in cpp/mrc/src/internal/control_plane/client.cpp

View check run for this annotation

Codecov / codecov/patch

cpp/mrc/src/internal/control_plane/client.cpp#L282-L284

Added lines #L282 - L284 were not covered by tests
}

AsyncEventStatus status;

if (await_response)
{
// If we are supporting awaiting, create the promise now
Promise<protos::Event> promise;

// Set the future to the status
status.set_future(promise.get_future());

// Set the tag to the request ID to allow looking up the promise later
event.set_tag(status.request_id());

// Save the promise to the pending promises to be retrieved later
std::unique_lock<decltype(m_mutex)> lock(m_mutex);

m_pending_events[status.request_id()] = std::move(promise);
}

// Finally, write the event
m_writer->await_write(std::move(event));

return status;
}
} // namespace mrc::control_plane
127 changes: 75 additions & 52 deletions cpp/mrc/src/internal/control_plane/client.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,22 @@

#include "internal/control_plane/client/instance.hpp" // IWYU pragma: keep
#include "internal/grpc/client_streaming.hpp"
#include "internal/grpc/stream_writer.hpp"
#include "internal/resources/partition_resources_base.hpp"
#include "internal/service.hpp"

#include "mrc/core/error.hpp"
#include "mrc/exceptions/runtime_error.hpp"
#include "mrc/node/forward.hpp"
#include "mrc/node/writable_entrypoint.hpp"
#include "mrc/protos/architect.grpc.pb.h"
#include "mrc/protos/architect.pb.h"
#include "mrc/runnable/launch_options.hpp"
#include "mrc/types.hpp"
#include "mrc/utils/macros.hpp"

#include <boost/fiber/future/future.hpp>
#include <glog/logging.h>

#include <atomic>
#include <cstddef> // for size_t
#include <cstdint>
#include <map>
#include <memory>
Expand Down Expand Up @@ -65,10 +65,62 @@
class Runner;
} // namespace mrc::runnable

namespace mrc::rpc {
class PromiseHandler;
template <typename T>
struct StreamWriter;
} // namespace mrc::rpc

namespace mrc::control_plane {

template <typename ResponseT>
class AsyncStatus;
class AsyncEventStatus
{
public:
size_t request_id() const
{
return m_request_id;
}

template <typename ResponseT>
Expected<ResponseT> await_response()
{
if (!m_future.valid())
{
throw exceptions::MrcRuntimeError(

Check warning on line 89 in cpp/mrc/src/internal/control_plane/client.hpp

View check run for this annotation

Codecov / codecov/patch

cpp/mrc/src/internal/control_plane/client.hpp#L89

Added line #L89 was not covered by tests
"This AsyncEventStatus is not expecting a response or the response has already been awaited");
}

auto event = m_future.get();

if (event.has_error())
{
return Error::create(event.error().message());

Check warning on line 97 in cpp/mrc/src/internal/control_plane/client.hpp

View check run for this annotation

Codecov / codecov/patch

cpp/mrc/src/internal/control_plane/client.hpp#L97

Added line #L97 was not covered by tests
}

ResponseT response;
if (!event.message().UnpackTo(&response))
{
throw Error::create("fatal error: unable to unpack message; server sent the wrong message type");

Check warning on line 103 in cpp/mrc/src/internal/control_plane/client.hpp

View check run for this annotation

Codecov / codecov/patch

cpp/mrc/src/internal/control_plane/client.hpp#L103

Added line #L103 was not covered by tests
}

return response;
}

private:
AsyncEventStatus() : m_request_id(++s_request_id_counter) {}

void set_future(Future<protos::Event> future)
{
m_future = std::move(future);
}

static std::atomic_size_t s_request_id_counter;

size_t m_request_id;
Future<protos::Event> m_future;

friend class Client;
};

/**
* @brief Primary Control Plane Client
Expand Down Expand Up @@ -128,13 +180,13 @@
template <typename ResponseT, typename RequestT>
Expected<ResponseT> await_unary(const protos::EventType& event_type, RequestT&& request);

template <typename ResponseT, typename RequestT>
void async_unary(const protos::EventType& event_type, RequestT&& request, AsyncStatus<ResponseT>& status);
template <typename RequestT>
AsyncEventStatus async_unary(const protos::EventType& event_type, RequestT&& request);

template <typename MessageT>
void issue_event(const protos::EventType& event_type, MessageT&& message);
AsyncEventStatus issue_event(const protos::EventType& event_type, MessageT&& message);

void issue_event(const protos::EventType& event_type);
AsyncEventStatus issue_event(const protos::EventType& event_type);

bool has_subscription_service(const std::string& name) const;

Expand All @@ -150,6 +202,8 @@
void request_update();

private:
AsyncEventStatus write_event(protos::Event event, bool await_response = false);

void route_state_update(std::uint64_t tag, protos::StateUpdate&& update);

void do_service_start() final;
Expand All @@ -175,7 +229,7 @@
// if true, then the following runners should not be null
// if false, then the following runners must be null
const bool m_owns_progress_engine;
std::unique_ptr<mrc::runnable::Runner> m_progress_handler;
std::unique_ptr<mrc::rpc::PromiseHandler> m_progress_handler;
std::unique_ptr<mrc::runnable::Runner> m_progress_engine;
std::unique_ptr<mrc::runnable::Runner> m_event_handler;

Expand All @@ -201,70 +255,39 @@

std::mutex m_mutex;

std::map<size_t, Promise<protos::Event>> m_pending_events;

friend network::NetworkResources;
};

// todo: create this object from the client which will own the stop_source
// create this object with a stop_token associated with the client's stop_source

template <typename ResponseT>
class AsyncStatus
{
public:
AsyncStatus() = default;

DELETE_COPYABILITY(AsyncStatus);
DELETE_MOVEABILITY(AsyncStatus);

Expected<ResponseT> await_response()
{
// todo(ryan): expand this into a wait_until with a deadline and a stop token
auto event = m_promise.get_future().get();

if (event.has_error())
{
return Error::create(event.error().message());
}

ResponseT response;
if (!event.message().UnpackTo(&response))
{
throw Error::create("fatal error: unable to unpack message; server sent the wrong message type");
}

return response;
}

private:
Promise<protos::Event> m_promise;
friend Client;
};

template <typename ResponseT, typename RequestT>
Expected<ResponseT> Client::await_unary(const protos::EventType& event_type, RequestT&& request)
{
AsyncStatus<ResponseT> status;
async_unary(event_type, std::move(request), status);
return status.await_response();
auto status = this->async_unary(event_type, std::move(request));
return status.template await_response<ResponseT>();
}

template <typename ResponseT, typename RequestT>
void Client::async_unary(const protos::EventType& event_type, RequestT&& request, AsyncStatus<ResponseT>& status)
template <typename RequestT>
AsyncEventStatus Client::async_unary(const protos::EventType& event_type, RequestT&& request)
{
protos::Event event;
event.set_event(event_type);
event.set_tag(reinterpret_cast<std::uint64_t>(&status.m_promise));
CHECK(event.mutable_message()->PackFrom(request));
m_writer->await_write(std::move(event));

return this->write_event(std::move(event), true);
}

template <typename MessageT>
void Client::issue_event(const protos::EventType& event_type, MessageT&& message)
AsyncEventStatus Client::issue_event(const protos::EventType& event_type, MessageT&& message)

Check warning on line 284 in cpp/mrc/src/internal/control_plane/client.hpp

View check run for this annotation

Codecov / codecov/patch

cpp/mrc/src/internal/control_plane/client.hpp#L284

Added line #L284 was not covered by tests
{
protos::Event event;
event.set_event(event_type);
CHECK(event.mutable_message()->PackFrom(message));
m_writer->await_write(std::move(event));

return this->write_event(std::move(event), false);

Check warning on line 290 in cpp/mrc/src/internal/control_plane/client.hpp

View check run for this annotation

Codecov / codecov/patch

cpp/mrc/src/internal/control_plane/client.hpp#L290

Added line #L290 was not covered by tests
}

} // namespace mrc::control_plane
Loading