Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry-pick LLaMA/SDXL to rel-1.16.2 #18202

Merged
merged 40 commits into from
Nov 1, 2023
Merged

Conversation

tianleiwu
Copy link
Contributor

Description

Cherry-pick changes related to LLaMA and StableDiffusion XL to 1.16.2 release branch.

Motivation and Context

kunal-vaishnavi and others added 30 commits October 31, 2023 20:32
### Description
This PR adds the following scripts for LLaMA:
- LLaMA conversion (support for TorchScript and Dynamo exporters)
- LLaMA parity
- LLaMA benchmark
- LLaMA quantization
- LLaMA integration with [Hugging Face
Optimum](https://github.com/huggingface/optimum)



### Motivation and Context
This PR adds scripts for using LLaMA. There is a [follow-up
PR](#17043) for adding
scripts for Whisper.
### Description
This PR adds benchmark scripts for Whisper. It is a follow-up to [this
PR](#17020) that adds the
LLaMA scripts.



### Motivation and Context
This PR enables benchmarking Whisper across various configurations.
…ts (#17249)

### Description
Tested with stable diffusion unet models exported by both pytorch 2.1.0
(nightly) and pytorch 1.13.1, with and without LoRA weights.



### Motivation and Context
LoRA weights modifiy the unet model by adding matmul and scale
operations to every q/k/v/out tensors, which breaks the current MHA
pattern recognition.
Some operators have dtype attribute (search `dtype` in
https://github.com/onnx/onnx/blob/main/docs/Operators.md).
This change make sure dtype attribute is handled correctly in float16
conversion.
Add a check of num_heads and hidden_size to avoid assert error (#17254)
…ta (#17427)

Some initializers are added without raw=True flag. That causes those
tensors cannot be saved to external data. If those tensors exceed 2GB
in total, optimized model cannot be saved due to protobuf limit.

This change will save attention weights and bias in raw data.

Note: it is optional to use raw data for shape tensor since they are
tiny.

### Motivation and Context
#17212
#15349
The embedding sum could be graph output (when exporting with output
hidden state enabled). Previously, we only check whether there are
multiple children node to decide whether to output embedding sum in
fused node. This fix will check if the sum is graph output, we will
retain the name.
Add attention fusion for stable diffusion clip model to improve performance of SD or SDXL
During optimization of SDXL UNet, the prune_graph takes up to 5 minutes.
The cause is to find a node in all nodes is time-consuming. This
optimization will reduce the latency of prune_graph to 2 seconds.

New algorithm will use a hash table (key is first node output, value is
node) to speed up.
- [x] Optimize SDXL models exported by optimum.
- [x] Enable it to run locally instead of using module.
- [x] Detect external data file in original model, and save with same
format by default.
- [x]  Add tests

### Example
```
pip install optimum transformers diffusers onnx onnxruntime-gpu>=1.16
optimum-cli export onnx --model stabilityai/stable-diffusion-xl-base-1.0 --task stable-diffusion-xl ./sd_xl_base_onnx
python -m  onnxruntime.transformers.models.stable_diffusion.optimize_pipeline -i ./sd_xl_base_onnx -o ./sd_xl_base_fp16 --float16
```

### Known issues
(1) VAE decoder cannot be converted to float16. Otherwise, there will be
black image in output.
(2) To use the float16 models, need a minor change in optimum to convert
the inputs for VAE decoder from float16 to float32 since we keep VAE
decoder as float32. The change is to append a line like the following
after [this
line](https://github.com/huggingface/optimum/blob/afd2b5a36663bebd1f501486acee065c728947bc/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl.py#L483)
```
latents = latents.astype(np.float32)
```
…t_to_present (#17559)

To avoid a huge cu file and make code more readable:
 - Move PrepareQKV to separate cu file (attention_prepare_qkv.cu)
 - Move ConcatPastToPresent to attention_concat.cu
 - Add default value for AttentionData
- Add a data structure QkvData to track Q, K and V pointers and track
QKV format.
### Description
This PR changes the Whisper export scripts to further optimize the
process of removing duplicate initializers from two subgraphs.

The current Greedy approach is quicker by a large factor, but results in
some duplicate initializers not being caught and removed. This not only
results in a slightly larger Whisper model, but also a model that uses
more GPU memory.

The approach in this PR uses data hashes and caches to keep a quick
export but no longer rely on a greedy approach.

---------

Co-authored-by: Peter McAughan <[email protected]>
### Description
Fixes a bug in `get_shared_initializers` where `signature_cache1,
signature_cache2` are passed as positional arguments to
`remove_shared_initializers` but their positions don't match the
function signature. So `signature_cache1` is passed to `min_elements`
and causes comparison error at line 907.

Pass the arguments as kwargs so that it doesn't rely on their positions.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Fixes the bug described above.
* Break QkvToContext into small functions. Each fused and unfused kernel
will have separated function.
* Move DecoderAttention kernel to separated file
* Move KV cache related kernel to attention_kv_cache.cu

### Motivation and Context
To make the code easier to maintain.
- update whisper benchmark for ROCm EP.
Accelerate StableDiffusion XL with TensorRT EP. It is modified from
TensorRT demo diffusion, and we updated the design to make the pipeline
works with different backend engines.

The following result is from A100 80GB with 30 steps of Base, or 30
steps Base & 30 Steps Refiner to generate 1024x1024 images. The engine
is built with static input shape, and cuda graph is enabled.

  | Batch Size | TRT Latency (ms) | ORT_TRT Latency (ms) | Diff
-- | -- | -- | -- | --
Base | 1 | 2714 | 2679 | -1.3%
Base & Refiner | 1 | 3593 | 3530 | -1.8%

The test environment: onnxruntime-gpu is built from source, and the following packages or
libraries are used in this test:
* tensorrt==8.6.1.post1
* torch==2.2.0.dev20230920+cu121
* transformers==4.31.0
* diffusers==0.19.3
* onnx==1.14.1
* onnx-graphsurgeon==0.3.27
* polygraphy==0.47.1
* protobuf==3.20.2
* onnxruntime-gpu==1.17.0 (built from source of main branch)
* CUDA 12.2.2
* cuDNN 8.9.5.29
* python 3.10.13
Add CUDA EP to the demo of stable diffusion.

### A100 Performance
Test | Engine Property | Batch Size | TRT Latency (ms) | ORT_TRT Latency
(ms) | ORT_CUDA Latency (ms) | TORCH Latency (ms)
-- | -- | -- | -- | -- | -- | --
SD 1.5, 50 steps, 512x512 | Static Input Shape | 1 | 861 | 851 | 861 |
N/A
SD 1.5, 50 steps, 512x512 | Dynamic Input Shape, Optimized for batch
size 1 and image size 512x512 | 1 | 974 | 1079 | 928 | 1222
SD 1.5, 50 steps, 768x768 | Dynamic Input Shape, Optimized for batch
size 1 and image size 512x512 | 1 | 2492 | OOM | 1901 | 1971
SD 1.5, 50 steps, 768x768 | Dynamic Input Shape, Optimized for batch
size 1 and image size 512x512 | 4 |9091 | OOM | 6785 | 6700

We can see that ORT_CUDA is the most robust one for handling dynamic
input shape. PyTorch could be a good choice if you run large batch size.

The above result is from one A100-SXM4-80GB GPU (in
Standard_ND96amsr_A100_v4 Azure VM) with 50 steps to generate 512x512 or
768x768 images using StableDiffusion 1.5. Onnxruntime-gpu is built from
source, and the following packages or libraries are used in this test:
* tensorrt==8.6.1.post1
* torch==2.2.0.dev20230920+cu121
* transformers==4.31.0
* diffusers==0.19.3
* onnx==1.14.1
* onnx-graphsurgeon==0.3.27
* polygraphy==0.47.1
* protobuf==3.20.2
* onnxruntime-gpu==1.17.0 (built from source of main branch)
* CUDA 12.2.2
* cuDNN 8.9.5.29
* python 3.10.13

For static input shape, the engine is built with static batch size and
static image shape, and cuda graph is enabled.

For dynamic input shape, the engine is built to support dynamic batch
size and dynamic image shape, and cuda graph is disabled. The TensorRT
engine is built for batch size 1~4, image size 256x256 ~ 1024x1024, and
the optimized image size is 512x512.

The script to test static and dynamic input shape are like the
following:
```
prompt="a cute magical flying dog, fantasy art drawn by disney concept artists, highly detailed, digital paintining"
for e in TRT ORT_TRT ORT_CUDA
do
  python demo_txt2img.py --engine $e "$prompt"
  python demo_txt2img.py --engine $e --disable-cuda-graph --build-dynamic-batch --build-dynamic-shape "$prompt"
  python demo_txt2img.py --engine $e --disable-cuda-graph --build-dynamic-batch --build-dynamic-shape --height 768 --width 768 "$prompt"
done
```

Performance of PyTorch is from commands like the following:
```
python benchmark.py -e torch -v 1.5 --enable_torch_compile -b 1 --height 512 --width 512
python benchmark.py -e torch -v 1.5 --enable_torch_compile -b 1 --height 768 --width 768
python benchmark.py -e torch -v 1.5 --enable_torch_compile -b 4 --height 768 --width 768
```
### Description
<!-- Describe your changes. --> 

Replace
onnxruntime::cuda::Transpose4DKernelParallelizeMultipleElementsPerThreadInInnermostDim()
with custom transpose kernel in ReorderPastState(). The original
implementation doesn't benefit from vectorized loading and coalesced
accessing(write). and not fully utilize threads in the block.

benchmarked with TNLGv4 model(batch=4, seq_len=4K)
transpose kernel speed up: ~1.9X (392 μs -> 206 μs)
overall reordering speedup: ~1.48X

Latency:
before:

![image](https://github.com/microsoft/onnxruntime/assets/52801275/34c7ab73-3da1-4c41-a036-e9fb6a966891)
after:

![image](https://github.com/microsoft/onnxruntime/assets/52801275/337818ec-9598-4d8a-9e9b-7215b6862498)

GPU matrix:
before:

![image](https://github.com/microsoft/onnxruntime/assets/52801275/4962248f-703c-49bd-8586-deaeccd9bce0)
after:

![image](https://github.com/microsoft/onnxruntime/assets/52801275/a795a892-4c5d-432d-8375-0bb67385d2bc)


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Your Name <[email protected]>
### Description
Added Group Query Attention op, supporting integer multiple number of
heads for Q / KV. As of now, this op can only use FlashAttention kernel,
meaning it only supports sm>=80 on Linux.

Results from onnxruntime/test/python/transformers/benchmark_gqa.py show
an on-average ~37% speed-up over Decoder Masked Multi-Head Attention,
with even greater improvements for long past sequence lengths.

```
op      batch   s_kv    heads   h_dim   ms      TFLOPS
gqa     16      2048    8       32      0.34    0.10
dmmha   16      2048    8       32      0.39    0.09
---------
gqa     16      2048    8       64      0.45    0.15
dmmha   16      2048    8       64      0.61    0.11
---------
gqa     16      2048    8       128     0.54    0.25
dmmha   16      2048    8       128     0.83    0.16
---------
gqa     16      2048    16      32      0.45    0.15
dmmha   16      2048    16      32      0.69    0.10
---------
gqa     16      2048    16      64      0.69    0.19
dmmha   16      2048    16      64      0.83    0.16
---------
gqa     16      2048    16      128     0.71    0.38
dmmha   16      2048    16      128     1.28    0.21
---------
gqa     16      2048    32      32      0.58    0.23
dmmha   16      2048    32      32      0.77    0.17
---------
gqa     16      2048    32      64      0.58    0.46
dmmha   16      2048    32      64      1.25    0.21
---------
gqa     16      2048    32      128     0.76    0.71
dmmha   16      2048    32      128     2.15    0.25
---------
gqa     16      2048    64      32      0.68    0.39
dmmha   16      2048    64      32      1.23    0.22
---------
gqa     16      2048    64      64      0.77    0.70
dmmha   16      2048    64      64      2.11    0.25
---------
gqa     16      2048    64      128     1.10    0.97
dmmha   16      2048    64      128     4.06    0.26
---------
gqa     16      2048    128     32      1.00    0.54
dmmha   16      2048    128     32      2.09    0.26
---------
gqa     16      2048    128     64      1.10    0.97
dmmha   16      2048    128     64      4.08    0.26
```


### Motivation and Context
As of now, this op is targeted for use on LLama models, as it supports
kv-caching and different number of heads for Q and KV (Grouped Query
Attention). We plan to add support for more platforms, input formats,
etc. in the future.

---------

Co-authored-by: Tianlei Wu <[email protected]>
Co-authored-by: [email protected] <[email protected]>
SD XL Refiner model has new hidden dimension sizes not supported by BiasSplitGelu. This update the kernel to support them.

### Motivation and Context
Current BiasSplitGelu does not support optimization for SD XL refiner model.
Previously, BiasAdd only supports hidden dimensions of 32, 640 and 1280
for stable diffusion. This adds a kernel that could support any number
of channels.

### Motivation and Context
Stable Diffusion XL refiner model uses hidden dimensions of 768 or 1536,
which was not supported in BiasAdd.
Right now, GroupNorm only support limited number of channels (320, 640,
960, 1280, 1920, 2560, 128, 256, 512). Skip the fusion if number of
channels are not supported.

### Motivation and Context
SD XL refiner model uses number of channels 384, 768, 1152, 2304 and
3072 in GroupNorm.
In SLN strict mode, current code (#16510) does not handle skip broadcast
nicely . There are two issues:
(1) skip related parameters is not passed to cuda kernel in strict mode
(2) Strict mode kernel also has bug in handling skip broadcasting (like
cuWelfordMuSigma2 does not handle skip broadcasting).

Here we remove the support of skip broadcasting in strict mode, and
operator will return error message that strict mode only support same
shape of input and skip.

Other changes:
* skip_size is misleading when there is no broadcasting. Change to
correct value.
* Refactor the code to be more efficient: (1) no need to check whether
there is broadcasting in kernel. (2) remove one local buffer (load input
to sum_v directly to save a local buffer copy).
* compute input + bias + skip instead of input + skip + bias. The order
is followed common pattern in transformers model (Here assume graph
fusion will distinguish input and skip correctly, need double check
fusion code later).
* update unit test so that strict mode is triggered in each test case
(unless skip broadcasting) to have higher test coverage.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

SLN strict mode does not support skip broadcast but current code will
silently run (kernel might fail)
### Description
<!-- Describe your changes. -->
nvcc 12.2 crashes while building
onnxruntime/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_*
for SM<8.0. nvcc 18.8 works though. It should be a bug in nvcc 12.2.

This PR excludes building flashattention for arch < 800.
Fixes a bug in `get_shared_initializers` where `signature_cache1,
signature_cache2` are passed as positional arguments to
`remove_shared_initializers` but their positions don't match the
function signature. So `signature_cache1` is passed to `min_elements`
and causes comparison error at line 907.

Pass the arguments as kwargs so that it doesn't rely on their positions.

<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Fixes the bug described above.
Fix a bug in #11803:
When hidden size is not exactly same as next size (for example ld=320 in
stable diffusion) current vectorized kernel might read out-of-bounds,
and might cause CUDA failure.

Also resolved another issue: for the first and last size, current macro
will cause some dead code (some branch will never run). Here we change
it to avoid those branches in boundary sizes.

Performance tests with stable diffusion shows that the performance is
on-par before/after this fix.
The previous shape inference never had the chance to infer the past_key
and past_value outputs because we were returning early.
Add CUDA EP to the StableDiffusion XL Demo including:
(1) Add fp16 VAE support for CUDA EP.
(2) Configuration for each model separately (For example, some models
can run with CUDA graph but some models cannot).

Some remaining works will boost performance further later:
(1) Enable CUDA Graph for Clip2 and UNet. Currently, some part of graph
is partitioned to CPU, which blocks CUDA graph.
(2) Update GroupNorm CUDA kernel for refiner. Currently, the cuda kernel
only supports limited number of channels in refiner so we shall see some
gain there if we remove the limitation.

Some extra works that are nice to have (thus lower priority):
(3) Support denoising_end to ensemble base and refiner.
(4) Support classifier free guidance (The idea is from
https://www.baseten.co/blog/sdxl-inference-in-under-2-seconds-the-ultimate-guide-to-stable-diffusion-optimiza/).


#### Performance on A100-SXM4-80GB

Example commands to test an engine built with static shape or dynamic
shape:
```
engine_name=ORT_CUDA
python demo_txt2img_xl.py --engine $engine_name "some prompt"
python demo_txt2img_xl.py --engine $engine_name --disable-cuda-graph --build-dynamic-batch --build-dynamic-shape "some prompt"
```
Engine built with dynamic shape could support different batch size (1 to
4 for TRT; 1 to 16 for CUDA) and image size (256x256 to 1024x1024).
Engine built with static shape could only support fixed batch size (1)
and image size (1024x1024).

The latency (ms) of generating an image of size 1024x1024 (sorted by
total latency):

 Engine | Base (30 Steps)* | Refiner (9 Steps) | Total Latency (ms)
-- | -- | -- | --
ORT_TRT (static shape) | 2467 | 1033 | 3501
TRT (static shape) | 2507 | 1048 | 3555
ORT_CUDA (static shape) | 2630 | 1015 | 3645
ORT_CUDA (dynamic shape) | 2639 | 1016 | 3654
TRT (dynamic shape) | 2777 | 1099 | 3876
ORT_TRT (dynamic shape) | 2890 | 1166 | 4057

\* VAE decoder is not used in Base since the output from base is latent,
which is consumed by refiner to output image.

We can see that ORT_CUDA is faster on dynamic shape, while slower in
static shape (The cause is Clip2 and UNet cannot run with CUDA Graph
right now, and we will address the issue later).

### Motivation and Context
Follow up of #17536
### Description
Initialize previously unitialized parameters that were causing Op to
crash.



### Motivation and Context
Solves Cuda Memory Misalignment / Illegal Memory Access error when
FlashAttention was used in Packed Multi-Head Attention.
wejoncy and others added 7 commits October 31, 2023 21:32
### Description

This PR is to implemente a exporter which works for large language
models(LLM).
It works for models like Llama2-70b or gpt-175.

The main idea is to utilize multiple-GPU and dispatch differnet layers
to different GPU, in short, it symply implemented auto pipeline
parallelism.

For example : to export Llama2-70b, you need 8x V100-32GB or 4x A100-80G
or More GPU memories.

It would expect to export decoder-only models. For encoder-decoder
arch-like models, we didn't test it yet.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Justin Chu <[email protected]>
This PR contains fusion-level and kernel-level optimizations for [Meta's
LLaMA-2](https://blogs.microsoft.com/blog/2023/07/18/microsoft-and-meta-expand-their-ai-partnership-with-llama-2-on-azure-and-windows/).

Some of the added optimizations include:

- SimplifiedLayerNorm changes
  - Fusions for multiple variants
- SkipSimplifiedLayerNorm changes
  - Kernel support for CPU
- Rotary embeddings (previously did not exist)
  - Fusions for multiple variants
  - CPU and CUDA kernels
  - Supports interleaving and non-interleaving in the same kernels
  - Optimized cache that requires half of its originally exported sizes
- Reduced from `(max_sequence_length, head_size)` to
`(max_sequence_length, head_size / 2)`
- Multi-head attention
  - Support for 2D and 3D attention masks
- Group query attention (for FP16 CUDA and INT4 CUDA)
  - Integration with flash attention v2 and past-present buffer sharing
- Removes need for `attention_mask` input as it is supported in the
kernel
- 4 bit quantization
  - `block_size` parameter is available for customizing
- Support the new changes for [Microsoft
version](https://github.com/microsoft/Llama-2-Onnx)
- Support combinations of the below variants (ex: export ORT version and
run with Optimum)

Supported variants of LLaMA-2 include:
- [ORT
version](https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/python/tools/transformers/models/llama)
- Produces one ONNX file that is already optimized (and quantized if
requested)
  - Integrates with Optimum
- [Another Microsoft version](https://github.com/microsoft/Llama-2-Onnx)
  - Already exported and available off-the-shelf
  - Faster versions of those models will be uploaded there soon
- [Hugging Face version](https://huggingface.co/meta-llama)
  - Models that end with `-hf`
- Some older and current versions of
[`transformers`](https://github.com/huggingface/transformers) and
[`optimum`](https://github.com/huggingface/optimum) that export the
model to ONNX differently
- Note that while some older versions are supported, it is recommended
to use the latest package versions.

To use the optimizations, please see `README.md` for details. Please
note the various `requirements.txt` files for the package versions
recommended in order to use these changes.

To run the ORT transformer optimizer separately, run the script as
follows:
```
$ cd onnxruntime/onnxruntime/python/tools/transformers/
$ python3 optimizer.py --input <filename>.onnx --output <filename>.onnx --model_type gpt2 --num_heads <number of attention heads> --hidden_size <attention hidden size> --use_external_data_format --opt_level 0
```

This PR helps the following issues:
- #14997
- #16254
- #17681
- #17925
- microsoft/onnxruntime-inference-examples#320

This PR uses changes from the following PRs:
- pytorch/pytorch#104468
- pytorch/pytorch#109759
- #17020
- #17674
- #17890
- #17920
- huggingface/transformers#26162
- huggingface/optimum#1257
- huggingface/optimum#1289
- huggingface/optimum#1462

This PR uses changes from the following issues and PRs to begin
supporting the [new TorchDynamo
exporter](https://pytorch.org/docs/stable/onnx.html#torchdynamo-based-onnx-exporter):
- huggingface/transformers#26307
- pytorch/pytorch#104903
- pytorch/pytorch#105040
- microsoft/onnxscript#847
- microsoft/onnxscript#862
- microsoft/onnxscript#493
### Description
Add a contrib op MatMulBnb4 (FP4 and NF4) and related toolchain to
support quantization on weight.

This PR adds:
- schema for contrib op MatMulBnb4 which can support FP4 (4-bit floating
point) and NF4 (4-bit NormalFloat) quantization on weight.
- a naive implementation for MatMulBnb4 on CPU and GPU, i.e.,
implemented like MatMul(A, Dequantize(B)).
- a special implementation for GemV for MatMulBnb4 and related benchmark
tool.
- tool to quantize model to FP4 or NF4.
### Description
This PR adds a few updates to scripts in the LLaMA folder:
- Fixes the precision re-naming in the LLaMA export
- Adds a "prerequisites" section in the README
- Adds IO binding synchronizations during benchmarking for other EPs



### Motivation and Context
- With precision re-naming, the LLaMA parity check does not produce
errors when creating the FP32 CPU model
- The "prerequisites" section shows that there are specific package
versions needed
- This allows for benchmarking with other EPs besides CPU and CUDA
### Description
<!-- Describe your changes. -->
Optimize SkipLayerNorm for large dimension (>=2048) by handling 8
elements in one thread. It avoid the re-writing and re-loading sum of
input, skip and bias to main memory. It reduces the latency of dimension
4096 with small batch size from ~18us to ~3.8us on A100.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Add an MHA mask pattern for the scenario where the mask has already been
broadcasted via an Expand node.
* Add a new operator SkipGroupNorm to support skip and bias inputs.
* Update GroupNorm kernel to support number of channels used in SD XLrefiner.
* Add epsilon in kernel
* Add parity and performance test script
* Remove many limitations including max batch size, max number of groups, c % cPerBlock ==0 etc.

### Motivation and Context

Update GroupNorm to support SD XL Refiner and beyond.
@tianleiwu tianleiwu marked this pull request as draft October 31, 2023 21:54
@tianleiwu tianleiwu marked this pull request as ready for review November 1, 2023 00:38
kunal-vaishnavi and others added 2 commits November 1, 2023 04:25
### Description
This PR reduces the memory usage when exporting and benchmarking LLaMA.



### Motivation and Context
- Exporting: The PyTorch model is deleted from memory after a successful
export instead of deleting it from memory after exporting + converting
the ONNX model to the desired precision.
- Benchmarking: In the ONNX model with GroupQueryAttention, the KV cache
inputs use the same GPU memory for both the prompt and token generation
benchmarks.
### Description
Implement Split KV optimization for FlashAttention in MHA and Attention
operators.

### Motivation and Context
Can help further accelerate these ops.
@tianleiwu tianleiwu merged commit c273f7a into rel-1.16.2 Nov 1, 2023
95 of 99 checks passed
@tianleiwu tianleiwu deleted the tlwu/rel-1.16.2_sdxl_llama branch November 1, 2023 21:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.