Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP Liqun/linalg #18621

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GridSample);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, LinalgSVD);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, LinalgSVD);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, LinalgSolve);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, LinalgSolve);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, LinalgInv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, LinalgCholesky);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, LinalgCholesky);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Attention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, BeamSearch);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, WhisperBeamSearch);
Expand Down Expand Up @@ -248,6 +255,13 @@

// add more kernels here
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GridSample)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, LinalgSVD)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, LinalgSVD)>,

Check warning on line 259 in onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc#L259

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc:259:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, LinalgSolve)>,

Check warning on line 260 in onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc#L260

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc:260:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, LinalgSolve)>,

Check warning on line 261 in onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc#L261

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc:261:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, LinalgInv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, LinalgCholesky)>,

Check warning on line 263 in onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc#L263

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc:263:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, LinalgCholesky)>,

Check warning on line 264 in onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc#L264

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc:264:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, BeamSearch)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, WhisperBeamSearch)>,
Expand Down
87 changes: 87 additions & 0 deletions onnxruntime/contrib_ops/cpu/linalg_cholesky.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
// Licensed under the MIT License.
#ifdef _MSC_VER
#pragma warning(push)
#pragma warning(disable : 4996)
#endif

#include "contrib_ops/cpu/linalg_cholesky.h"
#include "core/framework/framework_common.h"
#include "core/framework/tensorprotoutils.h"
#include <Eigen/Dense>

Check warning on line 11 in onnxruntime/contrib_ops/cpu/linalg_cholesky.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_cholesky.cc#L11

Found C system header after other header. Should be: linalg_cholesky.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cpu/linalg_cholesky.cc:11:  Found C system header after other header. Should be: linalg_cholesky.h, c system, c++ system, other.  [build/include_order] [4]

using namespace ONNX_NAMESPACE;

Check warning on line 13 in onnxruntime/contrib_ops/cpu/linalg_cholesky.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_cholesky.cc#L13

Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
Raw output
onnxruntime/contrib_ops/cpu/linalg_cholesky.cc:13:  Do not use namespace using-directives.  Use using-declarations instead.  [build/namespaces] [5]
using namespace onnxruntime::common;

Check warning on line 14 in onnxruntime/contrib_ops/cpu/linalg_cholesky.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_cholesky.cc#L14

Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
Raw output
onnxruntime/contrib_ops/cpu/linalg_cholesky.cc:14:  Do not use namespace using-directives.  Use using-declarations instead.  [build/namespaces] [5]

namespace onnxruntime {
namespace contrib {

ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
LinalgCholesky,
1,
float,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
LinalgCholesky<float>);

ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
LinalgCholesky,
1,
double,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
LinalgCholesky<double>);

template <typename T>
Status cholesky(const T* a_data, T* l_data, int64_t n, bool upper) {
const Eigen::StorageOptions option = Eigen::RowMajor;
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, option>> a_matrix(a_data, narrow<size_t>(n), narrow<size_t>(n));

Check warning on line 38 in onnxruntime/contrib_ops/cpu/linalg_cholesky.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_cholesky.cc#L38

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/linalg_cholesky.cc:38:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, option>> l_matrix(l_data, narrow<size_t>(n), narrow<size_t>(n));

Check warning on line 39 in onnxruntime/contrib_ops/cpu/linalg_cholesky.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_cholesky.cc#L39

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/linalg_cholesky.cc:39:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
Eigen::LLT<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, option>> lltOfA(a_matrix);
if (lltOfA.info() == Eigen::Success) {
if (upper) {
l_matrix = lltOfA.matrixU();
} else {
l_matrix = lltOfA.matrixL();
}
return Status::OK();
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input matrix A is not decomposable with Cholesky.");
}
}

#pragma warning(disable : 4189)
template <typename T>
Status LinalgCholesky<T>::Compute(OpKernelContext* context) const {
const Tensor* A = context->Input<Tensor>(0);
const TensorShape& a_shape = A->Shape();
int64_t a_rank = a_shape.NumDimensions();
assert(a_rank == 2 || a_rank == 3);
int64_t batch = a_rank == 2 ? 1 : a_shape[0];

assert(a_shape[a_rank - 1] == a_shape[a_rank - 2]);

Tensor* L = context->Output(0, a_shape);

if (batch == 1) {
return cholesky(A->Data<T>(), L->MutableData<T>(), a_shape[a_rank - 1], upper_);
} else {
std::mutex status_mutex;
Status summary_status = Status::OK();
int64_t single_batch_size = a_shape.SizeFromDimension(a_rank - 2);
std::function<void(ptrdiff_t)> fn = [&](ptrdiff_t batch_num) {
const T* a_data = A->Data<T>() + batch_num * single_batch_size;
T* l_data = L->MutableData<T>() + batch_num * single_batch_size;
Status status = cholesky(a_data, l_data, a_shape[a_rank - 1], upper_);
if (!status.IsOK()) {
// let the main function return any unsuccessful status
std::lock_guard<std::mutex> lock(status_mutex);
summary_status = status;
}
};
concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), narrow<size_t>(batch), std::move(fn), 0);

Check warning on line 82 in onnxruntime/contrib_ops/cpu/linalg_cholesky.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_cholesky.cc#L82

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/linalg_cholesky.cc:82:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

Check warning on line 82 in onnxruntime/contrib_ops/cpu/linalg_cholesky.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_cholesky.cc#L82

Add #include <utility> for move [build/include_what_you_use] [4]
Raw output
onnxruntime/contrib_ops/cpu/linalg_cholesky.cc:82:  Add #include <utility> for move  [build/include_what_you_use] [4]
return summary_status;
}
}
} // namespace contrib
} // namespace onnxruntime
33 changes: 33 additions & 0 deletions onnxruntime/contrib_ops/cpu/linalg_cholesky.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
// Licensed under the MIT License.

#pragma once
#include <functional>

#include "core/common/common.h"
#include "core/framework/feeds_fetches_manager.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cpu/controlflow/utils.h"

namespace onnxruntime {

namespace contrib {

template <typename T>
class LinalgCholesky : public OpKernel {
public:
LinalgCholesky(const OpKernelInfo& info)

Check warning on line 19 in onnxruntime/contrib_ops/cpu/linalg_cholesky.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_cholesky.h#L19

Weird number of spaces at line-start. Are you using a 2-space indent? [whitespace/indent] [3]
Raw output
onnxruntime/contrib_ops/cpu/linalg_cholesky.h:19:  Weird number of spaces at line-start.  Are you using a 2-space indent?  [whitespace/indent] [3]

Check warning on line 19 in onnxruntime/contrib_ops/cpu/linalg_cholesky.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_cholesky.h#L19

Single-parameter constructors should be marked explicit. [runtime/explicit] [5]
Raw output
onnxruntime/contrib_ops/cpu/linalg_cholesky.h:19:  Single-parameter constructors should be marked explicit.  [runtime/explicit] [5]
: OpKernel(info) {
int64_t upper;
ORT_ENFORCE(info.GetAttr<int64_t>("upper", &upper).IsOK());
upper_ = upper != 0;
}

Status Compute(OpKernelContext* context) const override;

private:

Check warning on line 28 in onnxruntime/contrib_ops/cpu/linalg_cholesky.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_cholesky.h#L28

private: should be indented +1 space inside class LinalgCholesky [whitespace/indent] [3]
Raw output
onnxruntime/contrib_ops/cpu/linalg_cholesky.h:28:  private: should be indented +1 space inside class LinalgCholesky  [whitespace/indent] [3]
bool upper_ = false;
};

} // namespace contrib
} // namespace onnxruntime
46 changes: 46 additions & 0 deletions onnxruntime/contrib_ops/cpu/linalg_inv.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
// Licensed under the MIT License.
#ifdef _MSC_VER
#pragma warning(push)
#pragma warning(disable : 4996)
#endif

#include "contrib_ops/cpu/linalg_inv.h"
#include "core/framework/framework_common.h"
#include "core/framework/tensorprotoutils.h"
#include <Eigen/Dense>

Check warning on line 11 in onnxruntime/contrib_ops/cpu/linalg_inv.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_inv.cc#L11

Found C system header after other header. Should be: linalg_inv.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cpu/linalg_inv.cc:11:  Found C system header after other header. Should be: linalg_inv.h, c system, c++ system, other.  [build/include_order] [4]

using namespace ONNX_NAMESPACE;

Check warning on line 13 in onnxruntime/contrib_ops/cpu/linalg_inv.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_inv.cc#L13

Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
Raw output
onnxruntime/contrib_ops/cpu/linalg_inv.cc:13:  Do not use namespace using-directives.  Use using-declarations instead.  [build/namespaces] [5]
using namespace onnxruntime::common;

Check warning on line 14 in onnxruntime/contrib_ops/cpu/linalg_inv.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_inv.cc#L14

Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
Raw output
onnxruntime/contrib_ops/cpu/linalg_inv.cc:14:  Do not use namespace using-directives.  Use using-declarations instead.  [build/namespaces] [5]

namespace onnxruntime {
namespace contrib {

ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
LinalgInv,
1,
float,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
LinalgInv);

#pragma warning(disable : 4189)
Status LinalgInv::Compute(OpKernelContext* context) const {
Status status = Status::OK();
const Tensor* A = context->Input<Tensor>(0);
const TensorShape& a_shape = A->Shape();
assert(a_shape.NumDimensions() == 2);
assert(a_shape[0] == a_shape[1]);

TensorShape X_shape = { a_shape[1], a_shape[0] };
Tensor* X = context->Output(0, X_shape);

const Eigen::StorageOptions option = Eigen::RowMajor;
Eigen::Map<const Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, option>> a_matrix(A->Data<float>(), narrow<size_t>(a_shape[0]), narrow<size_t>(a_shape[1]));

Check warning on line 39 in onnxruntime/contrib_ops/cpu/linalg_inv.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_inv.cc#L39

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/linalg_inv.cc:39:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

Eigen::Map<Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, option>> x_matrix(X->MutableData<float>(), narrow<size_t>(a_shape[1]), narrow<size_t>(a_shape[0]));

Check warning on line 41 in onnxruntime/contrib_ops/cpu/linalg_inv.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_inv.cc#L41

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/linalg_inv.cc:41:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
x_matrix = a_matrix.inverse();
return status;
}
} // namespace contrib
} // namespace onnxruntime
29 changes: 29 additions & 0 deletions onnxruntime/contrib_ops/cpu/linalg_inv.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
// Licensed under the MIT License.

#pragma once
#include <functional>

#include "core/common/common.h"
#include "core/framework/feeds_fetches_manager.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cpu/controlflow/utils.h"

namespace onnxruntime {

namespace contrib {

class LinalgInv : public OpKernel {
public:
LinalgInv(const OpKernelInfo& info)

Check warning on line 18 in onnxruntime/contrib_ops/cpu/linalg_inv.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_inv.h#L18

Weird number of spaces at line-start. Are you using a 2-space indent? [whitespace/indent] [3]
Raw output
onnxruntime/contrib_ops/cpu/linalg_inv.h:18:  Weird number of spaces at line-start.  Are you using a 2-space indent?  [whitespace/indent] [3]

Check warning on line 18 in onnxruntime/contrib_ops/cpu/linalg_inv.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_inv.h#L18

Single-parameter constructors should be marked explicit. [runtime/explicit] [5]
Raw output
onnxruntime/contrib_ops/cpu/linalg_inv.h:18:  Single-parameter constructors should be marked explicit.  [runtime/explicit] [5]
: OpKernel(info) {
}

Status Compute(OpKernelContext* context) const override;

private:

Check warning on line 24 in onnxruntime/contrib_ops/cpu/linalg_inv.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_inv.h#L24

private: should be indented +1 space inside class LinalgInv [whitespace/indent] [3]
Raw output
onnxruntime/contrib_ops/cpu/linalg_inv.h:24:  private: should be indented +1 space inside class LinalgInv  [whitespace/indent] [3]
bool left_ = true;
};

} // namespace contrib
} // namespace onnxruntime
162 changes: 162 additions & 0 deletions onnxruntime/contrib_ops/cpu/linalg_solve.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
// Licensed under the MIT License.
#ifdef _MSC_VER
#pragma warning(push)
#pragma warning(disable : 4996)
#endif

#include "contrib_ops/cpu/linalg_solve.h"
#include "core/framework/framework_common.h"
#include "core/framework/tensorprotoutils.h"
#include <Eigen/Dense>

Check warning on line 11 in onnxruntime/contrib_ops/cpu/linalg_solve.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_solve.cc#L11

Found C system header after other header. Should be: linalg_solve.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cpu/linalg_solve.cc:11:  Found C system header after other header. Should be: linalg_solve.h, c system, c++ system, other.  [build/include_order] [4]

using namespace ONNX_NAMESPACE;

Check warning on line 13 in onnxruntime/contrib_ops/cpu/linalg_solve.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_solve.cc#L13

Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
Raw output
onnxruntime/contrib_ops/cpu/linalg_solve.cc:13:  Do not use namespace using-directives.  Use using-declarations instead.  [build/namespaces] [5]
using namespace onnxruntime::common;

Check warning on line 14 in onnxruntime/contrib_ops/cpu/linalg_solve.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_solve.cc#L14

Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
Raw output
onnxruntime/contrib_ops/cpu/linalg_solve.cc:14:  Do not use namespace using-directives.  Use using-declarations instead.  [build/namespaces] [5]

namespace onnxruntime {
namespace contrib {

ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
LinalgSolve,
1,
float,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
LinalgSolve<float>);

ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
LinalgSolve,
1,
double,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
LinalgSolve<double>);

template <typename T>
void solve(const T* a_data, const T* b_data, T* x_data, bool left, int64_t n, int64_t k) {
const Eigen::StorageOptions option = Eigen::RowMajor;
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, option>> b_matrix(b_data, narrow<size_t>(n), narrow<size_t>(k));

Check warning on line 38 in onnxruntime/contrib_ops/cpu/linalg_solve.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_solve.cc#L38

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/linalg_solve.cc:38:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, option>> x_matrix(x_data, narrow<size_t>(n), narrow<size_t>(k));

Check warning on line 39 in onnxruntime/contrib_ops/cpu/linalg_solve.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_solve.cc#L39

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/linalg_solve.cc:39:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

if (left) {
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, option>> a_matrix(a_data, narrow<size_t>(n), narrow<size_t>(n));

Check warning on line 42 in onnxruntime/contrib_ops/cpu/linalg_solve.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_solve.cc#L42

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/linalg_solve.cc:42:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
Eigen::BDCSVD<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> svd(a_matrix, Eigen::ComputeThinU | Eigen::ComputeThinV);

Check warning on line 43 in onnxruntime/contrib_ops/cpu/linalg_solve.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_solve.cc#L43

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/linalg_solve.cc:43:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
x_matrix = svd.solve(b_matrix);
} else {
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, option>> a_matrix(a_data, narrow<size_t>(k), narrow<size_t>(k));

Check warning on line 46 in onnxruntime/contrib_ops/cpu/linalg_solve.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_solve.cc#L46

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/linalg_solve.cc:46:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
auto a_matrix_transposed = a_matrix.transpose();
auto b_matrix_transposed = b_matrix.transpose();
Eigen::BDCSVD<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> svd(a_matrix_transposed, Eigen::ComputeThinU | Eigen::ComputeThinV);

Check warning on line 49 in onnxruntime/contrib_ops/cpu/linalg_solve.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_solve.cc#L49

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/linalg_solve.cc:49:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, option> x_matrix_transposed_result = svd.solve(b_matrix_transposed);

Check warning on line 50 in onnxruntime/contrib_ops/cpu/linalg_solve.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_solve.cc#L50

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/linalg_solve.cc:50:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
x_matrix = x_matrix_transposed_result.transpose();
}
}

#pragma warning(disable : 4189)
template <typename T>
Status LinalgSolve<T>::Compute(OpKernelContext* context) const {
Status status = Status::OK();
const Tensor* A = context->Input<Tensor>(0);
const TensorShape& a_shape = A->Shape();
assert(a_shape.NumDimensions() == 2 || a_shape.NumDimensions() == 3);
bool has_batch = a_shape.NumDimensions() == 3;
const Tensor* B = context->Input<Tensor>(1);
const TensorShape& b_shape = B->Shape();

int64_t batch = has_batch ? a_shape[0] : 1, n = 1, k = 1;
bool b_as_a_vector = b_shape.NumDimensions() == 1;
bool broadcast;
int64_t n_or_k = a_shape[1];
if (left_) {
n = n_or_k;
} else {
k = n_or_k;
}

if (has_batch) {
ORT_ENFORCE(a_shape[1] == a_shape[2], "A should be square matrix: ", a_shape);
if (b_shape.NumDimensions() == 1) {
b_as_a_vector = true;
broadcast = true;
} else if (b_shape.NumDimensions() == 2) {
if (b_shape[0] == a_shape[0] && b_shape[1] == a_shape[1]) { // A has shape (*, n/k, n/k) and B has shape(*, n/k)
b_as_a_vector = true;
broadcast = false;
} else if (left_ && b_shape[0] == a_shape[1]) { // A has shape (*, n, n) and B has shape (n, k)
b_as_a_vector = false;
broadcast = true;
k = b_shape[1];
} else if (!left_ && b_shape[1] == a_shape[1]) { // A has shape (*, k, k) and B has shape (n, k)
b_as_a_vector = false;
broadcast = true;
n = b_shape[0];
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "B shape does not mach A shape.", b_shape, a_shape);
}
} else { // b_shape.NumDimensions() == 3

Check warning on line 96 in onnxruntime/contrib_ops/cpu/linalg_solve.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_solve.cc#L96

At least two spaces is best between code and comments [whitespace/comments] [2]
Raw output
onnxruntime/contrib_ops/cpu/linalg_solve.cc:96:  At least two spaces is best between code and comments  [whitespace/comments] [2]
ORT_ENFORCE(b_shape[0] == a_shape[0], "A and B shall have the same batch size");
b_as_a_vector = false;
broadcast = false;
if (left_) { // A: (*, n, n), B: (*, n, k)

Check warning on line 100 in onnxruntime/contrib_ops/cpu/linalg_solve.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_solve.cc#L100

At least two spaces is best between code and comments [whitespace/comments] [2]
Raw output
onnxruntime/contrib_ops/cpu/linalg_solve.cc:100:  At least two spaces is best between code and comments  [whitespace/comments] [2]
ORT_ENFORCE(b_shape[1] == a_shape[1], "A and B shall have matching size at dim 1: ", b_shape[1], "vs", a_shape[1]);

Check warning on line 101 in onnxruntime/contrib_ops/cpu/linalg_solve.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_solve.cc#L101

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/linalg_solve.cc:101:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
k = b_shape[2];
} else { // A: (*, k, k), B: (*, n, k)
ORT_ENFORCE(b_shape[2] == a_shape[2], "A and B shall have matching size at dim 2: ", b_shape[2], "vs", a_shape[2]);

Check warning on line 104 in onnxruntime/contrib_ops/cpu/linalg_solve.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_solve.cc#L104

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/linalg_solve.cc:104:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
n = b_shape[1];
}
}
} else { // !has_batch

Check warning on line 108 in onnxruntime/contrib_ops/cpu/linalg_solve.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_solve.cc#L108

At least two spaces is best between code and comments [whitespace/comments] [2]
Raw output
onnxruntime/contrib_ops/cpu/linalg_solve.cc:108:  At least two spaces is best between code and comments  [whitespace/comments] [2]
ORT_ENFORCE(a_shape[0] == a_shape[1], "A should be square matrix: ", a_shape);
broadcast = false;
if (b_shape.NumDimensions() == 1) { // A: (n/k. n/k), B: (n/k,)

Check warning on line 111 in onnxruntime/contrib_ops/cpu/linalg_solve.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_solve.cc#L111

At least two spaces is best between code and comments [whitespace/comments] [2]
Raw output
onnxruntime/contrib_ops/cpu/linalg_solve.cc:111:  At least two spaces is best between code and comments  [whitespace/comments] [2]
ORT_ENFORCE(b_shape[0] == a_shape[0], "A and B shall have matching size at dim 2: ", b_shape[2], "vs", a_shape[2]);

Check warning on line 112 in onnxruntime/contrib_ops/cpu/linalg_solve.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_solve.cc#L112

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/linalg_solve.cc:112:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
b_as_a_vector = true;
} else if (b_shape.NumDimensions() == 2) { // A: (n/k. n/k), B: (n, k)

Check warning on line 114 in onnxruntime/contrib_ops/cpu/linalg_solve.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_solve.cc#L114

At least two spaces is best between code and comments [whitespace/comments] [2]
Raw output
onnxruntime/contrib_ops/cpu/linalg_solve.cc:114:  At least two spaces is best between code and comments  [whitespace/comments] [2]
b_as_a_vector = false;
if (left_) { // A: (n, n), B: (n, k)

Check warning on line 116 in onnxruntime/contrib_ops/cpu/linalg_solve.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_solve.cc#L116

At least two spaces is best between code and comments [whitespace/comments] [2]
Raw output
onnxruntime/contrib_ops/cpu/linalg_solve.cc:116:  At least two spaces is best between code and comments  [whitespace/comments] [2]
k = b_shape[1];
} else { // A: (k, k), B: (n, k)

Check warning on line 118 in onnxruntime/contrib_ops/cpu/linalg_solve.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_solve.cc#L118

At least two spaces is best between code and comments [whitespace/comments] [2]
Raw output
onnxruntime/contrib_ops/cpu/linalg_solve.cc:118:  At least two spaces is best between code and comments  [whitespace/comments] [2]
n = b_shape[0];
}
}
}

std::vector<int64_t> x_dims;

Check warning on line 124 in onnxruntime/contrib_ops/cpu/linalg_solve.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_solve.cc#L124

Add #include <vector> for vector<> [build/include_what_you_use] [4]
Raw output
onnxruntime/contrib_ops/cpu/linalg_solve.cc:124:  Add #include <vector> for vector<>  [build/include_what_you_use] [4]
if (has_batch) {
x_dims.push_back(batch);
}
if (b_as_a_vector) {
if (left_) {
x_dims.push_back(n);
} else {
x_dims.push_back(k);
}

Check warning on line 133 in onnxruntime/contrib_ops/cpu/linalg_solve.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_solve.cc#L133

Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
Raw output
onnxruntime/contrib_ops/cpu/linalg_solve.cc:133:  Line ends in whitespace.  Consider deleting these extra spaces.  [whitespace/end_of_line] [4]
} else {
x_dims.push_back(n);
x_dims.push_back(k);
}
TensorShape x_shape(x_dims);
Tensor* X = context->Output(0, x_shape);

if (batch == 1) {
const T* a_data = A->Data<T>();
const T* b_data = B->Data<T>();
T* x_data = X->MutableData<T>();
solve(a_data, b_data, x_data, left_, n, k);
} else {
int64_t a_single_batch_size = a_shape.SizeFromDimension(a_shape.NumDimensions() - 2);
int64_t b_single_batch_size = broadcast ? 0 : b_shape.SizeFromDimension(1);
int64_t x_single_batch_size = x_shape.SizeFromDimension(1);
std::function<void(ptrdiff_t)> fn = [&](ptrdiff_t batch_num) {
const T* a_data = A->Data<T>() + batch_num * a_single_batch_size;
const T* b_data = B->Data<T>() + (broadcast ? 0 : batch_num * b_single_batch_size);
T* x_data = X->MutableData<T>() + batch_num * x_single_batch_size;
solve(a_data, b_data, x_data, left_, n, k);
};
concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), narrow<size_t>(batch), std::move(fn), 0);

Check warning on line 156 in onnxruntime/contrib_ops/cpu/linalg_solve.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_solve.cc#L156

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/linalg_solve.cc:156:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

Check warning on line 156 in onnxruntime/contrib_ops/cpu/linalg_solve.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/linalg_solve.cc#L156

Add #include <utility> for move [build/include_what_you_use] [4]
Raw output
onnxruntime/contrib_ops/cpu/linalg_solve.cc:156:  Add #include <utility> for move  [build/include_what_you_use] [4]
}

return status;
}
} // namespace contrib
} // namespace onnxruntime
Loading
Loading