forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add template based unboxing (pytorch#1284)
Summary: Pull Request resolved: pytorch#1284 Adding a new feature to allow users to bypass codegen and register their kernels directly. This is very useful for custom kernels for custom ops. Example usage: ``` Tensor& my_op(RuntimeContext& ctx, const Tensor& self, const Tensor& other, Tensor& out) { // ... return out; } Kernel my_kernel = Kernel.make_boxed_kernel("my_ns::my_op",EXECUTORCH_FN(my_op)); register_kernels({my_kernel}); ``` imported-using-ghimport Test Plan: Imported from OSS Reviewed By: iseeyuan Differential Revision: D51553099 Pulled By: larryliu0820 fbshipit-source-id: 0e3877a481d58c4e64c7767f7693537407ab27c5
- Loading branch information
1 parent
4dfb637
commit 75284d2
Showing
9 changed files
with
557 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
This header file `make_boxed_from_unboxed_functor.h` defines a template that can be used to create a boxed version of an unboxed functor. It is part of the executorch extension in the torch namespace. | ||
## Requirements | ||
This header requires C++17 or later. | ||
## Usage | ||
The template takes an unboxed function pointer and wraps it into a functor that takes `RuntimeContext` and `EValues` as inputs and returns void. The wrapped functor will unbox all inputs and forward them to the unboxed kernel. | ||
Here is an example of how to use the template: | ||
```C++ | ||
Tensor& my_op(RuntimeContext& ctx, const Tensor& self, const Tensor& other, Tensor& out) { | ||
// ... | ||
return out; | ||
} | ||
Kernel my_kernel = Kernel::make_boxed_kernel("my_ns::my_op", EXECUTORCH_FN(my_op)); | ||
static auto res = register_kernels({my_kernel}); | ||
``` | ||
Alternatively, you can use the EXECUTORCH_LIBRARY macro to simplify the process: | ||
```C++ | ||
EXECUTORCH_LIBRARY(my_ns, "my_op", my_op); | ||
``` | ||
## Details | ||
The template uses a lot of C++17 features to convert each EValue to the inferred argument type. It checks if the first argument is `RuntimeContext`, and if so, it removes it. The call method of the `WrapUnboxedIntoFunctor` struct calls the unboxed function with the corresponding arguments. | ||
The `EXECUTORCH_LIBRARY` macro registers the kernel for the operation and stores the result in a static variable. | ||
## Note | ||
The `RuntimeContext` is a placeholder for a context that will be passed to kernels. It is currently empty, but it is planned to be used for kernel temp memory allocation and error handling in the future. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Any targets that should be shared between fbcode and xplat must be defined in | ||
# targets.bzl. This file can contain fbcode-only targets. | ||
|
||
load(":targets.bzl", "define_common_targets") | ||
|
||
oncall("executorch") | ||
|
||
define_common_targets() |
145 changes: 145 additions & 0 deletions
145
extension/kernel_util/make_boxed_from_unboxed_functor.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
//===----------------------------------------------------------------------===// | ||
/// \file extension/kernel_util/make_boxed_from_unboxed_functor.h | ||
/// Defines a template that can be used to create a boxed version of an unboxed | ||
/// functor. | ||
/// Example usage: | ||
/// ``` | ||
/// Tensor& | ||
/// my_op(RuntimeContext& ctx, const Tensor& self, const Tensor& other, Tensor& | ||
/// out) { | ||
/// // ... | ||
/// return out; | ||
/// } | ||
/// | ||
/// Kernel my_kernel = Kernel::make_boxed_kernel("my_ns::my_op", | ||
/// EXECUTORCH_FN(my_op)); | ||
/// static auto res = register_kernels({my_kernel}); | ||
/// ``` | ||
/// Or simply: | ||
/// ``` | ||
/// EXECUTORCH_LIBRARY(my_ns, "my_op", my_op); | ||
/// ``` | ||
/// | ||
/// The trick here is to convert each EValue to inferred argument type. This | ||
/// uses a lot of C++17 features. | ||
//===----------------------------------------------------------------------===// | ||
|
||
#pragma once | ||
#if __cplusplus < 201703L | ||
#error "This header requires C++17" | ||
#endif | ||
|
||
#include <executorch/extension/kernel_util/meta_programming.h> | ||
#include <executorch/extension/kernel_util/type_list.h> | ||
#include <executorch/runtime/core/evalue.h> | ||
#include <executorch/runtime/core/exec_aten/exec_aten.h> | ||
#include <executorch/runtime/kernel/operator_registry.h> | ||
#include <cstdlib> | ||
#include <memory> | ||
#include <type_traits> | ||
#include <typeinfo> | ||
|
||
namespace torch { | ||
namespace executor { | ||
|
||
class KernelRuntimeContext; // Forward declaration | ||
using RuntimeContext = KernelRuntimeContext; // TODO(T147221312): Remove | ||
|
||
// evalue_to_arg | ||
template <class T> | ||
struct decay_if_not_tensor final { | ||
using type = std::decay_t<T>; | ||
}; | ||
template <> | ||
struct decay_if_not_tensor<exec_aten::Tensor&> final { | ||
using type = exec_aten::Tensor&; | ||
}; | ||
template <> | ||
struct decay_if_not_tensor<const exec_aten::Tensor&> final { | ||
using type = const exec_aten::Tensor&; | ||
}; | ||
|
||
template <class T> | ||
struct evalue_to_arg final { | ||
static T call(EValue& v) { | ||
return std::move(v).to<T>(); | ||
} | ||
}; | ||
|
||
template <> | ||
struct evalue_to_arg<exec_aten::Tensor&> final { | ||
static exec_aten::Tensor& call(EValue& v) { | ||
return v.toTensor(); | ||
} | ||
}; | ||
|
||
template <> | ||
struct evalue_to_arg<const exec_aten::Tensor&> final { | ||
static const exec_aten::Tensor& call(EValue& v) { | ||
return v.toTensor(); | ||
} | ||
}; | ||
// Call functor with args from stack | ||
|
||
template <class Functor, size_t... evalue_arg_indices, typename... ArgTypes> | ||
void call_functor_with_args_from_stack_( | ||
RuntimeContext& ctx, | ||
EValue** stack, | ||
std::index_sequence<evalue_arg_indices...>, | ||
typelist<ArgTypes...>*) { | ||
(*Functor::func_ptr())( | ||
ctx, | ||
evalue_to_arg<typename decay_if_not_tensor<ArgTypes>::type>::call( | ||
*stack[evalue_arg_indices])...); | ||
} | ||
|
||
/** | ||
* WrapUnboxedIntoFunctor: Given a function pointer, wrap it into a functor that | ||
* takes EValues as input and returns void. The wrapped functor will unbox all | ||
* inputs and forward them to unboxed kernel. | ||
*/ | ||
template <class FuncType> | ||
struct WrapUnboxedIntoFunctor { | ||
static_assert( | ||
is_compile_time_function_pointer<FuncType>::value, | ||
"Can't handle function other than EXECUTORCH_FN"); | ||
using TrueType = typename FuncType::FuncType; | ||
using ReturnType = typename infer_function_traits_t<TrueType>::return_type; | ||
using ArgsType = typename infer_function_traits_t<TrueType>::parameter_types; | ||
// check if the first argument is RuntimeContext, if so, remove it | ||
static constexpr bool first_arg_is_context = std::is_same< | ||
RuntimeContext, | ||
std::remove_reference_t<head_with_default_t<void, ArgsType>>>::value; | ||
using ContextRemovedArgsType = std::conditional_t< | ||
first_arg_is_context, | ||
drop_if_nonempty_t<ArgsType, 1>, | ||
ArgsType>; | ||
|
||
static void call(RuntimeContext& ctx, EValue** stack) { | ||
constexpr size_t num_inputs = size<ContextRemovedArgsType>::value; | ||
return call_functor_with_args_from_stack_<FuncType>( | ||
ctx, | ||
stack, | ||
std::make_index_sequence<num_inputs>(), | ||
static_cast<ContextRemovedArgsType*>(nullptr)); | ||
} | ||
}; | ||
|
||
template <typename FuncType> | ||
static Kernel make_boxed_kernel(const char* name, FuncType) { | ||
return Kernel(name, WrapUnboxedIntoFunctor<FuncType>::call); | ||
} | ||
|
||
#define EXECUTORCH_LIBRARY(ns, op_name, func) \ | ||
static auto res_##ns = register_kernels( \ | ||
make_boxed_kernel(#ns "::" op_name, EXECUTORCH_FN(func))) | ||
} // namespace executor | ||
} // namespace torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#pragma once | ||
#if __cplusplus < 201703L | ||
#error "This header requires C++17" | ||
#endif | ||
|
||
#include <executorch/extension/kernel_util/type_list.h> | ||
#include <cstdlib> | ||
#include <memory> | ||
#include <type_traits> | ||
#include <typeinfo> | ||
|
||
namespace torch { | ||
namespace executor { | ||
|
||
// Check if a given type is a function | ||
template <class T> | ||
struct is_function_type : std::false_type {}; | ||
template <class Result, class... Args> | ||
struct is_function_type<Result(Args...)> : std::true_type {}; | ||
template <class T> | ||
using is_function_type_t = typename is_function_type<T>::type; | ||
|
||
// A compile-time wrapper around a function pointer | ||
template <class FuncType_, FuncType_* func_ptr_> | ||
struct CompileTimeFunctionPointer final { | ||
static_assert( | ||
is_function_type<FuncType_>::value, | ||
"EXECUTORCH_FN can only wrap function types."); | ||
using FuncType = FuncType_; | ||
|
||
static constexpr FuncType* func_ptr() { | ||
return func_ptr_; | ||
} | ||
}; | ||
|
||
// Check if a given type is a compile-time function pointer | ||
template <class T> | ||
struct is_compile_time_function_pointer : std::false_type {}; | ||
template <class FuncType, FuncType* func_ptr> | ||
struct is_compile_time_function_pointer< | ||
CompileTimeFunctionPointer<FuncType, func_ptr>> : std::true_type {}; | ||
|
||
#define EXECUTORCH_FN_TYPE(func) \ | ||
CompileTimeFunctionPointer< \ | ||
std::remove_pointer_t<std::remove_reference_t<decltype(func)>>, \ | ||
func> | ||
#define EXECUTORCH_FN(func) EXECUTORCH_FN_TYPE(func)() | ||
|
||
/** | ||
* strip_class: helper to remove the class type from pointers to `operator()`. | ||
*/ | ||
template <typename T> | ||
struct strip_class {}; | ||
template <typename Class, typename Result, typename... Args> | ||
struct strip_class<Result (Class::*)(Args...)> { | ||
using type = Result(Args...); | ||
}; | ||
template <typename Class, typename Result, typename... Args> | ||
struct strip_class<Result (Class::*)(Args...) const> { | ||
using type = Result(Args...); | ||
}; | ||
template <typename T> | ||
using strip_class_t = typename strip_class<T>::type; | ||
|
||
/** | ||
* Access information about result type or arguments from a function type. | ||
* Example: | ||
* using A = function_traits<int (float, double)>::return_type // A == int | ||
* using A = function_traits<int (float, double)>::parameter_types::tuple_type | ||
* // A == tuple<float, double> | ||
*/ | ||
template <class Func> | ||
struct function_traits { | ||
static_assert( | ||
!std::is_same<Func, Func>::value, | ||
"In function_traits<Func>, Func must be a plain function type."); | ||
}; | ||
template <class Result, class... Args> | ||
struct function_traits<Result(Args...)> { | ||
using func_type = Result(Args...); | ||
using return_type = Result; | ||
using parameter_types = typelist<Args...>; | ||
static constexpr auto number_of_parameters = sizeof...(Args); | ||
}; | ||
|
||
/** | ||
* infer_function_traits: creates a `function_traits` type for a simple | ||
* function (pointer) or functor (lambda/struct). Currently does not support | ||
* class methods. | ||
*/ | ||
template <typename Functor> | ||
struct infer_function_traits { | ||
using type = function_traits<strip_class_t<decltype(&Functor::operator())>>; | ||
}; | ||
template <typename Result, typename... Args> | ||
struct infer_function_traits<Result (*)(Args...)> { | ||
using type = function_traits<Result(Args...)>; | ||
}; | ||
template <typename Result, typename... Args> | ||
struct infer_function_traits<Result(Args...)> { | ||
using type = function_traits<Result(Args...)>; | ||
}; | ||
template <typename T> | ||
using infer_function_traits_t = typename infer_function_traits<T>::type; | ||
|
||
} // namespace executor | ||
} // namespace torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") | ||
|
||
def define_common_targets(): | ||
"""Defines targets that should be shared between fbcode and xplat. | ||
The directory containing this targets.bzl file should also contain both | ||
TARGETS and BUCK files that call this function. | ||
""" | ||
|
||
runtime.cxx_library( | ||
name = "kernel_util", | ||
srcs = [], | ||
exported_headers = [ | ||
"make_boxed_from_unboxed_functor.h", | ||
"meta_programming.h", | ||
"type_list.h", | ||
], | ||
visibility = [ | ||
"//executorch/...", | ||
"@EXECUTORCH_CLIENTS", | ||
], | ||
exported_deps = [ | ||
"//executorch/runtime/core:core", | ||
"//executorch/runtime/core:evalue", | ||
"//executorch/runtime/kernel:kernel_includes", | ||
"//executorch/runtime/kernel:kernel_runtime_context", | ||
"//executorch/runtime/kernel:operator_registry", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Any targets that should be shared between fbcode and xplat must be defined in | ||
# targets.bzl. This file can contain fbcode-only targets. | ||
|
||
load(":targets.bzl", "define_common_targets") | ||
|
||
oncall("executorch") | ||
|
||
define_common_targets() |
Oops, something went wrong.