Skip to content

Commit

Permalink
Fixed bug in Flatten's axis (microsoft#18645)
Browse files Browse the repository at this point in the history
Flatten's axis is in the range [-r, r] rather than [-r, r-1].
  • Loading branch information
Honry authored Dec 1, 2023
1 parent 6781b6c commit 73a2eb8
Showing 1 changed file with 5 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
int64_t rank = input_shape.size();
NodeAttrHelper helper(node);
int64_t axis = helper.Get("axis", 1);
axis = HandleNegativeAxis(axis, rank);
ORT_ENFORCE(axis >= -rank && axis <= rank, "axis ", axis,
" is not in valid range [-", rank, ",", rank, "]");
if (axis < 0) {
axis += rank;
}

// Use WebNN's reshape to implement Flatten.
int64_t num_pre_axis_elements = std::accumulate(
Expand Down

0 comments on commit 73a2eb8

Please sign in to comment.