Skip to content

Commit

Permalink
More refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
PatriceVignola committed Oct 7, 2023
1 parent f850a2d commit 339abde
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,14 @@ namespace Windows::AI::MachineLearning::Adapter
std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
EdgeShapes outputShapes;
const std::unordered_map<std::string, std::vector<uint32_t>>* inferredOutputShapes;
};

using GraphNodeFactory = std::function<void(
const onnxruntime::Node& node,
MLOperatorTensorGetter& constantInputGetter,
const void* executionHandle,
const EdgeShapes* inputShapesOverrides,
/*out*/ EdgeShapes* outputShapes,
/*out*/ DmlGraphNodeCreateInfo* graphNodeCreateInfo
)>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,25 +492,23 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
MLOperatorTensorGetter& constantInputGetter,
const void* executionHandle,
const EdgeShapes* inputShapesOverrides,
/*out*/ EdgeShapes* outputShapes,
/*out*/ DmlGraphNodeCreateInfo* graphNodeCreateInfo
)
{
onnxruntime::ProtoHelperNodeContext nodeContext(node);
onnxruntime::OpNodeProtoHelper<onnxruntime::ProtoHelperNodeContext> protoHelper(&nodeContext);

// Use the same list of required constant inputs for the shape inferrer and the kernel.
EdgeShapes outputShapes;
InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, inputShapesOverrides, outputShapes);

graphNodeCreateInfo->outputShapes = outputShapes;
InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, inputShapesOverrides, *outputShapes);

// Create the kernel while allowing input shape and output shape queries according to options
ComPtr<DmlGraphOpKernelInfoWrapper> kernelInfoWrapper = wil::MakeOrThrow<DmlGraphOpKernelInfoWrapper>(
&protoHelper,
executionHandle,
true,
inputShapesOverrides,
&outputShapes,
outputShapes,
&defaultAttributesCapture,
graphNodeCreateInfo,
constantCpuInputCapture,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,6 @@ namespace Dml::GraphDescBuilder
return tensor;
};

DmlGraphNodeCreateInfo graphNodeCreateInfo;
EdgeShapes inputShapesOverrides(node.InputDefs().size());

// Override the input shapes with shapes that were previously inferred
Expand All @@ -269,19 +268,21 @@ namespace Dml::GraphDescBuilder
}
}

graphNodeCreateInfo.inferredOutputShapes = &inferredOutputShapes;
EdgeShapes outputShapes;
DmlGraphNodeCreateInfo graphNodeCreateInfo;
graphNodeProps.internalRegInfo->graphNodeFactoryRegistration->factory(
node,
constantCpuNodeInputGetter,
executionHandle,
&inputShapesOverrides,
/*out*/ &outputShapes,
/*out*/ &graphNodeCreateInfo
);

ORT_THROW_HR_IF(E_UNEXPECTED, graphNodeCreateInfo.outputShapes.EdgeCount() != node.OutputDefs().size());
ORT_THROW_HR_IF(E_UNEXPECTED, outputShapes.EdgeCount() != node.OutputDefs().size());
for (int i = 0; i < node.OutputDefs().size(); ++i)
{
inferredOutputShapes[node.OutputDefs()[i]->Name()] = graphNodeCreateInfo.outputShapes.GetShape(i);
inferredOutputShapes[node.OutputDefs()[i]->Name()] = outputShapes.GetShape(i);
}

// Create a map between operatorGraphNodeIndex to mainGraphNodeIndex.
Expand Down Expand Up @@ -380,7 +381,7 @@ namespace Dml::GraphDescBuilder
operatorGraphOutputEdge.FromNodeOutputIndex
};

nodeOutputShapes[arg->Name()] = graphNodeCreateInfo.outputShapes;
nodeOutputShapes[arg->Name()] = outputShapes;
}
}

Expand Down

0 comments on commit 339abde

Please sign in to comment.