Skip to content

Commit

Permalink
Support MatMulBnb4 in PaddingElimination (#18646)
Browse files Browse the repository at this point in the history
Also support Cast pattern between input and embedding node for sparsity
inspecting
  • Loading branch information
guyang3532 authored Dec 1, 2023
1 parent ccfea55 commit 182c525
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,8 @@ void IterateSubgraphFromNode(Graph& graph,
ORT_ENFORCE(subgraph.find(cur->MutableInputDefs()[0]) != subgraph.end());
subgraph.insert(cur->MutableOutputDefs()[0]);
PushAllOutputNode(graph, to_visit, cur, visited);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "MatMul", {1, 9, 13})) {
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "MatMul", {1, 9, 13}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "MatMulBnb4", {1}, kMSDomain)) {
if (subgraph.find(cur->MutableInputDefs()[0]) != subgraph.end()) {
// If shape of [batch_size, seqlen, ...] is propagated from the first argument of MatMul.
// The dim size of the first argument must be larger than 2 to propagate the first two dims to the output.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,7 @@ def _initialize_embedding_padding_inspector(self, model, user_input_names):
self._embedding_graph_input_to_padding_idx_map.clear()

for node in model.graph.node:
if not (
node.domain == "org.pytorch.aten"
and node.op_type == "ATen"
and node.input[1] in user_input_names
and len(node.input) >= 3
):
if not (node.domain == "org.pytorch.aten" and node.op_type == "ATen" and len(node.input) >= 3):
continue

found = [attr for attr in node.attribute if attr.name == "operator"]
Expand Down Expand Up @@ -194,10 +189,29 @@ def _initialize_embedding_padding_inspector(self, model, user_input_names):
if padding_idx < 0:
continue

if node.input[1] not in self._embedding_graph_input_to_padding_idx_map:
self._embedding_graph_input_to_padding_idx_map[node.input[1]] = set()
# Given the input arg of embedding node, find the corresponding user input that feeds into the data.
# Will iterate the args recursively if some subgraph pattern is found between the input and the embedding,
# such as Input -> Cast -> Cast -> Embedding.
# TODO: This is a workaround for the case that the input of embedding is a list of Cast nodes which is found
# in Llama-2. We need to find a general way to handle all types of subgraph parttern between input and embedding.
def _get_embedding_graph_input(node_arg):
if node_arg in user_input_names:
return node_arg
input_node = self._try_get_node_from_its_output(node_arg)
if input_node.op_type == "Cast":
return _get_embedding_graph_input(input_node.input[0])
else:
self._logger.warning(f"Cannot find embedding input {node_arg}")
return None

embedding_graph_input = _get_embedding_graph_input(node.input[1])
if embedding_graph_input is None:
continue

if embedding_graph_input not in self._embedding_graph_input_to_padding_idx_map:
self._embedding_graph_input_to_padding_idx_map[embedding_graph_input] = set()

self._embedding_graph_input_to_padding_idx_map[node.input[1]].add(padding_idx)
self._embedding_graph_input_to_padding_idx_map[embedding_graph_input].add(padding_idx)

def _initialize_loss_label_padding_inspector(self, model, user_input_names):
"""Register loss label input padding inspector.
Expand Down

0 comments on commit 182c525

Please sign in to comment.