diff --git a/runtime/executor/test/kernel_integration_test.cpp b/runtime/executor/test/kernel_integration_test.cpp index bec291018a..ccf9611520 100644 --- a/runtime/executor/test/kernel_integration_test.cpp +++ b/runtime/executor/test/kernel_integration_test.cpp @@ -90,8 +90,7 @@ struct KernelControl { // TensorMeta(ScalarType::Float, contiguous), // other // TensorMeta(ScalarType::Float, contiguous), // out // TensorMeta(ScalarType::Float, contiguous)}; // out (repeated) - KernelKey key = torch::executor::KernelKey( - "v0/\x06;\x00\x01|\x06;\x00\x01|\x06;\x00\x01|\x06;\x00\x01\xff"); + KernelKey key = torch::executor::KernelKey("v1/6;0,1|6;0,1|6;0,1|6;0,1"); Kernel kernel = torch::executor::Kernel( "aten::add.out", key, KernelControl::kernel_hook); Error err = torch::executor::register_kernels({kernel}); diff --git a/runtime/executor/test/kernel_resolution_test.cpp b/runtime/executor/test/kernel_resolution_test.cpp index 898cd1fd65..100babe21d 100644 --- a/runtime/executor/test/kernel_resolution_test.cpp +++ b/runtime/executor/test/kernel_resolution_test.cpp @@ -103,8 +103,7 @@ TEST_F(KernelResolutionTest, ResolveKernelKeySuccess) { // TensorMeta(ScalarType::Float, contiguous), // TensorMeta(ScalarType::Float, contiguous), // TensorMeta(ScalarType::Float, contiguous)}; - KernelKey key = KernelKey( - "v0/\x06;\x00\x01|\x06;\x00\x01|\x06;\x00\x01|\x06;\x00\x01\xff"); + KernelKey key = KernelKey("v1/6;0,1|6;0,1|6;0,1|6;0,1"); Kernel kernel_1 = Kernel( "aten::add.out", key, [](KernelRuntimeContext& context, EValue** stack) { (void)context; diff --git a/runtime/kernel/operator_registry.cpp b/runtime/kernel/operator_registry.cpp index c439d73375..629077ca7e 100644 --- a/runtime/kernel/operator_registry.cpp +++ b/runtime/kernel/operator_registry.cpp @@ -91,33 +91,50 @@ bool hasOpsFn(const char* name, ArrayRef kernel_key) { return getOperatorRegistry().hasOpsFn(name, kernel_key); } -static void make_kernel_key_string(ArrayRef key, char* buf) { +static int copy_char_as_number_to_buf(char num, char* buf) { + if ((char)num < 10) { + *buf = '0' + (char)num; + buf += 1; + return 1; + } else { + *buf = '0' + ((char)num) / 10; + buf += 1; + *buf = '0' + ((char)num) % 10; + buf += 1; + return 2; + } +} + +void make_kernel_key_string(ArrayRef key, char* buf); + +void make_kernel_key_string(ArrayRef key, char* buf) { if (key.empty()) { // If no tensor is present in an op, kernel key does not apply - *buf = 0xff; return; } - strncpy(buf, "v0/", 3); + strncpy(buf, "v1/", 3); buf += 3; for (size_t i = 0; i < key.size(); i++) { auto& meta = key[i]; - *buf = (char)meta.dtype_; - buf += 1; + buf += copy_char_as_number_to_buf((char)meta.dtype_, buf); *buf = ';'; buf += 1; - memcpy(buf, (char*)meta.dim_order_.data(), meta.dim_order_.size()); - buf += meta.dim_order_.size(); - *buf = (i < (key.size() - 1)) ? '|' : 0xff; + for (int j = 0; j < meta.dim_order_.size(); j++) { + buf += copy_char_as_number_to_buf((char)meta.dim_order_[j], buf); + if (j != meta.dim_order_.size() - 1) { + *buf = ','; + buf += 1; + } + } + *buf = (i < (key.size() - 1)) ? '|' : 0x00; buf += 1; } } -constexpr int BUF_SIZE = 307; - bool OperatorRegistry::hasOpsFn( const char* name, ArrayRef meta_list) { - char buf[BUF_SIZE] = {0}; + char buf[KernelKey::MAX_SIZE] = {0}; make_kernel_key_string(meta_list, buf); KernelKey kernel_key = KernelKey(buf); @@ -140,7 +157,7 @@ const OpFunction& getOpsFn(const char* name, ArrayRef kernel_key) { const OpFunction& OperatorRegistry::getOpsFn( const char* name, ArrayRef meta_list) { - char buf[BUF_SIZE] = {0}; + char buf[KernelKey::MAX_SIZE] = {0}; make_kernel_key_string(meta_list, buf); KernelKey kernel_key = KernelKey(buf); diff --git a/runtime/kernel/operator_registry.h b/runtime/kernel/operator_registry.h index 9c05afaa6f..ae477fbde5 100644 --- a/runtime/kernel/operator_registry.h +++ b/runtime/kernel/operator_registry.h @@ -102,21 +102,18 @@ struct TensorMeta { * registered. * * The format of a kernel key data is a string: - * "v/|...\xff" - * Size: Up to 307 1 1 1 (18 +1) * 16 + * "v/|..." + * Size: Up to 691 1 1 1 (42 +1) * 16 * Assuming max number of tensors is 16 ^ - * Kernel key version is v0 for now. If the kernel key format changes, + * Kernel key version is v1 for now. If the kernel key format changes, * update the version to avoid breaking pre-existing kernel keys. - * Example: v0/0x07;0x00 0x01 0x02 0x03 \xff + * Example: v1/7;0,1,2,3 * The kernel key has only one tensor: a double tensor with dimension 0, 1, 2, 3 * - * The string is a byte array and contains non-printable characters. It must - * be terminated with a '\xff' so 0xff cannot be a scalar type. - * - * Each tensor_meta has the following format: ";" - * Size: Up to 18 1 1 16 - * Assuming that the max number of dims is 16 ^ - * Example: 0x07;0x00 0x01 0x02 0x03 for [double; 0, 1, 2, 3] + * Each tensor_meta has the following format: ";" + * Size: Up to 42 1-2 1 24 (1 byte for 0-9; 2 + * for 10-15) + 15 commas Assuming that the max number of dims is 16 ^ Example: + * 7;0,1,2,3 for [double; 0, 1, 2, 3] * * IMPORTANT: * Users should not construct a kernel key manually. Instead, it should be @@ -129,7 +126,7 @@ struct KernelKey { /* implicit */ KernelKey(const char* kernel_key_data) : kernel_key_data_(kernel_key_data), is_fallback_(false) {} - constexpr static char TERMINATOR = 0xff; + constexpr static int MAX_SIZE = 691; bool operator==(const KernelKey& other) const { return this->equals(other); @@ -146,16 +143,7 @@ struct KernelKey { if (is_fallback_) { return true; } - size_t i; - for (i = 0; kernel_key_data_[i] != TERMINATOR && - other.kernel_key_data_[i] != TERMINATOR; - i++) { - if (kernel_key_data_[i] != other.kernel_key_data_[i]) { - return false; - } - } - return kernel_key_data_[i] == TERMINATOR && - other.kernel_key_data_[i] == TERMINATOR; + return strncmp(kernel_key_data_, other.kernel_key_data_, MAX_SIZE) == 0; } bool is_fallback() const { diff --git a/runtime/kernel/test/operator_registry_test.cpp b/runtime/kernel/test/operator_registry_test.cpp index 697f51804c..e3b99c1241 100644 --- a/runtime/kernel/test/operator_registry_test.cpp +++ b/runtime/kernel/test/operator_registry_test.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include @@ -43,27 +44,7 @@ TEST_F(OperatorRegistryTest, RegisterOpsMoreThanOnceDie) { ET_EXPECT_DEATH({ auto res = register_kernels(kernels_array); }, ""); } -void make_kernel_key( - std::vector>> - tensors, - char* buf) { - char* start = buf; - strncpy(buf, "v0/", 3); - buf += 3; - for (size_t i = 0; i < tensors.size(); i++) { - auto& tensor = tensors[i]; - *buf = (char)tensor.first; - buf += 1; - *buf = ';'; - buf += 1; - memcpy(buf, (char*)tensor.second.data(), tensor.second.size()); - buf += tensor.second.size(); - *buf = (i < (tensors.size() - 1)) ? '|' : 0xff; - buf += 1; - } -} - -constexpr int BUF_SIZE = 307; +constexpr int BUF_SIZE = KernelKey::MAX_SIZE; TEST_F(OperatorRegistryTest, KernelKeyEquals) { char buf_long_contiguous[BUF_SIZE]; diff --git a/runtime/kernel/test/targets.bzl b/runtime/kernel/test/targets.bzl index de00865e70..f7bbfc21e4 100644 --- a/runtime/kernel/test/targets.bzl +++ b/runtime/kernel/test/targets.bzl @@ -13,6 +13,7 @@ def define_common_targets(): srcs = [ "operator_registry_test.cpp", ], + headers = ["test_util.h"], deps = [ "//executorch/runtime/kernel:operator_registry", "//executorch/runtime/kernel:kernel_runtime_context", diff --git a/runtime/kernel/test/test_util.h b/runtime/kernel/test/test_util.h new file mode 100644 index 0000000000..02078c6015 --- /dev/null +++ b/runtime/kernel/test/test_util.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +namespace torch { +namespace executor { +void make_kernel_key_string(ArrayRef key, char* buf); + +inline void make_kernel_key( + std::vector>> + tensors, + char* buf) { + std::vector meta; + for (auto& t : tensors) { + ArrayRef dim_order( + t.second.data(), t.second.size()); + meta.emplace_back(t.first, dim_order); + } + auto meatadata = ArrayRef(meta.data(), meta.size()); + make_kernel_key_string(meatadata, buf); +} + +} // namespace executor +} // namespace torch