Skip to content

Commit

Permalink
handle errors
Browse files Browse the repository at this point in the history
  • Loading branch information
amarin16 committed Oct 3, 2024
1 parent f93ccb7 commit 4be0255
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 33 deletions.
107 changes: 86 additions & 21 deletions onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ REGISTER_KERNEL_TYPED(MLFloat16)
namespace {

template <typename T, typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double>, void>>
void ComputeJob(
Status ComputeJob(
const T* input_data,
const T* skip_data,
const T* gamma_data,
Expand Down Expand Up @@ -94,9 +94,11 @@ void ComputeJob(
p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h] + beta_data[h];
}
}

return Status::OK();
}

void ComputeJob(
Status ComputeJob(
const MLFloat16* input_data,
const MLFloat16* skip_data,
const MLFloat16* gamma_data,
Expand All @@ -117,36 +119,63 @@ void ComputeJob(

float mean(0.0f);
float mean_square(0.0f);

const size_t num_elems = static_cast<size_t>(hidden_size);
float* float_input = new float[num_elems];
MlasConvertHalfToFloatBuffer(p_input, float_input, num_elems);
float* float_skip = new float[num_elems];
MlasConvertHalfToFloatBuffer(p_skip, float_skip, num_elems);

float* float_input = nullptr;
try {
float_input = new float[num_elems];
MlasConvertHalfToFloatBuffer(p_input, float_input, num_elems);
} catch (const std::exception& e) {
return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Failed to convert input data to float: ", e.what());
}

float* float_skip = nullptr;
try {
float_skip = new float[num_elems];
MlasConvertHalfToFloatBuffer(p_skip, float_skip, num_elems);
} catch (const std::exception& e) {
return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Failed to convert skip data to float: ", e.what());
}

float* float_bias = nullptr;
if (bias_data != nullptr) {
float_bias = new float[num_elems];
MlasConvertHalfToFloatBuffer(bias_data, float_bias, num_elems);
try {
float_bias = new float[num_elems];
MlasConvertHalfToFloatBuffer(bias_data, float_bias, num_elems);
} catch (const std::exception& e) {
return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Failed to convert bias data to float: ", e.what());
}
}

float* float_output = nullptr;
try {
float_output = new float[num_elems];
} catch (const std::exception& e) {
return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Failed to allocate memory for float output.", e.what());
}

float* float_output = new float[num_elems];
for (size_t h = 0; h < num_elems; h++) {
float val = float_input[h] + float_skip[h];

if (nullptr != bias_data) {
if (nullptr != float_bias) {
val += float_bias[h];
}

float_output[h] = val;
mean += val;
mean_square += val * val;
}

if (float_bias != nullptr) {
delete[] float_bias;
}

if (nullptr != p_skip_input_bias_add_output) {
MlasConvertFloatToHalfBuffer(float_output, p_skip_input_bias_add_output, num_elems);
try {
MlasConvertFloatToHalfBuffer(float_output, p_skip_input_bias_add_output, num_elems);
} catch (const std::exception& e) {
return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Failed to convert skip_input_bias_add_output data to MLFLoat16: ", e.what());
}
}

mean = mean / hidden_size;
Expand All @@ -157,23 +186,43 @@ void ComputeJob(
}

float* float_gamma = float_input; // overwrite float_input with gamma values, since they have the same size
MlasConvertHalfToFloatBuffer(gamma_data, float_gamma, num_elems);
float* float_beta = float_skip; // overwrite float_skip with beta values, since they have the same size
MlasConvertHalfToFloatBuffer(beta_data, float_beta, num_elems);
try {
MlasConvertHalfToFloatBuffer(gamma_data, float_gamma, num_elems);
} catch (const std::exception& e) {
return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Failed to convert gamma data to float: ", e.what());
}

float* float_beta = nullptr; // overwrite float_skip with beta values, since they have the same size
if (beta_data) {
float_beta = float_skip;
try {
MlasConvertHalfToFloatBuffer(beta_data, float_beta, num_elems);
} catch (const std::exception& e) {
return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Failed to convert beta data to float: ", e.what());
}
}

for (size_t h = 0; h < num_elems; h++) {
if (simplified) {
float_output[h] = float_output[h] / mean_square * float_gamma[h];
} else if (nullptr == beta_data) {
} else if (nullptr == float_beta) {
float_output[h] = (float_output[h] - mean) / mean_square * float_gamma[h];
} else {
float_output[h] = (float_output[h] - mean) / mean_square * float_gamma[h] + float_beta[h];
}
}
delete[] float_gamma; // also deletes float_input
delete[] float_beta; // also deletes float_skip
delete[] float_skip; // also deletes float_beta if used

try {
MlasConvertFloatToHalfBuffer(float_output, p_output, num_elems);
} catch (const std::exception& e) {
return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Failed to convert float output data to MLFLoat16: ", e.what());
}

MlasConvertFloatToHalfBuffer(float_output, p_output, num_elems);
delete[] float_output;

return Status::OK();
}

} // namespace
Expand Down Expand Up @@ -211,27 +260,43 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
int64_t task_count = input->Shape().SizeToDimension(input_dims_size - 1);

const T* input_data = input->Data<T>();
if (!input_data) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "The input data should not be null.");
}
const T* skip_data = skip->Data<T>();
if (!skip_data) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "The skip data should not be null.");
}
const T* gamma_data = gamma->Data<T>();
if (!gamma_data) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "The gamma data should not be null.");
}
const T* beta_data = beta == nullptr ? nullptr : beta->Data<T>();
const T* bias_data = bias == nullptr ? nullptr : bias->Data<T>();

T* output_data = output->MutableData<T>();
if (!output_data) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "The output data pointer should not be null.");
}

// For inferencing, we support one more optional output which is the sum of the input and skip tensors
T* skip_input_bias_add_output_data = skip_input_bias_add_output == nullptr ? nullptr : skip_input_bias_add_output->MutableData<T>();

const int64_t& skip_size = skip->Shape().Size();

auto return_status = Status::OK();
concurrency::ThreadPool::TryBatchParallelFor(
p_ctx->GetOperatorThreadPool(), static_cast<int32_t>(task_count),
[&](ptrdiff_t task_idx) {
ComputeJob(input_data, skip_data, gamma_data, beta_data, bias_data, task_idx, hidden_size, skip_size, epsilon_,
simplified, output_data, skip_input_bias_add_output_data);
auto status = ComputeJob(input_data, skip_data, gamma_data, beta_data, bias_data, task_idx, hidden_size,
skip_size, epsilon_, simplified, output_data, skip_input_bias_add_output_data);
if (status != Status::OK()) {
return_status = status;
}
},
0);

return Status::OK();
return return_status;
}

} // namespace contrib
Expand Down
72 changes: 60 additions & 12 deletions onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace {
template <typename T,
typename U,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double>, void>>
void ComputeJob(
Status ComputeJob(
const T* X_data,
const T* scale_data,
const T* bias_data,
Expand Down Expand Up @@ -66,10 +66,12 @@ void ComputeJob(
if (inv_std_dev_data != nullptr) {
inv_std_dev_data[task_idx] = gsl::narrow_cast<float>(1 / mean_square);
}

return Status::OK();
}

template <typename U>
void ComputeJob(
Status ComputeJob(
const MLFloat16* X_data,
const MLFloat16* scale_data,
const MLFloat16* bias_data,
Expand All @@ -87,10 +89,21 @@ void ComputeJob(
float mean_square(0.0f);

const size_t num_elems = static_cast<size_t>(norm_size);
float* float_input = new float[num_elems];
MlasConvertHalfToFloatBuffer(p_input, float_input, num_elems);
float* float_input = nullptr;
try {
float_input = new float[num_elems];
MlasConvertHalfToFloatBuffer(p_input, float_input, num_elems);
} catch (const std::exception& e) {
return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Failed to convert input data to float: ", e.what());
}

float* float_output = nullptr;
try {
float_output = new float[num_elems];
} catch (const std::exception& e) {
return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Failed to allocate memory for float output.", e.what());
}

float* float_output = new float[num_elems];
for (size_t h = 0; h < num_elems; h++) {
float_output[h] = float_input[h];
mean += float_input[h];
Expand All @@ -105,9 +118,22 @@ void ComputeJob(
}

float* float_scale = float_input; // overwrite float_input with scale values, since they have the same size
MlasConvertHalfToFloatBuffer(scale_data, float_scale, num_elems);
float* float_bias = new float[num_elems];
MlasConvertHalfToFloatBuffer(bias_data, float_bias, num_elems);
try {
MlasConvertHalfToFloatBuffer(scale_data, float_scale, num_elems);
} catch (const std::exception& e) {
return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Failed to convert scale data to float: ", e.what());
}

float* float_bias = nullptr;
if (bias_data) {
try {
float_bias = new float[num_elems];
MlasConvertHalfToFloatBuffer(bias_data, float_bias, num_elems);
} catch (const std::exception& e) {
return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Failed to convert bias data to float: ", e.what());
}
}

for (size_t h = 0; h < num_elems; h++) {
if (simplified) {
float_output[h] = float_output[h] / mean_square * float_scale[h];
Expand All @@ -118,9 +144,16 @@ void ComputeJob(
}
}
delete[] float_scale; // also deletes float_input
delete[] float_bias;
if (float_bias) {
delete[] float_bias;
}

try {
MlasConvertFloatToHalfBuffer(float_output, p_output, num_elems);
} catch (const std::exception& e) {
return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Failed to convert float output data to MLFLoat16: ", e.what());
}

MlasConvertFloatToHalfBuffer(float_output, p_output, num_elems);
delete[] float_output;

if (mean_data != nullptr) {
Expand All @@ -131,6 +164,8 @@ void ComputeJob(
if (inv_std_dev_data != nullptr) {
inv_std_dev_data[task_idx] = MLFloat16(1 / mean_square);
}

return Status::OK();
}

} // namespace
Expand All @@ -148,14 +183,23 @@ Status LayerNormImpl::ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, flo
const Tensor* scale = p_ctx->Input<Tensor>(1);
const Tensor* bias = p_ctx->Input<Tensor>(2);
const T* X_data = X->Data<T>();
if (!X_data) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "The input data should not be null.");
}
const T* scale_data = scale->Data<T>();
if (!scale_data) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "The scale data should not be null.");
}
const T* bias_data = (simplified || nullptr == bias) ? nullptr : bias->Data<T>();

const TensorShape& x_shape = X->Shape();
const TensorShape& scale_shape = scale->Shape();
const TensorShape& bias_shape = bias->Shape();
Tensor* Y = p_ctx->Output(0, x_shape);
T* Y_data = Y->MutableData<T>();
if (!Y_data) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "The output data pointer should not be null.");
}

const int64_t axis = HandleNegativeAxis(orig_axis, x_shape.NumDimensions());

Expand Down Expand Up @@ -226,11 +270,15 @@ Status LayerNormImpl::ComputeWithoutContext(
scale_size, " and bias size of ", bias_size);
}

auto return_status = Status::OK();
concurrency::ThreadPool::TryBatchParallelFor(
thread_pool, static_cast<int32_t>(norm_count),
[&](ptrdiff_t task_idx) {
ComputeJob(X_data, scale_data, bias_data, task_idx, norm_size, epsilon, simplified,
Y_data, mean_data, inv_std_dev_data);
auto status = ComputeJob(X_data, scale_data, bias_data, task_idx, norm_size, epsilon, simplified,
Y_data, mean_data, inv_std_dev_data);
if (status != Status::OK()) {
return_status = status;
}
},
0);

Expand Down

0 comments on commit 4be0255

Please sign in to comment.