diff --git a/cpp/mrc/include/mrc/edge/edge_builder.hpp b/cpp/mrc/include/mrc/edge/edge_builder.hpp index 78e88b577..e4bd0e571 100644 --- a/cpp/mrc/include/mrc/edge/edge_builder.hpp +++ b/cpp/mrc/include/mrc/edge/edge_builder.hpp @@ -285,8 +285,11 @@ class DeferredWritableMultiEdge : public MultiEdgeHolder, public DeferredWritableMultiEdgeBase { public: - DeferredWritableMultiEdge(determine_indices_fn_t indices_fn = nullptr, bool deep_copy = false) : - m_indices_fn(std::move(indices_fn)) + DeferredWritableMultiEdge(determine_indices_fn_t indices_fn = nullptr, + bool deep_copy = false, + std::string name = std::string()) : + m_indices_fn(std::move(indices_fn)), + MultiEdgeHolder(std::move(name)) { // // Generate warning if deep_copy = True but type does not support it // if constexpr (!std::is_copy_constructible_v) diff --git a/cpp/mrc/include/mrc/edge/edge_holder.hpp b/cpp/mrc/include/mrc/edge/edge_holder.hpp index 0262a7e71..6486d955d 100644 --- a/cpp/mrc/include/mrc/edge/edge_holder.hpp +++ b/cpp/mrc/include/mrc/edge/edge_holder.hpp @@ -37,7 +37,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -50,7 +52,8 @@ template class EdgeHolder { public: - EdgeHolder() = default; + EdgeHolder(std::string name = std::string()) : m_name(std::move(name)){}; + virtual ~EdgeHolder() { // Drop any edge connections before this object goes out of scope. This should execute any disconnectors @@ -58,12 +61,35 @@ class EdgeHolder if (this->check_active_connection(false)) { - LOG(FATAL) << "A node was destructed which still had dependent connections. Nodes must be kept alive while " - "dependent connections are still active"; + LOG(FATAL) << "EdgeHolder(" << this << ")[" << m_name << "] " + << "A node was destructed which still had dependent connections. Nodes must be kept alive while " + "dependent connections are still active\n" + << this->connection_info(); } } + const std::string& name() const + { + return m_name; + }; + protected: + std::string connection_info() const + { + std::stringstream ss; + ss << "m_owned_edge=" << m_owned_edge.lock() << "\tm_owned_edge_lifetime=" << m_owned_edge_lifetime + << "\tm_connected_edge=" << m_connected_edge; + + bool is_connected = false; + if (m_connected_edge) + { + is_connected = m_connected_edge->is_connected(); + } + + ss << "\tis_connected=" << is_connected << "\tcheck_active_connection=" << this->check_active_connection(false); + return ss.str(); + } + bool check_active_connection(bool do_throw = true) const { // Alive connection exists when the lock is true, lifetime is false or a connction object has been set @@ -155,6 +181,13 @@ class EdgeHolder m_connected_edge.reset(); } + void disconnect() + { + m_connected_edge.reset(); + m_owned_edge_lifetime.reset(); + m_owned_edge.reset(); + } + const std::shared_ptr>& get_connected_edge() const { return m_connected_edge; @@ -188,6 +221,8 @@ class EdgeHolder // Holds a pointer to any set edge (different from init edge). Maintains lifetime std::shared_ptr> m_connected_edge; + std::string m_name; + // Allow edge builder to call set_edge friend EdgeBuilder; @@ -200,10 +235,25 @@ template class MultiEdgeHolder { public: - MultiEdgeHolder() = default; + MultiEdgeHolder(std::string name = std::string()) : m_name(std::move(name)){}; virtual ~MultiEdgeHolder() = default; + const std::string& name() const + { + return m_name; + }; + protected: + std::string connection_info() const + { + std::stringstream ss; + ss << "m_edges.size()=" << m_edges.size(); + for (const auto& [key, edge_pair] : m_edges) + { + ss << "\n\tkey=" << key << "\t" << edge_pair.connection_info(); + } + return ss.str(); + } void init_owned_edge(KeyT key, std::shared_ptr> edge) { auto& edge_pair = this->get_edge_pair(key, true); @@ -276,7 +326,9 @@ class MultiEdgeHolder { if (create_if_missing) { - m_edges[key] = EdgeHolder(); + std::ostringstream edge_name; + edge_name << m_name << "_" << key; + m_edges[key] = EdgeHolder(edge_name.str()); return m_edges[key]; } @@ -321,6 +373,8 @@ class MultiEdgeHolder // Keeps pairs of get_edge/set_edge for each key std::map> m_edges; + std::string m_name; + // Allow edge builder to call set_edge friend EdgeBuilder; }; diff --git a/cpp/mrc/include/mrc/manifold/composite_manifold.hpp b/cpp/mrc/include/mrc/manifold/composite_manifold.hpp index 974729468..7d77c6285 100644 --- a/cpp/mrc/include/mrc/manifold/composite_manifold.hpp +++ b/cpp/mrc/include/mrc/manifold/composite_manifold.hpp @@ -59,6 +59,17 @@ class CompositeManifold : public Manifold mrc::make_edge(*m_ingress, *m_egress); } + ~CompositeManifold() override + { + shutdown(); + }; + + void shutdown() final + { + m_ingress->shutdown(); + m_egress->shutdown(); + } + protected: IngressT& ingress() { diff --git a/cpp/mrc/include/mrc/manifold/egress.hpp b/cpp/mrc/include/mrc/manifold/egress.hpp index 781122d61..51fe23f74 100644 --- a/cpp/mrc/include/mrc/manifold/egress.hpp +++ b/cpp/mrc/include/mrc/manifold/egress.hpp @@ -35,6 +35,7 @@ struct EgressDelegate { virtual ~EgressDelegate() = default; virtual void add_output(const SegmentAddress& address, edge::IWritableProviderBase* output_sink) = 0; + virtual void shutdown(){}; }; template @@ -55,6 +56,13 @@ class TypedEgress : public EgressDelegate template class RoundRobinEgress : public node::Router, public TypedEgress { + public: + void shutdown() final + { + DVLOG(10) << "Releasing edges from manifold egress"; + node::Router::release_edge_connections(); + } + protected: SegmentAddress determine_key_for_value(const T& t) override { diff --git a/cpp/mrc/include/mrc/manifold/ingress.hpp b/cpp/mrc/include/mrc/manifold/ingress.hpp index 060446b79..2842be4ba 100644 --- a/cpp/mrc/include/mrc/manifold/ingress.hpp +++ b/cpp/mrc/include/mrc/manifold/ingress.hpp @@ -23,6 +23,8 @@ #include "mrc/node/sink_properties.hpp" #include "mrc/node/source_properties.hpp" +#include + #include namespace mrc::manifold { @@ -31,6 +33,7 @@ struct IngressDelegate { virtual ~IngressDelegate() = default; virtual void add_input(const SegmentAddress& address, edge::IWritableAcceptorBase* input_source) = 0; + virtual void shutdown(){}; }; template @@ -51,6 +54,13 @@ class TypedIngress : public IngressDelegate template class MuxedIngress : public node::Muxer, public TypedIngress { + public: + void shutdown() final + { + DVLOG(10) << "Releasing edges from manifold ingress"; + node::SourceProperties::release_edge_connection(); + } + protected: void do_add_input(const SegmentAddress& address, edge::IWritableAcceptor* source) final { diff --git a/cpp/mrc/include/mrc/manifold/interface.hpp b/cpp/mrc/include/mrc/manifold/interface.hpp index 5c3d28fa4..706487091 100644 --- a/cpp/mrc/include/mrc/manifold/interface.hpp +++ b/cpp/mrc/include/mrc/manifold/interface.hpp @@ -27,9 +27,11 @@ struct Interface virtual ~Interface() = default; virtual const PortName& port_name() const = 0; + virtual const std::string& info() const = 0; - virtual void start() = 0; - virtual void join() = 0; + virtual void start() = 0; + virtual void join() = 0; + virtual void shutdown() = 0; virtual void add_input(const SegmentAddress& address, edge::IWritableAcceptorBase* input_source) = 0; virtual void add_output(const SegmentAddress& address, edge::IWritableProviderBase* output_sink) = 0; diff --git a/cpp/mrc/include/mrc/manifold/manifold.hpp b/cpp/mrc/include/mrc/manifold/manifold.hpp index 4cb567341..046934331 100644 --- a/cpp/mrc/include/mrc/manifold/manifold.hpp +++ b/cpp/mrc/include/mrc/manifold/manifold.hpp @@ -39,12 +39,12 @@ class Manifold : public Interface ~Manifold() override; const PortName& port_name() const final; + const std::string& info() const final; + void shutdown() override; protected: runnable::IRunnableResources& resources(); - const std::string& info() const; - private: void add_input(const SegmentAddress& address, edge::IWritableAcceptorBase* input_source) final; void add_output(const SegmentAddress& address, edge::IWritableProviderBase* output_sink) final; diff --git a/cpp/mrc/include/mrc/node/generic_sink.hpp b/cpp/mrc/include/mrc/node/generic_sink.hpp index de0fbec7c..05f96af6e 100644 --- a/cpp/mrc/include/mrc/node/generic_sink.hpp +++ b/cpp/mrc/include/mrc/node/generic_sink.hpp @@ -63,7 +63,7 @@ template class GenericSinkComponent : public RxSinkComponent { public: - GenericSinkComponent() + GenericSinkComponent(std::string name = std::string()) : RxSinkComponent(std::move(name)) { RxSinkComponent::set_observer(rxcpp::make_observer_dynamic( [this](T data) { diff --git a/cpp/mrc/include/mrc/node/generic_source.hpp b/cpp/mrc/include/mrc/node/generic_source.hpp index 19956d422..96406006c 100644 --- a/cpp/mrc/include/mrc/node/generic_source.hpp +++ b/cpp/mrc/include/mrc/node/generic_source.hpp @@ -71,7 +71,7 @@ template class GenericSourceComponent : public ForwardingEgressProvider { public: - GenericSourceComponent() = default; + GenericSourceComponent(std::string name = std::string()) : m_name(std::move(name)) {} ~GenericSourceComponent() override = default; private: @@ -81,6 +81,8 @@ class GenericSourceComponent : public ForwardingEgressProvider } virtual mrc::channel::Status get_data(T& data) = 0; + + std::string m_name; }; template @@ -90,6 +92,10 @@ class LambdaSourceComponent : public GenericSourceComponent using get_data_fn_t = std::function; LambdaSourceComponent(get_data_fn_t get_data_fn) : m_get_data_fn(std::move(get_data_fn)) {} + LambdaSourceComponent(std::string name, get_data_fn_t get_data_fn) : + GenericSourceComponent(std::move(name)), + m_get_data_fn(std::move(get_data_fn)) + {} ~LambdaSourceComponent() override = default; private: diff --git a/cpp/mrc/include/mrc/node/operators/broadcast.hpp b/cpp/mrc/include/mrc/node/operators/broadcast.hpp index 553739c1c..dbb2c0598 100644 --- a/cpp/mrc/include/mrc/node/operators/broadcast.hpp +++ b/cpp/mrc/include/mrc/node/operators/broadcast.hpp @@ -34,6 +34,8 @@ namespace mrc::node { class BroadcastTypeless : public edge::IWritableProviderBase, public edge::IWritableAcceptorBase { public: + BroadcastTypeless(std::string name = std::string()) : m_name(std::move(name)) {} + std::shared_ptr get_writable_edge_handle() const override { auto* self = const_cast(this); @@ -141,6 +143,7 @@ class BroadcastTypeless : public edge::IWritableProviderBase, public edge::IWrit } private: + std::string m_name; std::mutex m_mutex; std::vector> m_upstream_handles; std::vector> m_downstream_handles; @@ -199,18 +202,18 @@ class Broadcast : public WritableProvider, public edge::IWritableAcceptor Broadcast(bool deep_copy = false) { - auto edge = std::make_shared(*this, deep_copy); - - // Save to avoid casting - m_edge = edge; + init_edge(deep_copy); + } - WritableProvider::init_owned_edge(edge); + Broadcast(std::string name, bool deep_copy = false) : m_name(std::move(name)) + { + init_edge(deep_copy); } ~Broadcast() { // Debug print - VLOG(10) << "Destroying TestBroadcast"; + VLOG(10) << "Destroying Broadcast " << m_name; } void set_writable_edge_handle(std::shared_ptr ingress) override @@ -227,11 +230,22 @@ class Broadcast : public WritableProvider, public edge::IWritableAcceptor void on_complete() { - VLOG(10) << "TestBroadcast completed"; + VLOG(10) << "Broadcast completed " << m_name; } private: + void init_edge(bool deep_copy) + { + auto edge = std::make_shared(*this, deep_copy); + + // Save to avoid casting + m_edge = edge; + + WritableProvider::init_owned_edge(edge); + } + std::weak_ptr m_edge; + std::string m_name; }; } // namespace mrc::node diff --git a/cpp/mrc/include/mrc/node/queue.hpp b/cpp/mrc/include/mrc/node/queue.hpp index 81038feaf..2bb554203 100644 --- a/cpp/mrc/include/mrc/node/queue.hpp +++ b/cpp/mrc/include/mrc/node/queue.hpp @@ -23,17 +23,28 @@ #include "mrc/node/sink_properties.hpp" #include "mrc/node/source_properties.hpp" +#include + namespace mrc::node { +using namespace std::string_literals; + template class Queue : public WritableProvider, public ReadableProvider { public: - Queue() + Queue(const std::string& name = std::string()) : + SinkProperties(name + "-sink"s), + SourceProperties(name + "-source"s) { this->set_channel(std::make_unique>()); } - ~Queue() override = default; + + ~Queue() override + { + SinkProperties::release_edge_connection(); + SourceProperties::release_edge_connection(); + }; void set_channel(std::unique_ptr> channel) { diff --git a/cpp/mrc/include/mrc/node/rx_node.hpp b/cpp/mrc/include/mrc/node/rx_node.hpp index 34a48fd0d..a467dee9d 100644 --- a/cpp/mrc/include/mrc/node/rx_node.hpp +++ b/cpp/mrc/include/mrc/node/rx_node.hpp @@ -36,6 +36,7 @@ #include #include #include +#include namespace mrc::node { @@ -50,11 +51,14 @@ class RxNode : public RxSinkBase, // function defining the stream, i.e. operations linking Sink -> Source using stream_fn_t = std::function(const rxcpp::observable&)>; - RxNode(); + RxNode(std::string name = std::string()); template RxNode(OpsT&&... ops); + template + RxNode(std::string name, OpsT&&... ops); + ~RxNode() override = default; template @@ -68,13 +72,20 @@ class RxNode : public RxSinkBase, void make_stream(stream_fn_t fn); + void set_name(std::string name) + { + RxSinkBase::m_name = name; + RxSourceBase::m_name = std::move(name); + } + + void on_shutdown_critical_section() final; + private: // the following method(s) are moved to private from their original scopes to prevent access from deriving classes using RxSinkBase::observable; using RxSourceBase::observer; void do_subscribe(rxcpp::composite_subscription& subscription) final; - void on_shutdown_critical_section() final; void on_stop(const rxcpp::subscription& subscription) override; void on_kill(const rxcpp::subscription& subscription) final; @@ -85,7 +96,11 @@ class RxNode : public RxSinkBase, }; template -RxNode::RxNode() : +RxNode::RxNode(std::string name) : + RxSinkBase{name}, + RxSourceBase{name}, + SinkProperties(name), + SourceProperties(name), m_stream([](const rxcpp::observable& obs) { // Default to just returning the input return obs; @@ -99,6 +114,15 @@ RxNode::RxNode(OpsT&&... ops) pipe(std::forward(ops)...); } +template +template +RxNode::RxNode(std::string name, OpsT&&... ops) : + RxSinkBase{name}, + RxSourceBase{name} +{ + pipe(std::forward(ops)...); +} + template void RxNode::make_stream(stream_fn_t fn) { @@ -140,7 +164,7 @@ void RxNode::on_kill(const rxcpp::subscription& subsc template void RxNode::on_shutdown_critical_section() { - DVLOG(10) << runnable::Context::get_runtime_context().info() << " releasing source channel"; + DVLOG(10) << "releasing source channel"; RxSourceBase::release_edge_connection(); } @@ -182,7 +206,7 @@ class RxNodeComponent : public WritableProvider, public WritableAcceptor public: using stream_fn_t = std::function(const rxcpp::observable&)>; - RxNodeComponent() + RxNodeComponent(std::string name = std::string()) : m_name(std::move(name)) { auto edge = std::make_shared>(m_subject.get_subscriber()); @@ -194,12 +218,23 @@ class RxNodeComponent : public WritableProvider, public WritableAcceptor this->make_stream(stream_fn); } + RxNodeComponent(std::string name, stream_fn_t stream_fn) : RxNodeComponent(std::move(name)) + { + this->make_stream(stream_fn); + } + template RxNodeComponent(OpsT&&... ops) : RxNodeComponent() { this->pipe(std::forward(ops)...); } + template + RxNodeComponent(std::string name, OpsT&&... ops) : RxNodeComponent(std::move(name)) + { + this->pipe(std::forward(ops)...); + } + template RxNodeComponent& pipe(OpsT&&... ops) { @@ -247,6 +282,7 @@ class RxNodeComponent : public WritableProvider, public WritableAcceptor private: rxcpp::subjects::subject m_subject; rxcpp::subscription m_subject_subscription; + std::string m_name; }; } // namespace mrc::node diff --git a/cpp/mrc/include/mrc/node/rx_sink.hpp b/cpp/mrc/include/mrc/node/rx_sink.hpp index 43c7fc2dd..b93f65a07 100644 --- a/cpp/mrc/include/mrc/node/rx_sink.hpp +++ b/cpp/mrc/include/mrc/node/rx_sink.hpp @@ -84,7 +84,8 @@ class RxSink : public RxSinkBase, public RxRunnable, public RxProlo using on_error_fn_t = std::function; using on_complete_fn_t = std::function; - RxSink() = default; + RxSink(std::string name = std::string()) : RxSinkBase(name), SinkProperties(name) {} + ~RxSink() override = default; template @@ -93,6 +94,12 @@ class RxSink : public RxSinkBase, public RxRunnable, public RxProlo set_observer(std::forward(args)...); } + template + RxSink(std::string name, ArgsT&&... args) : RxSinkBase(name), SinkProperties(name) + { + set_observer(std::forward(args)...); + } + void set_observer(observer_t observer); template @@ -174,11 +181,12 @@ class RxSinkComponent : public WritableProvider RxSinkComponent() { - auto edge = std::make_shared>(); - - m_sink_edge = edge; + init_edge(); + } - WritableProvider::init_owned_edge(edge); + RxSinkComponent(std::string name = std::string()) : m_name(std::move(name)) + { + init_edge(); } ~RxSinkComponent() = default; @@ -186,12 +194,12 @@ class RxSinkComponent : public WritableProvider template RxSinkComponent(ArgsT&&... args) : RxSinkComponent() { - // auto edge = std::make_shared>(); - - // m_sink_edge = edge; - - // WritableProvider::init_owned_edge(edge); + set_observer(std::forward(args)...); + } + template + RxSinkComponent(std::string name, ArgsT&&... args) : RxSinkComponent(std::move(name)) + { set_observer(std::forward(args)...); } @@ -204,7 +212,17 @@ class RxSinkComponent : public WritableProvider } private: + void init_edge() + { + auto edge = std::make_shared>(); + + m_sink_edge = edge; + + WritableProvider::init_owned_edge(edge); + } + std::weak_ptr> m_sink_edge; + std::string m_name; // observer_t m_observer; }; diff --git a/cpp/mrc/include/mrc/node/rx_sink_base.hpp b/cpp/mrc/include/mrc/node/rx_sink_base.hpp index 89d4a8c8f..294d0b98f 100644 --- a/cpp/mrc/include/mrc/node/rx_sink_base.hpp +++ b/cpp/mrc/include/mrc/node/rx_sink_base.hpp @@ -47,10 +47,11 @@ class RxSinkBase : public WritableProvider, public ReadableAcceptor, publi void sink_remove_watcher(std::shared_ptr watcher); protected: - RxSinkBase(); + RxSinkBase(std::string name = std::string()); ~RxSinkBase() override = default; const rxcpp::observable& observable() const; + std::string m_name; private: // this is our channel reader progress engine @@ -61,7 +62,9 @@ class RxSinkBase : public WritableProvider, public ReadableAcceptor, publi }; template -RxSinkBase::RxSinkBase() : +RxSinkBase::RxSinkBase(std::string name) : + SinkChannelOwner(name), + m_name(std::move(name)), m_observable(rxcpp::observable<>::create([this](rxcpp::subscriber s) { progress_engine(s); })) diff --git a/cpp/mrc/include/mrc/node/rx_source.hpp b/cpp/mrc/include/mrc/node/rx_source.hpp index 034d669ff..6e90623a3 100644 --- a/cpp/mrc/include/mrc/node/rx_source.hpp +++ b/cpp/mrc/include/mrc/node/rx_source.hpp @@ -38,6 +38,7 @@ #include #include #include +#include #include namespace mrc::node { @@ -51,8 +52,9 @@ template class RxSource : public RxSourceBase, public RxRunnable, public RxEpilogueTap { public: - RxSource() = default; + RxSource(std::string name = std::string()) : RxSourceBase(name), SourceProperties(name) {} RxSource(rxcpp::observable observable); + RxSource(std::string name, rxcpp::observable observable); ~RxSource() override = default; void set_observable(rxcpp::observable observable); @@ -72,6 +74,14 @@ RxSource::RxSource(rxcpp::observable observable) : RxSourceBase< set_observable(observable); } +template +RxSource::RxSource(std::string name, rxcpp::observable observable) : + RxSourceBase(name), + SourceProperties(name) +{ + set_observable(observable); +} + template void RxSource::on_shutdown_critical_section() { diff --git a/cpp/mrc/include/mrc/node/rx_source_base.hpp b/cpp/mrc/include/mrc/node/rx_source_base.hpp index 58876d3c6..4af410e71 100644 --- a/cpp/mrc/include/mrc/node/rx_source_base.hpp +++ b/cpp/mrc/include/mrc/node/rx_source_base.hpp @@ -55,10 +55,11 @@ class RxSourceBase : public ReadableProvider, void source_remove_watcher(std::shared_ptr watcher); protected: - RxSourceBase(); + RxSourceBase(std::string name = std::string()); ~RxSourceBase() override = default; const rxcpp::observer& observer() const; + std::string m_name; private: // // the following methods are moved to private from their original scopes to prevent access from deriving classes @@ -68,7 +69,9 @@ class RxSourceBase : public ReadableProvider, }; template -RxSourceBase::RxSourceBase() : +RxSourceBase::RxSourceBase(std::string name) : + SourceChannelOwner(name), + m_name(std::move(name)), m_observer(rxcpp::make_observer_dynamic( [this](T data) { this->watcher_epilogue(WatchableEvent::sink_on_data, true, &data); diff --git a/cpp/mrc/include/mrc/node/sink_channel_owner.hpp b/cpp/mrc/include/mrc/node/sink_channel_owner.hpp index 8997e3a8d..712afc014 100644 --- a/cpp/mrc/include/mrc/node/sink_channel_owner.hpp +++ b/cpp/mrc/include/mrc/node/sink_channel_owner.hpp @@ -42,7 +42,7 @@ class SinkChannelOwner : public virtual SinkProperties } protected: - SinkChannelOwner() = default; + SinkChannelOwner(std::string name = std::string()) : SinkProperties(std::move(name)){}; void do_set_channel(edge::EdgeChannel& edge_channel) { diff --git a/cpp/mrc/include/mrc/node/sink_properties.hpp b/cpp/mrc/include/mrc/node/sink_properties.hpp index af56b0fe0..5cad278da 100644 --- a/cpp/mrc/include/mrc/node/sink_properties.hpp +++ b/cpp/mrc/include/mrc/node/sink_properties.hpp @@ -26,6 +26,7 @@ #include #include +#include #include namespace mrc::node { @@ -109,7 +110,7 @@ class SinkProperties : public edge::EdgeHolder, public SinkPropertiesBase } protected: - SinkProperties() + SinkProperties(std::string name = std::string()) : edge::EdgeHolder(std::move(name)) { // Set the default edge to be a null one in case no connection is made this->init_connected_edge(std::make_shared>()); diff --git a/cpp/mrc/include/mrc/node/source_channel_owner.hpp b/cpp/mrc/include/mrc/node/source_channel_owner.hpp index 226492e5e..823f9f21b 100644 --- a/cpp/mrc/include/mrc/node/source_channel_owner.hpp +++ b/cpp/mrc/include/mrc/node/source_channel_owner.hpp @@ -34,6 +34,7 @@ template class SourceChannelOwner : public virtual SourceProperties { public: + SourceChannelOwner(std::string name = std::string()) : SourceProperties(std::move(name)){}; ~SourceChannelOwner() override = default; void set_channel(std::unique_ptr> channel) @@ -44,8 +45,6 @@ class SourceChannelOwner : public virtual SourceProperties } protected: - SourceChannelOwner() = default; - void do_set_channel(edge::EdgeChannel& edge_channel) { // Create 2 edges, one for reading and writing. On connection, persist the other to allow the node to still use diff --git a/cpp/mrc/include/mrc/node/source_properties.hpp b/cpp/mrc/include/mrc/node/source_properties.hpp index 12166eb43..44ef87b11 100644 --- a/cpp/mrc/include/mrc/node/source_properties.hpp +++ b/cpp/mrc/include/mrc/node/source_properties.hpp @@ -23,6 +23,7 @@ #include "mrc/edge/edge_writable.hpp" #include "mrc/node/forward.hpp" #include "mrc/type_traits.hpp" +#include "mrc/types.hpp" // for Mutex #include "mrc/utils/type_utils.hpp" #include @@ -111,7 +112,7 @@ class SourceProperties : public edge::EdgeHolder, public SourcePropertiesBase } protected: - SourceProperties() + SourceProperties(std::string name = std::string()) : edge::EdgeHolder(std::move(name)) { // Set the default edge to be a null one in case no connection is made this->init_connected_edge(std::make_shared>()); @@ -208,37 +209,68 @@ template class ForwardingEgressProvider : public ReadableProvider { protected: + struct State + { + Mutex m_mutex; + bool m_is_destroyed{false}; + }; + class ForwardingEdge : public edge::IEdgeReadable { public: - ForwardingEdge(ForwardingEgressProvider& parent) : m_parent(parent) {} + ForwardingEdge(ForwardingEgressProvider& parent, std::shared_ptr state) : + m_parent(parent), + m_state(std::move(state)) + {} ~ForwardingEdge() = default; channel::Status await_read(T& t) override { - return m_parent.get_next(t); + std::lock_guardm_mutex)> lock(m_state->m_mutex); + if (!(m_state->m_is_destroyed)) + { + return m_parent.get_next(t); + } + + return channel::Status::closed; } private: ForwardingEgressProvider& m_parent; + std::shared_ptr m_state; }; - ForwardingEgressProvider() + ForwardingEgressProvider() : m_state(std::make_shared()) { - auto inner_edge = std::make_shared(*this); + auto inner_edge = std::make_shared(*this, m_state); - inner_edge->add_disconnector([this]() { - // Only call the on_complete if we have been connected - this->on_complete(); + inner_edge->add_disconnector([this, state = m_state]() { + std::lock_guardm_mutex)> lock(state->m_mutex); + if (!(state->m_is_destroyed)) + { + // Only call the on_complete if we have been connected and `this` is still alive + this->on_complete(); + } }); ReadableProvider::init_owned_edge(inner_edge); } + ~ForwardingEgressProvider() + { + SourceProperties::disconnect(); + { + std::lock_guardm_mutex)> lock(m_state->m_mutex); + m_state->m_is_destroyed = true; + } + } + virtual channel::Status get_next(T& t) = 0; virtual void on_complete() {} + + std::shared_ptr m_state; }; } // namespace mrc::node diff --git a/cpp/mrc/include/mrc/segment/builder.hpp b/cpp/mrc/include/mrc/segment/builder.hpp index a35f571c9..46bd756ab 100644 --- a/cpp/mrc/include/mrc/segment/builder.hpp +++ b/cpp/mrc/include/mrc/segment/builder.hpp @@ -364,7 +364,7 @@ class IBuilder template std::shared_ptr> IBuilder::construct_object(std::string name, ArgsT&&... args) { - auto uptr = std::make_unique(std::forward(args)...); + auto uptr = std::make_unique(name, std::forward(args)...); return make_object(std::move(name), std::move(uptr)); } @@ -376,13 +376,13 @@ std::shared_ptr> IBuilder::make_object(std::string name, std::un if constexpr (std::is_base_of_v) { - segment_object = std::make_shared>(std::move(node)); - this->add_object(name, segment_object); + segment_object = std::make_shared>(name, std::move(node)); + this->add_object(std::move(name), segment_object); } else { - segment_object = std::make_shared>(std::move(node)); - this->add_object(name, segment_object); + segment_object = std::make_shared>(name, std::move(node)); + this->add_object(std::move(name), segment_object); } CHECK(segment_object); diff --git a/cpp/mrc/include/mrc/segment/component.hpp b/cpp/mrc/include/mrc/segment/component.hpp index 3e25f9b63..adef7b7dc 100644 --- a/cpp/mrc/include/mrc/segment/component.hpp +++ b/cpp/mrc/include/mrc/segment/component.hpp @@ -23,6 +23,7 @@ #include #include +#include #include namespace mrc::segment { @@ -32,8 +33,17 @@ class Component final : public Object { public: Component(std::unique_ptr resource) : m_resource(std::move(resource)) {} + Component(std::string name, std::unique_ptr resource) : + m_resource(std::move(resource)), + m_name(std::move(name)) + {} ~Component() final = default; + void destroy() final + { + m_resource.reset(); + } + private: ResourceT* get_object() const final { @@ -42,6 +52,7 @@ class Component final : public Object } std::unique_ptr m_resource; + std::string m_name; }; } // namespace mrc::segment diff --git a/cpp/mrc/include/mrc/segment/ingress_port.hpp b/cpp/mrc/include/mrc/segment/ingress_port.hpp index fec6d469e..50b796ef8 100644 --- a/cpp/mrc/include/mrc/segment/ingress_port.hpp +++ b/cpp/mrc/include/mrc/segment/ingress_port.hpp @@ -85,6 +85,13 @@ class IngressPort : public Object>, public IngressPortBase manifold->add_output(m_segment_address, m_source.get()); } + void destroy() final + { + DVLOG(10) << "Destroying ingress port " << this->type_name(); + m_source->on_shutdown_critical_section(); + m_source.reset(); + } + SegmentAddress m_segment_address; PortName m_port_name; std::unique_ptr> m_source; diff --git a/cpp/mrc/include/mrc/segment/object.hpp b/cpp/mrc/include/mrc/segment/object.hpp index 2ccc80094..6e6652797 100644 --- a/cpp/mrc/include/mrc/segment/object.hpp +++ b/cpp/mrc/include/mrc/segment/object.hpp @@ -74,6 +74,8 @@ struct ObjectProperties virtual runnable::LaunchOptions& launch_options() = 0; virtual const runnable::LaunchOptions& launch_options() const = 0; + + virtual void destroy(){}; }; inline ObjectProperties::~ObjectProperties() = default; @@ -152,6 +154,8 @@ template class Object : public virtual ObjectProperties { public: + Object(std::string name = std::string()) : m_name(std::move(name)){}; + ObjectT& object(); std::string name() const final; diff --git a/cpp/mrc/include/mrc/segment/runnable.hpp b/cpp/mrc/include/mrc/segment/runnable.hpp index ab5b590ca..e582e649e 100644 --- a/cpp/mrc/include/mrc/segment/runnable.hpp +++ b/cpp/mrc/include/mrc/segment/runnable.hpp @@ -41,11 +41,28 @@ class Runnable : public Object, public runnable::Launchable Runnable(ArgsT&&... args) : m_node(std::make_unique(std::forward(args)...)) {} + template + Runnable(std::string name, ArgsT&&... args) : + Object(std::move(name)), + m_node(std::make_unique(std::forward(args)...)) + {} + Runnable(std::unique_ptr node) : m_node(std::move(node)) { CHECK(m_node); } + void destroy() final + { + DVLOG(10) << "Destroying runnable " << this->type_name(); + m_node.reset(); + } + + Runnable(std::string name, std::unique_ptr node) : Object(std::move(name)), m_node(std::move(node)) + { + CHECK(m_node); + } + private: NodeT* get_object() const final; std::unique_ptr prepare_launcher(runnable::LaunchControl& launch_control) final; diff --git a/cpp/mrc/src/internal/pipeline/pipeline_instance.cpp b/cpp/mrc/src/internal/pipeline/pipeline_instance.cpp index dddd73a3c..e88d4dac5 100644 --- a/cpp/mrc/src/internal/pipeline/pipeline_instance.cpp +++ b/cpp/mrc/src/internal/pipeline/pipeline_instance.cpp @@ -37,6 +37,7 @@ #include #include +#include // for lock_guard #include #include #include @@ -89,7 +90,7 @@ void PipelineInstance::join_segment(const SegmentAddress& address) search->second->service_await_join(); } -void PipelineInstance::stop_segment(const SegmentAddress& address) +void PipelineInstance::stop_segment(const SegmentAddress& address, bool kill) { auto search = m_segments.find(address); CHECK(search != m_segments.end()); @@ -97,10 +98,9 @@ void PipelineInstance::stop_segment(const SegmentAddress& address) auto [id, rank] = segment_address_decode(address); const auto& segdef = m_definition->find_segment(id); - for (const auto& name : segdef->ingress_port_names()) + if (kill) { - DVLOG(3) << "Dropping IngressPort for " << ::mrc::segment::info(address) << " on manifold " << name; - // manifold(name).drop_output(address); + search->second->shutdown(); } search->second->service_stop(); @@ -111,6 +111,7 @@ void PipelineInstance::create_segment(const SegmentAddress& address, std::uint32 // perform our allocations on the numa domain of the intended target // CHECK_LT(partition_id, m_resources->host_resources().size()); CHECK_LT(partition_id, resources().partition_count()); + DVLOG(10) << "Enqueing Creation of segment " << ::mrc::segment::info(address); resources() .partition(partition_id) .runnable() @@ -121,7 +122,8 @@ void PipelineInstance::create_segment(const SegmentAddress& address, std::uint32 auto [id, rank] = segment_address_decode(address); auto definition = m_definition->find_segment(id); - auto segment = std::make_unique(definition, rank, *this, partition_id); + DVLOG(10) << "Creating segment " << definition->name() << " - " << ::mrc::segment::info(address); + auto segment = std::make_unique(definition, rank, *this, partition_id); for (const auto& name : definition->egress_port_names()) { @@ -149,6 +151,7 @@ void PipelineInstance::create_segment(const SegmentAddress& address, std::uint32 segment->attach_manifold(manifold); } + DVLOG(10) << "Created segment " << definition->name() << " - " << ::mrc::segment::info(address); m_segments[address] = std::move(segment); }) .get(); @@ -200,12 +203,25 @@ void PipelineInstance::do_service_stop() void PipelineInstance::do_service_kill() { + std::lock_guard guard(m_kill_mux); + DVLOG(10) << "pipeline::PipelineInstance - killing " << m_manifolds.size() << " manifolds - " << m_segments.size() + << " segments"; mark_joinable(); + for (const auto& [name, manifold] : m_manifolds) + { + DVLOG(10) << "pipeline::PipelineInstance - killing manifold " << name; + manifold->shutdown(); + } + + m_manifolds.clear(); + for (auto& [id, segment] : m_segments) { - stop_segment(id); + stop_segment(id, true); segment->service_kill(); } + + m_segments.clear(); } void PipelineInstance::do_service_await_join() @@ -228,6 +244,7 @@ void PipelineInstance::do_service_await_join() if (first_exception) { LOG(ERROR) << "pipeline::PipelineInstance - an exception was caught while awaiting on segments - rethrowing"; + do_service_kill(); std::rethrow_exception(std::move(first_exception)); } } diff --git a/cpp/mrc/src/internal/pipeline/pipeline_instance.hpp b/cpp/mrc/src/internal/pipeline/pipeline_instance.hpp index 7dc51e38e..d17b88a9b 100644 --- a/cpp/mrc/src/internal/pipeline/pipeline_instance.hpp +++ b/cpp/mrc/src/internal/pipeline/pipeline_instance.hpp @@ -51,7 +51,7 @@ class PipelineInstance final : public Service, public PipelineResources // we need to stage those object that are created into some struct/container so we can mass start them after all // object have been created void create_segment(const SegmentAddress& address, std::uint32_t partition_id); - void stop_segment(const SegmentAddress& address); + void stop_segment(const SegmentAddress& address, bool kill = false); void join_segment(const SegmentAddress& address); void remove_segment(const SegmentAddress& address); @@ -85,6 +85,7 @@ class PipelineInstance final : public Service, public PipelineResources bool m_joinable{false}; Promise m_joinable_promise; SharedFuture m_joinable_future; + Mutex m_kill_mux; }; } // namespace mrc::pipeline diff --git a/cpp/mrc/src/internal/segment/builder_definition.cpp b/cpp/mrc/src/internal/segment/builder_definition.cpp index b11614328..74dca81fe 100644 --- a/cpp/mrc/src/internal/segment/builder_definition.cpp +++ b/cpp/mrc/src/internal/segment/builder_definition.cpp @@ -28,7 +28,7 @@ #include "mrc/modules/properties/persistent.hpp" // IWYU pragma: keep #include "mrc/modules/segment_modules.hpp" #include "mrc/node/port_registry.hpp" -#include "mrc/runnable/launchable.hpp" +#include "mrc/runnable/launchable.hpp" // for Launchable #include "mrc/segment/egress_port.hpp" // IWYU pragma: keep #include "mrc/segment/ingress_port.hpp" // IWYU pragma: keep #include "mrc/segment/object.hpp" @@ -284,11 +284,32 @@ void BuilderDefinition::initialize() << ", Segment Rank: " << m_rank << ". Exception message:\n" << e.what(); - // Rethrow after logging + shutdown(); + // Rethrow after logging std::rethrow_exception(std::current_exception()); } } +void BuilderDefinition::shutdown() +{ + DVLOG(10) << "Shutting down segment: " << m_definition->name(); + for (auto& [name, obj_prop] : m_objects) + { + if (obj_prop->is_source() && !obj_prop->is_sink()) + { + DVLOG(10) << "Destroying: " << name; + obj_prop->destroy(); + } + } + + m_ingress_ports.clear(); + m_egress_ports.clear(); + m_nodes.clear(); + m_objects.clear(); + + DVLOG(10) << "Shutting down segment: " << m_definition->name() << " - done"; +} + const std::map>& BuilderDefinition::nodes() const { return m_nodes; diff --git a/cpp/mrc/src/internal/segment/builder_definition.hpp b/cpp/mrc/src/internal/segment/builder_definition.hpp index aa0c96140..ff2adc8b4 100644 --- a/cpp/mrc/src/internal/segment/builder_definition.hpp +++ b/cpp/mrc/src/internal/segment/builder_definition.hpp @@ -119,6 +119,7 @@ class BuilderDefinition : public IBuilder const SegmentDefinition& definition() const; void initialize(); + void shutdown(); const std::map>& nodes() const; const std::map>& egress_ports() const; diff --git a/cpp/mrc/src/internal/segment/segment_instance.cpp b/cpp/mrc/src/internal/segment/segment_instance.cpp index 53f66b804..cbbdb0acb 100644 --- a/cpp/mrc/src/internal/segment/segment_instance.cpp +++ b/cpp/mrc/src/internal/segment/segment_instance.cpp @@ -331,4 +331,12 @@ std::shared_ptr SegmentInstance::create_manifold(const Port return nullptr; } +void SegmentInstance::shutdown() +{ + std::lock_guard lock(m_mutex); + DVLOG(10) << m_name << " - " << info() << " - shutting down segment"; + do_service_kill(); + m_builder->shutdown(); +} + } // namespace mrc::segment diff --git a/cpp/mrc/src/internal/segment/segment_instance.hpp b/cpp/mrc/src/internal/segment/segment_instance.hpp index addd38dd1..7d18fdd4e 100644 --- a/cpp/mrc/src/internal/segment/segment_instance.hpp +++ b/cpp/mrc/src/internal/segment/segment_instance.hpp @@ -56,6 +56,7 @@ class SegmentInstance final : public Service std::shared_ptr create_manifold(const PortName& name); void attach_manifold(std::shared_ptr manifold); + void shutdown(); protected: const std::string& info() const; diff --git a/cpp/mrc/src/public/manifold/manifold.cpp b/cpp/mrc/src/public/manifold/manifold.cpp index a1a3cca25..3b4aba3eb 100644 --- a/cpp/mrc/src/public/manifold/manifold.cpp +++ b/cpp/mrc/src/public/manifold/manifold.cpp @@ -50,6 +50,8 @@ const std::string& Manifold::info() const return m_info; } +void Manifold::shutdown() {} + void Manifold::add_input(const SegmentAddress& address, edge::IWritableAcceptorBase* input_source) { DVLOG(3) << "manifold " << this->port_name() << ": connecting to upstream segment " << segment::info(address); diff --git a/cpp/mrc/src/tests/test_pipeline.cpp b/cpp/mrc/src/tests/test_pipeline.cpp index 0f23a6fa2..3785b7444 100644 --- a/cpp/mrc/src/tests/test_pipeline.cpp +++ b/cpp/mrc/src/tests/test_pipeline.cpp @@ -199,7 +199,7 @@ TEST_F(TestPipeline, Queue) auto segment = pipeline->make_segment("seg_1", [](segment::IBuilder& s) { auto source = s.make_object("source", test::nodes::infinite_int_rx_source()); - auto queue = s.make_object("queue", std::make_unique>()); + auto queue = s.make_object("queue", std::make_unique>("queue")); auto sink = s.make_object("sink", test::nodes::int_sink()); s.make_edge(source, queue); s.make_edge(queue, sink); diff --git a/cpp/mrc/tests/test_edges.cpp b/cpp/mrc/tests/test_edges.cpp index 91c6d4e09..cf68141ce 100644 --- a/cpp/mrc/tests/test_edges.cpp +++ b/cpp/mrc/tests/test_edges.cpp @@ -54,7 +54,7 @@ // IWYU thinks we need vector for make_segment // IWYU pragma: no_include -using namespace std::chrono_literals; +using namespace std::literals; TEST_CLASS(Edges); @@ -1002,6 +1002,7 @@ template class TestEdgeHolder : public edge::EdgeHolder { public: + TestEdgeHolder(std::string name) : edge::EdgeHolder(std::move(name)) {} bool has_active_connection() const { return this->check_active_connection(false); @@ -1016,11 +1017,26 @@ class TestEdgeHolder : public edge::EdgeHolder { this->init_owned_edge(std::move(edge)); } + + void call_init_connected_edge(std::shared_ptr> edge) + { + this->init_connected_edge(std::move(edge)); + } +}; + +template +class TestEdge : public edge::Edge +{ + public: + void call_connect() + { + this->connect(); + } }; TEST_F(TestEdges, EdgeHolderIsConnected) { - TestEdgeHolder edge_holder; + TestEdgeHolder edge_holder("test_holder"s); auto edge = std::make_shared>(); EXPECT_FALSE(edge_holder.has_active_connection()); @@ -1030,4 +1046,36 @@ TEST_F(TestEdges, EdgeHolderIsConnected) edge_holder.call_release_edge_connection(); EXPECT_FALSE(edge_holder.has_active_connection()); } + +TEST_F(TestEdges, EdgeHolderName) +{ + TestEdgeHolder edge_holder("test_holder"s); + EXPECT_EQ(edge_holder.name(), "test_holder"s); +} + +TEST_F(TestEdges, EdgeHolderConnectRelase) +{ + TestEdgeHolder edge_holder("test_holder"s); + auto edge = std::make_shared>(); + EXPECT_FALSE(edge_holder.has_active_connection()); + + edge_holder.call_init_connected_edge(std::make_shared>()); + EXPECT_FALSE(edge_holder.has_active_connection()); + + edge_holder.call_init_owned_edge(edge); + EXPECT_FALSE(edge_holder.has_active_connection()); + + edge->call_connect(); + EXPECT_TRUE(edge_holder.has_active_connection()); + + edge_holder.call_release_edge_connection(); + + // EdgeHolder is disconnected, but someone is still holding a reference to the edge + EXPECT_TRUE(edge_holder.has_active_connection()); + + edge.reset(); + + EXPECT_FALSE(edge_holder.has_active_connection()); +} + } // namespace mrc diff --git a/cpp/mrc/tests/test_pipeline.cpp b/cpp/mrc/tests/test_pipeline.cpp index 6d1bc4499..9ded8d65a 100644 --- a/cpp/mrc/tests/test_pipeline.cpp +++ b/cpp/mrc/tests/test_pipeline.cpp @@ -16,9 +16,9 @@ */ #include "mrc/node/rx_sink.hpp" -#include "mrc/node/rx_sink_base.hpp" +#include "mrc/node/rx_sink_base.hpp" // for RxSinkBase #include "mrc/node/rx_source.hpp" -#include "mrc/node/rx_source_base.hpp" +#include "mrc/node/rx_source_base.hpp" // for RxSourceBase #include "mrc/options/options.hpp" #include "mrc/options/topology.hpp" #include "mrc/pipeline/executor.hpp" @@ -37,6 +37,7 @@ #include #include #include +#include #include namespace mrc { @@ -104,8 +105,6 @@ TEST_F(TestPipeline, DuplicateSegments) TEST_F(TestPipeline, TwoSegment) { - GTEST_SKIP() << "#185"; - std::atomic next_count = 0; std::atomic complete_count = 0; @@ -161,6 +160,131 @@ TEST_F(TestPipeline, TwoSegment) LOG(INFO) << "Done" << std::endl; } +TEST_F(TestPipeline, SegmentInitErrorHandling) +{ + // Test to reproduce issue #360 + auto pipeline = mrc::make_pipeline(); + + auto seg = pipeline->make_segment("seg_1", [](segment::IBuilder& seg) { + auto rx_source = seg.make_source("rx_source", [](rxcpp::subscriber s) { + FAIL() << "This should not be called"; + }); + + auto rx_sink = seg.make_sink("rx_sink", + rxcpp::make_observer_dynamic( + [&](float x) { + FAIL() << "This should not be " + "called"; + }, + [&]() { + FAIL() << "This should not be " + "called"; + })); + + seg.make_edge(rx_source, rx_sink); + + throw std::runtime_error("Error in initializer"); + }); + + Executor exec(std::move(m_options)); + + exec.register_pipeline(std::move(pipeline)); + + exec.start(); + + EXPECT_THROW(exec.join(), std::runtime_error); +} + +TEST_F(TestPipeline, SegmentInitErrorHandlingFirstSeg) +{ + // Test to reproduce issue #360 + auto pipeline = mrc::make_pipeline(); + + auto seg_1 = + pipeline->make_segment("seg_1", segment::EgressPorts({"float_port"}), [](segment::IBuilder& seg) { + auto rx_source = seg.make_source("rx_source", [](rxcpp::subscriber s) { + FAIL() << "This should not be called"; + }); + + auto my_float_egress = seg.get_egress("float_port"); + + seg.make_edge(rx_source, my_float_egress); + throw std::runtime_error("Error in initializer"); + }); + + auto seg_2 = pipeline->make_segment("seg_2", + segment::IngressPorts({"float_port"}), + [&](segment::IBuilder& seg) { + auto my_float_ingress = seg.get_ingress("float_port"); + + auto rx_sink = seg.make_sink("rx_sink", + rxcpp::make_observer_dynamic( + [&](float x) { + FAIL() << "This should not be " + "called"; + }, + [&]() { + FAIL() << "This should not be " + "called"; + })); + + seg.make_edge(my_float_ingress, rx_sink); + }); + + Executor exec(std::move(m_options)); + + exec.register_pipeline(std::move(pipeline)); + + exec.start(); + + EXPECT_THROW(exec.join(), std::runtime_error); +} + +TEST_F(TestPipeline, SegmentInitErrorHandlingSecondSeg) +{ + // Test to reproduce issue #360 + auto pipeline = mrc::make_pipeline(); + + auto seg_1 = + pipeline->make_segment("seg_1", segment::EgressPorts({"float_port"}), [](segment::IBuilder& seg) { + auto rx_source = seg.make_source("rx_source", [](rxcpp::subscriber s) { + FAIL() << "This should not be called"; + }); + + auto my_float_egress = seg.get_egress("float_port"); + + seg.make_edge(rx_source, my_float_egress); + }); + + auto seg_2 = pipeline->make_segment("seg_2", + segment::IngressPorts({"float_port"}), + [&](segment::IBuilder& seg) { + auto my_float_ingress = seg.get_ingress("float_port"); + + auto rx_sink = seg.make_sink("rx_sink", + rxcpp::make_observer_dynamic( + [&](float x) { + FAIL() << "This should not be " + "called"; + }, + [&]() { + FAIL() << "This should not be " + "called"; + })); + + seg.make_edge(my_float_ingress, rx_sink); + throw std::runtime_error("Error in initializer"); + }); + + Executor exec(std::move(m_options)); + + exec.register_pipeline(std::move(pipeline)); + + exec.start(); + + EXPECT_THROW(exec.join(), std::runtime_error); +} + /* TEST_F(TestPipeline, TwoSegmentManualTag) { diff --git a/docs/quickstart/cpp/common/include/nodes.hpp b/docs/quickstart/cpp/common/include/nodes.hpp index d2e5eaf89..64706146e 100644 --- a/docs/quickstart/cpp/common/include/nodes.hpp +++ b/docs/quickstart/cpp/common/include/nodes.hpp @@ -19,12 +19,14 @@ #include +#include + namespace mrc::quickstart::cpp::common { class IntSource : public mrc::node::RxSource { public: - IntSource(); + IntSource(std::string name = std::string()); }; } // namespace mrc::quickstart::cpp::common diff --git a/docs/quickstart/cpp/common/src/nodes.cpp b/docs/quickstart/cpp/common/src/nodes.cpp index bf19a26dc..cb9a73c10 100644 --- a/docs/quickstart/cpp/common/src/nodes.cpp +++ b/docs/quickstart/cpp/common/src/nodes.cpp @@ -21,13 +21,13 @@ namespace mrc::quickstart::cpp::common { -IntSource::IntSource() : - mrc::node::RxSource(rxcpp::observable<>::create([](rxcpp::subscriber s) { - s.on_next(1); - s.on_next(2); - s.on_next(3); - s.on_completed(); - })) +IntSource::IntSource(std::string name) : + mrc::node::RxSource(std::move(name), rxcpp::observable<>::create([](rxcpp::subscriber s) { + s.on_next(1); + s.on_next(2); + s.on_next(3); + s.on_completed(); + })) {} } // namespace mrc::quickstart::cpp::common diff --git a/docs/quickstart/hybrid/mrc_qs_hybrid/common/nodes.cpp b/docs/quickstart/hybrid/mrc_qs_hybrid/common/nodes.cpp index 10b060ae6..ab5f655a7 100644 --- a/docs/quickstart/hybrid/mrc_qs_hybrid/common/nodes.cpp +++ b/docs/quickstart/hybrid/mrc_qs_hybrid/common/nodes.cpp @@ -39,7 +39,7 @@ class DataObjectSource : public mrc::pymrc::PythonSource, std::shared_ptr>; public: - DataObjectNode() : PythonNode(base_t::op_factory_from_sub_fn(build())) {} + DataObjectNode(std::string name = std::string()) : + PythonNode(std::move(name), base_t::op_factory_from_sub_fn(build())) + {} private: subscribe_fn_t build() @@ -96,7 +98,8 @@ class DataObjectSink : public mrc::pymrc::PythonSink> { public: - MyDataObjectSource(size_t count) : PythonSource(build()), m_count(count) {} + MyDataObjectSource(std::string name, size_t count) : PythonSource(std::move(name), build()), m_count(count) {} private: subscriber_fn_t build() @@ -62,7 +62,9 @@ class MyDataObjectNode using base_t = mrc::pymrc::PythonNode, std::shared_ptr>; public: - MyDataObjectNode() : PythonNode(base_t::op_factory_from_sub_fn(build())) {} + MyDataObjectNode(std::string name = std::string()) : + PythonNode(std::move(name), base_t::op_factory_from_sub_fn(build())) + {} private: subscribe_fn_t build() @@ -89,7 +91,9 @@ class MyDataObjectNode class MyDataObjectSink : public mrc::pymrc::PythonSink> { public: - MyDataObjectSink() : PythonSink(build_on_next(), build_on_complete()) {} + MyDataObjectSink(std::string name = std::string()) : + PythonSink(std::move(name), build_on_next(), build_on_complete()) + {} private: on_next_fn_t build_on_next() diff --git a/python/mrc/_pymrc/include/pymrc/node.hpp b/python/mrc/_pymrc/include/pymrc/node.hpp index f5d72e7c3..91346b706 100644 --- a/python/mrc/_pymrc/include/pymrc/node.hpp +++ b/python/mrc/_pymrc/include/pymrc/node.hpp @@ -32,6 +32,8 @@ #include "mrc/node/rx_node.hpp" #include "mrc/node/rx_sink.hpp" #include "mrc/node/rx_source.hpp" +#include "mrc/node/sink_properties.hpp" +#include "mrc/node/source_properties.hpp" #include "mrc/runnable/context.hpp" #include @@ -295,6 +297,8 @@ class PythonSink : public node::RxSink, using typename base_t::observer_t; using base_t::base_t; + + PythonSink(std::string name = std::string()) : node::SinkProperties(name), base_t(name) {} }; template @@ -325,6 +329,12 @@ class PythonNode : public node::RxNode, using base_t::base_t; + PythonNode(std::string name = std::string()) : + node::SinkProperties(name), + node::SourceProperties(name), + base_t(name) + {} + protected: static auto op_factory_from_sub_fn(subscribe_fn_t sub_fn) { @@ -380,11 +390,12 @@ class PythonSource : public node::RxSource, using base_t::base_t; - PythonSource(const subscriber_fn_t& f) : - base_t(rxcpp::observable<>::create([f](rxcpp::subscriber& s) { - // Call the wrapped subscriber function - f(s); - })) + PythonSource(std::string name, const subscriber_fn_t& f) : + node::SourceProperties(name), + base_t(name, rxcpp::observable<>::create([f](rxcpp::subscriber& s) { + // Call the wrapped subscriber function + f(s); + })) {} }; @@ -396,7 +407,10 @@ class PythonSourceComponent : public node::LambdaSourceComponent, using base_t = node::LambdaSourceComponent; public: - using base_t::base_t; + PythonSourceComponent(std::string name, const typename base_t::get_data_fn_t& f) : + node::SourceProperties(name), + base_t(name, f) + {} }; class SegmentObjectProxy diff --git a/python/mrc/tests/test_edges.cpp b/python/mrc/tests/test_edges.cpp index ccac5a2d7..731b4b0fe 100644 --- a/python/mrc/tests/test_edges.cpp +++ b/python/mrc/tests/test_edges.cpp @@ -150,7 +150,7 @@ class TestSource : public pymrc::PythonSource>, public TestSo using base_t = pymrc::PythonSource>; TestSource(std::string name, pymrc::PyHolder counter, size_t msg_count = 5) : - base_t(), + base_t(name), TestSourceImpl(std::move(name), std::move(counter), msg_count) { this->set_observable(this->build()); @@ -164,7 +164,7 @@ class TestSourceComponent : public pymrc::PythonSourceComponent>; TestSourceComponent(std::string name, pymrc::PyHolder counter, size_t msg_count = 5) : - base_t(build()), + base_t(name, build()), TestSourceImpl(std::move(name), std::move(counter), msg_count) {} @@ -227,6 +227,7 @@ class TestNode : public pymrc::PythonNode, std::shared_ptr public: TestNode(std::string name, pymrc::PyHolder counter, size_t msg_count = 5) : + base_t(name), TestNodeImpl(std::move(name), std::move(counter)) { this->make_stream(this->build_operator()); @@ -241,6 +242,7 @@ class TestNodeComponent : public pymrc::PythonNodeComponent, public: TestNodeComponent(std::string name, pymrc::PyHolder counter, size_t msg_count = 5) : + base_t(name), TestNodeImpl(std::move(name), std::move(counter)) { this->make_stream(this->build_operator()); @@ -278,8 +280,11 @@ class TestSinkImpl : public PythonTestNodeMixin template class TestSink : public pymrc::PythonSink>, public TestSinkImpl { + using base_t = pymrc::PythonSink>; + public: TestSink(std::string name, pymrc::PyHolder counter, size_t msg_count = 5) : + base_t(name), TestSinkImpl(std::move(name), std::move(counter)) { this->set_observer(this->build()); @@ -289,8 +294,11 @@ class TestSink : public pymrc::PythonSink>, public TestSinkIm template class TestSinkComponent : public pymrc::PythonSinkComponent>, public TestSinkImpl { + using base_t = pymrc::PythonSinkComponent>; + public: TestSinkComponent(std::string name, pymrc::PyHolder counter, size_t msg_count = 5) : + base_t(name), TestSinkImpl(std::move(name), std::move(counter)) { this->set_observer(this->build()); @@ -310,7 +318,7 @@ GENERATE_NODE_TYPES(TestSinkComponent, SinkComponent); std::shared_ptr>>(py_mod, #class_name) \ .def(py::init<>( \ [](mrc::segment::IBuilder& parent, const std::string& name, py::dict counter, size_t msg_count) { \ - auto stage = parent.construct_object(name, name, std::move(counter), msg_count); \ + auto stage = parent.construct_object(name, std::move(counter), msg_count); \ return stage; \ }), \ py::arg("parent"), \ diff --git a/python/tests/test_edges.py b/python/tests/test_edges.py index 98ed11d0e..a9c359518 100644 --- a/python/tests/test_edges.py +++ b/python/tests/test_edges.py @@ -14,6 +14,7 @@ # limitations under the License. import itertools +import logging import typing import pytest @@ -23,9 +24,12 @@ import mrc.core.operators as ops import mrc.tests.test_edges_cpp as m +mrc.logging.init_logging("test_edges") +mrc.logging.set_level(logging.INFO) -@pytest.fixture -def ex_runner(): + +@pytest.fixture(params=[mrc.core.options.EngineType.Thread, mrc.core.options.EngineType.Fiber], ids=["thread", "fiber"]) +def ex_runner(request: pytest.FixtureRequest): def run_exec(segment_init): pipeline = mrc.Pipeline() @@ -36,6 +40,7 @@ def run_exec(segment_init): # Set to 1 thread options.topology.user_cpuset = "0-0" + options.engine_factories.default_engine_type = request.param executor = mrc.Executor(options) @@ -431,14 +436,15 @@ def fail_if_more_derived_type(combo: typing.Tuple): @pytest.mark.parametrize("source_cpp", [True, False], ids=["source_cpp", "source_py"]) @pytest.mark.parametrize("sink1_cpp", [True, False], ids=["sink1_cpp", "sink2_py"]) @pytest.mark.parametrize("sink2_cpp", [True, False], ids=["sink2_cpp", "sink2_py"]) -@pytest.mark.parametrize("source_type,sink1_type,sink2_type", - gen_parameters("source", - "sink1", - "sink2", - is_fail_fn=fail_if_more_derived_type, - values={ - "base": m.Base, "derived": m.DerivedA - })) +@pytest.mark.parametrize( + "source_type,sink1_type,sink2_type", + gen_parameters("source", + "sink1", + "sink2", + is_fail_fn=fail_if_more_derived_type, + values={ + "base": m.Base, "derived": m.DerivedA + })) def test_source_to_broadcast_to_sinks(run_segment, sink1_component: bool, sink2_component: bool, @@ -504,12 +510,10 @@ def segment_init(seg: mrc.Builder): @pytest.mark.parametrize("source_cpp", [True, False], ids=["source_cpp", "source_py"]) -@pytest.mark.parametrize("source_type", - gen_parameters("source", - is_fail_fn=lambda _: False, - values={ - "base": m.Base, "derived": m.DerivedA - })) +@pytest.mark.parametrize( + "source_type", gen_parameters("source", is_fail_fn=lambda _: False, values={ + "base": m.Base, "derived": m.DerivedA + })) def test_source_to_null(run_segment, source_cpp: bool, source_type: type): def segment_init(seg: mrc.Builder): @@ -522,24 +526,24 @@ def segment_init(seg: mrc.Builder): assert results == expected_node_counts -@pytest.mark.parametrize("source_cpp,node_cpp", - gen_parameters("source", "node", is_fail_fn=lambda _: False, values={ - "cpp": True, "py": False - })) -@pytest.mark.parametrize("source_type,node_type", - gen_parameters("source", - "node", - is_fail_fn=fail_if_more_derived_type, - values={ - "base": m.Base, "derived": m.DerivedA - })) -@pytest.mark.parametrize("source_component,node_component", - gen_parameters("source", - "node", - is_fail_fn=lambda x: x[0] and x[1], - values={ - "run": False, "com": True - })) +@pytest.mark.parametrize( + "source_cpp,node_cpp", + gen_parameters("source", "node", is_fail_fn=lambda _: False, values={ + "cpp": True, "py": False + })) +@pytest.mark.parametrize( + "source_type,node_type", + gen_parameters("source", + "node", + is_fail_fn=fail_if_more_derived_type, + values={ + "base": m.Base, "derived": m.DerivedA + })) +@pytest.mark.parametrize( + "source_component,node_component", + gen_parameters("source", "node", is_fail_fn=lambda x: x[0] and x[1], values={ + "run": False, "com": True + })) def test_source_to_node_to_null(run_segment, source_cpp: bool, node_cpp: bool, diff --git a/python/tests/test_pipeline.py b/python/tests/test_pipeline.py index 7f70068e6..49b5a891f 100644 --- a/python/tests/test_pipeline.py +++ b/python/tests/test_pipeline.py @@ -19,7 +19,8 @@ # from functools import partial # import numpy as np -# import pytest + +import pytest import mrc import mrc.tests.test_edges_cpp as m @@ -445,9 +446,49 @@ def on_complete(): executor.join() +def test_segment_init_error(): + """ + Test for issue #360 + """ + + def gen_data(): + yield 1 + + def init1(builder: mrc.Builder): + source = builder.make_source("source", gen_data) + egress = builder.get_egress("b") + builder.make_edge(source, egress) + raise RuntimeError("Test for #360") + + def init2(builder: mrc.Builder): + + def on_next(input): + pass + + ingress = builder.get_ingress("b") + sink = builder.make_sink("sink", on_next) + + builder.make_edge(ingress, sink) + + pipe = mrc.Pipeline() + + pipe.make_segment("TestSegment1", [], [("b", int, False)], init1) + pipe.make_segment("TestSegment2", [("b", int, False)], [], init2) + + options = mrc.Options() + + executor = mrc.Executor(options) + executor.register_pipeline(pipe) + + with pytest.raises(RuntimeError): + executor.start() + executor.join() + + if (__name__ in ("__main__", )): test_dynamic_port_creation_good() test_dynamic_port_creation_bad() test_ingress_egress_custom_type_construction() test_dynamic_port_get_ingress_egress() test_dynamic_port_with_type_get_ingress_egress() + test_segment_init_error()