Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/web] BiasSplitGelu and BiasAdd kernels #17161

Merged
merged 10 commits into from
Oct 3, 2023
Merged

Conversation

dakenf
Copy link
Contributor

@dakenf dakenf commented Aug 15, 2023

Description

Two contrib kernels that supposed to speed-up StableDiffusion according to this doc https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md

However, there is no noticable effect in speed or memory consumption. So i guess the only way to make it faster is to implement MultiHeadAttention but i'm not capable of doing that right now. So i'll focus on existing PRs and finding the JSEP kernel that produces incorrect results. It should be one of the old ones (i suspect Conv or ConvTranspose), as SD was not generating images correctly on webgpu since i started working on it. I hoped someone else would fix that by the time i finish with kernels/optimizations 😅

@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Aug 16, 2023
@guschmue
Copy link
Contributor

/azp run ONNX Runtime Web CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 1 pipeline(s).

@satyajandhyala
Copy link
Contributor

If applicable please add JSONC test cases to validate. For example see onnxruntime/js/web/test/data/ops/gelu.jsonc

@satyajandhyala
Copy link
Contributor

Description

Two contrib kernels that supposed to speed-up StableDiffusion according to this doc https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md

However, there is no noticable effect in speed or memory consumption. So i guess the only way to make it faster is to implement MultiHeadAttention but i'm not capable of doing that right now. So i'll focus on existing PRs and finding the JSEP kernel that produces incorrect results. It should be one of the old ones (i suspect Conv or ConvTranspose), as SD was not generating images correctly on webgpu since i started working on it. I hoped someone else would fix that by the time i finish with kernels/optimizations 😅

Can you share a (small) test case that demonstrates the problem?

@dakenf
Copy link
Contributor Author

dakenf commented Aug 16, 2023

Can you share a (small) test case that demonstrates the problem?

Unfortunately no. I will make different builds with some JSEP kernels disabled to see which one causes the problem. Right now if i run it on webgpu, i get an image like this for any prompt. If i run it on CPU, everything works fine
image

@satyajandhyala
Copy link
Contributor

Can you share a (small) test case that demonstrates the problem?

Unfortunately no. I will make different builds with some JSEP kernels disabled to see which one causes the problem. Right now if i run it on webgpu, i get an image like this for any prompt. If i run it on CPU, everything works fine image

@dakenf It will help if you can you add steps to reproduce this problem?

@dakenf
Copy link
Contributor Author

dakenf commented Aug 17, 2023

@dakenf It will help if you can you add steps to reproduce this problem?

To reproduce the error you'll need to make 64bit build of runtime. So it's quite complicated

I've found two issues that gave incorrect results

  1. MatMul did not support broadcasting, fixed here [js/web] MatMul broadcasting #17191
  2. Conv kernel gives incorrect results for StableDiffusion unet. If i comment it out, everything works fine but slow. Not sure where to start as it passes all tests

@fs-eire
Copy link
Contributor

fs-eire commented Aug 17, 2023

@dakenf It will help if you can you add steps to reproduce this problem?

To reproduce the error you'll need to make 64bit build of runtime. So it's quite complicated

I've found two issues that gave incorrect results

  1. MatMul did not support broadcasting, fixed here [js/web] MatMul broadcasting #17191
  2. Conv kernel gives incorrect results for StableDiffusion unet. If i comment it out, everything works fine but slow. Not sure where to start as it passes all tests

could you help to dump a set of input/output sample (also the attributes) for the incorrect conv op so that we can take a look at it?

@dakenf
Copy link
Contributor Author

dakenf commented Aug 17, 2023

could you help to dump a set of input/output sample (also the attributes) for the incorrect conv op so that we can take a look at it?

yeah. there are 65 conv nodes so i feel it is going to be fun

@dakenf
Copy link
Contributor Author

dakenf commented Aug 18, 2023

I could not find an easy way to dump inputs/outputs as they are very big but found a solution to fix my issues: #17219

@guschmue
Copy link
Contributor

ci is nagging, run 'npm run format'

@dakenf
Copy link
Contributor Author

dakenf commented Aug 18, 2023

ci is nagging, run 'npm run format'

yup. will also add some tests and update the PR

@guschmue
Copy link
Contributor

/azp run ONNX Runtime Web CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 1 pipeline(s).

@satyajandhyala
Copy link
Contributor

@dakenf Can you try running with the latest code. I merged fixes to ConvTranspose.

@guschmue
Copy link
Contributor

@dakenf
Copy link
Contributor Author

dakenf commented Aug 25, 2023

@dakenf Can you try running with the latest code. I merged fixes to ConvTranspose.

Yeah. Will don on the weekend. And this week emscripten got a release with all required changes (except 64bit threads) so i'll clean up 64bit PR

There's still a big issue with VRAM limit in chrome, it's ~16gb on windows. And StableDiffusion unet eats 10gb+, so it's not possible to fit unet and vae into VRAM. I'm trying to solve it with Attention+MultiHeadAttention ops and it went down to ~5gb but it does not give correct results. Most likely i've messed up with indices for packed weights or batched gemm/matmul. If i won't be able to fix it in a reasonable time, will make a draft PR and ask for your help

@dakenf
Copy link
Contributor Author

dakenf commented Aug 30, 2023

@dakenf Can you try running with the latest code. I merged fixes to ConvTranspose.

It seems fine now. However i'm experiencing some other issues, like getting weird images completely unrelated to prompt. But it happens for both wasm and webgpu EPs. Need to check whether the problem is in model optimizer, new emscripten release or my DNA
download (1)

@fs-eire
Copy link
Contributor

fs-eire commented Aug 30, 2023

/azp run Windows ARM64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CI Pipeline,Windows GPU TensorRT CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline

@fs-eire
Copy link
Contributor

fs-eire commented Aug 30, 2023

/azp run Linux QNN CI Pipeline,MacOS CI Pipeline,orttraining-amd-gpu-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,onnxruntime-python-checks-ci-pipeline,onnxruntime-binary-size-checks-ci-pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 7 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 10 pipeline(s).

@dakenf
Copy link
Contributor Author

dakenf commented Sep 1, 2023

However i'm experiencing some other issues, like getting weird images completely unrelated to prompt. But it happens for both wasm and webgpu EPs

This is a bug in python fusion optimizer. It breaks the text encoder model, so replacing it with original one fixes the issue. However, it is quite weird because when i use python code to generate images with that broken model, it works fine

Maybe will investigate more and fill a bug later. Want to focus on optimizing InstanceNorm/LayerNorm kernels as InstanceNorm has 3 for loops from 0 to 40k in each invocation. And then finally revisit Attention kernel with a fresh look

@dakenf
Copy link
Contributor Author

dakenf commented Sep 1, 2023

BTW, if you are struggling with a choice of next OP to implement, can it be NhwcConv?

@dakenf
Copy link
Contributor Author

dakenf commented Sep 2, 2023

With MultiHeadAttention (without packed weights), Attention with some vec2/vec4 optimizations and in-place SoftMax, InstanceNorm vec2/vec4 i've narrowed unet to 5.6gb VRAM and ~2.5sec for one step

Will apply same optimizations to LayerNorm/SkipLayerNorm and implement GroupNorm to see if it will speed it up. Maybe with GroupNorm VAE will use less than 7.5gb of VRAM

So you can expect a few more PRs next week with all these stuff

2023-09-03.02-01-07.mp4

@dakenf
Copy link
Contributor Author

dakenf commented Sep 4, 2023

@fs-eire finally got shader-f16 extension working with latest chrome and --enable-dawn-features=allow_unsafe_apis

Since almost all OPs use indices helper, it will be an easy change (however i've seen some hardcoded var x = fp32(0))
If that would give 2x boost and if i'll manage to implement flash attention with packed weights, it all might go down from 2.5sec to less than 1. Feels like christmas, will do tests tomorrow
image

@fs-eire
Copy link
Contributor

fs-eire commented Sep 8, 2023

The part of 2 operators are good. could you please revert the changes for adding test support for several other EPs? This change is helpful but should be separated.

please mention me after this change, I will try to kick the CI asap.

@dakenf
Copy link
Contributor Author

dakenf commented Sep 12, 2023

@fs-eire i've reverted test runner changes and added fp16 support

@fs-eire
Copy link
Contributor

fs-eire commented Sep 12, 2023

/azp run Windows ARM64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CI Pipeline,Windows GPU TensorRT CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline

@fs-eire
Copy link
Contributor

fs-eire commented Sep 12, 2023

/azp run Linux QNN CI Pipeline,MacOS CI Pipeline,orttraining-amd-gpu-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,onnxruntime-python-checks-ci-pipeline,onnxruntime-binary-size-checks-ci-pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 7 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 10 pipeline(s).

@fs-eire
Copy link
Contributor

fs-eire commented Oct 2, 2023

/azp run Windows x64 QNN CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 1 pipeline(s).

@fs-eire
Copy link
Contributor

fs-eire commented Oct 2, 2023

/azp run Windows ARM64 QNN CI Pipeline,Windows x64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CI Pipeline,Windows GPU TensorRT CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline

@fs-eire
Copy link
Contributor

fs-eire commented Oct 2, 2023

/azp run Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,orttraining-amd-gpu-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,onnxruntime-python-checks-ci-pipeline,onnxruntime-binary-size-checks-ci-pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 8 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 10 pipeline(s).

@guschmue guschmue merged commit d0519a7 into microsoft:main Oct 3, 2023
63 of 65 checks passed
kleiti pushed a commit to kleiti/onnxruntime that referenced this pull request Mar 22, 2024
### Description
Two contrib kernels that supposed to speed-up StableDiffusion according
to this doc
https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md

However, there is no noticable effect in speed or memory consumption. So
i guess the only way to make it faster is to implement
MultiHeadAttention but i'm not capable of doing that right now. So i'll
focus on existing PRs and finding the JSEP kernel that produces
incorrect results. It should be one of the old ones (i suspect Conv or
ConvTranspose), as SD was not generating images correctly on webgpu
since i started working on it. I hoped someone else would fix that by
the time i finish with kernels/optimizations 😅

---------

Co-authored-by: Guenther Schmuelling <[email protected]>
Co-authored-by: Yulong Wang <[email protected]>
siweic0 pushed a commit to siweic0/onnxruntime-web that referenced this pull request May 9, 2024
### Description
Two contrib kernels that supposed to speed-up StableDiffusion according
to this doc
https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md

However, there is no noticable effect in speed or memory consumption. So
i guess the only way to make it faster is to implement
MultiHeadAttention but i'm not capable of doing that right now. So i'll
focus on existing PRs and finding the JSEP kernel that produces
incorrect results. It should be one of the old ones (i suspect Conv or
ConvTranspose), as SD was not generating images correctly on webgpu
since i started working on it. I hoped someone else would fix that by
the time i finish with kernels/optimizations 😅

---------

Co-authored-by: Guenther Schmuelling <[email protected]>
Co-authored-by: Yulong Wang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:WebGPU ort-web webgpu provider
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants