Skip to content

Commit

Permalink
reuse allocated arrays when possible
Browse files Browse the repository at this point in the history
  • Loading branch information
amarin16 committed Oct 2, 2024
1 parent fd904f6 commit a41b802
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 26 deletions.
27 changes: 12 additions & 15 deletions onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,22 +119,17 @@ void ComputeJob(
float mean_square(0.0f);

const size_t num_elems = static_cast<size_t>(hidden_size);
float* float_output = new float[num_elems];
float* float_input = new float[num_elems];
MlasConvertHalfToFloatBuffer(p_input, float_input, num_elems);
float* float_skip = new float[num_elems];
float* float_gamma = new float[num_elems];
float* float_beta = new float[num_elems];
MlasConvertHalfToFloatBuffer(p_skip, float_skip, num_elems);
float* float_bias = nullptr;
if (bias_data != nullptr) {
float_bias = new float[num_elems];
MlasConvertHalfToFloatBuffer(bias_data, float_bias, num_elems);
}
MlasConvertFloatToHalfBuffer(float_output, p_output, num_elems);
MlasConvertHalfToFloatBuffer(p_input, float_input, num_elems);
MlasConvertHalfToFloatBuffer(p_skip, float_skip, num_elems);
MlasConvertHalfToFloatBuffer(gamma_data, float_gamma, num_elems);
MlasConvertHalfToFloatBuffer(beta_data, float_beta, num_elems);

float* float_output = new float[num_elems];
for (size_t h = 0; h < num_elems; h++) {
float val = float_input[h] + float_skip[h];

Expand All @@ -146,6 +141,9 @@ void ComputeJob(
mean += val;
mean_square += val * val;
}
if (float_bias != nullptr) {
delete[] float_bias;
}

if (nullptr != p_skip_input_bias_add_output) {
MlasConvertFloatToHalfBuffer(float_output, p_skip_input_bias_add_output, num_elems);
Expand All @@ -158,6 +156,10 @@ void ComputeJob(
mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon);
}

float* float_gamma = float_input; // overwrite float_input with gamma values, since they have the same size
MlasConvertHalfToFloatBuffer(gamma_data, float_gamma, num_elems);
float* float_beta = float_skip; // overwrite float_input with beta values, since they have the same size
MlasConvertHalfToFloatBuffer(beta_data, float_beta, num_elems);
for (size_t h = 0; h < num_elems; h++) {
if (simplified) {
float_output[h] = float_output[h] / mean_square * float_gamma[h];
Expand All @@ -167,16 +169,11 @@ void ComputeJob(
float_output[h] = (float_output[h] - mean) / mean_square * float_gamma[h] + float_beta[h];
}
}
delete[] float_gamma; // also deletes float_input
delete[] float_beta; // also deletes float_skip

MlasConvertFloatToHalfBuffer(float_output, p_output, num_elems);
delete[] float_output;
delete[] float_input;
delete[] float_skip;
delete[] float_gamma;
delete[] float_beta;
if (float_bias != nullptr) {
delete[] float_bias;
}
}

} // namespace
Expand Down
21 changes: 10 additions & 11 deletions onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,9 @@ void ComputeJob(

const size_t num_elems = static_cast<size_t>(norm_size);
float* float_input = new float[num_elems];
float* float_scale = new float[num_elems];
float* float_bias = new float[num_elems];
float* float_output = new float[num_elems];
MlasConvertHalfToFloatBuffer(p_input, float_input, num_elems);
MlasConvertHalfToFloatBuffer(scale_data, float_scale, num_elems);
MlasConvertHalfToFloatBuffer(bias_data, float_bias, num_elems);
MlasConvertFloatToHalfBuffer(float_output, p_output, num_elems);

float* float_output = new float[num_elems];
for (size_t h = 0; h < num_elems; h++) {
float_output[h] = float_input[h];
mean += float_input[h];
Expand All @@ -109,6 +104,10 @@ void ComputeJob(
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
}

float* float_scale = float_input; // overwrite float_input with scale values, since they have the same size
MlasConvertHalfToFloatBuffer(scale_data, float_scale, num_elems);
float* float_bias = new float[num_elems];
MlasConvertHalfToFloatBuffer(bias_data, float_bias, num_elems);
for (size_t h = 0; h < num_elems; h++) {
if (simplified) {
float_output[h] = float_output[h] / mean_square * float_scale[h];
Expand All @@ -118,6 +117,11 @@ void ComputeJob(
float_output[h] = (float_output[h] - mean) / mean_square * float_scale[h] + float_bias[h];
}
}
delete[] float_scale; // also deletes float_input
delete[] float_bias;

MlasConvertFloatToHalfBuffer(float_output, p_output, num_elems);
delete[] float_output;

if (mean_data != nullptr) {
// ONNX spec doesn't support 'double' for 'U' so when 'T' == double, 'U' == float and we need to narrow
Expand All @@ -127,11 +131,6 @@ void ComputeJob(
if (inv_std_dev_data != nullptr) {
inv_std_dev_data[task_idx] = MLFloat16(1 / mean_square);
}

delete[] float_input;
delete[] float_output;
delete[] float_scale;
delete[] float_bias;
}

} // namespace
Expand Down

0 comments on commit a41b802

Please sign in to comment.