diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 8092c26da651a..67bfe48327e14 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -421,7 +421,7 @@ Do not modify directly.*
|Transpose|*in* data:**T**
*out* transposed:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||[13, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|Trilu|*in* input:**T**
*in* k:**tensor(int64)**
*out* output:**T**|14+|**T** = tensor(double), tensor(float), tensor(int64)|
+|Trilu|*in* input:**T**
*in* k:**tensor(int64)**
*out* output:**T**|14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int64)|
|Unique|*in* X:**T**
*out* Y:**T**
*out* indices:**tensor(int64)**
*out* inverse_indices:**tensor(int64)**
*out* counts:**tensor(int64)**|11+|**T** = tensor(double), tensor(float), tensor(int64), tensor(int8), tensor(string)|
|Unsqueeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* expanded:**T**
or
*in* data:**T**
*out* expanded:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[13, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
diff --git a/onnxruntime/core/providers/cpu/tensor/trilu.cc b/onnxruntime/core/providers/cpu/tensor/trilu.cc
index 91e429ef60d91..017bbcd44904e 100644
--- a/onnxruntime/core/providers/cpu/tensor/trilu.cc
+++ b/onnxruntime/core/providers/cpu/tensor/trilu.cc
@@ -31,7 +31,7 @@ ONNX_OPERATOR_KERNEL_EX(
kOnnxDomain,
14,
kCpuExecutionProvider,
- KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", BuildKernelDefConstraints()),
+ KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", BuildKernelDefConstraints()),
Trilu);
template
@@ -110,6 +110,9 @@ Status Trilu::Compute(OpKernelContext* ctx) const {
case sizeof(double):
status = TriluImpl(X, Y, k_val, up);
break;
+ case sizeof(bool):
+ status = TriluImpl(X, Y, k_val, up);
+ break;
default:
ORT_THROW("Unsupported input data type of ", data_type);
}
diff --git a/onnxruntime/test/providers/cpu/tensor/trilu_op_test.cc b/onnxruntime/test/providers/cpu/tensor/trilu_op_test.cc
index f0b5d6afa9c7b..f1d1d94343e6f 100644
--- a/onnxruntime/test/providers/cpu/tensor/trilu_op_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/trilu_op_test.cc
@@ -62,63 +62,54 @@ TEST(TriluOpTest, two_by_two_long_lower) {
test.Run();
}
+TEST(TriluOpTest, two_by_two_bool_upper) {
+ OpTester test("Trilu", 14, kOnnxDomain);
+ int64_t up = 1;
+ test.AddAttribute("upper", up);
+ test.AddInput("X", {2, 2},
+ {true, true,
+ true, true});
+ test.AddOutput("Y", {2, 2},
+ {true, true,
+ false, true});
+ test.Run();
+}
+
+TEST(TriluOpTest, three_by_three_bool_lower) {
+ OpTester test("Trilu", 14, kOnnxDomain);
+ int64_t up = 0;
+ test.AddAttribute("upper", up);
+ test.AddInput("X", {3, 3},
+ // include a couple of false values to check they are copied
+ {true, true, true,
+ true, false, true,
+ true, true, false});
+ test.AddOutput("Y", {3, 3},
+ {true, false, false,
+ true, false, false,
+ true, true, false});
+ test.Run();
+}
+
TEST(TriluOpTest, three_dim_float_upper) {
OpTester test("Trilu", 14, kOnnxDomain);
test.AddInput("X", {2, 3, 4},
- {
- 4.f,
- 1.f,
- 5.f,
- 8.f,
- 4.f,
- 3.f,
- 2.f,
- 4.f,
- 6.f,
- 1.f,
- 2.f,
- 3.f,
- 1.f,
- 6.f,
- 2.f,
- 1.f,
- 4.f,
- 1.f,
- 5.f,
- 8.f,
- 4.f,
- 3.f,
- 2.f,
- 4.f,
- });
+ {4.f, 1.f, 5.f, 8.f,
+ 4.f, 3.f, 2.f, 4.f,
+ 6.f, 1.f, 2.f, 3.f,
+
+ 1.f, 6.f, 2.f, 1.f,
+ 4.f, 1.f, 5.f, 8.f,
+ 4.f, 3.f, 2.f, 4.f});
test.AddInput("k", {1}, {1});
test.AddOutput("Y", {2, 3, 4},
- {
- 0.f,
- 1.f,
- 5.f,
- 8.f,
- 0.f,
- 0.f,
- 2.f,
- 4.f,
- 0.f,
- 0.f,
- 0.f,
- 3.f,
- 0.f,
- 6.f,
- 2.f,
- 1.f,
- 0.f,
- 0.f,
- 5.f,
- 8.f,
- 0.f,
- 0.f,
- 0.f,
- 4.f,
- });
+ {0.f, 1.f, 5.f, 8.f,
+ 0.f, 0.f, 2.f, 4.f,
+ 0.f, 0.f, 0.f, 3.f,
+
+ 0.f, 6.f, 2.f, 1.f,
+ 0.f, 0.f, 5.f, 8.f,
+ 0.f, 0.f, 0.f, 4.f});
test.Run();
}
@@ -127,60 +118,22 @@ TEST(TriluOpTest, three_dim_float_lower) {
int64_t up = 0;
test.AddAttribute("upper", up);
test.AddInput("X", {2, 3, 4},
- {
- 4.f,
- 1.f,
- 5.f,
- 8.f,
- 4.f,
- 3.f,
- 2.f,
- 4.f,
- 6.f,
- 1.f,
- 2.f,
- 3.f,
- 1.f,
- 6.f,
- 2.f,
- 1.f,
- 4.f,
- 1.f,
- 5.f,
- 8.f,
- 4.f,
- 3.f,
- 2.f,
- 4.f,
- });
+ {4.f, 1.f, 5.f, 8.f,
+ 4.f, 3.f, 2.f, 4.f,
+ 6.f, 1.f, 2.f, 3.f,
+
+ 1.f, 6.f, 2.f, 1.f,
+ 4.f, 1.f, 5.f, 8.f,
+ 4.f, 3.f, 2.f, 4.f});
test.AddInput("k", {1}, {1});
test.AddOutput("Y", {2, 3, 4},
- {
- 4.f,
- 1.f,
- 0.f,
- 0.f,
- 4.f,
- 3.f,
- 2.f,
- 0.f,
- 6.f,
- 1.f,
- 2.f,
- 3.f,
- 1.f,
- 6.f,
- 0.f,
- 0.f,
- 4.f,
- 1.f,
- 5.f,
- 0.f,
- 4.f,
- 3.f,
- 2.f,
- 4.f,
- });
+ {4.f, 1.f, 0.f, 0.f,
+ 4.f, 3.f, 2.f, 0.f,
+ 6.f, 1.f, 2.f, 3.f,
+
+ 1.f, 6.f, 0.f, 0.f,
+ 4.f, 1.f, 5.f, 0.f,
+ 4.f, 3.f, 2.f, 4.f});
test.Run();
}
@@ -189,60 +142,22 @@ TEST(TriluOpTest, neg_k_float_upper) {
int64_t up = 1;
test.AddAttribute("upper", up);
test.AddInput("X", {2, 3, 4},
- {
- 4.f,
- 1.f,
- 5.f,
- 8.f,
- 4.f,
- 3.f,
- 2.f,
- 4.f,
- 6.f,
- 1.f,
- 2.f,
- 3.f,
- 1.f,
- 6.f,
- 2.f,
- 1.f,
- 4.f,
- 1.f,
- 5.f,
- 8.f,
- 4.f,
- 3.f,
- 2.f,
- 4.f,
- });
+ {4.f, 1.f, 5.f, 8.f,
+ 4.f, 3.f, 2.f, 4.f,
+ 6.f, 1.f, 2.f, 3.f,
+
+ 1.f, 6.f, 2.f, 1.f,
+ 4.f, 1.f, 5.f, 8.f,
+ 4.f, 3.f, 2.f, 4.f});
test.AddInput("k", {1}, {-1});
test.AddOutput("Y", {2, 3, 4},
- {
- 4.f,
- 1.f,
- 5.f,
- 8.f,
- 4.f,
- 3.f,
- 2.f,
- 4.f,
- 0.f,
- 1.f,
- 2.f,
- 3.f,
- 1.f,
- 6.f,
- 2.f,
- 1.f,
- 4.f,
- 1.f,
- 5.f,
- 8.f,
- 0.f,
- 3.f,
- 2.f,
- 4.f,
- });
+ {4.f, 1.f, 5.f, 8.f,
+ 4.f, 3.f, 2.f, 4.f,
+ 0.f, 1.f, 2.f, 3.f,
+
+ 1.f, 6.f, 2.f, 1.f,
+ 4.f, 1.f, 5.f, 8.f,
+ 0.f, 3.f, 2.f, 4.f});
test.Run();
}
@@ -251,120 +166,44 @@ TEST(TriluOpTest, neg_k_float_lower) {
int64_t up = 0;
test.AddAttribute("upper", up);
test.AddInput("X", {2, 3, 4},
- {
- 4.f,
- 1.f,
- 5.f,
- 8.f,
- 4.f,
- 3.f,
- 2.f,
- 4.f,
- 6.f,
- 1.f,
- 2.f,
- 3.f,
- 1.f,
- 6.f,
- 2.f,
- 1.f,
- 4.f,
- 1.f,
- 5.f,
- 8.f,
- 4.f,
- 3.f,
- 2.f,
- 4.f,
- });
+ {4.f, 1.f, 5.f, 8.f,
+ 4.f, 3.f, 2.f, 4.f,
+ 6.f, 1.f, 2.f, 3.f,
+
+ 1.f, 6.f, 2.f, 1.f,
+ 4.f, 1.f, 5.f, 8.f,
+ 4.f, 3.f, 2.f, 4.f});
test.AddInput("k", {1}, {-1});
test.AddOutput("Y", {2, 3, 4},
- {
- 0.f,
- 0.f,
- 0.f,
- 0.f,
- 4.f,
- 0.f,
- 0.f,
- 0.f,
- 6.f,
- 1.f,
- 0.f,
- 0.f,
- 0.f,
- 0.f,
- 0.f,
- 0.f,
- 4.f,
- 0.f,
- 0.f,
- 0.f,
- 4.f,
- 3.f,
- 0.f,
- 0.f,
- });
+ {0.f, 0.f, 0.f, 0.f,
+ 4.f, 0.f, 0.f, 0.f,
+ 6.f, 1.f, 0.f, 0.f,
+
+ 0.f, 0.f, 0.f, 0.f,
+ 4.f, 0.f, 0.f, 0.f,
+ 4.f, 3.f, 0.f, 0.f});
test.Run();
}
TEST(TriluTest, small_k_float_upper) {
OpTester test("Trilu", 14, kOnnxDomain);
test.AddInput("X", {2, 3, 4},
- {
- 4.f,
- 1.f,
- 5.f,
- 8.f,
- 4.f,
- 3.f,
- 2.f,
- 4.f,
- 6.f,
- 1.f,
- 2.f,
- 3.f,
- 1.f,
- 6.f,
- 2.f,
- 1.f,
- 4.f,
- 1.f,
- 5.f,
- 8.f,
- 4.f,
- 3.f,
- 2.f,
- 4.f,
- });
+ {4.f, 1.f, 5.f, 8.f,
+ 4.f, 3.f, 2.f, 4.f,
+ 6.f, 1.f, 2.f, 3.f,
+
+ 1.f, 6.f, 2.f, 1.f,
+ 4.f, 1.f, 5.f, 8.f,
+ 4.f, 3.f, 2.f, 4.f});
test.AddInput("k", {1}, {-5});
test.AddOutput("Y", {2, 3, 4},
- {
- 4.f,
- 1.f,
- 5.f,
- 8.f,
- 4.f,
- 3.f,
- 2.f,
- 4.f,
- 6.f,
- 1.f,
- 2.f,
- 3.f,
- 1.f,
- 6.f,
- 2.f,
- 1.f,
- 4.f,
- 1.f,
- 5.f,
- 8.f,
- 4.f,
- 3.f,
- 2.f,
- 4.f,
- });
+ {4.f, 1.f, 5.f, 8.f,
+ 4.f, 3.f, 2.f, 4.f,
+ 6.f, 1.f, 2.f, 3.f,
+
+ 1.f, 6.f, 2.f, 1.f,
+ 4.f, 1.f, 5.f, 8.f,
+ 4.f, 3.f, 2.f, 4.f});
test.Run();
}
@@ -373,60 +212,22 @@ TEST(TriluOpTest, small_k_float_lower) {
int64_t up = 0;
test.AddAttribute("upper", up);
test.AddInput("X", {2, 3, 4},
- {
- 4.f,
- 1.f,
- 5.f,
- 8.f,
- 4.f,
- 3.f,
- 2.f,
- 4.f,
- 6.f,
- 1.f,
- 2.f,
- 3.f,
- 1.f,
- 6.f,
- 2.f,
- 1.f,
- 4.f,
- 1.f,
- 5.f,
- 8.f,
- 4.f,
- 3.f,
- 2.f,
- 4.f,
- });
+ {4.f, 1.f, 5.f, 8.f,
+ 4.f, 3.f, 2.f, 4.f,
+ 6.f, 1.f, 2.f, 3.f,
+
+ 1.f, 6.f, 2.f, 1.f,
+ 4.f, 1.f, 5.f, 8.f,
+ 4.f, 3.f, 2.f, 4.f});
test.AddInput("k", {1}, {-5});
test.AddOutput("Y", {2, 3, 4},
- {
- 0.f,
- 0.f,
- 0.f,
- 0.f,
- 0.f,
- 0.f,
- 0.f,
- 0.f,
- 0.f,
- 0.f,
- 0.f,
- 0.f,
- 0.f,
- 0.f,
- 0.f,
- 0.f,
- 0.f,
- 0.f,
- 0.f,
- 0.f,
- 0.f,
- 0.f,
- 0.f,
- 0.f,
- });
+ {0.f, 0.f, 0.f, 0.f,
+ 0.f, 0.f, 0.f, 0.f,
+ 0.f, 0.f, 0.f, 0.f,
+
+ 0.f, 0.f, 0.f, 0.f,
+ 0.f, 0.f, 0.f, 0.f,
+ 0.f, 0.f, 0.f, 0.f});
test.Run();
}