Skip to content

Commit

Permalink
Cherry pick fix constant pow (#18785)
Browse files Browse the repository at this point in the history
### Description
Cherry pick #18784
  • Loading branch information
Jamather authored and jeffbloo committed Jan 4, 2024
1 parent 107d749 commit d2f7a5b
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ class DmlOperatorElementwisePow : public DmlOperator
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 2);
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);

auto constExpTensor = kernelInfo.TryGetConstantInputTensor(1);
auto constExpTensor = kernelInfo.TryGetConstantCpuInputTensor(1);
if (constExpTensor && constExpTensor->GetTotalElementCount() == 1)
{
std::vector<std::optional<uint32_t>> kernelInputIndices = {0};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -605,11 +605,11 @@ class MLOperatorKernelCreationContext : public MLOperatorAttributes
return MLOperatorTensor(tensor.Get());
}

std::optional<MLOperatorTensor> TryGetConstantInputTensor(uint32_t inputIndex) const
std::optional<MLOperatorTensor> TryGetConstantCpuInputTensor(uint32_t inputIndex) const
{
Microsoft::WRL::ComPtr<IMLOperatorTensor> tensor;
ORT_THROW_IF_FAILED(m_implPrivate->TryGetConstantInputTensor(inputIndex, &tensor));
if (tensor)
if (tensor && tensor->IsCpuData())
{
return MLOperatorTensor(tensor.Get());
}
Expand Down

0 comments on commit d2f7a5b

Please sign in to comment.