Skip to content

Commit

Permalink
SimplifiedLayerNormalization Fusion BFloat16 support for Llama-v2 on …
Browse files Browse the repository at this point in the history
…A100 (#18898)

### Description
<!-- Describe your changes. -->

Adds bfloat16 as a supported dtype for SimplifiedLayerNormFusion which
will provide speedup for Llama-v2 on A100 using bfloat16 numerical
format.

_layernorm_optimized_training.onnx exported in bfloat16 vs. float16:_

![image](https://github.com/microsoft/onnxruntime/assets/31260940/8c0a5f0f-5fcb-4637-bcd9-f34272ec0284)

### Repro Instructions

```python
from torch import nn
from onnxruntime.training.ortmodule import ORTModule, DebugOptions, LogLevel
import torch

dtype = torch.bfloat16
# dtype = torch.float16

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(784, 10, dtype=dtype)
        self.layernorm = nn.LayerNorm([784], dtype=dtype)

    def forward(self, x):
        x = x.view(x.shape[0], -1)
        x = self.layernorm(x)
        x = self.fc(x)

        return x

model = Net()
model = ORTModule(model, DebugOptions(save_onnx=True, onnx_prefix='layernorm', log_level=LogLevel.INFO))
model.to("cuda")

images = torch.randn((8, 28, 28), dtype=dtype).to("cuda")
output = model(images)
```

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

ONNX Runtime integration with Llama-v2 family of LLMs.

---------

Co-authored-by: Prathik Rao <[email protected]@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
  • Loading branch information
prathikr and Prathik Rao authored Feb 14, 2024
1 parent 18f76bd commit 5444070
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ Do not modify directly.*
|Sigmoid|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|Sign|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|SimplifiedLayerNormalization|*in* X:**T**<br> *in* scale:**V**<br> *out* Y:**V**<br> *out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float), tensor(float16)<br/> **U** = tensor(double), tensor(float)<br/> **V** = tensor(double), tensor(float), tensor(float16)|
|SimplifiedLayerNormalization|*in* X:**T**<br> *in* scale:**V**<br> *out* Y:**V**<br> *out* inv_std_var:**U**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **U** = tensor(double), tensor(float)<br/> **V** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|Sin|*in* input:**T**<br> *out* output:**T**|7+|**T** = tensor(double), tensor(float), tensor(float16)|
|Size|*in* data:**T**<br> *out* size:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, BFloat16_float_BFloat16, SimplifiedLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits);
Expand Down Expand Up @@ -318,6 +319,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, BFloat16_float_BFloat16, SimplifiedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits)>,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/layer_norm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using namespace onnxruntime::common;
namespace onnxruntime {

// LayerNorm supports limited data types.
static constexpr std::array<std::string_view, 3> supported_data_types{"tensor(float16)", "tensor(float)", "tensor(double)"};
static constexpr std::array<std::string_view, 4> supported_data_types{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"};
// Default epsilon
static constexpr float DEFAULT_LAYERNORM_EPSILON = 1e-5f;

Expand Down
23 changes: 23 additions & 0 deletions onnxruntime/test/contrib_ops/layer_norm_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "core/session/inference_session.h"
#include "test/common/dnnl_op_test_utils.h"
#include "test/common/tensor_op_test_utils.h"
#include "test/common/cuda_op_test_utils.h"
#include "test/framework/test_utils.h"
#include "test/util/include/default_providers.h"
#include "test/providers/provider_test_utils.h"
Expand Down Expand Up @@ -75,6 +76,28 @@ TEST(LayerNormTest, LayerNorm) {
test.Run();
}

TEST(LayerNormTest, LayerNorm_BFloat16Input) {
// prevents test from running on non-BF16-supporting hardware
#ifdef USE_CUDA
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16";
return;
}
#endif
OpTester test("LayerNormalization");
test.AddAttribute<float>("epsilon", 1e-05f);

std::vector<int64_t> dims{1, 2, 3};
test.AddInput<BFloat16>("x", dims, MakeBFloat16({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}));
test.AddInput<BFloat16>("gamma", {3}, MakeBFloat16({1.0f, 1.0f, 1.0f}));
test.AddOutput<BFloat16>("output", dims, MakeBFloat16({-1.2247f, 0.0f, 1.2247f, -1.2247f, 0.0f, 1.2247f}));
// TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider,
kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider});
}

TEST(LayerNormTest, LayerNorm_Scale) {
OpTester test("LayerNormalization");
test.AddAttribute<float>("epsilon", 1e-05f);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, MixedPrecisionScale);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float_BFloat16, LayerNormalizationGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float_BFloat16, SimplifiedLayerNormalizationGrad);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float, ReduceAllL2);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_BFloat16, ReduceAllL2);
Expand Down Expand Up @@ -452,6 +453,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, MixedPrecisionScale)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float_BFloat16, LayerNormalizationGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float_BFloat16, SimplifiedLayerNormalizationGrad)>,

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float, ReduceAllL2)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_BFloat16, ReduceAllL2)>,
Expand Down

0 comments on commit 5444070

Please sign in to comment.