Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix layernorm and softmax axis after upstream (microsoft#17255)
### Fix layernorm and softmax axis after upstream For Gather (the slicing is a scalar), the output rank is small than its inputs. When we upstream this kind of Gather before softmax or layernorm, we should also update the axis attribute. Otherwise, the axis might be out-of-date and incorrect for the updated rank. ``` File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_fallback.py", line 157, in handle_exception raise exception File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 280, in forward self._build_graph(graph_transformer_config) File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_logger.py", line 158, in wrapper result = func(graph_execution_manager, *args, **kwargs) File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_logger.py", line 273, in wrapper result = func(graph_execution_manager, *args, **kwargs) File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 361, in _build_graph super()._build_graph(graph_transformer_config) File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 184, in _build_graph self._graph_builder.build(config) RuntimeError: /onnxruntime/orttraining/orttraining/python/orttraining_pybind_state.cc:823 onnxruntime::python::addObjectMethodsForTraining(pybind11::module&, onnxruntime::python::ExecutionProviderRegistrationFn)::<lambda(onnxruntime::training::OrtModuleGraphBuilder*, const onnxruntime::training::TrainingGraphTransformerConfiguration&)> [ONNXRuntimeError] : 1 : FAIL : Node (Softmax_2904) Op (Softmax) [ShapeInferenceError] 'axis' must be in [-3 , 2]. Its actual value is: 3 ```
- Loading branch information