From b9080be305d71368d7cfcc9f8bf79bb2425f6e1b Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Thu, 23 Nov 2023 12:01:10 -0800 Subject: [PATCH 1/3] add and support svd Signed-off-by: Liqun Fu --- .../contrib_ops/cpu/cpu_contrib_kernels.cc | 10 + .../contrib_ops/cpu/linalg_cholesky.cc | 56 +++++ onnxruntime/contrib_ops/cpu/linalg_cholesky.h | 32 +++ onnxruntime/contrib_ops/cpu/linalg_inv.cc | 46 ++++ onnxruntime/contrib_ops/cpu/linalg_inv.h | 29 +++ onnxruntime/contrib_ops/cpu/linalg_solve.cc | 54 +++++ onnxruntime/contrib_ops/cpu/linalg_solve.h | 32 +++ onnxruntime/contrib_ops/cpu/linalg_svd.cc | 101 +++++++++ onnxruntime/contrib_ops/cpu/linalg_svd.h | 36 ++++ .../core/graph/contrib_ops/contrib_defs.cc | 196 ++++++++++++++++++ onnxruntime/core/graph/contrib_ops/ms_opset.h | 8 + .../test/contrib_ops/linalg_svd_test.cc | 140 +++++++++++++ .../python/test_linalg_ops_with_pytorch.py | 144 +++++++++++++ 13 files changed, 884 insertions(+) create mode 100644 onnxruntime/contrib_ops/cpu/linalg_cholesky.cc create mode 100644 onnxruntime/contrib_ops/cpu/linalg_cholesky.h create mode 100644 onnxruntime/contrib_ops/cpu/linalg_inv.cc create mode 100644 onnxruntime/contrib_ops/cpu/linalg_inv.h create mode 100644 onnxruntime/contrib_ops/cpu/linalg_solve.cc create mode 100644 onnxruntime/contrib_ops/cpu/linalg_solve.h create mode 100644 onnxruntime/contrib_ops/cpu/linalg_svd.cc create mode 100644 onnxruntime/contrib_ops/cpu/linalg_svd.h create mode 100644 onnxruntime/test/contrib_ops/linalg_svd_test.cc create mode 100644 onnxruntime/test/python/test_linalg_ops_with_pytorch.py diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index f9d9b13f0fedc..7d7c644c56450 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -11,6 +11,11 @@ namespace contrib { class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GridSample); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, LinalgSVD); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, LinalgSVD); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, LinalgSolve); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, LinalgInv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, LinalgCholesky); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Attention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, BeamSearch); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, WhisperBeamSearch); @@ -248,6 +253,11 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { // add more kernels here BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/linalg_cholesky.cc b/onnxruntime/contrib_ops/cpu/linalg_cholesky.cc new file mode 100644 index 0000000000000..18fe99e5f3ff6 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/linalg_cholesky.cc @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + +#include "contrib_ops/cpu/linalg_cholesky.h" +#include "core/framework/framework_common.h" +#include "core/framework/tensorprotoutils.h" +#include + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace contrib { + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + LinalgCholesky, + 1, + float, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + LinalgCholesky); + +#pragma warning(disable : 4189) +Status LinalgCholesky::Compute(OpKernelContext* context) const { + Status status = Status::OK(); + const Tensor* A = context->Input(0); + const TensorShape& a_shape = A->Shape(); + assert(a_shape.NumDimensions() == 2); + assert(a_shape[0] == a_shape[1]); + + TensorShape X_shape = { a_shape[0], a_shape[1] }; + Tensor* X = context->Output(0, X_shape); + + const Eigen::StorageOptions option = Eigen::RowMajor; + Eigen::Map> a_matrix(A->Data(), narrow(a_shape[0]), narrow(a_shape[1])); + Eigen::Map> x_matrix(X->MutableData(), narrow(a_shape[0]), narrow(a_shape[1])); + Eigen::LLT lltOfA(a_matrix); + if (lltOfA.info() == Eigen::Success) { + if (this->upper_) { + x_matrix = lltOfA.matrixU(); + } + else { + x_matrix = lltOfA.matrixL(); + } + } + else { + assert(false); + } + return status; +} +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/linalg_cholesky.h b/onnxruntime/contrib_ops/cpu/linalg_cholesky.h new file mode 100644 index 0000000000000..1ba5cc39b8459 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/linalg_cholesky.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +#include "core/common/common.h" +#include "core/framework/feeds_fetches_manager.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/controlflow/utils.h" + +namespace onnxruntime { + +namespace contrib { + +class LinalgCholesky : public OpKernel { + public: + LinalgCholesky(const OpKernelInfo& info) + : OpKernel(info) { + int64_t upper; + ORT_ENFORCE(info.GetAttr("upper", &upper).IsOK()); + upper_ = upper != 0; + } + + Status Compute(OpKernelContext* context) const override; + +private: + bool upper_ = false; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/linalg_inv.cc b/onnxruntime/contrib_ops/cpu/linalg_inv.cc new file mode 100644 index 0000000000000..1cb5d9a14b9fe --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/linalg_inv.cc @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + +#include "contrib_ops/cpu/linalg_inv.h" +#include "core/framework/framework_common.h" +#include "core/framework/tensorprotoutils.h" +#include + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace contrib { + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + LinalgInv, + 1, + float, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + LinalgInv); + +#pragma warning(disable : 4189) +Status LinalgInv::Compute(OpKernelContext* context) const { + Status status = Status::OK(); + const Tensor* A = context->Input(0); + const TensorShape& a_shape = A->Shape(); + assert(a_shape.NumDimensions() == 2); + assert(a_shape[0] == a_shape[1]); + + TensorShape X_shape = { a_shape[1], a_shape[0] }; + Tensor* X = context->Output(0, X_shape); + + const Eigen::StorageOptions option = Eigen::RowMajor; + Eigen::Map> a_matrix(A->Data(), narrow(a_shape[0]), narrow(a_shape[1])); + + Eigen::Map> x_matrix(X->MutableData(), narrow(a_shape[1]), narrow(a_shape[0])); + x_matrix = a_matrix.inverse(); + return status; +} +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/linalg_inv.h b/onnxruntime/contrib_ops/cpu/linalg_inv.h new file mode 100644 index 0000000000000..b79cb8c3f9b77 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/linalg_inv.h @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +#include "core/common/common.h" +#include "core/framework/feeds_fetches_manager.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/controlflow/utils.h" + +namespace onnxruntime { + +namespace contrib { + +class LinalgInv : public OpKernel { + public: + LinalgInv(const OpKernelInfo& info) + : OpKernel(info) { + } + + Status Compute(OpKernelContext* context) const override; + +private: + bool left_ = true; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/linalg_solve.cc b/onnxruntime/contrib_ops/cpu/linalg_solve.cc new file mode 100644 index 0000000000000..258a5bb452246 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/linalg_solve.cc @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + +#include "contrib_ops/cpu/linalg_solve.h" +#include "core/framework/framework_common.h" +#include "core/framework/tensorprotoutils.h" +#include + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace contrib { + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + LinalgSolve, + 1, + float, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + LinalgSolve); + +#pragma warning(disable : 4189) +Status LinalgSolve::Compute(OpKernelContext* context) const { + Status status = Status::OK(); + const Tensor* A = context->Input(0); + const TensorShape& a_shape = A->Shape(); + assert(a_shape.NumDimensions() == 2); + const Tensor* B = context->Input(1); + const TensorShape& b_shape = B->Shape(); + assert(b_shape.NumDimensions() == 2); + + int64_t n = a_shape[0]; + assert(a_shape[1] == n && b_shape[0] == n && b_shape[1] == n); + + TensorShape X_shape = { n, n }; + Tensor* X = context->Output(0, X_shape); + + const Eigen::StorageOptions option = Eigen::RowMajor; + Eigen::Map> a_matrix(A->Data(), narrow(n), narrow(n)); + Eigen::Map> b_matrix(B->Data(), narrow(n), narrow(n)); + + Eigen::BDCSVD svd(a_matrix, Eigen::ComputeThinU | Eigen::ComputeThinV); + + Eigen::Map> x_matrix(X->MutableData(), narrow(n), narrow(n)); + x_matrix = svd.solve(b_matrix); + return status; +} +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/linalg_solve.h b/onnxruntime/contrib_ops/cpu/linalg_solve.h new file mode 100644 index 0000000000000..44dfe107a7782 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/linalg_solve.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +#include "core/common/common.h" +#include "core/framework/feeds_fetches_manager.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/controlflow/utils.h" + +namespace onnxruntime { + +namespace contrib { + +class LinalgSolve : public OpKernel { + public: + LinalgSolve(const OpKernelInfo& info) + : OpKernel(info) { + int64_t left; + ORT_ENFORCE(info.GetAttr("left", &left).IsOK()); + left_ = left != 0; + } + + Status Compute(OpKernelContext* context) const override; + +private: + bool left_ = false; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/linalg_svd.cc b/onnxruntime/contrib_ops/cpu/linalg_svd.cc new file mode 100644 index 0000000000000..da7ac2a303e69 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/linalg_svd.cc @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + +#include "contrib_ops/cpu/linalg_svd.h" +#include "core/framework/framework_common.h" +#include "core/framework/tensorprotoutils.h" +#include +#include + + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace contrib { + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + LinalgSVD, + 1, + float, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + LinalgSVD); + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + LinalgSVD, + 1, + double, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + LinalgSVD); + +template +void compute_svd(const T* a_data, T* u_data, T* s_data, T* v_data, int64_t m, int64_t n, int64_t k, bool full_matrices) { + Eigen::Map> a_map(a_data, narrow(m), narrow(n)); + + Eigen::Map> u_map(u_data, m, full_matrices ? m : k); + Eigen::Map> s_map(s_data, k, 1); + Eigen::Map> v_map(v_data, narrow(full_matrices ? n : k), narrow(n)); + + // Compute the SVD + unsigned int computationOptions = full_matrices ? (Eigen::ComputeFullU | Eigen::ComputeFullV) : (Eigen::ComputeThinU | Eigen::ComputeThinV); + Eigen::JacobiSVD> svd(a_map, computationOptions); + + // Assign the computed matrices to the pre-allocated memory + u_map = svd.matrixU(); + s_map = svd.singularValues(); + v_map = svd.matrixV().transpose(); +} + +template +Status LinalgSVD::Compute(OpKernelContext* context) const { + Status status = Status::OK(); + const Tensor* A = context->Input(0); + const TensorShape& a_shape = A->Shape(); + int64_t dimensions = A->Shape().NumDimensions(); + ORT_ENFORCE(dimensions == 2 || dimensions == 3, "data must be 2D or 3D tensor"); + + int64_t batch = 1, m, n, k; + m = a_shape[dimensions - 2]; + n = a_shape[dimensions - 1]; + k = std::min(m, n); + + TensorShape u_shape, s_shape, v_shape; + if (dimensions == 3) { + batch = a_shape[0]; + u_shape = {batch, m, full_matrices_ ? m : k}; + s_shape = {batch, k}; + v_shape = {batch, full_matrices_ ? n : k, n}; + } else { + u_shape = {m, full_matrices_ ? m : k}; + s_shape = {k}; + v_shape = {full_matrices_ ? n : k, n}; + } + Tensor* U = context->Output(0, u_shape); + Tensor* S = context->Output(1, s_shape); + Tensor* V = context->Output(2, v_shape); + + int64_t a_single_batch_size = A->Shape().SizeFromDimension(dimensions - 2); + int64_t u_single_batch_size = U->Shape().SizeFromDimension(dimensions - 2); + int64_t s_single_batch_size = S->Shape().SizeFromDimension(S->Shape().NumDimensions() - 1); + int64_t v_single_batch_size = V->Shape().SizeFromDimension(dimensions - 2); + + std::function fn = [&](ptrdiff_t batch_num) { + const T* a_data = A->template Data() + batch_num * a_single_batch_size; + T* u_data = U->template MutableData() + batch_num * u_single_batch_size; + T* s_data = S->template MutableData() + batch_num * s_single_batch_size; + T* v_data = V->template MutableData() + batch_num * v_single_batch_size; + + compute_svd(a_data, u_data, s_data, v_data, m, n, k, full_matrices_); + }; + + concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), narrow(batch), std::move(fn), 0); + return status; +} +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/linalg_svd.h b/onnxruntime/contrib_ops/cpu/linalg_svd.h new file mode 100644 index 0000000000000..2ed87ab7925fa --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/linalg_svd.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +#include "core/common/common.h" +#include "core/framework/feeds_fetches_manager.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/controlflow/utils.h" + +namespace onnxruntime { + +namespace contrib { + +template +class LinalgSVD : public OpKernel { + public: + LinalgSVD(const OpKernelInfo& info) + : OpKernel(info) { + int64_t full_matrices; + ORT_ENFORCE(info.GetAttr("full_matrices", &full_matrices).IsOK()); + full_matrices_ = full_matrices != 0; + //int64_t compute_uv; + //ORT_ENFORCE(info.GetAttr("compute_uv", &compute_uv).IsOK()); + //compute_uv_ = compute_uv != 0; + } + + Status Compute(OpKernelContext* context) const override; + +private: + bool full_matrices_ = true; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 4c0d78f0ee297..bb20ee2f99464 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -2712,6 +2712,202 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GemmFloat8, 1, updateOutputShape(ctx, 0, {first_input_shape.dim(transA ? 1 : 0), second_input_shape.dim(transB ? 0 : 1)}); })); +ONNX_MS_OPERATOR_SET_SCHEMA(LinalgSVD, 1, + OpSchema() + .SetDoc(R"DOC(For internal use.)DOC") + .Attr( + "full_matrices", + "", + AttributeProto::INT, + static_cast(1)) + .Input( + 0, + "A", + "", + "T") + .Output( + 0, + "U", + "", + "T") + .Output( + 1, + "S", + "", + "T") + .Output( + 2, + "Vh", + "", + "T") + .TypeConstraint( + "T", + {"tensor(float)", "tensor(double)"}, + "") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 1); + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 2); + int64_t full_matrices = ctx.getAttribute("full_matrices")->i(); + + const TensorShapeProto& A_shape = ctx.getInputType(0)->tensor_type().shape(); + const auto& M = A_shape.dim(A_shape.dim_size() - 2); + const auto& N = A_shape.dim(A_shape.dim_size() - 1); + if (!M.has_dim_value() || !N.has_dim_value()) { + // cannot do shape inference without knowing dimension values + return; + } + const auto& K = M.dim_value() < N.dim_value() ? M : N; + auto* u_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); + auto* s_shape = ctx.getOutputType(1)->mutable_tensor_type()->mutable_shape(); + auto* v_shape = ctx.getOutputType(2)->mutable_tensor_type()->mutable_shape(); + if (A_shape.dim_size() == 3) { + const auto& batch_dim = A_shape.dim(0); + *u_shape->add_dim() = batch_dim; + *s_shape->add_dim() = batch_dim; + *v_shape->add_dim() = batch_dim; + } + *u_shape->add_dim() = M; + *u_shape->add_dim() = full_matrices ? M : K; + *s_shape->add_dim() = K; + *v_shape->add_dim() = full_matrices ? N : K; + *v_shape->add_dim() = N; + })); + +ONNX_MS_OPERATOR_SET_SCHEMA(LinalgCholesky, 1, + OpSchema() + .SetDoc(R"DOC(For internal use.)DOC") + .Attr( + "upper", + "", + AttributeProto::INT, + static_cast(0)) + .Input( + 0, + "A", + "", + "T") + .Output( + 0, + "X", + "", + "T") + .TypeConstraint( + "T", + {"tensor(float)"}, + "") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); + const TensorShapeProto& A_shape = ctx.getInputType(0)->tensor_type().shape(); + auto* output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); + + *output_shape->add_dim() = A_shape.dim(0); + *output_shape->add_dim() = A_shape.dim(1); + })); + +ONNX_MS_OPERATOR_SET_SCHEMA(LinalgInv, 1, + OpSchema() + .SetDoc(R"DOC(For internal use.)DOC") + .Input( + 0, + "A", + "", + "T") + .Output( + 0, + "X", + "", + "T") + .TypeConstraint( + "T", + {"tensor(float)"}, + "") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); + const TensorShapeProto& A_shape = ctx.getInputType(0)->tensor_type().shape(); + auto* output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); + + *output_shape->add_dim() = A_shape.dim(0); + *output_shape->add_dim() = A_shape.dim(1); + })); + +ONNX_MS_OPERATOR_SET_SCHEMA(LinalgSolve, 1, + OpSchema() + .SetDoc(R"DOC(For internal use.)DOC") + .Attr( + "left", + "", + AttributeProto::INT, + static_cast(0)) + .Input( + 0, + "A", + "", + "T") + .Input( + 1, + "B", + "", + "T") + .Output( + 0, + "X", + "", + "T") + .TypeConstraint( + "T", + {"tensor(float)"}, + "") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); + const TensorShapeProto& A_shape = ctx.getInputType(0)->tensor_type().shape(); + const TensorShapeProto& B_shape = ctx.getInputType(1)->tensor_type().shape(); + auto* output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); + + ///////////////////////////////// + // for now only for np.linalg.solve(_s, _d) + ///////////////////////////////// + *output_shape->add_dim() = A_shape.dim(0); + *output_shape->add_dim() = B_shape.dim(1); + + // int64_t A_rank = A_shape.dim_size(); + // int64_t B_rank = B_shape.dim_size(); + // assert(A_rank == 3); // shape A mush be (*, n, n) + ////assert(A_shape.dim(1) == A_shape.dim(2)); + + // auto* output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); + //*output_shape->add_dim() = A_shape.dim(0); + // if (B_rank == 3) { + // // B (*, n, k) case => (*, n, k) + // //assert(A_shape.dim(2) == B_shape.dim(1)); + // *output_shape->add_dim() = A_shape.dim(1); + // *output_shape->add_dim() = B_shape.dim(2); + // } + // else if (B_rank == 1) { + // // B (n,) case => (*, n) + // //assert(A_shape.dim(2) == B_shape.dim(0)); + // *output_shape->add_dim() = A_shape.dim(1); + // } + // else if (B_rank == 2) { + // // B (*, n) or (n, k) cases + // if (/*B_shape.dim(0) == A_shape.dim(0) &&*/ A_shape.dim(2).dim_value() == B_shape.dim(1).dim_value()) { + // // B (*, n) => (*, n) + // *output_shape->add_dim() = A_shape.dim(1); + // } + // else if (A_shape.dim(1).dim_value() == B_shape.dim(0).dim_value()) { + // // B (n, k) => (*, n, k) + // *output_shape->add_dim() = A_shape.dim(1); + // *output_shape->add_dim() = B_shape.dim(1); + // } + // else { + // assert(false); + // } + // } + // else { + // assert(false); + // } + })); + static void MatmulWithQuantWeightShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int64_t K, int64_t N, diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index 5eef1b33a24dd..456008f597bb3 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -45,6 +45,10 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QOrderedAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QOrderedLongformerAttention); // Others +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, LinalgSVD); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, LinalgSolve); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, LinalgInv); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, LinalgCholesky); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Attention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BeamSearch); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, WhisperBeamSearch); @@ -152,6 +156,10 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); + fn(GetOpSchema()); + fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/test/contrib_ops/linalg_svd_test.cc b/onnxruntime/test/contrib_ops/linalg_svd_test.cc new file mode 100644 index 0000000000000..47605dc92dcce --- /dev/null +++ b/onnxruntime/test/contrib_ops/linalg_svd_test.cc @@ -0,0 +1,140 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include "core/util/math.h" + +namespace onnxruntime { +namespace test { + +template +class LinalgSVDContribOpTest : public ::testing::Test { +}; + +using LinalgTypes = ::testing::Types; +TYPED_TEST_SUITE(LinalgSVDContribOpTest, LinalgTypes); + +// DO NOT EDIT following test cases.they are generated with: +// in test_linalg_ops_with_pytorch.py, set generate_testcases to True to print C++ test cases +// python onnxruntime/test/python/test_linalg_ops_with_pytorch.py -k TestLinalgOps.test_linalg_svd +TYPED_TEST(LinalgSVDContribOpTest, batch_full_matrices) { + OpTester test("LinalgSVD", 1, kMSDomain); + test.AddAttribute("full_matrices", (int64_t)1); + test.AddInput("A", {2, 3, 4}, { + -1.125840f, -1.152360f, -0.250579f, -0.433879f, + 0.848710f, 0.692009f, -0.316013f, -2.115219f, + 0.468096f, -0.157712f, 1.443660f, 0.266049f, + 0.166455f, 0.874382f, -0.143474f, -0.111609f, + 0.931827f, 1.259009f, 2.004981f, 0.053737f, + 0.618057f, -0.412802f, -0.841065f, -2.316042f + }); + test.AddOutput("U", {2, 3, 3}, { + 0.190744f, 0.773181f, -0.604820f, + -0.969842f, 0.053195f, -0.237860f, + 0.151736f, -0.631950f, -0.760010f, + 0.078401f, -0.181647f, 0.980233f, + 0.702239f, -0.687852f, -0.183633f, + -0.707612f, -0.702755f, -0.073631f + }); + test.AddOutput("S", {2, 3}, {{ + 2.456875f, 1.861905f, 1.231135f, + 2.889926f, 2.222110f, 0.797447f + }}); + test.AddOutput("Vh", {2, 4, 4}, { + -0.393522f, -0.372374f, 0.194451f, 0.817720f, + -0.602149f, -0.405233f, -0.603078f, -0.330906f, + 0.100150f, 0.529781f, -0.707051f, 0.457582f, + 0.687406f, -0.645334f, -0.313951f, 0.111593f, + 0.079611f, 0.430731f, 0.689247f, 0.577123f, + -0.497517f, -0.330651f, -0.342920f, 0.724951f, + -0.067035f, 0.822999f, -0.560399f, 0.064283f, + 0.861188f, -0.166776f, -0.305446f, 0.370463f + }); + test.Run(); +} + +TYPED_TEST(LinalgSVDContribOpTest, batch_no_full_matrices) { + OpTester test("LinalgSVD", 1, kMSDomain); + test.AddAttribute("full_matrices", (int64_t)0); + test.AddInput("A", {2, 3, 4}, { + -1.125840f, -1.152360f, -0.250579f, -0.433879f, + 0.848710f, 0.692009f, -0.316013f, -2.115219f, + 0.468096f, -0.157712f, 1.443660f, 0.266049f, + 0.166455f, 0.874382f, -0.143474f, -0.111609f, + 0.931827f, 1.259009f, 2.004981f, 0.053737f, + 0.618057f, -0.412802f, -0.841065f, -2.316042f + }); + test.AddOutput("U", {2, 3, 3}, { + 0.190744f, 0.773181f, -0.604820f, + -0.969842f, 0.053195f, -0.237860f, + 0.151736f, -0.631950f, -0.760010f, + 0.078401f, -0.181647f, 0.980233f, + 0.702239f, -0.687852f, -0.183633f, + -0.707612f, -0.702755f, -0.073631f + }); + test.AddOutput("S", {2, 3}, {{ + 2.456875f, 1.861905f, 1.231135f, + 2.889926f, 2.222110f, 0.797447f + }}); + test.AddOutput("Vh", {2, 3, 4}, { + -0.393522f, -0.372374f, 0.194451f, 0.817720f, + -0.602149f, -0.405233f, -0.603078f, -0.330906f, + 0.100150f, 0.529781f, -0.707051f, 0.457582f, + 0.079611f, 0.430731f, 0.689247f, 0.577123f, + -0.497517f, -0.330651f, -0.342920f, 0.724951f, + -0.067035f, 0.822999f, -0.560399f, 0.064283f + }); + test.Run(); +} + +TYPED_TEST(LinalgSVDContribOpTest, no_batch_full_matrices) { + OpTester test("LinalgSVD", 1, kMSDomain); + test.AddAttribute("full_matrices", (int64_t)1); + test.AddInput("A", {3, 4}, { + 1.540996f, -0.293429f, -2.178789f, 0.568431f, + -1.084522f, -1.398595f, 0.403347f, 0.838026f, + -0.719258f, -0.403344f, -0.596635f, 0.182036f + }); + test.AddOutput("U", {3, 3}, { + -0.928314f, 0.342269f, -0.145207f, + 0.371614f, 0.841924f, -0.391236f, + 0.011654f, 0.417151f, 0.908762f + }); + test.AddOutput("S", {3}, {{ + 2.862108f, 1.985799f, 0.679939f + }}); + test.AddOutput("Vh", {4, 4}, { + -0.643559f, -0.088063f, 0.756623f, -0.074819f, + -0.345297f, -0.728271f, -0.329858f, 0.491514f, + -0.666373f, 0.328333f, -0.564209f, -0.360296f, + -0.150161f, 0.595033f, 0.019583f, 0.789306f + }); + test.Run(); +} + +TYPED_TEST(LinalgSVDContribOpTest, no_batch_no_full_matrices) { + OpTester test("LinalgSVD", 1, kMSDomain); + test.AddAttribute("full_matrices", (int64_t)0); + test.AddInput("A", {3, 4}, { + 1.540996f, -0.293429f, -2.178789f, 0.568431f, + -1.084522f, -1.398595f, 0.403347f, 0.838026f, + -0.719258f, -0.403344f, -0.596635f, 0.182036f + }); + test.AddOutput("U", {3, 3}, { + -0.928314f, 0.342269f, -0.145207f, + 0.371614f, 0.841924f, -0.391236f, + 0.011654f, 0.417151f, 0.908762f + }); + test.AddOutput("S", {3}, {{ + 2.862108f, 1.985799f, 0.679939f + }}); + test.AddOutput("Vh", {3, 4}, { + -0.643559f, -0.088063f, 0.756623f, -0.074819f, + -0.345297f, -0.728271f, -0.329858f, 0.491514f, + -0.666373f, 0.328333f, -0.564209f, -0.360296f + }); + test.Run(); +} +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/python/test_linalg_ops_with_pytorch.py b/onnxruntime/test/python/test_linalg_ops_with_pytorch.py new file mode 100644 index 0000000000000..88b498d7d6103 --- /dev/null +++ b/onnxruntime/test/python/test_linalg_ops_with_pytorch.py @@ -0,0 +1,144 @@ +import unittest +import onnx +from onnx import helper +import onnxruntime as ort +import torch + +import numpy as np +import parameterized + +# set generate_testcases to True to print C++ test cases +generate_testcases = True + +def create_model(op, inputs, outputs, opset_version, node_kwargs): + # create an onnx model with the given op + input_names = [i.name for i in inputs] + output_names = [o.name for o in outputs] + node = helper.make_node(op, input_names, output_names, **node_kwargs) + graph = helper.make_graph([node], f"test_graph_with_linalg_{op}", inputs, outputs) + + # TODO: remove onnx opset import + opset_imports = [ + onnx.helper.make_opsetid("", opset_version), + onnx.helper.make_opsetid("com.microsoft", 1)] + meta = { + "ir_version": 9, + "opset_imports": opset_imports, + "producer_name": "onnxruntime test", + } + model = onnx.helper.make_model(graph, **meta) + onnx.checker.check_model(model) + return model + +def create_svd_model(use_batch, full_matrices): + a_shape = ["B"] if use_batch else [] + if full_matrices: + a_shape = [*a_shape, "M", "N"] + u_shape = [*a_shape, "M", "M"] + s_shape = [*a_shape, "N"] + v_shape = [*a_shape, "N", "N"] + else: + a_shape = [*a_shape, "M", "N"] + u_shape = [*a_shape, "M", "K"] + s_shape = [*a_shape, "K"] + v_shape = [*a_shape, "K", "N"] + + a_value_info = helper.make_tensor_value_info("A", onnx.TensorProto.FLOAT, a_shape) + u_value_info = helper.make_tensor_value_info("U", onnx.TensorProto.FLOAT, u_shape) + s_value_info = helper.make_tensor_value_info("S", onnx.TensorProto.FLOAT, s_shape) + v_value_info = helper.make_tensor_value_info("V", onnx.TensorProto.FLOAT, v_shape) + + onnx_model = create_model( + "LinalgSVD", + [a_value_info], + [u_value_info, s_value_info, v_value_info], + opset_version=17, + node_kwargs={"full_matrices": full_matrices, "domain": "com.microsoft",}, + ) + return onnx_model + +def normalize_signs(a, b): + signs = np.sign(a[0]) * np.sign(b[0]) + return a, b * signs, signs == -1 + +def validate_base_equal(actual, expected): + sign_changes = 0 + if len(expected.shape) > 2: + for actual_, expected_ in zip(actual, expected): + validate_base_equal(actual_, expected_) + return + for i in range(expected.shape[1]): + actual[:, i], expected[:, i], sign_changed = normalize_signs(actual[:, i], expected[:, i]) + if sign_changed: + sign_changes += 1 + np.testing.assert_allclose(expected, actual, rtol=1e-5, atol=1e-7) + assert(sign_changes % 2 == 0) + +def format_tensor(tensor): + # Reshape the tensor and convert it to a list of lists + tensor_list = tensor.reshape(-1, tensor.shape[-1]).tolist() + + # Format each row of the tensor as a string + tensor_str = ',\n '.join([', '.join([f"{val:.6f}f" for val in row]) for row in tensor_list]) + + return '{\n ' + tensor_str + '\n }' + +class TestLinalgOps(unittest.TestCase): + def setUp(self): + self.opset_version = 17 + + @parameterized.parameterized.expand([ + (True, True), + (True, False), + (False, True), + (False, False), + ]) + def test_linalg_svd(self, use_batch, full_matrices): + torch.manual_seed(0) + if use_batch: + A = torch.randn(2, 3, 4) + else: + A = torch.randn(3, 4) + U, S, Vh = torch.linalg.svd(A, full_matrices=full_matrices) + + onnx_model = create_svd_model(use_batch=use_batch, full_matrices=full_matrices) + session = ort.InferenceSession(onnx_model.SerializeToString()) + input_name = session.get_inputs()[0].name + output_names = [output.name for output in session.get_outputs()] + input_data = {input_name: A.numpy()} + actual_u, actual_s, actual_vh = session.run(output_names, input_data) + + expected_u = U.numpy() + expected_s = S.numpy() + expected_vh = Vh.numpy() + validate_base_equal(actual_u, expected_u) + if len(expected_vh.shape) == 3: + validate_base_equal(actual_vh.transpose((0, -1, -2)), expected_vh.transpose((0, -1, -2))) + else: + validate_base_equal(actual_vh.transpose(), expected_vh.transpose()) + np.testing.assert_allclose(actual_s, expected_s, rtol=1e-5, atol=1e-7) + + if generate_testcases: + # Print the C++ test case + A_str = format_tensor(A) + U_str = format_tensor(actual_u) + S_str = format_tensor(actual_s) + Vh_str = format_tensor(actual_vh) + + batch_str = 'batch' if use_batch else 'no_batch' + full_matrices_str = 'full_matrices' if full_matrices else 'no_full_matrices' + test_case_name = f'{batch_str}_{full_matrices_str}' + + print(f'TYPED_TEST(LinalgSVDContribOpTest, {test_case_name}) {{\n' + f' OpTester test("LinalgSVD", 1, kMSDomain);\n' + f' test.AddAttribute("full_matrices", (int64_t){"1" if full_matrices else "0"});\n' + f' test.AddInput("A", {{{", ".join(map(str, A.shape))}}}, {A_str});\n' + f' test.AddOutput("U", {{{", ".join(map(str, U.shape))}}}, {U_str});\n' + f' test.AddOutput("S", {{{", ".join(map(str, S.shape))}}}, {{{S_str}}});\n' + f' test.AddOutput("Vh", {{{", ".join(map(str, Vh.shape))}}}, {Vh_str});\n' + f' test.Run();\n' + f'}}') + + +if __name__ == '__main__': + unittest.main() From ddec688eb61b696e09c115bbe055a90af97dd8cb Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Fri, 24 Nov 2023 20:37:27 -0800 Subject: [PATCH 2/3] solve kernel and tests Signed-off-by: Liqun Fu --- .../contrib_ops/cpu/cpu_contrib_kernels.cc | 2 + onnxruntime/contrib_ops/cpu/linalg_solve.cc | 136 ++++++- onnxruntime/contrib_ops/cpu/linalg_solve.h | 3 +- .../core/graph/contrib_ops/contrib_defs.cc | 149 +++++--- .../test/contrib_ops/linalg_cholesky_test.cc | 111 ++++++ .../test/contrib_ops/linalg_solve_test.cc | 361 ++++++++++++++++++ .../python/test_linalg_ops_with_pytorch.py | 174 +++++++++ 7 files changed, 867 insertions(+), 69 deletions(-) create mode 100644 onnxruntime/test/contrib_ops/linalg_cholesky_test.cc create mode 100644 onnxruntime/test/contrib_ops/linalg_solve_test.cc diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 7d7c644c56450..e2feec38540fa 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -14,6 +14,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, LinalgSVD); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, LinalgSVD); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, LinalgSolve); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, LinalgSolve); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, LinalgInv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, LinalgCholesky); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Attention); @@ -256,6 +257,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/linalg_solve.cc b/onnxruntime/contrib_ops/cpu/linalg_solve.cc index 258a5bb452246..9966698cf6e60 100644 --- a/onnxruntime/contrib_ops/cpu/linalg_solve.cc +++ b/onnxruntime/contrib_ops/cpu/linalg_solve.cc @@ -22,32 +22,140 @@ ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( float, KernelDefBuilder() .TypeConstraint("T", DataTypeImpl::GetTensorType()), - LinalgSolve); + LinalgSolve); + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + LinalgSolve, + 1, + double, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + LinalgSolve); + +template +void solve(const T* a_data, const T* b_data, T* x_data, bool left, int64_t n, int64_t k) { + const Eigen::StorageOptions option = Eigen::RowMajor; + Eigen::Map> b_matrix(b_data, narrow(n), narrow(k)); + Eigen::Map> x_matrix(x_data, narrow(n), narrow(k)); + + if (left) { + Eigen::Map> a_matrix(a_data, narrow(n), narrow(n)); + Eigen::BDCSVD> svd(a_matrix, Eigen::ComputeThinU | Eigen::ComputeThinV); + x_matrix = svd.solve(b_matrix); + } else { + Eigen::Map> a_matrix(a_data, narrow(k), narrow(k)); + auto a_matrix_transposed = a_matrix.transpose(); + auto b_matrix_transposed = b_matrix.transpose(); + Eigen::BDCSVD> svd(a_matrix_transposed, Eigen::ComputeThinU | Eigen::ComputeThinV); + Eigen::Matrix x_matrix_transposed_result = svd.solve(b_matrix_transposed); + x_matrix = x_matrix_transposed_result.transpose(); + } +} #pragma warning(disable : 4189) -Status LinalgSolve::Compute(OpKernelContext* context) const { +template +Status LinalgSolve::Compute(OpKernelContext* context) const { Status status = Status::OK(); const Tensor* A = context->Input(0); const TensorShape& a_shape = A->Shape(); - assert(a_shape.NumDimensions() == 2); + assert(a_shape.NumDimensions() == 2 || a_shape.NumDimensions() == 3); + bool has_batch = a_shape.NumDimensions() == 3; const Tensor* B = context->Input(1); const TensorShape& b_shape = B->Shape(); - assert(b_shape.NumDimensions() == 2); - int64_t n = a_shape[0]; - assert(a_shape[1] == n && b_shape[0] == n && b_shape[1] == n); + int64_t batch = has_batch ? a_shape[0] : 1, n = 1, k = 1; + bool b_as_a_vector = b_shape.NumDimensions() == 1; + bool broadcast; + int64_t n_or_k = a_shape[1]; + if (left_) { + n = n_or_k; + } else { + k = n_or_k; + } - TensorShape X_shape = { n, n }; - Tensor* X = context->Output(0, X_shape); + if (has_batch) { + ORT_ENFORCE(a_shape[1] == a_shape[2], "A should be square matrix: ", a_shape); + if (b_shape.NumDimensions() == 1) { + b_as_a_vector = true; + broadcast = true; + } else if (b_shape.NumDimensions() == 2) { + if (b_shape[0] == a_shape[0] && b_shape[1] == a_shape[1]) { // A has shape (*, n/k, n/k) and B has shape(*, n/k) + b_as_a_vector = true; + broadcast = false; + } else if (left_ && b_shape[0] == a_shape[1]) { // A has shape (*, n, n) and B has shape (n, k) + b_as_a_vector = false; + broadcast = true; + k = b_shape[1]; + } else if (!left_ && b_shape[1] == a_shape[1]) { // A has shape (*, k, k) and B has shape (n, k) + b_as_a_vector = false; + broadcast = true; + n = b_shape[0]; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "B shape does not mach A shape.", b_shape, a_shape); + } + } else { // b_shape.NumDimensions() == 3 + ORT_ENFORCE(b_shape[0] == a_shape[0], "A and B shall have the same batch size"); + b_as_a_vector = false; + broadcast = false; + if (left_) { // A: (*, n, n), B: (*, n, k) + ORT_ENFORCE(b_shape[1] == a_shape[1], "A and B shall have matching size at dim 1: ", b_shape[1], "vs", a_shape[1]); + k = b_shape[2]; + } else { // A: (*, k, k), B: (*, n, k) + ORT_ENFORCE(b_shape[2] == a_shape[2], "A and B shall have matching size at dim 2: ", b_shape[2], "vs", a_shape[2]); + n = b_shape[1]; + } + } + } else { // !has_batch + ORT_ENFORCE(a_shape[0] == a_shape[1], "A should be square matrix: ", a_shape); + broadcast = false; + if (b_shape.NumDimensions() == 1) { // A: (n/k. n/k), B: (n/k,) + ORT_ENFORCE(b_shape[0] == a_shape[0], "A and B shall have matching size at dim 2: ", b_shape[2], "vs", a_shape[2]); + b_as_a_vector = true; + } else if (b_shape.NumDimensions() == 2) { // A: (n/k. n/k), B: (n, k) + b_as_a_vector = false; + if (left_) { // A: (n, n), B: (n, k) + k = b_shape[1]; + } else { // A: (k, k), B: (n, k) + n = b_shape[0]; + } + } + } - const Eigen::StorageOptions option = Eigen::RowMajor; - Eigen::Map> a_matrix(A->Data(), narrow(n), narrow(n)); - Eigen::Map> b_matrix(B->Data(), narrow(n), narrow(n)); + std::vector x_dims; + if (has_batch) { + x_dims.push_back(batch); + } + if (b_as_a_vector) { + if (left_) { + x_dims.push_back(n); + } else { + x_dims.push_back(k); + } + } else { + x_dims.push_back(n); + x_dims.push_back(k); + } + TensorShape x_shape(x_dims); + Tensor* X = context->Output(0, x_shape); - Eigen::BDCSVD svd(a_matrix, Eigen::ComputeThinU | Eigen::ComputeThinV); + if (batch == 1) { + const T* a_data = A->Data(); + const T* b_data = B->Data(); + T* x_data = X->MutableData(); + solve(a_data, b_data, x_data, left_, n, k); + } else { + int64_t a_single_batch_size = a_shape.SizeFromDimension(a_shape.NumDimensions() - 2); + int64_t b_single_batch_size = broadcast ? 0 : b_shape.SizeFromDimension(1); + int64_t x_single_batch_size = x_shape.SizeFromDimension(1); + std::function fn = [&](ptrdiff_t batch_num) { + const T* a_data = A->Data() + batch_num * a_single_batch_size; + const T* b_data = B->Data() + (broadcast ? 0 : batch_num * b_single_batch_size); + T* x_data = X->MutableData() + batch_num * x_single_batch_size; + solve(a_data, b_data, x_data, left_, n, k); + }; + concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), narrow(batch), std::move(fn), 0); + } - Eigen::Map> x_matrix(X->MutableData(), narrow(n), narrow(n)); - x_matrix = svd.solve(b_matrix); return status; } } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/linalg_solve.h b/onnxruntime/contrib_ops/cpu/linalg_solve.h index 44dfe107a7782..c4a15a2cd130e 100644 --- a/onnxruntime/contrib_ops/cpu/linalg_solve.h +++ b/onnxruntime/contrib_ops/cpu/linalg_solve.h @@ -13,6 +13,7 @@ namespace onnxruntime { namespace contrib { +template class LinalgSolve : public OpKernel { public: LinalgSolve(const OpKernelInfo& info) @@ -25,7 +26,7 @@ class LinalgSolve : public OpKernel { Status Compute(OpKernelContext* context) const override; private: - bool left_ = false; + bool left_ = true; }; } // namespace contrib diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index bb20ee2f99464..344441ff6f9ba 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -2800,9 +2800,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(LinalgCholesky, 1, ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); const TensorShapeProto& A_shape = ctx.getInputType(0)->tensor_type().shape(); auto* output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); - - *output_shape->add_dim() = A_shape.dim(0); - *output_shape->add_dim() = A_shape.dim(1); + for (auto& dim : A_shape.dim()) + *output_shape->add_dim() = dim; })); ONNX_MS_OPERATOR_SET_SCHEMA(LinalgInv, 1, @@ -2831,6 +2830,96 @@ ONNX_MS_OPERATOR_SET_SCHEMA(LinalgInv, 1, *output_shape->add_dim() = A_shape.dim(1); })); +void linalg_solve_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) { + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); + const TensorShapeProto& a_shape = ctx.getInputType(0)->tensor_type().shape(); + const TensorShapeProto& b_shape = ctx.getInputType(1)->tensor_type().shape(); + int64_t left = ctx.getAttribute("left")->i(); + auto* output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); + + int64_t a_rank = a_shape.dim_size(); + assert(a_rank == 2 || a_rank == 3); // shape A mush be (*, n, n) + int64_t b_rank = b_shape.dim_size(); + if (a_rank == 3) { // has batch + // assert(a_shape.dim(1).dim_value() == a_shape.dim(2).dim_value()); + if (left) { + // A: (b, n, n) + if (b_rank == 1) { + // B: (n,) => X: (b, n) + // assert(b_shape.dim(0).dim_value() == a_shape.dim(1).dim_value()); + *output_shape->add_dim() = a_shape.dim(0); + *output_shape->add_dim() = a_shape.dim(1); + } else if (b_rank == 2) { + if (b_shape.dim(0).dim_value() == a_shape.dim(1).dim_value()) { + // B: (n, k) => X: (b, n, k) + *output_shape->add_dim() = a_shape.dim(0); + *output_shape->add_dim() = a_shape.dim(1); + *output_shape->add_dim() = b_shape.dim(1); + } else if (b_shape.dim(0).dim_value() == a_shape.dim(0).dim_value()) { + // B: (b, n) => X: (b, n) + *output_shape->add_dim() = a_shape.dim(0); + *output_shape->add_dim() = a_shape.dim(1); + } + } else if (b_rank == 3) { + // B: (b, n, k) => X: (b, n, k) + // assert(b_shape.dim(1) == a_shape.dim(1)); + *output_shape->add_dim() = a_shape.dim(0); + *output_shape->add_dim() = a_shape.dim(1); + *output_shape->add_dim() = b_shape.dim(2); + } + } else { + // A: (b, k, k) + if (b_rank == 1) { + // B: (k,) => X: (b, k) + // assert(b_shape.dim(0) == a_shape.dim(1)); + *output_shape->add_dim() = a_shape.dim(0); + *output_shape->add_dim() = a_shape.dim(1); + } else if (b_rank == 2) { + if (b_shape.dim(1).dim_value() == a_shape.dim(1).dim_value()) { + // B: (n, k) => X: (b, n, k) + *output_shape->add_dim() = a_shape.dim(0); + *output_shape->add_dim() = b_shape.dim(0); + *output_shape->add_dim() = a_shape.dim(1); + } else if (b_shape.dim(0).dim_value() == a_shape.dim(0).dim_value()) { + // B: (b, k) => X: (b, k) + *output_shape->add_dim() = a_shape.dim(0); + *output_shape->add_dim() = a_shape.dim(1); + } + } else if (b_rank == 3) { + // B: (b, n, k) => X: (b, n, k) + // assert(b_shape.dim(1) == a_shape.dim(1)); + *output_shape->add_dim() = a_shape.dim(0); + *output_shape->add_dim() = b_shape.dim(1); + *output_shape->add_dim() = a_shape.dim(1); + } + } + } else { // a_rank == 2, no batch + assert(b_rank == 1 || b_rank == 2); + if (left) { + // A: (n, n) + if (b_rank == 1) { + // B: (n,) => X: (n,) + *output_shape->add_dim() = a_shape.dim(0); + } else if (b_rank == 2) { + // B: (n, k) => X: (n, k) + *output_shape->add_dim() = a_shape.dim(0); + *output_shape->add_dim() = b_shape.dim(1); + } + } else { + // A: (k, k) + if (b_rank == 1) { + // B: (k,) => X: (k,) + // assert(b_shape.dim(0) == a_shape.dim(0)); + *output_shape->add_dim() = a_shape.dim(0); + } else if (b_rank == 2) { + // B: (n, k) => X: (n, k) + *output_shape->add_dim() = b_shape.dim(0); + *output_shape->add_dim() = a_shape.dim(0); + } + } + } +} + ONNX_MS_OPERATOR_SET_SCHEMA(LinalgSolve, 1, OpSchema() .SetDoc(R"DOC(For internal use.)DOC") @@ -2838,7 +2927,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(LinalgSolve, 1, "left", "", AttributeProto::INT, - static_cast(0)) + static_cast(1)) .Input( 0, "A", @@ -2856,57 +2945,9 @@ ONNX_MS_OPERATOR_SET_SCHEMA(LinalgSolve, 1, "T") .TypeConstraint( "T", - {"tensor(float)"}, + {"tensor(float)", "tensor(double)"}, "") - .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); - const TensorShapeProto& A_shape = ctx.getInputType(0)->tensor_type().shape(); - const TensorShapeProto& B_shape = ctx.getInputType(1)->tensor_type().shape(); - auto* output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); - - ///////////////////////////////// - // for now only for np.linalg.solve(_s, _d) - ///////////////////////////////// - *output_shape->add_dim() = A_shape.dim(0); - *output_shape->add_dim() = B_shape.dim(1); - - // int64_t A_rank = A_shape.dim_size(); - // int64_t B_rank = B_shape.dim_size(); - // assert(A_rank == 3); // shape A mush be (*, n, n) - ////assert(A_shape.dim(1) == A_shape.dim(2)); - - // auto* output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); - //*output_shape->add_dim() = A_shape.dim(0); - // if (B_rank == 3) { - // // B (*, n, k) case => (*, n, k) - // //assert(A_shape.dim(2) == B_shape.dim(1)); - // *output_shape->add_dim() = A_shape.dim(1); - // *output_shape->add_dim() = B_shape.dim(2); - // } - // else if (B_rank == 1) { - // // B (n,) case => (*, n) - // //assert(A_shape.dim(2) == B_shape.dim(0)); - // *output_shape->add_dim() = A_shape.dim(1); - // } - // else if (B_rank == 2) { - // // B (*, n) or (n, k) cases - // if (/*B_shape.dim(0) == A_shape.dim(0) &&*/ A_shape.dim(2).dim_value() == B_shape.dim(1).dim_value()) { - // // B (*, n) => (*, n) - // *output_shape->add_dim() = A_shape.dim(1); - // } - // else if (A_shape.dim(1).dim_value() == B_shape.dim(0).dim_value()) { - // // B (n, k) => (*, n, k) - // *output_shape->add_dim() = A_shape.dim(1); - // *output_shape->add_dim() = B_shape.dim(1); - // } - // else { - // assert(false); - // } - // } - // else { - // assert(false); - // } - })); + .TypeAndShapeInferenceFunction(linalg_solve_shape_infer)); static void MatmulWithQuantWeightShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int64_t K, diff --git a/onnxruntime/test/contrib_ops/linalg_cholesky_test.cc b/onnxruntime/test/contrib_ops/linalg_cholesky_test.cc new file mode 100644 index 0000000000000..4a52865b020e0 --- /dev/null +++ b/onnxruntime/test/contrib_ops/linalg_cholesky_test.cc @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include "core/util/math.h" + +namespace onnxruntime { +namespace test { + +template +class LinalgCholeskyContribOpTest : public ::testing::Test { +}; + +using LinalgTypes = ::testing::Types; +TYPED_TEST_SUITE(LinalgCholeskyContribOpTest, LinalgTypes); + +// DO NOT EDIT following test cases.they are generated with: +// in test_linalg_ops_with_pytorch.py, set generate_testcases to True to print C++ test cases +// python onnxruntime/test/python/test_linalg_ops_with_pytorch.py -k TestLinalgOps.test_linalg_cholesky + +TYPED_TEST(LinalgCholeskyContribOpTest, no_batch_lower) { + OpTester test("LinalgCholesky", 1, kMSDomain); + test.AddAttribute("upper", (int64_t)0); + test.AddInput("A", {4, 4}, { + 3.846490f, -0.756021f, 0.871596f, -1.733702f, + -0.756021f, 6.773203f, -1.363091f, 1.128313f, + 0.871596f, -1.363091f, 2.917311f, -1.210296f, + -1.733702f, 1.128313f, -1.210296f, 3.854497f + }); + test.AddOutput("L", {4, 4}, { + 1.961247f, 0.000000f, 0.000000f, 0.000000f, + -0.385480f, 2.573832f, 0.000000f, 0.000000f, + 0.444409f, -0.463038f, 1.582848f, 0.000000f, + -0.883979f, 0.305986f, -0.426929f, 1.672477f + }); + test.Run(); +} + +TYPED_TEST(LinalgCholeskyContribOpTest, no_batch_upper) { + OpTester test("LinalgCholesky", 1, kMSDomain); + test.AddAttribute("upper", (int64_t)1); + test.AddInput("A", {4, 4}, { + 3.846490f, -0.756021f, 0.871596f, -1.733702f, + -0.756021f, 6.773203f, -1.363091f, 1.128313f, + 0.871596f, -1.363091f, 2.917311f, -1.210296f, + -1.733702f, 1.128313f, -1.210296f, 3.854497f + }); + test.AddOutput("L", {4, 4}, { + 1.961247f, -0.385480f, 0.444409f, -0.883979f, + 0.000000f, 2.573832f, -0.463038f, 0.305986f, + 0.000000f, 0.000000f, 1.582848f, -0.426929f, + 0.000000f, 0.000000f, 0.000000f, 1.672477f + }); + test.Run(); +} + +TYPED_TEST(LinalgCholeskyContribOpTest, batch_lower) { + OpTester test("LinalgCholesky", 1, kMSDomain); + test.AddAttribute("upper", (int64_t)0); + test.AddInput("A", {2, 4, 4}, { + 3.846490f, -0.756021f, 0.871596f, -1.733702f, + -0.756021f, 6.773203f, -1.363091f, 1.128313f, + 0.871596f, -1.363091f, 2.917311f, -1.210296f, + -1.733702f, 1.128313f, -1.210296f, 3.854497f, + 6.656603f, 3.104256f, 0.025553f, -4.702889f, + 3.104256f, 7.327088f, 1.758924f, -3.521260f, + 0.025553f, 1.758924f, 1.969322f, -0.205388f, + -4.702889f, -3.521260f, -0.205388f, 7.054066f + }); + test.AddOutput("L", {2, 4, 4}, { + 1.961247f, 0.000000f, 0.000000f, 0.000000f, + -0.385480f, 2.573832f, 0.000000f, 0.000000f, + 0.444409f, -0.463038f, 1.582848f, 0.000000f, + -0.883979f, 0.305986f, -0.426930f, 1.672477f, + 2.580039f, 0.000000f, 0.000000f, 0.000000f, + 1.203182f, 2.424756f, 0.000000f, 0.000000f, + 0.009904f, 0.720488f, 1.204210f, 0.000000f, + -1.822797f, -0.547727f, 0.172142f, 1.844407f + }); + test.Run(); +} + +TYPED_TEST(LinalgCholeskyContribOpTest, batch_upper) { + OpTester test("LinalgCholesky", 1, kMSDomain); + test.AddAttribute("upper", (int64_t)1); + test.AddInput("A", {2, 4, 4}, { + 3.846490f, -0.756021f, 0.871596f, -1.733702f, + -0.756021f, 6.773203f, -1.363091f, 1.128313f, + 0.871596f, -1.363091f, 2.917311f, -1.210296f, + -1.733702f, 1.128313f, -1.210296f, 3.854497f, + 6.656603f, 3.104256f, 0.025553f, -4.702889f, + 3.104256f, 7.327088f, 1.758924f, -3.521260f, + 0.025553f, 1.758924f, 1.969322f, -0.205388f, + -4.702889f, -3.521260f, -0.205388f, 7.054066f + }); + test.AddOutput("L", {2, 4, 4}, { + 1.961247f, -0.385480f, 0.444409f, -0.883979f, + 0.000000f, 2.573832f, -0.463038f, 0.305986f, + 0.000000f, 0.000000f, 1.582848f, -0.426930f, + 0.000000f, 0.000000f, 0.000000f, 1.672477f, + 2.580039f, 1.203182f, 0.009904f, -1.822797f, + 0.000000f, 2.424756f, 0.720488f, -0.547727f, + 0.000000f, 0.000000f, 1.204210f, 0.172142f, + 0.000000f, 0.000000f, 0.000000f, 1.844407f + }); + test.Run(); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/linalg_solve_test.cc b/onnxruntime/test/contrib_ops/linalg_solve_test.cc new file mode 100644 index 0000000000000..a14c17e0c105f --- /dev/null +++ b/onnxruntime/test/contrib_ops/linalg_solve_test.cc @@ -0,0 +1,361 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include "core/util/math.h" + +namespace onnxruntime { +namespace test { + +template +class LinalgSolveContribOpTest : public ::testing::Test { +}; + +using LinalgTypes = ::testing::Types; +TYPED_TEST_SUITE(LinalgSolveContribOpTest, LinalgTypes); + +// DO NOT EDIT following test cases.they are generated with: +// in test_linalg_ops_with_pytorch.py, set generate_testcases to True to print C++ test cases +// python onnxruntime/test/python/test_linalg_ops_with_pytorch.py -k TestLinalgOps.test_linalg_solve +TYPED_TEST(LinalgSolveContribOpTest, no_batch_no_left_no_boardcast_no_b_as_vector) { + OpTester test("LinalgSolve", 1, kMSDomain); + test.AddAttribute("left", (int64_t)0); + test.AddInput("A", {3, 3}, { + 7.207892f, 4.241426f, 1.942765f, + 4.241426f, 3.455372f, 0.326367f, + 1.942765f, 0.326367f, 1.382308f + }); + test.AddInput("B", {4, 3}, { + -0.403344f, -0.596635f, 0.182036f, + -0.856675f, 1.100604f, -1.071187f, + 0.122701f, -0.566317f, 0.373115f, + -0.891995f, -1.509108f, 0.370394f + }); + test.AddOutput("X", {4, 3}, {{ + 0.235647f, -0.453186f, -0.092502f, + -3.583405f, 4.413022f, 3.219445f, + 1.368743f, -1.726304f, -1.246193f, + 1.550905f, -2.209154f, -1.390180f + }}, + false, + 1e-3f, + 1e-3f); + test.Run(); +} + +TYPED_TEST(LinalgSolveContribOpTest, no_batch_no_left_no_boardcast_b_as_vector) { + OpTester test("LinalgSolve", 1, kMSDomain); + test.AddAttribute("left", (int64_t)0); + test.AddInput("A", {3, 3}, { + 7.207892f, 4.241426f, 1.942765f, + 4.241426f, 3.455372f, 0.326367f, + 1.942765f, 0.326367f, 1.382308f + }); + test.AddInput("B", {1, 3}, { + -0.403344f, -0.596635f, 0.182036f + }); + test.AddOutput("X", {1, 3}, {{ + 0.235647f, -0.453186f, -0.092502f + }}, + false, + 1e-3f, + 1e-3f); + test.Run(); +} + +TYPED_TEST(LinalgSolveContribOpTest, no_batch_left_no_boardcast_no_b_as_vector) { + OpTester test("LinalgSolve", 1, kMSDomain); + test.AddAttribute("left", (int64_t)1); + test.AddInput("A", {4, 4}, { + 2.846490f, -0.756021f, 0.871596f, -1.733702f, + -0.756021f, 5.773203f, -1.363091f, 1.128313f, + 0.871596f, -1.363091f, 1.917311f, -1.210296f, + -1.733702f, 1.128313f, -1.210296f, 2.854497f + }); + test.AddInput("B", {4, 3}, { + -0.719258f, -0.403344f, -0.596635f, + 0.182036f, -0.856675f, 1.100604f, + -1.071187f, 0.122701f, -0.566317f, + 0.373115f, -0.891995f, -1.509108f + }); + test.AddOutput("X", {4, 3}, {{ + -0.225903f, -0.515750f, -0.785838f, + -0.112631f, -0.137781f, 0.198996f, + -0.699208f, -0.218630f, -0.657229f, + -0.258434f, -0.663969f, -1.363283f + }}, + false, + 1e-3f, + 1e-3f); + test.Run(); +} + +TYPED_TEST(LinalgSolveContribOpTest, no_batch_left_no_boardcast_b_as_vector) { + OpTester test("LinalgSolve", 1, kMSDomain); + test.AddAttribute("left", (int64_t)1); + test.AddInput("A", {4, 4}, { + 2.846490f, -0.756021f, 0.871596f, -1.733702f, + -0.756021f, 5.773203f, -1.363091f, 1.128313f, + 0.871596f, -1.363091f, 1.917311f, -1.210296f, + -1.733702f, 1.128313f, -1.210296f, 2.854497f + }); + test.AddInput("B", {4}, { + -0.719258f, -0.403344f, -0.596635f, 0.182036f + }); + test.AddOutput("X", {4}, {{ + -0.312115f, -0.167263f, -0.444982f, -0.248349f + }}, + false, + 1e-3f, + 1e-3f); + test.Run(); +} + +TYPED_TEST(LinalgSolveContribOpTest, batch_no_left_no_boardcast_no_b_as_vector) { + OpTester test("LinalgSolve", 1, kMSDomain); + test.AddAttribute("left", (int64_t)0); + test.AddInput("A", {2, 3, 3}, { + 2.916542f, -2.464638f, -1.485766f, + -2.464638f, 3.406585f, 0.110846f, + -1.485766f, 0.110846f, 3.769273f, + 0.088604f, -0.320561f, 0.202483f, + -0.320561f, 5.342275f, 1.497411f, + 0.202483f, 1.497411f, 3.165195f + }); + test.AddInput("B", {2, 4, 3}, { + -0.492677f, 0.248415f, 0.439696f, + 0.112411f, 0.640792f, 0.441156f, + -0.215863f, -0.742548f, -0.573077f, + -0.555358f, 0.594323f, 1.541943f, + 0.507334f, -0.591033f, -0.569248f, + 0.919971f, -0.069073f, -0.494925f, + -1.495915f, -0.193837f, 0.445512f, + 1.325275f, -1.629326f, -0.549744f + }); + test.AddOutput("X", {2, 4, 3}, {{ + -0.249439f, -0.108246f, 0.021512f, + 1.203938f, 1.040889f, 0.560996f, + -1.573670f, -1.332656f, -0.733155f, + 0.727643f, 0.678912f, 0.675938f, + 13.462979f, 1.140217f, -1.580517f, + 25.272738f, 2.306375f, -2.864220f, + -40.813610f, -3.754412f, 4.527832f, + 32.358044f, 2.611879f, -3.479328f + }}, + false, + 1e-3f, + 1e-3f); + test.Run(); +} + +TYPED_TEST(LinalgSolveContribOpTest, batch_no_left_no_boardcast_b_as_vector) { + OpTester test("LinalgSolve", 1, kMSDomain); + test.AddAttribute("left", (int64_t)0); + test.AddInput("A", {2, 3, 3}, { + 2.916542f, -2.464638f, -1.485766f, + -2.464638f, 3.406585f, 0.110846f, + -1.485766f, 0.110846f, 3.769273f, + 0.088604f, -0.320561f, 0.202483f, + -0.320561f, 5.342275f, 1.497411f, + 0.202483f, 1.497411f, 3.165195f + }); + test.AddInput("B", {2, 1, 3}, { + 0.408716f, 1.421418f, 0.149397f, + -0.670860f, -0.214186f, -0.431969f + }); + test.AddOutput("X", {2, 1, 3}, {{ + 2.423967f, 2.140646f, 0.932159f, + -16.843258f, -1.515485f, 1.657973f + }}, + false, + 1e-3f, + 1e-3f); + test.Run(); +} + +TYPED_TEST(LinalgSolveContribOpTest, batch_no_left_boardcast_no_b_as_vector) { + OpTester test("LinalgSolve", 1, kMSDomain); + test.AddAttribute("left", (int64_t)0); + test.AddInput("A", {2, 3, 3}, { + 2.916542f, -2.464638f, -1.485766f, + -2.464638f, 3.406585f, 0.110846f, + -1.485766f, 0.110846f, 3.769273f, + 0.088604f, -0.320561f, 0.202483f, + -0.320561f, 5.342275f, 1.497411f, + 0.202483f, 1.497411f, 3.165195f + }); + test.AddInput("B", {4, 3}, { + 0.408716f, 1.421418f, 0.149397f, + -0.670860f, -0.214186f, -0.431969f, + -0.707878f, -0.106434f, -1.242732f, + -0.476232f, -0.685918f, -1.505142f + }); + test.AddOutput("X", {2, 4, 3}, {{ + 2.423967f, 2.140646f, 0.932159f, + -1.617166f, -1.209568f, -0.716484f, + -2.049185f, -1.478216f, -1.093974f, + -2.506701f, -1.971671f, -1.329424f, + 13.611500f, 1.514484f, -1.540033f, + -16.843258f, -1.515485f, 1.657973f, + -15.300617f, -1.270853f, 1.187406f, + -9.874021f, -0.881538f, 0.573173f + }}, + false, + 1e-3f, + 1e-3f); + test.Run(); +} + +TYPED_TEST(LinalgSolveContribOpTest, batch_no_left_boardcast_b_as_vector) { + OpTester test("LinalgSolve", 1, kMSDomain); + test.AddAttribute("left", (int64_t)0); + test.AddInput("A", {2, 3, 3}, { + 2.916542f, -2.464638f, -1.485766f, + -2.464638f, 3.406585f, 0.110846f, + -1.485766f, 0.110846f, 3.769273f, + 0.088604f, -0.320561f, 0.202483f, + -0.320561f, 5.342275f, 1.497411f, + 0.202483f, 1.497411f, 3.165195f + }); + test.AddInput("B", {1, 3}, { + 0.408716f, 1.421418f, 0.149397f + }); + test.AddOutput("X", {2, 1, 3}, {{ + 2.423967f, 2.140646f, 0.932159f, + 13.611500f, 1.514484f, -1.540033f + }}, + false, + 1e-3f, + 1e-3f); + test.Run(); +} + +TYPED_TEST(LinalgSolveContribOpTest, batch_left_no_boardcast_no_b_as_vector) { + OpTester test("LinalgSolve", 1, kMSDomain); + test.AddAttribute("left", (int64_t)1); + test.AddInput("A", {2, 4, 4}, { + 2.846490f, -0.756021f, 0.871596f, -1.733702f, + -0.756021f, 5.773203f, -1.363091f, 1.128313f, + 0.871596f, -1.363091f, 1.917311f, -1.210296f, + -1.733702f, 1.128313f, -1.210296f, 2.854497f, + 5.656603f, 3.104256f, 0.025553f, -4.702889f, + 3.104256f, 6.327088f, 1.758924f, -3.521260f, + 0.025553f, 1.758924f, 0.969322f, -0.205388f, + -4.702889f, -3.521260f, -0.205388f, 6.054066f + }); + test.AddInput("B", {2, 4, 3}, { + -0.613583f, 0.031593f, -0.492677f, + 0.248415f, 0.439696f, 0.112411f, + 0.640792f, 0.441156f, 0.205526f, + -0.450330f, -0.573077f, -0.555358f, + 0.594323f, 1.541943f, 0.507334f, + -0.591033f, -1.325326f, 0.188554f, + -0.069073f, -0.494925f, -1.495915f, + -0.193837f, 0.445512f, 1.325275f + }); + test.AddOutput("X", {2, 4, 3}, {{ + -0.524159f, -0.190020f, -0.465158f, + 0.150501f, 0.166460f, 0.067125f, + 0.466158f, 0.264487f, 0.066190f, + -0.337954f, -0.269829f, -0.475542f, + 0.380720f, 1.101162f, -0.025249f, + -0.539868f, -0.609932f, 2.319154f, + 0.894120f, 0.693820f, -5.462131f, + -0.019942f, 0.597768f, 1.362889f + }}, + false, + 1e-3f, + 1e-3f); + test.Run(); +} + +TYPED_TEST(LinalgSolveContribOpTest, batch_left_no_boardcast_b_as_vector) { + OpTester test("LinalgSolve", 1, kMSDomain); + test.AddAttribute("left", (int64_t)1); + test.AddInput("A", {2, 4, 4}, { + 2.846490f, -0.756021f, 0.871596f, -1.733702f, + -0.756021f, 5.773203f, -1.363091f, 1.128313f, + 0.871596f, -1.363091f, 1.917311f, -1.210296f, + -1.733702f, 1.128313f, -1.210296f, 2.854497f, + 5.656603f, 3.104256f, 0.025553f, -4.702889f, + 3.104256f, 6.327088f, 1.758924f, -3.521260f, + 0.025553f, 1.758924f, 0.969322f, -0.205388f, + -4.702889f, -3.521260f, -0.205388f, 6.054066f + }); + test.AddInput("B", {2, 4}, { + -0.566317f, 0.373115f, -0.891995f, -1.509108f, + 0.370394f, 1.456503f, 0.939810f, 0.774849f + }); + test.AddOutput("X", {2, 4}, {{ + -0.749342f, 0.003339f, -1.015990f, -1.415893f, + 0.574198f, -0.054876f, 1.177326f, 0.582058f + }}, + false, + 1e-3f, + 1e-3f); + test.Run(); +} + +TYPED_TEST(LinalgSolveContribOpTest, batch_left_boardcast_no_b_as_vector) { + OpTester test("LinalgSolve", 1, kMSDomain); + test.AddAttribute("left", (int64_t)1); + test.AddInput("A", {2, 4, 4}, { + 2.846490f, -0.756021f, 0.871596f, -1.733702f, + -0.756021f, 5.773203f, -1.363091f, 1.128313f, + 0.871596f, -1.363091f, 1.917311f, -1.210296f, + -1.733702f, 1.128313f, -1.210296f, 2.854497f, + 5.656603f, 3.104256f, 0.025553f, -4.702889f, + 3.104256f, 6.327088f, 1.758924f, -3.521260f, + 0.025553f, 1.758924f, 0.969322f, -0.205388f, + -4.702889f, -3.521260f, -0.205388f, 6.054066f + }); + test.AddInput("B", {4, 3}, { + -0.566317f, 0.373115f, -0.891995f, + -1.509108f, 0.370394f, 1.456503f, + 0.939810f, 0.774849f, 0.191869f, + 1.263795f, -1.290435f, -0.791103f + }); + test.AddOutput("X", {2, 4, 3}, {{ + 0.034407f, -0.244951f, -0.766153f, + -0.226307f, 0.215456f, 0.354410f, + 0.904993f, 0.321732f, 0.195547f, + 0.936803f, -0.549595f, -0.799650f, + 0.880404f, 0.072070f, -1.015060f, + -2.001246f, -1.122906f, 0.889210f, + 4.553026f, 2.682681f, -1.484702f, + -0.116869f, -0.719277f, -0.452360f + }}, + false, + 1e-3f, + 1e-3f); + test.Run(); +} + +TYPED_TEST(LinalgSolveContribOpTest, batch_left_boardcast_b_as_vector) { + OpTester test("LinalgSolve", 1, kMSDomain); + test.AddAttribute("left", (int64_t)1); + test.AddInput("A", {2, 4, 4}, { + 2.846490f, -0.756021f, 0.871596f, -1.733702f, + -0.756021f, 5.773203f, -1.363091f, 1.128313f, + 0.871596f, -1.363091f, 1.917311f, -1.210296f, + -1.733702f, 1.128313f, -1.210296f, 2.854497f, + 5.656603f, 3.104256f, 0.025553f, -4.702889f, + 3.104256f, 6.327088f, 1.758924f, -3.521260f, + 0.025553f, 1.758924f, 0.969322f, -0.205388f, + -4.702889f, -3.521260f, -0.205388f, 6.054066f + }); + test.AddInput("B", {4}, { + -0.566317f, 0.373115f, -0.891995f, -1.509108f + }); + test.AddOutput("X", {2, 4}, {{ + -0.749342f, 0.003339f, -1.015990f, -1.415893f, + -1.325651f, 1.274423f, -3.335771f, -0.650975f + }}, + false, + 1e-3f, + 1e-3f); + test.Run(); +} +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/python/test_linalg_ops_with_pytorch.py b/onnxruntime/test/python/test_linalg_ops_with_pytorch.py index 88b498d7d6103..74aa69d5b9f47 100644 --- a/onnxruntime/test/python/test_linalg_ops_with_pytorch.py +++ b/onnxruntime/test/python/test_linalg_ops_with_pytorch.py @@ -57,6 +57,40 @@ def create_svd_model(use_batch, full_matrices): ) return onnx_model +def create_solve_model(use_batch, left, boardcast, b_as_vector): + a_shape = ["B"] if use_batch else [] + b_shape = ["B"] if use_batch and not boardcast else [] + x_shape = ["B"] if use_batch else [] + if left: + a_shape = [*a_shape, "N", "N"] + if b_as_vector: + b_shape = [*b_shape, "N"] + x_shape = [*x_shape, "N"] + else: + b_shape = [*b_shape, "N", "K"] + x_shape = [*x_shape, "N", "K"] + else: + a_shape = [*a_shape, "K", "K"] + if b_as_vector: + b_shape = [*b_shape, 1, "K"] + x_shape = [*x_shape, 1, "K"] + else: + b_shape = [*b_shape, "N", "K"] + x_shape = [*x_shape, "N", "K"] + + a_value_info = helper.make_tensor_value_info("A", onnx.TensorProto.FLOAT, a_shape) + b_value_info = helper.make_tensor_value_info("B", onnx.TensorProto.FLOAT, b_shape) + x_value_info = helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, x_shape) + + onnx_model = create_model( + "LinalgSolve", + [a_value_info, b_value_info], + [x_value_info], + opset_version=17, + node_kwargs={"left": left, "domain": "com.microsoft",}, + ) + return onnx_model + def normalize_signs(a, b): signs = np.sign(a[0]) * np.sign(b[0]) return a, b * signs, signs == -1 @@ -87,6 +121,146 @@ class TestLinalgOps(unittest.TestCase): def setUp(self): self.opset_version = 17 + @parameterized.parameterized.expand([ + (False, False), + (False, True), + (True, False), + (True, True), + ]) + def test_linalg_cholesky(self, use_batch, upper): + torch.manual_seed(0) + batch = 2 + n = 4 + A = torch.randn(batch, n, n) if use_batch else torch.randn(n, n) + A = A @ A.transpose(-2, -1) + torch.eye(n) + L = torch.linalg.cholesky(A, upper=upper) + + onnx_model = create_model( + "LinalgCholesky", + [helper.make_tensor_value_info("A", onnx.TensorProto.FLOAT, ["B", "N", "N"] if use_batch else ["N", "N"])], + [helper.make_tensor_value_info("L", onnx.TensorProto.FLOAT, ["B", "N", "N"] if use_batch else ["N", "N"])], + opset_version=self.opset_version, + node_kwargs={"upper": upper, "domain": "com.microsoft",}, + ) + # session = ort.InferenceSession(onnx_model.SerializeToString()) + # input_name = session.get_inputs()[0].name + # output_names = [output.name for output in session.get_outputs()] + # input_data = {input_name: A.numpy()} + # actual_l = session.run(output_names, input_data) + # expected_l = L.numpy() + # validate_base_equal(actual_l[0], expected_l) + + if generate_testcases: + # Print the C++ test case + A_str = format_tensor(A) + L_str = format_tensor(L) + + batch_str = 'batch' if use_batch else 'no_batch' + upper_str = 'upper' if upper else 'lower' + test_case_name = f'{batch_str}_{upper_str}' + + print(f'TYPED_TEST(LinalgCholeskyContribOpTest, {test_case_name}) {{\n' + f' OpTester test("LinalgCholesky", 1, kMSDomain);\n' + f' test.AddAttribute("upper", (int64_t){"1" if upper else "0"});\n' + f' test.AddInput("A", {{{", ".join(map(str, A.shape))}}}, {A_str});\n' + f' test.AddOutput("L", {{{", ".join(map(str, L.shape))}}}, {L_str});\n' + f' test.Run();\n' + f'}}') + + + @parameterized.parameterized.expand([ + (False, False, False, False), + (False, False, False, True), + (False, True, False, False), + (False, True, False, True), + (True, False, False, False), + (True, False, False, True), + (True, False, True, False), + (True, False, True, True), + (True, True, False, False), + (True, True, False, True), + (True, True, True, False), + (True, True, True, True), + ]) + def test_linalg_solve(self, use_batch, left, boardcast, b_as_vector): + def create_invertable_matrix(shape): + A = torch.randn(*shape) + if len(shape) == 3: + return torch.matmul(A, A.transpose(-2, -1)) + else: + return torch.matmul(A, A.t()) + + torch.manual_seed(0) + batch = 2 + n = 4 + k = 3 + if left: + A = create_invertable_matrix((batch, n, n)) if use_batch else create_invertable_matrix((n, n)) + else: + A = create_invertable_matrix((batch, k, k)) if use_batch else create_invertable_matrix((k, k)) + if use_batch: + if boardcast: + if b_as_vector: + if left: + B = torch.randn(n) + else: + B = torch.randn(1, k) + else: + B = torch.randn(n, k) + else: + if b_as_vector: + if left: + B = torch.randn(batch, n) + else: + B = torch.randn(batch, 1, k) + else: + B = torch.randn(batch, n, k) + else: + assert boardcast is False, "boardcast shall not set for non-batch mode" + if b_as_vector: + if left: + B = torch.randn(n) + else: + B = torch.randn(1, k) + else: + B = torch.randn(n, k) + + X = torch.linalg.solve(A, B, left=left) + + onnx_model = create_solve_model(use_batch=use_batch, left=left, boardcast=boardcast, b_as_vector=b_as_vector) + session = ort.InferenceSession(onnx_model.SerializeToString()) + input_names = [input.name for input in session.get_inputs()] + output_names = [output.name for output in session.get_outputs()] + input_data = {input_names[0]: A.numpy(), input_names[1]: B.numpy()} + actual_x = session.run(output_names, input_data) + expected_x = X.numpy() + np.testing.assert_allclose(actual_x[0], expected_x, rtol=1e-5, atol=1e-7) + if generate_testcases: + # Print the C++ test case + A_str = format_tensor(A) + B_str = format_tensor(B) + X_str = format_tensor(X) + + batch_str = 'batch' if use_batch else 'no_batch' + left_str = '_left' if left else '_no_left' + boardcast_str = '_boardcast' if boardcast else '_no_boardcast' + b_as_vector_str = '_b_as_vector' if b_as_vector else '_no_b_as_vector' + test_case_name = f'{batch_str}{left_str}{boardcast_str}{b_as_vector_str}' + + print(f'TYPED_TEST(LinalgSolveContribOpTest, {test_case_name}) {{\n' + f' OpTester test("LinalgSolve", 1, kMSDomain);\n' + f' test.AddAttribute("left", (int64_t){"1" if left else "0"});\n' + f' test.AddInput("A", {{{", ".join(map(str, A.shape))}}}, {A_str});\n' + f' test.AddInput("B", {{{", ".join(map(str, B.shape))}}}, {B_str});\n' + f' test.AddOutput("X", {{{", ".join(map(str, X.shape))}}}, {{{X_str}}},\n' + f' false,\n' + f' 1e-3f,\n' + f' 1e-3f);\n' + f' test.Run();\n' + f'}}') + + + @parameterized.parameterized.expand([ (True, True), (True, False), From d9f6b98ee08b6cc3813536fcb2d35bf3516f98fd Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Sat, 25 Nov 2023 12:21:18 -0800 Subject: [PATCH 3/3] linalg_cholesky Signed-off-by: Liqun Fu --- .../contrib_ops/cpu/cpu_contrib_kernels.cc | 2 + .../contrib_ops/cpu/linalg_cholesky.cc | 75 +++++++++++++------ onnxruntime/contrib_ops/cpu/linalg_cholesky.h | 1 + .../core/graph/contrib_ops/contrib_defs.cc | 2 +- .../python/test_linalg_ops_with_pytorch.py | 16 ++-- 5 files changed, 65 insertions(+), 31 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index e2feec38540fa..446434fc53705 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -17,6 +17,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, LinalgSolve); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, LinalgInv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, LinalgCholesky); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, LinalgCholesky); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Attention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, BeamSearch); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, WhisperBeamSearch); @@ -260,6 +261,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/linalg_cholesky.cc b/onnxruntime/contrib_ops/cpu/linalg_cholesky.cc index 18fe99e5f3ff6..dd7042e2b1bd8 100644 --- a/onnxruntime/contrib_ops/cpu/linalg_cholesky.cc +++ b/onnxruntime/contrib_ops/cpu/linalg_cholesky.cc @@ -22,35 +22,66 @@ ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( float, KernelDefBuilder() .TypeConstraint("T", DataTypeImpl::GetTensorType()), - LinalgCholesky); + LinalgCholesky); -#pragma warning(disable : 4189) -Status LinalgCholesky::Compute(OpKernelContext* context) const { - Status status = Status::OK(); - const Tensor* A = context->Input(0); - const TensorShape& a_shape = A->Shape(); - assert(a_shape.NumDimensions() == 2); - assert(a_shape[0] == a_shape[1]); - - TensorShape X_shape = { a_shape[0], a_shape[1] }; - Tensor* X = context->Output(0, X_shape); +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + LinalgCholesky, + 1, + double, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + LinalgCholesky); +template +Status cholesky(const T* a_data, T* l_data, int64_t n, bool upper) { const Eigen::StorageOptions option = Eigen::RowMajor; - Eigen::Map> a_matrix(A->Data(), narrow(a_shape[0]), narrow(a_shape[1])); - Eigen::Map> x_matrix(X->MutableData(), narrow(a_shape[0]), narrow(a_shape[1])); - Eigen::LLT lltOfA(a_matrix); + Eigen::Map> a_matrix(a_data, narrow(n), narrow(n)); + Eigen::Map> l_matrix(l_data, narrow(n), narrow(n)); + Eigen::LLT> lltOfA(a_matrix); if (lltOfA.info() == Eigen::Success) { - if (this->upper_) { - x_matrix = lltOfA.matrixU(); - } - else { - x_matrix = lltOfA.matrixL(); + if (upper) { + l_matrix = lltOfA.matrixU(); + } else { + l_matrix = lltOfA.matrixL(); } + return Status::OK(); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input matrix A is not decomposable with Cholesky."); } - else { - assert(false); +} + +#pragma warning(disable : 4189) +template +Status LinalgCholesky::Compute(OpKernelContext* context) const { + const Tensor* A = context->Input(0); + const TensorShape& a_shape = A->Shape(); + int64_t a_rank = a_shape.NumDimensions(); + assert(a_rank == 2 || a_rank == 3); + int64_t batch = a_rank == 2 ? 1 : a_shape[0]; + + assert(a_shape[a_rank - 1] == a_shape[a_rank - 2]); + + Tensor* L = context->Output(0, a_shape); + + if (batch == 1) { + return cholesky(A->Data(), L->MutableData(), a_shape[a_rank - 1], upper_); + } else { + std::mutex status_mutex; + Status summary_status = Status::OK(); + int64_t single_batch_size = a_shape.SizeFromDimension(a_rank - 2); + std::function fn = [&](ptrdiff_t batch_num) { + const T* a_data = A->Data() + batch_num * single_batch_size; + T* l_data = L->MutableData() + batch_num * single_batch_size; + Status status = cholesky(a_data, l_data, a_shape[a_rank - 1], upper_); + if (!status.IsOK()) { + // let the main function return any unsuccessful status + std::lock_guard lock(status_mutex); + summary_status = status; + } + }; + concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), narrow(batch), std::move(fn), 0); + return summary_status; } - return status; } } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/linalg_cholesky.h b/onnxruntime/contrib_ops/cpu/linalg_cholesky.h index 1ba5cc39b8459..0aafe8dcfc8d5 100644 --- a/onnxruntime/contrib_ops/cpu/linalg_cholesky.h +++ b/onnxruntime/contrib_ops/cpu/linalg_cholesky.h @@ -13,6 +13,7 @@ namespace onnxruntime { namespace contrib { +template class LinalgCholesky : public OpKernel { public: LinalgCholesky(const OpKernelInfo& info) diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 344441ff6f9ba..19cef331602d6 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -2794,7 +2794,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(LinalgCholesky, 1, "T") .TypeConstraint( "T", - {"tensor(float)"}, + {"tensor(float)", "tensor(double)"}, "") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); diff --git a/onnxruntime/test/python/test_linalg_ops_with_pytorch.py b/onnxruntime/test/python/test_linalg_ops_with_pytorch.py index 74aa69d5b9f47..d10a5de3af8e4 100644 --- a/onnxruntime/test/python/test_linalg_ops_with_pytorch.py +++ b/onnxruntime/test/python/test_linalg_ops_with_pytorch.py @@ -140,15 +140,15 @@ def test_linalg_cholesky(self, use_batch, upper): [helper.make_tensor_value_info("A", onnx.TensorProto.FLOAT, ["B", "N", "N"] if use_batch else ["N", "N"])], [helper.make_tensor_value_info("L", onnx.TensorProto.FLOAT, ["B", "N", "N"] if use_batch else ["N", "N"])], opset_version=self.opset_version, - node_kwargs={"upper": upper, "domain": "com.microsoft",}, + node_kwargs={"upper": 1 if upper else 0, "domain": "com.microsoft",}, ) - # session = ort.InferenceSession(onnx_model.SerializeToString()) - # input_name = session.get_inputs()[0].name - # output_names = [output.name for output in session.get_outputs()] - # input_data = {input_name: A.numpy()} - # actual_l = session.run(output_names, input_data) - # expected_l = L.numpy() - # validate_base_equal(actual_l[0], expected_l) + session = ort.InferenceSession(onnx_model.SerializeToString()) + input_name = session.get_inputs()[0].name + output_names = [output.name for output in session.get_outputs()] + input_data = {input_name: A.numpy()} + actual_l = session.run(output_names, input_data) + expected_l = L.numpy() + np.testing.assert_allclose(actual_l[0], expected_l, rtol=1e-5, atol=1e-7) if generate_testcases: # Print the C++ test case