Skip to content

Commit

Permalink
RegexFullMatch operator
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Jan 5, 2024
1 parent 447a3a7 commit 3078acf
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 6 deletions.
6 changes: 3 additions & 3 deletions cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ flatbuffers;https://github.com/google/flatbuffers/archive/refs/tags/v1.12.0.zip;
fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip;b985f6985a05a1c03ff1bb71190f66d8f98a1494
fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1
google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.7.0.zip;e97c368b176e8614e3f1bf13dd9abcf6a7ad9908
google_nsync;https://github.com/google/nsync/archive/refs/tags/1.26.0.zip;5e7c00ef6bf5b787386fc040067903ec774e2752
googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583ee817743c9b3a42b4.zip;5e3a61db2aa975cfd0f97ba92c818744e7fa7034
googlexnnpack;https://github.com/google/XNNPACK/archive/0da379fc4808f9601faef392352018c741c0f297.zip;663883491e380b628e0a5b162b5f2658032fae73
google_nsync;https://github.com/google/nsync/archive/refs/tags/1.23.0.zip;f3233450cf7156fc0bedd1b0e884eddec264897c
googletest;https://github.com/google/googletest/archive/519beb0e52c842729b4b53731d27c0e0c32ab4a2.zip;4b3c37972e4c1bef1185d46f702082f8772ee73f
googlexnnpack;https://github.com/google/XNNPACK/archive/003c580e696a774afdc984996ee909b7c8d8128c.zip;9f192e3f15e1e37ae9c78d53eeea47e45c5eb31c
json;https://github.com/nlohmann/json/archive/refs/tags/v3.10.5.zip;f257f8dc27c5b8c085dc887b40cddd18ae1f725c
microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14
microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2FNUZ, IsNaN);
#endif
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, RegexFullMatch);

// !!PLEASE READ BELOW!! Following that, add new entries above this comment

Expand Down Expand Up @@ -2418,6 +2419,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {

// Opset 20
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape)>,
<<<<<<< HEAD
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, bool, ReduceMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, ReduceMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, ReduceMax)>,
Expand Down Expand Up @@ -2447,6 +2449,9 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2FNUZ, IsNaN)>,
#endif
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf)>,
=======
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, RegexFullMatch)>,
>>>>>>> 4f27bde807 (RegexFullMatch operator)
};

for (auto& function_table_entry : function_table) {
Expand Down
41 changes: 41 additions & 0 deletions onnxruntime/core/providers/cpu/nn/regex_full_match.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#include "regex_full_match.h"
#include "core/common/common.h"

namespace onnxruntime {
ONNX_CPU_OPERATOR_KERNEL(
RegexFullMatch,
20,
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::GetTensorType<std::string>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
RegexFullMatch);

RegexFullMatch::RegexFullMatch(const OpKernelInfo& info) : OpKernel(info) {
ORT_ENFORCE(info.GetAttr<std::string>("pattern", &pattern_).IsOK());
ORT_ENFORCE(RE2(pattern_).ok(), "Invalid pattern: ", pattern_);
}

Status RegexFullMatch::Compute(OpKernelContext* context) const {
RE2 re(pattern_);
const auto* input_tensor = context->Input<Tensor>(0);
if (nullptr == input_tensor) {
return Status(common::ONNXRUNTIME, common::FAIL, "Input count mismatch");
}
auto* output_tensor = context->Output(0, input_tensor->Shape());
if (nullptr == output_tensor) {
return Status(common::ONNXRUNTIME, common::FAIL, "Output count mismatch");
}
const auto input_data = input_tensor->template DataAsSpan<std::string>();
auto output_data = output_tensor->template MutableDataAsSpan<bool>();
const auto N = input_tensor->Shape().Size();
auto output_iter = output_data.begin();
auto input_iter = input_data.begin();
while (input_iter != output_data.end()) {
*output_iter = RE2::FullMatch(*input_iter, re);
input_iter++;
output_iter++;
}
return Status::OK();
}

} // namespace onnxruntime
17 changes: 17 additions & 0 deletions onnxruntime/core/providers/cpu/nn/regex_full_match.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once

#include "core/framework/op_kernel.h"
#include "re2/re2.h"

namespace onnxruntime {

class RegexFullMatch final : public OpKernel {
public:
explicit RegexFullMatch(const OpKernelInfo& info);
Status Compute(OpKernelContext* context) const override;

private:
std::string pattern_;
};

} // namespace onnxruntime
24 changes: 24 additions & 0 deletions onnxruntime/test/providers/cpu/nn/regex_full_match_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"

namespace onnxruntime {
namespace test {

static void RunTest(const std::initializer_list<int64_t>& dims, const std::initializer_list<std::string>& input, const std::string& pattern, const std::initializer_list<bool>& output) {
OpTester test("RegexFullMatch", 20, kOnnxDomain);
test.AddAttribute("pattern", pattern);
test.AddInput<std::string>("Input", dims, input);
test.AddOutput<bool>("Output", dims, output);
test.Run();
}

TEST(RegexFullMatch, WebsiteMatch) {
RunTest({3, 1}, {"www.google.com", "www.facebook.com", "www.bbc.co.uk"}, R"(www\.[\w.-]+\.\bcom\b)", {true, true, false});
}

TEST(RegexFullMatch, EmailMatch) {
RunTest({2, 2}, {"[email protected]", "[email protected]", "not email", "[email protected]"}, R"((\W|^)[\w.\-]{0,25}@(yahoo|gmail)\.com(\W|$))", {true, false, false, true});
}

} // namespace test
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,6 @@
"^test_image_decoder_decode_pnm_rgb",
"^test_image_decoder_decode_tiff_rgb",
"^test_image_decoder_decode_webp_rgb",
"^test_regex_full_match_basic",
"^test_regex_full_match_email_domain",
"^test_regex_full_match_empty",
"^test_string_concat_broadcasting",
"^test_string_concat",
"^test_string_concat_empty_string",
Expand Down

0 comments on commit 3078acf

Please sign in to comment.