Skip to content

Commit

Permalink
Add check in case CPU inputs changed
Browse files Browse the repository at this point in the history
  • Loading branch information
PatriceVignola committed Oct 7, 2023
1 parent b809a5b commit bc322c1
Showing 1 changed file with 35 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,36 +99,62 @@ namespace Dml
{
ORT_THROW_HR_IF(E_UNEXPECTED, m_subgraphInputs.size() != kernelContext->InputCount());

bool recompiledNeeded = m_compiledExecutionPlanOperator == nullptr;
bool recompileNeeded = m_compiledExecutionPlanOperator == nullptr;

for (int inputIndex = 0; inputIndex < kernelContext->InputCount(); ++inputIndex)
{
const auto& input = kernelContext->RequiredInput<onnxruntime::Tensor>(inputIndex);
const std::string& inputName = m_subgraphInputs[inputIndex]->Name();
auto iter = m_inferredInputShapes.find(inputName);
auto shapeIter = m_inferredInputShapes.find(inputName);

if (iter == m_inferredInputShapes.end())
if (shapeIter == m_inferredInputShapes.end())
{
m_inferredInputShapes[inputName] = input.Shape();
recompiledNeeded = true;
recompileNeeded = true;
}
else if (iter->second != input.Shape())
else if (shapeIter->second != input.Shape())
{
iter->second = input.Shape();
recompiledNeeded = true;
shapeIter->second = input.Shape();
recompileNeeded = true;
}

// If we have CPU inputs that are not initializers (i.e. they were computed at runtime), add them to the initializer list
if (input.Location().device.Type() == OrtDevice::CPU)
{
// TODO (pavignol): Force recompile if CPU data changed
auto inputProto = onnxruntime::utils::TensorToTensorProto(input, inputName);

// We can only avoid recompiling the graph when all CPU inputs are identical
auto initializerIter = m_isInitializerTransferable.find(inputName);

if (initializerIter != m_isInitializerTransferable.end())
{
if (initializerIter->second.first->raw_data().length() == inputProto.raw_data().length())
{
for (int i = 0; i < inputProto.raw_data().length(); ++i)
{
if (initializerIter->second.first->raw_data()[i] != inputProto.raw_data()[i])
{
recompileNeeded = true;
break;
}
}
}
else
{
recompileNeeded = true;
}
}
else
{
recompileNeeded = true;
}

m_ownedCpuInputs.push_back(std::make_unique<ONNX_NAMESPACE::TensorProto>(std::move(inputProto)));
m_isInitializerTransferable[inputName] = std::make_pair(m_ownedCpuInputs.back().get(), false);
}
}

if (recompiledNeeded)
if (recompileNeeded)
{
// Go through all the node args and replace their shapes with the real ones
for (auto& nodeArg : m_intermediateNodeArgs)
Expand Down

0 comments on commit bc322c1

Please sign in to comment.