From 274d162d93d64969e7af3040530e8c6cf93f06ae Mon Sep 17 00:00:00 2001 From: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Date: Wed, 8 May 2024 16:07:02 -0700 Subject: [PATCH] Fix SparseAttention cos/sin cache dimension checks (#20609) ### Description This PR fixes the dimension checks for the cos/sin caches used in the rotary embeddings in the `SparseAttention` operator. ### Motivation and Context This PR ports over the same changes from [this PR](https://github.com/microsoft/onnxruntime/pull/20547) for `GroupQueryAttention`. --- .../contrib_ops/cuda/sparse/sparse_attention_helper.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_helper.h b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_helper.h index 416ebf1667b87..7e98b374c455e 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_helper.h +++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_helper.h @@ -202,13 +202,13 @@ Status CheckInputs(void* params, "head_size shall be a multiple of 16. Got head_size = ", head_size); } - if (cos_dims[0] < max_sequence_length) { + if (cos_dims[0] < total_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache dimension 0 should be of max_sequence_length."); + "cos_cache dimension 0 should be not be less than total_sequence_length."); } - if (sin_dims[0] < max_sequence_length) { + if (sin_dims[0] < total_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sin_cache dimension 0 should be of max_sequence_length."); + "sin_cache dimension 0 should be not be less than total_sequence_length."); } if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,