diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index c856d12141c9c..5e38789b65137 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -305,6 +305,7 @@ Do not modify directly.* |||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|RegexFullMatch|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(string)
**T2** = tensor(bool)| |Relu|*in* X:**T**
*out* Y:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8)| |||13|**T** = tensor(double), tensor(float)| |||[6, 12]|**T** = tensor(double), tensor(float)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index ba7738b405795..9cd0b3d0620af 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -990,6 +990,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, #endif class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, StringConcat); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, RegexFullMatch); // !!PLEASE READ BELOW!! Following that, add new entries above this comment @@ -2449,6 +2450,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { #endif BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cpu/text/regex_full_match.cc b/onnxruntime/core/providers/cpu/text/regex_full_match.cc new file mode 100644 index 0000000000000..cc4a5a9ae4e61 --- /dev/null +++ b/onnxruntime/core/providers/cpu/text/regex_full_match.cc @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "regex_full_match.h" +#include "core/common/common.h" + +namespace onnxruntime { +ONNX_CPU_OPERATOR_KERNEL( + RegexFullMatch, + 20, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + RegexFullMatch); + +RegexFullMatch::RegexFullMatch(const OpKernelInfo& info) : OpKernel(info), re_{info.GetAttr("pattern")} { + ORT_ENFORCE(re_.ok(), "Invalid regex pattern: ", re_.pattern()); +} + +Status RegexFullMatch::Compute(OpKernelContext* context) const { + const auto* input_tensor = context->Input(0); + const auto input_data = input_tensor->template DataAsSpan(); + auto* output_tensor = context->Output(0, input_tensor->Shape()); + auto output_data = output_tensor->template MutableDataAsSpan(); + auto output_iter = output_data.begin(); + auto input_iter = input_data.begin(); + while (input_iter != input_data.end()) { + *output_iter = RE2::FullMatch(*input_iter, re_); + input_iter++; + output_iter++; + } + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/text/regex_full_match.h b/onnxruntime/core/providers/cpu/text/regex_full_match.h new file mode 100644 index 0000000000000..0d3f1f4b4b824 --- /dev/null +++ b/onnxruntime/core/providers/cpu/text/regex_full_match.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" +#include "re2/re2.h" + +namespace onnxruntime { + +class RegexFullMatch final : public OpKernel { + public: + explicit RegexFullMatch(const OpKernelInfo& info); + Status Compute(OpKernelContext* context) const override; + + private: + RE2 re_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/text/regex_full_match_test.cc b/onnxruntime/test/providers/cpu/text/regex_full_match_test.cc new file mode 100644 index 0000000000000..4aa5a0d44b678 --- /dev/null +++ b/onnxruntime/test/providers/cpu/text/regex_full_match_test.cc @@ -0,0 +1,119 @@ +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +namespace onnxruntime { +namespace test { + +static void RunTest(const std::initializer_list& dims, const std::initializer_list& input, const std::string& pattern, const std::initializer_list& output) { + OpTester test("RegexFullMatch", 20, kOnnxDomain); + test.AddAttribute("pattern", pattern); + test.AddInput("Input", dims, input); + test.AddOutput("Output", dims, output); + test.Run(); +} + +TEST(RegexFullMatch, WebsiteMatch) { + RunTest({3, 1}, {"www.google.com", "www.facebook.com", "www.bbc.co.uk"}, R"(www\.[\w.-]+\.\bcom\b)", {true, true, false}); +} + +TEST(RegexFullMatch, EmailMatch) { + RunTest({2, 2}, {"account@gmail.com", "account@hotmail.com", "not email", "account@yahoo.com"}, R"((\W|^)[\w.\-]{0,25}@(yahoo|gmail)\.com(\W|$))", {true, false, false, true}); +} + +TEST(RegexFullMatch, MultibyteMatch) { + RunTest({1, 2}, {"ä", "a"}, "ä", {true, false}); + RunTest({ + 1, + }, + {"une cédille like in Besançon"}, R"(.*Besançon.*)", { + true, + }); + RunTest({ + 1, + }, + {"une cédille like in Besançon"}, R"(.*Besancon.*)", { + false, + }); + RunTest({ + 1, + }, + {"Mit freundlichen Grüßen"}, R"(.*Grüßen$)", { + true, + }); + RunTest({ + 1, + }, + {"Mit freundlichen Grüßen"}, R"(.*Grußen$)", { + false, + }); + RunTest({ + 3, + }, + {"HПонедельник", "Понедельник", "недельник"}, R"(^Понед.*)", { + false, + true, + false, + }); + RunTest({ + 3, + }, + {"thank you", "どうもありがとうございます", "こんにちは世界"}, R"(^こんにちは世界.*)", { + false, + false, + true, + }); + RunTest({ + 3, + }, + {"नमस्ते, आपसे मिलकर अच्छा लगा", "नमस्ते", "स्वागत एवं नमस्ते"}, R"(.+नमस्ते$)", { + false, + false, + true, + }); + RunTest({ + 3, + }, + {"你好,你好吗?", "你好呀", "你好呀!"}, R"(^你好.*\?$)", { + true, + false, + false, + }); +} + +TEST(RegexFullMatch, InvalidPattern) { + OpTester test("RegexFullMatch", 20, kOnnxDomain); + test.AddAttribute("pattern", R"([a-z)"); + test.AddInput("Input", { + 1, + }, + { + "abcdef", + }); + test.AddOutput("Output", { + 1, + }, + { + false, + }); + test.Run(BaseTester::ExpectResult::kExpectFailure, "Invalid regex pattern: [a-z"); +} + +TEST(RegexFullMatch, NonUtf8Pattern) { + uint8_t invalid_bytes[] = {0xC0, 0xC1, 0x41, 0x42, 0xC3, 0x80, 0xC2, 0x80, 0xC2, 0xC3, 0xC4, 0x00}; + OpTester test("RegexFullMatch", 20, kOnnxDomain); + test.AddAttribute("pattern", std::string((char*)invalid_bytes, sizeof(invalid_bytes))); + test.AddInput("Input", { + 1, + }, + { + "abcd", + }); + test.AddOutput("Output", { + 1, + }, + { + false, + }); + test.Run(BaseTester::ExpectResult::kExpectFailure, "Invalid regex pattern"); +} +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index fb33ef0a1da3c..c2ca5f860a107 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -248,9 +248,6 @@ "^test_image_decoder_decode_pnm_rgb", "^test_image_decoder_decode_tiff_rgb", "^test_image_decoder_decode_webp_rgb", - "^test_regex_full_match_basic", - "^test_regex_full_match_email_domain", - "^test_regex_full_match_empty", "^test_string_split_basic", "^test_string_split_consecutive_delimiters", "^test_string_split_empty_string_delimiter",