Skip to content

Commit

Permalink
separate implementation for MLFloat16 inside layer_norm_impl
Browse files Browse the repository at this point in the history
  • Loading branch information
amarin16 committed Oct 2, 2024
1 parent 11eb7fb commit 46775a7
Showing 1 changed file with 99 additions and 129 deletions.
228 changes: 99 additions & 129 deletions onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,75 +15,119 @@ namespace onnxruntime {

namespace {

ORT_FORCEINLINE double* OnlyCreateBufferIfMLFloat16(double* p_output, [[maybe_unused]] size_t num_elems) {
return p_output;
}
template <typename T,
typename U,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double>, void>>
void ComputeJob(
const T* X_data,
const T* scale_data,
const T* bias_data,
const ptrdiff_t task_idx,
const int64_t norm_size,
float epsilon,
bool simplified,
T* Y_data,
U* mean_data,
U* inv_std_dev_data) {
const T* p_input = X_data + task_idx * norm_size;
T* p_output = Y_data + task_idx * norm_size;

ORT_FORCEINLINE float* OnlyCreateBufferIfMLFloat16(float* p_output, [[maybe_unused]] size_t num_elems) {
return p_output;
}
T mean(0.0f);
T mean_square(0.0f);

ORT_FORCEINLINE float* OnlyCreateBufferIfMLFloat16(MLFloat16* p_output, size_t num_elems) {
return p_output == nullptr ? nullptr : new float[num_elems];
}

template <typename T>
ORT_FORCEINLINE std::shared_ptr<std::vector<float>> ConvertMLFloat16ToFloatBufferIfNeeded(
[[maybe_unused]] const T* p_input, [[maybe_unused]] size_t num_elems);
for (int64_t h = 0; h < norm_size; h++) {
p_output[h] = p_input[h];
mean += p_input[h];
mean_square += p_input[h] * p_input[h];
}

template <typename T>
ORT_FORCEINLINE std::shared_ptr<std::vector<float>> ConvertMLFloat16ToFloatBufferIfNeeded(
[[maybe_unused]] const std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double>, T>* p_input,
[[maybe_unused]] size_t num_elems) {
return nullptr;
}
mean = mean / norm_size;
if (simplified) {
mean_square = sqrt(mean_square / norm_size + epsilon);
} else {
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
}

template <>
std::shared_ptr<std::vector<float>> ConvertMLFloat16ToFloatBufferIfNeeded<MLFloat16>(const MLFloat16* p_input, size_t num_elems) {
if (!p_input) {
return nullptr;
for (int64_t h = 0; h < norm_size; h++) {
if (simplified) {
p_output[h] = p_output[h] / mean_square * scale_data[h];
} else if (nullptr == bias_data) {
p_output[h] = (p_output[h] - mean) / mean_square * scale_data[h];
} else {
p_output[h] = (p_output[h] - mean) / mean_square * scale_data[h] + bias_data[h];
}
}

// Efficiently convert all the MLFloat16 values to floats.
std::shared_ptr<std::vector<float>> vec = std::make_shared<std::vector<float>>(num_elems);
MlasConvertHalfToFloatBuffer(p_input, &(*vec)[0], num_elems);
if (mean_data != nullptr) {
// ONNX spec doesn't support 'double' for 'U' so when 'T' == double, 'U' == float and we need to narrow
mean_data[task_idx] = gsl::narrow_cast<float>(mean);
}

return vec;
if (inv_std_dev_data != nullptr) {
inv_std_dev_data[task_idx] = gsl::narrow_cast<float>(1 / mean_square);
}
}

void ConvertFloatBufferToMLFloat16(const float* output_buffer, MLFloat16* p_output, size_t num_elems) {
if (!output_buffer || !p_output) {
return;
template <typename U>
void ComputeJob(
const MLFloat16* X_data,
const MLFloat16* scale_data,
const MLFloat16* bias_data,
const ptrdiff_t task_idx,
const int64_t norm_size,
float epsilon,
bool simplified,
MLFloat16* Y_data,
U* mean_data,
U* inv_std_dev_data) {
const MLFloat16* p_input = X_data + task_idx * norm_size;
MLFloat16* p_output = Y_data + task_idx * norm_size;

float mean(0.0f);
float mean_square(0.0f);

std::vector<float> float_input(norm_size);
MlasConvertHalfToFloatBuffer(p_input, &float_input[0], norm_size);

std::vector<float> float_output(norm_size);
for (int64_t h = 0; h < norm_size; h++) {
float_output[h] = float_input[h];
mean += float_input[h];
mean_square += float_input[h] * float_input[h];
}

MlasConvertFloatToHalfBuffer(output_buffer, p_output, num_elems);
}
mean = mean / norm_size;
if (simplified) {
mean_square = sqrt(mean_square / norm_size + epsilon);
} else {
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
}

ORT_FORCEINLINE constexpr float ConvertToFloatIfNeeded(float val) {
return val;
}
std::vector<float> float_scale(norm_size);
MlasConvertHalfToFloatBuffer(scale_data, &float_scale[0], norm_size);
std::vector<float> float_bias(norm_size);
MlasConvertHalfToFloatBuffer(bias_data, &float_bias[0], norm_size);

ORT_FORCEINLINE constexpr float ConvertToFloatIfNeeded(double val) {
// ONNX spec doesn't support 'double' for 'Ret' so when 'T' == double, 'Ret' == float and we need to narrow
return gsl::narrow_cast<float>(val);
}
for (int64_t h = 0; h < norm_size; h++) {
if (simplified) {
float_output[h] = float_output[h] / mean_square * float_scale[h];
} else if (nullptr == bias_data) {
float_output[h] = (float_output[h] - mean) / mean_square * float_scale[h];
} else {
float_output[h] = (float_output[h] - mean) / mean_square * float_scale[h] + float_bias[h];
}
}

// Function template that only converts the input value to MLFloat16 if T is MLFloat16.
template <typename T>
ORT_FORCEINLINE constexpr typename std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double>, float>
ConvertToMLFloat16IfNeeded(float val) {
return val;
}
MlasConvertFloatToHalfBuffer(&float_output[0], p_output, static_cast<size_t>(norm_size));

template <typename T>
ORT_FORCEINLINE constexpr typename std::enable_if_t<std::is_same_v<T, MLFloat16>, MLFloat16>
ConvertToMLFloat16IfNeeded(float val) {
return MLFloat16(val);
}
if (mean_data != nullptr) {
// ONNX spec doesn't support 'double' for 'U' so when 'T' == double, 'U' == float and we need to narrow
mean_data[task_idx] = MLFloat16(mean);
}

template <typename T>
ORT_FORCEINLINE constexpr double ConvertToMLFloat16IfNeeded(double val) {
return val;
if (inv_std_dev_data != nullptr) {
inv_std_dev_data[task_idx] = MLFloat16(1 / mean_square);
}
}

} // namespace
Expand Down Expand Up @@ -180,82 +224,8 @@ Status LayerNormImpl::ComputeWithoutContext(
concurrency::ThreadPool::TryBatchParallelFor(
thread_pool, static_cast<int32_t>(norm_count),
[&](ptrdiff_t task_idx) {
const T* p_input = X_data + task_idx * norm_size;
T* p_output = Y_data + task_idx * norm_size;

using DoubleOrFloat = typename std::conditional<
std::is_same<T, double>::value, // If T is double
double, // Use double
float // Otherwise, use float (covers float and MLFloat16)
>::type;

DoubleOrFloat mean(0.0f);
DoubleOrFloat mean_square(0.0f);

std::shared_ptr<std::vector<float>> float_input = ConvertMLFloat16ToFloatBufferIfNeeded<T>(
p_input, static_cast<size_t>(norm_size));
const DoubleOrFloat* converted_input =
float_input == nullptr
? reinterpret_cast<const DoubleOrFloat*>(p_input)
: reinterpret_cast<const DoubleOrFloat*>(&(*float_input)[0]);

// If T is float or double, then output_buffer will be the same as p_output, so we don't allocate new memory.
// If T is MLFloat16, then we allocate norm_size floats in output_buffer.
DoubleOrFloat* output_buffer = static_cast<DoubleOrFloat*>(
OnlyCreateBufferIfMLFloat16(p_output, static_cast<size_t>(norm_size)));

for (int64_t h = 0; h < norm_size; h++) {
output_buffer[h] = converted_input[h];
mean += converted_input[h];
mean_square += converted_input[h] * converted_input[h];
}

mean = mean / norm_size;
if (simplified) {
mean_square = sqrt(mean_square / norm_size + epsilon);
} else {
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
}

std::shared_ptr<std::vector<float>> float_scale = ConvertMLFloat16ToFloatBufferIfNeeded<T>(
scale_data, static_cast<size_t>(norm_size));
const DoubleOrFloat* converted_scale =
float_scale == nullptr
? reinterpret_cast<const DoubleOrFloat*>(scale_data)
: reinterpret_cast<const DoubleOrFloat*>(&(*float_scale)[0]);
std::shared_ptr<std::vector<float>> float_bias = ConvertMLFloat16ToFloatBufferIfNeeded<T>(
bias_data, static_cast<size_t>(norm_size));
const DoubleOrFloat* converted_bias =
float_bias == nullptr
? reinterpret_cast<const DoubleOrFloat*>(bias_data)
: reinterpret_cast<const DoubleOrFloat*>(&(*float_bias)[0]);

for (int64_t h = 0; h < norm_size; h++) {
if (simplified) {
output_buffer[h] = output_buffer[h] / mean_square * converted_scale[h];
} else if (nullptr == bias_data) {
output_buffer[h] = (output_buffer[h] - mean) / mean_square * converted_scale[h];
} else {
output_buffer[h] = (output_buffer[h] - mean) / mean_square * converted_scale[h] + converted_bias[h];
}
}

if (std::is_same_v<decltype(p_output), MLFloat16*>) {
ConvertFloatBufferToMLFloat16(
reinterpret_cast<float*>(output_buffer),
reinterpret_cast<MLFloat16*>(p_output),
static_cast<size_t>(norm_size));
delete[] output_buffer;
}

if (mean_data != nullptr) {
// ONNX spec doesn't support 'double' for 'U' so when 'T' == double, 'U' == float and we need to narrow
mean_data[task_idx] = ConvertToMLFloat16IfNeeded<U>(ConvertToFloatIfNeeded(mean));
}

if (inv_std_dev_data != nullptr) {
inv_std_dev_data[task_idx] = ConvertToMLFloat16IfNeeded<U>(ConvertToFloatIfNeeded(1 / mean_square));
}
ComputeJob(X_data, scale_data, bias_data, task_idx, norm_size, epsilon, simplified,
Y_data, mean_data, inv_std_dev_data);
},
0);

Expand Down

0 comments on commit 46775a7

Please sign in to comment.