diff --git a/ci/conda/environments/clang_env.yml b/ci/conda/environments/clang_env.yml index 50d6cc655..bebe11bfd 100644 --- a/ci/conda/environments/clang_env.yml +++ b/ci/conda/environments/clang_env.yml @@ -26,4 +26,4 @@ dependencies: - libclang=16 - libclang-cpp=16 - llvmdev=16 - - include-what-you-use + - include-what-you-use=0.20 diff --git a/ci/conda/environments/dev_env.yml b/ci/conda/environments/dev_env.yml index 58d83d9a7..5af8a91c9 100644 --- a/ci/conda/environments/dev_env.yml +++ b/ci/conda/environments/dev_env.yml @@ -25,7 +25,7 @@ dependencies: - autoconf>=2.69 - bash-completion - benchmark=1.6.0 - - boost-cpp=1.74 + - boost-cpp=1.82 - ccache - cmake=3.24 - cuda-toolkit # Version comes from the channel above @@ -46,7 +46,7 @@ dependencies: - isort - jinja2=3.0 - lcov=1.15 - - libhwloc=2.5 + - libhwloc=2.9.2 - libprotobuf=3.21 - librmm=23.06 - libtool diff --git a/ci/scripts/cpp_checks.sh b/ci/scripts/cpp_checks.sh index 416c92167..b83df0727 100755 --- a/ci/scripts/cpp_checks.sh +++ b/ci/scripts/cpp_checks.sh @@ -80,9 +80,22 @@ if [[ -n "${MRC_MODIFIED_FILES}" ]]; then # Include What You Use if [[ "${SKIP_IWYU}" == "" ]]; then - IWYU_DIRS="cpp python" + # Remove .h, .hpp, and .cu files from the modified list + shopt -s extglob + IWYU_MODIFIED_FILES=( "${MRC_MODIFIED_FILES[@]/*.@(h|hpp|cu)/}" ) + + # Get the list of compiled files relative to this directory + WORKING_PREFIX="${PWD}/" + COMPILED_FILES=( $(jq -r .[].file ${BUILD_DIR}/compile_commands.json | sort -u ) ) + COMPILED_FILES=( "${COMPILED_FILES[@]/#$WORKING_PREFIX/}" ) + COMBINED_FILES=("${COMPILED_FILES[@]}") + COMBINED_FILES+=("${IWYU_MODIFIED_FILES[@]}") + + # Find the intersection between compiled files and modified files + IWYU_MODIFIED_FILES=( $(printf '%s\0' "${COMBINED_FILES[@]}" | sort -z | uniq -d -z | xargs -0n1) ) + NUM_PROC=$(get_num_proc) - IWYU_OUTPUT=`${IWYU_TOOL} -p ${BUILD_DIR} -j ${NUM_PROC} ${IWYU_DIRS} 2>&1` + IWYU_OUTPUT=`${IWYU_TOOL} -p ${BUILD_DIR} -j ${NUM_PROC} ${IWYU_MODIFIED_FILES[@]} 2>&1` IWYU_RETVAL=$? fi else diff --git a/cpp/mrc/src/internal/control_plane/client.cpp b/cpp/mrc/src/internal/control_plane/client.cpp index 7a85adc2e..03af88e3e 100644 --- a/cpp/mrc/src/internal/control_plane/client.cpp +++ b/cpp/mrc/src/internal/control_plane/client.cpp @@ -19,8 +19,10 @@ #include "internal/control_plane/client/connections_manager.hpp" #include "internal/grpc/progress_engine.hpp" -#include "internal/grpc/promise_handler.hpp" +#include "internal/grpc/promise_handler.hpp" // for PromiseHandler +#include "internal/grpc/stream_writer.hpp" // for StreamWriter #include "internal/runnable/runnable_resources.hpp" +#include "internal/service.hpp" #include "internal/system/system.hpp" #include "mrc/channel/status.hpp" @@ -33,15 +35,20 @@ #include "mrc/runnable/launch_control.hpp" #include "mrc/runnable/launcher.hpp" #include "mrc/runnable/runner.hpp" +#include "mrc/types.hpp" +#include // for promise #include #include #include +#include #include namespace mrc::control_plane { +std::atomic_uint64_t AsyncEventStatus::s_request_id_counter; + Client::Client(resources::PartitionResourceBase& base, std::shared_ptr cq) : resources::PartitionResourceBase(base), m_cq(std::move(cq)), @@ -73,13 +80,11 @@ void Client::do_service_start() if (m_owns_progress_engine) { CHECK(m_cq); - auto progress_engine = std::make_unique(m_cq); - auto progress_handler = std::make_unique(); + auto progress_engine = std::make_unique(m_cq); + m_progress_handler = std::make_unique(); - mrc::make_edge(*progress_engine, *progress_handler); + mrc::make_edge(*progress_engine, *m_progress_handler); - m_progress_handler = - runnable().launch_control().prepare_launcher(launch_options(), std::move(progress_handler))->ignition(); m_progress_engine = runnable().launch_control().prepare_launcher(launch_options(), std::move(progress_engine))->ignition(); } @@ -135,7 +140,6 @@ void Client::do_service_await_live() if (m_owns_progress_engine) { m_progress_engine->await_live(); - m_progress_handler->await_live(); } m_event_handler->await_live(); } @@ -150,7 +154,6 @@ void Client::do_service_await_join() { m_cq->Shutdown(); m_progress_engine->await_join(); - m_progress_handler->await_join(); } } @@ -161,10 +164,21 @@ void Client::do_handle_event(event_t&& event) // handle a subset of events directly on the event handler case protos::EventType::Response: { - auto* promise = reinterpret_cast*>(event.msg.tag()); - if (promise != nullptr) + auto event_tag = event.msg.tag(); + + if (event_tag != 0) { - promise->set_value(std::move(event.msg)); + // Lock to prevent multiple threads + std::unique_lock lock(m_mutex); + + // Find the promise associated with the event tag + auto promise = m_pending_events.extract(event_tag); + + // Unlock to allow other threads to continue as soon as possible + lock.unlock(); + + // Finally, set the value + promise.mapped().set_value(std::move(event.msg)); } } break; @@ -242,11 +256,12 @@ const mrc::runnable::LaunchOptions& Client::launch_options() const return m_launch_options; } -void Client::issue_event(const protos::EventType& event_type) +AsyncEventStatus Client::issue_event(const protos::EventType& event_type) { protos::Event event; event.set_event(event_type); - m_writer->await_write(std::move(event)); + // m_writer->await_write(std::move(event)); + return this->write_event(std::move(event), false); } void Client::request_update() @@ -260,4 +275,37 @@ void Client::request_update() // } } +AsyncEventStatus Client::write_event(protos::Event event, bool await_response) +{ + if (event.tag() != 0) + { + LOG(WARNING) << "event tag is set but this field should exclusively be used by the control plane client. " + "Clearing to avoid confusion"; + event.clear_tag(); + } + + AsyncEventStatus status; + + if (await_response) + { + // If we are supporting awaiting, create the promise now + Promise promise; + + // Set the future to the status + status.set_future(promise.get_future()); + + // Set the tag to the request ID to allow looking up the promise later + event.set_tag(status.request_id()); + + // Save the promise to the pending promises to be retrieved later + std::unique_lock lock(m_mutex); + + m_pending_events[status.request_id()] = std::move(promise); + } + + // Finally, write the event + m_writer->await_write(std::move(event)); + + return status; +} } // namespace mrc::control_plane diff --git a/cpp/mrc/src/internal/control_plane/client.hpp b/cpp/mrc/src/internal/control_plane/client.hpp index 0a07991a6..f23990614 100644 --- a/cpp/mrc/src/internal/control_plane/client.hpp +++ b/cpp/mrc/src/internal/control_plane/client.hpp @@ -19,22 +19,22 @@ #include "internal/control_plane/client/instance.hpp" // IWYU pragma: keep #include "internal/grpc/client_streaming.hpp" -#include "internal/grpc/stream_writer.hpp" #include "internal/resources/partition_resources_base.hpp" #include "internal/service.hpp" #include "mrc/core/error.hpp" +#include "mrc/exceptions/runtime_error.hpp" #include "mrc/node/forward.hpp" #include "mrc/node/writable_entrypoint.hpp" #include "mrc/protos/architect.grpc.pb.h" #include "mrc/protos/architect.pb.h" #include "mrc/runnable/launch_options.hpp" #include "mrc/types.hpp" -#include "mrc/utils/macros.hpp" -#include #include +#include +#include // for size_t #include #include #include @@ -65,10 +65,62 @@ namespace mrc::runnable { class Runner; } // namespace mrc::runnable +namespace mrc::rpc { +class PromiseHandler; +template +struct StreamWriter; +} // namespace mrc::rpc + namespace mrc::control_plane { -template -class AsyncStatus; +class AsyncEventStatus +{ + public: + size_t request_id() const + { + return m_request_id; + } + + template + Expected await_response() + { + if (!m_future.valid()) + { + throw exceptions::MrcRuntimeError( + "This AsyncEventStatus is not expecting a response or the response has already been awaited"); + } + + auto event = m_future.get(); + + if (event.has_error()) + { + return Error::create(event.error().message()); + } + + ResponseT response; + if (!event.message().UnpackTo(&response)) + { + throw Error::create("fatal error: unable to unpack message; server sent the wrong message type"); + } + + return response; + } + + private: + AsyncEventStatus() : m_request_id(++s_request_id_counter) {} + + void set_future(Future future) + { + m_future = std::move(future); + } + + static std::atomic_size_t s_request_id_counter; + + size_t m_request_id; + Future m_future; + + friend class Client; +}; /** * @brief Primary Control Plane Client @@ -128,13 +180,13 @@ class Client final : public resources::PartitionResourceBase, public Service template Expected await_unary(const protos::EventType& event_type, RequestT&& request); - template - void async_unary(const protos::EventType& event_type, RequestT&& request, AsyncStatus& status); + template + AsyncEventStatus async_unary(const protos::EventType& event_type, RequestT&& request); template - void issue_event(const protos::EventType& event_type, MessageT&& message); + AsyncEventStatus issue_event(const protos::EventType& event_type, MessageT&& message); - void issue_event(const protos::EventType& event_type); + AsyncEventStatus issue_event(const protos::EventType& event_type); bool has_subscription_service(const std::string& name) const; @@ -150,6 +202,8 @@ class Client final : public resources::PartitionResourceBase, public Service void request_update(); private: + AsyncEventStatus write_event(protos::Event event, bool await_response = false); + void route_state_update(std::uint64_t tag, protos::StateUpdate&& update); void do_service_start() final; @@ -175,7 +229,7 @@ class Client final : public resources::PartitionResourceBase, public Service // if true, then the following runners should not be null // if false, then the following runners must be null const bool m_owns_progress_engine; - std::unique_ptr m_progress_handler; + std::unique_ptr m_progress_handler; std::unique_ptr m_progress_engine; std::unique_ptr m_event_handler; @@ -201,70 +255,39 @@ class Client final : public resources::PartitionResourceBase, public Service std::mutex m_mutex; + std::map> m_pending_events; + friend network::NetworkResources; }; // todo: create this object from the client which will own the stop_source // create this object with a stop_token associated with the client's stop_source -template -class AsyncStatus -{ - public: - AsyncStatus() = default; - - DELETE_COPYABILITY(AsyncStatus); - DELETE_MOVEABILITY(AsyncStatus); - - Expected await_response() - { - // todo(ryan): expand this into a wait_until with a deadline and a stop token - auto event = m_promise.get_future().get(); - - if (event.has_error()) - { - return Error::create(event.error().message()); - } - - ResponseT response; - if (!event.message().UnpackTo(&response)) - { - throw Error::create("fatal error: unable to unpack message; server sent the wrong message type"); - } - - return response; - } - - private: - Promise m_promise; - friend Client; -}; - template Expected Client::await_unary(const protos::EventType& event_type, RequestT&& request) { - AsyncStatus status; - async_unary(event_type, std::move(request), status); - return status.await_response(); + auto status = this->async_unary(event_type, std::move(request)); + return status.template await_response(); } -template -void Client::async_unary(const protos::EventType& event_type, RequestT&& request, AsyncStatus& status) +template +AsyncEventStatus Client::async_unary(const protos::EventType& event_type, RequestT&& request) { protos::Event event; event.set_event(event_type); - event.set_tag(reinterpret_cast(&status.m_promise)); CHECK(event.mutable_message()->PackFrom(request)); - m_writer->await_write(std::move(event)); + + return this->write_event(std::move(event), true); } template -void Client::issue_event(const protos::EventType& event_type, MessageT&& message) +AsyncEventStatus Client::issue_event(const protos::EventType& event_type, MessageT&& message) { protos::Event event; event.set_event(event_type); CHECK(event.mutable_message()->PackFrom(message)); - m_writer->await_write(std::move(event)); + + return this->write_event(std::move(event), false); } } // namespace mrc::control_plane diff --git a/cpp/mrc/src/internal/grpc/promise_handler.hpp b/cpp/mrc/src/internal/grpc/promise_handler.hpp index 437a22e69..812a683e3 100644 --- a/cpp/mrc/src/internal/grpc/promise_handler.hpp +++ b/cpp/mrc/src/internal/grpc/promise_handler.hpp @@ -28,13 +28,19 @@ namespace mrc::rpc { /** * @brief MRC Sink to handle ProgressEvents which correspond to Promise tags */ -class PromiseHandler final : public mrc::node::GenericSink +class PromiseHandler final : public mrc::node::GenericSinkComponent { - void on_data(ProgressEvent&& event) final + mrc::channel::Status on_data(ProgressEvent&& event) final { auto* promise = static_cast*>(event.tag); promise->set_value(event.ok); - } + return mrc::channel::Status::success; + }; + + void on_complete() override + { + SinkProperties::release_edge_connection(); + }; }; } // namespace mrc::rpc diff --git a/cpp/mrc/src/internal/grpc/server.cpp b/cpp/mrc/src/internal/grpc/server.cpp index 9e0c0ecb4..65de1417e 100644 --- a/cpp/mrc/src/internal/grpc/server.cpp +++ b/cpp/mrc/src/internal/grpc/server.cpp @@ -18,7 +18,7 @@ #include "internal/grpc/server.hpp" #include "internal/grpc/progress_engine.hpp" -#include "internal/grpc/promise_handler.hpp" +#include "internal/grpc/promise_handler.hpp" // for PromiseHandler #include "internal/runnable/runnable_resources.hpp" #include "mrc/edge/edge_builder.hpp" @@ -47,11 +47,10 @@ void Server::do_service_start() m_server = m_builder.BuildAndStart(); auto progress_engine = std::make_unique(m_cq); - auto event_handler = std::make_unique(); - mrc::make_edge(*progress_engine, *event_handler); + m_event_hander = std::make_unique(); + mrc::make_edge(*progress_engine, *m_event_hander); m_progress_engine = m_runnable.launch_control().prepare_launcher(std::move(progress_engine))->ignition(); - m_event_hander = m_runnable.launch_control().prepare_launcher(std::move(event_handler))->ignition(); } void Server::do_service_stop() @@ -70,19 +69,17 @@ void Server::do_service_kill() void Server::do_service_await_live() { - if (m_progress_engine && m_event_hander) + if (m_progress_engine) { m_progress_engine->await_live(); - m_event_hander->await_live(); } } void Server::do_service_await_join() { - if (m_progress_engine && m_event_hander) + if (m_progress_engine) { m_progress_engine->await_join(); - m_event_hander->await_join(); } } diff --git a/cpp/mrc/src/internal/grpc/server.hpp b/cpp/mrc/src/internal/grpc/server.hpp index cacd4602d..db9436d95 100644 --- a/cpp/mrc/src/internal/grpc/server.hpp +++ b/cpp/mrc/src/internal/grpc/server.hpp @@ -34,6 +34,10 @@ namespace mrc::runnable { class Runner; } // namespace mrc::runnable +namespace mrc::rpc { +class PromiseHandler; +} // namespace mrc::rpc + namespace mrc::rpc { class Server : public Service @@ -61,7 +65,7 @@ class Server : public Service std::shared_ptr m_cq; std::unique_ptr m_server; std::unique_ptr m_progress_engine; - std::unique_ptr m_event_hander; + std::unique_ptr m_event_hander; }; } // namespace mrc::rpc diff --git a/cpp/mrc/src/tests/test_control_plane.cpp b/cpp/mrc/src/tests/test_control_plane.cpp index 96d85945c..c4ee114ce 100644 --- a/cpp/mrc/src/tests/test_control_plane.cpp +++ b/cpp/mrc/src/tests/test_control_plane.cpp @@ -27,6 +27,7 @@ #include "internal/runnable/runnable_resources.hpp" #include "internal/runtime/partition.hpp" #include "internal/runtime/runtime.hpp" +#include "internal/system/partition.hpp" #include "internal/system/partitions.hpp" #include "internal/system/system.hpp" #include "internal/system/system_provider.hpp" @@ -43,7 +44,6 @@ #include "mrc/pubsub/subscriber.hpp" #include "mrc/types.hpp" -#include #include #include #include @@ -121,6 +121,35 @@ TEST_F(TestControlPlane, SingleClientConnectDisconnect) server->service_await_join(); } +TEST_F(TestControlPlane, SingleClientConnectDisconnectSingleCore) +{ + // Similar to SingleClientConnectDisconnect except both client & server are locked to the same core + // making issue #379 easier to reproduce. + auto sr = make_runtime([](Options& options) { + options.topology().user_cpuset("0"); + }); + auto server = std::make_unique(sr->partition(0).resources().runnable()); + + server->service_start(); + server->service_await_live(); + + auto cr = make_runtime([](Options& options) { + options.topology().user_cpuset("0"); + options.architect_url("localhost:13337"); + }); + + // the total number of partition is system dependent + auto expected_partitions = cr->resources().system().partitions().flattened().size(); + EXPECT_EQ(cr->partition(0).resources().network()->control_plane().client().connections().instance_ids().size(), + expected_partitions); + + // destroying the resources should gracefully shutdown the data plane and the control plane. + cr.reset(); + + server->service_stop(); + server->service_await_join(); +} + TEST_F(TestControlPlane, DoubleClientConnectExchangeDisconnect) { auto sr = make_runtime();