From 3078acf1a62ef4ef40a9f833edb7a3a9e3038f7a Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Tue, 17 Oct 2023 21:39:03 +0100 Subject: [PATCH] RegexFullMatch operator --- cmake/deps.txt | 6 +-- .../providers/cpu/cpu_execution_provider.cc | 5 +++ .../core/providers/cpu/nn/regex_full_match.cc | 41 +++++++++++++++++++ .../core/providers/cpu/nn/regex_full_match.h | 17 ++++++++ .../providers/cpu/nn/regex_full_match_test.cc | 24 +++++++++++ .../onnx_backend_test_series_filters.jsonc | 3 -- 6 files changed, 90 insertions(+), 6 deletions(-) create mode 100644 onnxruntime/core/providers/cpu/nn/regex_full_match.cc create mode 100644 onnxruntime/core/providers/cpu/nn/regex_full_match.h create mode 100644 onnxruntime/test/providers/cpu/nn/regex_full_match_test.cc diff --git a/cmake/deps.txt b/cmake/deps.txt index ff07803013071..5266626880bd2 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -26,9 +26,9 @@ flatbuffers;https://github.com/google/flatbuffers/archive/refs/tags/v1.12.0.zip; fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip;b985f6985a05a1c03ff1bb71190f66d8f98a1494 fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1 google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.7.0.zip;e97c368b176e8614e3f1bf13dd9abcf6a7ad9908 -google_nsync;https://github.com/google/nsync/archive/refs/tags/1.26.0.zip;5e7c00ef6bf5b787386fc040067903ec774e2752 -googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583ee817743c9b3a42b4.zip;5e3a61db2aa975cfd0f97ba92c818744e7fa7034 -googlexnnpack;https://github.com/google/XNNPACK/archive/0da379fc4808f9601faef392352018c741c0f297.zip;663883491e380b628e0a5b162b5f2658032fae73 +google_nsync;https://github.com/google/nsync/archive/refs/tags/1.23.0.zip;f3233450cf7156fc0bedd1b0e884eddec264897c +googletest;https://github.com/google/googletest/archive/519beb0e52c842729b4b53731d27c0e0c32ab4a2.zip;4b3c37972e4c1bef1185d46f702082f8772ee73f +googlexnnpack;https://github.com/google/XNNPACK/archive/003c580e696a774afdc984996ee909b7c8d8128c.zip;9f192e3f15e1e37ae9c78d53eeea47e45c5eb31c json;https://github.com/nlohmann/json/archive/refs/tags/v3.10.5.zip;f257f8dc27c5b8c085dc887b40cddd18ae1f725c microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index f60c7ddac5c05..ca31e7f03ca6f 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -989,6 +989,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2FNUZ, IsNaN); #endif class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, RegexFullMatch); // !!PLEASE READ BELOW!! Following that, add new entries above this comment @@ -2418,6 +2419,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { // Opset 20 BuildKernelCreateInfo, +<<<<<<< HEAD BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2447,6 +2449,9 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, #endif BuildKernelCreateInfo, +======= + BuildKernelCreateInfo, +>>>>>>> 4f27bde807 (RegexFullMatch operator) }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cpu/nn/regex_full_match.cc b/onnxruntime/core/providers/cpu/nn/regex_full_match.cc new file mode 100644 index 0000000000000..3c51d30b24313 --- /dev/null +++ b/onnxruntime/core/providers/cpu/nn/regex_full_match.cc @@ -0,0 +1,41 @@ +#include "regex_full_match.h" +#include "core/common/common.h" + +namespace onnxruntime { +ONNX_CPU_OPERATOR_KERNEL( + RegexFullMatch, + 20, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + RegexFullMatch); + +RegexFullMatch::RegexFullMatch(const OpKernelInfo& info) : OpKernel(info) { + ORT_ENFORCE(info.GetAttr("pattern", &pattern_).IsOK()); + ORT_ENFORCE(RE2(pattern_).ok(), "Invalid pattern: ", pattern_); +} + +Status RegexFullMatch::Compute(OpKernelContext* context) const { + RE2 re(pattern_); + const auto* input_tensor = context->Input(0); + if (nullptr == input_tensor) { + return Status(common::ONNXRUNTIME, common::FAIL, "Input count mismatch"); + } + auto* output_tensor = context->Output(0, input_tensor->Shape()); + if (nullptr == output_tensor) { + return Status(common::ONNXRUNTIME, common::FAIL, "Output count mismatch"); + } + const auto input_data = input_tensor->template DataAsSpan(); + auto output_data = output_tensor->template MutableDataAsSpan(); + const auto N = input_tensor->Shape().Size(); + auto output_iter = output_data.begin(); + auto input_iter = input_data.begin(); + while (input_iter != output_data.end()) { + *output_iter = RE2::FullMatch(*input_iter, re); + input_iter++; + output_iter++; + } + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/regex_full_match.h b/onnxruntime/core/providers/cpu/nn/regex_full_match.h new file mode 100644 index 0000000000000..49290532a304e --- /dev/null +++ b/onnxruntime/core/providers/cpu/nn/regex_full_match.h @@ -0,0 +1,17 @@ +#pragma once + +#include "core/framework/op_kernel.h" +#include "re2/re2.h" + +namespace onnxruntime { + +class RegexFullMatch final : public OpKernel { + public: + explicit RegexFullMatch(const OpKernelInfo& info); + Status Compute(OpKernelContext* context) const override; + + private: + std::string pattern_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/nn/regex_full_match_test.cc b/onnxruntime/test/providers/cpu/nn/regex_full_match_test.cc new file mode 100644 index 0000000000000..8d429fd8dfc09 --- /dev/null +++ b/onnxruntime/test/providers/cpu/nn/regex_full_match_test.cc @@ -0,0 +1,24 @@ +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +static void RunTest(const std::initializer_list& dims, const std::initializer_list& input, const std::string& pattern, const std::initializer_list& output) { + OpTester test("RegexFullMatch", 20, kOnnxDomain); + test.AddAttribute("pattern", pattern); + test.AddInput("Input", dims, input); + test.AddOutput("Output", dims, output); + test.Run(); +} + +TEST(RegexFullMatch, WebsiteMatch) { + RunTest({3, 1}, {"www.google.com", "www.facebook.com", "www.bbc.co.uk"}, R"(www\.[\w.-]+\.\bcom\b)", {true, true, false}); +} + +TEST(RegexFullMatch, EmailMatch) { + RunTest({2, 2}, {"account@gmail.com", "account@hotmail.com", "not email", "account@yahoo.com"}, R"((\W|^)[\w.\-]{0,25}@(yahoo|gmail)\.com(\W|$))", {true, false, false, true}); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 3a13e39702904..3db497fa92315 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -248,9 +248,6 @@ "^test_image_decoder_decode_pnm_rgb", "^test_image_decoder_decode_tiff_rgb", "^test_image_decoder_decode_webp_rgb", - "^test_regex_full_match_basic", - "^test_regex_full_match_email_domain", - "^test_regex_full_match_empty", "^test_string_concat_broadcasting", "^test_string_concat", "^test_string_concat_empty_string",