From d8962d67f48c957d7cbe15f86bcb70c0f7b6074b Mon Sep 17 00:00:00 2001
From: Aditya Goel <48102515+adityagoel4512@users.noreply.github.com>
Date: Thu, 11 Jan 2024 23:50:07 +0000
Subject: [PATCH] RegexFullMatch operator (#18002)
### Description
### Motivation and Context
Closes https://github.com/microsoft/onnxruntime/issues/17594.
---
docs/OperatorKernels.md | 1 +
.../providers/cpu/cpu_execution_provider.cc | 2 +
.../providers/cpu/text/regex_full_match.cc | 35 ++++++
.../providers/cpu/text/regex_full_match.h | 20 +++
.../cpu/text/regex_full_match_test.cc | 119 ++++++++++++++++++
.../onnx_backend_test_series_filters.jsonc | 3 -
6 files changed, 177 insertions(+), 3 deletions(-)
create mode 100644 onnxruntime/core/providers/cpu/text/regex_full_match.cc
create mode 100644 onnxruntime/core/providers/cpu/text/regex_full_match.h
create mode 100644 onnxruntime/test/providers/cpu/text/regex_full_match_test.cc
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index c856d12141c9c..5e38789b65137 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -305,6 +305,7 @@ Do not modify directly.*
|||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
+|RegexFullMatch|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(string)
**T2** = tensor(bool)|
|Relu|*in* X:**T**
*out* Y:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8)|
|||13|**T** = tensor(double), tensor(float)|
|||[6, 12]|**T** = tensor(double), tensor(float)|
diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
index ba7738b405795..9cd0b3d0620af 100644
--- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
+++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
@@ -990,6 +990,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
#endif
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, StringConcat);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, RegexFullMatch);
// !!PLEASE READ BELOW!! Following that, add new entries above this comment
@@ -2449,6 +2450,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
#endif
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
};
for (auto& function_table_entry : function_table) {
diff --git a/onnxruntime/core/providers/cpu/text/regex_full_match.cc b/onnxruntime/core/providers/cpu/text/regex_full_match.cc
new file mode 100644
index 0000000000000..cc4a5a9ae4e61
--- /dev/null
+++ b/onnxruntime/core/providers/cpu/text/regex_full_match.cc
@@ -0,0 +1,35 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "regex_full_match.h"
+#include "core/common/common.h"
+
+namespace onnxruntime {
+ONNX_CPU_OPERATOR_KERNEL(
+ RegexFullMatch,
+ 20,
+ KernelDefBuilder()
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ RegexFullMatch);
+
+RegexFullMatch::RegexFullMatch(const OpKernelInfo& info) : OpKernel(info), re_{info.GetAttr("pattern")} {
+ ORT_ENFORCE(re_.ok(), "Invalid regex pattern: ", re_.pattern());
+}
+
+Status RegexFullMatch::Compute(OpKernelContext* context) const {
+ const auto* input_tensor = context->Input(0);
+ const auto input_data = input_tensor->template DataAsSpan();
+ auto* output_tensor = context->Output(0, input_tensor->Shape());
+ auto output_data = output_tensor->template MutableDataAsSpan();
+ auto output_iter = output_data.begin();
+ auto input_iter = input_data.begin();
+ while (input_iter != input_data.end()) {
+ *output_iter = RE2::FullMatch(*input_iter, re_);
+ input_iter++;
+ output_iter++;
+ }
+ return Status::OK();
+}
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cpu/text/regex_full_match.h b/onnxruntime/core/providers/cpu/text/regex_full_match.h
new file mode 100644
index 0000000000000..0d3f1f4b4b824
--- /dev/null
+++ b/onnxruntime/core/providers/cpu/text/regex_full_match.h
@@ -0,0 +1,20 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/framework/op_kernel.h"
+#include "re2/re2.h"
+
+namespace onnxruntime {
+
+class RegexFullMatch final : public OpKernel {
+ public:
+ explicit RegexFullMatch(const OpKernelInfo& info);
+ Status Compute(OpKernelContext* context) const override;
+
+ private:
+ RE2 re_;
+};
+
+} // namespace onnxruntime
diff --git a/onnxruntime/test/providers/cpu/text/regex_full_match_test.cc b/onnxruntime/test/providers/cpu/text/regex_full_match_test.cc
new file mode 100644
index 0000000000000..4aa5a0d44b678
--- /dev/null
+++ b/onnxruntime/test/providers/cpu/text/regex_full_match_test.cc
@@ -0,0 +1,119 @@
+#include "gtest/gtest.h"
+#include "test/providers/provider_test_utils.h"
+namespace onnxruntime {
+namespace test {
+
+static void RunTest(const std::initializer_list& dims, const std::initializer_list& input, const std::string& pattern, const std::initializer_list& output) {
+ OpTester test("RegexFullMatch", 20, kOnnxDomain);
+ test.AddAttribute("pattern", pattern);
+ test.AddInput("Input", dims, input);
+ test.AddOutput("Output", dims, output);
+ test.Run();
+}
+
+TEST(RegexFullMatch, WebsiteMatch) {
+ RunTest({3, 1}, {"www.google.com", "www.facebook.com", "www.bbc.co.uk"}, R"(www\.[\w.-]+\.\bcom\b)", {true, true, false});
+}
+
+TEST(RegexFullMatch, EmailMatch) {
+ RunTest({2, 2}, {"account@gmail.com", "account@hotmail.com", "not email", "account@yahoo.com"}, R"((\W|^)[\w.\-]{0,25}@(yahoo|gmail)\.com(\W|$))", {true, false, false, true});
+}
+
+TEST(RegexFullMatch, MultibyteMatch) {
+ RunTest({1, 2}, {"ä", "a"}, "ä", {true, false});
+ RunTest({
+ 1,
+ },
+ {"une cédille like in Besançon"}, R"(.*Besançon.*)", {
+ true,
+ });
+ RunTest({
+ 1,
+ },
+ {"une cédille like in Besançon"}, R"(.*Besancon.*)", {
+ false,
+ });
+ RunTest({
+ 1,
+ },
+ {"Mit freundlichen Grüßen"}, R"(.*Grüßen$)", {
+ true,
+ });
+ RunTest({
+ 1,
+ },
+ {"Mit freundlichen Grüßen"}, R"(.*Grußen$)", {
+ false,
+ });
+ RunTest({
+ 3,
+ },
+ {"HПонедельник", "Понедельник", "недельник"}, R"(^Понед.*)", {
+ false,
+ true,
+ false,
+ });
+ RunTest({
+ 3,
+ },
+ {"thank you", "どうもありがとうございます", "こんにちは世界"}, R"(^こんにちは世界.*)", {
+ false,
+ false,
+ true,
+ });
+ RunTest({
+ 3,
+ },
+ {"नमस्ते, आपसे मिलकर अच्छा लगा", "नमस्ते", "स्वागत एवं नमस्ते"}, R"(.+नमस्ते$)", {
+ false,
+ false,
+ true,
+ });
+ RunTest({
+ 3,
+ },
+ {"你好,你好吗?", "你好呀", "你好呀!"}, R"(^你好.*\?$)", {
+ true,
+ false,
+ false,
+ });
+}
+
+TEST(RegexFullMatch, InvalidPattern) {
+ OpTester test("RegexFullMatch", 20, kOnnxDomain);
+ test.AddAttribute("pattern", R"([a-z)");
+ test.AddInput("Input", {
+ 1,
+ },
+ {
+ "abcdef",
+ });
+ test.AddOutput("Output", {
+ 1,
+ },
+ {
+ false,
+ });
+ test.Run(BaseTester::ExpectResult::kExpectFailure, "Invalid regex pattern: [a-z");
+}
+
+TEST(RegexFullMatch, NonUtf8Pattern) {
+ uint8_t invalid_bytes[] = {0xC0, 0xC1, 0x41, 0x42, 0xC3, 0x80, 0xC2, 0x80, 0xC2, 0xC3, 0xC4, 0x00};
+ OpTester test("RegexFullMatch", 20, kOnnxDomain);
+ test.AddAttribute("pattern", std::string((char*)invalid_bytes, sizeof(invalid_bytes)));
+ test.AddInput("Input", {
+ 1,
+ },
+ {
+ "abcd",
+ });
+ test.AddOutput("Output", {
+ 1,
+ },
+ {
+ false,
+ });
+ test.Run(BaseTester::ExpectResult::kExpectFailure, "Invalid regex pattern");
+}
+} // namespace test
+} // namespace onnxruntime
diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc
index fb33ef0a1da3c..c2ca5f860a107 100644
--- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc
+++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc
@@ -248,9 +248,6 @@
"^test_image_decoder_decode_pnm_rgb",
"^test_image_decoder_decode_tiff_rgb",
"^test_image_decoder_decode_webp_rgb",
- "^test_regex_full_match_basic",
- "^test_regex_full_match_email_domain",
- "^test_regex_full_match_empty",
"^test_string_split_basic",
"^test_string_split_consecutive_delimiters",
"^test_string_split_empty_string_delimiter",