diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc index 9dea120949bdf..fa1315d90e1b6 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc @@ -40,7 +40,7 @@ REGISTER_KERNEL_TYPED(MLFloat16) namespace { template || std::is_same_v, void>> -void ComputeJob( +Status ComputeJob( const T* input_data, const T* skip_data, const T* gamma_data, @@ -94,9 +94,11 @@ void ComputeJob( p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h] + beta_data[h]; } } + + return Status::OK(); } -void ComputeJob( +Status ComputeJob( const MLFloat16* input_data, const MLFloat16* skip_data, const MLFloat16* gamma_data, @@ -117,23 +119,45 @@ void ComputeJob( float mean(0.0f); float mean_square(0.0f); - const size_t num_elems = static_cast(hidden_size); - float* float_input = new float[num_elems]; - MlasConvertHalfToFloatBuffer(p_input, float_input, num_elems); - float* float_skip = new float[num_elems]; - MlasConvertHalfToFloatBuffer(p_skip, float_skip, num_elems); + + float* float_input = nullptr; + try { + float_input = new float[num_elems]; + MlasConvertHalfToFloatBuffer(p_input, float_input, num_elems); + } catch (const std::exception& e) { + return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Failed to convert input data to float: ", e.what()); + } + + float* float_skip = nullptr; + try { + float_skip = new float[num_elems]; + MlasConvertHalfToFloatBuffer(p_skip, float_skip, num_elems); + } catch (const std::exception& e) { + return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Failed to convert skip data to float: ", e.what()); + } + float* float_bias = nullptr; if (bias_data != nullptr) { - float_bias = new float[num_elems]; - MlasConvertHalfToFloatBuffer(bias_data, float_bias, num_elems); + try { + float_bias = new float[num_elems]; + MlasConvertHalfToFloatBuffer(bias_data, float_bias, num_elems); + } catch (const std::exception& e) { + return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Failed to convert bias data to float: ", e.what()); + } + } + + float* float_output = nullptr; + try { + float_output = new float[num_elems]; + } catch (const std::exception& e) { + return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Failed to allocate memory for float output.", e.what()); } - float* float_output = new float[num_elems]; for (size_t h = 0; h < num_elems; h++) { float val = float_input[h] + float_skip[h]; - if (nullptr != bias_data) { + if (nullptr != float_bias) { val += float_bias[h]; } @@ -141,12 +165,17 @@ void ComputeJob( mean += val; mean_square += val * val; } + if (float_bias != nullptr) { delete[] float_bias; } if (nullptr != p_skip_input_bias_add_output) { - MlasConvertFloatToHalfBuffer(float_output, p_skip_input_bias_add_output, num_elems); + try { + MlasConvertFloatToHalfBuffer(float_output, p_skip_input_bias_add_output, num_elems); + } catch (const std::exception& e) { + return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Failed to convert skip_input_bias_add_output data to MLFLoat16: ", e.what()); + } } mean = mean / hidden_size; @@ -157,23 +186,43 @@ void ComputeJob( } float* float_gamma = float_input; // overwrite float_input with gamma values, since they have the same size - MlasConvertHalfToFloatBuffer(gamma_data, float_gamma, num_elems); - float* float_beta = float_skip; // overwrite float_skip with beta values, since they have the same size - MlasConvertHalfToFloatBuffer(beta_data, float_beta, num_elems); + try { + MlasConvertHalfToFloatBuffer(gamma_data, float_gamma, num_elems); + } catch (const std::exception& e) { + return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Failed to convert gamma data to float: ", e.what()); + } + + float* float_beta = nullptr; // overwrite float_skip with beta values, since they have the same size + if (beta_data) { + float_beta = float_skip; + try { + MlasConvertHalfToFloatBuffer(beta_data, float_beta, num_elems); + } catch (const std::exception& e) { + return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Failed to convert beta data to float: ", e.what()); + } + } + for (size_t h = 0; h < num_elems; h++) { if (simplified) { float_output[h] = float_output[h] / mean_square * float_gamma[h]; - } else if (nullptr == beta_data) { + } else if (nullptr == float_beta) { float_output[h] = (float_output[h] - mean) / mean_square * float_gamma[h]; } else { float_output[h] = (float_output[h] - mean) / mean_square * float_gamma[h] + float_beta[h]; } } delete[] float_gamma; // also deletes float_input - delete[] float_beta; // also deletes float_skip + delete[] float_skip; // also deletes float_beta if used + + try { + MlasConvertFloatToHalfBuffer(float_output, p_output, num_elems); + } catch (const std::exception& e) { + return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Failed to convert float output data to MLFLoat16: ", e.what()); + } - MlasConvertFloatToHalfBuffer(float_output, p_output, num_elems); delete[] float_output; + + return Status::OK(); } } // namespace @@ -211,27 +260,43 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { int64_t task_count = input->Shape().SizeToDimension(input_dims_size - 1); const T* input_data = input->Data(); + if (!input_data) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "The input data should not be null."); + } const T* skip_data = skip->Data(); + if (!skip_data) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "The skip data should not be null."); + } const T* gamma_data = gamma->Data(); + if (!gamma_data) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "The gamma data should not be null."); + } const T* beta_data = beta == nullptr ? nullptr : beta->Data(); const T* bias_data = bias == nullptr ? nullptr : bias->Data(); T* output_data = output->MutableData(); + if (!output_data) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "The output data pointer should not be null."); + } // For inferencing, we support one more optional output which is the sum of the input and skip tensors T* skip_input_bias_add_output_data = skip_input_bias_add_output == nullptr ? nullptr : skip_input_bias_add_output->MutableData(); const int64_t& skip_size = skip->Shape().Size(); + auto return_status = Status::OK(); concurrency::ThreadPool::TryBatchParallelFor( p_ctx->GetOperatorThreadPool(), static_cast(task_count), [&](ptrdiff_t task_idx) { - ComputeJob(input_data, skip_data, gamma_data, beta_data, bias_data, task_idx, hidden_size, skip_size, epsilon_, - simplified, output_data, skip_input_bias_add_output_data); + auto status = ComputeJob(input_data, skip_data, gamma_data, beta_data, bias_data, task_idx, hidden_size, + skip_size, epsilon_, simplified, output_data, skip_input_bias_add_output_data); + if (status != Status::OK()) { + return_status = status; + } }, 0); - return Status::OK(); + return return_status; } } // namespace contrib diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc index 71dd5ab803263..4bd042ac59de3 100644 --- a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc +++ b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc @@ -18,7 +18,7 @@ namespace { template || std::is_same_v, void>> -void ComputeJob( +Status ComputeJob( const T* X_data, const T* scale_data, const T* bias_data, @@ -66,10 +66,12 @@ void ComputeJob( if (inv_std_dev_data != nullptr) { inv_std_dev_data[task_idx] = gsl::narrow_cast(1 / mean_square); } + + return Status::OK(); } template -void ComputeJob( +Status ComputeJob( const MLFloat16* X_data, const MLFloat16* scale_data, const MLFloat16* bias_data, @@ -87,10 +89,21 @@ void ComputeJob( float mean_square(0.0f); const size_t num_elems = static_cast(norm_size); - float* float_input = new float[num_elems]; - MlasConvertHalfToFloatBuffer(p_input, float_input, num_elems); + float* float_input = nullptr; + try { + float_input = new float[num_elems]; + MlasConvertHalfToFloatBuffer(p_input, float_input, num_elems); + } catch (const std::exception& e) { + return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Failed to convert input data to float: ", e.what()); + } + + float* float_output = nullptr; + try { + float_output = new float[num_elems]; + } catch (const std::exception& e) { + return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Failed to allocate memory for float output.", e.what()); + } - float* float_output = new float[num_elems]; for (size_t h = 0; h < num_elems; h++) { float_output[h] = float_input[h]; mean += float_input[h]; @@ -105,9 +118,22 @@ void ComputeJob( } float* float_scale = float_input; // overwrite float_input with scale values, since they have the same size - MlasConvertHalfToFloatBuffer(scale_data, float_scale, num_elems); - float* float_bias = new float[num_elems]; - MlasConvertHalfToFloatBuffer(bias_data, float_bias, num_elems); + try { + MlasConvertHalfToFloatBuffer(scale_data, float_scale, num_elems); + } catch (const std::exception& e) { + return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Failed to convert scale data to float: ", e.what()); + } + + float* float_bias = nullptr; + if (bias_data) { + try { + float_bias = new float[num_elems]; + MlasConvertHalfToFloatBuffer(bias_data, float_bias, num_elems); + } catch (const std::exception& e) { + return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Failed to convert bias data to float: ", e.what()); + } + } + for (size_t h = 0; h < num_elems; h++) { if (simplified) { float_output[h] = float_output[h] / mean_square * float_scale[h]; @@ -118,9 +144,16 @@ void ComputeJob( } } delete[] float_scale; // also deletes float_input - delete[] float_bias; + if (float_bias) { + delete[] float_bias; + } + + try { + MlasConvertFloatToHalfBuffer(float_output, p_output, num_elems); + } catch (const std::exception& e) { + return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Failed to convert float output data to MLFLoat16: ", e.what()); + } - MlasConvertFloatToHalfBuffer(float_output, p_output, num_elems); delete[] float_output; if (mean_data != nullptr) { @@ -131,6 +164,8 @@ void ComputeJob( if (inv_std_dev_data != nullptr) { inv_std_dev_data[task_idx] = MLFloat16(1 / mean_square); } + + return Status::OK(); } } // namespace @@ -148,7 +183,13 @@ Status LayerNormImpl::ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, flo const Tensor* scale = p_ctx->Input(1); const Tensor* bias = p_ctx->Input(2); const T* X_data = X->Data(); + if (!X_data) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "The input data should not be null."); + } const T* scale_data = scale->Data(); + if (!scale_data) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "The scale data should not be null."); + } const T* bias_data = (simplified || nullptr == bias) ? nullptr : bias->Data(); const TensorShape& x_shape = X->Shape(); @@ -156,6 +197,9 @@ Status LayerNormImpl::ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, flo const TensorShape& bias_shape = bias->Shape(); Tensor* Y = p_ctx->Output(0, x_shape); T* Y_data = Y->MutableData(); + if (!Y_data) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "The output data pointer should not be null."); + } const int64_t axis = HandleNegativeAxis(orig_axis, x_shape.NumDimensions()); @@ -226,11 +270,15 @@ Status LayerNormImpl::ComputeWithoutContext( scale_size, " and bias size of ", bias_size); } + auto return_status = Status::OK(); concurrency::ThreadPool::TryBatchParallelFor( thread_pool, static_cast(norm_count), [&](ptrdiff_t task_idx) { - ComputeJob(X_data, scale_data, bias_data, task_idx, norm_size, epsilon, simplified, - Y_data, mean_data, inv_std_dev_data); + auto status = ComputeJob(X_data, scale_data, bias_data, task_idx, norm_size, epsilon, simplified, + Y_data, mean_data, inv_std_dev_data); + if (status != Status::OK()) { + return_status = status; + } }, 0);