diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index ba022533a1e94..adb4fd131119f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -347,19 +347,23 @@ namespace Dml::GraphDescBuilder // This is a highly inefficient approach to generating constant nodes. It duplicates constant data // across the graph input as well as every consumer's unique constant node. However it is currently // only used for small inputs. - - // TODO: Rework this to create DML constant nodes with the minimum data size actually used by consuming - // nodes. This would allow this size to be reduced while handling the case that 1D scale and zero point - // values that have been de-duplicated with conversion to scalars in kernels. - uint32_t c_maxConstNodeDataSize = 1024 * 1024; + uint32_t c_maxConstNodeDataSize = 8; ComPtr constantInput = constantCpuGraphInputGetter(arg->Name()); - if (constantInput && constantInput->GetTensorByteSize() < c_maxConstNodeDataSize) + auto& operatorGraphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex]; + std::vector toNodeInputTensorDescs = operatorGraphInputNode->GetInputTensors(); + DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex]; + + if (constantInput && tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize) { + // The tensor description's size should be no larger than the constant input unless it was rounded to + // the required alignment. + assert(((constantInput->GetTensorByteSize() + 3) & ~3) >= tensorDesc->totalTensorSizeInBytes); + size_t minimumConstantSize = std::min(constantInput->GetTensorByteSize(), tensorDesc->totalTensorSizeInBytes); auto data = static_cast(constantInput->GetData()); - std::vector tensorData(data, data + constantInput->GetTensorByteSize()); - + std::vector tensorData(data, data + minimumConstantSize); + NodeInfo nodeInfo = {}; nodeInfo.nodeDef = std::move(tensorData); graphNodes.push_back(std::move(nodeInfo)); @@ -379,9 +383,6 @@ namespace Dml::GraphDescBuilder edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; graphInputEdges.push_back(edge); - auto& graphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex]; - std::vector toNodeInputTensorDescs = graphInputNode->GetInputTensors(); - DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex]; tensorDesc->flags |= DML_TENSOR_FLAG_OWNED_BY_DML; } } @@ -445,7 +446,7 @@ namespace Dml::GraphDescBuilder // TODO: Change as new header is ingested if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING) dmlDesc.Type = (DML_OPERATOR_TYPE) 169; - + // TODO: Change as new header is ingested if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT) dmlDesc.Type = (DML_OPERATOR_TYPE) 170;