From d2f7a5b1286e34ea55dceccff8f17a45f2f799aa Mon Sep 17 00:00:00 2001 From: Jake Mathern Date: Mon, 11 Dec 2023 17:41:16 -0800 Subject: [PATCH] Cherry pick fix constant pow (#18785) ### Description Cherry pick https://github.com/microsoft/onnxruntime/pull/18784 --- .../src/Operators/DmlOperatorElementWise.cpp | 2 +- .../dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp index f0a16da3a3c06..ec94772238cc9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp @@ -479,7 +479,7 @@ class DmlOperatorElementwisePow : public DmlOperator ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 2); ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); - auto constExpTensor = kernelInfo.TryGetConstantInputTensor(1); + auto constExpTensor = kernelInfo.TryGetConstantCpuInputTensor(1); if (constExpTensor && constExpTensor->GetTotalElementCount() == 1) { std::vector> kernelInputIndices = {0}; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h index 59a1719d08ee6..c40f82a8c31c6 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h @@ -605,11 +605,11 @@ class MLOperatorKernelCreationContext : public MLOperatorAttributes return MLOperatorTensor(tensor.Get()); } - std::optional TryGetConstantInputTensor(uint32_t inputIndex) const + std::optional TryGetConstantCpuInputTensor(uint32_t inputIndex) const { Microsoft::WRL::ComPtr tensor; ORT_THROW_IF_FAILED(m_implPrivate->TryGetConstantInputTensor(inputIndex, &tensor)); - if (tensor) + if (tensor && tensor->IsCpuData()) { return MLOperatorTensor(tensor.Get()); }