Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RegexFullMatch operator #18002

Merged
merged 9 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2FNUZ, IsNaN);
#endif
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, RegexFullMatch);

// !!PLEASE READ BELOW!! Following that, add new entries above this comment

Expand Down Expand Up @@ -2447,6 +2448,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2FNUZ, IsNaN)>,
#endif
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, RegexFullMatch)>,
};

for (auto& function_table_entry : function_table) {
Expand Down
35 changes: 35 additions & 0 deletions onnxruntime/core/providers/cpu/text/regex_full_match.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "regex_full_match.h"

Check warning on line 4 in onnxruntime/core/providers/cpu/text/regex_full_match.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: onnxruntime/core/providers/cpu/text/regex_full_match.cc:4: Include the directory when naming header files [build/include_subdir] [4]
#include "core/common/common.h"

namespace onnxruntime {
ONNX_CPU_OPERATOR_KERNEL(
RegexFullMatch,
20,
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::GetTensorType<std::string>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
RegexFullMatch);

RegexFullMatch::RegexFullMatch(const OpKernelInfo& info) : OpKernel(info), re_{info.GetAttr<std::string>("pattern")} {
ORT_ENFORCE(re_.ok(), "Invalid regex pattern: ", re_.pattern());
}

Status RegexFullMatch::Compute(OpKernelContext* context) const {
const auto* input_tensor = context->Input<Tensor>(0);
const auto input_data = input_tensor->template DataAsSpan<std::string>();

Check warning on line 22 in onnxruntime/core/providers/cpu/text/regex_full_match.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/cpu/text/regex_full_match.cc:22: Add #include <string> for string [build/include_what_you_use] [4]
auto* output_tensor = context->Output(0, input_tensor->Shape());
auto output_data = output_tensor->template MutableDataAsSpan<bool>();
auto output_iter = output_data.begin();
auto input_iter = input_data.begin();
while (input_iter != input_data.end()) {
*output_iter = RE2::FullMatch(*input_iter, re_);
input_iter++;
output_iter++;
}
return Status::OK();
}

} // namespace onnxruntime
20 changes: 20 additions & 0 deletions onnxruntime/core/providers/cpu/text/regex_full_match.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/framework/op_kernel.h"
#include "re2/re2.h"

namespace onnxruntime {

class RegexFullMatch final : public OpKernel {
public:
explicit RegexFullMatch(const OpKernelInfo& info);
Status Compute(OpKernelContext* context) const override;

private:
RE2 re_;
};

} // namespace onnxruntime
119 changes: 119 additions & 0 deletions onnxruntime/test/providers/cpu/text/regex_full_match_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
namespace onnxruntime {
namespace test {

static void RunTest(const std::initializer_list<int64_t>& dims, const std::initializer_list<std::string>& input, const std::string& pattern, const std::initializer_list<bool>& output) {

Check warning on line 6 in onnxruntime/test/providers/cpu/text/regex_full_match_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/test/providers/cpu/text/regex_full_match_test.cc:6: Lines should be <= 120 characters long [whitespace/line_length] [2]
OpTester test("RegexFullMatch", 20, kOnnxDomain);
test.AddAttribute("pattern", pattern);
test.AddInput<std::string>("Input", dims, input);
test.AddOutput<bool>("Output", dims, output);
test.Run();
}

TEST(RegexFullMatch, WebsiteMatch) {
RunTest({3, 1}, {"www.google.com", "www.facebook.com", "www.bbc.co.uk"}, R"(www\.[\w.-]+\.\bcom\b)", {true, true, false});
}

TEST(RegexFullMatch, EmailMatch) {
RunTest({2, 2}, {"[email protected]", "[email protected]", "not email", "[email protected]"}, R"((\W|^)[\w.\-]{0,25}@(yahoo|gmail)\.com(\W|$))", {true, false, false, true});

Check warning on line 19 in onnxruntime/test/providers/cpu/text/regex_full_match_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/test/providers/cpu/text/regex_full_match_test.cc:19: Lines should be <= 120 characters long [whitespace/line_length] [2]
}

TEST(RegexFullMatch, MultibyteMatch) {
RunTest({1, 2}, {"ä", "a"}, "ä", {true, false});
RunTest({
1,
},
{"une cédille like in Besançon"}, R"(.*Besançon.*)", {
true,
});
RunTest({
1,
},
{"une cédille like in Besançon"}, R"(.*Besancon.*)", {
false,
});
RunTest({
1,
},
{"Mit freundlichen Grüßen"}, R"(.*Grüßen$)", {
true,
});
RunTest({
1,
},
{"Mit freundlichen Grüßen"}, R"(.*Grußen$)", {
false,
});
RunTest({
3,
},
{"HПонедельник", "Понедельник", "недельник"}, R"(^Понед.*)", {
false,
true,
false,
});
RunTest({
3,
},
{"thank you", "どうもありがとうございます", "こんにちは世界"}, R"(^こんにちは世界.*)", {
false,
false,
true,
});
RunTest({
3,
},
{"नमस्ते, आपसे मिलकर अच्छा लगा", "नमस्ते", "स्वागत एवं नमस्ते"}, R"(.+नमस्ते$)", {
false,
false,
true,
});
RunTest({
3,
},
{"你好,你好吗?", "你好呀", "你好呀!"}, R"(^你好.*\?$)", {
true,
false,
false,
});
}

TEST(RegexFullMatch, InvalidPattern) {
OpTester test("RegexFullMatch", 20, kOnnxDomain);
test.AddAttribute("pattern", R"([a-z)");
test.AddInput<std::string>("Input", {
1,
},
{
"abcdef",
});
test.AddOutput<bool>("Output", {
1,
},
{
false,
});
test.Run(BaseTester::ExpectResult::kExpectFailure, "Invalid regex pattern: [a-z");
}

TEST(RegexFullMatch, NonUtf8Pattern) {
uint8_t invalid_bytes[] = {0xC0, 0xC1, 0x41, 0x42, 0xC3, 0x80, 0xC2, 0x80, 0xC2, 0xC3, 0xC4, 0x00};
OpTester test("RegexFullMatch", 20, kOnnxDomain);
test.AddAttribute("pattern", std::string((char*)invalid_bytes, sizeof(invalid_bytes)));

Check warning on line 103 in onnxruntime/test/providers/cpu/text/regex_full_match_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use reinterpret_cast<char*>(...) instead [readability/casting] [4] Raw Output: onnxruntime/test/providers/cpu/text/regex_full_match_test.cc:103: Using C-style cast. Use reinterpret_cast<char*>(...) instead [readability/casting] [4]
test.AddInput<std::string>("Input", {
1,
},
{
"abcd",
});
test.AddOutput<bool>("Output", {
1,
},
{
false,
});
test.Run(BaseTester::ExpectResult::kExpectFailure, "Invalid regex pattern");
}
} // namespace test
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,6 @@
"^test_image_decoder_decode_pnm_rgb",
"^test_image_decoder_decode_tiff_rgb",
"^test_image_decoder_decode_webp_rgb",
"^test_regex_full_match_basic",
"^test_regex_full_match_email_domain",
"^test_regex_full_match_empty",
"^test_string_concat_broadcasting",
"^test_string_concat",
"^test_string_concat_empty_string",
Expand Down
Loading