Skip to content

Commit

Permalink
Security fuzz address sanitizer fix Bug (continue) (#21579)
Browse files Browse the repository at this point in the history
Add a check of node.InputDefs()[2]->Exists() for Layernorm bias (Follow up https://github.com/microsoft/onnxruntime/pull/21528/files#r1694026327)

Format the file: break long line to be within 120 chars limit.
  • Loading branch information
tianleiwu authored Aug 2, 2024
1 parent 1391354 commit 54d6614
Showing 1 changed file with 37 additions and 16 deletions.
53 changes: 37 additions & 16 deletions onnxruntime/core/optimizer/attention_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ void MergeWeights(const T* q, const T* k, const T* v, std::vector<T>& result, in

// Merge 2-D weights (q, k and v) by concatenating them row by row.
template <typename T>
void MergeMatMulWeights(const T* q_weight, const T* k_weight, const T* v_weight, std::vector<T>& result, int64_t hidden_size) {
void MergeMatMulWeights(const T* q_weight, const T* k_weight, const T* v_weight,
std::vector<T>& result, int64_t hidden_size) {
const T* q = q_weight;
const T* k = k_weight;
const T* v = v_weight;
Expand Down Expand Up @@ -144,7 +145,8 @@ static NodeArg& MergeQkvWeights(Graph& graph, int64_t hidden_size,
return graph_utils::AddInitializer(graph, initializer);
}

static NodeArg* ConvertMaskToInt32(Graph& graph, NodeArg* mask_input, ProviderType provider_type, const logging::Logger& logger) {
static NodeArg* ConvertMaskToInt32(Graph& graph, NodeArg* mask_input, ProviderType provider_type,
const logging::Logger& logger) {
// Validate mask input shape (batch_size, sequence_length) and data type.
// Note that batch_size and sequence_length could be symbolic.
const TensorShapeProto* mask_shape = mask_input->Shape();
Expand Down Expand Up @@ -208,9 +210,11 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
Node& node = *p_node;
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));

if ((node.GetOutputEdgesCount() >= 2 && node.GetOutputEdgesCount() <= 6) && // Add node.GetOutputEdgesCount() == 5/6 for distilbert
// Add node.GetOutputEdgesCount() == 5/6 for distilbert
if ((node.GetOutputEdgesCount() >= 2 && node.GetOutputEdgesCount() <= 6) &&
graph_utils::IsSupportedOptypeVersionAndDomain(node, "LayerNormalization", {1, 17}, kOnnxDomain) &&
graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) && node.InputDefs().size() > 2) {
graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) &&
node.InputDefs().size() > 2 && node.InputDefs()[2]->Exists()) { // Bias is an optional input for LayerNorm
// Get hidden size from layer norm bias tensor shape.
const NodeArg& layer_norm_bias = *(node.InputDefs()[2]);
if (!optimizer_utils::IsShapeKnownOnAllDims(layer_norm_bias, 1)) {
Expand Down Expand Up @@ -242,8 +246,10 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
fused_count++;
modified = true;
}
} else if (reshape_count == 1 && (shape_count == 1 || shape_count == 3) && (static_cast<size_t>(reshape_count) + shape_count) == node.GetOutputEdgesCount()) { // GPT
if (AttentionFusionHelper::FuseGptAttention(node, graph, hidden_size, mask_int32_map, shape_count == 1, logger)) {
} else if (reshape_count == 1 && (shape_count == 1 || shape_count == 3) &&
(static_cast<size_t>(reshape_count) + shape_count) == node.GetOutputEdgesCount()) { // GPT
if (AttentionFusionHelper::FuseGptAttention(node, graph, hidden_size, mask_int32_map, shape_count == 1,
logger)) {
fused_count++;
modified = true;
}
Expand Down Expand Up @@ -301,7 +307,8 @@ static bool FuseSubGraphQKImpl(Node& layer_norm,
return false;
}

if (!AttentionFusionHelper::CheckNodesInPathQ(graph, pivot_nodes[1].get(), q_reshape, q_transpose, num_heads, head_size, logger)) {
if (!AttentionFusionHelper::CheckNodesInPathQ(graph, pivot_nodes[1].get(),
q_reshape, q_transpose, num_heads, head_size, logger)) {
DEBUG_LOG("CheckNodesInPathQ returns false");
return false;
}
Expand Down Expand Up @@ -365,7 +372,8 @@ static bool FuseSubGraphQKImpl(Node& layer_norm,
}

// Now everything is ready, we will start fusing subgraph.
NodeArg* mask_int32 = ConvertMaskToInt32(graph, mask_input, mask_int32_map, layer_norm.GetExecutionProviderType(), logger);
NodeArg* mask_int32 = ConvertMaskToInt32(graph, mask_input, mask_int32_map, layer_norm.GetExecutionProviderType(),
logger);
if (nullptr == mask_int32) {
DEBUG_LOG("Failed to convert mask to int32");
return false;
Expand Down Expand Up @@ -438,7 +446,8 @@ static bool FuseSubGraphQK(Node& layer_norm,
}

std::vector<NodeIndex> nodes_to_remove;
if (!FuseSubGraphQKImpl(layer_norm, graph, parent_path_nodes, mask_input, mask_int32_map, edges, nodes_to_remove, hidden_size,
if (!FuseSubGraphQKImpl(layer_norm, graph, parent_path_nodes,
mask_input, mask_int32_map, edges, nodes_to_remove, hidden_size,
num_heads, head_size, mask_nodes.mask_filter_value, logger)) {
return false;
}
Expand Down Expand Up @@ -529,7 +538,8 @@ static bool FuseSubGraphQKDistilBert(Node& layer_norm,
}

std::vector<NodeIndex> nodes_to_remove;
if (!FuseSubGraphQKImpl(layer_norm, graph, parent_path_nodes, mask_input, mask_int32_map, edges, nodes_to_remove, hidden_size,
if (!FuseSubGraphQKImpl(layer_norm, graph, parent_path_nodes,
mask_input, mask_int32_map, edges, nodes_to_remove, hidden_size,
num_heads, head_size, mask_nodes.mask_filter_value, logger)) {
return false;
}
Expand Down Expand Up @@ -615,7 +625,12 @@ After Fusion:
| |
Add
*/
bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer_norm, Graph& graph, int64_t hidden_size, std::map<std::string, NodeArg*>& mask_int32_map, const logging::Logger& logger) {
bool AttentionFusion::FuseSubGraph(Node& layer_norm,
const Node& add_after_layer_norm,
Graph& graph,
int64_t hidden_size,
std::map<std::string, NodeArg*>& mask_int32_map,
const logging::Logger& logger) {
std::vector<graph_utils::EdgeEndToMatch> parent_path{
{0, 0, "Add", {7, 13}, kOnnxDomain},
{0, 0, "MatMul", {1, 9, 13}, kOnnxDomain},
Expand Down Expand Up @@ -657,7 +672,9 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer
int64_t num_heads = 0; // will be updated in CheckNodesInPathV
int64_t head_size = 0; // will be updated in CheckNodesInPathV
NodeIndex record_node_idx = 0; // will be updated in CheckNodesInPathV if it's distilbert model
if (!AttentionFusionHelper::CheckNodesInPathV(graph, reshape, transpose, qkv_matmul, v_transpose, v_reshape, num_heads, head_size, hidden_size, record_node_idx, logger)) {
if (!AttentionFusionHelper::CheckNodesInPathV(graph, reshape, transpose,
qkv_matmul, v_transpose, v_reshape, num_heads,
head_size, hidden_size, record_node_idx, logger)) {
DEBUG_LOG("CheckNodesInPathV return false");
return false;
}
Expand All @@ -672,7 +689,8 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer
}

// store parent path
std::vector<std::reference_wrapper<const Node>> parent_path_nodes{reshape, transpose, qkv_matmul, v_transpose, v_reshape, v_add, v_matmul};
std::vector<std::reference_wrapper<const Node>> parent_path_nodes{
reshape, transpose, qkv_matmul, v_transpose, v_reshape, v_add, v_matmul};

// Find mask nodes: Unsqueeze -> Unsqueeze -> (Cast) -> Sub -> Mul -> Add -> Softmax --> [MatMul]
// The "Cast" node in parentheses is optional.
Expand All @@ -681,10 +699,13 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer

if (AttentionFusionHelper::MatchInputMaskSubgraph(graph, qkv_matmul, mask_nodes, logger, false)) {
NodeArg* mask_input = graph.GetNode(mask_nodes.unsqueeze_1->Index())->MutableInputDefs()[0];
return FuseSubGraphQK(layer_norm, graph, mask_nodes, mask_input, parent_path_nodes, hidden_size, num_heads, head_size, mask_int32_map, logger);
} else if (AttentionFusionHelper::MatchInputMaskSubgraph(graph, layer_norm, qkv_matmul, mask_nodes_distilbert, record_node_idx, logger)) {
return FuseSubGraphQK(layer_norm, graph, mask_nodes, mask_input,
parent_path_nodes, hidden_size, num_heads, head_size, mask_int32_map, logger);
} else if (AttentionFusionHelper::MatchInputMaskSubgraph(graph, layer_norm, qkv_matmul,
mask_nodes_distilbert, record_node_idx, logger)) {
NodeArg* mask_input = graph.GetNode(mask_nodes_distilbert.equal->Index())->MutableInputDefs()[0];
return FuseSubGraphQKDistilBert(layer_norm, graph, mask_nodes_distilbert, mask_input, parent_path_nodes, hidden_size, num_heads, head_size, mask_int32_map, logger);
return FuseSubGraphQKDistilBert(layer_norm, graph, mask_nodes_distilbert, mask_input,
parent_path_nodes, hidden_size, num_heads, head_size, mask_int32_map, logger);
} else {
DEBUG_LOG("Failed in match input mask subgraph");
return false;
Expand Down

0 comments on commit 54d6614

Please sign in to comment.