From d0260f730d1e81a8730eb9126a839e3d3bc63ff3 Mon Sep 17 00:00:00 2001 From: Hoyt Koepke Date: Sat, 21 Dec 2019 16:26:44 -0700 Subject: [PATCH 1/3] Working prototype of alternate model management system. This class proposes a new method for managing extensions. It is meant to replace the bulky and unweildy macro system in use currently. The new code is contained in model_server_v2. --- src/CMakeLists.txt | 3 +- src/model_server/lib/variant.hpp | 8 +- src/model_server_v2/CMakeLists.txt | 17 + src/model_server_v2/demo.cpp | 105 +++++ src/model_server_v2/method_parameters.hpp | 112 +++++ src/model_server_v2/method_registry.hpp | 137 ++++++ src/model_server_v2/method_wrapper.hpp | 398 ++++++++++++++++++ src/model_server_v2/model_base.cpp | 17 + src/model_server_v2/model_base.hpp | 143 +++++++ src/model_server_v2/model_server.cpp | 67 +++ src/model_server_v2/model_server.hpp | 288 +++++++++++++ src/model_server_v2/registration.hpp | 68 +++ .../annotation/object_detection.hpp | 2 +- 13 files changed, 1362 insertions(+), 3 deletions(-) create mode 100644 src/model_server_v2/CMakeLists.txt create mode 100644 src/model_server_v2/demo.cpp create mode 100644 src/model_server_v2/method_parameters.hpp create mode 100644 src/model_server_v2/method_registry.hpp create mode 100644 src/model_server_v2/method_wrapper.hpp create mode 100644 src/model_server_v2/model_base.cpp create mode 100644 src/model_server_v2/model_base.hpp create mode 100644 src/model_server_v2/model_server.cpp create mode 100644 src/model_server_v2/model_server.hpp create mode 100644 src/model_server_v2/registration.hpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7a68f607aa..de8679d3c5 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -7,7 +7,8 @@ subdirs( ml toolkits visualization - model_server) + model_server + model_server_v2) if(TC_BUILD_PYTHON) diff --git a/src/model_server/lib/variant.hpp b/src/model_server/lib/variant.hpp index 8b14a629db..bd58c6358f 100644 --- a/src/model_server/lib/variant.hpp +++ b/src/model_server/lib/variant.hpp @@ -14,7 +14,6 @@ #include #include #include -#include namespace turi { class model_base; @@ -308,6 +307,13 @@ template inline variant_type to_variant(const T& f) { return variant_converter::type>().set(f); } + +/** Overload for the case when we're trying to wrap a void return + * type in a template. + */ +inline variant_type to_variant() { + return variant_type(); +} } // namespace turi namespace turi { diff --git a/src/model_server_v2/CMakeLists.txt b/src/model_server_v2/CMakeLists.txt new file mode 100644 index 0000000000..778460b573 --- /dev/null +++ b/src/model_server_v2/CMakeLists.txt @@ -0,0 +1,17 @@ +make_library( model_server_v2 + SOURCES + model_server.cpp + model_base.cpp + REQUIRES + unity_shared) + + +make_executable(demo + SOURCES demo.cpp + REQUIRES model_server_v2 unity_shared) + + + + + + diff --git a/src/model_server_v2/demo.cpp b/src/model_server_v2/demo.cpp new file mode 100644 index 0000000000..76039f1ed0 --- /dev/null +++ b/src/model_server_v2/demo.cpp @@ -0,0 +1,105 @@ +#include +#include +#include + +using namespace turi; +using namespace turi::v2; + +/** Demo model. + * + */ +class demo_model : public turi::v2::model_base { + + public: + + /** The name of the model. + * + */ + const char* name() const { return "demo_model"; } + + /** Sets up the options and the registration. + * + * The registration is done in the constructor, without the use of macros. + */ + demo_model() { + register_method("add", &demo_model::add, "x", "y"); + register_method("concat_strings", &demo_model::append_strings, "s1", "s2"); + + // Defaults are specified inline + register_method("increment", &demo_model::increment, "x", Parameter("delta", 1)); + } + + /** Add two numbers. + * + * Const is fine. + */ + size_t add(size_t x, size_t y) const { + return x + y; + } + + /** Append two strings with a + + */ + std::string append_strings(const std::string& s1, const std::string& s2) const + { + return s1 + "+" + s2; + } + + /** Incerment a value. + */ + size_t increment(size_t x, size_t delta) const { + return x + delta; + } + +}; + +/** Registration for a model is just a single macro in the header. + * + * This automatically loads and registers the model when the library is loaded. + * This registration is trivially cheap. + */ +REGISTER_MODEL(demo_model); + + +void hello_world(const std::string& greeting) { + std::cout << "Hello, world!! " << greeting << std::endl; +} + +/** Registration for a function is just a single macro in a source file or header. + * + * This automatically loads and registers the function when the library is loaded. + */ +REGISTER_FUNCTION(hello_world, "greeting"); + + + +int main(int argc, char** argv) { + + + auto dm = model_server().create_model("demo_model"); + + std::string name = variant_get_value(dm->call_method("name")); + + std::cout << "Demoing model = " << name << std::endl; + + size_t result = variant_get_value(dm->call_method("add", 5, 9)); + + std::cout << "5 + 9 = " << result << std::endl; + + std::string s_res = variant_get_value(dm->call_method("concat_strings", "A", "B")); + + std::cout << "Concat A, +, B: " << s_res << std::endl; + + // Delta default is 1 + size_t inc_value = variant_get_value(dm->call_method("increment", 5)); + + std::cout << "Incremented 5: " << inc_value << std::endl; + + + // Call the registered function. + std::cout << "Calling hello_world." << std::endl; + model_server().call_function("hello_world", "This works!"); + + + return 0; + +} diff --git a/src/model_server_v2/method_parameters.hpp b/src/model_server_v2/method_parameters.hpp new file mode 100644 index 0000000000..b5b83132b0 --- /dev/null +++ b/src/model_server_v2/method_parameters.hpp @@ -0,0 +1,112 @@ +/* Copyright © 2019 Apple Inc. All rights reserved. + * + * Use of this source code is governed by a BSD-3-clause license that can + * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + */ +#ifndef TURI_METHOD_PARAMETERS_HPP_ +#define TURI_METHOD_PARAMETERS_HPP_ + +#include +#include +#include + +namespace turi { +namespace v2 { + + +/** Struct to hold information about the user specified parameter of a method. + * + * Includes information about a possible default value. + */ +struct Parameter { + + Parameter() {} + + // Allow implicit cast from string here.. + Parameter(const std::string& n) : name(n) {} + Parameter(std::string&& n) : name(std::move(n)) {} + + // Specify parameter with a default value + Parameter(const std::string& n, const variant_type& v) + : name(n), has_default(true), default_value(v) + {} + + Parameter(std::string&& n, variant_type&& v) + : name(std::move(n)), has_default(true), default_value(std::move(v)) + {} + + + // TODO: expand this out into a proper container class. + + + // Name + std::string name; + + // Optional default value + bool has_default = false; + variant_type default_value; +}; + + + +template +void validate_parameter_list(const std::vector& params) { + + // Validates that the parameter list works with the given types of + // the function. + if(sizeof...(FuncParams) != params.size()) { + throw std::invalid_argument("Mismatch in number of specified parameters."); + } + + // TODO: validate uniqueness of names. + // TODO: validate defaults can be cast to proper types. +} + + +//////////////////////////////////////////////////////////////////////////////// + +// How the arguments are bundled up and packaged. +struct argument_pack { + std::vector ordered_arguments; + variant_map_type named_arguments; +}; + +/** Method for resolving incoming arguments to a method. + * + */ +template +void resolve_method_arguments(std::array& arg_v, + const std::vector& parameter_list, + const argument_pack& args) { + + size_t n_ordered = args.ordered_arguments.size(); + for(size_t i = 0; i < n_ordered; ++i) { + arg_v[i] = &args.ordered_arguments[i]; + } + + // TODO: check if more ordered arguments given than are + // possible here. + size_t used_counter = n_ordered; + for(size_t i = n_ordered; i < n; ++i) { + auto it = args.named_arguments.find(parameter_list[i].name); + if(it == args.named_arguments.end()) { + if(parameter_list[i].has_default) { + arg_v[i] = &(parameter_list[i].default_value); + } else { + // TODO: intelligent error message. + throw std::string("Missing argument."); + } + } else { + arg_v[i] = &(it->second); + ++used_counter; + } + } + + // TODO: check that all the arguments have been used up. If not, + // generate a good error message. +} + +} +} + +#endif diff --git a/src/model_server_v2/method_registry.hpp b/src/model_server_v2/method_registry.hpp new file mode 100644 index 0000000000..a6472d306c --- /dev/null +++ b/src/model_server_v2/method_registry.hpp @@ -0,0 +1,137 @@ +/* Copyright © 2017 Apple Inc. All rights reserved. + * + * Use of this source code is governed by a BSD-3-clause license that can + * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + */ +#ifndef TURI_METHOD_REGISTRY_HPP_ +#define TURI_METHOD_REGISTRY_HPP_ + +#include +#include +#include +#include +#include + +namespace turi { +namespace v2 { + + +/** Manager all the methods in a given class / model. + * + * This class exists to manage a collection of methods associated with a given + * class. It provides an interface to call previously registered methods on + * this class by name, along with helpful error messages if the call is wrong. + * + * If the BaseClass is void, it provides a registry for standalone functions. + * TODO: implement this. + */ +template +class method_registry { + public: + + method_registry() + : m_class_name() + {} + + method_registry(const std::string& _name) + : m_class_name(_name) + {} + + /** Register a new method. + * + * See method_wrapper::create for an explanation of the arguments. + * + */ + template + void register_method(const std::string& name, RegisterMethodArgs&&... rmargs) { + + try { + + auto wrapper = method_wrapper::create(rmargs...); + + m_method_lookup[name] = wrapper; + } catch(...) { + // TODO: Expand these exceptions to make them informative. + process_exception(std::current_exception()); + } + } + + // Lookup a call function information. + std::shared_ptr > lookup(const std::string& name) const { + + // TODO: proper error message here + return m_method_lookup.at(name); + } + + /** Call a given const method registered previously. + */ + variant_type call_method(const BaseClass* inst, const std::string& name, + const argument_pack& arguments) const { + + try { + return lookup(name)->call(inst, arguments); + + } catch(...) { + process_exception(std::current_exception()); + } + } + + /** Call a given const or non-const method registered previously. + */ + variant_type call_method(BaseClass* inst, const std::string& name, + const argument_pack& arguments) const { + + try { + return lookup(name)->call(inst, arguments); + } catch(...) { + process_exception(std::current_exception()); + } + } + + private: + + [[noreturn]] void process_exception(std::exception_ptr e) const { + // TODO: Expand these exceptions to make them informative. + + std::rethrow_exception(e); + } + + // Unpack arguments. + template = 0> + inline void _arg_unpack(std::vector& dest, const Tuple& t) const { + dest[i] = to_variant(std::get(t)); + _arg_unpack(dest, t); + } + + template = 0> + inline void _arg_unpack(std::vector& dest, const Tuple& t) const { + } + + + public: + + // Call a method with the arguments explicitly. + template ::type, BaseClass>::value> = 0> + variant_type call_method(BC* inst, const std::string& name, const Args&... args) const { + + argument_pack arg_list; + arg_list.ordered_arguments.resize(sizeof...(Args)); + + _arg_unpack<0, sizeof...(Args)>(arg_list.ordered_arguments, std::make_tuple(args...)); + + return call_method(inst, name, arg_list); + } + + private: + + std::string m_class_name; + + std::unordered_map > > + m_method_lookup; +}; + +} +} + +#endif diff --git a/src/model_server_v2/method_wrapper.hpp b/src/model_server_v2/method_wrapper.hpp new file mode 100644 index 0000000000..25189eefd7 --- /dev/null +++ b/src/model_server_v2/method_wrapper.hpp @@ -0,0 +1,398 @@ +/* Copyright © 2019 Apple Inc. All rights reserved. + * + * Use of this source code is governed by a BSD-3-clause license that can + * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + */ +#ifndef TURI_METHOD_WRAPPER_HPP_ +#define TURI_METHOD_WRAPPER_HPP_ + +#include +#include +#include + +namespace turi { + +// Helper for enable if in templates. +// TODO: move this to a common location +template using enable_if_ = typename std::enable_if::type; + +namespace v2 { + + +/** Base class for wrapper around a specific method. + * + * This class provides an interface to call a method + * or function using generic arguments. The interface is contained in a + * templated instance of the class derived from this. It's created + * through the `create` static factory method. + * + * This class is meant to be a member of the + * + */ +template class method_wrapper { + + public: + + virtual ~method_wrapper(){} + + /** The type of the Class on which our method rests. + * + * May be , in which case it's a standalone function. + */ + typedef BaseClass class_type; + + /// Calling method for const base classes. + virtual variant_type call(const BaseClass* _inst, const argument_pack& args) const = 0; + + /// Calling method for non-const base classes. + virtual variant_type call(BaseClass* C, const argument_pack& args) const = 0; + + /// Calling method for standalone functions. + variant_type call(const argument_pack& args) const { + return call(nullptr, args); + } + + /// Returns the parameter info struct for a particular parameter + inline const Parameter& parameter_info(size_t n) const { + return m_parameter_list.at(n); + } + + /// Returns the name of the parameter + inline const std::string& parameter_name(size_t n) const { + return parameter_info(n).name; + } + + /** Factory method. + * + * Call this method to create the interface wrapper around + * the method. + */ + template + static std::shared_ptr + create(RetType(Class::*method)(FuncParams...) const, const ParamDefs&...); + + /** Overload of above factory method for non-const methods + * + */ + template + static std::shared_ptr + create(RetType(Class::*method)(FuncParams...), const ParamDefs&...); + + + /** Factory method for non-method functions. + * + */ + template + static std::shared_ptr + create(RetType(*function)(FuncParams...), const ParamDefs&...); + + protected: + + // To be called only from the instantiating class + method_wrapper(const std::vector& _parameter_list) + : m_parameter_list(_parameter_list) + { } + + // Information about the function / method + std::vector m_parameter_list; + }; + + +///////////////////////////////////////////////////////////////////////////////// +// +// Implementation details of the above. +// + +// If there is no class, then this is used instead of Class +// to denote that there is no class. +struct __NoClass {}; + +/** Child class for resolving the arguments passed into a function call. + * + * This class is mainly a container to define the types present + * during the recursive parameter expansion stage. + * + */ +template + class method_wrapper_impl : public method_wrapper { + private: + + /// The number of parameters required for the function. + static constexpr size_t N = sizeof...(FuncParams); + + /// Are we in a class instance or just a standalone function? + static constexpr bool is_method = !std::is_same::value; + + // logic below requires is_const_method to be false when is_method is false + static_assert(is_method || !is_const_method, + "is_const_method=1 when is_method =0"); + + /// Set the method type -- general class vs standalone function + template struct method_type_impl {}; + + // Non-method case + template <> struct method_type_impl<0> { + typedef RetType (*type)(FuncParams...); + }; + + // Method non-const case. + template <> struct method_type_impl<1> { + typedef RetType (Class::*type)(FuncParams...); + }; + + // Method const case. + template <> struct method_type_impl<2> { + typedef RetType (Class::*type)(FuncParams...) const; + }; + + typedef typename method_type_impl::type method_type; + + /// Function pointer to the method we're calling. + method_type m_method; + + public: + /** Constructor. + */ + method_wrapper_impl( + std::vector&& _parameter_list, method_type method) + : method_wrapper(std::move(_parameter_list)) + , m_method(method) + { + validate_parameter_list(this->m_parameter_list); + } + + private: + + + + ////////////////////////////////////////////////////////////////////////////// + // + // Calling methods + + /** A handy way to refer to the type of the nth argument. + */ + template + struct nth_param_type { + typedef typename std::tuple_element >::type raw_type; + typedef typename std::decay::type type; + }; + + /// Container for passing the calling arguments around after unpacking from argument_list + typedef std::array arg_v_type; + + ////////////////////////////////////////////////////////////////////////////// + // Entrance methods to this process. + + /** Non-const calling method. + * + */ + variant_type call(BaseClass* inst, const argument_pack& args) const override { + return _choose_call_path(inst, args); + } + + /** Const calling method. + * + */ + variant_type call(const BaseClass* inst, const argument_pack& args) const override { + return _choose_call_path(inst, args); + } + + ////////////////////////////////////////////////////////////////////////////// + // Step 1: Determine the evaluation path and Class pointer type depending + // on whether it's a const method, regular method, or function. + + template + struct _call_chooser { + static constexpr bool func_path = !is_method; + static constexpr bool const_method = is_const_method; + static constexpr bool _non_const_method = is_method && !is_const_method; + static constexpr bool bad_const_call = _non_const_method && std::is_const::value; + static constexpr bool method_path = _non_const_method && !std::is_const::value; + }; + + // If it's a regular function. + template ::func_path> = 0> + variant_type _choose_call_path(C* inst, const argument_pack& args) const { + return _call(nullptr, args); + } + + // If it's a bad const call. + template ::bad_const_call> = 0> + [[noreturn]] + variant_type _choose_call_path(C* inst, const argument_pack& args) const { + // Null implementation of the above to intercept compilation of the in-formed + // case where it's a const class and the method is non-const. + throw std::invalid_argument("Non-const method call attempted on const class pointer."); + } + + // Const method. + template ::const_method> = 0> + variant_type _choose_call_path(C* inst, const argument_pack& args) const { + return _call(dynamic_cast(inst), args); + } + + // Non-const method. + template ::method_call> = 0> + variant_type _choose_call_path(C* inst, const argument_pack& args) const { + return _call(dynamic_cast(inst), args); + } + + ////////////////////////////////////////////////////////////////////////////// + // Step 2: Unpack and resolve arguments. + + template + variant_type _call(C* inst, const argument_pack& args) const { + arg_v_type arg_v; + + // Resolve and unpack the incoming arguments. + resolve_method_arguments(arg_v, this->m_parameter_list, args); + + // Now that the argument list arg_v is filled out, we can call the + // recursive calling function and return the value. + return __call<0>(inst, arg_v); + } + + ////////////////////////////////////////////////////////////////////////////// + // Step 3: Recursively unpack the parameters into a parameter pack with the + // correct values. Checks performed here. + + template = 0> + variant_type __call(C* inst, const arg_v_type& arg_v, const Expanded&... args) const { + + // TODO: Separate out the case where the unpacking can be done by + // reference. + typedef typename nth_param_type::type arg_type; + + // TODO: Add intelligent error messages here on failure + arg_type next_arg = variant_get_value(*(arg_v[arg_idx])); + + // Call the next unpacking routine. + return __call(inst, arg_v, args..., next_arg); + } + + ////////////////////////////////////////////////////////////////////////////// + // Step 4: Call the function / method. + // + // This is the stopping case of expansion -- we've unpacked and translated all the + // arguments, now it's time to actually call the method. + + // First case: class method, void return type. + template ::value> = 0> + variant_type __call(C* inst, const arg_v_type& arg_v, const Expanded&... args) const { + (inst->*m_method)(args...); + return variant_type(); + } + + // Second case: class method, non-void return type. + template ::value> = 0> + variant_type __call(C* inst, const arg_v_type& arg_v, const Expanded&... args) const { + + return to_variant( (inst->*m_method)(args...) ); + } + + // Third case: standalone function, void return type + template ::value> = 0> + variant_type __call(C*, const arg_v_type& arg_v, const Expanded&... args) const { + + m_method(args...); + return variant_type(); + } + + // Fourth case: standalone function, non-void return type. + template ::value> = 0> + variant_type __call(C*, const arg_v_type& arg_v, const Expanded&... args) const { + + return to_variant(m_method(args...) ); + } +}; + + +////////////////////////////////////////////////////////// +// Some utility functions to help unpack the variadic arguments into +// a vector of parameters. + +template +void __unpack_parameters( + std::vector& dest, + const std::vector& vp) { + dest = vp; +} + +/** Recursive function to unpack the parameter list. + */ +template = 0> +void __unpack_parameters(std::vector& dest, const Params&... pv) { + dest[idx] = Parameter(std::get(std::make_tuple(pv...))); + __unpack_parameters(dest, pv...); +} + +/** Stopping case of recursive unpack function. + * + */ +template = 0> +void __unpack_parameters(std::vector& dest, const Params&... pv) { +} + +/** Implementation of the factory method for non-const methods. + */ +template +template +std::shared_ptr > method_wrapper::create( + RetType(Class::*method)(FuncParams...), + const ParamDefs&... param_defs) { + + std::vector params; + params.resize(sizeof...(ParamDefs)); + __unpack_parameters<0>(params, param_defs...); + + return std::make_shared > + (std::move(params), method); + }; + +/** Const overload of the method interface factory method. + */ +template +template +std::shared_ptr > method_wrapper::create( + RetType(Class::*method)(FuncParams...) const, + const ParamDefs&... param_defs) { + + std::vector params; + params.resize(sizeof...(ParamDefs)); + __unpack_parameters<0>(params, param_defs...); + + return std::make_shared > + (std::move(params), method); + }; + +/** Factory method for non-method functions. + * + */ +template + template + std::shared_ptr > + method_wrapper::create( + RetType(*function)(FuncParams...), + const ParamDefs&... param_defs) { + + std::vector params; + params.resize(sizeof...(ParamDefs)); + __unpack_parameters<0>(params, param_defs...); + + return std::make_shared > + (std::move(params), function); + + } + + +} +} + +#endif diff --git a/src/model_server_v2/model_base.cpp b/src/model_server_v2/model_base.cpp new file mode 100644 index 0000000000..0f9ebe9859 --- /dev/null +++ b/src/model_server_v2/model_base.cpp @@ -0,0 +1,17 @@ +#include + +namespace turi { +namespace v2 { + +model_base::model_base() + : m_registry(new method_registry()) +{ + register_method("name", &model_base::name); +} + + +model_base::~model_base() { } + +} +} + diff --git a/src/model_server_v2/model_base.hpp b/src/model_server_v2/model_base.hpp new file mode 100644 index 0000000000..658be4dceb --- /dev/null +++ b/src/model_server_v2/model_base.hpp @@ -0,0 +1,143 @@ +/* Copyright © 2019 Apple Inc. All rights reserved. + * + * Use of this source code is governed by a BSD-3-clause license that can + * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + */ +#ifndef TURI_MODEL_BASE_V2_HPP +#define TURI_MODEL_BASE_V2_HPP + +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace turi { + namespace v2 { + + +/** + * The base class from which all new models must inherit. + * + * This class defines a generic object interface, listing properties and + * callable methods, so that instances can be naturally wrapped and exposed to + * other languages, such as Python. + * + * Subclasses that wish to support saving and loading should also override the + * save_impl, load_version, and get_version functions below. + */ +class EXPORT model_base { + public: + + model_base(); + + virtual ~model_base(); + + // These public member functions define the communication between model_base + // instances and the unity runtime. Subclasses define the behavior of their + // instances using the protected interface below. + + /** + * Returns the name of the toolkit class, as exposed to client code. For + * example, the Python proxy for this instance will have a type with this + * name. + * + */ + virtual const char* name() const = 0; + + /** Sets up the class given the options present. + * + * TODO: implement all of this. + */ + virtual void setup(const variant_map_type& options) { + // option_manager.update(options); + } + + /** Call one of the const methods registered using the configure() method above. + * + * `args` may either be explicit arguments or an instance of + * the argument_pack class. + */ + template + variant_type call_method(const std::string& name, const Args&... args) const; + + + /** Call one of the methods registered using the configure() method above. + * + * `args` may either be explicit arguments or an instance of + * the argument_pack class. + */ + template + variant_type call_method(const std::string& name, const Args&... args); + + + /** Register a method that can be called by name using the registry above. + * + * The format for calling this is the function name, the pointer to the method, + * then a list of names or Parameter class instances giving the names of the + * parameters. + * + * Example: + * + * // For a method "add" in class C that derives from model_base. + * register_method("add", &C::add, "x", "y"); + * + * // For a method "inc" in class C that derives from model_base, + * // with a default parameter of 1. + * register_method("inc", &C::inc, Parameter("delta", 1) ); + * + * + * See the documentation on method_wrapper<...>::create to see the format of args. + */ + template + void register_method(const std::string& name, Method&&, const Args&... args); + + + // TODO: add back in load and save routines. + + private: + std::shared_ptr > m_registry; + +}; + +/////////////////////////////////////////////////////////////////////// +// +// Implementation of above template functions. + +/** Call one of the methods registered using the configure() method above. + * + */ +template +variant_type model_base::call_method( + const std::string& name, const Args&... args) { + + return m_registry->call_method(this, name, args...); +} + +/** Const overload of the above. + * + */ +template +variant_type model_base::call_method( + const std::string& name, const Args&... args) const { + + return m_registry->call_method(this, name, args...); +} + +/** Register a method that can be called by name using the registry above. + */ +template +void model_base::register_method( + const std::string& name, Method&& m, const Args&... args) { + + m_registry->register_method(name, m, args...); +} + +} +} + +#endif diff --git a/src/model_server_v2/model_server.cpp b/src/model_server_v2/model_server.cpp new file mode 100644 index 0000000000..4c16321c0c --- /dev/null +++ b/src/model_server_v2/model_server.cpp @@ -0,0 +1,67 @@ +/* Copyright © 2017 Apple Inc. All rights reserved. + * + * Use of this source code is governed by a BSD-3-clause license that can + * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + */ +#include + + +namespace turi { +namespace v2 { + +EXPORT model_server_impl& model_server() { + static model_server_impl global_model_server; + return global_model_server; +} + + +model_server_impl::model_server_impl() + : m_function_registry(new method_registry) +{ + + +} + +/** Does the work of registering things with the callbacks. + */ +void model_server_impl::_process_registered_callbacks_internal() { + + std::lock_guard _lg(m_model_registration_lock); + + size_t cur_idx; + + while( (cur_idx = m_callback_last_processed_index) != m_callback_pushback_index) { + + + // Call the callback function to perform the registration, simultaneously + // zeroing out the pointer. + _registration_callback reg_f = nullptr; + size_t idx = cur_idx % m_registration_callback_list.size(); + std::swap(reg_f, m_registration_callback_list[idx]); + reg_f(*this); + + // We're done here; advance. + ++m_callback_last_processed_index; + } +} + +/** Instantiate a previously registered model by name. + */ +std::shared_ptr model_server_impl::create_model(const std::string& model_name) { + + // Make sure there aren't new models waiting on the horizon. + check_registered_callback_queue(); + + auto it = m_model_by_name.find(model_name); + + if(it == m_model_by_name.end()) { + // TODO: make this more informative. + throw std::invalid_argument("Model not recognized."); + } + + return it->second(); +} + + +} +} diff --git a/src/model_server_v2/model_server.hpp b/src/model_server_v2/model_server.hpp new file mode 100644 index 0000000000..3e7871ca35 --- /dev/null +++ b/src/model_server_v2/model_server.hpp @@ -0,0 +1,288 @@ +/* Copyright © 2017 Apple Inc. All rights reserved. + * + * Use of this source code is governed by a BSD-3-clause license that can + * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + */ +#ifndef TURI_MODEL_SERVER_HPP +#define TURI_MODEL_SERVER_HPP + +#include +#include +#include +#include +#include +#include +#include + + +namespace turi { +namespace v2 { + +class model_server_impl; + +/** Returns the singleton version of the model server. + * + */ +EXPORT model_server_impl& model_server(); + + +EXPORT class model_server_impl { + private: + + // Disable instantiation outside of the global instance. + model_server_impl(); + friend model_server_impl& model_server(); + + // Explicitly disable copying, etc. + model_server_impl(const model_server_impl&) = delete; + model_server_impl(model_server_impl&&) = delete; + + public: + + //////////////////////////////////////////////////////////////////////////// + // Calling models. + + /** Instantiate a previously registered model by name. + * + */ + std::shared_ptr create_model(const std::string& model_name); + + + /** Instantiate a model by type. + */ + template + std::shared_ptr create_model(); + + + /** Call a previously registered function. + */ + template + variant_type call_function( + const std::string& function_name, FunctionArgs&&... args); + + public: + + /** Registration of a function. + * + * Registers a new function that can be called through the call_function + * call above. + * + * The format of the call is name of the function, the function, and then + * a list of 0 or more parameter specs. + * + * Example: + * + * void f(int x, int y); + * register_new_function("f", f, "x", "y"); + * + * + * + * \param name The name of the function. + * \param function A pointer to the function itself. + * \param param_specs The parameter specs + */ + template + void register_new_function(const std::string& name, Function&& function, ParamSpecs&&...); + + + /** Registration of new models. + * + * A model is registered through a call to register new model, which + * instantiates it and populates the required options and method call + * lookups. Copies of these options and method call lookups are stored + * internally in a registry here so new models can be instantiated quickly. + * + * + * The new model's name() method provides the name of the model being + * registered. + * + * This method can be called at any point. + * + */ + template void register_new_model(); + + /** Fast on-load model registration. + * + * The callbacks below provide a fast method for registering new models + * on library load time. This works by first registering a callback + * using a simple callback function. + * + */ + typedef void (*_registration_callback)(model_server_impl&); + + /** Register a callback function to be processed when a model is served. + * + * Function is reentrant and fast enough to be called from a static initializer. + */ + inline void add_registration_callback(_registration_callback callback) GL_HOT_INLINE_FLATTEN; + + private: + + + ////////////////////////////////////////////////////////////////////////////// + // Registered model lookups. + typedef std::function()> model_creation_function; + std::unordered_map m_model_by_name; + + + ////////////////////////////////////////////////////////////////////////////// + // Registered function lookups. + + std::unique_ptr > m_function_registry; + + + /////////////////////////////////////////////////////////////////////////////// + // TODO: Registered function lookups. + // + + /** Lock to ensure that model registration is queued correctly. + */ + std::mutex m_model_registration_lock; + std::type_index m_last_model_registered = std::type_index(typeid(void)); + + /** An intermediate buffer of registration callbacks. + * + * These queues are used on library load to register callback functions, which + * are then processed when any model is requested to ensure that library loading + * is done efficiently. check_registered_callback_queue() should be called + * before any lookups are done to ensure that all possible lookups have been + * registered. + * + */ + std::array<_registration_callback, 512> m_registration_callback_list; + std::atomic m_callback_pushback_index; + std::atomic m_callback_last_processed_index; + + /** Process the registered callbacks. + * + * First performs a fast inline check to see if it's needed, so + * this function can be called easily. + */ + inline void check_registered_callback_queue(); + + /** Does the work of registering things with the callbacks. + */ + void _process_registered_callbacks_internal(); +}; + +///////////////////////////////////////////////////////////////////////////////// +// +// Implementations of inline functions for the model server class +// + +/** Fast inline check + */ +inline void model_server_impl::check_registered_callback_queue() { + if(m_callback_last_processed_index < m_callback_pushback_index) { + _process_registered_callbacks_internal(); + } +} + +/** Add the callback to the registration function. + * + * This works by putting the callback function into a round-robin queue to avoid + * potential allocations or deallocations during library load time and to + * preserve thread safety. + */ +inline void model_server_impl::add_registration_callback( + model_server_impl::_registration_callback callback) { + + + size_t insert_index_raw = (m_callback_pushback_index++); + + do { + // Check to make sure this can be safely inserted. + size_t processed_index_raw = m_callback_last_processed_index; + + // Check to make sure we aren't so far behind the number of actually + // registered callbacks that we're out of space. + if(processed_index_raw + m_registration_callback_list.size() > insert_index_raw) { + break; + } else { + // This will process the next block of insertions. + _process_registered_callbacks_internal(); + } + + } while(true); + + size_t insert_index = insert_index_raw % m_registration_callback_list.size(); + + ASSERT_TRUE(m_registration_callback_list[insert_index] == nullptr); + m_registration_callback_list[insert_index] = callback; +} + + +/** Registration of new models. +* +* A model is registered through a call to register new model, which +* instantiates it and populates the required options and method call +* lookups. Copies of these options and method call lookups are stored +* internally in a registry here so new models can be instantiated quickly. +* +* +* The new model's name() method provides the name of the model being +* registered. +* +* This method can be called at any point. +* +*/ +template void model_server_impl::register_new_model() { + + // Quick check to cut out duplicate registrations. This can + // happen, e.g. if the class or the function macros appear in a header, + // which is fine and something we are designed to handle. + // However, this means that multiple registration calls can occur for the same + // class, and this quickly filters those registrations out. + if(std::type_index(typeid(ModelClass)) == m_last_model_registered) { + return; + } + m_last_model_registered = std::type_index(typeid(ModelClass)); + + // TODO: As the registration is now performed in the constructor, + // a base instantiated version of the class should be held, then + // subsequent model creations should simply use the copy constructor to + // instantiate them. This means the entire method registry is not + // duplicated. For now, just go through this way. + const std::string& name = ModelClass().name(); + + model_creation_function mcf = [=](){ return this->create_model(); }; + + m_model_by_name.insert({name, mcf}); +} + + +/** Instantiate a previously registered model by type. + */ +template + std::shared_ptr model_server_impl::create_model() { + + // Make sure there aren't new models waiting on the horizon. + check_registered_callback_queue(); + + return std::make_shared(); +} + + +/** Register a new function. + */ +template + void model_server_impl::register_new_function( + const std::string& name, Function&& function, ParamSpecs&&... param_specs) { + m_function_registry->register_method(name, function, param_specs...); +} + +/** Call the function. + */ +template + variant_type model_server_impl::call_function( + const std::string& function_name, FunctionArgs&&... args) { + + // Make sure there aren't new functions waiting on the horizon. + check_registered_callback_queue(); + + return m_function_registry->call_method(nullptr, function_name, args...); +} + +} +} // End turi namespace +#endif diff --git a/src/model_server_v2/registration.hpp b/src/model_server_v2/registration.hpp new file mode 100644 index 0000000000..82b9246170 --- /dev/null +++ b/src/model_server_v2/registration.hpp @@ -0,0 +1,68 @@ +/* Copyright © 2017 Apple Inc. All rights reserved. + * + * Use of this source code is governed by a BSD-3-clause license that can + * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + */ +#ifndef TURI_MODEL_SERVER_V2_REGSTRATION_HPP_ +#define TURI_MODEL_SERVER_V2_REGSTRATION_HPP_ + + +#include +#include +#include +#include + +namespace turi { + namespace v2 { + +// A helper class to use a static initializer to do a lightweight registration +// of class loading at library load time. Intended to be used as a component of +// the registration macro. +class __model_server_static_class_registration_hook { + public: + inline __model_server_static_class_registration_hook( + model_server_impl::_registration_callback f) { + model_server().add_registration_callback(f); + } +}; + + + +#define REGISTER_MODEL(model) \ + static void __register_##model(model_server_impl& server) { \ + server.template register_new_model(); \ + } \ + \ + static turi::v2::__model_server_static_class_registration_hook \ + __register_##model##_hook(__register_##model) + + +// A helper class to use a static initializer to do a lightweight registration +// of class loading at library load time. Intended to be used as a component of +// the +class __model_server_static_function_registration_hook { + public: + inline __model_server_static_function_registration_hook( + model_server_impl::_registration_callback f) { + model_server().add_registration_callback(f); + } +}; + + + +#define REGISTER_NAMED_FUNCTION(name, function, ...) \ +\ + static void register_function_##function(model_server_impl& server) {\ + server.register_new_function(name, function, __VA_ARGS__);\ + } \ + __model_server_static_function_registration_hook \ +__register_function_##function##_hook(register_function_##function) + +#define REGISTER_FUNCTION(function, ...) \ + REGISTER_NAMED_FUNCTION(#function, function, __VA_ARGS__) + + +} +} + +#endif diff --git a/src/visualization/annotation/object_detection.hpp b/src/visualization/annotation/object_detection.hpp index df45e3c656..650aece20c 100644 --- a/src/visualization/annotation/object_detection.hpp +++ b/src/visualization/annotation/object_detection.hpp @@ -60,4 +60,4 @@ create_object_detection_annotation(const std::shared_ptr &data, } // namespace annotate } // namespace turi -#endif \ No newline at end of file +#endif From 35d6b0fd9eb863c328676deb639fe159f3c9f86c Mon Sep 17 00:00:00 2001 From: Hoyt Koepke Date: Sat, 21 Dec 2019 16:46:10 -0700 Subject: [PATCH 2/3] Adjust comment. --- src/model_server_v2/method_registry.hpp | 1 - src/model_server_v2/method_wrapper.hpp | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/model_server_v2/method_registry.hpp b/src/model_server_v2/method_registry.hpp index a6472d306c..c217ebbd49 100644 --- a/src/model_server_v2/method_registry.hpp +++ b/src/model_server_v2/method_registry.hpp @@ -23,7 +23,6 @@ namespace v2 { * this class by name, along with helpful error messages if the call is wrong. * * If the BaseClass is void, it provides a registry for standalone functions. - * TODO: implement this. */ template class method_registry { diff --git a/src/model_server_v2/method_wrapper.hpp b/src/model_server_v2/method_wrapper.hpp index 25189eefd7..dbe9a2a8f2 100644 --- a/src/model_server_v2/method_wrapper.hpp +++ b/src/model_server_v2/method_wrapper.hpp @@ -304,7 +304,7 @@ template ::value> = 0> variant_type __call(C*, const arg_v_type& arg_v, const Expanded&... args) const { From 03bcd61bb0f99d8ab59835b5a6bdd9b7a1f58e9c Mon Sep 17 00:00:00 2001 From: Hoyt Koepke Date: Mon, 23 Dec 2019 20:23:31 -0700 Subject: [PATCH 3/3] Bugfixes on model server v2. --- src/model_server_v2/demo.cpp | 1 + src/model_server_v2/method_parameters.hpp | 1 + src/model_server_v2/method_registry.hpp | 1 + src/model_server_v2/method_wrapper.hpp | 15 ++++++++------- src/model_server_v2/model_base.cpp | 1 + src/model_server_v2/model_base.hpp | 1 + src/model_server_v2/model_server.cpp | 1 + src/model_server_v2/model_server.hpp | 6 ++++-- src/model_server_v2/registration.hpp | 13 +++++++------ 9 files changed, 25 insertions(+), 15 deletions(-) diff --git a/src/model_server_v2/demo.cpp b/src/model_server_v2/demo.cpp index 76039f1ed0..f346a9a40d 100644 --- a/src/model_server_v2/demo.cpp +++ b/src/model_server_v2/demo.cpp @@ -1,4 +1,5 @@ #include +#include #include #include diff --git a/src/model_server_v2/method_parameters.hpp b/src/model_server_v2/method_parameters.hpp index b5b83132b0..430edcddb7 100644 --- a/src/model_server_v2/method_parameters.hpp +++ b/src/model_server_v2/method_parameters.hpp @@ -6,6 +6,7 @@ #ifndef TURI_METHOD_PARAMETERS_HPP_ #define TURI_METHOD_PARAMETERS_HPP_ +#include #include #include #include diff --git a/src/model_server_v2/method_registry.hpp b/src/model_server_v2/method_registry.hpp index c217ebbd49..588f6fdd00 100644 --- a/src/model_server_v2/method_registry.hpp +++ b/src/model_server_v2/method_registry.hpp @@ -6,6 +6,7 @@ #ifndef TURI_METHOD_REGISTRY_HPP_ #define TURI_METHOD_REGISTRY_HPP_ +#include #include #include #include diff --git a/src/model_server_v2/method_wrapper.hpp b/src/model_server_v2/method_wrapper.hpp index dbe9a2a8f2..b675f985b0 100644 --- a/src/model_server_v2/method_wrapper.hpp +++ b/src/model_server_v2/method_wrapper.hpp @@ -6,6 +6,7 @@ #ifndef TURI_METHOD_WRAPPER_HPP_ #define TURI_METHOD_WRAPPER_HPP_ +#include #include #include #include @@ -206,11 +207,11 @@ template struct _call_chooser { - static constexpr bool func_path = !is_method; - static constexpr bool const_method = is_const_method; - static constexpr bool _non_const_method = is_method && !is_const_method; - static constexpr bool bad_const_call = _non_const_method && std::is_const::value; - static constexpr bool method_path = _non_const_method && !std::is_const::value; + static constexpr int func_path = !is_method; + static constexpr int const_method_path = is_const_method; + static constexpr int _non_const_method = is_method && !is_const_method; + static constexpr int bad_const_call = _non_const_method && std::is_const::value; + static constexpr int method_path = _non_const_method && !std::is_const::value; }; // If it's a regular function. @@ -229,13 +230,13 @@ template ::const_method> = 0> + template ::const_method_path> = 0> variant_type _choose_call_path(C* inst, const argument_pack& args) const { return _call(dynamic_cast(inst), args); } // Non-const method. - template ::method_call> = 0> + template ::method_path> = 0> variant_type _choose_call_path(C* inst, const argument_pack& args) const { return _call(dynamic_cast(inst), args); } diff --git a/src/model_server_v2/model_base.cpp b/src/model_server_v2/model_base.cpp index 0f9ebe9859..7ec1d28911 100644 --- a/src/model_server_v2/model_base.cpp +++ b/src/model_server_v2/model_base.cpp @@ -1,3 +1,4 @@ +#include #include namespace turi { diff --git a/src/model_server_v2/model_base.hpp b/src/model_server_v2/model_base.hpp index 658be4dceb..2dfe5569e5 100644 --- a/src/model_server_v2/model_base.hpp +++ b/src/model_server_v2/model_base.hpp @@ -11,6 +11,7 @@ #include #include +#include #include #include #include diff --git a/src/model_server_v2/model_server.cpp b/src/model_server_v2/model_server.cpp index 4c16321c0c..09fdb75a73 100644 --- a/src/model_server_v2/model_server.cpp +++ b/src/model_server_v2/model_server.cpp @@ -3,6 +3,7 @@ * Use of this source code is governed by a BSD-3-clause license that can * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause */ +#include #include diff --git a/src/model_server_v2/model_server.hpp b/src/model_server_v2/model_server.hpp index 3e7871ca35..42845d4429 100644 --- a/src/model_server_v2/model_server.hpp +++ b/src/model_server_v2/model_server.hpp @@ -3,9 +3,10 @@ * Use of this source code is governed by a BSD-3-clause license that can * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause */ -#ifndef TURI_MODEL_SERVER_HPP -#define TURI_MODEL_SERVER_HPP +#ifndef TURI_MODEL_SERVER_V2_HPP +#define TURI_MODEL_SERVER_V2_HPP +#include #include #include #include @@ -13,6 +14,7 @@ #include #include #include +#include namespace turi { diff --git a/src/model_server_v2/registration.hpp b/src/model_server_v2/registration.hpp index 82b9246170..db112f0640 100644 --- a/src/model_server_v2/registration.hpp +++ b/src/model_server_v2/registration.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include namespace turi { @@ -21,15 +22,15 @@ namespace turi { class __model_server_static_class_registration_hook { public: inline __model_server_static_class_registration_hook( - model_server_impl::_registration_callback f) { - model_server().add_registration_callback(f); + turi::v2::model_server_impl::_registration_callback f) { + turi::v2::model_server().add_registration_callback(f); } }; #define REGISTER_MODEL(model) \ - static void __register_##model(model_server_impl& server) { \ + static void __register_##model(turi::v2::model_server_impl& server) { \ server.template register_new_model(); \ } \ \ @@ -43,8 +44,8 @@ class __model_server_static_class_registration_hook { class __model_server_static_function_registration_hook { public: inline __model_server_static_function_registration_hook( - model_server_impl::_registration_callback f) { - model_server().add_registration_callback(f); + turi::v2::model_server_impl::_registration_callback f) { + turi::v2::model_server().add_registration_callback(f); } }; @@ -52,7 +53,7 @@ class __model_server_static_function_registration_hook { #define REGISTER_NAMED_FUNCTION(name, function, ...) \ \ - static void register_function_##function(model_server_impl& server) {\ + static void register_function_##function(turi::v2::model_server_impl& server) {\ server.register_new_function(name, function, __VA_ARGS__);\ } \ __model_server_static_function_registration_hook \