diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp index f0a16da3a3c06..ec94772238cc9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp @@ -479,7 +479,7 @@ class DmlOperatorElementwisePow : public DmlOperator ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 2); ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); - auto constExpTensor = kernelInfo.TryGetConstantInputTensor(1); + auto constExpTensor = kernelInfo.TryGetConstantCpuInputTensor(1); if (constExpTensor && constExpTensor->GetTotalElementCount() == 1) { std::vector> kernelInputIndices = {0}; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h index 59a1719d08ee6..c40f82a8c31c6 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h @@ -605,11 +605,11 @@ class MLOperatorKernelCreationContext : public MLOperatorAttributes return MLOperatorTensor(tensor.Get()); } - std::optional TryGetConstantInputTensor(uint32_t inputIndex) const + std::optional TryGetConstantCpuInputTensor(uint32_t inputIndex) const { Microsoft::WRL::ComPtr tensor; ORT_THROW_IF_FAILED(m_implPrivate->TryGetConstantInputTensor(inputIndex, &tensor)); - if (tensor) + if (tensor && tensor->IsCpuData()) { return MLOperatorTensor(tensor.Get()); }