From 182c525416eb5cbace8df52b6809a77ffc91545d Mon Sep 17 00:00:00 2001 From: guyang3532 <62738430+guyang3532@users.noreply.github.com> Date: Fri, 1 Dec 2023 19:27:50 +0800 Subject: [PATCH] Support MatMulBnb4 in PaddingElimination (#18646) Also support Cast pattern between input and embedding node for sparsity inspecting --- .../compute_optimizer/padding_elimination.cc | 3 +- .../training/ortmodule/_runtime_inspector.py | 32 +++++++++++++------ 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc index 2d75a02004ff2..d42af92c7c66d 100644 --- a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc +++ b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc @@ -282,7 +282,8 @@ void IterateSubgraphFromNode(Graph& graph, ORT_ENFORCE(subgraph.find(cur->MutableInputDefs()[0]) != subgraph.end()); subgraph.insert(cur->MutableOutputDefs()[0]); PushAllOutputNode(graph, to_visit, cur, visited); - } else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "MatMul", {1, 9, 13})) { + } else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "MatMul", {1, 9, 13}) || + graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "MatMulBnb4", {1}, kMSDomain)) { if (subgraph.find(cur->MutableInputDefs()[0]) != subgraph.end()) { // If shape of [batch_size, seqlen, ...] is propagated from the first argument of MatMul. // The dim size of the first argument must be larger than 2 to propagate the first two dims to the output. diff --git a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py index cfd2e25e13e26..05a5f30683824 100644 --- a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py +++ b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py @@ -157,12 +157,7 @@ def _initialize_embedding_padding_inspector(self, model, user_input_names): self._embedding_graph_input_to_padding_idx_map.clear() for node in model.graph.node: - if not ( - node.domain == "org.pytorch.aten" - and node.op_type == "ATen" - and node.input[1] in user_input_names - and len(node.input) >= 3 - ): + if not (node.domain == "org.pytorch.aten" and node.op_type == "ATen" and len(node.input) >= 3): continue found = [attr for attr in node.attribute if attr.name == "operator"] @@ -194,10 +189,29 @@ def _initialize_embedding_padding_inspector(self, model, user_input_names): if padding_idx < 0: continue - if node.input[1] not in self._embedding_graph_input_to_padding_idx_map: - self._embedding_graph_input_to_padding_idx_map[node.input[1]] = set() + # Given the input arg of embedding node, find the corresponding user input that feeds into the data. + # Will iterate the args recursively if some subgraph pattern is found between the input and the embedding, + # such as Input -> Cast -> Cast -> Embedding. + # TODO: This is a workaround for the case that the input of embedding is a list of Cast nodes which is found + # in Llama-2. We need to find a general way to handle all types of subgraph parttern between input and embedding. + def _get_embedding_graph_input(node_arg): + if node_arg in user_input_names: + return node_arg + input_node = self._try_get_node_from_its_output(node_arg) + if input_node.op_type == "Cast": + return _get_embedding_graph_input(input_node.input[0]) + else: + self._logger.warning(f"Cannot find embedding input {node_arg}") + return None + + embedding_graph_input = _get_embedding_graph_input(node.input[1]) + if embedding_graph_input is None: + continue + + if embedding_graph_input not in self._embedding_graph_input_to_padding_idx_map: + self._embedding_graph_input_to_padding_idx_map[embedding_graph_input] = set() - self._embedding_graph_input_to_padding_idx_map[node.input[1]].add(padding_idx) + self._embedding_graph_input_to_padding_idx_map[embedding_graph_input].add(padding_idx) def _initialize_loss_label_padding_inspector(self, model, user_input_names): """Register loss label input padding inspector.