diff --git a/onnxruntime/core/providers/cuda/math/topk.cc b/onnxruntime/core/providers/cuda/math/topk.cc index d516537e25949..cf26e0acfa557 100644 --- a/onnxruntime/core/providers/cuda/math/topk.cc +++ b/onnxruntime/core/providers/cuda/math/topk.cc @@ -56,7 +56,7 @@ TopK::TopK(const OpKernelInfo& info) : CudaKernel(info) { info.GetAttrOrDefault("largest", &largest_, 1); info.GetAttrOrDefault("sorted", &sorted_, 1); if (!inputk) { - info.GetAttrOrDefault("k", &K_, 0); + info.GetAttrOrDefault("k", &attr_k_, 0); } } @@ -67,7 +67,7 @@ TopK::TopK(const OpKernelInfo& info) : CudaKernel(info) { static_cast(tensor_I->MutableDataRaw()), \ elem_nums_cuda, \ elem_nums.size(), \ - axis, K_, largest_, sorted_, N, dimension) + axis, k_value, largest_, sorted_, N, dimension) template Status TopK::ComputeInternal(OpKernelContext* ctx) const { @@ -77,19 +77,29 @@ Status TopK::ComputeInternal(OpKernelContext* ctx) const { int32_t axis = static_cast(axis_ < 0 ? rank + axis_ : axis_); ORT_ENFORCE(axis > -1 && axis < rank); + int64_t k_value = 0; if (inputk) { auto tensor_K = ctx->Input(1); ORT_ENFORCE(nullptr != tensor_K); - K_ = *tensor_K->Data(); - ORT_ENFORCE(K_ >= 0 && K_ <= tensor_X->Shape().GetDims()[axis]); + k_value = *tensor_K->Data(); + } else { // from attribute + k_value = attr_k_; } - auto output_shape = tensor_X->Shape(); - output_shape[axis] = K_; + // Now that we know the value of 'K' and the input shape, + // make a final validation before going to the implementation + const auto& input_shape = tensor_X->Shape(); + if ((k_value < 0) || (k_value > input_shape.GetDims()[axis])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Value of K outside range. K value: ", k_value, + ". Input shape: ", input_shape, " . Axis: ", axis); + } + + auto output_shape = input_shape; + output_shape[axis] = k_value; auto tensor_V = ctx->Output(0, output_shape); auto tensor_I = ctx->Output(1, output_shape); - if (0 == K_) { + if (output_shape.Size() == 0) { // Bail out early if the output is going to be empty return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/math/topk.h b/onnxruntime/core/providers/cuda/math/topk.h index 9dec13ad2a930..5731df3130c5a 100644 --- a/onnxruntime/core/providers/cuda/math/topk.h +++ b/onnxruntime/core/providers/cuda/math/topk.h @@ -17,7 +17,7 @@ class TopK final : public CudaKernel { int64_t axis_; int64_t largest_; int64_t sorted_; - mutable int64_t K_; + int64_t attr_k_; }; } // namespace cuda } // namespace onnxruntime