-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Experimental] Add a path to fallback more nodes to CPUs. #19769
Conversation
8da6b51
to
ffa61d7
Compare
Shape-related nodes don't only start with `Shape` or `Size`. In dynamo-captured ONNX model, it can starts with a graph input. A new transform is added to fallback `all` nodes which can be reversely traversed from a `shape-like` variable. Some `shape-like` variables are list below. - all inputs of Range - 2nd input of Reshape - 2nd input of Unsqueeze - 1st input of ConstantOfShape - 2nd-to-last inputs of Slice. Fix header Remove unused variable Versioning shape inputs Fix
ffa61d7
to
ed79ec7
Compare
08ac5f3
to
f896fb8
Compare
orttraining/orttraining/test/python/orttraining_test_aggressive_cpu_fallback.py
Fixed
Show fixed
Hide fixed
orttraining/orttraining/test/python/orttraining_test_aggressive_cpu_fallback.py
Fixed
Show fixed
Hide fixed
orttraining/orttraining/test/python/orttraining_test_aggressive_cpu_fallback.py
Fixed
Show fixed
Hide fixed
orttraining/orttraining/test/python/orttraining_test_aggressive_cpu_fallback.py
Fixed
Show fixed
Hide fixed
orttraining/orttraining/test/python/orttraining_test_aggressive_cpu_fallback.py
Fixed
Show fixed
Hide fixed
e9de2b8
to
5abdb86
Compare
Fix typo Write to fixed place Remove unused import's run it Fix Change test location
5abdb86
to
af9319b
Compare
@@ -39,6 +39,7 @@ steps: | |||
timeoutInMinutes: 60 | |||
|
|||
# Entry point for all ort training api tests | |||
# TODO: move onnxscript installation to CI image. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure when it will be the right time as onnxscript is a relatively new tool.
|
||
std::unordered_map<std::string, std::unordered_map<int64_t, std::vector<size_t>>> shape_related_inputs_in_nodes = { | ||
// 2nd input of Expand-13 is a shape-related input. | ||
{"Expand", {{13 /* since version */, {1} /* shape inputs' indices */}}}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of reverse traversal from these pre-specified list of ops (which requires periodic maintenance - updating based on new ops added to the ONNX standard, op version revisions, shape input indices across op version revisions, etc.) - can the reverse traversal start from a provider assigned node requiring a specific input on CPU (usually any input needed on CPU by a provider node is "shape like") and this information is available in the kernel def of the node ? That seems like a more "automated" way of the pre-cooked list approach ?
// shape = onnx::Concat(s0, s1) | ||
// reshaped = onnx::Reshape(x, shape) | ||
// Then, the shape-producing node is Concat. | ||
std::unordered_set<const Node*> shape_producing_nodes; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
InlinedHashSet
// 2. finds `shape` is a shape-related variable since Reshape's 2nd input is a shape-related input, | ||
// 3. and then records the producer of `shape` (i.e., `Concat`). | ||
for (auto& input_index : shape_input_indices) { | ||
auto input = node.InputDefs().at(input_index); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we check if this is an iniitializer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the difference? From finding shape-related nodes' perspective, graph input and initializer are the same. I am not sure if ORT have different assumptions somewhere.
// Stop the traversal when a "Shape" node is found. | ||
graph.ReverseDFSFrom( | ||
start_nodes, | ||
[&shape_related_node_indices](const Node* n) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there are nodes, where shape is just one of the outputs, but the rest of the computation should be done on device?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not aware of any examples. If you are looking for an op producing both of CPU and GPU outputs, attention could be a case when it wants to pass forward's random seed (int64 scalar) to backward.
@@ -39,6 +43,132 @@ static bool IsSmallInitializer(const onnxruntime::GraphViewer& graph, const Node | |||
} | |||
} // namespace | |||
|
|||
std::unordered_set<NodeIndex> GetShapeRelatedNodes(const onnxruntime::GraphViewer& viewer) { | |||
// Conceptually, this function traverse from shape-consuming nodes | |||
// to fallback all its upstream nodes to CPU. Consider a graph |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add some TODOs to enhance this for situations where it won't work:
(1) There is no shape "consumer" at all (i.e.) the "shape like" output eventually becomes graph output (Rare corner case - but there are definitiely models like these)
(2) Cases where the shape subgraph is split across graph levels - main graph has some portion of the shape nodes and a subgraph has a portion of the shape nodes - in this case the "shape consumer" at the main graph level will be a subgraph containing node (If/Loop/Scan) - and the shape info may be consumed "explicitly" (as a graph input to If/Loop/Scan) or implicitly by the node (i.e.) not as an explicit graph input but due to some node in the subgraph referencing the main graph node output(s)
// 1st input of ConstantOfShape is a shape-related input. | ||
{"ConstantOfShape", {{9, {0}}, {20, {0}}, {21, {0}}}}, | ||
// 2nd to 5th inputs of Slice-13 are shape-related inputs. | ||
{"Slice", {{13, {1, 2, 3, 4}}}}}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if operator Range is inlined but it could be considered as consuming a shape as well.
to_stop); | ||
} | ||
|
||
return shape_related_node_indices; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if an shape input is on CUDA when this algorithm is moved to CPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will fallback the producer of the shape input and its upstream nodes to CPU.
Force-pushed and now I can't update this branch anymore. #19875 continues the work. |
Shape-related nodes don't only start with
Shape
orSize
. In dynamo-captured ONNX model, it can starts with a graph input. A new transform is added to fallbackall
nodes which can be reversely traversed from ashape-like
variable. Someshape-like
variables are list below.For example, the comment below explains the desired CPU ops for a
Reshape
.This PR fixes my llama model + AtenOp. The running time is reduced from 4.x sec to 0.6 sec. The side effect of this change to other graph transformers is still unclear, so it's off by default. To enable it, set
ORT_AGGRESSIVE_CPU_FALLBACK=1
. Ideally, we should fallback allsmall computation node
(nodes with small inputs/outputs) to CPU, but shape information is not available for each ofNodeArg
. We should also improve shape inference in ORT in the future.The old
GetCpuPreferredNodes
traverses the graphtopologically
from CPU-outputgenerating
nodes and tries to place downstream nodes on CPU when possible. This is different since this PR's traverses the graphreversely topologically
starting with CPU-outputconsuming
nodes.