Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GemmFloat8 as a contrib ops #16051

Merged
merged 342 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from 250 commits
Commits
Show all changes
342 commits
Select commit Hold shift + click to select a range
8c792cc
fix unit tests
xadupre Apr 20, 2023
9c17567
add more unit test and better error message
xadupre Apr 20, 2023
833ff29
fix quantize
xadupre Apr 20, 2023
c694dbd
fix lint
xadupre Apr 24, 2023
2bd7144
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f8
xadupre Apr 24, 2023
5a47850
onnx is using 1.14.0rc2
xadupre Apr 24, 2023
c4c7b19
disable a few unit test related to float8
xadupre Apr 24, 2023
b47c688
fix registration
xadupre Apr 24, 2023
4219b12
disable test to enable when Resize, AveragePool 19 arem merged
xadupre Apr 24, 2023
cca384c
disable tests failing because of float8
xadupre Apr 25, 2023
b0341b2
disable more tests
xadupre Apr 25, 2023
d3a81de
fix exclusion list
xadupre Apr 25, 2023
d8a2763
disable more tests
xadupre Apr 25, 2023
fbd4065
fix eclusion
xadupre Apr 25, 2023
43c7e08
fix exclusion
xadupre Apr 25, 2023
275eee9
fix exclusive
xadupre Apr 25, 2023
0efb603
fix merge conflict
xadupre Apr 26, 2023
10b2886
extend float 8 types support to C#, flatbuffers
xadupre Apr 26, 2023
9dc5219
update flatbuffers header
xadupre Apr 26, 2023
036ee20
lint
xadupre Apr 26, 2023
1b97938
fix misspelling C#
xadupre Apr 26, 2023
b7df7cd
fix C# issue
xadupre Apr 26, 2023
bd846f1
check error msg
xadupre Apr 26, 2023
1673895
fix qnn
xadupre Apr 26, 2023
ffa2f52
qnn
xadupre Apr 26, 2023
f835549
remove QDQ 13-18
xadupre Apr 26, 2023
4e38758
disable some tests
xadupre Apr 26, 2023
37354c2
fix exclusion and tests
xadupre Apr 27, 2023
5f3bede
fix self.failed
xadupre Apr 27, 2023
a43be2e
disable more tests
xadupre Apr 27, 2023
5e0dc14
exclude test
xadupre Apr 27, 2023
6cae621
enable disabled tests
xadupre Apr 27, 2023
c78ad2c
fix QuantlizeLinear
xadupre Apr 27, 2023
4cf5c60
fix ci
xadupre Apr 27, 2023
56ad3d1
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f8
xadupre Apr 27, 2023
ffbdb82
fix rocmbuild
xadupre Apr 27, 2023
325048b
disable one test on CUDA
xadupre Apr 28, 2023
ff674ca
fix merge conflict
xadupre Apr 28, 2023
83de6ac
rocm
xadupre Apr 28, 2023
57a33e5
fix negative axis
xadupre Apr 28, 2023
9c1e145
fix compilation issue
xadupre Apr 28, 2023
e28844c
fix quantize negative axis
xadupre Apr 28, 2023
6f3f7ec
fix compilation issue
xadupre Apr 28, 2023
07a032b
fix compilation issue
xadupre Apr 28, 2023
7451848
fix compilation issue
xadupre Apr 28, 2023
41bf7fc
fix merge conflicts
xadupre Apr 28, 2023
d7ff10c
fix wrong import
xadupre Apr 28, 2023
47c0c7c
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f8
xadupre Apr 29, 2023
9d5c1dc
sort exceptions
xadupre Apr 29, 2023
bcb119b
update operator.md
xadupre Apr 29, 2023
fd711f6
disable more unit test on dwml
xadupre Apr 29, 2023
3f43080
fix merge conflit
xadupre May 2, 2023
e1e0657
fix merge conflict
xadupre May 3, 2023
d67c76c
disable a couple of test to see which one is useful
xadupre May 3, 2023
2ca6e47
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f8
xadupre May 4, 2023
a497626
rename all types
xadupre May 4, 2023
c13cb27
improve exclusion list
xadupre May 4, 2023
7497c6f
fix link issues
xadupre May 4, 2023
f595338
lint
xadupre May 4, 2023
de4c8d5
fix merge conflicts
xadupre May 4, 2023
29400c5
fix dependency number
xadupre May 4, 2023
d016cf0
use released onnx
xadupre May 4, 2023
3588789
update deps.txt with the latest onnx
xadupre May 4, 2023
4dfb2b3
update onnx repo
xadupre May 4, 2023
8d00750
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f8
xadupre May 4, 2023
5924c43
update manifest
xadupre May 4, 2023
857601a
1.0.55
xadupre May 4, 2023
eb9dc01
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f8
xadupre May 4, 2023
5839829
updtae md
xadupre May 4, 2023
e7b3ac3
disable new tests
xadupre May 4, 2023
5962571
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f8
xadupre May 5, 2023
333c712
remove commentafter ennbleingdisabled tests
xadupre May 5, 2023
993bbb8
fix merge conflicts
xadupre May 7, 2023
551d31e
refactor tests
xadupre May 7, 2023
29672ea
refactor disabled tests
xadupre May 7, 2023
dc37f60
update version for download-deps.yml
xadupre May 7, 2023
db263a8
remove comments, improves exclusion list
xadupre May 8, 2023
126362d
improve error message
xadupre May 8, 2023
7724e9d
fix one misspelling
xadupre May 8, 2023
324f985
fix cast op
xadupre May 8, 2023
8d2ed30
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f8
xadupre May 9, 2023
c9f11ff
minor updates based on comments
xadupre May 9, 2023
8d2f9d2
fix compiling issues
xadupre May 9, 2023
ce9a963
lint
xadupre May 9, 2023
0f5b314
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f8
xadupre May 10, 2023
a03ab5b
minor changes, replies to comment
xadupre May 11, 2023
c38667d
fix compilation issues
xadupre May 11, 2023
275a321
restore model test exclusion lists
xadupre May 11, 2023
36da2ec
fix a warning
xadupre May 11, 2023
57c484d
fix merge conflicts
xadupre May 12, 2023
dcf3b1a
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f8
xadupre May 12, 2023
717c264
support some operator for opset 19 in optimizers
xadupre May 12, 2023
01691bf
remove empty line
xadupre May 12, 2023
f00c273
scope down ut
RandyShuai May 13, 2023
069ede2
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f8
xadupre May 15, 2023
10b0054
Merge commit 'f00c273d820ce3008c0c2f3d6d2d07f93d564afa' into f8
xadupre May 15, 2023
1ddc1f2
Update onnxruntime/core/providers/cpu/quantization/quantize_linear.cc
xadupre May 15, 2023
f5f7129
Update onnxruntime/core/providers/cpu/quantization/quantize_linear.cc
xadupre May 15, 2023
9f01066
merge conflict
xadupre May 15, 2023
3f55df4
remove unused functions
xadupre May 15, 2023
eb56ce2
lint
xadupre May 15, 2023
3482858
suggested modifications
xadupre May 16, 2023
d5acbe0
remove unnecessary cpu test exclusion
xadupre May 16, 2023
bb8ef69
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f8
xadupre May 16, 2023
92f972a
lint
xadupre May 16, 2023
557e9c9
lint2
xadupre May 16, 2023
e74bb22
restore typedef
xadupre May 16, 2023
920ef2a
lint
xadupre May 16, 2023
a3bfa5d
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f8
xadupre May 16, 2023
b07dcca
extended the list of supported ORT formats
xadupre May 16, 2023
d14071f
improve the code with review comments
xadupre May 16, 2023
a408803
comments from review
xadupre May 17, 2023
9de02b2
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f8
xadupre May 17, 2023
05dd3e4
add missing files in previous commit
xadupre May 17, 2023
6a2230b
rename IR4 into IRv4
xadupre May 17, 2023
dda473b
Update onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc
xadupre May 17, 2023
b62667c
fix AMD build
xadupre May 17, 2023
9a9179c
enable disabled test
xadupre May 17, 2023
bebc467
First draft for GemmFloat8
xadupre May 17, 2023
3e2f877
refactoring
xadupre May 18, 2023
8d51059
remove empty line
xadupre May 18, 2023
5c6eb78
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f8
xadupre May 18, 2023
081d9c1
Merge branch 'f8' of https://github.com/xadupre/onnxruntime into f8
xadupre May 18, 2023
0fb28a1
add one comment in test file
xadupre May 22, 2023
b66b5e1
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f8
xadupre May 22, 2023
f2ddc47
add one more test for saturate
xadupre May 22, 2023
db3de4c
add one more test
xadupre May 22, 2023
87aa6b0
refactor cast implementation on CPU
xadupre May 22, 2023
863c576
Update include/onnxruntime/core/framework/float8.h
xadupre May 23, 2023
68eda16
Update onnxruntime/core/framework/tensor_type_and_shape.cc
xadupre May 23, 2023
f0deb22
lint
xadupre May 23, 2023
878ffd1
Merge branch 'f8' of https://github.com/xadupre/onnxruntime into f8
xadupre May 23, 2023
efb36c0
Update onnxruntime/core/framework/tensorprotoutils.cc
xadupre May 23, 2023
dd3cab5
Merge branch 'f8' of https://github.com/xadupre/onnxruntime into f8
xadupre May 23, 2023
2b1a53d
Update onnxruntime/core/providers/cpu/quantization/quantize_linear.cc
xadupre May 23, 2023
6eadaf4
Merge branch 'f8' of https://github.com/xadupre/onnxruntime into f8
xadupre May 23, 2023
3fc31c4
update the code following review comments
xadupre May 23, 2023
5727127
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f8
xadupre May 23, 2023
8bca51f
Merge branch 'f8' of https://github.com/xadupre/onnxruntime into gemm8
xadupre May 23, 2023
ea0e0e5
last, link issue
xadupre May 23, 2023
d1e7b37
fix link issue
xadupre May 23, 2023
a470901
Update include/onnxruntime/core/framework/float8.h
xadupre May 24, 2023
b9685a7
Update include/onnxruntime/core/framework/float8.h
xadupre May 24, 2023
e647892
Update include/onnxruntime/core/framework/float8.h
xadupre May 24, 2023
ee5a0b2
fix gemm8
xadupre May 24, 2023
fe9c0e0
refactoring
xadupre May 24, 2023
a463178
fix default value
xadupre May 24, 2023
5123bf8
gemm
xadupre May 24, 2023
856a6f3
fix issue
xadupre May 24, 2023
0f34760
fix compilation
xadupre May 24, 2023
cf30843
update code
xadupre May 24, 2023
bc138da
update example
xadupre May 24, 2023
e893f14
iteration
xadupre May 24, 2023
3f4d81f
update
xadupre May 24, 2023
ddb44ac
iteration
xadupre May 24, 2023
0877f58
gemm
xadupre May 24, 2023
6af6dab
gemm
xadupre May 24, 2023
143847a
fix inference
xadupre May 24, 2023
e90cfa5
better error message
xadupre May 25, 2023
31bda94
Merge branch 'f8' of https://github.com/xadupre/onnxruntime into gemm8
xadupre May 25, 2023
7d0edc2
fix compilation
xadupre May 25, 2023
840fc54
fix compilation
xadupre May 25, 2023
6298112
update
xadupre May 26, 2023
e455930
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f8
xadupre May 26, 2023
cd23e84
Update onnxruntime/core/providers/cpu/tensor/reshape.cc
xadupre May 26, 2023
46ce272
v9
xadupre May 26, 2023
50d8363
Update orttraining/orttraining/python/ort_trainer.py
xadupre May 26, 2023
5a1a379
Update onnxruntime/core/providers/cuda/tensor/reshape.cc
xadupre May 26, 2023
fb1fd1d
raise an exception for an unsupported case
xadupre May 26, 2023
d3e1fec
Merge branch 'f8' of https://github.com/xadupre/onnxruntime into f8
xadupre May 26, 2023
d08ea5f
fix rocmprovider for opset 19
xadupre May 26, 2023
640b66b
removed unncessary comments
xadupre May 26, 2023
9dcf3a4
disable specific code for bfloat16
xadupre May 26, 2023
1ef646e
add flag DISABLE_FLOAT8_TYPES
xadupre May 26, 2023
0fa96f9
Merge branch 'f8' of https://github.com/xadupre/onnxruntime into gemm8
xadupre May 26, 2023
812e1cf
temp
xadupre May 26, 2023
05bf92a
lint
xadupre May 26, 2023
8175814
Merge branch 'f8' of https://github.com/xadupre/onnxruntime into gemm8
xadupre May 26, 2023
0df344d
temp
xadupre May 26, 2023
96748a6
lint
xadupre May 26, 2023
6e7292a
Merge branch 'f8' of https://github.com/xadupre/onnxruntime into gemm8
xadupre May 26, 2023
ff8feca
fix disable float8
xadupre May 26, 2023
d58c9cb
missing S
xadupre May 26, 2023
4f6ea36
disable if on rocm
xadupre May 26, 2023
267b526
update code
xadupre May 26, 2023
33ec91b
Merge branch 'f8' of https://github.com/xadupre/onnxruntime into gemm8
xadupre May 26, 2023
e54e85c
fix two compilation issues
xadupre May 26, 2023
a55176b
workspace
xadupre May 26, 2023
786c04b
update documentation
xadupre May 27, 2023
dd1d34c
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre May 27, 2023
3f61423
enables f8 tests
xadupre May 27, 2023
beed7a5
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f8
xadupre May 27, 2023
b2f6b6d
enable float8 tests
xadupre May 27, 2023
bfcd61b
disable float8 types for qnn, dnnl
xadupre May 27, 2023
d1e9eb1
avoid disabling tests twice
xadupre May 28, 2023
cf5faa7
change opset in function generate_size_op_test
xadupre May 28, 2023
ccc6876
fix json and lint
xadupre May 28, 2023
c6d20a6
Merge branch 'f8' of https://github.com/xadupre/onnxruntime into gemm8
xadupre May 29, 2023
dfa9085
fix merge conflicts
xadupre May 29, 2023
7e7da3c
fix merge conflicts
xadupre Jun 5, 2023
ca8aa15
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Jun 20, 2023
cc641f9
draft
xadupre Jun 23, 2023
6d3f7d9
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Jun 23, 2023
890f228
finalization of the first draft
xadupre Jun 23, 2023
877b1fc
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Jun 26, 2023
00719ad
update gemm float 8
xadupre Jun 26, 2023
d9c7778
gemm8
xadupre Jun 27, 2023
f9ba8f5
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Jun 28, 2023
c768d44
fix GemmFloat8
xadupre Jun 28, 2023
7a21260
fix compilation on windows
xadupre Jun 28, 2023
b9ce311
lint
xadupre Jun 29, 2023
5d2a969
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Jun 29, 2023
29b8ae4
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Jun 30, 2023
85e1126
finalize the unit tests
xadupre Jun 30, 2023
68571a5
add activation
xadupre Jul 3, 2023
cf55c92
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Jul 3, 2023
131734f
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Jul 10, 2023
cddd1c0
fix markdown
xadupre Jul 10, 2023
e5a5356
fix documentation, exclude gemm_float8 from the list of rocm files
xadupre Jul 11, 2023
559ceda
fix merge conflicts
xadupre Aug 8, 2023
dff451d
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Aug 9, 2023
ea052f2
fix misspelling
xadupre Aug 9, 2023
cbdd21d
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Aug 28, 2023
907c4ca
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Aug 30, 2023
78a9d0e
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Sep 1, 2023
55b06ea
fix null pointer
xadupre Sep 1, 2023
db31126
lint c++
xadupre Sep 1, 2023
a7db787
allows scaleY to be null
Sep 1, 2023
2160f4a
Merge branch 'gemm8' of https://github.com/xadupre/onnxruntime into g…
Sep 1, 2023
554a949
Merge branch 'gemm8' of https://github.com/xadupre/onnxruntime into g…
xadupre Sep 29, 2023
fc89b9f
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Sep 29, 2023
3e4435e
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Oct 10, 2023
e7e3107
fix merge conflicts
xadupre Oct 16, 2023
d075a6d
lint
xadupre Oct 16, 2023
1de2c1d
fix lint issues
xadupre Oct 16, 2023
650bfe9
lint
xadupre Oct 16, 2023
5ac6288
lint
xadupre Oct 16, 2023
d7753b0
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Oct 17, 2023
f33f99d
first wave of fixes
xadupre Oct 17, 2023
471b079
rename rowMajor into rowMajorCompute
xadupre Oct 17, 2023
c9dcef7
more test
xadupre Oct 17, 2023
c5df2a1
fix SetParams
xadupre Oct 17, 2023
70b7302
simplify SetParams
xadupre Oct 17, 2023
6821e65
remove unnecessary parameters
xadupre Oct 18, 2023
9ab37db
reply to comment
xadupre Oct 23, 2023
abb8c85
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Oct 23, 2023
975049b
lint
xadupre Oct 23, 2023
9834752
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Oct 25, 2023
e9ea8ea
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Oct 26, 2023
53aba07
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Oct 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ set(contrib_ops_excluded_files
"diffusion/group_norm_impl.cu"
"diffusion/group_norm_impl.h"
"diffusion/nhwc_conv.cc"
"math/gemm_float8.cc"
"math/gemm_float8.cu"
"math/gemm_float8.h"
"quantization/attention_quantization.cc"
"quantization/attention_quantization.h"
"quantization/attention_quantization_impl.cu"
Expand Down
66 changes: 66 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Do not modify directly.*
* <a href="#com.microsoft.GatherND">com.microsoft.GatherND</a>
* <a href="#com.microsoft.Gelu">com.microsoft.Gelu</a>
* <a href="#com.microsoft.GemmFastGelu">com.microsoft.GemmFastGelu</a>
* <a href="#com.microsoft.GemmFloat8">com.microsoft.GemmFloat8</a>
* <a href="#com.microsoft.GreedySearch">com.microsoft.GreedySearch</a>
* <a href="#com.microsoft.GridSample">com.microsoft.GridSample</a>
* <a href="#com.microsoft.GroupNorm">com.microsoft.GroupNorm</a>
Expand Down Expand Up @@ -2135,6 +2136,71 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.GemmFloat8"></a><a name="com.microsoft.gemmfloat8">**com.microsoft.GemmFloat8**</a>

Generic Gemm for float and float 8.

#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Attributes

<dl>
<dt><tt>activation</tt> : string</dt>
<dd>Activation function, RELU or GELU or NONE (default).</dd>
<dt><tt>alpha</tt> : float</dt>
<dd>Scalar multiplier for the product of input tensors A * B.</dd>
<dt><tt>beta</tt> : float</dt>
<dd>Scalar multiplier for the product of input bias C.</dd>
<dt><tt>dtype</tt> : int</dt>
<dd>Output Type. Same definition as attribute 'to' for operator Cast.</dd>
<dt><tt>transA</tt> : int</dt>
<dd>Whether A should be transposed. Float 8 only supprted transA=0.</dd>
<dt><tt>transB</tt> : int</dt>
<dd>Whether B should be transposed. Float 8 only supprted transB=1.</dd>
</dl>

#### Inputs (2 - 6)

<dl>
<dt><tt>A</tt> : TA</dt>
<dd>Input tensor A. The shape of A should be (M, K) if transA is 0, or (K, M) if transA is non-zero.</dd>
<dt><tt>B</tt> : TB</dt>
<dd>Input tensor B. The shape of B should be (K, N) if transB is 0, or (N, K) if transB is non-zero.</dd>
<dt><tt>C</tt> (optional) : TC</dt>
<dd>Input tensor C.</dd>
<dt><tt>scaleA</tt> (optional) : TS</dt>
<dd>Scale of tensor A if A is float 8 tensor</dd>
<dt><tt>scaleB</tt> (optional) : TS</dt>
<dd>Scale of tensor B if B is float 8 tensor</dd>
<dt><tt>scaleY</tt> (optional) : TS</dt>
<dd>Scale of the output tensor if A or B is float 8.</dd>
</dl>

#### Outputs

<dl>
<dt><tt>Y</tt> : TR</dt>
<dd>Output tensor of shape (M, N).</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>TA</tt> : tensor(float8e4m3fn), tensor(float8e5m2), tensor(float16), tensor(bfloat16), tensor(float)</dt>
<dd>Constrain type to input A.</dd>
<dt><tt>TB</tt> : tensor(float8e4m3fn), tensor(float8e5m2), tensor(float16), tensor(bfloat16), tensor(float)</dt>
<dd>Constrain type to input B.</dd>
<dt><tt>TC</tt> : tensor(float16), tensor(bfloat16), tensor(float)</dt>
<dd>Constrain type to input C.</dd>
<dt><tt>TR</tt> : tensor(float8e4m3fn), tensor(float8e5m2), tensor(float16), tensor(bfloat16), tensor(float)</dt>
<dd>Constrain type to result type.</dd>
<dt><tt>TS</tt> : tensor(float)</dt>
<dd>Constrain type for all input scales (scaleA, scaleB, scaleY).</dd>
</dl>


### <a name="com.microsoft.GreedySearch"></a><a name="com.microsoft.greedysearch">**com.microsoft.GreedySearch**</a>

Greedy Search for text generation.
Expand Down
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,7 @@ Do not modify directly.*
|FusedMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|GatedRelativePositionBias|*in* query_layer:**T**<br> *in* query_bias:**T**<br> *in* rel_pos:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* eco_a:**T**<br> *in* token_offset:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|GemmFloat8|*in* A:**TA**<br> *in* B:**TB**<br> *in* C:**TC**<br> *in* scaleA:**TS**<br> *in* scaleB:**TS**<br> *in* scaleY:**TS**<br> *out* Y:**TR**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)<br/> **TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)<br/> **TR** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)<br/> **TS** = tensor(float)|
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedSelfAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GemmFloat8);

#ifdef ENABLE_ATEN
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen);
Expand Down Expand Up @@ -305,6 +306,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedSelfAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GemmFloat8)>,

#ifdef ENABLE_ATEN
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen)>,
Expand Down
70 changes: 70 additions & 0 deletions onnxruntime/contrib_ops/cuda/math/gemm_float8.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
Fixed Show fixed Hide fixed
// Licensed under the MIT License.

#include <string>
#include "core/providers/cuda/math/gemm.h"
yufenglee marked this conversation as resolved.
Show resolved Hide resolved
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/shared_inc/fpgeneric.h"
#include "core/providers/cpu/math/gemm_helper.h"
#include "contrib_ops/cuda/math/gemm_float8.h"

using namespace ONNX_NAMESPACE;

Check warning on line 11 in onnxruntime/contrib_ops/cuda/math/gemm_float8.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/contrib_ops/cuda/math/gemm_float8.cc:11: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]

namespace onnxruntime {
namespace contrib {
namespace cuda {

#define REGISTER_KERNEL() \
ONNX_OPERATOR_KERNEL_EX( \
GemmFloat8, \
kMSDomain, \
1, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("TA", BuildKernelDefConstraints<Float8E4M3FN, Float8E5M2, MLFloat16, BFloat16, float>()) \
.TypeConstraint("TB", BuildKernelDefConstraints<Float8E4M3FN, Float8E5M2, MLFloat16, BFloat16, float>()) \
.TypeConstraint("TR", BuildKernelDefConstraints<Float8E4M3FN, Float8E5M2, MLFloat16, BFloat16, float>()) \
.TypeConstraint("TS", BuildKernelDefConstraints<float>()), \
GemmFloat8);

REGISTER_KERNEL()

GemmFloat8::GemmFloat8(const OpKernelInfo& info) : CudaKernel(info) {
transA_ = info.GetAttrOrDefault<int64_t>("transA", 0);
transB_ = info.GetAttrOrDefault<int64_t>("transB", 0);
dtype_ = info.GetAttrOrDefault<int64_t>("dtype", ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
auto& device_prop = GetDeviceProp();
sm_count_ = device_prop.multiProcessorCount;
alpha_ = info.GetAttrOrDefault<float>("alpha", 1);
beta_ = info.GetAttrOrDefault<float>("beta", 0);

#if (CUDA_VERSION <= 12000)
ORT_ENFORCE(beta_ == 0, "CUDA < 12.0 does not support bias, beta must be 0.");
#endif

std::string stemp = info.GetAttrOrDefault<std::string>("activation", "NONE");
if (stemp == "NONE") {
epilogue_ = CUBLASLT_EPILOGUE_DEFAULT;
} else if (stemp == "RELU") {
epilogue_ = CUBLASLT_EPILOGUE_RELU;
} else if (stemp == "GELU") {
epilogue_ = CUBLASLT_EPILOGUE_GELU;
} else {
ORT_THROW("Unexpected value for activation: '", stemp, "'.");
}
}

Status GemmFloat8::SetCheck(const TensorShape& a_shape, const TensorShape& b_shape, int& M, int& N, int& K) const {
GemmHelper helper(a_shape, transA_, b_shape, transB_, TensorShape({}));
if (!helper.State().IsOK())
return helper.State();

M = gsl::narrow_cast<int>(helper.M());
N = gsl::narrow_cast<int>(helper.N());
K = gsl::narrow_cast<int>(helper.K());
return helper.State();
}

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
Loading
Loading