Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into stringsplit
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Jan 11, 2024
2 parents 8e757a9 + d8962d6 commit 0b92c5c
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ Do not modify directly.*
|||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|RegexFullMatch|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(string)<br/> **T2** = tensor(bool)|
|Relu|*in* X:**T**<br> *out* Y:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8)|
|||13|**T** = tensor(double), tensor(float)|
|||[6, 12]|**T** = tensor(double), tensor(float)|
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
#endif
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, StringConcat);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, RegexFullMatch);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, StringSplit);

// !!PLEASE READ BELOW!! Following that, add new entries above this comment
Expand Down Expand Up @@ -2450,6 +2451,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
#endif
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, StringConcat)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, RegexFullMatch)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, StringSplit)>,
};

Expand Down
35 changes: 35 additions & 0 deletions onnxruntime/core/providers/cpu/text/regex_full_match.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "regex_full_match.h"
#include "core/common/common.h"

namespace onnxruntime {
ONNX_CPU_OPERATOR_KERNEL(
RegexFullMatch,
20,
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::GetTensorType<std::string>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
RegexFullMatch);

RegexFullMatch::RegexFullMatch(const OpKernelInfo& info) : OpKernel(info), re_{info.GetAttr<std::string>("pattern")} {
ORT_ENFORCE(re_.ok(), "Invalid regex pattern: ", re_.pattern());
}

Status RegexFullMatch::Compute(OpKernelContext* context) const {
const auto* input_tensor = context->Input<Tensor>(0);
const auto input_data = input_tensor->template DataAsSpan<std::string>();
auto* output_tensor = context->Output(0, input_tensor->Shape());
auto output_data = output_tensor->template MutableDataAsSpan<bool>();
auto output_iter = output_data.begin();
auto input_iter = input_data.begin();
while (input_iter != input_data.end()) {
*output_iter = RE2::FullMatch(*input_iter, re_);
input_iter++;
output_iter++;
}
return Status::OK();
}

} // namespace onnxruntime
20 changes: 20 additions & 0 deletions onnxruntime/core/providers/cpu/text/regex_full_match.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/framework/op_kernel.h"
#include "re2/re2.h"

namespace onnxruntime {

class RegexFullMatch final : public OpKernel {
public:
explicit RegexFullMatch(const OpKernelInfo& info);
Status Compute(OpKernelContext* context) const override;

private:
RE2 re_;
};

} // namespace onnxruntime
119 changes: 119 additions & 0 deletions onnxruntime/test/providers/cpu/text/regex_full_match_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
namespace onnxruntime {
namespace test {

static void RunTest(const std::initializer_list<int64_t>& dims, const std::initializer_list<std::string>& input, const std::string& pattern, const std::initializer_list<bool>& output) {
OpTester test("RegexFullMatch", 20, kOnnxDomain);
test.AddAttribute("pattern", pattern);
test.AddInput<std::string>("Input", dims, input);
test.AddOutput<bool>("Output", dims, output);
test.Run();
}

TEST(RegexFullMatch, WebsiteMatch) {
RunTest({3, 1}, {"www.google.com", "www.facebook.com", "www.bbc.co.uk"}, R"(www\.[\w.-]+\.\bcom\b)", {true, true, false});
}

TEST(RegexFullMatch, EmailMatch) {
RunTest({2, 2}, {"[email protected]", "[email protected]", "not email", "[email protected]"}, R"((\W|^)[\w.\-]{0,25}@(yahoo|gmail)\.com(\W|$))", {true, false, false, true});
}

TEST(RegexFullMatch, MultibyteMatch) {
RunTest({1, 2}, {"ä", "a"}, "ä", {true, false});
RunTest({
1,
},
{"une cédille like in Besançon"}, R"(.*Besançon.*)", {
true,
});
RunTest({
1,
},
{"une cédille like in Besançon"}, R"(.*Besancon.*)", {
false,
});
RunTest({
1,
},
{"Mit freundlichen Grüßen"}, R"(.*Grüßen$)", {
true,
});
RunTest({
1,
},
{"Mit freundlichen Grüßen"}, R"(.*Grußen$)", {
false,
});
RunTest({
3,
},
{"HПонедельник", "Понедельник", "недельник"}, R"(^Понед.*)", {
false,
true,
false,
});
RunTest({
3,
},
{"thank you", "どうもありがとうございます", "こんにちは世界"}, R"(^こんにちは世界.*)", {
false,
false,
true,
});
RunTest({
3,
},
{"नमस्ते, आपसे मिलकर अच्छा लगा", "नमस्ते", "स्वागत एवं नमस्ते"}, R"(.+नमस्ते$)", {
false,
false,
true,
});
RunTest({
3,
},
{"你好,你好吗?", "你好呀", "你好呀!"}, R"(^你好.*\?$)", {
true,
false,
false,
});
}

TEST(RegexFullMatch, InvalidPattern) {
OpTester test("RegexFullMatch", 20, kOnnxDomain);
test.AddAttribute("pattern", R"([a-z)");
test.AddInput<std::string>("Input", {
1,
},
{
"abcdef",
});
test.AddOutput<bool>("Output", {
1,
},
{
false,
});
test.Run(BaseTester::ExpectResult::kExpectFailure, "Invalid regex pattern: [a-z");
}

TEST(RegexFullMatch, NonUtf8Pattern) {
uint8_t invalid_bytes[] = {0xC0, 0xC1, 0x41, 0x42, 0xC3, 0x80, 0xC2, 0x80, 0xC2, 0xC3, 0xC4, 0x00};
OpTester test("RegexFullMatch", 20, kOnnxDomain);
test.AddAttribute("pattern", std::string((char*)invalid_bytes, sizeof(invalid_bytes)));
test.AddInput<std::string>("Input", {
1,
},
{
"abcd",
});
test.AddOutput<bool>("Output", {
1,
},
{
false,
});
test.Run(BaseTester::ExpectResult::kExpectFailure, "Invalid regex pattern");
}
} // namespace test
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,6 @@
"^test_image_decoder_decode_pnm_rgb",
"^test_image_decoder_decode_tiff_rgb",
"^test_image_decoder_decode_webp_rgb",
"^test_regex_full_match_basic",
"^test_regex_full_match_email_domain",
"^test_regex_full_match_empty",
"^test_reduce_l1_empty_set_cuda",
"^test_reduce_l1_empty_set_expanded_cuda",
"^test_reduce_l2_empty_set_cuda",
Expand Down

0 comments on commit 0b92c5c

Please sign in to comment.