From 3ecf48e3b5ea63a0a7a24e13fc5da98edd5b0b68 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Thu, 6 Jun 2024 15:21:34 +1000 Subject: [PATCH] Add support for Trilu. (#20917) ### Description Trilu is used by phi-3 when exported with torch.onnx.export. ### Motivation and Context --- docs/OperatorKernels.md | 2 +- .../core/providers/cpu/tensor/trilu.cc | 5 +- .../providers/cpu/tensor/trilu_op_test.cc | 425 +++++------------- 3 files changed, 118 insertions(+), 314 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 8092c26da651a..67bfe48327e14 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -421,7 +421,7 @@ Do not modify directly.* |Transpose|*in* data:**T**
*out* transposed:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| |||[13, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Trilu|*in* input:**T**
*in* k:**tensor(int64)**
*out* output:**T**|14+|**T** = tensor(double), tensor(float), tensor(int64)| +|Trilu|*in* input:**T**
*in* k:**tensor(int64)**
*out* output:**T**|14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int64)| |Unique|*in* X:**T**
*out* Y:**T**
*out* indices:**tensor(int64)**
*out* inverse_indices:**tensor(int64)**
*out* counts:**tensor(int64)**|11+|**T** = tensor(double), tensor(float), tensor(int64), tensor(int8), tensor(string)| |Unsqueeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* expanded:**T**

or

*in* data:**T**
*out* expanded:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[13, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/core/providers/cpu/tensor/trilu.cc b/onnxruntime/core/providers/cpu/tensor/trilu.cc index 91e429ef60d91..017bbcd44904e 100644 --- a/onnxruntime/core/providers/cpu/tensor/trilu.cc +++ b/onnxruntime/core/providers/cpu/tensor/trilu.cc @@ -31,7 +31,7 @@ ONNX_OPERATOR_KERNEL_EX( kOnnxDomain, 14, kCpuExecutionProvider, - KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", BuildKernelDefConstraints()), + KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", BuildKernelDefConstraints()), Trilu); template @@ -110,6 +110,9 @@ Status Trilu::Compute(OpKernelContext* ctx) const { case sizeof(double): status = TriluImpl(X, Y, k_val, up); break; + case sizeof(bool): + status = TriluImpl(X, Y, k_val, up); + break; default: ORT_THROW("Unsupported input data type of ", data_type); } diff --git a/onnxruntime/test/providers/cpu/tensor/trilu_op_test.cc b/onnxruntime/test/providers/cpu/tensor/trilu_op_test.cc index f0b5d6afa9c7b..f1d1d94343e6f 100644 --- a/onnxruntime/test/providers/cpu/tensor/trilu_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/trilu_op_test.cc @@ -62,63 +62,54 @@ TEST(TriluOpTest, two_by_two_long_lower) { test.Run(); } +TEST(TriluOpTest, two_by_two_bool_upper) { + OpTester test("Trilu", 14, kOnnxDomain); + int64_t up = 1; + test.AddAttribute("upper", up); + test.AddInput("X", {2, 2}, + {true, true, + true, true}); + test.AddOutput("Y", {2, 2}, + {true, true, + false, true}); + test.Run(); +} + +TEST(TriluOpTest, three_by_three_bool_lower) { + OpTester test("Trilu", 14, kOnnxDomain); + int64_t up = 0; + test.AddAttribute("upper", up); + test.AddInput("X", {3, 3}, + // include a couple of false values to check they are copied + {true, true, true, + true, false, true, + true, true, false}); + test.AddOutput("Y", {3, 3}, + {true, false, false, + true, false, false, + true, true, false}); + test.Run(); +} + TEST(TriluOpTest, three_dim_float_upper) { OpTester test("Trilu", 14, kOnnxDomain); test.AddInput("X", {2, 3, 4}, - { - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - 6.f, - 1.f, - 2.f, - 3.f, - 1.f, - 6.f, - 2.f, - 1.f, - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - }); + {4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + 6.f, 1.f, 2.f, 3.f, + + 1.f, 6.f, 2.f, 1.f, + 4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f}); test.AddInput("k", {1}, {1}); test.AddOutput("Y", {2, 3, 4}, - { - 0.f, - 1.f, - 5.f, - 8.f, - 0.f, - 0.f, - 2.f, - 4.f, - 0.f, - 0.f, - 0.f, - 3.f, - 0.f, - 6.f, - 2.f, - 1.f, - 0.f, - 0.f, - 5.f, - 8.f, - 0.f, - 0.f, - 0.f, - 4.f, - }); + {0.f, 1.f, 5.f, 8.f, + 0.f, 0.f, 2.f, 4.f, + 0.f, 0.f, 0.f, 3.f, + + 0.f, 6.f, 2.f, 1.f, + 0.f, 0.f, 5.f, 8.f, + 0.f, 0.f, 0.f, 4.f}); test.Run(); } @@ -127,60 +118,22 @@ TEST(TriluOpTest, three_dim_float_lower) { int64_t up = 0; test.AddAttribute("upper", up); test.AddInput("X", {2, 3, 4}, - { - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - 6.f, - 1.f, - 2.f, - 3.f, - 1.f, - 6.f, - 2.f, - 1.f, - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - }); + {4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + 6.f, 1.f, 2.f, 3.f, + + 1.f, 6.f, 2.f, 1.f, + 4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f}); test.AddInput("k", {1}, {1}); test.AddOutput("Y", {2, 3, 4}, - { - 4.f, - 1.f, - 0.f, - 0.f, - 4.f, - 3.f, - 2.f, - 0.f, - 6.f, - 1.f, - 2.f, - 3.f, - 1.f, - 6.f, - 0.f, - 0.f, - 4.f, - 1.f, - 5.f, - 0.f, - 4.f, - 3.f, - 2.f, - 4.f, - }); + {4.f, 1.f, 0.f, 0.f, + 4.f, 3.f, 2.f, 0.f, + 6.f, 1.f, 2.f, 3.f, + + 1.f, 6.f, 0.f, 0.f, + 4.f, 1.f, 5.f, 0.f, + 4.f, 3.f, 2.f, 4.f}); test.Run(); } @@ -189,60 +142,22 @@ TEST(TriluOpTest, neg_k_float_upper) { int64_t up = 1; test.AddAttribute("upper", up); test.AddInput("X", {2, 3, 4}, - { - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - 6.f, - 1.f, - 2.f, - 3.f, - 1.f, - 6.f, - 2.f, - 1.f, - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - }); + {4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + 6.f, 1.f, 2.f, 3.f, + + 1.f, 6.f, 2.f, 1.f, + 4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f}); test.AddInput("k", {1}, {-1}); test.AddOutput("Y", {2, 3, 4}, - { - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - 0.f, - 1.f, - 2.f, - 3.f, - 1.f, - 6.f, - 2.f, - 1.f, - 4.f, - 1.f, - 5.f, - 8.f, - 0.f, - 3.f, - 2.f, - 4.f, - }); + {4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + 0.f, 1.f, 2.f, 3.f, + + 1.f, 6.f, 2.f, 1.f, + 4.f, 1.f, 5.f, 8.f, + 0.f, 3.f, 2.f, 4.f}); test.Run(); } @@ -251,120 +166,44 @@ TEST(TriluOpTest, neg_k_float_lower) { int64_t up = 0; test.AddAttribute("upper", up); test.AddInput("X", {2, 3, 4}, - { - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - 6.f, - 1.f, - 2.f, - 3.f, - 1.f, - 6.f, - 2.f, - 1.f, - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - }); + {4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + 6.f, 1.f, 2.f, 3.f, + + 1.f, 6.f, 2.f, 1.f, + 4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f}); test.AddInput("k", {1}, {-1}); test.AddOutput("Y", {2, 3, 4}, - { - 0.f, - 0.f, - 0.f, - 0.f, - 4.f, - 0.f, - 0.f, - 0.f, - 6.f, - 1.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 4.f, - 0.f, - 0.f, - 0.f, - 4.f, - 3.f, - 0.f, - 0.f, - }); + {0.f, 0.f, 0.f, 0.f, + 4.f, 0.f, 0.f, 0.f, + 6.f, 1.f, 0.f, 0.f, + + 0.f, 0.f, 0.f, 0.f, + 4.f, 0.f, 0.f, 0.f, + 4.f, 3.f, 0.f, 0.f}); test.Run(); } TEST(TriluTest, small_k_float_upper) { OpTester test("Trilu", 14, kOnnxDomain); test.AddInput("X", {2, 3, 4}, - { - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - 6.f, - 1.f, - 2.f, - 3.f, - 1.f, - 6.f, - 2.f, - 1.f, - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - }); + {4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + 6.f, 1.f, 2.f, 3.f, + + 1.f, 6.f, 2.f, 1.f, + 4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f}); test.AddInput("k", {1}, {-5}); test.AddOutput("Y", {2, 3, 4}, - { - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - 6.f, - 1.f, - 2.f, - 3.f, - 1.f, - 6.f, - 2.f, - 1.f, - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - }); + {4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + 6.f, 1.f, 2.f, 3.f, + + 1.f, 6.f, 2.f, 1.f, + 4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f}); test.Run(); } @@ -373,60 +212,22 @@ TEST(TriluOpTest, small_k_float_lower) { int64_t up = 0; test.AddAttribute("upper", up); test.AddInput("X", {2, 3, 4}, - { - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - 6.f, - 1.f, - 2.f, - 3.f, - 1.f, - 6.f, - 2.f, - 1.f, - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - }); + {4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + 6.f, 1.f, 2.f, 3.f, + + 1.f, 6.f, 2.f, 1.f, + 4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f}); test.AddInput("k", {1}, {-5}); test.AddOutput("Y", {2, 3, 4}, - { - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - }); + {0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, + + 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f}); test.Run(); }