Skip to content

Commit

Permalink
[DML EP] Fix external data unpacking (#19415)
Browse files Browse the repository at this point in the history
### Description
This change
55a6694
didn't take into account external data when unpacking initializer, and
therefore crashes when trying to unpack them.
  • Loading branch information
PatriceVignola authored and rachguo committed Feb 8, 2024
1 parent 6e61306 commit ad63507
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -344,20 +344,25 @@ namespace Dml::GraphDescBuilder
dmlFusedNodeInputIndex < isConstGpuGraphInputCount &&
isConstGpuGraphInput[dmlFusedNodeInputIndex])
{
// This is a highly inefficient approach to generating constant nodes. It duplicates constant data
// across the graph input as well as every consumer's unique constant node. However it is currently
// This is a highly inefficient approach to generating constant nodes. It duplicates constant data
// across the graph input as well as every consumer's unique constant node. However it is currently
// only used for small inputs.
uint32_t c_maxConstNodeDataSize = 8;

ComPtr<OnnxTensorWrapper> constantInput = constantCpuGraphInputGetter(arg->Name());

auto& operatorGraphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex];
std::vector<DmlBufferTensorDesc*> toNodeInputTensorDescs = operatorGraphInputNode->GetInputTensors();
DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex];
ComPtr<OnnxTensorWrapper> constantInput;

if (constantInput && tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize)
if (tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize)
{
// The tensor description's size should be no larger than the constant input unless it was rounded to
constantInput = constantCpuGraphInputGetter(arg->Name());
}

if (constantInput)
{
// The tensor description's size should be no larger than the constant input unless it was rounded to
// the required alignment.
assert(((constantInput->GetTensorByteSize() + 3) & ~3) >= tensorDesc->totalTensorSizeInBytes);
size_t minimumConstantSize = std::min(constantInput->GetTensorByteSize(), gsl::narrow_cast<size_t>(tensorDesc->totalTensorSizeInBytes));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1123,7 +1123,7 @@ namespace Windows::AI::MachineLearning::Adapter
}
ORT_CATCH_RETURN
}

template <class NodeInfoImpl_t, class Base1_t, class Base2_t>
HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper<NodeInfoImpl_t, Base1_t, Base2_t>::GetConstantInputTensor(uint32_t inputIndex, IMLOperatorTensor** tensor) const noexcept
{
Expand Down Expand Up @@ -1168,7 +1168,7 @@ namespace Windows::AI::MachineLearning::Adapter
m_requiredConstantCpuInputs.begin(),
m_requiredConstantCpuInputs.end(),
inputIndex) != m_requiredConstantCpuInputs.end();

// This shouldn't happen since kernel creation is deferred and repeated when required constant inputs are not present.
ORT_THROW_HR_IF(E_UNEXPECTED, inputRequiredAsConstant);
}
Expand Down Expand Up @@ -1562,7 +1562,13 @@ namespace Windows::AI::MachineLearning::Adapter
OnnxTensorWrapper::OnnxTensorWrapper(onnx::TensorProto* impl, const onnxruntime::Path& modelPath) : m_impl(impl)
{
// The tensor may be stored as raw data or in typed fields.
if (impl->has_raw_data())
if (impl->data_location() == onnx::TensorProto_DataLocation_EXTERNAL)
{
THROW_IF_NOT_OK(onnxruntime::utils::UnpackInitializerData(*impl, modelPath, m_unpackedExternalTensor));
m_dataPtr = reinterpret_cast<std::byte*>(m_unpackedExternalTensor.data());
m_tensorByteSize = m_unpackedExternalTensor.size();
}
else if (impl->has_raw_data())
{
m_dataPtr = reinterpret_cast<std::byte*>(impl->mutable_raw_data()->data());
m_tensorByteSize = impl->raw_data().size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ class OnnxTensorWrapper : public WRL::Base<IMLOperatorTensor>, public Closable
private:
size_t m_tensorByteSize = 0;
std::unique_ptr<std::byte[]> m_unpackedTensor;
std::vector<uint8_t> m_unpackedExternalTensor;
std::byte* m_dataPtr = nullptr;

// Lifetime is managed by the caller and guaranteed to outlive this class
Expand Down

0 comments on commit ad63507

Please sign in to comment.