Skip to content

Commit

Permalink
#7699: update moreh softmax backward dim validation
Browse files Browse the repository at this point in the history
  • Loading branch information
hschoi4448 committed Apr 23, 2024
1 parent 5c0d6bb commit 06f3ff3
Showing 1 changed file with 3 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ void MorehSoftmaxBackward::validate_with_output_tensors(const std::vector<Tensor
TT_ASSERT(output_grad_tensor.get_dtype() == DataType::BFLOAT16 || output_grad_tensor.get_dtype() == DataType::BFLOAT8_B);

// validate parameters
TT_ASSERT(this->dim >= 0 || this->dim <= 3, "Only dim [0,1,2,3] supported");
auto rank = output_tensor.get_legacy_shape().rank();

TT_ASSERT(this->dim >= 0 && this->dim < rank, fmt::format("dim {} should be less than output tensor rank {}", this->dim, rank));

if(output_tensors.empty() || !output_tensors.at(0).has_value()){
// If the user decided to not use any optional output tensors, then this would be empty or would be a nullptr.
Expand Down

0 comments on commit 06f3ff3

Please sign in to comment.