Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add microbenchmark for layer normalization and improve latency #22223

Merged
merged 36 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
2b8cd17
Add microbenchmark for layer normalization
amarin16 Sep 25, 2024
0c89631
fix warnings
amarin16 Sep 25, 2024
bca13ca
initialize test input data at compile time
amarin16 Sep 26, 2024
680cf4f
remove unused specialization that fails on pipeline
amarin16 Sep 26, 2024
f0df526
fix build on linux
amarin16 Sep 30, 2024
87725c3
convert all inputs to float efficiently if needed
amarin16 Sep 30, 2024
8aa80da
convert output buffer efficiently in layer_norm_impl
amarin16 Sep 30, 2024
295d652
convert output buffer efficiently in skip_layer_norm
amarin16 Sep 30, 2024
405a0a0
add inline and fix some lint issues
amarin16 Sep 30, 2024
245f298
fix some lint errors
amarin16 Sep 30, 2024
f398b64
fix warning
amarin16 Sep 30, 2024
a483ca4
maybe_unused
amarin16 Oct 1, 2024
19d225a
Fix bug
amarin16 Oct 1, 2024
05b5037
separate MLFloat16 implementation in skip_layer_norm
amarin16 Oct 1, 2024
ab2e5f2
fix linter issues
amarin16 Oct 1, 2024
63e9644
fix precision warning
amarin16 Oct 1, 2024
11eb7fb
cast
amarin16 Oct 2, 2024
46775a7
separate implementation for MLFloat16 inside layer_norm_impl
amarin16 Oct 2, 2024
fd904f6
don't use vectors
amarin16 Oct 2, 2024
a41b802
reuse allocated arrays when possible
amarin16 Oct 2, 2024
6aece95
make_unique instead of new
amarin16 Oct 2, 2024
766c4b2
Revert "make_unique instead of new" for latency
amarin16 Oct 2, 2024
cb55d4b
lint
amarin16 Oct 2, 2024
2895f37
fix bug
amarin16 Oct 2, 2024
f93ccb7
fix bug
amarin16 Oct 2, 2024
4be0255
handle errors
amarin16 Oct 3, 2024
48ce979
remove checks on tensor data
amarin16 Oct 3, 2024
3d6b990
remove try/catch due to -fno-exceptions
amarin16 Oct 3, 2024
f04aac0
Prepack scale and bias in layer_norm_impl
amarin16 Oct 9, 2024
1eaa63f
Prepack skip, gamma, beta, bias in skip_layer_norm
amarin16 Oct 9, 2024
26ddc6c
return void from ComputeJob
amarin16 Oct 9, 2024
3231cff
lint
amarin16 Oct 9, 2024
2a37a92
Use GenerateArrayWithRandomValue in microbenchmark
amarin16 Oct 14, 2024
d8b11ab
Use allocator instead of new
amarin16 Oct 14, 2024
402b65d
lint
amarin16 Oct 14, 2024
57c3e63
switch to IAllocator::MakeUniquePtr
amarin16 Oct 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1128,7 +1128,8 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
${BENCHMARK_DIR}/gelu.cc
${BENCHMARK_DIR}/activation.cc
${BENCHMARK_DIR}/quantize.cc
${BENCHMARK_DIR}/reduceminmax.cc)
${BENCHMARK_DIR}/reduceminmax.cc
${BENCHMARK_DIR}/layer_normalization.cc)
target_include_directories(onnxruntime_benchmark PRIVATE ${ONNXRUNTIME_ROOT} ${onnxruntime_graph_header} ${ONNXRUNTIME_ROOT}/core/mlas/inc)
target_compile_definitions(onnxruntime_benchmark PRIVATE BENCHMARK_STATIC_DEFINE)
if(WIN32)
Expand Down
318 changes: 221 additions & 97 deletions onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "core/framework/tensor.h"
#include "core/mlas/inc/mlas.h"
#include "core/util/math_cpuonly.h"
#include "core/providers/common.h"
#include "core/platform/threadpool.h"
Expand Down Expand Up @@ -36,52 +37,207 @@
REGISTER_KERNEL_TYPED(double)
REGISTER_KERNEL_TYPED(MLFloat16)

// Utility to convert from MLFloat16 to float only when the input type is MLFloat16.
template <typename T, typename Ret>
ORT_FORCEINLINE Ret ConvertMLFloat16ToDoubleOrFloatIfNeeded(T val);

template <>
ORT_FORCEINLINE float ConvertMLFloat16ToDoubleOrFloatIfNeeded<MLFloat16, float>(MLFloat16 val) {
return val.ToFloat();
}

template <>
ORT_FORCEINLINE double ConvertMLFloat16ToDoubleOrFloatIfNeeded<MLFloat16, double>(MLFloat16 val) {
return static_cast<double>(ConvertMLFloat16ToDoubleOrFloatIfNeeded<MLFloat16, float>(val));
namespace {

template <typename T, typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double>, void>>
void ComputeJob(
const T* input_data,

Check warning on line 44 in onnxruntime/contrib_ops/cpu/skip_layer_norm.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/contrib_ops/cpu/skip_layer_norm.cc:44: Do not indent within a namespace. [whitespace/indent_namespace] [4]
const T* skip_data,

Check warning on line 45 in onnxruntime/contrib_ops/cpu/skip_layer_norm.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/contrib_ops/cpu/skip_layer_norm.cc:45: Do not indent within a namespace. [whitespace/indent_namespace] [4]
const T* gamma_data,

Check warning on line 46 in onnxruntime/contrib_ops/cpu/skip_layer_norm.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/contrib_ops/cpu/skip_layer_norm.cc:46: Do not indent within a namespace. [whitespace/indent_namespace] [4]
const T* beta_data,

Check warning on line 47 in onnxruntime/contrib_ops/cpu/skip_layer_norm.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/contrib_ops/cpu/skip_layer_norm.cc:47: Do not indent within a namespace. [whitespace/indent_namespace] [4]
const T* bias_data,

Check warning on line 48 in onnxruntime/contrib_ops/cpu/skip_layer_norm.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/contrib_ops/cpu/skip_layer_norm.cc:48: Do not indent within a namespace. [whitespace/indent_namespace] [4]
const IAllocatorUniquePtr<float>& skip_fp32,

Check warning on line 49 in onnxruntime/contrib_ops/cpu/skip_layer_norm.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/contrib_ops/cpu/skip_layer_norm.cc:49: Do not indent within a namespace. [whitespace/indent_namespace] [4]
const IAllocatorUniquePtr<float>& gamma_fp32,

Check warning on line 50 in onnxruntime/contrib_ops/cpu/skip_layer_norm.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/contrib_ops/cpu/skip_layer_norm.cc:50: Do not indent within a namespace. [whitespace/indent_namespace] [4]
const IAllocatorUniquePtr<float>& beta_fp32,

Check warning on line 51 in onnxruntime/contrib_ops/cpu/skip_layer_norm.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/contrib_ops/cpu/skip_layer_norm.cc:51: Do not indent within a namespace. [whitespace/indent_namespace] [4]
const IAllocatorUniquePtr<float>& bias_fp32,

Check warning on line 52 in onnxruntime/contrib_ops/cpu/skip_layer_norm.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/contrib_ops/cpu/skip_layer_norm.cc:52: Do not indent within a namespace. [whitespace/indent_namespace] [4]
ptrdiff_t task_idx,

Check warning on line 53 in onnxruntime/contrib_ops/cpu/skip_layer_norm.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/contrib_ops/cpu/skip_layer_norm.cc:53: Do not indent within a namespace. [whitespace/indent_namespace] [4]
int hidden_size,
int64_t skip_size,
float epsilon,
bool simplified,
T* output_data,
T* skip_input_bias_add_output_data,
AllocatorPtr alloc) {
ORT_UNUSED_PARAMETER(skip_fp32); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(gamma_fp32); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(beta_fp32); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(bias_fp32); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(alloc);

auto offset = task_idx * hidden_size;
const T* p_input = input_data + offset;
const T* p_skip = skip_data + (offset % skip_size);
T* p_output = output_data + offset;
T* p_skip_input_bias_add_output = skip_input_bias_add_output_data == nullptr ? nullptr : skip_input_bias_add_output_data + offset;

T mean(0.0f);
T mean_square(0.0f);

for (decltype(hidden_size) h = 0; h < hidden_size; h++) {
T val = p_input[h] + p_skip[h];

if (nullptr != bias_data) {
val += bias_data[h];
}

if (nullptr != p_skip_input_bias_add_output) {
p_skip_input_bias_add_output[h] = val;
}

p_output[h] = val;
mean += val;
mean_square += val * val;
}

mean = mean / hidden_size;
if (simplified) {
mean_square = sqrt(mean_square / hidden_size + epsilon);
} else {
mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon);
}

for (decltype(hidden_size) h = 0; h < hidden_size; h++) {
if (simplified) {
p_output[h] = p_output[h] / mean_square * gamma_data[h];
} else if (nullptr == beta_data) {
p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h];
} else {
p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h] + beta_data[h];
}
}
}

template <>
ORT_FORCEINLINE constexpr float ConvertMLFloat16ToDoubleOrFloatIfNeeded<float, float>(float val) {
return val;
void ComputeJob(
const MLFloat16* input_data,
const MLFloat16* skip_data,
const MLFloat16* gamma_data,
const MLFloat16* beta_data,
const MLFloat16* bias_data,
const IAllocatorUniquePtr<float>& skip_fp32,
const IAllocatorUniquePtr<float>& gamma_fp32,
const IAllocatorUniquePtr<float>& beta_fp32,
const IAllocatorUniquePtr<float>& bias_fp32,
ptrdiff_t task_idx,
int hidden_size,
int64_t skip_size,
float epsilon,
bool simplified,
MLFloat16* output_data,
MLFloat16* skip_input_bias_add_output_data,
AllocatorPtr alloc) {
auto offset = task_idx * hidden_size;
const MLFloat16* p_input = input_data + offset;
const MLFloat16* p_skip = skip_data + (offset % skip_size);
MLFloat16* p_output = output_data + offset;
MLFloat16* p_skip_input_bias_add_output = skip_input_bias_add_output_data == nullptr ? nullptr : skip_input_bias_add_output_data + offset;

float mean(0.0f);
float mean_square(0.0f);
const size_t num_elems = static_cast<size_t>(hidden_size);

float* float_input = (float*)alloc->Alloc(num_elems * sizeof(float));
MlasConvertHalfToFloatBuffer(p_input, float_input, num_elems);

float* float_skip = skip_fp32.get();
if (nullptr == float_skip) {
float_skip = (float*)alloc->Alloc(num_elems * sizeof(float));
MlasConvertHalfToFloatBuffer(p_skip, float_skip, num_elems);
}

float* float_bias = nullptr;
if (bias_data) {
if (nullptr != bias_fp32) {
float_bias = bias_fp32.get();
} else {
float_bias = (float*)alloc->Alloc(num_elems * sizeof(float));
MlasConvertHalfToFloatBuffer(bias_data, float_bias, num_elems);
}
}

float* float_output = (float*)alloc->Alloc(num_elems * sizeof(float));

for (size_t h = 0; h < num_elems; h++) {
float val = float_input[h] + float_skip[h];

if (nullptr != float_bias) {
val += float_bias[h];
}

float_output[h] = val;
mean += val;
mean_square += val * val;
}

if (float_bias && (nullptr == bias_fp32)) {
delete[] float_bias;
Fixed Show fixed Hide fixed
}
amarin16 marked this conversation as resolved.
Show resolved Hide resolved

if (nullptr != p_skip_input_bias_add_output) {
MlasConvertFloatToHalfBuffer(float_output, p_skip_input_bias_add_output, num_elems);
}

mean = mean / hidden_size;
if (simplified) {
mean_square = sqrt(mean_square / hidden_size + epsilon);
} else {
mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon);
}

float* float_gamma = gamma_fp32.get();
if (nullptr == float_gamma) {
float_gamma = float_input; // overwrite float_input with gamma values, since they have the same size
MlasConvertHalfToFloatBuffer(gamma_data, float_gamma, num_elems);
}

float* float_beta = nullptr;
if (beta_data) {
if (nullptr != beta_fp32) {
float_beta = beta_fp32.get();
} else {
float_beta = (float*)alloc->Alloc(num_elems * sizeof(float));
MlasConvertHalfToFloatBuffer(beta_data, float_beta, num_elems);
}
}

for (size_t h = 0; h < num_elems; h++) {
if (simplified) {
float_output[h] = float_output[h] / mean_square * float_gamma[h];
} else if (nullptr == float_beta) {
float_output[h] = (float_output[h] - mean) / mean_square * float_gamma[h];
} else {
float_output[h] = (float_output[h] - mean) / mean_square * float_gamma[h] + float_beta[h];
}
}

alloc->Free(float_input); // also takes care of float_gamma if reused
amarin16 marked this conversation as resolved.
Show resolved Hide resolved
if (float_skip && (nullptr == skip_fp32)) {
alloc->Free(float_skip);
}
if (beta_data && (nullptr == beta_fp32)) {
alloc->Free(float_beta);
}

MlasConvertFloatToHalfBuffer(float_output, p_output, num_elems);
alloc->Free(float_output);
}

template <>
ORT_FORCEINLINE constexpr double ConvertMLFloat16ToDoubleOrFloatIfNeeded<double, double>(double val) {
return val;
}
void ConvertMLFloat16ToFloatIfNeeded(const Tensor& tensor, AllocatorPtr alloc, IAllocatorUniquePtr<float>& dest, bool& is_packed) {
if (tensor.GetElementType() == utils::ToTensorProtoElementType<MLFloat16>()) {
auto tensor_data_ptr = tensor.Data<MLFloat16>();
auto tensor_size = static_cast<size_t>(tensor.Shape().Size());
auto float_ptr = IAllocator::MakeUniquePtr<float>(alloc, tensor_size, true);

// Function template that only converts the input value to MLFloat16 if T is MLFloat16.
template <typename T>
ORT_FORCEINLINE constexpr typename std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double>, T>
ConvertDoubleOrFloatToMLFloat16IfNeeded(T val) {
return val;
MlasConvertHalfToFloatBuffer(tensor_data_ptr, float_ptr.get(), tensor_size);
dest = std::move(float_ptr);
is_packed = true;
}
}

template <typename T>
ORT_FORCEINLINE constexpr typename std::enable_if_t<std::is_same_v<T, MLFloat16>, T>
ConvertDoubleOrFloatToMLFloat16IfNeeded(float val) {
return MLFloat16(val);
}

template <typename T>
ORT_FORCEINLINE constexpr typename std::enable_if_t<std::is_same_v<T, MLFloat16>, T>
ConvertDoubleOrFloatToMLFloat16IfNeeded(double val) {
return MLFloat16(static_cast<float>(val));
}
} // namespace

template <typename T, bool simplified>
SkipLayerNorm<T, simplified>::SkipLayerNorm(const OpKernelInfo& op_kernel_info)
: OpKernel(op_kernel_info) {
: OpKernel(op_kernel_info), skip_fp32_(nullptr), gamma_fp32_(nullptr), beta_fp32_(nullptr), bias_fp32_(nullptr) {
ORT_ENFORCE(op_kernel_info.GetAttr<float>("epsilon", &epsilon_).IsOK());
ORT_ENFORCE(epsilon_ >= 0);
}
Expand All @@ -94,8 +250,7 @@
const Tensor* beta = p_ctx->Input<Tensor>(3);
const Tensor* bias = p_ctx->Input<Tensor>(4);
Tensor* output = p_ctx->Output(0, input->Shape());
// For inferencing, we support one more optional output which is the sum
// of the input and skip tensors
// For inferencing, we support one more optional output which is the sum of the input and skip tensors
Tensor* skip_input_bias_add_output = p_ctx->Output(3, input->Shape());

const auto& input_dims = input->Shape().GetDims();
Expand All @@ -120,75 +275,44 @@

T* output_data = output->MutableData<T>();

// For inferencing, we support one more optional output which is the sum
// of the input and skip tensors
T* skip_input_bias_add_output_data = skip_input_bias_add_output != nullptr ? skip_input_bias_add_output->MutableData<T>() : nullptr;
// For inferencing, we support one more optional output which is the sum of the input and skip tensors
T* skip_input_bias_add_output_data = skip_input_bias_add_output == nullptr ? nullptr : skip_input_bias_add_output->MutableData<T>();

const auto& skip_size = skip->Shape().Size();
const int64_t& skip_size = skip->Shape().Size();

AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc));

concurrency::ThreadPool::TryBatchParallelFor(
p_ctx->GetOperatorThreadPool(), static_cast<int32_t>(task_count),
[&](ptrdiff_t task_idx) {
auto offset = task_idx * hidden_size;

const T* p_input = input_data + offset;
const T* p_skip = skip_data + (offset % skip_size);
T* p_output = output_data + offset;
T* p_skip_input_bias_add_output_data = skip_input_bias_add_output_data != nullptr ? skip_input_bias_add_output_data + offset : nullptr;

using DoubleOrFloat = typename std::conditional<
std::is_same<T, double>::value, // If T is double
double, // Use double
float // Otherwise, use float (covers float and MLFloat16)
>::type;

DoubleOrFloat mean(0.0f);
DoubleOrFloat mean_square(0.0f);

std::unique_ptr<DoubleOrFloat[]> output_buffer = std::make_unique<DoubleOrFloat[]>(hidden_size);
for (size_t h = 0; h < static_cast<size_t>(hidden_size); h++) {
DoubleOrFloat input_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(p_input[h]);
DoubleOrFloat skip_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(p_skip[h]);

DoubleOrFloat value = input_value + skip_value;

if (nullptr != bias_data) {
value += ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(bias_data[h]);
}

output_buffer[h] = value;
T converted_value = ConvertDoubleOrFloatToMLFloat16IfNeeded<T>(value);
if (nullptr != p_skip_input_bias_add_output_data) {
p_skip_input_bias_add_output_data[h] = converted_value;
}

mean += value;
mean_square += value * value;
}

mean = mean / hidden_size;
if (simplified) {
mean_square = sqrt(mean_square / hidden_size + epsilon_);
} else {
mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon_);
}

for (size_t h = 0; h < static_cast<size_t>(hidden_size); h++) {
DoubleOrFloat gamma_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(gamma_data[h]);
if (simplified) {
p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded<T>(output_buffer[h] / mean_square * gamma_value);
} else if (nullptr == beta_data) {
p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded<T>((output_buffer[h] - mean) / mean_square * gamma_value);
} else {
DoubleOrFloat beta_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(beta_data[h]);
p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded<T>((output_buffer[h] - mean) / mean_square * gamma_value + beta_value);
}
}
ComputeJob(input_data, skip_data, gamma_data, beta_data, bias_data, skip_fp32_, gamma_fp32_, beta_fp32_,
bias_fp32_, task_idx, hidden_size, skip_size, epsilon_, simplified, output_data,
skip_input_bias_add_output_data, alloc);
},
0);

return Status::OK();
}

template <typename T, bool simplified>
Status SkipLayerNorm<T, simplified>::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
bool& is_packed, PrePackedWeights* prepacked_weights) {
ORT_UNUSED_PARAMETER(prepacked_weights);

is_packed = false;
if (input_idx == 1) { // skip
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, skip_fp32_, is_packed);
} else if (input_idx == 2) { // gamma
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, gamma_fp32_, is_packed);
} else if (input_idx == 3) { // beta
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, beta_fp32_, is_packed);
} else if (input_idx == 4) { // bias
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, bias_fp32_, is_packed);
}

return Status::OK();
}

} // namespace contrib
} // namespace onnxruntime
7 changes: 7 additions & 0 deletions onnxruntime/contrib_ops/cpu/skip_layer_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,15 @@ class SkipLayerNorm final : public OpKernel {
SkipLayerNorm(const OpKernelInfo& op_kernel_info);
Status Compute(OpKernelContext* p_op_kernel_context) const override;

Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
bool& is_packed, PrePackedWeights* prepacked_weights) override;

private:
float epsilon_;
IAllocatorUniquePtr<float> skip_fp32_;
IAllocatorUniquePtr<float> gamma_fp32_;
IAllocatorUniquePtr<float> beta_fp32_;
IAllocatorUniquePtr<float> bias_fp32_;
};

} // namespace contrib
Expand Down
Loading
Loading