From bc322c1cdaf9938e2cf3bd8725465f6725db19b8 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Fri, 6 Oct 2023 23:00:18 -0700 Subject: [PATCH] Add check in case CPU inputs changed --- .../src/DmlRuntimeFusedGraphKernel.cpp | 44 +++++++++++++++---- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp index 2c9d27b0546f6..fd26a90be87e7 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -99,36 +99,62 @@ namespace Dml { ORT_THROW_HR_IF(E_UNEXPECTED, m_subgraphInputs.size() != kernelContext->InputCount()); - bool recompiledNeeded = m_compiledExecutionPlanOperator == nullptr; + bool recompileNeeded = m_compiledExecutionPlanOperator == nullptr; for (int inputIndex = 0; inputIndex < kernelContext->InputCount(); ++inputIndex) { const auto& input = kernelContext->RequiredInput(inputIndex); const std::string& inputName = m_subgraphInputs[inputIndex]->Name(); - auto iter = m_inferredInputShapes.find(inputName); + auto shapeIter = m_inferredInputShapes.find(inputName); - if (iter == m_inferredInputShapes.end()) + if (shapeIter == m_inferredInputShapes.end()) { m_inferredInputShapes[inputName] = input.Shape(); - recompiledNeeded = true; + recompileNeeded = true; } - else if (iter->second != input.Shape()) + else if (shapeIter->second != input.Shape()) { - iter->second = input.Shape(); - recompiledNeeded = true; + shapeIter->second = input.Shape(); + recompileNeeded = true; } // If we have CPU inputs that are not initializers (i.e. they were computed at runtime), add them to the initializer list if (input.Location().device.Type() == OrtDevice::CPU) { - // TODO (pavignol): Force recompile if CPU data changed auto inputProto = onnxruntime::utils::TensorToTensorProto(input, inputName); + + // We can only avoid recompiling the graph when all CPU inputs are identical + auto initializerIter = m_isInitializerTransferable.find(inputName); + + if (initializerIter != m_isInitializerTransferable.end()) + { + if (initializerIter->second.first->raw_data().length() == inputProto.raw_data().length()) + { + for (int i = 0; i < inputProto.raw_data().length(); ++i) + { + if (initializerIter->second.first->raw_data()[i] != inputProto.raw_data()[i]) + { + recompileNeeded = true; + break; + } + } + } + else + { + recompileNeeded = true; + } + } + else + { + recompileNeeded = true; + } + m_ownedCpuInputs.push_back(std::make_unique(std::move(inputProto))); m_isInitializerTransferable[inputName] = std::make_pair(m_ownedCpuInputs.back().get(), false); } } - if (recompiledNeeded) + if (recompileNeeded) { // Go through all the node args and replace their shapes with the real ones for (auto& nodeArg : m_intermediateNodeArgs)