diff --git a/cpp/mrc/include/mrc/node/node_parent.hpp b/cpp/mrc/include/mrc/node/node_parent.hpp new file mode 100644 index 000000000..51f1f36be --- /dev/null +++ b/cpp/mrc/include/mrc/node/node_parent.hpp @@ -0,0 +1,41 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mrc::node { + +template +class NodeParent +{ + public: + using child_types_t = std::tuple; + + virtual std::tuple>...> get_children_refs() const = 0; +}; + +} // namespace mrc::node diff --git a/cpp/mrc/include/mrc/node/operators/zip.hpp b/cpp/mrc/include/mrc/node/operators/zip.hpp index b73321fd7..f06a39657 100644 --- a/cpp/mrc/include/mrc/node/operators/zip.hpp +++ b/cpp/mrc/include/mrc/node/operators/zip.hpp @@ -20,8 +20,11 @@ #include "mrc/channel/buffered_channel.hpp" #include "mrc/channel/channel.hpp" #include "mrc/channel/status.hpp" +#include "mrc/node/node_parent.hpp" #include "mrc/node/sink_properties.hpp" #include "mrc/node/source_properties.hpp" +#include "mrc/types.hpp" +#include "mrc/utils/string_utils.hpp" #include "mrc/utils/tuple_utils.hpp" #include "mrc/utils/type_utils.hpp" @@ -42,8 +45,16 @@ namespace mrc::node { +class ZipBase +{ + public: + virtual ~ZipBase() = default; +}; + template -class Zip : public WritableAcceptor> +class Zip : public ZipBase, + public WritableAcceptor>, + public NodeParent...> { template using queue_t = BufferedChannel; @@ -63,6 +74,13 @@ class Zip : public WritableAcceptor> return std::make_tuple(std::make_unique>(channel_size)...); } + template + static std::tuple>>...> + build_child_pairs(Zip* self, std::index_sequence /*unused*/) + { + return std::make_tuple(std::make_pair(MRC_CONCAT_STR("sink[" << Is << "]"), std::ref(self->get_sink()))...); + } + template channel::Status tuple_pop_each(queues_tuple_type& queues_tuple, output_t& output_tuple) { @@ -89,12 +107,18 @@ class Zip : public WritableAcceptor> m_queue_counts.fill(0); } - virtual ~Zip() = default; + ~Zip() override = default; template - std::shared_ptr>> get_sink() const + edge::IWritableProvider>& get_sink() const + { + return *std::get(m_upstream_holders); + } + + std::tuple>>...> get_children_refs() + const override { - return std::get(m_upstream_holders); + return build_child_pairs(const_cast(this), std::index_sequence_for{}); } protected: @@ -242,7 +266,7 @@ class Zip : public WritableAcceptor> } } - boost::fibers::mutex m_mutex; + mutable Mutex m_mutex; // Once an upstream is closed, this is set representing the max number of values in a queue before its closed size_t m_max_queue_count{std::numeric_limits::max()}; diff --git a/cpp/mrc/include/mrc/segment/component.hpp b/cpp/mrc/include/mrc/segment/component.hpp index 3e25f9b63..462a49bff 100644 --- a/cpp/mrc/include/mrc/segment/component.hpp +++ b/cpp/mrc/include/mrc/segment/component.hpp @@ -31,7 +31,13 @@ template class Component final : public Object { public: - Component(std::unique_ptr resource) : m_resource(std::move(resource)) {} + Component(std::unique_ptr resource) : + ObjectProperties(Object::build_state()), + Object(), + m_resource(std::move(resource)) + { + this->init_children(); + } ~Component() final = default; private: diff --git a/cpp/mrc/include/mrc/segment/egress_port.hpp b/cpp/mrc/include/mrc/segment/egress_port.hpp index 7fe52a5ce..909d0e1b2 100644 --- a/cpp/mrc/include/mrc/segment/egress_port.hpp +++ b/cpp/mrc/include/mrc/segment/egress_port.hpp @@ -59,10 +59,14 @@ class EgressPort final : public Object>, public: EgressPort(SegmentAddress address, PortName name) : + ObjectProperties(Object>::build_state()), m_segment_address(address), m_port_name(std::move(name)), m_sink(std::make_unique>()) - {} + { + // Must call after constructing Object + this->init_children(); + } private: node::RxSinkBase* get_object() const final diff --git a/cpp/mrc/include/mrc/segment/ingress_port.hpp b/cpp/mrc/include/mrc/segment/ingress_port.hpp index fec6d469e..8757f5a70 100644 --- a/cpp/mrc/include/mrc/segment/ingress_port.hpp +++ b/cpp/mrc/include/mrc/segment/ingress_port.hpp @@ -53,10 +53,14 @@ class IngressPort : public Object>, public IngressPortBase public: IngressPort(SegmentAddress address, PortName name) : + ObjectProperties(Object>::build_state()), m_segment_address(address), m_port_name(std::move(name)), m_source(std::make_unique>()) - {} + { + // Must call after constructing Object + this->init_children(); + } private: node::RxSourceBase* get_object() const final diff --git a/cpp/mrc/include/mrc/segment/object.hpp b/cpp/mrc/include/mrc/segment/object.hpp index 2ccc80094..e7b291960 100644 --- a/cpp/mrc/include/mrc/segment/object.hpp +++ b/cpp/mrc/include/mrc/segment/object.hpp @@ -19,28 +19,64 @@ #include "mrc/channel/ingress.hpp" #include "mrc/edge/edge_builder.hpp" +#include "mrc/edge/edge_readable.hpp" +#include "mrc/edge/edge_writable.hpp" #include "mrc/exceptions/runtime_error.hpp" #include "mrc/node/forward.hpp" +#include "mrc/node/node_parent.hpp" #include "mrc/node/sink_properties.hpp" #include "mrc/node/source_properties.hpp" #include "mrc/node/type_traits.hpp" #include "mrc/runnable/launch_options.hpp" #include "mrc/runnable/runnable.hpp" #include "mrc/segment/forward.hpp" +#include "mrc/type_traits.hpp" +#include "mrc/utils/tuple_utils.hpp" +#include +#include #include #include #include +#include namespace mrc::segment { -struct ObjectProperties +template +class SharedObject; + +template +class ReferencedObject; + +struct ObjectPropertiesState { + std::string name; + std::string type_name; + bool is_writable_acceptor; + bool is_writable_provider; + bool is_readable_acceptor; + bool is_readable_provider; +}; + +class ObjectProperties +{ + public: virtual ~ObjectProperties() = 0; - virtual void set_name(const std::string& name) = 0; - virtual std::string name() const = 0; - virtual std::string type_name() const = 0; + virtual void set_name(const std::string& name) + { + m_state->name = name; + } + + virtual std::string name() const + { + return m_state->name; + } + + virtual std::string type_name() const + { + return m_state->type_name; + } virtual bool is_sink() const = 0; virtual bool is_source() const = 0; @@ -48,10 +84,22 @@ struct ObjectProperties virtual std::type_index sink_type(bool ignore_holder = false) const = 0; virtual std::type_index source_type(bool ignore_holder = false) const = 0; - virtual bool is_writable_acceptor() const = 0; - virtual bool is_writable_provider() const = 0; - virtual bool is_readable_acceptor() const = 0; - virtual bool is_readable_provider() const = 0; + bool is_writable_acceptor() const + { + return m_state->is_writable_acceptor; + } + bool is_writable_provider() const + { + return m_state->is_writable_provider; + } + bool is_readable_acceptor() const + { + return m_state->is_readable_acceptor; + } + bool is_readable_provider() const + { + return m_state->is_readable_provider; + } virtual edge::IWritableAcceptorBase& writable_acceptor_base() = 0; virtual edge::IWritableProviderBase& writable_provider_base() = 0; @@ -74,6 +122,20 @@ struct ObjectProperties virtual runnable::LaunchOptions& launch_options() = 0; virtual const runnable::LaunchOptions& launch_options() const = 0; + + virtual std::shared_ptr get_child(const std::string& name) const = 0; + virtual std::map> get_children() const = 0; + + protected: + ObjectProperties(std::shared_ptr state) : m_state(std::move(state)) {} + + std::shared_ptr get_state() const + { + return m_state; + } + + private: + std::shared_ptr m_state; }; inline ObjectProperties::~ObjectProperties() = default; @@ -147,15 +209,33 @@ edge::IReadableProvider& ObjectProperties::readable_provider_typed() } // Object - template -class Object : public virtual ObjectProperties +class Object : public virtual ObjectProperties, public std::enable_shared_from_this> { + protected: + static std::shared_ptr build_state() + { + auto state = std::make_shared(); + + state->type_name = std::string(::mrc::type_name()); + state->is_writable_acceptor = std::is_base_of_v; + state->is_writable_provider = std::is_base_of_v; + state->is_readable_acceptor = std::is_base_of_v; + state->is_readable_provider = std::is_base_of_v; + + return state; + } + public: + // Object(const Object& other) : m_name(other.m_name), m_launch_options(other.m_launch_options) {} + // Object(Object&&) = delete; + // Object& operator=(const Object&) = delete; + // Object& operator=(Object&&) = delete; + ObjectT& object(); - std::string name() const final; - std::string type_name() const final; + // std::string name() const final; + // std::string type_name() const final; bool is_source() const final; bool is_sink() const final; @@ -163,10 +243,10 @@ class Object : public virtual ObjectProperties std::type_index sink_type(bool ignore_holder) const final; std::type_index source_type(bool ignore_holder) const final; - bool is_writable_acceptor() const final; - bool is_writable_provider() const final; - bool is_readable_acceptor() const final; - bool is_readable_provider() const final; + // bool is_writable_acceptor() const final; + // bool is_writable_provider() const final; + // bool is_readable_acceptor() const final; + // bool is_readable_provider() const final; edge::IWritableAcceptorBase& writable_acceptor_base() final; edge::IWritableProviderBase& writable_provider_base() final; @@ -198,15 +278,106 @@ class Object : public virtual ObjectProperties return m_launch_options; } + std::shared_ptr get_child(const std::string& name) const override + { + CHECK(m_children.contains(name)) << "Child " << name << " not found in " << this->name(); + + if (auto child = m_children.at(name).lock()) + { + return child; + } + + auto* mutable_this = const_cast(this); + + // Otherwise, we need to build one + auto child = mutable_this->m_create_children_fns.at(name)(); + + mutable_this->m_children[name] = child; + + return child; + } + + std::map> get_children() const override + { + std::map> children; + + for (const auto& [name, child] : m_children) + { + children[name] = this->get_child(name); + } + + return children; + } + + template + requires std::derived_from + std::shared_ptr> as() const + { + auto shared_object = std::make_shared>(*const_cast(this)); + + return shared_object; + } + protected: - // Move to protected to allow only the IBuilder to set the name - void set_name(const std::string& name) override; + Object() : ObjectProperties(build_state()) + { + LOG(INFO) << "Creating Object '" << this->name() << "' with type: " << this->type_name(); + } - private: - std::string m_name{}; + template + requires std::derived_from + Object(const Object& other) : + ObjectProperties(other), + m_launch_options(other.m_launch_options), + m_children(other.m_children), + m_create_children_fns(other.m_create_children_fns) + { + LOG(INFO) << "Copying Object '" << this->name() << "' from type: " << other.type_name() + << " to type: " << this->type_name(); + } + + void init_children() + { + if constexpr (is_base_of_template::value) + { + using child_types_t = typename ObjectT::child_types_t; + // Get the name/reference pairs from the NodeParent + auto children_ref_pairs = this->object().get_children_refs(); + + // Finally, convert the tuple of name/ChildObject pairs into a map + utils::tuple_for_each( + children_ref_pairs, + [this](std::pair>& pair, + size_t idx) { + // auto child_obj = std::make_shared>(this->shared_from_this(), + // pair.second); + + // m_children.emplace(std::move(pair.first), std::move(child_obj)); + + m_children.emplace(pair.first, std::weak_ptr()); + + m_create_children_fns.emplace(pair.first, [this, obj_ref = pair.second]() { + return std::make_shared>(this->shared_from_this(), obj_ref); + }); + }); + } + } + + // // Move to protected to allow only the IBuilder to set the name + // void set_name(const std::string& name) override; + + private: virtual ObjectT* get_object() const = 0; + runnable::LaunchOptions m_launch_options; + + std::map> m_children; + std::map()>> m_create_children_fns; + + // Allows converting to base classes + template + friend class Object; }; template @@ -222,23 +393,23 @@ ObjectT& Object::object() return *node; } -template -void Object::set_name(const std::string& name) -{ - m_name = name; -} +// template +// void Object::set_name(const std::string& name) +// { +// m_name = name; +// } -template -std::string Object::name() const -{ - return m_name; -} +// template +// std::string Object::name() const +// { +// return m_name; +// } -template -std::string Object::type_name() const -{ - return std::string(::mrc::type_name()); -} +// template +// std::string Object::type_name() const +// { +// return std::string(::mrc::type_name()); +// } template bool Object::is_source() const @@ -276,83 +447,130 @@ std::type_index Object::source_type(bool ignore_holder) const return base->source_type(ignore_holder); } -template -bool Object::is_writable_acceptor() const -{ - return std::is_base_of_v; -} +// template +// bool Object::is_writable_acceptor() const +// { +// return std::is_base_of_v; +// } + +// template +// bool Object::is_writable_provider() const +// { +// return std::is_base_of_v; +// } + +// template +// bool Object::is_readable_acceptor() const +// { +// return std::is_base_of_v; +// } + +// template +// bool Object::is_readable_provider() const +// { +// return std::is_base_of_v; +// } template -bool Object::is_writable_provider() const +edge::IWritableAcceptorBase& Object::writable_acceptor_base() { - return std::is_base_of_v; -} + // if constexpr (!std::is_base_of_v) + // { + // LOG(ERROR) << type_name() << " is not a IIngressAcceptorBase"; + // throw exceptions::MrcRuntimeError("Object is not a IIngressAcceptorBase"); + // } -template -bool Object::is_readable_acceptor() const -{ - return std::is_base_of_v; + auto* base = dynamic_cast(get_object()); + CHECK(base) << type_name() << " is not a IIngressAcceptorBase"; + return *base; } template -bool Object::is_readable_provider() const +edge::IWritableProviderBase& Object::writable_provider_base() { - return std::is_base_of_v; + // if constexpr (!std::is_base_of_v) + // { + // LOG(ERROR) << type_name() << " is not a IIngressProviderBase"; + // throw exceptions::MrcRuntimeError("Object is not a IIngressProviderBase"); + // } + + auto* base = dynamic_cast(get_object()); + CHECK(base) << type_name() << " is not a IWritableProviderBase"; + return *base; } template -edge::IWritableAcceptorBase& Object::writable_acceptor_base() +edge::IReadableAcceptorBase& Object::readable_acceptor_base() { - if constexpr (!std::is_base_of_v) - { - LOG(ERROR) << type_name() << " is not a IIngressAcceptorBase"; - throw exceptions::MrcRuntimeError("Object is not a IIngressAcceptorBase"); - } + // if constexpr (!std::is_base_of_v) + // { + // LOG(ERROR) << type_name() << " is not a IEgressAcceptorBase"; + // throw exceptions::MrcRuntimeError("Object is not a IEgressAcceptorBase"); + // } - auto* base = dynamic_cast(get_object()); - CHECK(base); + auto* base = dynamic_cast(get_object()); + CHECK(base) << type_name() << " is not a IReadableAcceptorBase"; return *base; } template -edge::IWritableProviderBase& Object::writable_provider_base() +edge::IReadableProviderBase& Object::readable_provider_base() { - if constexpr (!std::is_base_of_v) - { - LOG(ERROR) << type_name() << " is not a IIngressProviderBase"; - throw exceptions::MrcRuntimeError("Object is not a IIngressProviderBase"); - } + // if constexpr (!std::is_base_of_v) + // { + // LOG(ERROR) << type_name() << " is not a IEgressProviderBase"; + // throw exceptions::MrcRuntimeError("Object is not a IEgressProviderBase"); + // } - auto* base = dynamic_cast(get_object()); - CHECK(base); + auto* base = dynamic_cast(get_object()); + CHECK(base) << type_name() << " is not a IReadableProviderBase"; return *base; } template -edge::IReadableAcceptorBase& Object::readable_acceptor_base() +class SharedObject final : public Object { - if constexpr (!std::is_base_of_v) + public: + SharedObject(std::shared_ptr owner, std::reference_wrapper resource) : + ObjectProperties(Object::build_state()), + m_owner(std::move(owner)), + m_resource(std::move(resource)) + {} + ~SharedObject() final = default; + + private: + ObjectT* get_object() const final { - LOG(ERROR) << type_name() << " is not a IEgressAcceptorBase"; - throw exceptions::MrcRuntimeError("Object is not a IEgressAcceptorBase"); + return &m_resource.get(); } - auto* base = dynamic_cast(get_object()); - CHECK(base); - return *base; -} + std::shared_ptr m_owner; + std::reference_wrapper m_resource; +}; template -edge::IReadableProviderBase& Object::readable_provider_base() +class ReferencedObject final : public Object { - if constexpr (!std::is_base_of_v) + public: + template + requires std::derived_from + ReferencedObject(Object& other) : + ObjectProperties(other), + Object(other), + m_owner(other.shared_from_this()), + m_resource(other.object()) + {} + + ~ReferencedObject() final = default; + + private: + ObjectT* get_object() const final { - LOG(ERROR) << type_name() << " is not a IEgressProviderBase"; - throw exceptions::MrcRuntimeError("Object is not a IEgressProviderBase"); + return &m_resource.get(); } - auto* base = dynamic_cast(get_object()); - CHECK(base); - return *base; -} + std::shared_ptr m_owner; + std::reference_wrapper m_resource; +}; + } // namespace mrc::segment diff --git a/cpp/mrc/include/mrc/segment/runnable.hpp b/cpp/mrc/include/mrc/segment/runnable.hpp index ab5b590ca..b40e01a00 100644 --- a/cpp/mrc/include/mrc/segment/runnable.hpp +++ b/cpp/mrc/include/mrc/segment/runnable.hpp @@ -37,15 +37,20 @@ template class Runnable : public Object, public runnable::Launchable { public: - template - Runnable(ArgsT&&... args) : m_node(std::make_unique(std::forward(args)...)) - {} - - Runnable(std::unique_ptr node) : m_node(std::move(node)) + Runnable(std::unique_ptr node) : + ObjectProperties(Object::build_state()), + Object(), + m_node(std::move(node)) { CHECK(m_node); + + this->init_children(); } + template + Runnable(ArgsT&&... args) : Runnable(std::make_unique(std::forward(args)...)) + {} + private: NodeT* get_object() const final; std::unique_ptr prepare_launcher(runnable::LaunchControl& launch_control) final; diff --git a/cpp/mrc/src/internal/segment/builder_definition.cpp b/cpp/mrc/src/internal/segment/builder_definition.cpp index e631c3f1e..13e211483 100644 --- a/cpp/mrc/src/internal/segment/builder_definition.cpp +++ b/cpp/mrc/src/internal/segment/builder_definition.cpp @@ -164,7 +164,7 @@ std::shared_ptr BuilderDefinition::get_egress(std::string name void BuilderDefinition::init_module(std::shared_ptr smodule) { - this->ns_push(smodule); + this->module_push(smodule); VLOG(2) << "Initializing module: " << m_namespace_prefix; smodule->m_module_instance_registered_namespace = m_namespace_prefix; smodule->initialize(*this); @@ -177,7 +177,8 @@ void BuilderDefinition::init_module(std::shared_ptr // Just save to a vector to keep it alive m_modules.push_back(persist); } - this->ns_pop(); + + this->module_pop(smodule); } void BuilderDefinition::register_module_input(std::string input_name, std::shared_ptr object) @@ -366,6 +367,24 @@ void BuilderDefinition::add_object(const std::string& name, std::shared_ptr<::mr // Save by the original name m_egress_ports[local_name] = egress_port; } + + // Now register any child objects + auto children = object->get_children(); + + if (!children.empty()) + { + // Push the namespace for this object + this->ns_push(local_name); + + for (auto& [child_name, child_object] : children) + { + // Add the child object + this->add_object(child_name, child_object); + } + + // Pop the namespace for this object + this->ns_pop(local_name); + } } std::shared_ptr<::mrc::segment::IngressPortBase> BuilderDefinition::get_ingress_base(const std::string& name) @@ -402,20 +421,43 @@ std::function BuilderDefinition::make_throughput_counter(con }; } -void BuilderDefinition::ns_push(std::shared_ptr smodule) +std::string BuilderDefinition::module_push(std::shared_ptr smodule) { m_module_stack.push_back(smodule); - m_namespace_stack.push_back(smodule->component_prefix()); + + return this->ns_push(smodule->component_prefix()); +} + +std::string BuilderDefinition::module_pop(std::shared_ptr smodule) +{ + CHECK_EQ(smodule, m_module_stack.back()) + << "Namespace stack mismatch. Expected " << m_module_stack.back()->component_prefix() << " but got " + << smodule->component_prefix(); + + m_module_stack.pop_back(); + + return this->ns_pop(smodule->component_prefix()); +} + +std::string BuilderDefinition::ns_push(const std::string& name) +{ + m_namespace_stack.push_back(name); m_namespace_prefix = std::accumulate(m_namespace_stack.begin(), m_namespace_stack.end(), std::string(""), ::accum_merge); + + return m_namespace_prefix; } -void BuilderDefinition::ns_pop() +std::string BuilderDefinition::ns_pop(const std::string& name) { - m_module_stack.pop_back(); + CHECK_EQ(name, m_namespace_stack.back()) + << "Namespace stack mismatch. Expected " << m_namespace_stack.back() << " but got " << name; + m_namespace_stack.pop_back(); m_namespace_prefix = std::accumulate(m_namespace_stack.begin(), m_namespace_stack.end(), std::string(""), ::accum_merge); + + return m_namespace_prefix; } } // namespace mrc::segment diff --git a/cpp/mrc/src/internal/segment/builder_definition.hpp b/cpp/mrc/src/internal/segment/builder_definition.hpp index aa0c96140..fa8d0ece3 100644 --- a/cpp/mrc/src/internal/segment/builder_definition.hpp +++ b/cpp/mrc/src/internal/segment/builder_definition.hpp @@ -135,8 +135,11 @@ class BuilderDefinition : public IBuilder // Local methods bool has_object(const std::string& name) const; - void ns_push(std::shared_ptr smodule); - void ns_pop(); + std::string module_push(std::shared_ptr smodule); + std::string module_pop(std::shared_ptr smodule); + + std::string ns_push(const std::string& name); + std::string ns_pop(const std::string& name); // definition std::shared_ptr m_definition; diff --git a/cpp/mrc/tests/test_edges.cpp b/cpp/mrc/tests/test_edges.cpp index 8b022c5b9..5276b54e6 100644 --- a/cpp/mrc/tests/test_edges.cpp +++ b/cpp/mrc/tests/test_edges.cpp @@ -1030,8 +1030,8 @@ TEST_F(TestEdges, Zip) auto sink = std::make_shared>>(); - mrc::make_edge(*source1, *zip->get_sink<0>()); - mrc::make_edge(*source2, *zip->get_sink<1>()); + mrc::make_edge(*source1, zip->get_sink<0>()); + mrc::make_edge(*source2, zip->get_sink<1>()); mrc::make_edge(*zip, *sink); source1->run(); @@ -1057,8 +1057,8 @@ TEST_F(TestEdges, ZipEarlyClose) auto sink = std::make_shared>>(); - mrc::make_edge(*source1, *zip->get_sink<0>()); - mrc::make_edge(*source2, *zip->get_sink<1>()); + mrc::make_edge(*source1, zip->get_sink<0>()); + mrc::make_edge(*source2, zip->get_sink<1>()); mrc::make_edge(*zip, *sink); source1->run(); @@ -1077,8 +1077,8 @@ TEST_F(TestEdges, ZipLateClose) auto sink = std::make_shared>>(); - mrc::make_edge(*source1, *zip->get_sink<0>()); - mrc::make_edge(*source2, *zip->get_sink<1>()); + mrc::make_edge(*source1, zip->get_sink<0>()); + mrc::make_edge(*source2, zip->get_sink<1>()); mrc::make_edge(*zip, *sink); source1->run(); @@ -1094,6 +1094,36 @@ TEST_F(TestEdges, ZipLateClose) })); } +TEST_F(TestEdges, ZipEarlyReset) +{ + // Have one source emit different counts than the other + auto source1 = std::make_shared>(4); + auto source2 = std::make_shared>(3); + + auto zip = std::make_shared>(); + + auto sink = std::make_shared>>(); + + mrc::make_edge(*source1, zip->get_sink<0>()); + mrc::make_edge(*source2, zip->get_sink<1>()); + mrc::make_edge(*zip, *sink); + + // After the edges have been made, reset the zip to ensure that it can be kept alive by its children + zip.reset(); + + source1->run(); + source2->run(); + + sink->run(); + + EXPECT_EQ(sink->get_values(), + (std::vector>{ + std::tuple{0, 0}, + std::tuple{1, 1}, + std::tuple{2, 2}, + })); +} + TEST_F(TestEdges, WithLatestFrom) { auto source1 = std::make_shared>(5); diff --git a/python/mrc/core/node.cpp b/python/mrc/core/node.cpp index bbbdfe658..af682fdaa 100644 --- a/python/mrc/core/node.cpp +++ b/python/mrc/core/node.cpp @@ -20,6 +20,7 @@ #include "pymrc/utils.hpp" #include "mrc/node/operators/broadcast.hpp" +#include "mrc/node/operators/zip.hpp" #include "mrc/segment/builder.hpp" #include "mrc/segment/object.hpp" #include "mrc/utils/string_utils.hpp" @@ -58,6 +59,26 @@ PYBIND11_MODULE(node, py_mod) return node; })); + py::class_, + mrc::segment::ObjectProperties, + std::shared_ptr>>(py_mod, "Zip") + .def(py::init<>([](mrc::segment::IBuilder& builder, std::string name, size_t count) { + // std::shared_ptr node; + + if (count == 2) + { + return builder.construct_object>(name)->as(); + } + else + { + py::print("Unsupported count!"); + throw std::runtime_error("Unsupported count!"); + } + })) + .def("get_sink", [](mrc::segment::Object& self, size_t index) { + return self.get_child(MRC_CONCAT_STR("sink[" << index << "]")); + }); + py_mod.attr("__version__") = MRC_CONCAT_STR(mrc_VERSION_MAJOR << "." << mrc_VERSION_MINOR << "." << mrc_VERSION_PATCH); } diff --git a/python/tests/test_edges.py b/python/tests/test_edges.py index 98ed11d0e..75b41cea6 100644 --- a/python/tests/test_edges.py +++ b/python/tests/test_edges.py @@ -252,6 +252,16 @@ def add_broadcast(seg: mrc.Builder, *upstream: mrc.SegmentObject): return node +def add_zip(seg: mrc.Builder, *upstream: mrc.SegmentObject): + + node = mrc.core.node.Zip(seg, "Zip", len(upstream)) + + for i, u in enumerate(upstream): + seg.make_edge(u, node.get_sink(i)) + + return node + + # THIS TEST IS CAUSING ISSUES WHEN RUNNING ALL TESTS TOGETHER # @dataclasses.dataclass @@ -557,3 +567,18 @@ def segment_init(seg: mrc.Builder): results = run_segment(segment_init) assert results == expected_node_counts + + +@pytest.mark.parametrize("source_cpp", [True, False], ids=["source_cpp", "source_py"]) +def test_multi_source_to_zip_to_sink(run_segment, source_cpp: bool): + + def segment_init(seg: mrc.Builder): + + source1 = add_source(seg, is_cpp=source_cpp, data_type=m.Base, is_component=False, suffix="1") + source2 = add_source(seg, is_cpp=source_cpp, data_type=m.Base, is_component=False, suffix="2") + zip = add_zip(seg, source1, source2) + add_sink(seg, zip, is_cpp=False, data_type=m.Base, is_component=False) + + results = run_segment(segment_init) + + assert results == expected_node_counts