diff --git a/cmake/patches/abseil/absl_windows.patch b/cmake/patches/abseil/absl_windows.patch index 66ef0c5125a74..584c49d612293 100644 --- a/cmake/patches/abseil/absl_windows.patch +++ b/cmake/patches/abseil/absl_windows.patch @@ -25,17 +25,91 @@ index a6efc98e..8c4de8e7 100644 "/wd4800", ] diff --git a/absl/copts/copts.py b/absl/copts/copts.py -index 0d6c1ec3..75fd935f 100644 +index e6e11949..0aa7d868 100644 --- a/absl/copts/copts.py +++ b/absl/copts/copts.py -@@ -132,10 +132,6 @@ COPT_VARS = { - "/wd4068", # unknown pragma - # qualifier applied to function type has no meaning; ignored - "/wd4180", -- # conversion from 'type1' to 'type2', possible loss of data -- "/wd4244", -- # conversion from 'size_t' to 'type', possible loss of data -- "/wd4267", - # The decorated name was longer than the compiler limit - "/wd4503", - # forcing value to bool 'true' or 'false' (performance warning) +@@ -115,10 +115,6 @@ MSVC_WARNING_FLAGS = [ + "/wd4068", # unknown pragma + # qualifier applied to function type has no meaning; ignored + "/wd4180", +- # conversion from 'type1' to 'type2', possible loss of data +- "/wd4244", +- # conversion from 'size_t' to 'type', possible loss of data +- "/wd4267", + # The decorated name was longer than the compiler limit + "/wd4503", + # forcing value to bool 'true' or 'false' (performance warning) +diff --git a/absl/debugging/symbolize_win32.inc b/absl/debugging/symbolize_win32.inc +index 53a099a1..34d210d6 100644 +--- a/absl/debugging/symbolize_win32.inc ++++ b/absl/debugging/symbolize_win32.inc +@@ -35,15 +35,15 @@ ABSL_NAMESPACE_BEGIN + + static HANDLE process = NULL; + +-void InitializeSymbolizer(const char*) { +- if (process != nullptr) { +- return; +- } ++namespace { ++void InitializeSymbolizerImpl() { ++ + process = GetCurrentProcess(); + + // Symbols are not loaded until a reference is made requiring the + // symbols be loaded. This is the fastest, most efficient way to use + // the symbol handler. ++ + SymSetOptions(SYMOPT_DEFERRED_LOADS | SYMOPT_UNDNAME); + if (!SymInitialize(process, nullptr, true)) { + // GetLastError() returns a Win32 DWORD, but we assign to +@@ -54,6 +54,36 @@ void InitializeSymbolizer(const char*) { + } + } + ++bool LookupAndInitialize(const void* pc, SYMBOL_INFO* symbol) { ++ auto hProcess = (process != NULL) ? process : GetCurrentProcess(); ++ if (SymFromAddr(hProcess, reinterpret_cast(pc), nullptr, symbol) != TRUE) { ++ if (GetLastError() == ERROR_INVALID_HANDLE && process == NULL) { ++ InitializeSymbolizerImpl(); ++ if (SymFromAddr(process, reinterpret_cast(pc), nullptr, symbol) != TRUE) { ++ return false; ++ } ++ } else { ++ return false; ++ } ++ return false; ++ } ++ return true; ++} ++} ++ ++void InitializeSymbolizer(const char*) { ++ if (process != nullptr) { ++ return; ++ } ++ ++ alignas(SYMBOL_INFO) char buf[sizeof(SYMBOL_INFO) + MAX_SYM_NAME]; ++ SYMBOL_INFO* symbol = reinterpret_cast(buf); ++ symbol->SizeOfStruct = sizeof(SYMBOL_INFO); ++ symbol->MaxNameLen = MAX_SYM_NAME; ++ ++ static_cast(LookupAndInitialize(reinterpret_cast(&InitializeSymbolizer), symbol)); ++} ++ + bool Symbolize(const void* pc, char* out, int out_size) { + if (out_size <= 0) { + return false; +@@ -62,9 +92,11 @@ bool Symbolize(const void* pc, char* out, int out_size) { + SYMBOL_INFO* symbol = reinterpret_cast(buf); + symbol->SizeOfStruct = sizeof(SYMBOL_INFO); + symbol->MaxNameLen = MAX_SYM_NAME; +- if (!SymFromAddr(process, reinterpret_cast(pc), nullptr, symbol)) { ++ ++ if(!LookupAndInitialize(pc, symbol)) { + return false; + } ++ + const size_t out_size_t = static_cast(out_size); + strncpy(out, symbol->Name, out_size_t); + if (out[out_size_t - 1] != '\0') { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index 11c8778b72335..080b24a2432aa 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -164,17 +164,14 @@ export const createConv2DTransposeMatMulProgramInfo = const outWidth = isChannelsLast ? outputShape[2] : outputShape[3]; const outHeight = isChannelsLast ? outputShape[1] : outputShape[2]; const outChannels = isChannelsLast ? outputShape[3] : outputShape[1]; - const isVec4 = - isChannelsLast ? inChannels % 4 === 0 && outChannels % 4 === 0 : outWidth % 4 === 0 && outChannels % 4 === 0; + // TODO: enable vec4 for NCHW + const isVec4 = isChannelsLast && (inChannels % 4 === 0 && inChannels % 3) && outChannels % 4 === 0; // TODO: fine tune size const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight; const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels; - const workGroupSize: [number, number, number] = isVec4 ? - [8, 8, 1] : - [(dispatchX <= 4 || dispatchY <= 4) ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1]; - const elementsPerThread = - isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 4, dispatchX > 4 && dispatchY <= 4 ? 1 : 4, 1]; + const workGroupSize: [number, number, number] = [8, 8, 1]; + const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1]; const dispatch = [ Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]), Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]), diff --git a/js/web/test/data/ops/conv-transpose.jsonc b/js/web/test/data/ops/conv-transpose.jsonc index 7038e2a4f8766..8ed48dd07e6f1 100644 --- a/js/web/test/data/ops/conv-transpose.jsonc +++ b/js/web/test/data/ops/conv-transpose.jsonc @@ -392,5 +392,267 @@ ] } ] + }, + { + "name": "ConvTranspose without bias addition C", + "operator": "ConvTranspose", + "attributes": [ + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "strides", "data": [2, 2], "type": "ints" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, + 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, + 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, + 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, + 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, + 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 + ], + "dims": [1, 4, 16, 16], + "type": "float32" + }, + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15 + ], + "dims": [4, 4, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0, 0, 0, 4, 0, 8, 0, 12, 0, 16, 0, 20, 0, 24, 0, 28, 0, 32, 0, 36, 0, 40, 0, 44, 0, 48, 0, 52, 0, 56, 0, + 60, 0, 0, 8, 12, 16, 24, 24, 36, 32, 48, 40, 60, 48, 72, 56, 84, 64, 96, 72, 108, 80, 120, 88, 132, 96, + 144, 104, 156, 112, 168, 120, 180, 0, 64, 0, 68, 0, 72, 0, 76, 0, 80, 0, 84, 0, 88, 0, 92, 0, 96, 0, 100, + 0, 104, 0, 108, 0, 112, 0, 116, 0, 120, 0, 124, 128, 192, 136, 204, 144, 216, 152, 228, 160, 240, 168, + 252, 176, 264, 184, 276, 192, 288, 200, 300, 208, 312, 216, 324, 224, 336, 232, 348, 240, 360, 248, 372, + 0, 0, 0, 4, 0, 8, 0, 12, 0, 16, 0, 20, 0, 24, 0, 28, 0, 32, 0, 36, 0, 40, 0, 44, 0, 48, 0, 52, 0, 56, 0, + 60, 0, 0, 8, 12, 16, 24, 24, 36, 32, 48, 40, 60, 48, 72, 56, 84, 64, 96, 72, 108, 80, 120, 88, 132, 96, + 144, 104, 156, 112, 168, 120, 180, 0, 64, 0, 68, 0, 72, 0, 76, 0, 80, 0, 84, 0, 88, 0, 92, 0, 96, 0, 100, + 0, 104, 0, 108, 0, 112, 0, 116, 0, 120, 0, 124, 128, 192, 136, 204, 144, 216, 152, 228, 160, 240, 168, + 252, 176, 264, 184, 276, 192, 288, 200, 300, 208, 312, 216, 324, 224, 336, 232, 348, 240, 360, 248, 372, + 0, 0, 0, 4, 0, 8, 0, 12, 0, 16, 0, 20, 0, 24, 0, 28, 0, 32, 0, 36, 0, 40, 0, 44, 0, 48, 0, 52, 0, 56, 0, + 60, 0, 0, 8, 12, 16, 24, 24, 36, 32, 48, 40, 60, 48, 72, 56, 84, 64, 96, 72, 108, 80, 120, 88, 132, 96, + 144, 104, 156, 112, 168, 120, 180, 0, 64, 0, 68, 0, 72, 0, 76, 0, 80, 0, 84, 0, 88, 0, 92, 0, 96, 0, 100, + 0, 104, 0, 108, 0, 112, 0, 116, 0, 120, 0, 124, 128, 192, 136, 204, 144, 216, 152, 228, 160, 240, 168, + 252, 176, 264, 184, 276, 192, 288, 200, 300, 208, 312, 216, 324, 224, 336, 232, 348, 240, 360, 248, 372, + 0, 0, 0, 4, 0, 8, 0, 12, 0, 16, 0, 20, 0, 24, 0, 28, 0, 32, 0, 36, 0, 40, 0, 44, 0, 48, 0, 52, 0, 56, 0, + 60, 0, 0, 8, 12, 16, 24, 24, 36, 32, 48, 40, 60, 48, 72, 56, 84, 64, 96, 72, 108, 80, 120, 88, 132, 96, + 144, 104, 156, 112, 168, 120, 180, 0, 64, 0, 68, 0, 72, 0, 76, 0, 80, 0, 84, 0, 88, 0, 92, 0, 96, 0, 100, + 0, 104, 0, 108, 0, 112, 0, 116, 0, 120, 0, 124, 128, 192, 136, 204, 144, 216, 152, 228, 160, 240, 168, + 252, 176, 264, 184, 276, 192, 288, 200, 300, 208, 312, 216, 324, 224, 336, 232, 348, 240, 360, 248, 372, + 0, 0, 0, 4, 0, 8, 0, 12, 0, 16, 0, 20, 0, 24, 0, 28, 0, 32, 0, 36, 0, 40, 0, 44, 0, 48, 0, 52, 0, 56, 0, + 60, 0, 0, 8, 12, 16, 24, 24, 36, 32, 48, 40, 60, 48, 72, 56, 84, 64, 96, 72, 108, 80, 120, 88, 132, 96, + 144, 104, 156, 112, 168, 120, 180, 0, 64, 0, 68, 0, 72, 0, 76, 0, 80, 0, 84, 0, 88, 0, 92, 0, 96, 0, 100, + 0, 104, 0, 108, 0, 112, 0, 116, 0, 120, 0, 124, 128, 192, 136, 204, 144, 216, 152, 228, 160, 240, 168, + 252, 176, 264, 184, 276, 192, 288, 200, 300, 208, 312, 216, 324, 224, 336, 232, 348, 240, 360, 248, 372, + 0, 0, 0, 4, 0, 8, 0, 12, 0, 16, 0, 20, 0, 24, 0, 28, 0, 32, 0, 36, 0, 40, 0, 44, 0, 48, 0, 52, 0, 56, 0, + 60, 0, 0, 8, 12, 16, 24, 24, 36, 32, 48, 40, 60, 48, 72, 56, 84, 64, 96, 72, 108, 80, 120, 88, 132, 96, + 144, 104, 156, 112, 168, 120, 180, 0, 64, 0, 68, 0, 72, 0, 76, 0, 80, 0, 84, 0, 88, 0, 92, 0, 96, 0, 100, + 0, 104, 0, 108, 0, 112, 0, 116, 0, 120, 0, 124, 128, 192, 136, 204, 144, 216, 152, 228, 160, 240, 168, + 252, 176, 264, 184, 276, 192, 288, 200, 300, 208, 312, 216, 324, 224, 336, 232, 348, 240, 360, 248, 372, + 0, 0, 0, 4, 0, 8, 0, 12, 0, 16, 0, 20, 0, 24, 0, 28, 0, 32, 0, 36, 0, 40, 0, 44, 0, 48, 0, 52, 0, 56, 0, + 60, 0, 0, 8, 12, 16, 24, 24, 36, 32, 48, 40, 60, 48, 72, 56, 84, 64, 96, 72, 108, 80, 120, 88, 132, 96, + 144, 104, 156, 112, 168, 120, 180, 0, 64, 0, 68, 0, 72, 0, 76, 0, 80, 0, 84, 0, 88, 0, 92, 0, 96, 0, 100, + 0, 104, 0, 108, 0, 112, 0, 116, 0, 120, 0, 124, 128, 192, 136, 204, 144, 216, 152, 228, 160, 240, 168, + 252, 176, 264, 184, 276, 192, 288, 200, 300, 208, 312, 216, 324, 224, 336, 232, 348, 240, 360, 248, 372, + 0, 0, 0, 4, 0, 8, 0, 12, 0, 16, 0, 20, 0, 24, 0, 28, 0, 32, 0, 36, 0, 40, 0, 44, 0, 48, 0, 52, 0, 56, 0, + 60, 0, 0, 8, 12, 16, 24, 24, 36, 32, 48, 40, 60, 48, 72, 56, 84, 64, 96, 72, 108, 80, 120, 88, 132, 96, + 144, 104, 156, 112, 168, 120, 180, 0, 64, 0, 68, 0, 72, 0, 76, 0, 80, 0, 84, 0, 88, 0, 92, 0, 96, 0, 100, + 0, 104, 0, 108, 0, 112, 0, 116, 0, 120, 0, 124, 128, 192, 136, 204, 144, 216, 152, 228, 160, 240, 168, + 252, 176, 264, 184, 276, 192, 288, 200, 300, 208, 312, 216, 324, 224, 336, 232, 348, 240, 360, 248, 372, + 0, 0, 16, 20, 32, 40, 48, 60, 64, 80, 80, 100, 96, 120, 112, 140, 128, 160, 144, 180, 160, 200, 176, 220, + 192, 240, 208, 260, 224, 280, 240, 300, 0, 0, 24, 28, 48, 56, 72, 84, 96, 112, 120, 140, 144, 168, 168, + 196, 192, 224, 216, 252, 240, 280, 264, 308, 288, 336, 312, 364, 336, 392, 360, 420, 256, 320, 272, 340, + 288, 360, 304, 380, 320, 400, 336, 420, 352, 440, 368, 460, 384, 480, 400, 500, 416, 520, 432, 540, 448, + 560, 464, 580, 480, 600, 496, 620, 384, 448, 408, 476, 432, 504, 456, 532, 480, 560, 504, 588, 528, 616, + 552, 644, 576, 672, 600, 700, 624, 728, 648, 756, 672, 784, 696, 812, 720, 840, 744, 868, 0, 0, 16, 20, + 32, 40, 48, 60, 64, 80, 80, 100, 96, 120, 112, 140, 128, 160, 144, 180, 160, 200, 176, 220, 192, 240, 208, + 260, 224, 280, 240, 300, 0, 0, 24, 28, 48, 56, 72, 84, 96, 112, 120, 140, 144, 168, 168, 196, 192, 224, + 216, 252, 240, 280, 264, 308, 288, 336, 312, 364, 336, 392, 360, 420, 256, 320, 272, 340, 288, 360, 304, + 380, 320, 400, 336, 420, 352, 440, 368, 460, 384, 480, 400, 500, 416, 520, 432, 540, 448, 560, 464, 580, + 480, 600, 496, 620, 384, 448, 408, 476, 432, 504, 456, 532, 480, 560, 504, 588, 528, 616, 552, 644, 576, + 672, 600, 700, 624, 728, 648, 756, 672, 784, 696, 812, 720, 840, 744, 868, 0, 0, 16, 20, 32, 40, 48, 60, + 64, 80, 80, 100, 96, 120, 112, 140, 128, 160, 144, 180, 160, 200, 176, 220, 192, 240, 208, 260, 224, 280, + 240, 300, 0, 0, 24, 28, 48, 56, 72, 84, 96, 112, 120, 140, 144, 168, 168, 196, 192, 224, 216, 252, 240, + 280, 264, 308, 288, 336, 312, 364, 336, 392, 360, 420, 256, 320, 272, 340, 288, 360, 304, 380, 320, 400, + 336, 420, 352, 440, 368, 460, 384, 480, 400, 500, 416, 520, 432, 540, 448, 560, 464, 580, 480, 600, 496, + 620, 384, 448, 408, 476, 432, 504, 456, 532, 480, 560, 504, 588, 528, 616, 552, 644, 576, 672, 600, 700, + 624, 728, 648, 756, 672, 784, 696, 812, 720, 840, 744, 868, 0, 0, 16, 20, 32, 40, 48, 60, 64, 80, 80, 100, + 96, 120, 112, 140, 128, 160, 144, 180, 160, 200, 176, 220, 192, 240, 208, 260, 224, 280, 240, 300, 0, 0, + 24, 28, 48, 56, 72, 84, 96, 112, 120, 140, 144, 168, 168, 196, 192, 224, 216, 252, 240, 280, 264, 308, + 288, 336, 312, 364, 336, 392, 360, 420, 256, 320, 272, 340, 288, 360, 304, 380, 320, 400, 336, 420, 352, + 440, 368, 460, 384, 480, 400, 500, 416, 520, 432, 540, 448, 560, 464, 580, 480, 600, 496, 620, 384, 448, + 408, 476, 432, 504, 456, 532, 480, 560, 504, 588, 528, 616, 552, 644, 576, 672, 600, 700, 624, 728, 648, + 756, 672, 784, 696, 812, 720, 840, 744, 868, 0, 0, 16, 20, 32, 40, 48, 60, 64, 80, 80, 100, 96, 120, 112, + 140, 128, 160, 144, 180, 160, 200, 176, 220, 192, 240, 208, 260, 224, 280, 240, 300, 0, 0, 24, 28, 48, 56, + 72, 84, 96, 112, 120, 140, 144, 168, 168, 196, 192, 224, 216, 252, 240, 280, 264, 308, 288, 336, 312, 364, + 336, 392, 360, 420, 256, 320, 272, 340, 288, 360, 304, 380, 320, 400, 336, 420, 352, 440, 368, 460, 384, + 480, 400, 500, 416, 520, 432, 540, 448, 560, 464, 580, 480, 600, 496, 620, 384, 448, 408, 476, 432, 504, + 456, 532, 480, 560, 504, 588, 528, 616, 552, 644, 576, 672, 600, 700, 624, 728, 648, 756, 672, 784, 696, + 812, 720, 840, 744, 868, 0, 0, 16, 20, 32, 40, 48, 60, 64, 80, 80, 100, 96, 120, 112, 140, 128, 160, 144, + 180, 160, 200, 176, 220, 192, 240, 208, 260, 224, 280, 240, 300, 0, 0, 24, 28, 48, 56, 72, 84, 96, 112, + 120, 140, 144, 168, 168, 196, 192, 224, 216, 252, 240, 280, 264, 308, 288, 336, 312, 364, 336, 392, 360, + 420, 256, 320, 272, 340, 288, 360, 304, 380, 320, 400, 336, 420, 352, 440, 368, 460, 384, 480, 400, 500, + 416, 520, 432, 540, 448, 560, 464, 580, 480, 600, 496, 620, 384, 448, 408, 476, 432, 504, 456, 532, 480, + 560, 504, 588, 528, 616, 552, 644, 576, 672, 600, 700, 624, 728, 648, 756, 672, 784, 696, 812, 720, 840, + 744, 868, 0, 0, 16, 20, 32, 40, 48, 60, 64, 80, 80, 100, 96, 120, 112, 140, 128, 160, 144, 180, 160, 200, + 176, 220, 192, 240, 208, 260, 224, 280, 240, 300, 0, 0, 24, 28, 48, 56, 72, 84, 96, 112, 120, 140, 144, + 168, 168, 196, 192, 224, 216, 252, 240, 280, 264, 308, 288, 336, 312, 364, 336, 392, 360, 420, 256, 320, + 272, 340, 288, 360, 304, 380, 320, 400, 336, 420, 352, 440, 368, 460, 384, 480, 400, 500, 416, 520, 432, + 540, 448, 560, 464, 580, 480, 600, 496, 620, 384, 448, 408, 476, 432, 504, 456, 532, 480, 560, 504, 588, + 528, 616, 552, 644, 576, 672, 600, 700, 624, 728, 648, 756, 672, 784, 696, 812, 720, 840, 744, 868, 0, 0, + 16, 20, 32, 40, 48, 60, 64, 80, 80, 100, 96, 120, 112, 140, 128, 160, 144, 180, 160, 200, 176, 220, 192, + 240, 208, 260, 224, 280, 240, 300, 0, 0, 24, 28, 48, 56, 72, 84, 96, 112, 120, 140, 144, 168, 168, 196, + 192, 224, 216, 252, 240, 280, 264, 308, 288, 336, 312, 364, 336, 392, 360, 420, 256, 320, 272, 340, 288, + 360, 304, 380, 320, 400, 336, 420, 352, 440, 368, 460, 384, 480, 400, 500, 416, 520, 432, 540, 448, 560, + 464, 580, 480, 600, 496, 620, 384, 448, 408, 476, 432, 504, 456, 532, 480, 560, 504, 588, 528, 616, 552, + 644, 576, 672, 600, 700, 624, 728, 648, 756, 672, 784, 696, 812, 720, 840, 744, 868, 0, 0, 32, 36, 64, 72, + 96, 108, 128, 144, 160, 180, 192, 216, 224, 252, 256, 288, 288, 324, 320, 360, 352, 396, 384, 432, 416, + 468, 448, 504, 480, 540, 0, 0, 40, 44, 80, 88, 120, 132, 160, 176, 200, 220, 240, 264, 280, 308, 320, 352, + 360, 396, 400, 440, 440, 484, 480, 528, 520, 572, 560, 616, 600, 660, 512, 576, 544, 612, 576, 648, 608, + 684, 640, 720, 672, 756, 704, 792, 736, 828, 768, 864, 800, 900, 832, 936, 864, 972, 896, 1008, 928, 1044, + 960, 1080, 992, 1116, 640, 704, 680, 748, 720, 792, 760, 836, 800, 880, 840, 924, 880, 968, 920, 1012, + 960, 1056, 1000, 1100, 1040, 1144, 1080, 1188, 1120, 1232, 1160, 1276, 1200, 1320, 1240, 1364, 0, 0, 32, + 36, 64, 72, 96, 108, 128, 144, 160, 180, 192, 216, 224, 252, 256, 288, 288, 324, 320, 360, 352, 396, 384, + 432, 416, 468, 448, 504, 480, 540, 0, 0, 40, 44, 80, 88, 120, 132, 160, 176, 200, 220, 240, 264, 280, 308, + 320, 352, 360, 396, 400, 440, 440, 484, 480, 528, 520, 572, 560, 616, 600, 660, 512, 576, 544, 612, 576, + 648, 608, 684, 640, 720, 672, 756, 704, 792, 736, 828, 768, 864, 800, 900, 832, 936, 864, 972, 896, 1008, + 928, 1044, 960, 1080, 992, 1116, 640, 704, 680, 748, 720, 792, 760, 836, 800, 880, 840, 924, 880, 968, + 920, 1012, 960, 1056, 1000, 1100, 1040, 1144, 1080, 1188, 1120, 1232, 1160, 1276, 1200, 1320, 1240, 1364, + 0, 0, 32, 36, 64, 72, 96, 108, 128, 144, 160, 180, 192, 216, 224, 252, 256, 288, 288, 324, 320, 360, 352, + 396, 384, 432, 416, 468, 448, 504, 480, 540, 0, 0, 40, 44, 80, 88, 120, 132, 160, 176, 200, 220, 240, 264, + 280, 308, 320, 352, 360, 396, 400, 440, 440, 484, 480, 528, 520, 572, 560, 616, 600, 660, 512, 576, 544, + 612, 576, 648, 608, 684, 640, 720, 672, 756, 704, 792, 736, 828, 768, 864, 800, 900, 832, 936, 864, 972, + 896, 1008, 928, 1044, 960, 1080, 992, 1116, 640, 704, 680, 748, 720, 792, 760, 836, 800, 880, 840, 924, + 880, 968, 920, 1012, 960, 1056, 1000, 1100, 1040, 1144, 1080, 1188, 1120, 1232, 1160, 1276, 1200, 1320, + 1240, 1364, 0, 0, 32, 36, 64, 72, 96, 108, 128, 144, 160, 180, 192, 216, 224, 252, 256, 288, 288, 324, + 320, 360, 352, 396, 384, 432, 416, 468, 448, 504, 480, 540, 0, 0, 40, 44, 80, 88, 120, 132, 160, 176, 200, + 220, 240, 264, 280, 308, 320, 352, 360, 396, 400, 440, 440, 484, 480, 528, 520, 572, 560, 616, 600, 660, + 512, 576, 544, 612, 576, 648, 608, 684, 640, 720, 672, 756, 704, 792, 736, 828, 768, 864, 800, 900, 832, + 936, 864, 972, 896, 1008, 928, 1044, 960, 1080, 992, 1116, 640, 704, 680, 748, 720, 792, 760, 836, 800, + 880, 840, 924, 880, 968, 920, 1012, 960, 1056, 1000, 1100, 1040, 1144, 1080, 1188, 1120, 1232, 1160, 1276, + 1200, 1320, 1240, 1364, 0, 0, 32, 36, 64, 72, 96, 108, 128, 144, 160, 180, 192, 216, 224, 252, 256, 288, + 288, 324, 320, 360, 352, 396, 384, 432, 416, 468, 448, 504, 480, 540, 0, 0, 40, 44, 80, 88, 120, 132, 160, + 176, 200, 220, 240, 264, 280, 308, 320, 352, 360, 396, 400, 440, 440, 484, 480, 528, 520, 572, 560, 616, + 600, 660, 512, 576, 544, 612, 576, 648, 608, 684, 640, 720, 672, 756, 704, 792, 736, 828, 768, 864, 800, + 900, 832, 936, 864, 972, 896, 1008, 928, 1044, 960, 1080, 992, 1116, 640, 704, 680, 748, 720, 792, 760, + 836, 800, 880, 840, 924, 880, 968, 920, 1012, 960, 1056, 1000, 1100, 1040, 1144, 1080, 1188, 1120, 1232, + 1160, 1276, 1200, 1320, 1240, 1364, 0, 0, 32, 36, 64, 72, 96, 108, 128, 144, 160, 180, 192, 216, 224, 252, + 256, 288, 288, 324, 320, 360, 352, 396, 384, 432, 416, 468, 448, 504, 480, 540, 0, 0, 40, 44, 80, 88, 120, + 132, 160, 176, 200, 220, 240, 264, 280, 308, 320, 352, 360, 396, 400, 440, 440, 484, 480, 528, 520, 572, + 560, 616, 600, 660, 512, 576, 544, 612, 576, 648, 608, 684, 640, 720, 672, 756, 704, 792, 736, 828, 768, + 864, 800, 900, 832, 936, 864, 972, 896, 1008, 928, 1044, 960, 1080, 992, 1116, 640, 704, 680, 748, 720, + 792, 760, 836, 800, 880, 840, 924, 880, 968, 920, 1012, 960, 1056, 1000, 1100, 1040, 1144, 1080, 1188, + 1120, 1232, 1160, 1276, 1200, 1320, 1240, 1364, 0, 0, 32, 36, 64, 72, 96, 108, 128, 144, 160, 180, 192, + 216, 224, 252, 256, 288, 288, 324, 320, 360, 352, 396, 384, 432, 416, 468, 448, 504, 480, 540, 0, 0, 40, + 44, 80, 88, 120, 132, 160, 176, 200, 220, 240, 264, 280, 308, 320, 352, 360, 396, 400, 440, 440, 484, 480, + 528, 520, 572, 560, 616, 600, 660, 512, 576, 544, 612, 576, 648, 608, 684, 640, 720, 672, 756, 704, 792, + 736, 828, 768, 864, 800, 900, 832, 936, 864, 972, 896, 1008, 928, 1044, 960, 1080, 992, 1116, 640, 704, + 680, 748, 720, 792, 760, 836, 800, 880, 840, 924, 880, 968, 920, 1012, 960, 1056, 1000, 1100, 1040, 1144, + 1080, 1188, 1120, 1232, 1160, 1276, 1200, 1320, 1240, 1364, 0, 0, 32, 36, 64, 72, 96, 108, 128, 144, 160, + 180, 192, 216, 224, 252, 256, 288, 288, 324, 320, 360, 352, 396, 384, 432, 416, 468, 448, 504, 480, 540, + 0, 0, 40, 44, 80, 88, 120, 132, 160, 176, 200, 220, 240, 264, 280, 308, 320, 352, 360, 396, 400, 440, 440, + 484, 480, 528, 520, 572, 560, 616, 600, 660, 512, 576, 544, 612, 576, 648, 608, 684, 640, 720, 672, 756, + 704, 792, 736, 828, 768, 864, 800, 900, 832, 936, 864, 972, 896, 1008, 928, 1044, 960, 1080, 992, 1116, + 640, 704, 680, 748, 720, 792, 760, 836, 800, 880, 840, 924, 880, 968, 920, 1012, 960, 1056, 1000, 1100, + 1040, 1144, 1080, 1188, 1120, 1232, 1160, 1276, 1200, 1320, 1240, 1364, 0, 0, 48, 52, 96, 104, 144, 156, + 192, 208, 240, 260, 288, 312, 336, 364, 384, 416, 432, 468, 480, 520, 528, 572, 576, 624, 624, 676, 672, + 728, 720, 780, 0, 0, 56, 60, 112, 120, 168, 180, 224, 240, 280, 300, 336, 360, 392, 420, 448, 480, 504, + 540, 560, 600, 616, 660, 672, 720, 728, 780, 784, 840, 840, 900, 768, 832, 816, 884, 864, 936, 912, 988, + 960, 1040, 1008, 1092, 1056, 1144, 1104, 1196, 1152, 1248, 1200, 1300, 1248, 1352, 1296, 1404, 1344, 1456, + 1392, 1508, 1440, 1560, 1488, 1612, 896, 960, 952, 1020, 1008, 1080, 1064, 1140, 1120, 1200, 1176, 1260, + 1232, 1320, 1288, 1380, 1344, 1440, 1400, 1500, 1456, 1560, 1512, 1620, 1568, 1680, 1624, 1740, 1680, + 1800, 1736, 1860, 0, 0, 48, 52, 96, 104, 144, 156, 192, 208, 240, 260, 288, 312, 336, 364, 384, 416, 432, + 468, 480, 520, 528, 572, 576, 624, 624, 676, 672, 728, 720, 780, 0, 0, 56, 60, 112, 120, 168, 180, 224, + 240, 280, 300, 336, 360, 392, 420, 448, 480, 504, 540, 560, 600, 616, 660, 672, 720, 728, 780, 784, 840, + 840, 900, 768, 832, 816, 884, 864, 936, 912, 988, 960, 1040, 1008, 1092, 1056, 1144, 1104, 1196, 1152, + 1248, 1200, 1300, 1248, 1352, 1296, 1404, 1344, 1456, 1392, 1508, 1440, 1560, 1488, 1612, 896, 960, 952, + 1020, 1008, 1080, 1064, 1140, 1120, 1200, 1176, 1260, 1232, 1320, 1288, 1380, 1344, 1440, 1400, 1500, + 1456, 1560, 1512, 1620, 1568, 1680, 1624, 1740, 1680, 1800, 1736, 1860, 0, 0, 48, 52, 96, 104, 144, 156, + 192, 208, 240, 260, 288, 312, 336, 364, 384, 416, 432, 468, 480, 520, 528, 572, 576, 624, 624, 676, 672, + 728, 720, 780, 0, 0, 56, 60, 112, 120, 168, 180, 224, 240, 280, 300, 336, 360, 392, 420, 448, 480, 504, + 540, 560, 600, 616, 660, 672, 720, 728, 780, 784, 840, 840, 900, 768, 832, 816, 884, 864, 936, 912, 988, + 960, 1040, 1008, 1092, 1056, 1144, 1104, 1196, 1152, 1248, 1200, 1300, 1248, 1352, 1296, 1404, 1344, 1456, + 1392, 1508, 1440, 1560, 1488, 1612, 896, 960, 952, 1020, 1008, 1080, 1064, 1140, 1120, 1200, 1176, 1260, + 1232, 1320, 1288, 1380, 1344, 1440, 1400, 1500, 1456, 1560, 1512, 1620, 1568, 1680, 1624, 1740, 1680, + 1800, 1736, 1860, 0, 0, 48, 52, 96, 104, 144, 156, 192, 208, 240, 260, 288, 312, 336, 364, 384, 416, 432, + 468, 480, 520, 528, 572, 576, 624, 624, 676, 672, 728, 720, 780, 0, 0, 56, 60, 112, 120, 168, 180, 224, + 240, 280, 300, 336, 360, 392, 420, 448, 480, 504, 540, 560, 600, 616, 660, 672, 720, 728, 780, 784, 840, + 840, 900, 768, 832, 816, 884, 864, 936, 912, 988, 960, 1040, 1008, 1092, 1056, 1144, 1104, 1196, 1152, + 1248, 1200, 1300, 1248, 1352, 1296, 1404, 1344, 1456, 1392, 1508, 1440, 1560, 1488, 1612, 896, 960, 952, + 1020, 1008, 1080, 1064, 1140, 1120, 1200, 1176, 1260, 1232, 1320, 1288, 1380, 1344, 1440, 1400, 1500, + 1456, 1560, 1512, 1620, 1568, 1680, 1624, 1740, 1680, 1800, 1736, 1860, 0, 0, 48, 52, 96, 104, 144, 156, + 192, 208, 240, 260, 288, 312, 336, 364, 384, 416, 432, 468, 480, 520, 528, 572, 576, 624, 624, 676, 672, + 728, 720, 780, 0, 0, 56, 60, 112, 120, 168, 180, 224, 240, 280, 300, 336, 360, 392, 420, 448, 480, 504, + 540, 560, 600, 616, 660, 672, 720, 728, 780, 784, 840, 840, 900, 768, 832, 816, 884, 864, 936, 912, 988, + 960, 1040, 1008, 1092, 1056, 1144, 1104, 1196, 1152, 1248, 1200, 1300, 1248, 1352, 1296, 1404, 1344, 1456, + 1392, 1508, 1440, 1560, 1488, 1612, 896, 960, 952, 1020, 1008, 1080, 1064, 1140, 1120, 1200, 1176, 1260, + 1232, 1320, 1288, 1380, 1344, 1440, 1400, 1500, 1456, 1560, 1512, 1620, 1568, 1680, 1624, 1740, 1680, + 1800, 1736, 1860, 0, 0, 48, 52, 96, 104, 144, 156, 192, 208, 240, 260, 288, 312, 336, 364, 384, 416, 432, + 468, 480, 520, 528, 572, 576, 624, 624, 676, 672, 728, 720, 780, 0, 0, 56, 60, 112, 120, 168, 180, 224, + 240, 280, 300, 336, 360, 392, 420, 448, 480, 504, 540, 560, 600, 616, 660, 672, 720, 728, 780, 784, 840, + 840, 900, 768, 832, 816, 884, 864, 936, 912, 988, 960, 1040, 1008, 1092, 1056, 1144, 1104, 1196, 1152, + 1248, 1200, 1300, 1248, 1352, 1296, 1404, 1344, 1456, 1392, 1508, 1440, 1560, 1488, 1612, 896, 960, 952, + 1020, 1008, 1080, 1064, 1140, 1120, 1200, 1176, 1260, 1232, 1320, 1288, 1380, 1344, 1440, 1400, 1500, + 1456, 1560, 1512, 1620, 1568, 1680, 1624, 1740, 1680, 1800, 1736, 1860, 0, 0, 48, 52, 96, 104, 144, 156, + 192, 208, 240, 260, 288, 312, 336, 364, 384, 416, 432, 468, 480, 520, 528, 572, 576, 624, 624, 676, 672, + 728, 720, 780, 0, 0, 56, 60, 112, 120, 168, 180, 224, 240, 280, 300, 336, 360, 392, 420, 448, 480, 504, + 540, 560, 600, 616, 660, 672, 720, 728, 780, 784, 840, 840, 900, 768, 832, 816, 884, 864, 936, 912, 988, + 960, 1040, 1008, 1092, 1056, 1144, 1104, 1196, 1152, 1248, 1200, 1300, 1248, 1352, 1296, 1404, 1344, 1456, + 1392, 1508, 1440, 1560, 1488, 1612, 896, 960, 952, 1020, 1008, 1080, 1064, 1140, 1120, 1200, 1176, 1260, + 1232, 1320, 1288, 1380, 1344, 1440, 1400, 1500, 1456, 1560, 1512, 1620, 1568, 1680, 1624, 1740, 1680, + 1800, 1736, 1860, 0, 0, 48, 52, 96, 104, 144, 156, 192, 208, 240, 260, 288, 312, 336, 364, 384, 416, 432, + 468, 480, 520, 528, 572, 576, 624, 624, 676, 672, 728, 720, 780, 0, 0, 56, 60, 112, 120, 168, 180, 224, + 240, 280, 300, 336, 360, 392, 420, 448, 480, 504, 540, 560, 600, 616, 660, 672, 720, 728, 780, 784, 840, + 840, 900, 768, 832, 816, 884, 864, 936, 912, 988, 960, 1040, 1008, 1092, 1056, 1144, 1104, 1196, 1152, + 1248, 1200, 1300, 1248, 1352, 1296, 1404, 1344, 1456, 1392, 1508, 1440, 1560, 1488, 1612, 896, 960, 952, + 1020, 1008, 1080, 1064, 1140, 1120, 1200, 1176, 1260, 1232, 1320, 1288, 1380, 1344, 1440, 1400, 1500, + 1456, 1560, 1512, 1620, 1568, 1680, 1624, 1740, 1680, 1800, 1736, 1860 + ], + "dims": [1, 4, 32, 32], + "type": "float32" + } + ] + } + ] } ] diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 814aa1fb3c8f0..112f609d46598 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -159,8 +159,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { !use_flash_attention && !disable_memory_efficient_attention_ && local_window_size_ == -1 && - do_rotary_ == false && - key != nullptr && (parameters.head_size & 7) == 0 && parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length && (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && @@ -172,18 +170,31 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { if (use_memory_efficient_attention && needs_buff) { kv_buffer_bytes = (sizeof(T) * parameters.batch_size * parameters.num_heads * parameters.seqlen_present_kv_cache * parameters.head_size); } + size_t rotary_buffer_bytes = 0; + if (use_memory_efficient_attention && do_rotary_) { + rotary_buffer_bytes = 2 * sizeof(T) * parameters.batch_size * parameters.num_heads * parameters.sequence_length * parameters.head_size; + rotary_buffer_bytes += sizeof(int64_t) * parameters.batch_size * parameters.sequence_length; + } size_t fmha_buffer_bytes = 0; if (use_memory_efficient_attention && MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) { fmha_buffer_bytes = (parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(float)); } + size_t unpacked_qkv_bytes = 0; + if (use_memory_efficient_attention && parameters.is_packed_qkv) { + unpacked_qkv_bytes = (parameters.batch_size * parameters.sequence_length * (parameters.num_heads + 2 * parameters.kv_num_heads) * parameters.head_size * sizeof(T)); + } auto k_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); auto v_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); + auto rotary_buffer = GetScratchBuffer(rotary_buffer_bytes, context->GetComputeStream()); auto fmha_buffer = GetScratchBuffer(fmha_buffer_bytes, context->GetComputeStream()); + auto unpacked_qkv_buffer = GetScratchBuffer(unpacked_qkv_bytes, context->GetComputeStream()); #else constexpr bool use_memory_efficient_attention = false; auto k_buffer = GetScratchBuffer(0, context->GetComputeStream()); auto v_buffer = GetScratchBuffer(0, context->GetComputeStream()); + auto rotary_buffer = GetScratchBuffer(0, context->GetComputeStream()); auto fmha_buffer = GetScratchBuffer(0, context->GetComputeStream()); + auto unpacked_qkv_buffer = GetScratchBuffer(0, context->GetComputeStream()); #endif // seqlens_k buffer @@ -251,7 +262,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { if (fmha_buffer != nullptr) { data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); } - // Rotary + if (unpacked_qkv_buffer != nullptr) { + data.unpacked_qkv_buffer = reinterpret_cast(unpacked_qkv_buffer.get()); + } + if (rotary_buffer != nullptr) { + data.rotary_buffer = reinterpret_cast(rotary_buffer.get()); + } + // Rotary Embedding if (parameters.do_rotary) { data.cos_cache = reinterpret_cast(cos_cache->Data()); data.sin_cache = reinterpret_cast(sin_cache->Data()); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index afba83be34e2d..f519be1c97149 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -42,6 +42,7 @@ limitations under the License. #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cuda/bert/attention_impl.h" #include "core/providers/cuda/shared_inc/cuda_call.h" +#include "contrib_ops/cuda/bert/rotary_embedding_impl.h" #include using namespace onnxruntime::cuda; @@ -150,6 +151,8 @@ __global__ void ConcatNewToPastKVLarge(const int new_seqlen, template Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data, + const void* new_key, + const void* new_value, cudaStream_t stream, const int max_threads_per_block, const bool past_only = false) { @@ -171,14 +174,14 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter ConcatNewToPastKV<<>>(kv_sequence_length, past_sequence_length, reinterpret_cast(data.past_key), - reinterpret_cast(data.key), + reinterpret_cast(new_key), reinterpret_cast(data.present_key), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); ConcatNewToPastKV<<>>(kv_sequence_length, past_sequence_length, reinterpret_cast(data.past_value), - reinterpret_cast(data.value), + reinterpret_cast(new_value), reinterpret_cast(data.present_value), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); @@ -191,7 +194,7 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter H, kv_num_heads, reinterpret_cast(data.past_key), - reinterpret_cast(data.key), + reinterpret_cast(new_key), reinterpret_cast(data.present_key), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); @@ -200,7 +203,7 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter H, kv_num_heads, reinterpret_cast(data.past_value), - reinterpret_cast(data.value), + reinterpret_cast(new_value), reinterpret_cast(data.present_value), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); @@ -281,6 +284,8 @@ __global__ void ConcatKVInPlaceLarge(const int max_seqlen, template Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data, + const void* new_key, + const void* new_value, cudaStream_t stream, const int max_threads_per_block) { const int batch_size = parameters.batch_size; @@ -300,12 +305,12 @@ Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, const dim3 block(H, kv_num_heads, 1); ConcatKVInPlace<<>>(present_sequence_length, reinterpret_cast(data.present_key), - reinterpret_cast(data.key), + reinterpret_cast(new_key), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); ConcatKVInPlace<<>>(present_sequence_length, reinterpret_cast(data.present_value), - reinterpret_cast(data.value), + reinterpret_cast(new_value), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); } else { @@ -316,14 +321,14 @@ Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, H, kv_num_heads, reinterpret_cast(data.present_key), - reinterpret_cast(data.key), + reinterpret_cast(new_key), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); ConcatKVInPlaceLarge<<>>(present_sequence_length, H, kv_num_heads, reinterpret_cast(data.present_value), - reinterpret_cast(data.value), + reinterpret_cast(new_value), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); } @@ -468,6 +473,83 @@ Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, i return CUDA_CALL(cudaGetLastError()); } +// Kernel to unpack qkv from packed qkv +template +__global__ void UnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unpacked_v, const int num_heads, + const int kv_num_heads, const int head_size, const int sequence_length, + const int batch_size) { + const int tid = threadIdx.x + blockIdx.x * blockDim.x; + int d = (num_heads + 2 * kv_num_heads) * head_size; + const int qkv_size = batch_size * sequence_length * d; + const int q_size = num_heads * head_size; + const int k_size = kv_num_heads * head_size; + if (tid < qkv_size) { + int batch = tid / (d * sequence_length); + int sequence = (tid % (d * sequence_length)) / d; + int offset = tid % d; + if (offset < q_size) { + int unpacked_i = batch * sequence_length * num_heads * head_size + sequence * num_heads * head_size + offset; + unpacked_q[unpacked_i] = packed_qkv[tid]; + } else if (offset < q_size + k_size) { + int unpacked_i = batch * sequence_length * kv_num_heads * head_size + sequence * kv_num_heads * head_size + (offset - q_size); + unpacked_k[unpacked_i] = packed_qkv[tid]; + } else { + int unpacked_i = batch * sequence_length * kv_num_heads * head_size + sequence * kv_num_heads * head_size + (offset - q_size - k_size); + unpacked_v[unpacked_i] = packed_qkv[tid]; + } + } +} + +// Unpack packed qkv +template +Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unpacked_v, const int num_heads, + const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, + cudaStream_t stream, const int max_threads_per_block) { + const int threads = max_threads_per_block; + const int blocks = (batch_size * sequence_length * (num_heads + 2 * kv_num_heads) * head_size + threads - 1) / threads; + UnpackQKV<<>>(packed_qkv, unpacked_q, unpacked_k, unpacked_v, num_heads, kv_num_heads, + head_size, sequence_length, batch_size); + return CUDA_CALL(cudaGetLastError()); +} + +// Kernel to convert seqlens_k to position_ids +__global__ void SeqlensToPosIdsPrompt(int32_t* seqlens_k, int64_t* position_ids, const int seqlen, + const int batch_size) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + int b = tid / seqlen; + int s = tid % seqlen; + if (b < batch_size) { + if (s < seqlens_k[b] + 1) { + position_ids[tid] = s; + } else { + position_ids[tid] = 1; + } + } +} + +// Kernel to convert seqlens_k to position_ids +__global__ void SeqlensToPosIdsToken(int32_t* seqlens_k, int64_t* position_ids, const int batch_size) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid < batch_size) { + position_ids[tid] = seqlens_k[tid]; + } +} + +// Convert seqlens_k to position_ids +Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, + int64_t* position_ids, cudaStream_t stream, const int max_threads_per_block) { + const int seqlen = parameters.sequence_length; + const int batch_size = parameters.batch_size; + const int threads = max_threads_per_block; + const int blocks = (batch_size * seqlen + threads - 1) / threads; + if (parameters.is_prompt) { + SeqlensToPosIdsPrompt<<>>(seqlens_k, position_ids, seqlen, batch_size); + } else { + SeqlensToPosIdsToken<<>>(seqlens_k, position_ids, batch_size); + } + return CUDA_CALL(cudaGetLastError()); +} + ////////// Launch Kernels #if USE_FLASH_ATTENTION @@ -517,7 +599,8 @@ Status FlashAttention( seqlens_k = data.seqlens_k_total; } } else if (!parameters.kv_share_buffer) { // copy past kv to present kv - ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block, true)); + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, nullptr, nullptr, stream, max_threads_per_block, + true)); } void* present_key = reinterpret_cast(const_cast(data.present_key)); @@ -563,15 +646,62 @@ Status EfficientAttention( const int head_size = parameters.head_size; AttentionQkvFormat past_kv_format = parameters.past_kv_format; - const void* query = reinterpret_cast(data.query); - const void* key = reinterpret_cast(data.key); - const void* value = reinterpret_cast(data.value); + const void* query; + const void* key; + const void* value; + + if (!parameters.is_packed_qkv) { + query = reinterpret_cast(data.query); + key = reinterpret_cast(data.key); + value = reinterpret_cast(data.value); + } else { + size_t q_size = static_cast(batch_size * sequence_length * num_heads * head_size); + size_t k_size = static_cast(batch_size * sequence_length * kv_num_heads * head_size); + auto q = reinterpret_cast(data.unpacked_qkv_buffer); + auto k = reinterpret_cast(data.unpacked_qkv_buffer + q_size); + auto v = reinterpret_cast(data.unpacked_qkv_buffer + q_size + k_size); + ORT_RETURN_IF_ERROR(LaunchUnpackQKV(reinterpret_cast(data.query), q, k, v, num_heads, kv_num_heads, + head_size, sequence_length, batch_size, stream, max_threads_per_block)); + query = reinterpret_cast(q); + key = reinterpret_cast(k); + value = reinterpret_cast(v); + } + + if (parameters.do_rotary) { + size_t q_size = static_cast(batch_size * sequence_length * num_heads * head_size); + size_t k_size = static_cast(batch_size * sequence_length * kv_num_heads * head_size); + auto q_buffer = reinterpret_cast(data.rotary_buffer); + auto k_buffer = q_buffer + q_size; + auto position_ids_buff = reinterpret_cast(k_buffer + k_size); + ORT_RETURN_IF_ERROR(LaunchSeqlensToPosIds(parameters, data.seqlens_k, position_ids_buff, stream, + max_threads_per_block)); + DUMP_TENSOR_INIT(); + DUMP_TENSOR("position_ids", position_ids_buff, batch_size, sequence_length); + // Launch rotary embedding kernel + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(stream, q_buffer, reinterpret_cast(query), + position_ids_buff, data.cos_cache, data.sin_cache, + parameters.batch_size, parameters.sequence_length, + parameters.num_heads, parameters.head_size, + parameters.rotary_dim, parameters.seqlen_present_kv_cache, + /*position_ids_format*/ 1, parameters.rotary_interleaved, + device_prop.maxThreadsPerBlock, /*transposed*/ false)); + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(stream, k_buffer, reinterpret_cast(key), + position_ids_buff, data.cos_cache, data.sin_cache, + parameters.batch_size, parameters.sequence_length, + parameters.kv_num_heads, parameters.head_size, + parameters.rotary_dim, parameters.seqlen_present_kv_cache, + /*position_ids_format*/ 1, parameters.rotary_interleaved, + device_prop.maxThreadsPerBlock, /*transposed*/ false)); + query = reinterpret_cast(q_buffer); + key = reinterpret_cast(k_buffer); + } if (parameters.is_prompt) { // Launch kernel to copy seqlen constexpr int thr_per_blk = 256; int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk; - repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, batch_size); + repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, + batch_size); } else { ORT_RETURN_IF_ERROR(LaunchGetSeqlenBuff(parameters, data.seqlens_k, data.seqlens_k_total, true, stream, 256)); } @@ -583,7 +713,7 @@ Status EfficientAttention( "Past and present kv shall share the same tensor when kv_share_buffer is on."); } // Concatenate new kv in place - ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block)); + ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, key, value, stream, max_threads_per_block)); } else { // Not share buffer case if (data.past_key != nullptr && data.past_key == data.present_key) { @@ -591,7 +721,7 @@ Status EfficientAttention( "Past and present kv share the same tensor but kv_share_buffer is not on."); } // Copy past and concat new KV to present buffer - ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, key, value, stream, max_threads_per_block)); } // Ungroup if grouped, otherwise use present kv directly diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index 1bf91f9c875eb..32341afa0e3fa 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -30,6 +30,8 @@ struct GroupQueryAttentionData { int* seqlens_k_total = nullptr; // Memory Efficient buffers T* fmha_buffer = nullptr; + T* unpacked_qkv_buffer = nullptr; + T* rotary_buffer = nullptr; T* k = nullptr; T* v = nullptr; // Output Tensors diff --git a/onnxruntime/core/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp b/onnxruntime/core/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp index 955b7c5deee9a..43a12b37e4ffa 100644 --- a/onnxruntime/core/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp +++ b/onnxruntime/core/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp @@ -171,11 +171,9 @@ Return Value: if (k > 0) { Row0AElements0 = a[0]; - Row0AElements1 = a[1]; if (ProcessTwoRows) { Row1AElements0 = a[lda]; - Row1AElements1 = a[lda + 1]; } BElements0 = MlasLoadFloat32x4(B + 0); diff --git a/onnxruntime/core/platform/windows/debug_alloc.cc b/onnxruntime/core/platform/windows/debug_alloc.cc index ff6a059607367..f3520b4f7f7f5 100644 --- a/onnxruntime/core/platform/windows/debug_alloc.cc +++ b/onnxruntime/core/platform/windows/debug_alloc.cc @@ -55,41 +55,67 @@ struct MemoryBlock { }; struct SymbolHelper { - SymbolHelper() noexcept { - SymSetOptions(SymGetOptions() | SYMOPT_DEFERRED_LOADS); - SymInitialize(GetCurrentProcess(), nullptr, true); + HANDLE process_handle_ = GetCurrentProcess(); + bool initialized_ = false; + + bool InitializeWhenNeeded() { + // We try only once + if (!initialized_) { + SymSetOptions(SymGetOptions() | SYMOPT_DEFERRED_LOADS); + // We use GetCurrentProcess() because other libs are likely to use it + if (!SymInitialize(process_handle_, nullptr, true)) { + const unsigned long long error{GetLastError()}; + std::cerr << "SymInitialize() failed: " << error << std::endl; + return false; + } + initialized_ = true; + } + return true; + } + + SymbolHelper() = default; + + static constexpr size_t kInitialBufferSize = sizeof(SYMBOL_INFO) + MAX_SYM_NAME; + + bool LoookupSymAndInitialize(const ULONG_PTR address, char* buffer, size_t buffer_size, SYMBOL_INFO* symbol) { + if (SymFromAddr(process_handle_, address, 0, symbol) != TRUE) { + if (GetLastError() == ERROR_INVALID_HANDLE) { + // Try to initialize first + if (!InitializeWhenNeeded() || SymFromAddr(process_handle_, address, 0, symbol) != TRUE) { + _snprintf_s(buffer, buffer_size, _TRUNCATE, "0x%08IX (Unknown symbol)", address); + return false; + } + } else { + _snprintf_s(buffer, buffer_size, _TRUNCATE, "0x%08IX (Unknown symbol)", address); + return false; + } + } + return true; } void Lookup(std::string& string, const ULONG_PTR address) { - char buffer[2048] = {0}; - Symbol symbol; - if (SymFromAddr(GetCurrentProcess(), address, 0, &symbol) == false) { - _snprintf_s(buffer, _TRUNCATE, "0x%08IX (Unknown symbol)", address); + alignas(SYMBOL_INFO) char buffer[kInitialBufferSize] = {0}; + SYMBOL_INFO* symbol = reinterpret_cast(buffer); + symbol->SizeOfStruct = sizeof(SYMBOL_INFO); + symbol->MaxNameLen = MAX_SYM_NAME; + + if (!LoookupSymAndInitialize(address, buffer, kInitialBufferSize, symbol)) { string.append(buffer); return; } Line line; DWORD displacement; - if (SymGetLineFromAddr(GetCurrentProcess(), address, &displacement, &line) == false) { - _snprintf_s(buffer, _TRUNCATE, "(unknown file & line number): %s", symbol.Name); + if (SymGetLineFromAddr(process_handle_, address, &displacement, &line) == false) { + _snprintf_s(buffer, _TRUNCATE, "(unknown file & line number): %s", symbol->Name); string.append(buffer); return; } - _snprintf_s(buffer, _TRUNCATE, "%s(%d): %s", line.FileName, static_cast(line.LineNumber), symbol.Name); + _snprintf_s(buffer, _TRUNCATE, "%s(%d): %s", line.FileName, static_cast(line.LineNumber), symbol->Name); string.append(buffer); } - struct Symbol : SYMBOL_INFO { - Symbol() noexcept { - SizeOfStruct = sizeof(SYMBOL_INFO); - MaxNameLen = _countof(buffer); - } - - char buffer[1024] = {0}; - }; - struct Line : IMAGEHLP_LINE { Line() noexcept { SizeOfStruct = sizeof(IMAGEHLP_LINE); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index 6d8c80bd2aaa1..08c9a8449cc33 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -244,7 +244,9 @@ Status BaseOpBuilder::TransposeInitializer(const QnnModelWrapper& qnn_model_wrap TensorShape new_tensor_shape(new_tensor_shape_dims); Tensor out_tensor = Tensor(tensor_dtype, new_tensor_shape, cpu_allocator); - ORT_RETURN_IF_ERROR(onnxruntime::utils::TensorProtoToTensor(Env::Default(), nullptr, initializer, in_tensor)); + onnxruntime::PathString model_path = qnn_model_wrapper.GetGraphViewer().ModelPath().ToPathString(); + const ORTCHAR_T* model_path_str = model_path.empty() ? nullptr : model_path.c_str(); + ORT_RETURN_IF_ERROR(onnxruntime::utils::TensorProtoToTensor(Env::Default(), model_path_str, initializer, in_tensor)); ORT_RETURN_IF_ERROR(Transpose::DoTranspose(permutations, in_tensor, out_tensor)); onnx::TensorProto new_tensor_proto = onnxruntime::utils::TensorToTensorProto(out_tensor, "test"); ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(new_tensor_proto, transposed_data)); diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index 667d7047c1fbd..80617b7b5edaa 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -21,19 +21,15 @@ from .quant_utils import ( ONNX_TYPE_TO_NP_TYPE, TENSOR_NAME_QUANT_SUFFIX, - QuantizedValue, - QuantizedValueType, QuantType, - compute_scale_zp, - compute_scale_zp_float8, find_by_name, - get_qmin_qmax_for_qType, model_has_infer_metadata, quantize_data, quantize_nparray, save_and_reload_model_with_shape_infer, tensor_proto_to_array, ) +from .tensor_quant_overrides import TensorQuantOverridesHelper class QuantizationParams: @@ -121,27 +117,17 @@ def __init__( self.opset_version = self.check_opset_version() - # Map of all original value names to quantized value names - self.quantized_value_map = {} + # Get tensor-level quantization overrides and ensure they are valid. + self.tensor_quant_overrides = TensorQuantOverridesHelper(self.extra_options.get("TensorQuantOverrides", {})) - self.tensor_quant_overrides, self.tensor_quant_override_types = self._get_and_check_tensor_quant_overrides() - self.quantization_params = self.calculate_quantization_params() - - # to store specified scale and zeropoint instead of calculated value, tensor_name->(scale, zeropoint) - self.used_scale_zp_map = {} - - def set_quant_scale_zp(self, tensor_name, value): - assert isinstance(value, tuple) and len(value) == 2, "value must be scale(float or float16) and zeropoint" - assert hasattr(value[0], "dtype") - assert tensor_name not in self.used_scale_zp_map, f"{tensor_name} has been setted before" - self.used_scale_zp_map[tensor_name] = value + initializer_names = {initzer.name for initzer in self.model.initializer()} + overrides_valid, overrides_err = self.tensor_quant_overrides.is_valid( + initializer_names, self.value_infos.keys(), activation_qType + ) + if not overrides_valid: + raise ValueError(overrides_err) - def find_quant_scale_zp(self, input_name): - if input_name in self.used_scale_zp_map: - return self.used_scale_zp_map[input_name] - if self.parent is not None: - return self.parent.find_quantized_value(input_name) - return (None, None) + self.tensor_quant_override_qtypes = self.tensor_quant_overrides.get_quant_types() def quantize_model(self): raise NotImplementedError @@ -212,36 +198,16 @@ def check_opset_version(self): return opset_version - def quantize_bias_static(self, bias_name, input_name, weight_name, beta=1.0): + def quantize_bias_static_impl(self, bias_name, input_scale, weight_scale, beta=1.0): """ Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale """ - # Handle case where bias already in quantization map - if bias_name in self.quantized_value_map: - return self.quantized_value_map[bias_name].q_name - - # get scale for weight - weight_scale_name = self.quantized_value_map[weight_name].scale_name - weight_initializer = find_by_name(weight_scale_name, self.model.initializer()) - weight_scale = tensor_proto_to_array(weight_initializer) - # get bias bias_initializer = find_by_name(bias_name, self.model.initializer()) bias_data = tensor_proto_to_array(bias_initializer) quantized_bias_name = bias_name + TENSOR_NAME_QUANT_SUFFIX - # get scale for input - if input_name in self.quantized_value_map: - input_scale_name = self.quantized_value_map[input_name].scale_name - elif input_name in self.quantization_params: - _, input_scale_name, _, _, _ = self._get_quantization_params(input_name) - else: - raise ValueError(f"Expected {input_name} to be in quantized value map for static quantization") - - inputscale_initializer = find_by_name(input_scale_name, self.model.initializer()) - input_scale = tensor_proto_to_array(inputscale_initializer) - # quantize bias if self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN: data = np.asarray(bias_data) @@ -293,22 +259,16 @@ def quantize_bias_static(self, bias_name, input_name, weight_name, beta=1.0): packed_bias_zp_initializer = onnx.helper.make_tensor(quantized_bias_zp_name, tensor_type, [], [0]) self.model.initializer_extend([packed_bias_zp_initializer]) - assert bias_name not in self.quantized_value_map - quantized_value = QuantizedValue( - bias_name, + return ( quantized_bias_name, quantized_bias_scale_name, quantized_bias_zp_name, - QuantizedValueType.Initializer, - 0 if bias_scale_data.size > 1 else None, - node_type=node_type, - node_qtype=node_qtype, + bias_scale_data, + node_type, + node_qtype, ) - self.quantized_value_map[bias_name] = quantized_value - - return quantized_bias_name - def quantize_initializer(self, weight, qType, reduce_range=False, keep_float_weight=False): + def quantize_initializer_impl(self, weight, qType, reduce_range=False, keep_float_weight=False): """ :param weight: TensorProto initializer :param qType: type to quantize to @@ -316,22 +276,13 @@ def quantize_initializer(self, weight, qType, reduce_range=False, keep_float_wei If keep_float_weight is False, quantize the weight, or don't quantize the weight. :return: quantized weight name, zero point name, scale name """ - # Find if this input is already quantized - if weight.name in self.quantized_value_map: - quantized_value = self.quantized_value_map[weight.name] - return ( - quantized_value.q_name, - quantized_value.zp_name, - quantized_value.scale_name, - ) - q_weight_name = weight.name + TENSOR_NAME_QUANT_SUFFIX zp_name = weight.name + "_zero_point" scale_name = weight.name + "_scale" # Quantize weight data. Use quantization overrides if provided by the user. weight_data = tensor_proto_to_array(weight) - quant_overrides = self.get_per_tensor_quant_overrides(weight.name) + quant_overrides = self.tensor_quant_overrides.get_per_tensor_overrides(weight.name) if "quant_type" in quant_overrides: qType = quant_overrides["quant_type"].tensor_type # noqa: N806 @@ -392,19 +343,9 @@ def quantize_initializer(self, weight, qType, reduce_range=False, keep_float_wei q_weight_initializer = onnx.numpy_helper.from_array(q_weight_data, q_weight_name) self.model.initializer_extend([q_weight_initializer]) - # Log entry for this quantized weight - quantized_value = QuantizedValue( - weight.name, - q_weight_name, - scale_name, - zp_name, - QuantizedValueType.Initializer, - None, - ) - self.quantized_value_map[weight.name] = quantized_value return q_weight_name, zp_name, scale_name - def quantize_weight_per_channel( + def quantize_weight_per_channel_impl( self, weight_name, weight_qType, @@ -412,22 +353,13 @@ def quantize_weight_per_channel( reduce_range=True, keep_float_weight=False, ): - # Find if this input is already quantized - if weight_name in self.quantized_value_map: - quantized_value = self.quantized_value_map[weight_name] - return ( - quantized_value.q_name, - quantized_value.zp_name, - quantized_value.scale_name, - ) - initializer = find_by_name(weight_name, self.model.initializer()) if initializer is None: raise ValueError("{} is not an initializer", weight_name) weights = tensor_proto_to_array(initializer) channel_count = weights.shape[channel_axis] - quant_overrides_for_channels = self.get_per_channel_quant_overrides(weight_name, channel_count) + quant_overrides_for_channels = self.tensor_quant_overrides.get_per_channel_overrides(weight_name, channel_count) # If user provides per-channel quantization overrides, all channels must use the same quantization type. # So, just use the first channel's type. @@ -499,16 +431,6 @@ def quantize_weight_per_channel( zp_name = weight_name + "_zero_point" scale_name = weight_name + "_scale" - quantized_value = QuantizedValue( - weight_name, - q_weight_name, - scale_name, - zp_name, - QuantizedValueType.Initializer, - None, - ) - self.quantized_value_map[weight_name] = quantized_value - # Update packed weight, zero point, and scale initializers zero_scale_shape = [initializer.dims[channel_axis]] scale_initializer = onnx.helper.make_tensor( @@ -530,194 +452,25 @@ def quantize_weight_per_channel( return q_weight_name, zp_name, scale_name - def _get_and_check_tensor_quant_overrides(self): - """ - Get tensor quantization overrides and check correctness. - """ - tensor_quant_overrides = self.extra_options.get("TensorQuantOverrides", {}) - tensor_quant_override_types = set() - - # Validate that compatible/valid overrides are provided. - if tensor_quant_overrides: - initializer_names = self.model.get_initializer_name_set() - value_info_names = set(self.value_infos.keys()) - keys_unsupported_with_scale_zp = {"symmetric", "reduce_range", "rmax", "rmin"} - - for tensor_name, quant_overrides_list in tensor_quant_overrides.items(): - if tensor_name not in initializer_names and tensor_name not in value_info_names: - raise ValueError(f"Tensor '{tensor_name}' in TensorQuantOverrides is not present in the model") - - if not isinstance(quant_overrides_list, list): - raise ValueError(f"Tensor quantization overrides for '{tensor_name}' are not in a list") - - is_initializer = tensor_name in initializer_names - if not is_initializer and len(quant_overrides_list) > 1: - raise ValueError( - f"Tensor '{tensor_name}' has a list of per-channel overrides, but is not an initializer" - ) - - quant_type = None - for index, quant_overrides in enumerate(quant_overrides_list): - if not isinstance(quant_overrides, dict): - raise ValueError( - f"Tensor quantization overrides at index {index} for '{tensor_name}' are not in a dict" - ) - - # For per-channel quantization, all channels must use the same quantization type. - # Therefore, if the user tries to override the quant_type for a channel, it must match in all - # other channels. - if index == 0: - quant_type = quant_overrides.get("quant_type") - if quant_type: - tensor_quant_override_types.add(quant_type.tensor_type) - elif quant_type != quant_overrides.get("quant_type"): - raise ValueError( - "Channel quantization types for tensor '{tensor_name}' do not match at index {index}." - ) - - has_scale = "scale" in quant_overrides - has_zero_point = "zero_point" in quant_overrides - - if (has_scale and not has_zero_point) or (has_zero_point and not has_scale): - raise ValueError( - "Must provide both 'scale' and 'zero_point' if one of the overrides is provided" - ) - - if has_scale: - for key in keys_unsupported_with_scale_zp: - if key in quant_overrides: - raise ValueError( - f"Tensor override option '{key}' is invalid with 'scale' and 'zero_point'" - ) - - return tensor_quant_overrides, tensor_quant_override_types - - def get_per_tensor_quant_overrides(self, tensor_name): - quant_overrides_list = self.tensor_quant_overrides.get(tensor_name, [{}]) - num_overrides = len(quant_overrides_list) - if num_overrides > 1: - raise ValueError( - f"Expected tensor '{tensor_name}' to use per-tensor quantization overrides, " - f"but found {num_overrides} per-channel overrides." - ) - - return quant_overrides_list[0] if num_overrides > 0 else {} - - def get_per_channel_quant_overrides(self, tensor_name, num_channels): - quant_overrides_list = self.tensor_quant_overrides.get(tensor_name, [{} for i in range(num_channels)]) - - if len(quant_overrides_list) != num_channels: - raise ValueError( - f"Expected tensor '{tensor_name}' to have {num_channels} per-channel quantization overrides, " - f"but found {len(quant_overrides_list)} instead." - ) - - return quant_overrides_list - - def _get_quantization_params(self, param_name, use_scale=None, use_zeropoint=None): - """ - Create initializers and inputs in the graph for zero point and scale of output. - Zero point and scale values are obtained from self.quantization_params if specified. - parameter param_name: Name of the quantization parameter. - return: result, scale_name, zero_point_name, scale_shape, zero_point_shape. - """ - zero_point_type = self.activation_qType - - if use_scale is None or use_zeropoint is None: - if self.quantization_params is None or param_name not in self.quantization_params: - logging.info(f'Quantization parameters for tensor:"{param_name}" not specified') - return False, "", "", "", "" - - params = self.quantization_params[param_name] - if not isinstance(params, QuantizationParams): - raise TypeError(f"Unexpected type {type(params)} for {param_name!r}.") - if params is None or len(params) != 3: - raise ValueError( - "Quantization parameters should contain zero point, scale, quant type. " - f"Specified values for output {param_name}: {params}" - ) - - zero_point_values = np.array([params["zero_point"]]) - if not hasattr(params["scale"], "dtype") or params["scale"].dtype not in (np.float32, np.float16): - raise ValueError(f"Unexpected type {type(params['scale'])} and param_name={param_name!r}") - scale_values = np.array([params["scale"]]) - assert scale_values.dtype != np.float64 - zero_point_type = params["quant_type"] - else: - zero_point_values = np.array([use_zeropoint]) - scale_values = np.array([use_scale]) - params = self.quantization_params[param_name] - if "scale" in params: - dtype = params["scale"].dtype - scale_values = scale_values.astype(dtype) - assert scale_values.dtype != np.float64 - - zero_point_shape = [] - zero_point_name = param_name + "_zero_point" - scale_shape = [] - scale_name = param_name + "_scale" - - # Add initializers - init_zp = onnx.helper.make_tensor( - zero_point_name, zero_point_type, zero_point_shape, zero_point_values.ravel().tolist() - ) - self.model.add_initializer(init_zp) - if scale_values.dtype == np.float32: - scale_type = onnx.TensorProto.FLOAT - elif scale_values.dtype == np.float16: - scale_type = onnx.TensorProto.FLOAT16 - else: - raise ValueError(f"Unexpected dtype={scale_values.dtype} for param_name={param_name!r}") - init_scale = onnx.helper.make_tensor(scale_name, scale_type, scale_shape, scale_values.reshape((-1,)).tolist()) - self.model.add_initializer(init_scale) - - return True, scale_name, zero_point_name, scale_shape, zero_point_shape - - def calculate_quantization_params(self): + def adjust_tensor_ranges(self): if self.tensors_range is None: - return {} + return - # adjust tensor_ranges for input of Clip and Relu node for node in self.model.nodes(): - if node.op_type not in ["Clip", "Relu"]: - continue - if self.is_activation_symmetric: - continue - if not self.should_quantize_node(node): - continue - if len(self.model.input_name_to_nodes()[node.input[0]]) != 1: - continue - if node.input[0] not in self.tensors_range or node.output[0] not in self.tensors_range: - continue - td = self.tensors_range[node.output[0]] - if not isinstance(td, TensorData): - raise TypeError(f"Unexpected type {type(td)} for {node.output[0]!r}.") - self.tensors_range[node.input[0]] = td - - quantization_params = {} - for tensor_name in self.tensors_range: - td = self.tensors_range[tensor_name] - if not isinstance(td, TensorData): - raise TypeError(f"Unexpected type {type(td)} for {tensor_name!r}.") - - quant_overrides = self.get_per_tensor_quant_overrides(tensor_name) - - quant_type = self.activation_qType - if "quant_type" in quant_overrides: - quant_type = quant_overrides["quant_type"].tensor_type - - if "scale" in quant_overrides and "zero_point" in quant_overrides: - zero, scale = quant_overrides["zero_point"], quant_overrides["scale"] - elif quant_type == onnx.TensorProto.FLOAT8E4M3FN: - zero, scale = compute_scale_zp_float8(quant_type, td.avg_std[1]) - else: - rmin = quant_overrides.get("rmin", td.range_value[0]) - rmax = quant_overrides.get("rmax", td.range_value[1]) - symmetric = quant_overrides.get("symmetric", self.is_activation_symmetric) - reduce_range = quant_overrides.get("reduce_range", False) - qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range=reduce_range, symmetric=symmetric) - zero, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, self.min_real_range) - - quantization_params[tensor_name] = QuantizationParams(zero_point=zero, scale=scale, quant_type=quant_type) - - return quantization_params + # adjust tensor_ranges for input of Clip and Relu node + if node.op_type in ["Clip", "Relu"]: + if self.is_activation_symmetric: + continue + if not self.should_quantize_node(node): + continue + if len(self.model.input_name_to_nodes()[node.input[0]]) != 1: + continue + if node.input[0] not in self.tensors_range or node.output[0] not in self.tensors_range: + continue + td = self.tensors_range[node.output[0]] + if not isinstance(td, TensorData): + raise TypeError(f"Unexpected type {type(td)} for {node.output[0]!r}.") + self.tensors_range[node.input[0]] = td + # Adjust Softmax to range from 0.0 to 1.0 + elif node.op_type == "Softmax": + self.tensors_range[node.output[0]] = TensorData(lowest=np.float32(0.0), highest=np.float32(1.0)) diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py new file mode 100644 index 0000000000000..d59a0ec74ca7c --- /dev/null +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py @@ -0,0 +1,413 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from __future__ import annotations + +import logging +from dataclasses import dataclass + +import onnx + +from ...quant_utils import QuantType +from ...tensor_quant_overrides import QuantTypeInfo, TensorQuantOverridesHelper + + +@dataclass +class TensorTypeRequest: + """ + Bundles desired quantization type requests for a tensor. A distinction is made between the + produced type and the consumed type. + """ + + # The tensor's quant type at the producer end. If None, assumed to be the default activation quant type. + producer: QuantTypeInfo | None + + # The tensor's quant type received by a set of consumer nodes. + # If None, assumed to be the default activation quant type for all consumers. + # consumers[1] is a set of consumer node names. + consumers: tuple[QuantTypeInfo, set[str]] | None + + +class MixedPrecisionTensorQuantOverridesFixer: + """ + Helper that generates tensor quantization overrides for mixed-precision QDQ models. + + Specifically, this helper fixes an initial set of quantization overrides that assign a non-default + activation quantization type to one or more tensors by doing the following: + - Inferring which other tensors need to be overridden to the non-default activation quantization type. + - Inserting quantization data type conversions. + + Example: + -------- + + Float model: + + input_0 --> Op1 --> Op3 --> Op5 --> Op6 --> output_0 + ^ + | + input_1 --> Op2 -+-> Op4 ----+ + | + +-> Op7 --> output_1 + | + +-> Op8 --> output_2 + + If we'd like to quantize this model to uint8 precision, but would like to make sure tensor "Op4_out" + is quantized to 16-bit, then we would specify the following initial tensor quantization overrides: + + ``` + init_overrides = {"Op4_out": [{"quant_type": QuantType.QUInt16}]} + ``` + + These initial overrides may not create a valid model because Op4 and Op5 may require both the input and output + to be the same type (e.g., uint16). This helper fixes the overrides so that input/output data types + are valid: + + ``` + overrides = TensorQuantOverridesHelper(init_overrides) + + fixer = MixedPrecisionTensorQuantOverridesFixer.create_from_model(overrides, model, QuantType.QUInt8) + fixer.apply( + default_activation_qtype=QuantType.QUInt8, + default_activation_symmetric=False, + ) + ``` + + The above snippet generates the following "fixed" overrides (get via overrides.get_dict()): + + { + "Op2_out": [{"quant_type": QUInt8, "convert": {"quant_type": QUInt16, "recv_nodes": {"Op4"}}}], + "Op3_out": [{"quant_type": QUInt8, "convert": {"quant_type": QUInt16, "recv_nodes": {"Op5"}}}], + "Op4_out": [{"quant_type": QUInt16}], + "Op5_out": [{"quant_type": QUInt16, "convert": {"quant_type": QUInt8, "recv_nodes": {"Op6"}}}] + } + + How to interpret the fixed overrides: + - Op2's output is consumed by Op4, Op7, and Op8. Op4 consumes the converted u16 type, + but Op7 and Op8 consume the original u8 type. + - Op3's output is converted from u8 to u16. Op5 consumes the converted u16 type. + - Op4's output is just u16 (not converted). All consumers of Op4_out get the u16 type. + - Op5's output is converted from u16 to u8. Op6 consumes the u8 type. + """ + + def __init__( + self, + overrides: TensorQuantOverridesHelper, + producers: dict[str, onnx.NodeProto], + consumers: dict[str, list[onnx.NodeProto]], + value_infos: dict[str, onnx.ValueInfoProto], + initializers: dict[str, onnx.TensorProto], + ): + """ + Params: + overrides: The initial tensor quantization overrides to fix. + producers: Dictionary that maps a tensor name to the producer node that generates the tensor. + consumers: Dictionary that maps a tensor name to the consumer nodes that take the tensor as input. + value_infos: Dictionary that maps a tensor name to its onnx.ValueInfoProto. + initializers: Dictionary that maps an initializer name to its onnx.TensorProto. + """ + self.overrides = overrides + self.consumers = consumers + self.producers = producers + self.value_infos = value_infos + self.initializers = initializers + + @staticmethod + def create_from_model( + overrides: TensorQuantOverridesHelper, model: onnx.ModelProto, default_activation_qtype: QuantType + ) -> MixedPrecisionTensorQuantOverridesFixer: + """ + Helper function that creates an instance of this class from a loaded ONNX model. + + Params: + overrides: The initial tensor quantization overrides to fix. + model: Loaded ONNX model + default_activation_qtype: The intended default activation quantization type. + Used to validate the initial overrides. + + Returns: + Initialized MixedPrecisionTensorQuantOverridesFixer object + """ + model = onnx.shape_inference.infer_shapes(model) # Need to infer shapes to get value_infos + + # Build dictionaries that enable convenient lookups of initializers and value_infos by name. + initializers = {initializer.name: initializer for initializer in model.graph.initializer} + value_infos = {vi.name: vi for vi in model.graph.value_info} + value_infos.update({ot.name: ot for ot in model.graph.output}) + value_infos.update({it.name: it for it in model.graph.input}) + + # Ensure that the user-provided initial overrides are actually valid. + valid, err = overrides.is_valid(set(initializers), set(value_infos), default_activation_qtype) + if not valid: + pprint_overrides = overrides.pprint_str(indent=4) + logging.error(f"Provided invalid tensor quantization overrides:\n{pprint_overrides}") + raise ValueError(err) + + consumers = {} + producers = {} + + # Build dictionaries that map a tensor name to the consumer or producer nodes. + for node in model.graph.node: + for input_name in node.input: + if input_name: + if input_name not in consumers: + consumers[input_name] = [] + + consumers[input_name].append(node) + + for output_name in node.output: + producers[output_name] = node + + return MixedPrecisionTensorQuantOverridesFixer(overrides, producers, consumers, value_infos, initializers) + + def apply( + self, + default_activation_qtype: QuantType, + default_activation_symmetric: bool, + ): + """ + Fixes the initial tensor quantization overrides (in-place) for use in mixed-precision QDQ models. + + Params: + default_activation_qtype: The intended default activation quantization type. + default_activation_symmetric: The intended default symmetry used to quantize activations. + """ + type_requests = self.get_desired_tensor_types(default_activation_qtype, default_activation_symmetric) + + # Use type requests to "fix" tensor quantization overrides by adding + # quantization type conversions where necessary. + for tensor_name, type_req in type_requests.items(): + all_consumers = set([node.name for node in self.consumers.get(tensor_name, [])]) + has_producer_req = type_req.producer is not None + has_consumer_req = bool(type_req.consumers) + + # Only producer type: Add conversion back to default activation type + if has_producer_req and not has_consumer_req: + self._update_converted_tensor( + tensor_name, type_req.producer, QuantTypeInfo(default_activation_qtype), all_consumers + ) + # Only consumers + elif not has_producer_req and has_consumer_req: + prod_type_info = self.overrides.get_node_output_qtype_info(tensor_name, default_activation_qtype) + consumer_type_info = type_req.consumers[0] + + if prod_type_info != consumer_type_info: + self._update_converted_tensor( + tensor_name, prod_type_info, consumer_type_info, type_req.consumers[1] + ) + else: + if not self._check_nodes_are_not_convert_consumers(tensor_name, type_req.consumers[1]): + raise ValueError( + f"Tensor override for '{tensor_name}' converts the type for consumers that need the original type." + ) + # Both producer and consumers + elif has_producer_req and has_consumer_req: + prod_type_info = type_req.producer + consumer_type_info = type_req.consumers[0] + + if prod_type_info != consumer_type_info: + self._update_converted_tensor( + tensor_name, prod_type_info, consumer_type_info, type_req.consumers[1] + ) + else: + consumers_for_original_type = all_consumers.difference(type_req.consumers[1]) + + if len(consumers_for_original_type) == 0: + # All consumers want the overridden type, so no need for convert nodes! + # Just add the override to the new new if not already present. + if tensor_name not in self.overrides: + self.overrides[tensor_name] = [{}] + prod_type_info.save_to_dict(self.overrides[tensor_name][0]) + + assert "convert" not in self.overrides[tensor_name][0] + else: + # Some consumers don't want the overridden type. + self._update_converted_tensor( + tensor_name, + prod_type_info, + QuantTypeInfo(default_activation_qtype), + consumers_for_original_type, + ) + else: + raise ValueError(f"TypeRequest for tensor {tensor_name} has no producer or consumers.") + + # Done. Check if the overrides are valid. + valid, err = self.overrides.is_valid(set(self.initializers), set(self.value_infos), default_activation_qtype) + if not valid: + pprint_overrides = self.overrides.pprint_str(indent=4) + logging.error( + f"Generated invalid tensor quantization overrides for mixed-precision QDQ model:\n{pprint_overrides}" + ) + raise ValueError(err) + + def get_desired_tensor_types( + self, + default_activation_qtype: QuantType, + default_activation_symmetric: bool, + ) -> dict[str, TensorTypeRequest]: + """ + Iterates through the initial tensor quantization overrides and builds a set of TensorTypeRequests objects + that describe the quantization types required at each tensor. These TensorTypeRequests objects are ultimately + used to generated the "fixed" overrides. + + Params: + default_activation_qtype: The intended default activation quantization type. + default_activation_symmetric: The intended default symmetry used to quantize activations. + + Returns: + TensorTypeRequest objects as a dict that maps a tensor name to its requested types. + """ + type_requests = {} + default_activation_type_info = QuantTypeInfo(default_activation_qtype, default_activation_symmetric) + + # Scan tensor overrides for type conversion requests. + for tensor_name, override_list in self.overrides.items(): + if not self.__is_tensor_quantizable(tensor_name): + continue # Skip non-quantizable tensors (e.g., not a float) + + if tensor_name in self.initializers: + continue # Skip initializers + + if not override_list or len(override_list) > 1: + continue # Skip per-channel stuff + + override_dict = override_list[0] + quant_type_info = QuantTypeInfo.load_from_dict(override_dict, default_activation_type_info.quant_type) + producer_node = self.producers.get(tensor_name) # None if this is a model input + + if quant_type_info != default_activation_type_info and "convert" not in override_dict: + if producer_node is not None: + self._add_type_requests_for_node(type_requests, quant_type_info, producer_node) + + # Find all consumer nodes of `tensor_name` and update their inputs/outputs to the new type. + for consumer_node in self.consumers.get(tensor_name, []): + self._add_type_requests_for_node(type_requests, quant_type_info, consumer_node) + + return type_requests + + def _add_type_requests_for_node( + self, + type_requests: dict[str, TensorTypeRequest], + quant_type_info: QuantTypeInfo, + node: onnx.NodeProto, + ): + """ + Adds TensorTypeRequest objects for a given node, assuming that we want all its inputs and outputs + to have the same quantization type (as specified by the `quant_type_info` parameter). + + Params: + type_requests: Dictionary of type requests to append to for this node. + quant_type_info: The quantization type to use for inputs and outputs. + node: The node for which the TensorTypeRequest objects are created and added to type_requests. + """ + # Add output side + for output_name in node.output: + if not self.__is_tensor_quantizable(output_name): + continue + + if output_name not in type_requests: + type_requests[output_name] = TensorTypeRequest(quant_type_info, None) + else: + if ( + type_requests[output_name].producer is not None + and type_requests[output_name].producer != quant_type_info + ): + raise ValueError(f"Tensor {output_name} has multiple types.") + + type_requests[output_name].producer = quant_type_info + + # Add the consumer side + for input_name in node.input: + if input_name and input_name not in self.initializers and self.__is_tensor_quantizable(input_name): + if input_name not in type_requests: + type_requests[input_name] = TensorTypeRequest(None, None) + + if type_requests[input_name].consumers is None: + type_requests[input_name].consumers = (quant_type_info, set()) + + if type_requests[input_name].consumers[0] != quant_type_info: + raise ValueError(f"Tensor {input_name} has consumers requesting different types.") + + if not node.name: + raise ValueError( + f"Node of type {node.op_type} with output 0 {node.output[0]} does not have a name!" + ) + + type_requests[input_name].consumers[1].add(node.name) + + def _update_converted_tensor( + self, + tensor_name: str, + producer_type_info: QuantTypeInfo, + consumer_type_info: QuantTypeInfo, + consumer_names: set[str], + ): + """ + Updates the tensor quantization overrides for a tensor that is converted from one type to another. + + Params: + tensor_name: The name of the tensor for which to update overrides. + producer_type_info: Info for the tensor's produced type. + consumer_type_info: Info for the tensor's consumed (i.e., converted) type. + consumer_names: Nodes names of consumers that consume the converted type. + """ + if tensor_name not in self.overrides or not self.overrides[tensor_name]: + self.overrides[tensor_name] = [{}] + producer_type_info.save_to_dict(self.overrides[tensor_name][0]) + + overrides = self.overrides[tensor_name][0] + if producer_type_info != QuantTypeInfo.load_from_dict(overrides): + raise ValueError(f"Desired producer quant_type for {tensor_name} doesn't match existing type.") + + if consumer_names: + if "convert" not in overrides: + overrides["convert"] = {} + consumer_type_info.save_to_dict(overrides["convert"]) + + convert_dict = overrides["convert"] + if consumer_type_info != QuantTypeInfo.load_from_dict(convert_dict): + raise ValueError(f"Desired consumer quant_type for {tensor_name} doesn't match existing type.") + + if "recv_nodes" not in convert_dict: + convert_dict["recv_nodes"] = set() + + convert_dict["recv_nodes"].update(consumer_names) + + def _check_nodes_are_not_convert_consumers(self, tensor_name: str, node_names: set[str]): + """ + Returns true if the given nodes do not consume/receive a converted quantization type. + + Params: + tensor_name: The name of the tensor to check. + node_names: Set of node names that should not be consumers of the converted type. + """ + if tensor_name not in self.overrides or not self.overrides[tensor_name]: + return True + + overrides = self.overrides[tensor_name][0] + + if "convert" not in overrides: + return True + + convert_dict = overrides["convert"] + + if "recv_nodes" not in convert_dict: + return False + + return not convert_dict["recv_nodes"].intersection(node_names) + + def __is_tensor_quantizable(self, tensor_name): + weight = self.initializers.get(tensor_name) + if weight is not None: + if weight.data_type in (onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT16): + return True + elif tensor_name in self.value_infos: + vi = self.value_infos[tensor_name] + if vi.type.HasField("tensor_type") and vi.type.tensor_type.elem_type in ( + onnx.TensorProto.FLOAT, + onnx.TensorProto.FLOAT16, + ): + return True + + return False diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py index e9affae7ac263..479eaf5b0c542 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py @@ -3,6 +3,10 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +from __future__ import annotations + +import copy +import logging from pathlib import Path import numpy as np @@ -11,6 +15,8 @@ from ...calibrate import CalibrationDataReader, CalibrationMethod from ...quant_utils import QuantType from ...quantize import StaticQuantConfig +from ...tensor_quant_overrides import TensorQuantOverridesHelper +from .mixed_precision_overrides_utils import MixedPrecisionTensorQuantOverridesFixer Q16_TYPES = {QuantType.QInt16, QuantType.QUInt16} Q8_TYPES = {QuantType.QInt8, QuantType.QUInt8} @@ -18,6 +24,20 @@ MODEL_SIZE_THRESHOLD = 2147483648 # Quant model should use external data if >= 2GB +def warn_unable_to_override( + node: onnx.NodeProto, + what_str: str, + tensor_name: str, + io_kind: str, +): + logging.warning( + f"Unable to override {what_str} for {node.op_type} node's {io_kind} " + "because it has already been overridden! Check the initial quantization overrides provided " + "to get_qnn_qdq_config() if the generated QDQ model does not run on QNN EP. " + f"Node name: {node.name}, {io_kind} name: {tensor_name}" + ) + + def get_qnn_qdq_config( model_input: Path, calibration_data_reader: CalibrationDataReader, @@ -25,14 +45,20 @@ def get_qnn_qdq_config( activation_type=QuantType.QUInt8, weight_type=QuantType.QUInt8, per_channel=False, + init_overrides=None, + add_qtype_converts=True, + activation_symmetric=False, + weight_symmetric=None, ): if per_channel: raise ValueError("QNN EP does not yet support per-channel quantization.") + if weight_symmetric is None: + weight_symmetric = weight_type in {QuantType.QInt8, QuantType.QInt16} + model = onnx.load_model(model_input, load_external_data=False) op_types = set() - tensor_quant_overrides = {} model_has_external_data = False name_to_initializer = {} @@ -43,52 +69,40 @@ def get_qnn_qdq_config( if onnx.external_data_helper.uses_external_data(initializer): model_has_external_data = True - # Setup quantization overrides for specific operator types - for node in model.graph.node: - op_types.add(node.op_type) + overrides_helper = TensorQuantOverridesHelper(copy.deepcopy(init_overrides) if init_overrides else {}) - if node.op_type == "MatMul" and activation_type in Q16_TYPES and weight_type in Q8_TYPES: - weight_symmetric = weight_type == QuantType.QInt8 + if not overrides_helper.empty() and add_qtype_converts: + # Fix mixed-precision overrides. + overrides_fixer = MixedPrecisionTensorQuantOverridesFixer.create_from_model( + overrides_helper, model, activation_type + ) + overrides_fixer.apply(activation_type, activation_symmetric) - # Override initializers to use the weight_type - for input_name in node.input: - if input_name in name_to_initializer: - tensor_quant_overrides[input_name] = [{"quant_type": weight_type, "symmetric": weight_symmetric}] - elif node.op_type == "LayerNormalization" and activation_type in Q16_TYPES and weight_type in Q8_TYPES: - weight_symmetric = weight_type == QuantType.QInt8 + # Setup quantization overrides for specific operator types to ensure compatibility with QNN EP. + qnn_compat = QnnCompatibilityOverrides( + activation_type, + weight_type, + activation_symmetric, + weight_symmetric, + overrides_helper, + name_to_initializer, + ) - # Override initializers to use the weight_type. Don't override the bias input. - for i in range(2): - input_name = node.input[i] - if input_name in name_to_initializer: - tensor_quant_overrides[input_name] = [{"quant_type": weight_type, "symmetric": weight_symmetric}] - elif node.op_type == "Sigmoid": - if activation_type == QuantType.QUInt16: - tensor_quant_overrides[node.output[0]] = [ - {"scale": np.array(1.0 / 65536.0, dtype=np.float32), "zero_point": np.array(0, dtype=np.uint16)} - ] - elif activation_type == QuantType.QInt16: - tensor_quant_overrides[node.output[0]] = [ - {"scale": np.array(1.0 / 32768.0, dtype=np.float32), "zero_point": np.array(0, dtype=np.int16)} - ] - elif node.op_type == "Tanh": - if activation_type == QuantType.QUInt16: - tensor_quant_overrides[node.output[0]] = [ - {"scale": np.array(1.0 / 32768.0, dtype=np.float32), "zero_point": np.array(32768, dtype=np.uint16)} - ] - elif activation_type == QuantType.QInt16: - tensor_quant_overrides[node.output[0]] = [ - {"scale": np.array(1.0 / 32768.0, dtype=np.float32), "zero_point": np.array(0, dtype=np.int16)} - ] + for node in model.graph.node: + op_types.add(node.op_type) + qnn_compat.process_node(node) extra_options = { "MinimumRealRange": 0.0001, "DedicatedQDQPair": False, # Let ORT optimizer duplicate DQ nodes - "TensorQuantOverrides": tensor_quant_overrides, + "TensorQuantOverrides": overrides_helper.get_dict(), + "ActivationSymmetric": activation_symmetric, + "WeightSymmetric": weight_symmetric, } # TODO: Remove this extra option once ORT uses an ONNX version that supports 16-bit Q/DQ ops. - if activation_type in Q16_TYPES or weight_type in Q16_TYPES: + overrides_have_int16 = any(t in Q16_TYPES for t in overrides_helper.get_quant_types()) + if activation_type in Q16_TYPES or weight_type in Q16_TYPES or overrides_have_int16: extra_options["UseQDQContribOps"] = True return StaticQuantConfig( @@ -100,3 +114,163 @@ def get_qnn_qdq_config( use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD), extra_options=extra_options, ) + + +class QnnCompatibilityOverrides: + """ + Helper that processes nodes to generate quantization overrides that make the resulting QDQ model + compatible with QNN EP. + """ + + def __init__( + self, + default_activation_qtype: QuantType, + default_weight_qtype: QuantType, + activation_symmetric: bool, + weight_symmetric: bool, + overrides: TensorQuantOverridesHelper, + initializers: dict[str, onnx.TensorProto], + ): + self.default_activation_qtype = default_activation_qtype + self.default_weight_qtype = default_weight_qtype + self.activation_symmetric = activation_symmetric + self.weight_symmetric = weight_symmetric + self.overrides = overrides + self.initializers = initializers + + self.process_fns = { + "MatMul": self._process_matmul, + "LayerNormalization": self._process_layernorm, + "Sigmoid": self._process_sigmoid, + "Tanh": self._process_tanh, + } + + def process_node(self, node: onnx.NodeProto): + process_fn = self.process_fns.get(node.op_type) + + if process_fn is not None: + process_fn(node) + + def _process_matmul(self, node: onnx.NodeProto): + """ + Overrides MatMul's initializer input(s) to use the default weight type if: + - The default weight type is 8-bit + - One of the inputs is a 16-bit activation + """ + assert node.op_type == "MatMul", f"Expected MatMul, but got {node.op_type}" + if self.default_weight_qtype not in Q8_TYPES: + return + + input_16bit_act = None + input_wgt = None + + for input_name in node.input: + if input_name and input_name not in self.initializers: + qtype = self.overrides.get_node_input_qtype_info( + input_name, node.name, self.default_activation_qtype + ).quant_type + if qtype in Q16_TYPES: + input_16bit_act = input_name + else: + input_wgt = input_name + + # Override initializer to use the default weight type. + if input_16bit_act and input_wgt: + did_update = self.overrides.update_tensor_overrides( + input_wgt, + {"quant_type": self.default_weight_qtype, "symmetric": self.weight_symmetric}, + overwrite=False, + ) + + if not did_update: + warn_unable_to_override(node, "quant_type/symmetric", input_wgt, "input weight") + + def _process_layernorm(self, node: onnx.NodeProto): + """ + Overrides LayerNormalization's initializer input(s), except for bias, to use the default weight type if: + - The default weight type is 8-bit + - One of the inputs is a 16-bit activation + """ + assert node.op_type == "LayerNormalization", f"Expected LayerNormalization, but got {node.op_type}" + if self.default_weight_qtype not in Q8_TYPES: + return + + has_q16_activation = False + for input_name in node.input: + if input_name and input_name not in self.initializers: + qtype = self.overrides.get_node_input_qtype_info( + input_name, node.name, self.default_activation_qtype + ).quant_type + if qtype in Q16_TYPES: + has_q16_activation = True + break + + # Override initializers to use the self.default_weight_qtype. Don't override the bias input. + if has_q16_activation: + for i in range(2): + input_name = node.input[i] + if input_name and input_name in self.initializers: + did_update = self.overrides.update_tensor_overrides( + input_name, + {"quant_type": self.default_weight_qtype, "symmetric": self.weight_symmetric}, + overwrite=False, + ) + + if not did_update: + warn_unable_to_override(node, "quant_type/symmetric", input_name, "input weight") + + def _process_sigmoid(self, node: onnx.NodeProto): + """ + Overrides 16-bit Sigmoid's output scale and zero-point as per QNN requirements. + """ + assert node.op_type == "Sigmoid", f"Expected Sigmoid, but got {node.op_type}" + output_type = self.overrides.get_node_output_qtype_info( + node.output[0], self.default_activation_qtype + ).quant_type + + if output_type == QuantType.QUInt16: + self.overrides.update_tensor_overrides( + node.output[0], + { + "quant_type": output_type, + "scale": np.array(1.0 / 65536.0, dtype=np.float32), + "zero_point": np.array(0, dtype=np.uint16), + }, + ) + elif output_type == QuantType.QInt16: + self.overrides.update_tensor_overrides( + node.output[0], + { + "quant_type": output_type, + "scale": np.array(1.0 / 32768.0, dtype=np.float32), + "zero_point": np.array(0, dtype=np.int16), + }, + ) + + def _process_tanh(self, node: onnx.NodeProto): + """ + Overrides 16-bit Tanh's output scale and zero-point as per QNN requirements. + """ + assert node.op_type == "Tanh", f"Expected Tanh, but got {node.op_type}" + output_type = self.overrides.get_node_output_qtype_info( + node.output[0], self.default_activation_qtype + ).quant_type + + if output_type == QuantType.QUInt16: + self.overrides.update_tensor_overrides( + node.output[0], + { + "quant_type": output_type, + "scale": np.array(1.0 / 32768.0, dtype=np.float32), + "zero_point": np.array(32768, dtype=np.uint16), + }, + ) + elif output_type == QuantType.QInt16: + self.overrides.update_tensor_overrides( + node.output[0], + { + "quant_type": output_type, + "scale": np.array(1.0 / 32768.0, dtype=np.float32), + "zero_point": np.array(0, dtype=np.int16), + }, + ) diff --git a/onnxruntime/python/tools/quantization/onnx_model.py b/onnxruntime/python/tools/quantization/onnx_model.py index 716dd1eacec6a..174bf5fd1509c 100644 --- a/onnxruntime/python/tools/quantization/onnx_model.py +++ b/onnxruntime/python/tools/quantization/onnx_model.py @@ -441,6 +441,11 @@ def replace_input_of_all_nodes(self, old_input_name, new_input_name): for node in self.model.graph.node: ONNXModel.replace_node_input(node, old_input_name, new_input_name) + def replace_input_of_nodes(self, old_input_name, new_input_name, node_names_set): + for node in self.model.graph.node: + if node.name in node_names_set: + ONNXModel.replace_node_input(node, old_input_name, new_input_name) + @staticmethod def replace_node_output(node, old_output_name, new_output_name): assert isinstance(old_output_name, str) and isinstance(new_output_name, str) @@ -452,6 +457,11 @@ def replace_output_of_all_nodes(self, old_output_name, new_output_name): for node in self.model.graph.node: ONNXModel.replace_node_output(node, old_output_name, new_output_name) + def replace_output_of_nodes(self, old_output_name, new_output_name, node_names_set): + for node in self.model.graph.node: + if node.name in node_names_set: + ONNXModel.replace_node_output(node, old_output_name, new_output_name) + def remove_unused_constant(self): input_name_to_nodes = self.input_name_to_nodes() diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index e2044db04303d..4b76de6ecf1cb 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -5,30 +5,31 @@ # -------------------------------------------------------------------------- import logging +import numpy as np import onnx import onnx.numpy_helper from onnx import onnx_pb as onnx_proto -try: - from onnx.reference.op_run import to_array_extended -except ImportError: - # old version of onnx. - to_array_extended = None - -from .base_quantizer import BaseQuantizer +from .base_quantizer import BaseQuantizer, QuantizationParams +from .calibrate import TensorData from .onnx_model import ONNXModel from .quant_utils import ( TENSOR_NAME_QUANT_SUFFIX, QuantizationMode, QuantizedValue, + QuantizedValueType, __producer__, __version__, add_infer_metadata, attribute_to_kwarg, + compute_scale_zp, + compute_scale_zp_float8, find_by_name, + get_qmin_qmax_for_qType, get_qrange_for_qType, ms_domain, save_and_reload_model_with_shape_infer, + tensor_proto_to_array, ) from .registry import CreateOpQuantizer @@ -77,6 +78,7 @@ def __init__( self.fuse_dynamic_quant = self.opset_version > 10 self.q_matmul_const_b_only = "MatMulConstBOnly" in self.extra_options and self.extra_options["MatMulConstBOnly"] + self.new_nodes = [] self.graph_scope = "/" # for human readable debug information self.tensor_names = {} # in case the shape inference not totally working @@ -88,6 +90,8 @@ def __init__( if self.mode not in QuantizationMode: raise ValueError(f"unsupported quantization mode {self.mode}") + self.quantization_params = self.calculate_quantization_params() + # QuantizeRange tensor name and zero tensor name for scale and zero point calculation. # Used when static is False self.fixed_qrange_uint8_name = "fixed_quantization_range_uint8" @@ -97,6 +101,8 @@ def __init__( # For int8 data-type, zero point is always zero (respresented by fixed_zero_point_name tensor) self.fixed_zero_zp_name = "fixed_zero_zp" + # Map of all original value names to quantized value names + self.quantized_value_map = {} # some output from nodes will be quantized, yet itself should be treat as existing so # no dequantized will be applied when needed later self.generated_value_names = self.model.get_non_initializer_inputs() @@ -494,6 +500,65 @@ def _get_dynamic_input_quantization_params_uint8(self, input_name, nodes_list, i return input_scale_name, input_zp_name, [], [] + def _get_quantization_params(self, param_name, use_scale=None, use_zeropoint=None): + """ + Create initializers and inputs in the graph for zero point and scale of output. + Zero point and scale values are obtained from self.quantization_params if specified. + parameter param_name: Name of the quantization parameter. + return: result, scale_name, zero_point_name, scale_shape, zero_point_shape. + """ + zero_point_type = self.activation_qType + + if use_scale is None or use_zeropoint is None: + if self.quantization_params is None or param_name not in self.quantization_params: + logging.info(f'Quantization parameters for tensor:"{param_name}" not specified') + return False, "", "", "", "" + + params = self.quantization_params[param_name] + if not isinstance(params, QuantizationParams): + raise TypeError(f"Unexpected type {type(params)} for {param_name!r}.") + if params is None or len(params) != 3: + raise ValueError( + "Quantization parameters should contain zero point, scale, quant type. " + f"Specified values for output {param_name}: {params}" + ) + + zero_point_values = np.array([params["zero_point"]]) + if not hasattr(params["scale"], "dtype") or params["scale"].dtype not in (np.float32, np.float16): + raise ValueError(f"Unexpected type {type(params['scale'])} and param_name={param_name!r}") + scale_values = np.array([params["scale"]]) + assert scale_values.dtype != np.float64 + zero_point_type = params["quant_type"] + else: + zero_point_values = np.array([use_zeropoint]) + scale_values = np.array([use_scale]) + params = self.quantization_params[param_name] + if "scale" in params: + dtype = params["scale"].dtype + scale_values = scale_values.astype(dtype) + assert scale_values.dtype != np.float64 + + zero_point_shape = [] + zero_point_name = param_name + "_zero_point" + scale_shape = [] + scale_name = param_name + "_scale" + + # Add initializers + init_zp = onnx.helper.make_tensor( + zero_point_name, zero_point_type, zero_point_shape, zero_point_values.ravel().tolist() + ) + self.model.add_initializer(init_zp) + if scale_values.dtype == np.float32: + scale_type = onnx_proto.TensorProto.FLOAT + elif scale_values.dtype == np.float16: + scale_type = onnx_proto.TensorProto.FLOAT16 + else: + raise ValueError(f"Unexpected dtype={scale_values.dtype} for param_name={param_name!r}") + init_scale = onnx.helper.make_tensor(scale_name, scale_type, scale_shape, scale_values.reshape((-1,)).tolist()) + self.model.add_initializer(init_scale) + + return True, scale_name, zero_point_name, scale_shape, zero_point_shape + def _get_quantize_input_nodes(self, node, input_index, qType, given_scale_name=None, given_zp_name=None): """ Given an input for a node (which is not a initializer), this function @@ -564,6 +629,55 @@ def find_quantized_value(self, input_name): return self.parent.find_quantized_value(input_name) return None + def quantize_bias_static(self, bias_name, input_name, weight_name, beta=1.0): + """ + Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale + """ + + # Handle case where bias already in quantization map + if bias_name in self.quantized_value_map: + return self.quantized_value_map[bias_name].q_name + + # get scale for weight + weight_scale_name = self.quantized_value_map[weight_name].scale_name + weight_initializer = find_by_name(weight_scale_name, self.model.initializer()) + weight_scale = tensor_proto_to_array(weight_initializer) + + # get scale for input + if input_name in self.quantized_value_map: + input_scale_name = self.quantized_value_map[input_name].scale_name + elif input_name in self.quantization_params: + _, input_scale_name, _, _, _ = self._get_quantization_params(input_name) + else: + raise ValueError(f"Expected {input_name} to be in quantized value map for static quantization") + + inputscale_initializer = find_by_name(input_scale_name, self.model.initializer()) + input_scale = tensor_proto_to_array(inputscale_initializer) + + ( + quantized_bias_name, + quantized_bias_scale_name, + quantized_bias_zp_name, + bias_scale_data, + node_type, + node_qtype, + ) = self.quantize_bias_static_impl(bias_name, input_scale, weight_scale, beta) + + assert bias_name not in self.quantized_value_map + quantized_value = QuantizedValue( + bias_name, + quantized_bias_name, + quantized_bias_scale_name, + quantized_bias_zp_name, + QuantizedValueType.Initializer, + 0 if bias_scale_data.size > 1 else None, + node_type=node_type, + node_qtype=node_qtype, + ) + self.quantized_value_map[bias_name] = quantized_value + + return quantized_bias_name + def contains_tensor(self, tensor_name): """ only check for value info and newly generated tensor names, initializers are checked separately @@ -721,6 +835,71 @@ def __quantize_inputs( return quantized_input_names, zero_point_names, scale_names, nodes + def quantize_initializer(self, weight, qType, reduce_range=False, keep_float_weight=False): + """ + :param weight: TensorProto initializer + :param qType: type to quantize to + :param keep_float_weight: Whether to quantize the weight. In some cases, we only want to qunatize scale and zero point. + If keep_float_weight is False, quantize the weight, or don't quantize the weight. + :return: quantized weight name, zero point name, scale name + """ + # Find if this input is already quantized + if weight.name in self.quantized_value_map: + quantized_value = self.quantized_value_map[weight.name] + return ( + quantized_value.q_name, + quantized_value.zp_name, + quantized_value.scale_name, + ) + + q_weight_name, zp_name, scale_name = self.quantize_initializer_impl( + weight, qType, reduce_range, keep_float_weight + ) + + # Log entry for this quantized weight + quantized_value = QuantizedValue( + weight.name, + q_weight_name, + scale_name, + zp_name, + QuantizedValueType.Initializer, + None, + ) + self.quantized_value_map[weight.name] = quantized_value + return q_weight_name, zp_name, scale_name + + def quantize_weight_per_channel( + self, + weight_name, + weight_qType, + channel_axis, + reduce_range=True, + keep_float_weight=False, + ): + # Find if this input is already quantized + if weight_name in self.quantized_value_map: + quantized_value = self.quantized_value_map[weight_name] + return ( + quantized_value.q_name, + quantized_value.zp_name, + quantized_value.scale_name, + ) + + q_weight_name, zp_name, scale_name = self.quantize_weight_per_channel_impl( + weight_name, weight_qType, channel_axis, reduce_range, keep_float_weight + ) + quantized_value = QuantizedValue( + weight_name, + q_weight_name, + scale_name, + zp_name, + QuantizedValueType.Initializer, + None, + ) + self.quantized_value_map[weight_name] = quantized_value + + return q_weight_name, zp_name, scale_name + def _dequantize_value(self, value_name): """ Given a value (input/output) which is quantized, add a DequantizeLinear node to dequantize @@ -771,3 +950,37 @@ def _dequantize_outputs(self): dequantize_node = self._dequantize_value(output.name) if dequantize_node is not None: self.new_nodes.append(dequantize_node) + + def calculate_quantization_params(self): + if self.tensors_range is None: + return None + + self.adjust_tensor_ranges() + + quantization_params = {} + for tensor_name in self.tensors_range: + td = self.tensors_range[tensor_name] + if not isinstance(td, TensorData): + raise TypeError(f"Unexpected type {type(td)} for {tensor_name!r}.") + + quant_overrides = self.tensor_quant_overrides.get_per_tensor_overrides(tensor_name) + + quant_type = self.activation_qType + if "quant_type" in quant_overrides: + quant_type = quant_overrides["quant_type"].tensor_type + + if "scale" in quant_overrides and "zero_point" in quant_overrides: + zero, scale = quant_overrides["zero_point"], quant_overrides["scale"] + elif quant_type == onnx.TensorProto.FLOAT8E4M3FN: + zero, scale = compute_scale_zp_float8(quant_type, td.avg_std[1]) + else: + rmin = quant_overrides.get("rmin", td.range_value[0]) + rmax = quant_overrides.get("rmax", td.range_value[1]) + symmetric = quant_overrides.get("symmetric", self.is_activation_symmetric) + reduce_range = quant_overrides.get("reduce_range", False) + qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range=reduce_range, symmetric=symmetric) + zero, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, self.min_real_range) + + quantization_params[tensor_name] = QuantizationParams(zero_point=zero, scale=scale, quant_type=quant_type) + + return quantization_params diff --git a/onnxruntime/python/tools/quantization/operators/conv.py b/onnxruntime/python/tools/quantization/operators/conv.py index 06204585ba1ca..7054173450569 100644 --- a/onnxruntime/python/tools/quantization/operators/conv.py +++ b/onnxruntime/python/tools/quantization/operators/conv.py @@ -252,4 +252,4 @@ def quantize(self): self.quantizer.quantize_weight_tensor(node.input[1]) if len(node.input) == 3: - self.quantizer.quantize_bias_tensor(node.input[2], node.input[0], node.input[1]) + self.quantizer.quantize_bias_tensor(node.name, node.input[2], node.input[0], node.input[1]) diff --git a/onnxruntime/python/tools/quantization/operators/direct_q8.py b/onnxruntime/python/tools/quantization/operators/direct_q8.py index c14532b96acbc..ae9679ae8ec7a 100644 --- a/onnxruntime/python/tools/quantization/operators/direct_q8.py +++ b/onnxruntime/python/tools/quantization/operators/direct_q8.py @@ -73,6 +73,6 @@ def quantize(self): if self.quantizer.force_quantize_no_input_check: self.quantizer.quantize_activation_tensor(self.node.input[0]) if not self.disable_qdq_for_node_output: - self.quantizer.quantize_activation_tensor(self.node.output[0], self.node.input[0]) + self.quantizer.quantize_output_same_as_input(self.node.output[0], self.node.input[0], self.node.name) elif self.quantizer.is_tensor_quantized(self.node.input[0]) and not self.disable_qdq_for_node_output: - self.quantizer.quantize_activation_tensor(self.node.output[0], self.node.input[0]) + self.quantizer.quantize_output_same_as_input(self.node.output[0], self.node.input[0], self.node.name) diff --git a/onnxruntime/python/tools/quantization/operators/gather.py b/onnxruntime/python/tools/quantization/operators/gather.py index f48725d1e428f..e390e874a2662 100644 --- a/onnxruntime/python/tools/quantization/operators/gather.py +++ b/onnxruntime/python/tools/quantization/operators/gather.py @@ -59,6 +59,6 @@ def quantize(self): if self.quantizer.is_valid_quantize_weight(node.input[0]) or self.quantizer.force_quantize_no_input_check: self.quantizer.quantize_activation_tensor(node.input[0]) - self.quantizer.quantize_activation_tensor(node.output[0], node.input[0]) + self.quantizer.quantize_output_same_as_input(node.output[0], node.input[0], node.name) elif self.quantizer.is_tensor_quantized(node.input[0]): - self.quantizer.quantize_activation_tensor(node.output[0], node.input[0]) + self.quantizer.quantize_output_same_as_input(node.output[0], node.input[0], node.name) diff --git a/onnxruntime/python/tools/quantization/operators/gemm.py b/onnxruntime/python/tools/quantization/operators/gemm.py index d269c8fb47bd1..df24e256aa7fc 100644 --- a/onnxruntime/python/tools/quantization/operators/gemm.py +++ b/onnxruntime/python/tools/quantization/operators/gemm.py @@ -153,7 +153,9 @@ def quantize(self): if len(node.input) == 3: if self.quantizer.is_input_a_initializer(node.input[2]): - self.quantizer.quantize_bias_tensor(node.input[2], node.input[0], node.input[1], get_beta(self.node)) + self.quantizer.quantize_bias_tensor( + node.name, node.input[2], node.input[0], node.input[1], get_beta(self.node) + ) set_default_beta(self.node) else: logging.warning( diff --git a/onnxruntime/python/tools/quantization/operators/norm.py b/onnxruntime/python/tools/quantization/operators/norm.py index e825fe6075601..3c14c926a7e75 100644 --- a/onnxruntime/python/tools/quantization/operators/norm.py +++ b/onnxruntime/python/tools/quantization/operators/norm.py @@ -29,7 +29,7 @@ def quantize(self): self.quantizer.quantize_activation_tensor(node.input[1]) # Bias - self.quantizer.quantize_bias_tensor(node.input[2], node.input[0], node.input[1]) + self.quantizer.quantize_bias_tensor(node.name, node.input[2], node.input[0], node.input[1]) # Output if not self.disable_qdq_for_node_output: diff --git a/onnxruntime/python/tools/quantization/operators/softmax.py b/onnxruntime/python/tools/quantization/operators/softmax.py index 61a69ab3649dd..4b39fae8ac063 100644 --- a/onnxruntime/python/tools/quantization/operators/softmax.py +++ b/onnxruntime/python/tools/quantization/operators/softmax.py @@ -1,18 +1,8 @@ -import numpy as np import onnx import onnx.helper -from ..quant_utils import ( - TENSOR_NAME_QUANT_SUFFIX, - QuantizedValue, - QuantizedValueType, - attribute_to_kwarg, - compute_scale_zp, - get_qmin_qmax_for_qType, - ms_domain, -) +from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain from .base_operator import QuantOperatorBase -from .qdq_base_operator import QDQOperatorBase class QLinearSoftmax(QuantOperatorBase): @@ -82,29 +72,3 @@ def quantize(self): nodes.append(qnode) self.quantizer.new_nodes += nodes return None - - -class QDQSoftmax(QDQOperatorBase): - def quantize(self): - super().quantize() - output_name = self.node.output[0] - quant_overrides = self.quantizer.get_per_tensor_quant_overrides(output_name) - - quant_type = self.quantizer.activation_qType - if "quant_type" in quant_overrides: - quant_type = quant_overrides["quant_type"].tensor_type - - if "scale" in quant_overrides and "zero_point" in quant_overrides: - out_zero_point, out_scale = quant_overrides["zero_point"], quant_overrides["scale"] - else: - # Unless overridden by the user, force Softmax to range from 0.0 to 1.0 - qparams = self.quantizer.quantization_params[output_name] - dtype = qparams.data["scale"].dtype - rmin = quant_overrides.get("rmin", np.array(0, dtype=dtype)) - rmax = quant_overrides.get("rmax", np.array(1, dtype=dtype)) - symmetric = quant_overrides.get("symmetric", self.quantizer.is_activation_symmetric) - reduce_range = quant_overrides.get("reduce_range", False) - qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range=reduce_range, symmetric=symmetric) - out_zero_point, out_scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric=symmetric) - - self.quantizer.set_quant_scale_zp(output_name, (out_scale, out_zero_point)) diff --git a/onnxruntime/python/tools/quantization/operators/split.py b/onnxruntime/python/tools/quantization/operators/split.py index c36b767f5abcc..74fc30cd075d2 100644 --- a/onnxruntime/python/tools/quantization/operators/split.py +++ b/onnxruntime/python/tools/quantization/operators/split.py @@ -60,4 +60,4 @@ def quantize(self): self.quantizer.quantize_activation_tensor(node.input[0]) if not self.disable_qdq_for_node_output: for output in node.output: - self.quantizer.quantize_activation_tensor(output, node.input[0]) + self.quantizer.quantize_output_same_as_input(output, node.input[0], node.name) diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index 1875c552fab9c..c323c6fec545a 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -3,15 +3,21 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +from __future__ import annotations + import logging +from dataclasses import dataclass from enum import Enum +from typing import Any +import numpy as np import onnx import onnx.numpy_helper from onnx import TensorProto from onnx import onnx_pb as onnx_proto -from .base_quantizer import BaseQuantizer +from .base_quantizer import BaseQuantizer, QuantizationParams +from .calibrate import TensorData from .quant_utils import ( DEQUANT_OP_NAME, QUANT_OP_NAME, @@ -24,8 +30,12 @@ add_quant_input_suffix, add_quant_output_suffix, add_quant_suffix, + compute_scale_zp, + compute_scale_zp_float8, find_by_name, + get_qmin_qmax_for_qType, ms_domain, + tensor_proto_to_array, ) from .registry import CreateQDQQuantizer @@ -36,6 +46,17 @@ class QDQQuantTensorType(Enum): BIAS = 2 +# Holds the name of the node input from which a node output will share the +# same quantization param initializers (zero-point and scale initializers). +# Ex: A Transpose node's output will use the same quant param initializers used at the input. +@dataclass +class QDQQuantParamProvider: + input_name: str + node_name: str + + +# Holds information for tensors that have been marked for quantization by operator quantizers. +# Does not hold information for bias tensors. class QDQTensorQuantInfo: def __init__(self, tensor_type=QDQQuantTensorType.ACTIVATION, quant_para_provider=None, axis=None, data_type=None): self.tensor_type = tensor_type @@ -46,6 +67,64 @@ def __init__(self, tensor_type=QDQQuantTensorType.ACTIVATION, quant_para_provide self.data_type = data_type +# Holds information for bias tensors that have been marked for quantization by operator quantizers. +@dataclass +class QDQBiasQuantInfo: + node_name: str + input_name: str + weight_name: str + beta: float + + +# Holds quantization parameter values (scale, zp) for a tensor. +# A tensor typically has a one set of quantization parameters, unless the tensor is +# at a "mixed-precision" boundary where the activation quantization type changes (e.g., from uint8 to uint16). +@dataclass +class QDQTensorQuantParams: + original: QuantizationParams # Generated by producer node. + converted: QuantizationParams | None # Converted type consumed by some (or all/none) consumer nodes. + converted_recv_nodes: set[str] | None # The name of nodes that consume the converted type. + + +# Holds scale and zero_point initializer TensorProtos. +@dataclass +class QDQScaleZpInitializers: + scale: TensorProto + zero_point: TensorProto + + +# Holds all scale and zero-point initializers for a tensor. +# A tensor typically has a one set of quantization parameters, unless the tensor is +# at a "mixed-precision" boundary where the activation quantization type changes (e.g., from uint8 to uint16). +@dataclass +class QDQTensorScaleZpInitializers: + original: QDQScaleZpInitializers + converted: QDQScaleZpInitializers | None + converted_recv_nodes: set[str] | None + + +# Holds cached information of a tensor's quantized values (types, zp/scale initializer names, etc.). +# A tensor typically has a one set of quantization parameters, unless the tensor is +# at a "mixed-precision" boundary where the activation quantization type changes (e.g., from uint8 to uint16). +@dataclass +class QDQTensorQuantizedValue: + original: QuantizedValue + converted: QuantizedValue | None + converted_recv_nodes: set[str] | None + + def get_for_consumer(self, consumer_node_name) -> QuantizedValue: + if self.converted is None: # Quantized value is not converted, return original + return self.original + + if self.converted_recv_nodes is None: # All consumers receive the converted value + return self.converted + + # Check if consumer node name is in the list of nodes that + # receive the converted quantization value. If not, return the original value generated + # by the tensor's producer. + return self.converted if (consumer_node_name in self.converted_recv_nodes) else self.original + + class QDQQuantizer(BaseQuantizer): def __init__( self, @@ -74,7 +153,7 @@ def __init__( extra_options, ) self.tensors_to_quantize = {} - self.bias_to_quantize = [] + self.bias_to_quantize = {} self.nodes_to_remove = [] @@ -100,8 +179,7 @@ def __init__( # The default behavior is that multiple nodes can share a QDQ pair as their inputs. # In TRT, QDQ pair can`t be shared between nodes, so it will create dedicated QDQ pairs for each node. self.dedicated_qdq_pair = extra_options.get("DedicatedQDQPair", False) - if self.dedicated_qdq_pair: - self.tensor_to_its_receiving_nodes = {} + self.tensor_to_its_receiving_nodes = {} # Let user set channel axis for specific op type and it's effective only when per channel quantization is supported and per_channel is True. self.qdq_op_type_per_channel_support_to_axis = extra_options.get("QDQOpTypePerChannelSupportToAxis", {}) @@ -112,7 +190,7 @@ def __init__( # if the activation or weight types are 16-bit integers. # TODO: Remove this override (and use only the 'UseQDQContribOps' option) if/when ONNX adds 16-bit support. int16_types = (TensorProto.UINT16, TensorProto.INT16) - overrides_have_int16 = any(t in int16_types for t in self.tensor_quant_override_types) + overrides_have_int16 = any(t.tensor_type in int16_types for t in self.tensor_quant_override_qtypes) if not self.qdq_op_domain and ( self.activation_qType in int16_types or self.weight_qType in int16_types or overrides_have_int16 ): @@ -123,6 +201,11 @@ def __init__( ) self.qdq_op_domain = ms_domain + self.quantization_params = self.calc_graph_quant_params() + + # Map of all original value names to quantized value names + self.quantized_value_map = {} + def _get_tensor_type(self, tensor_name): """ Check if tensor can be quantized @@ -158,45 +241,71 @@ def _is_tensor_quantizable(self, tensor_name): return False - def __quantize_tensor(self, tensor_name, quant_sharing_param=None, tensor_type=QDQQuantTensorType.ACTIVATION): + def __quantize_tensor(self, tensor_name, quant_sharing_provider=None, tensor_type=QDQQuantTensorType.ACTIVATION): """ - Quantize tensors. If quant_param_tensor is not None, tensor with name tensor_name will be quantized with same - quantization parameters as tensor quant_param_tensor + Adds a tensor to the list (actually a dict) of tensors to quantize. Called indirectly by op quantizers that + want to quantize a tensor (i.e., "mark" a tensor for quantization). + + If quant_sharing_provider is not None, tensor with name tensor_name will be quantized with the same + quantization parameters as the node input specified in quant_sharing_provider. Ex: A Tranpose node's output + will typically use the same quantization parameter initializers used at the Transpose node's input. Args: tensor_name: name of the tensor to quantize - quant_sharing_param: name of the tensor that provides quantization parameter + quant_sharing_provider: name of the tensor and node that provides quantization parameter tensor_type: QDQQuantTensorType default ACTIVATION """ if self._is_tensor_quantizable(tensor_name): - if quant_sharing_param: + if quant_sharing_provider: + if not isinstance(quant_sharing_provider, QDQQuantParamProvider): + raise TypeError( + f"quant_sharing_provider must be of type QDQQuantParamProvider, not {type(quant_sharing_provider)}." + ) + data_type = self._get_tensor_type(tensor_name) self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo( - tensor_type=tensor_type, quant_para_provider=quant_sharing_param, data_type=data_type + tensor_type=tensor_type, quant_para_provider=quant_sharing_provider, data_type=data_type ) elif tensor_name not in self.tensors_to_quantize: data_type = self._get_tensor_type(tensor_name) self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo(tensor_type=tensor_type, data_type=data_type) - def quantize_activation_tensor(self, tensor_name, quant_sharing_param=None): + def quantize_activation_tensor(self, tensor_name: str): """ - Quantize Activation Tensor + Adds a tensor to the list of tensors to quantize. Called by op quantizers that + want to quantize a tensor (i.e., "mark" a tensor for quantization). + Args: tensor_name: name of the tensor to quantize - quant_sharing_param: name of the tensor that provides quantization parameter - """ - return self.__quantize_tensor(tensor_name, quant_sharing_param, QDQQuantTensorType.ACTIVATION) + return self.__quantize_tensor(tensor_name, None, QDQQuantTensorType.ACTIVATION) - def quantize_weight_tensor(self, tensor_name, quant_sharing_param=None): + def quantize_output_same_as_input(self, output_name: str, input_name: str, node_name: str): """ - Quantize Weight Tensor + Adds a tensor to the list of tensors to quantize. Called by op quantizers that + want to quantize an output tensor using the same quantization parameters as one of the node's inputs. + + Ex: A Tranpose node's output will typically use the same quantization parameter initializers used at + the Transpose node's input. + Args: - tensor_name: name of the tensor to quantize - quant_sharing_param: name of the tensor that provides quantization parameter + output_name: name of the node output to quantize so that it uses the same quantization params as an input. + input_name: name of the node input from which the output tensor will get its quantization params. + node_name: name of the node that consumes `input_name`. + """ + return self.__quantize_tensor( + output_name, QDQQuantParamProvider(input_name, node_name), QDQQuantTensorType.ACTIVATION + ) + def quantize_weight_tensor(self, tensor_name: str): """ - return self.__quantize_tensor(tensor_name, quant_sharing_param, QDQQuantTensorType.WEIGHT) + Adds a tensor to the list of weight tensors to quantize. Called by op quantizers that + want to quantize a weight (i.e., "mark" a weight for quantization). + + Args: + tensor_name: name of the weight to quantize + """ + return self.__quantize_tensor(tensor_name, None, QDQQuantTensorType.WEIGHT) def quantize_weight_tensor_per_channel(self, tensor_name, axis): weight = find_by_name(tensor_name, self.model.initializer()) @@ -208,7 +317,19 @@ def quantize_weight_tensor_per_channel(self, tensor_name, axis): else: logging.warning(f"only support per-channel quantization on weight. Tensor: {tensor_name} is not quantized.") - def quantize_bias_tensor(self, bias_name, input_name, weight_name, beta=1.0): + def quantize_bias_tensor(self, node_name, bias_name, input_name, weight_name, beta=1.0): + """ + Adds a bias tensor to the list of bias tensors to quantize. Called by op quantizers that + want to quantize a bias with bias_zero_point = 0 and bias_scale = input_scale * weight_scale * beta. + TODO: Explain the reasoning for using this formula. + + Args: + node_name: name of the node that consumes the bias, input, and weight tensors. + bias_name: name of the bias tensor to quantize. + input_name: name of the input tensor whose scale is used to compute the bias's scale. + weight_name: name of the weight tensor whose scale is used to compute the bias's scale. + beta: Multiplier used to compute the bias's scale. + """ # If the user provided quantization overrides for this tensor, treat it as a regular weight. if self.tensor_quant_overrides.get(bias_name): logging.info( @@ -223,7 +344,10 @@ def quantize_bias_tensor(self, bias_name, input_name, weight_name, beta=1.0): weight = find_by_name(bias_name, self.model.initializer()) if weight is not None: if weight.data_type in (onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16): - self.bias_to_quantize.append((bias_name, input_name, weight_name, beta)) + if bias_name not in self.bias_to_quantize: + self.bias_to_quantize[bias_name] = QDQBiasQuantInfo(node_name, input_name, weight_name, beta) + else: + logging.warning(f"Bias {bias_name} has already been marked for quantization") else: logging.warning(f"Expected {bias_name} to be a weight") @@ -239,11 +363,10 @@ def quantize_model(self): op_quantizer = CreateQDQQuantizer(self, node) op_quantizer.quantize() - if self.dedicated_qdq_pair: - for tensor_name in node.input: - if tensor_name not in self.tensor_to_its_receiving_nodes: - self.tensor_to_its_receiving_nodes[tensor_name] = [] - self.tensor_to_its_receiving_nodes[tensor_name].append(node) + for tensor_name in node.input: + if tensor_name not in self.tensor_to_its_receiving_nodes: + self.tensor_to_its_receiving_nodes[tensor_name] = [] + self.tensor_to_its_receiving_nodes[tensor_name].append(node) self._quantize_normal_tensors() self._quantize_sharing_param_tensors() @@ -263,6 +386,8 @@ def quantize_model(self): def try_replacing_upstream_output(self, upstream_output_name, output_name): if ( output_name in self.quantization_params + and self.quantization_params[output_name].converted is None + and self.quantization_params[upstream_output_name].converted is None and len(self.model.input_name_to_nodes()[upstream_output_name]) == 1 and not self.model.is_graph_output(upstream_output_name) and not self.model.is_graph_input(upstream_output_name) @@ -273,6 +398,50 @@ def try_replacing_upstream_output(self, upstream_output_name, output_name): return True return False + def _create_q_node( + self, + q_input: str, + q_output: str, + quant_node_name: str, + scale_name: str, + zp_name: str, + axis: int | None = None, + ): + """ + Creates a QuantizeLinear node and adds it to the model. + """ + qlinear_node = onnx.helper.make_node( + QUANT_OP_NAME, + [q_input, scale_name, zp_name], + [q_output], + quant_node_name, + axis=axis, + domain=self.qdq_op_domain, + ) + self.model.add_nodes([qlinear_node]) + + def _create_dq_node( + self, + dq_input: str, + dq_output: str, + dequant_node_name: str, + scale_name: str, + zp_name: str, + axis: int | None = None, + ): + """ + Creates a DequantizeLinear node and adds it to the model. + """ + dequant_node = onnx.helper.make_node( + DEQUANT_OP_NAME, + [dq_input, scale_name, zp_name], + [dq_output], + dequant_node_name, + axis=axis, + domain=self.qdq_op_domain, + ) + self.model.add_nodes([dequant_node]) + def _create_qdq_nodes( self, q_input, q_output, quant_node_name, dq_input, dq_output, dequant_node_name, scale_name, zp_name, axis=None ): @@ -383,7 +552,7 @@ def _add_qdq_pair_for_activation(self, tensor_name, scale_name, zp_name, data_ty QuantizedValueType.Input, scale_type=data_type, ) - self.quantized_value_map[tensor_name] = quantized_value + self.quantized_value_map[tensor_name] = QDQTensorQuantizedValue(quantized_value, None, None) else: q_input = tensor_name dq_output = add_dequant_output_suffix(tensor_name) @@ -413,9 +582,165 @@ def _add_qdq_pair_for_activation(self, tensor_name, scale_name, zp_name, data_ty QuantizedValueType.Input, scale_type=data_type, ) - self.quantized_value_map[tensor_name] = quantized_value + self.quantized_value_map[tensor_name] = QDQTensorQuantizedValue(quantized_value, None, None) + + def _add_qdq_ops_for_converted_activation( + self, + tensor_name, + first_scale_name, + first_zp_name, + scale_data_type, + convert_scale_name, + convert_zp_name, + convert_recv_nodes, + ): + """ + Adds Q and DQ ops to a tensor whose quantized data type is converted. That is, some consumers may use the + original data type from the producer, while other consumers use the converted data type. + This is generally done by adding a sequence of ops that convert from one data type (e.g., uint8) to another (e.g., uint16). + + T_float ---> Quant(to u8) ---> Convert(to u16) ---> Dequant(to float) ---> T_float' + where Convert(to u16) is equivalent to: ---> Dequant(to float) ---> Quant(to u16) ---> + + This function handles the following scenarios: + + 1) Tensor T is not a graph output; all consumers use the converted type + + ---> Q1 ---> DQ1 ---> Q2 ---> DQ2 ---> + + 2) Tensor T is not a graph output; some consumers use the original type, others use the converted type + + ---> Q1 -+-> DQ1 ---> + | + +-> DQ1' ---> Q2 ---> DQ2 ---> + + 3) Tensor T is a graph output; all consumers use the converted type + + ---> Q1 ---> DQ1 ---> Q2 ---> DQ2 -+-> + | + +-> + + 4) Tensor T is a graph output; some consumers use the original type, others use the converted type + + ---> Q1 -+-> DQ1 -+-> + | | + | +-> + | + +-> DQ1' ---> Q2 ---> DQ2 ---> + """ + tensor_recv_nodes = set([node.name for node in self.tensor_to_its_receiving_nodes[tensor_name]]) + + if ( + self.dedicated_qdq_pair + and tensor_name in self.tensor_to_its_receiving_nodes + and len(self.tensor_to_its_receiving_nodes[tensor_name]) > 1 + ): + # TODO: Add support for dedicated_qdq_pair if/when needed. + raise ValueError( + "Do not currently support converted quant_types in TensorQuantOverrides when the `dedicated_qdq_pair` extra_option is enabled" + ) + + # Determine which nodes consume the original quantized type and which nodes + # consume the converted quantized type. + original_recv_nodes = tensor_recv_nodes + if convert_recv_nodes is None: # In this case, all consumers receive the converted type. + convert_recv_nodes = tensor_recv_nodes + original_recv_nodes = set() + else: + original_recv_nodes = original_recv_nodes - convert_recv_nodes + + all_use_converted = len(convert_recv_nodes) == len(tensor_recv_nodes) + is_graph_output = self.model.is_graph_output(tensor_name) + + # Create first Q op. + first_q_input = tensor_name + if is_graph_output: + first_q_input = add_quant_input_suffix(tensor_name) + self.model.replace_output_of_all_nodes(tensor_name, first_q_input) + + first_q_output = add_quant_output_suffix(tensor_name) + self._create_q_node( + first_q_input, first_q_output, add_quant_suffix(tensor_name), first_scale_name, first_zp_name + ) + + # Create first DQ op. + first_dq_output = add_dequant_output_suffix(tensor_name) + if is_graph_output and not all_use_converted: + first_dq_output = tensor_name + if original_recv_nodes and first_dq_output != tensor_name: + self.model.replace_input_of_nodes(tensor_name, first_dq_output, original_recv_nodes) + + self._create_dq_node( + first_q_output, first_dq_output, add_dequant_suffix(tensor_name), first_scale_name, first_zp_name + ) + + # Create parallel clone of first DQ op if _not all_ consumers use the converted type. + # --> DQ1' --> Q2 --> DQ2 --> + # + # This DQ clone would only have one consumer Q node (Q2) and could be potentially fused with + # it by some EPs (e.g., QNN) without breaking other "node units". + # Ex QNN fusion: + # --> Convert (fused) --> DQ2 --> + second_q_input = first_dq_output + if not all_use_converted: + second_q_input = add_quant_input_suffix(f"{tensor_name}_convert") + self._create_dq_node( + first_q_output, + second_q_input, + add_dequant_suffix(f"{tensor_name}_convert_clone"), + first_scale_name, + first_zp_name, + ) + + # Create second Q op. + second_q_output = add_quant_output_suffix(f"{tensor_name}_convert") + self._create_q_node( + second_q_input, + second_q_output, + add_quant_suffix(f"{tensor_name}_convert"), + convert_scale_name, + convert_zp_name, + ) + + # Create second DQ op. + second_dq_output = add_dequant_output_suffix(f"{tensor_name}_convert") + if is_graph_output and all_use_converted: + second_dq_output = tensor_name + if convert_recv_nodes and second_dq_output != tensor_name: + self.model.replace_input_of_nodes(tensor_name, second_dq_output, convert_recv_nodes) + self._create_dq_node( + second_q_output, + second_dq_output, + add_dequant_suffix(f"{tensor_name}_convert"), + convert_scale_name, + convert_zp_name, + ) + + # Store in quantized_value_map + original_quantized_value = QuantizedValue( + tensor_name, + first_dq_output, + first_scale_name, + first_zp_name, + QuantizedValueType.Input, + scale_type=scale_data_type, + ) + converted_quantized_value = QuantizedValue( + tensor_name, + second_dq_output, + convert_scale_name, + convert_zp_name, + QuantizedValueType.Input, + scale_type=scale_data_type, + ) + self.quantized_value_map[tensor_name] = QDQTensorQuantizedValue( + original_quantized_value, converted_quantized_value, convert_recv_nodes + ) def _quantize_normal_tensors(self): + """ + Adds Q/DQ ops to tensors (activations and weights) that have been marked for quantization by op quantizers. + """ for tensor_name, tensor_info in self.tensors_to_quantize.copy().items(): if tensor_name in self.quantized_value_map: continue @@ -426,53 +751,105 @@ def _quantize_normal_tensors(self): if initializer: self._add_qdq_pair_for_initializer(initializer, tensor_info.tensor_type, tensor_info.axis) else: - used_scale, used_zp = self.find_quant_scale_zp(tensor_name) - if used_scale is not None and not hasattr(used_scale, "dtype"): - raise TypeError( - f"Unexpected type {type(used_scale)} for used_scale and tensor_name={tensor_name!r}" - ) - data_found, scale_name, zp_name, _, _ = self._get_quantization_params( - tensor_name, used_scale, used_zp - ) - - if not data_found: + tensor_qparam_initializers = self._make_tensor_scale_zp_initializers(tensor_name) + if not tensor_qparam_initializers: raise ValueError( f"Quantization parameters are not specified for param {tensor_name}. " "In static mode quantization params for inputs and outputs of nodes to be quantized are required." ) - self._add_qdq_pair_for_activation(tensor_name, scale_name, zp_name, data_type=tensor_info.data_type) + if tensor_qparam_initializers.converted is None: + # Normal case: --> Q --> DQ --> + self._add_qdq_pair_for_activation( + tensor_name, + tensor_qparam_initializers.original.scale.name, + tensor_qparam_initializers.original.zero_point.name, + data_type=tensor_info.data_type, + ) + else: + # Conversion case: ---> Q1 -+-> DQ1 --> + # | + # +-> DQ1' --> Q2 --> DQ2 --> + assert tensor_info.data_type == tensor_qparam_initializers.original.scale.data_type + self._add_qdq_ops_for_converted_activation( + tensor_name, + tensor_qparam_initializers.original.scale.name, + tensor_qparam_initializers.original.zero_point.name, + tensor_info.data_type, + tensor_qparam_initializers.converted.scale.name, + tensor_qparam_initializers.converted.zero_point.name, + tensor_qparam_initializers.converted_recv_nodes, + ) del self.tensors_to_quantize[tensor_name] def _quantize_sharing_param_tensors(self): + """ + Adds Q/DQ ops to tensors that have been marked for quantization by op quantizers. + Only operates on tensors that want to use the quantization parameter initializers from an upstream tensor. + For example, a Transpose node's output tensor will typically want to use the same quantization parameter + initializers as the Transpose node's input. + """ while self.tensors_to_quantize: for tensor_name, tensor_info in self.tensors_to_quantize.copy().items(): - tensor_provider_name = tensor_info.quant_para_provider - if tensor_provider_name in self.quantized_value_map: + quant_provider = tensor_info.quant_para_provider + if quant_provider and quant_provider.input_name in self.quantized_value_map: del self.tensors_to_quantize[tensor_name] - quantized_value = self.quantized_value_map[tensor_provider_name] - # Quantize the input - initializer = find_by_name(tensor_name, self.model.initializer()) - if initializer is not None: + quantized_value = self.quantized_value_map[quant_provider.input_name].get_for_consumer( + quant_provider.node_name + ) + if self.is_input_a_initializer(tensor_name): raise ValueError("Quantization parameter shared mode is not supported for weight yet") - self._add_qdq_pair_for_activation(tensor_name, quantized_value.scale_name, quantized_value.zp_name) + + # Need to check if this tensor's quant_type is converted for some consumers. + # If so, create new scale/zp initializers for these consumers. + converted_qparam_inits = None + converted_recv_nodes = None + if tensor_name in self.quantization_params: + tensor_params = self.quantization_params[tensor_name] + if tensor_params.converted: + converted_qparam_inits = self._make_scale_zp_initializers( + tensor_name, tensor_params.converted, "_convert" + ) + converted_recv_nodes = tensor_params.converted_recv_nodes + + if converted_qparam_inits is None: + # Normal case: --> Q_shared --> DQ_shared --> + self._add_qdq_pair_for_activation( + tensor_name, quantized_value.scale_name, quantized_value.zp_name + ) + else: + # Conversion case: ---> Q_shared -+-> DQ_shared --> + # | + # +-> DQ_shared' --> Q2 --> DQ2 --> + self._add_qdq_ops_for_converted_activation( + tensor_name, + quantized_value.scale_name, + quantized_value.zp_name, + converted_qparam_inits.scale.data_type, + converted_qparam_inits.scale.name, + converted_qparam_inits.zero_point.name, + converted_recv_nodes, + ) def _quantize_bias_tensors(self): - for bias_name, input_name, weight_name, beta in self.bias_to_quantize: + """ + Adds DQ ops (or Cast) for bias tensors that have been marked for quantization by op quantizers. + """ + for bias_name, bias_info in self.bias_to_quantize.items(): if bias_name in self.quantized_value_map: continue # Quantize the input - self.quantize_bias_static(bias_name, input_name, weight_name, beta) + self.quantize_bias_static(bias_name, bias_info) init = find_by_name(bias_name, self.model.initializer()) self.model.remove_initializer(init) - quant_value = self.quantized_value_map[bias_name] + quant_value = self.quantized_value_map[bias_name].original if quant_value.node_type == "Cast": # simple cast to float 16 and not DequantizeLinear # cublasLtMatmul only supports (b)float16, float bias. if not isinstance(init.data_type, int): - raise TypeError(f"Unexpected type {type(init.data_type)} for input={input_name!r}") + raise TypeError(f"Unexpected type {type(init.data_type)} for input={bias_info.input_name!r}") node_name = add_dequant_suffix(bias_name) dequant_node = onnx.helper.make_node( "Cast", @@ -511,5 +888,233 @@ def _quantize_bias_tensors(self): raise RuntimeError(f"Unexpected operator type {quant_value.node_type!r}.") self.model.add_node(dequant_node) - def is_tensor_quantized(self, tensor_name): + def is_tensor_quantized(self, tensor_name: str): return tensor_name in self.tensors_to_quantize or tensor_name in self.bias_to_quantize + + def quantize_initializer( + self, + weight: onnx.TensorProto, + qType: onnx.TensorProto.DataType, + reduce_range: bool = False, + keep_float_weight: bool = False, + ) -> tuple[str, str, str]: + """ + :param weight: TensorProto initializer + :param qType: type to quantize to + :param keep_float_weight: Whether to quantize the weight. In some cases, we only want to qunatize scale and zero point. + If keep_float_weight is False, quantize the weight, or don't quantize the weight. + :return: quantized weight name, zero point name, scale name + """ + # Find if this input is already quantized + if weight.name in self.quantized_value_map: + quantized_value = self.quantized_value_map[weight.name].original + return ( + quantized_value.q_name, + quantized_value.zp_name, + quantized_value.scale_name, + ) + + q_weight_name, zp_name, scale_name = self.quantize_initializer_impl( + weight, qType, reduce_range, keep_float_weight + ) + + # Log entry for this quantized weight + quantized_value = QuantizedValue( + weight.name, + q_weight_name, + scale_name, + zp_name, + QuantizedValueType.Initializer, + None, + ) + self.quantized_value_map[weight.name] = QDQTensorQuantizedValue(quantized_value, None, None) + return q_weight_name, zp_name, scale_name + + def quantize_weight_per_channel( + self, + weight_name: str, + weight_qType: onnx.TensorProto.DataType, + channel_axis: int, + reduce_range: bool = True, + keep_float_weight: bool = False, + ) -> tuple[str, str, str]: + # Find if this input is already quantized + if weight_name in self.quantized_value_map: + quantized_value = self.quantized_value_map[weight_name].original + return ( + quantized_value.q_name, + quantized_value.zp_name, + quantized_value.scale_name, + ) + + q_weight_name, zp_name, scale_name = self.quantize_weight_per_channel_impl( + weight_name, weight_qType, channel_axis, reduce_range, keep_float_weight + ) + quantized_value = QuantizedValue( + weight_name, + q_weight_name, + scale_name, + zp_name, + QuantizedValueType.Initializer, + None, + ) + self.quantized_value_map[weight_name] = QDQTensorQuantizedValue(quantized_value, None, None) + + return q_weight_name, zp_name, scale_name + + def quantize_bias_static(self, bias_name: str, bias_info: QDQBiasQuantInfo) -> str: + """ + Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale + """ + + # Handle case where bias already in quantization map + if bias_name in self.quantized_value_map: + return self.quantized_value_map[bias_name].original.q_name + + # get scale for weight + weight_scale_name = self.quantized_value_map[bias_info.weight_name].original.scale_name + weight_initializer = find_by_name(weight_scale_name, self.model.initializer()) + weight_scale = tensor_proto_to_array(weight_initializer) + + # get scale for input + input_scale_name = ( + self.quantized_value_map[bias_info.input_name].get_for_consumer(bias_info.node_name).scale_name + ) + inputscale_initializer = find_by_name(input_scale_name, self.model.initializer()) + input_scale = tensor_proto_to_array(inputscale_initializer) + + ( + quantized_bias_name, + quantized_bias_scale_name, + quantized_bias_zp_name, + bias_scale_data, + node_type, + node_qtype, + ) = self.quantize_bias_static_impl(bias_name, input_scale, weight_scale, bias_info.beta) + + quantized_value = QuantizedValue( + bias_name, + quantized_bias_name, + quantized_bias_scale_name, + quantized_bias_zp_name, + QuantizedValueType.Initializer, + 0 if bias_scale_data.size > 1 else None, + node_type=node_type, + node_qtype=node_qtype, + ) + self.quantized_value_map[bias_name] = QDQTensorQuantizedValue(quantized_value, None, None) + + return quantized_bias_name + + def _make_scale_zp_initializers( + self, param_name: str, params: QuantizationParams, init_name_suffix: str = "" + ) -> QDQScaleZpInitializers: + """ + Creates and returns scale and zero-point initializers for the given quantization params. The initializers are + named: + - {param_name}_zero_point{init_name_suffix} + - {param_name}_scale{init_name_suffix} + """ + zero_point_values = np.array([params["zero_point"]]) + if not hasattr(params["scale"], "dtype") or params["scale"].dtype not in (np.float32, np.float16): + raise ValueError(f"Unexpected type {type(params['scale'])} and param_name={param_name!r}") + scale_values = np.array([params["scale"]]) + assert scale_values.dtype != np.float64 + zero_point_type = params.data.get("quant_type", self.activation_qType) + + zero_point_shape = [] + zero_point_name = param_name + "_zero_point" + init_name_suffix + scale_shape = [] + scale_name = param_name + "_scale" + init_name_suffix + + # Add initializers to model + init_zp = onnx.helper.make_tensor( + zero_point_name, zero_point_type, zero_point_shape, zero_point_values.ravel().tolist() + ) + self.model.add_initializer(init_zp) + + if scale_values.dtype == np.float32: + scale_type = onnx_proto.TensorProto.FLOAT + elif scale_values.dtype == np.float16: + scale_type = onnx_proto.TensorProto.FLOAT16 + else: + raise ValueError(f"Unexpected dtype={scale_values.dtype} for param_name={param_name!r}") + init_scale = onnx.helper.make_tensor(scale_name, scale_type, scale_shape, scale_values.reshape((-1,)).tolist()) + self.model.add_initializer(init_scale) + + return QDQScaleZpInitializers(init_scale, init_zp) + + def _make_tensor_scale_zp_initializers(self, tensor_name: str) -> QDQTensorScaleZpInitializers | None: + """ + Create and returns all scale/zero_point initializers for a given tensor. If the tensor is converted + to a different quantization type, this function creates two pairs of zp/scale initializers. Otherwise, + only one pair of zp/scale initializers is created. + """ + if self.quantization_params is None or tensor_name not in self.quantization_params: + logging.info(f'Quantization parameters for tensor:"{tensor_name}" not specified') + return None + + tensor_params = self.quantization_params[tensor_name] + if not isinstance(tensor_params, QDQTensorQuantParams): + raise TypeError(f"Unexpected type {type(tensor_params)} for {tensor_name!r}.") + + original_inits = self._make_scale_zp_initializers(tensor_name, tensor_params.original) + converted_inits = ( + self._make_scale_zp_initializers(tensor_name, tensor_params.converted, "_convert") + if tensor_params.converted + else None + ) + + return QDQTensorScaleZpInitializers(original_inits, converted_inits, tensor_params.converted_recv_nodes) + + def calc_quant_params(self, tensor_data: TensorData, quant_overrides: dict[str, Any]) -> QuantizationParams: + """ + Calculates quantization parameters (scale/zero-point) given a tensor's min/max range and optional + user-provided overrides. + """ + quant_type = self.activation_qType + if "quant_type" in quant_overrides: + quant_type = quant_overrides["quant_type"].tensor_type + + if "scale" in quant_overrides and "zero_point" in quant_overrides: + zero, scale = quant_overrides["zero_point"], quant_overrides["scale"] + elif quant_type == onnx.TensorProto.FLOAT8E4M3FN: + zero, scale = compute_scale_zp_float8(quant_type, tensor_data.avg_std[1]) + else: + rmin = quant_overrides.get("rmin", tensor_data.range_value[0]) + rmax = quant_overrides.get("rmax", tensor_data.range_value[1]) + symmetric = quant_overrides.get("symmetric", self.is_activation_symmetric) + reduce_range = quant_overrides.get("reduce_range", False) + qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range=reduce_range, symmetric=symmetric) + zero, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, self.min_real_range) + + return QuantizationParams(zero_point=zero, scale=scale, quant_type=quant_type) + + def calc_graph_quant_params(self) -> dict[str, QDQTensorQuantParams]: + """ + Calculates quantization parameters (scale/zero-point) for all tensors in the graph using each tensor's min/max range + and optional user-provided overrides. + """ + if self.tensors_range is None: + return {} + + self.adjust_tensor_ranges() + + quantization_params = {} + for tensor_name in self.tensors_range: + td = self.tensors_range[tensor_name] + if not isinstance(td, TensorData): + raise TypeError(f"Unexpected type {type(td)} for {tensor_name!r}.") + + quant_overrides = self.tensor_quant_overrides.get_per_tensor_overrides(tensor_name) + original = self.calc_quant_params(td, quant_overrides) + converted = None + converted_recv_nodes = None + + if "convert" in quant_overrides: + converted = self.calc_quant_params(td, quant_overrides["convert"]) + converted_recv_nodes = quant_overrides["convert"].get("recv_nodes") + + quantization_params[tensor_name] = QDQTensorQuantParams(original, converted, converted_recv_nodes) + + return quantization_params diff --git a/onnxruntime/python/tools/quantization/registry.py b/onnxruntime/python/tools/quantization/registry.py index a693f4192bc2b..b00e830a2a366 100644 --- a/onnxruntime/python/tools/quantization/registry.py +++ b/onnxruntime/python/tools/quantization/registry.py @@ -18,7 +18,7 @@ from .operators.pooling import QLinearPool from .operators.qdq_base_operator import QDQOperatorBase from .operators.resize import QDQResize, QResize -from .operators.softmax import QDQSoftmax, QLinearSoftmax +from .operators.softmax import QLinearSoftmax from .operators.split import QDQSplit, QSplit from .operators.where import QDQWhere, QLinearWhere from .quant_utils import QuantizationMode @@ -79,7 +79,6 @@ "MatMul": QDQMatMul, "Split": QDQSplit, "Gather": QDQGather, - "Softmax": QDQSoftmax, "Where": QDQWhere, "InstanceNormalization": QDQNormalization, "LayerNormalization": QDQNormalization, diff --git a/onnxruntime/python/tools/quantization/tensor_quant_overrides.py b/onnxruntime/python/tools/quantization/tensor_quant_overrides.py new file mode 100644 index 0000000000000..793d58cbc4e3e --- /dev/null +++ b/onnxruntime/python/tools/quantization/tensor_quant_overrides.py @@ -0,0 +1,345 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from __future__ import annotations + +import json +from collections.abc import MutableMapping +from dataclasses import dataclass +from typing import Any + +from .quant_utils import QuantType + + +@dataclass +class QuantTypeInfo: + """ + The quantization type information for a tensor override. + """ + + quant_type: QuantType + symmetric: bool | None = None # If None, assumes default is used. + reduce_range: bool | None = None # If None, assumes default is used. + + def __eq__(self, other: object): + if isinstance(other, QuantTypeInfo): + return ( + self.quant_type == other.quant_type + and (self.symmetric is None or other.symmetric is None or self.symmetric == other.symmetric) + and (self.reduce_range is None or other.reduce_range is None or self.reduce_range == other.reduce_range) + ) + return NotImplemented + + @staticmethod + def load_from_dict( + raw_dict: dict[str, Any], + default_activation_qtype: QuantType | None = None, + default_activation_symmetric: bool | None = None, + default_activation_reduce_range: bool | None = None, + ) -> QuantTypeInfo: + return QuantTypeInfo( + raw_dict.get("quant_type", default_activation_qtype), + raw_dict.get("symmetric", default_activation_symmetric), + raw_dict.get("reduce_range", default_activation_reduce_range), + ) + + def save_to_dict(self, raw_dict: dict[str, Any]): + raw_dict["quant_type"] = self.quant_type + if self.symmetric is not None: + raw_dict["symmetric"] = self.symmetric + if self.reduce_range is not None: + raw_dict["reduce_range"] = self.reduce_range + + +class TensorQuantOverridesHelper(MutableMapping): + """ + Utility wrapper over the tensor quantization overrides passed via extra_options. + """ + + def __init__(self, raw_overrides: dict[str, list[dict[str, Any]]]): + self.overrides = raw_overrides + self.quant_types = None + + def get_per_tensor_overrides(self, tensor_name: str) -> dict[str, Any]: + overrides_list = self.overrides.get(tensor_name, [{}]) + num_overrides = len(overrides_list) + if num_overrides > 1: + raise ValueError( + f"Expected tensor '{tensor_name}' to use per-tensor quantization overrides, " + f"but found {num_overrides} per-channel overrides." + ) + + return overrides_list[0] if num_overrides > 0 else {} + + def get_per_channel_overrides( + self, + tensor_name: str, + num_channels: int, + ) -> list[dict[str, Any]]: + overrides_list = self.overrides.get(tensor_name, [{} for i in range(num_channels)]) + + if len(overrides_list) != num_channels: + raise ValueError( + f"Expected tensor '{tensor_name}' to have {num_channels} per-channel quantization overrides, " + f"but found {len(overrides_list)} instead." + ) + + return overrides_list + + def get_quant_types(self) -> set[QuantType]: + if self.quant_types is not None: + return self.quant_types + + self.quant_types = set() + + if self.overrides: + for quant_overrides_list in self.overrides.values(): + for quant_overrides in quant_overrides_list: + if "quant_type" in quant_overrides: + self.quant_types.add(quant_overrides["quant_type"]) + + if "convert" in quant_overrides and "quant_type" in quant_overrides["convert"]: + self.quant_types.add(quant_overrides["convert"]["quant_type"]) + + return self.quant_types + + def is_valid( + self, + initializer_names: set[str], + activation_names: set[str], + default_activation_qtype, + ) -> tuple[bool, str | None]: + self.quant_types = set() + + # Validate that compatible/valid overrides are provided. + if self.overrides: + keys_unsupported_with_scale_zp = {"symmetric", "reduce_range", "rmax", "rmin"} + + for tensor_name, quant_overrides_list in self.overrides.items(): + if tensor_name not in initializer_names and tensor_name not in activation_names: + return False, f"Tensor '{tensor_name}' in TensorQuantOverrides is not present in the model" + + if not isinstance(quant_overrides_list, list): + return False, f"Tensor quantization overrides for '{tensor_name}' are not in a list" + + is_initializer = tensor_name in initializer_names + if not is_initializer and len(quant_overrides_list) > 1: + return ( + False, + f"Tensor '{tensor_name}' has a list of per-channel overrides, but is not an initializer", + ) + + quant_type = None + for index, quant_overrides in enumerate(quant_overrides_list): + if not isinstance(quant_overrides, dict): + return ( + False, + f"Tensor quantization overrides at index {index} for '{tensor_name}' are not in a dict", + ) + + # For per-channel quantization, all channels must use the same quantization type. + # Therefore, if the user tries to override the quant_type for a channel, it must match in all + # other channels. + if index == 0: + quant_type = quant_overrides.get("quant_type") + if quant_type: + self.quant_types.add(quant_type) + elif quant_type != quant_overrides.get("quant_type"): + return ( + False, + "Channel quantization types for tensor '{tensor_name}' do not match at index {index}.", + ) + + has_scale = "scale" in quant_overrides + has_zero_point = "zero_point" in quant_overrides + + if (has_scale and not has_zero_point) or (has_zero_point and not has_scale): + return ( + False, + "Must provide both 'scale' and 'zero_point' if one of the overrides is provided", + ) + + if has_scale: + for key in keys_unsupported_with_scale_zp: + if key in quant_overrides: + return ( + False, + f"Tensor override option '{key}' is invalid with 'scale' and 'zero_point'", + ) + + if "reduce_range" in quant_overrides and not is_initializer: + return ( + False, + f"Option 'reduce_range' is only supported for initializers, not for activation {tensor_name}", + ) + + if "convert" in quant_overrides: + if index > 0: + return ( + False, + f"Per-channel overrides (tensor '{tensor_name}') do not support 'convert'.", + ) + + if is_initializer: + return False, "Cannot use 'convert' override for initializers" + + if "quant_type" not in quant_overrides["convert"]: + return False, f"'convert' options (tensor '{tensor_name}') must specify a 'quant_type'" + + if "reduce_range" in quant_overrides["convert"]: + return ( + False, + f"Option 'reduce_range' is only supported for initializers, not for activation {tensor_name}", + ) + + convert_quant_type = quant_overrides["convert"]["quant_type"] + original_quant_type = quant_type if quant_type is not None else default_activation_qtype + if convert_quant_type == original_quant_type: + return ( + False, + f"'convert' quant_type must differ from original quant_type (tensor '{tensor_name}')", + ) + + convert_has_scale = "scale" in quant_overrides["convert"] + convert_has_zero_point = "zero_point" in quant_overrides["convert"] + + if (convert_has_scale and not convert_has_zero_point) or ( + convert_has_zero_point and not convert_has_scale + ): + return ( + False, + f"Must provide both 'scale' and 'zero_point' if one of the overrides is provided (tensor '{tensor_name}')", + ) + + if convert_has_scale: + for key in keys_unsupported_with_scale_zp: + if key in quant_overrides["convert"]: + return ( + False, + f"Tensor override option '{key}' is invalid with 'scale' and 'zero_point' (tensor '{tensor_name}')", + ) + + self.quant_types.add(convert_quant_type) + + return True, None + + def update_tensor_overrides( + self, + tensor_name: str, + new_vals: dict[str, Any], + channels: list[int] | None = None, + overwrite: bool = True, + ) -> bool: + if not new_vals: + return False + + channels = set(channels) if channels is not None else None + have_overrides = self.overrides.get(tensor_name) + + # If `overwrite` is False, check if we would overwrite anything. + do_update = True + if not overwrite and have_overrides: + for channel, overrides in enumerate(self.overrides[tensor_name]): + if channels is not None and channel not in channels: + continue + if set(new_vals).intersection(set(overrides)): + do_update = False + break + + # Do the update if `overwrite` is True or if nothing is overwritten (do not want partial overwrites). + if do_update: + if not have_overrides: + self.overrides[tensor_name] = [{}] + + for channel, overrides in enumerate(self.overrides[tensor_name]): + if channels is not None and channel not in channels: + continue + overrides.update(new_vals) + + return do_update + + def get_node_output_qtype_info( + self, + output_name: str, + default_qtype: QuantType | None, + default_symmetric: bool | None = None, + ) -> QuantTypeInfo: + if output_name not in self.overrides: + return QuantTypeInfo(default_qtype, default_symmetric) + + # Get the first overrides dict in the list. This works for both per-tensor and per-channel + # quantization because all channels must use the same quant type. + tensor_overrides = self.overrides[output_name][0] + + return QuantTypeInfo( + tensor_overrides.get("quant_type", default_qtype), + tensor_overrides.get("symmetric", default_symmetric), + ) + + def get_node_input_qtype_info( + self, + input_name: str, + node_name: str, + default_qtype: QuantType | None, + default_symmetric: bool | None = None, + default_reduce_range: bool | None = None, + ) -> QuantTypeInfo: + if input_name not in self.overrides or not self.overrides[input_name]: + return QuantTypeInfo(default_qtype, default_symmetric, default_reduce_range) + + # Get the first overrides dict in the list. This works for both per-tensor and per-channel + # quantization because all channels must use the same quant type. + tensor_overrides = self.overrides[input_name][0] + producer_type = tensor_overrides.get("quant_type", default_qtype) + + if "convert" not in tensor_overrides: + return QuantTypeInfo(producer_type, default_symmetric, default_reduce_range) + + # This tensor is converted. Check if the node gets the original qtype or the converted qtype. + convert_dict = tensor_overrides["convert"] + qtype_info = QuantTypeInfo( + producer_type, + convert_dict.get("symmetric", default_symmetric), + convert_dict.get("reduce_range", default_reduce_range), + ) + + # Check if all nodes receive the converted type (i.e., recv_nodes is None) or this node + # is in the list of consumers (recv_nodes). + if ("recv_nodes" not in convert_dict) or (node_name in convert_dict["recv_nodes"]): + qtype_info.quant_type = convert_dict["quant_type"] + + return qtype_info + + def pprint_str(self, indent=None) -> str: + return json.dumps(self.overrides, default=str, indent=indent) + + def empty(self) -> bool: + return not self.overrides + + def get_dict(self) -> dict[str, list[dict[str, Any]]]: + return self.overrides + + # Required implementations of abstract methods in collections.abc.MutableMapping + # so that this class can be used like a dict. + def __setitem__(self, key: str, value: list[dict]): + self.overrides[key] = value + + def __getitem__(self, key: str) -> list[dict]: + return self.overrides[key] + + def __delitem__(self, key: str): + del self.overrides[key] + + def __iter__(self): + return iter(self.overrides) + + def __len__(self): + return len(self.overrides) + + def __str__(self) -> str: + return str(self.overrides) + + def __repr__(self) -> str: + return f"{super().__repr__()}, TensorQuantOverridesHelper({self.overrides})" diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py index 4d0d2e68e8983..47b7f35cbdd7c 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py @@ -400,11 +400,7 @@ def main(): sampling_times.append(sampling_end_time - sampling_start_time) all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1) - - # Return early if all batch entries have reached EOS token id current_length += 1 - if torch.all(has_eos) or current_length > max_length: - break # Update inputs for next inference run inputs["input_ids"] = tokens_to_add diff --git a/onnxruntime/python/tools/transformers/onnx_utils.py b/onnxruntime/python/tools/transformers/onnx_utils.py new file mode 100644 index 0000000000000..64fade9369395 --- /dev/null +++ b/onnxruntime/python/tools/transformers/onnx_utils.py @@ -0,0 +1,55 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from fusion_utils import NumpyHelper +from onnx import ModelProto, TensorProto +from onnx.external_data_helper import set_external_data +from onnx_model import OnnxModel + +from onnxruntime import OrtValue + + +def extract_raw_data_from_model(model: ModelProto): + """ + Extract external data from model and return the external data as a list of tuples (name, value). + Note this function does not handle external data that is not loaded into the model as raw data. + + Args: + model (ModelProto): the model proto to extract external data from. + Returns: + (external_names, external_values): a tuple of two lists of external data names and values. + """ + external_data = [] + onnx_model = OnnxModel(model) + for graph in onnx_model.graphs(): + for initializer in graph.initializer: + name = initializer.name + + if initializer.HasField("raw_data"): + numpy_tensor = NumpyHelper.to_array(initializer) + ort_value = OrtValue.ortvalue_from_numpy(numpy_tensor) + external_data.append((name, ort_value)) + # mimic set_external_data + set_external_data(initializer, location="foo.bin") + initializer.name = name + initializer.ClearField("raw_data") + + return zip(*external_data) + + +def has_external_data(model: ModelProto): + """ + Check if the model has external data. + + Args: + model (ModelProto): the model proto to check for external data. + Returns: + bool: True if the model has external data, False otherwise. + """ + onnx_model = OnnxModel(model) + for graph in onnx_model.graphs(): + for initializer in graph.initializer: + if initializer.HasField("data_location") and initializer.data_location == TensorProto.EXTERNAL: + return True + return False diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index ce0be6b3449ed..068ccefef7d97 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -21,11 +21,12 @@ import logging import os import tempfile -from typing import Dict, List, Optional +from pathlib import Path +from typing import Dict, List, Optional, Union import coloredlogs from fusion_options import FusionOptions -from onnx import ModelProto, TensorProto, load_model +from onnx import ModelProto, load_model from onnx_model import OnnxModel from onnx_model_bart import BartOnnxModel from onnx_model_bert import BertOnnxModel @@ -40,6 +41,9 @@ from onnx_model_unet import UnetOnnxModel from onnx_model_vae import VaeOnnxModel +import onnxruntime +from onnxruntime.transformers.onnx_utils import extract_raw_data_from_model, has_external_data + logger = logging.getLogger(__name__) # Map model type to tuple: optimizer class, export tools (pytorch, tf2onnx, keras2onnx), and default opt_level @@ -64,7 +68,7 @@ def optimize_by_onnxruntime( - onnx_model_path: str, + onnx_model: Union[str, ModelProto], use_gpu: bool = False, optimized_model_path: Optional[str] = None, opt_level: Optional[int] = 99, @@ -80,7 +84,7 @@ def optimize_by_onnxruntime( Use onnxruntime to optimize model. Args: - onnx_model_path (str): the path of input onnx model. + onnx_model (str | ModelProto): the path of input onnx model or ModelProto. use_gpu (bool): whether the optimized model is targeted to run in GPU. optimized_model_path (str or None): the path of optimized model. opt_level (int): graph optimization level. @@ -95,8 +99,6 @@ def optimize_by_onnxruntime( assert opt_level in [1, 2, 99] from torch import version as torch_version - import onnxruntime - if ( use_gpu and provider is None @@ -105,9 +107,13 @@ def optimize_by_onnxruntime( ) ): logger.error("There is no gpu for onnxruntime to do optimization.") - return onnx_model_path + return onnx_model - model = OnnxModel(load_model(onnx_model_path, load_external_data=False)) + model = ( + OnnxModel(load_model(onnx_model, load_external_data=False)) + if isinstance(onnx_model, str) + else OnnxModel(onnx_model) + ) if model.use_float16() and not use_gpu: logger.warning( "This model uses float16 in the graph, use_gpu=False might cause extra Cast nodes. " @@ -125,7 +131,10 @@ def optimize_by_onnxruntime( sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL if optimized_model_path is None: - path_prefix = onnx_model_path[:-5] # remove .onnx suffix + if isinstance(onnx_model, str): + path_prefix = str(Path(onnx_model).with_suffix("")) # remove .onnx suffix + else: + path_prefix = "optimized_model" optimized_model_path = "{}_o{}_{}.onnx".format(path_prefix, opt_level, "gpu" if use_gpu else "cpu") sess_options.optimized_model_filepath = optimized_model_path @@ -174,7 +183,20 @@ def optimize_by_onnxruntime( else: providers.append("CUDAExecutionProvider") - onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=providers, **kwargs) + # For large model, extract external data from model and add to session options + if isinstance(onnx_model, ModelProto): + if has_external_data(onnx_model): + raise ValueError( + "ModelProto has external data not loaded into memory, ORT cannot create session. " + "Please load external data before calling this function. " + "See https://onnx.ai/onnx/repo-docs/ExternalData.html for more information." + ) + external_names, external_values = extract_raw_data_from_model(onnx_model) + sess_options.add_external_initializers(list(external_names), list(external_values)) + + # Inference session is only used to optimize the model. + onnx_model = onnx_model.SerializeToString() if isinstance(onnx_model, ModelProto) else onnx_model + onnxruntime.InferenceSession(onnx_model, sess_options, providers=providers, **kwargs) assert os.path.exists(optimized_model_path) and os.path.isfile(optimized_model_path) logger.debug("Save optimized model by onnxruntime to %s", optimized_model_path) @@ -187,7 +209,7 @@ def optimize_by_fusion( num_heads: int = 0, hidden_size: int = 0, optimization_options: Optional[FusionOptions] = None, -): +) -> OnnxModel: """Optimize Model by graph fusion logic. Note that ONNXRuntime graph optimizations (like constant folding) will not be applied. So it is better to enable @@ -241,7 +263,7 @@ def optimize_by_fusion( def optimize_model( - input: str, + input: Union[str, ModelProto], model_type: str = "bert", num_heads: int = 0, hidden_size: int = 0, @@ -252,7 +274,7 @@ def optimize_model( verbose: bool = False, *, provider: Optional[str] = None, -): +) -> OnnxModel: """Optimize Model by OnnxRuntime and/or python fusion logic. ONNX Runtime has graph optimizations (https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html). @@ -275,7 +297,7 @@ def optimize_model( For BERT model, num_heads and hidden_size are optional. For other model types, you need specify these parameters. Args: - input (str): input model path. + input (str | ModelProto): input model path or ModelProto. model_type (str, optional): model type - like bert, bert_tf, bert_keras or gpt2. Defaults to 'bert'. num_heads (int, optional): number of attention heads. Defaults to 0. 0 allows detect the parameter from graph automatically. @@ -298,9 +320,9 @@ def optimize_model( if model_type not in MODEL_TYPES: logger.warning(f"Unsupported model type: {model_type} for optimization, directly return model.") - return OnnxModel(load_model(input)) + return OnnxModel(load_model(input)) if isinstance(input, str) else OnnxModel(input) - (optimizer_class, _producer, default_opt_level) = MODEL_TYPES[model_type] + (optimizer_class, _, default_opt_level) = MODEL_TYPES[model_type] if opt_level is None: opt_level = default_opt_level @@ -316,11 +338,9 @@ def optimize_model( # Auto detect if input model has external data has_external_data_file = False - original_model = load_model(input, load_external_data=False) - for initializer in original_model.graph.initializer: - if initializer.HasField("data_location") and initializer.data_location == TensorProto.EXTERNAL: - has_external_data_file = True - break + original_model = load_model(input, load_external_data=False) if isinstance(input, str) else input + if has_external_data(original_model): + has_external_data_file = True del original_model if opt_level > 1: @@ -365,7 +385,12 @@ def optimize_model( if only_onnxruntime and not temp_model_path: logger.warning("Please specify a positive value for opt_level when only_onnxruntime is True") - model = load_model(temp_model_path or input) + if temp_model_path is not None: + model = load_model(temp_model_path) + elif isinstance(input, str): + model = load_model(input) + else: + model = input if only_onnxruntime: optimizer = optimizer_class(model, num_heads, hidden_size) diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 4f294f899c170..7fd2441441dcf 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -168,6 +168,26 @@ TEST(QnnEP, TestDisableCPUFallback_ConflictingConfig) { } } +// Conv node `Conv` is not supported: GetFileLength for conv_qdq_external_ini.bin failed:open file conv_qdq_external_ini.bin fail, +// errcode = 2 - The system cannot find the file specified. +TEST_F(QnnHTPBackendTests, TestConvWithExternalData) { + Ort::SessionOptions so; + onnxruntime::ProviderOptions options; +#if defined(_WIN32) + options["backend_path"] = "QnnHtp.dll"; +#else + options["backend_path"] = "libQnnHtp.so"; +#endif + + so.AppendExecutionProvider("QNN", options); + + Ort::Status status(OrtSessionOptionsAppendExecutionProvider_CPU(so, 1)); + + const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "conv_qdq_external_ini.onnx"; + + Ort::Session session(*ort_env, ort_model_path, so); +} + // Helper function that runs an ONNX model with a NHWC Resize operator to test that // type/shape inference succeeds during layout transformation. // Refer to onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h. diff --git a/onnxruntime/test/python/quantization/test_mixed_prec_quant_overrides_fixer.py b/onnxruntime/test/python/quantization/test_mixed_prec_quant_overrides_fixer.py new file mode 100644 index 0000000000000..96277056adee0 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_mixed_prec_quant_overrides_fixer.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import unittest + +import onnx + +from onnxruntime.quantization import QuantType +from onnxruntime.quantization.execution_providers.qnn.mixed_precision_overrides_utils import ( + MixedPrecisionTensorQuantOverridesFixer, +) +from onnxruntime.quantization.tensor_quant_overrides import TensorQuantOverridesHelper + + +class TestMixedPrecisionQuantOverridesFixer(unittest.TestCase): + def build_test_model_1(self, shape): + input_0 = onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, shape) + input_1 = onnx.helper.make_tensor_value_info("input_1", onnx.TensorProto.FLOAT, shape) + output_0 = onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, shape) + output_1 = onnx.helper.make_tensor_value_info("output_1", onnx.TensorProto.FLOAT, shape) + output_2 = onnx.helper.make_tensor_value_info("output_2", onnx.TensorProto.FLOAT, shape) + + op1_node = onnx.helper.make_node("Sigmoid", ["input_0"], ["op1_out"], name="op1") + op2_node = onnx.helper.make_node("Cos", ["input_1"], ["op2_out"], name="op2") + op3_node = onnx.helper.make_node("Sin", ["op1_out"], ["op3_out"], name="op3") + op4_node = onnx.helper.make_node("Tanh", ["op2_out"], ["op4_out"], name="op4") + op5_node = onnx.helper.make_node("Mul", ["op3_out", "op4_out"], ["op5_out"], name="op5") + op6_node = onnx.helper.make_node("Relu", ["op5_out"], ["output_0"], name="op6") + op7_node = onnx.helper.make_node("Cos", ["op2_out"], ["output_1"], name="op7") + op8_node = onnx.helper.make_node("Sigmoid", ["op2_out"], ["output_2"], name="op8") + + graph = onnx.helper.make_graph( + [ + op1_node, + op2_node, + op3_node, + op4_node, + op5_node, + op6_node, + op7_node, + op8_node, + ], + "mixed_prec_test", + [input_0, input_1], + [output_0, output_1, output_2], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return onnx.shape_inference.infer_shapes(model) + + def test_fixer_1(self): + shape = (1, 2, 3) + model = self.build_test_model_1(shape) + onnx.save_model(model, "model.onnx") + + default_act_qtype = QuantType.QUInt8 + raw_overrides = {"op4_out": [{"quant_type": QuantType.QUInt16}]} + overrides = TensorQuantOverridesHelper(raw_overrides) + fixer = MixedPrecisionTensorQuantOverridesFixer.create_from_model(overrides, model, default_act_qtype) + fixer.apply(default_act_qtype, default_activation_symmetric=False) + + expected = { + "op2_out": [ + {"quant_type": QuantType.QUInt8, "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op4"}}} + ], + "op3_out": [ + {"quant_type": QuantType.QUInt8, "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op5"}}} + ], + "op4_out": [{"quant_type": QuantType.QUInt16}], + "op5_out": [ + {"quant_type": QuantType.QUInt16, "convert": {"quant_type": QuantType.QUInt8, "recv_nodes": {"op6"}}} + ], + } + self.assertDictEqual(overrides.get_dict(), expected) + + def test_fixer_with_symmetric(self): + shape = (1, 2, 3) + model = self.build_test_model_1(shape) + onnx.save_model(model, "model.onnx") + + default_act_qtype = QuantType.QInt8 + raw_overrides = {"op4_out": [{"quant_type": QuantType.QInt16, "symmetric": True}]} + overrides = TensorQuantOverridesHelper(raw_overrides) + fixer = MixedPrecisionTensorQuantOverridesFixer.create_from_model(overrides, model, default_act_qtype) + fixer.apply(default_act_qtype, default_activation_symmetric=False) + + expected = { + "op2_out": [ + { + "quant_type": QuantType.QInt8, + "convert": {"quant_type": QuantType.QInt16, "symmetric": True, "recv_nodes": {"op4"}}, + } + ], + "op3_out": [ + { + "quant_type": QuantType.QInt8, + "convert": {"quant_type": QuantType.QInt16, "symmetric": True, "recv_nodes": {"op5"}}, + } + ], + "op4_out": [{"quant_type": QuantType.QInt16, "symmetric": True}], + "op5_out": [ + { + "quant_type": QuantType.QInt16, + "symmetric": True, + "convert": {"quant_type": QuantType.QInt8, "recv_nodes": {"op6"}}, + } + ], + } + self.assertDictEqual(overrides.get_dict(), expected) + + def test_fixer_upgrade_output(self): + shape = (1, 2, 3) + model = self.build_test_model_1(shape) + onnx.save_model(model, "model.onnx") + + default_act_qtype = QuantType.QUInt8 + raw_overrides = { + "op4_out": [{"quant_type": QuantType.QUInt16}], + "output_0": [{"quant_type": QuantType.QUInt16}], + } + overrides = TensorQuantOverridesHelper(raw_overrides) + fixer = MixedPrecisionTensorQuantOverridesFixer.create_from_model(overrides, model, default_act_qtype) + fixer.apply(default_act_qtype, default_activation_symmetric=False) + + expected = { + "op2_out": [ + {"quant_type": QuantType.QUInt8, "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op4"}}} + ], + "op3_out": [ + {"quant_type": QuantType.QUInt8, "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op5"}}} + ], + "op4_out": [{"quant_type": QuantType.QUInt16}], + "op5_out": [{"quant_type": QuantType.QUInt16}], + "output_0": [{"quant_type": QuantType.QUInt16}], + } + self.assertDictEqual(overrides.get_dict(), expected) + + def test_fixer_upgrade_input(self): + shape = (1, 2, 3) + model = self.build_test_model_1(shape) + onnx.save_model(model, "model.onnx") + + default_act_qtype = QuantType.QUInt8 + raw_overrides = {"op4_out": [{"quant_type": QuantType.QUInt16}], "input_0": [{"quant_type": QuantType.QUInt16}]} + overrides = TensorQuantOverridesHelper(raw_overrides) + fixer = MixedPrecisionTensorQuantOverridesFixer.create_from_model(overrides, model, default_act_qtype) + fixer.apply(default_act_qtype, default_activation_symmetric=False) + + expected = { + "input_0": [{"quant_type": QuantType.QUInt16}], + "op1_out": [ + {"quant_type": QuantType.QUInt16, "convert": {"quant_type": QuantType.QUInt8, "recv_nodes": {"op3"}}} + ], + "op2_out": [ + {"quant_type": QuantType.QUInt8, "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op4"}}} + ], + "op3_out": [ + {"quant_type": QuantType.QUInt8, "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op5"}}} + ], + "op4_out": [{"quant_type": QuantType.QUInt16}], + "op5_out": [ + {"quant_type": QuantType.QUInt16, "convert": {"quant_type": QuantType.QUInt8, "recv_nodes": {"op6"}}} + ], + } + self.assertDictEqual(overrides.get_dict(), expected) diff --git a/onnxruntime/test/python/quantization/test_qdq.py b/onnxruntime/test/python/quantization/test_qdq.py index 9e7a4a125121d..db4ab7e8a412c 100644 --- a/onnxruntime/test/python/quantization/test_qdq.py +++ b/onnxruntime/test/python/quantization/test_qdq.py @@ -4,7 +4,9 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +from __future__ import annotations +import os import tempfile import unittest from pathlib import Path @@ -25,12 +27,12 @@ class TestQDQFormat(unittest.TestCase): - def input_feeds(self, n, name2shape): + def input_feeds(self, n, name2shape, np_float_type=np.float32): input_data_list = [] for _i in range(n): inputs = {} for name, shape in name2shape.items(): - inputs.update({name: np.random.randint(-1, 2, shape).astype(np.float32)}) + inputs.update({name: np.random.randint(-1, 2, shape).astype(np_float_type)}) input_data_list.extend([inputs]) dr = TestDataFeeds(input_data_list) return dr @@ -720,5 +722,593 @@ def test_activation_only(self): check_op_type_count(self, qdq_model_path, **qop_nodes) +class TestQDQMixedPrecision(TestQDQFormat): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.qdq.mixed_prec_") + + # Note: swap with the commented line if you want to see the models in local test dir. + cls._tmp_dir_path = cls._tmp_model_dir.name + # cls._tmp_dir_path = "." + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def build_test_model_for_add_qdq_ops( + self, + num_consumers: int, + is_graph_output: bool, + float_type: onnx.TensorProto.DataType = onnx.TensorProto.FLOAT, + op0_transpose: bool = False, + ): + """ + Builds a float32 model with a single producer node and a configurable number of consumer nodes. + The tensor between the producer and consumers can be optionally made a graph output. + op_0 can optionally be made a Transpose node to test sharing qparams across the input and output. + + +-> op_0_out (optional graph output) + | + input_0 --> op_0 --+-> op_1 --> output_0 + | + +-> op_2 --> output_1 + | + ... + | + +-> op_{n} --> output_{n-1} + """ + shape = (1, 2, 3) + shape_t = (1, 3, 2) + input_0 = onnx.helper.make_tensor_value_info("input_0", float_type, shape) + output_shape = shape if not op0_transpose else shape_t + + outputs = [] + for i in range(num_consumers): + outputs.append(onnx.helper.make_tensor_value_info(f"output_{i}", float_type, output_shape)) + + if is_graph_output: + outputs.append(onnx.helper.make_tensor_value_info("op_0_out", float_type, output_shape)) + + nodes = [] + if op0_transpose: + nodes.append(onnx.helper.make_node("Transpose", ["input_0"], ["op_0_out"], perm=[0, 2, 1], name="op_0")) + else: + nodes.append(onnx.helper.make_node("Sigmoid", ["input_0"], ["op_0_out"], name="op_0")) + + for i in range(num_consumers): + op_index = i + 1 + nodes.append(onnx.helper.make_node("Cos", ["op_0_out"], [f"output_{i}"], name=f"op_{op_index}")) + + graph = onnx.helper.make_graph( + nodes, + "test_add_qdq_ops_for_converted_activation", + [input_0], + outputs, + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return onnx.shape_inference.infer_shapes(model) + + def test_add_tensor_qdq_ops_case_1(self): + """ + Tensor T is not a graph output; all consumers use the converted type + ---> Q1 ---> DQ1 ---> Q2 ---> DQ2 ---> + """ + # Test configurations (qparam_sharing, float_type) + subtest_configs = [ + (False, onnx.TensorProto.FLOAT, np.float32), + (False, onnx.TensorProto.FLOAT16, np.float16), + (True, onnx.TensorProto.FLOAT, np.float32), + (True, onnx.TensorProto.FLOAT16, np.float16), + ] + for test_qparam_sharing, float_type, np_float_type in subtest_configs: + with self.subTest(test_qparam_sharing=test_qparam_sharing, float_type=float_type): + label = f"_share{test_qparam_sharing}_f{float_type}" + float_model_path = os.path.join(self._tmp_dir_path, f"case_1{label}.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, f"case_1{label}.qdq.onnx") + float_model = self.build_test_model_for_add_qdq_ops( + 2, False, float_type=float_type, op0_transpose=test_qparam_sharing + ) + onnx.save_model(float_model, float_model_path) + + data_reader = self.input_feeds(3, {"input_0": (1, 2, 3)}, np_float_type) + + mixed_prec_overrides = { + "op_0_out": [ + { + "quant_type": QuantType.QUInt8, + "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op_1", "op_2"}}, + } + ], + "output_0": [{"quant_type": QuantType.QUInt16}], + "output_1": [{"quant_type": QuantType.QUInt16}], + } + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + op_types_to_quantize=[node.op_type for node in float_model.graph.node], + extra_options={ + "TensorQuantOverrides": mixed_prec_overrides, + "ForceQuantizeNoInputCheck": test_qparam_sharing, # To ensure Transpose is wrapped in DQ/Q + }, + ) + + # Expect the following QDQ model: + # input_0 --> Q --> DQ --> op_0 --> Q_8 --> DQ_8 --> Q_16 --> DQ_16 -+-> op_1 --> Q --> DQ --> output_0 + # | + # +-> op_2 --> Q --> DQ --> output_1 + qdq_node_counts = {"QuantizeLinear": 5, "DequantizeLinear": 5} + check_op_type_count(self, qdq_model_path, **qdq_node_counts) + + qdq_model = onnx.load_model(qdq_model_path) + onnx.checker.check_model(qdq_model, True) + + initializers = {init.name: init for init in qdq_model.graph.initializer} + + # Check zero-point data types + orig_zp_init = None + if test_qparam_sharing: + # op_0_out_zero_point should not be in the model because the Transpose output is sharing + # qparams from the Transpose input. + self.assertNotIn("op_0_out_zero_point", initializers) + orig_zp_init = initializers["input_0_zero_point"] + else: + orig_zp_init = initializers["op_0_out_zero_point"] + + self.assertEqual(orig_zp_init.data_type, onnx.TensorProto.UINT8) + convert_zp_init = initializers["op_0_out_zero_point_convert"] + self.assertEqual(convert_zp_init.data_type, onnx.TensorProto.UINT16) + output_0_zp_init = initializers["output_0_zero_point"] + self.assertEqual(output_0_zp_init.data_type, onnx.TensorProto.UINT16) + output_1_zp_init = initializers["output_1_zero_point"] + self.assertEqual(output_1_zp_init.data_type, onnx.TensorProto.UINT16) + + # Check scale data types + orig_scale_init = None + if test_qparam_sharing: + self.assertNotIn("op_0_out_scale", initializers) + orig_scale_init = initializers["input_0_scale"] + else: + orig_scale_init = initializers["op_0_out_scale"] + + self.assertEqual(orig_scale_init.data_type, float_type) + convert_scale_init = initializers["op_0_out_scale_convert"] + self.assertEqual(convert_scale_init.data_type, float_type) + output_0_scale_init = initializers["output_0_scale"] + self.assertEqual(output_0_scale_init.data_type, float_type) + output_1_scale_init = initializers["output_1_scale"] + self.assertEqual(output_1_scale_init.data_type, float_type) + + def test_add_tensor_qdq_ops_case_2(self): + """ + Tensor T is not a graph output; some consumers use the original type, others use the converted type + ---> Q1 -+-> DQ1 ---> + | + +-> DQ1' ---> Q2 ---> DQ2 ---> + """ + # Test configurations (qparam_sharing, float_type) + subtest_configs = [ + (False, onnx.TensorProto.FLOAT, np.float32), + (False, onnx.TensorProto.FLOAT16, np.float16), + (True, onnx.TensorProto.FLOAT, np.float32), + (True, onnx.TensorProto.FLOAT16, np.float16), + ] + for test_qparam_sharing, float_type, np_float_type in subtest_configs: + with self.subTest(test_qparam_sharing=test_qparam_sharing, float_type=float_type): + label = f"_share{test_qparam_sharing}_f{float_type}" + float_model_path = os.path.join(self._tmp_dir_path, f"case_2{label}.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, f"case_2{label}.qdq.onnx") + float_model = self.build_test_model_for_add_qdq_ops( + 4, False, float_type=float_type, op0_transpose=test_qparam_sharing + ) + onnx.save_model(float_model, float_model_path) + + data_reader = self.input_feeds(3, {"input_0": (1, 2, 3)}, np_float_type) + + mixed_prec_overrides = { + "op_0_out": [ + { + "quant_type": QuantType.QUInt8, + "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op_3", "op_4"}}, + } + ], + "output_2": [{"quant_type": QuantType.QUInt16}], + "output_3": [{"quant_type": QuantType.QUInt16}], + } + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + op_types_to_quantize=[node.op_type for node in float_model.graph.node], + extra_options={ + "TensorQuantOverrides": mixed_prec_overrides, + "ForceQuantizeNoInputCheck": test_qparam_sharing, # To ensure Transpose is wrapped in DQ/Q + }, + ) + + # Expect the following QDQ model: + # input_0 --> Q --> DQ --> op_0 --> Q_8 -+-> DQ_8 -+-> op_1 --> Q --> DQ --> output_0 + # | | + # | +-> op_2 --> Q --> DQ --> output_1 + # | + # +-> DQ_8' --> Q_16 --> DQ_16 -+-> op_3 --> Q --> DQ --> output_2 + # | + # +-> op_4 --> Q --> DQ --> output_3 + qdq_node_counts = {"QuantizeLinear": 7, "DequantizeLinear": 8} + check_op_type_count(self, qdq_model_path, **qdq_node_counts) + + qdq_model = onnx.load_model(qdq_model_path) + onnx.checker.check_model(qdq_model, True) + + initializers = {init.name: init for init in qdq_model.graph.initializer} + + # Check zero-point data types + orig_zp_init = None + if test_qparam_sharing: + # op_0_out_zero_point should not be in the model because the Transpose output is sharing + # qparams from the Transpose input. + self.assertNotIn("op_0_out_zero_point", initializers) + orig_zp_init = initializers["input_0_zero_point"] + else: + orig_zp_init = initializers["op_0_out_zero_point"] + + self.assertEqual(orig_zp_init.data_type, onnx.TensorProto.UINT8) + convert_zp_init = initializers["op_0_out_zero_point_convert"] + self.assertEqual(convert_zp_init.data_type, onnx.TensorProto.UINT16) + output_0_zp_init = initializers["output_0_zero_point"] + self.assertEqual(output_0_zp_init.data_type, onnx.TensorProto.UINT8) + output_1_zp_init = initializers["output_1_zero_point"] + self.assertEqual(output_1_zp_init.data_type, onnx.TensorProto.UINT8) + output_2_zp_init = initializers["output_2_zero_point"] + self.assertEqual(output_2_zp_init.data_type, onnx.TensorProto.UINT16) + output_3_zp_init = initializers["output_3_zero_point"] + self.assertEqual(output_3_zp_init.data_type, onnx.TensorProto.UINT16) + + # Check scale data types + orig_scale_init = None + if test_qparam_sharing: + self.assertNotIn("op_0_out_scale", initializers) + orig_scale_init = initializers["input_0_scale"] + else: + orig_scale_init = initializers["op_0_out_scale"] + + self.assertEqual(orig_scale_init.data_type, float_type) + convert_scale_init = initializers["op_0_out_scale_convert"] + self.assertEqual(convert_scale_init.data_type, float_type) + output_0_scale_init = initializers["output_0_scale"] + self.assertEqual(output_0_scale_init.data_type, float_type) + output_1_scale_init = initializers["output_1_scale"] + self.assertEqual(output_1_scale_init.data_type, float_type) + output_2_scale_init = initializers["output_2_scale"] + self.assertEqual(output_2_scale_init.data_type, float_type) + output_3_scale_init = initializers["output_3_scale"] + self.assertEqual(output_3_scale_init.data_type, float_type) + + def test_add_tensor_qdq_ops_case_3(self): + """ + Tensor T is a graph output; all consumers use the converted type + ---> Q1 ---> DQ1 ---> Q2 ---> DQ2 -+-> + | + +-> + """ + # Test configurations (qparam_sharing, float_type) + subtest_configs = [ + (False, onnx.TensorProto.FLOAT, np.float32), + (False, onnx.TensorProto.FLOAT16, np.float16), + (True, onnx.TensorProto.FLOAT, np.float32), + (True, onnx.TensorProto.FLOAT16, np.float16), + ] + for test_qparam_sharing, float_type, np_float_type in subtest_configs: + with self.subTest(test_qparam_sharing=test_qparam_sharing, float_type=float_type): + label = f"_share{test_qparam_sharing}_f{float_type}" + float_model_path = os.path.join(self._tmp_dir_path, f"case_3{label}.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, f"case_3{label}.qdq.onnx") + float_model = self.build_test_model_for_add_qdq_ops( + 2, True, float_type=float_type, op0_transpose=test_qparam_sharing + ) + onnx.save_model(float_model, float_model_path) + + data_reader = self.input_feeds(3, {"input_0": (1, 2, 3)}, np_float_type) + + mixed_prec_overrides = { + "op_0_out": [ + { + "quant_type": QuantType.QUInt8, + "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op_1", "op_2"}}, + } + ], + "output_0": [{"quant_type": QuantType.QUInt16}], + "output_1": [{"quant_type": QuantType.QUInt16}], + } + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + op_types_to_quantize=[node.op_type for node in float_model.graph.node], + extra_options={ + "TensorQuantOverrides": mixed_prec_overrides, + "ForceQuantizeNoInputCheck": test_qparam_sharing, # To ensure Transpose is wrapped in DQ/Q + }, + ) + + # Expect the following QDQ model: + # input_0 --> Q --> DQ --> op_0 --> Q_8 --> DQ_8 --> Q_16 --> DQ_16 -+-> op_1 --> Q --> DQ --> output_0 + # | + # +-> op_2 --> Q --> DQ --> output_1 + # | + # +--> op_0_out (is graph output) + qdq_node_counts = {"QuantizeLinear": 5, "DequantizeLinear": 5} + check_op_type_count(self, qdq_model_path, **qdq_node_counts) + + qdq_model = onnx.load_model(qdq_model_path) + onnx.checker.check_model(qdq_model, True) + + initializers = {init.name: init for init in qdq_model.graph.initializer} + graph_outputs = {g_output.name: g_output for g_output in qdq_model.graph.output} + + # Check zero-point data types + orig_zp_init = None + if test_qparam_sharing: + # op_0_out_zero_point should not be in the model because the Transpose output is sharing + # qparams from the Transpose input. + self.assertNotIn("op_0_out_zero_point", initializers) + self.assertNotIn("op_0_out_scale", initializers) + orig_zp_init = initializers["input_0_zero_point"] + else: + orig_zp_init = initializers["op_0_out_zero_point"] + + self.assertEqual(orig_zp_init.data_type, onnx.TensorProto.UINT8) + convert_zp_init = initializers["op_0_out_zero_point_convert"] + self.assertEqual(convert_zp_init.data_type, onnx.TensorProto.UINT16) + output_0_zp_init = initializers["output_0_zero_point"] + self.assertEqual(output_0_zp_init.data_type, onnx.TensorProto.UINT16) + output_1_zp_init = initializers["output_1_zero_point"] + self.assertEqual(output_1_zp_init.data_type, onnx.TensorProto.UINT16) + + # Check scale data types + orig_scale_init = None + if test_qparam_sharing: + self.assertNotIn("op_0_out_scale", initializers) + orig_scale_init = initializers["input_0_scale"] + else: + orig_scale_init = initializers["op_0_out_scale"] + + self.assertEqual(orig_scale_init.data_type, float_type) + convert_scale_init = initializers["op_0_out_scale_convert"] + self.assertEqual(convert_scale_init.data_type, float_type) + output_0_scale_init = initializers["output_0_scale"] + self.assertEqual(output_0_scale_init.data_type, float_type) + output_1_scale_init = initializers["output_1_scale"] + self.assertEqual(output_1_scale_init.data_type, float_type) + + self.assertIn("op_0_out", graph_outputs) + + def test_add_tensor_qdq_ops_case_4(self): + """ + Tensor T is a graph output; some consumers use the original type, others use the converted type + ---> Q1 -+-> DQ1 -+-> + | | + | +-> + | + +-> DQ1' ---> Q2 ---> DQ2 ---> + """ + # Test configurations (qparam_sharing, float_type) + subtest_configs = [ + (False, onnx.TensorProto.FLOAT, np.float32), + (False, onnx.TensorProto.FLOAT16, np.float16), + (True, onnx.TensorProto.FLOAT, np.float32), + (True, onnx.TensorProto.FLOAT16, np.float16), + ] + for test_qparam_sharing, float_type, np_float_type in subtest_configs: + with self.subTest(test_qparam_sharing=test_qparam_sharing, float_type=float_type): + label = f"_share{test_qparam_sharing}_f{float_type}" + float_model_path = os.path.join(self._tmp_dir_path, f"case_4{label}.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, f"case_4{label}.qdq.onnx") + float_model = self.build_test_model_for_add_qdq_ops( + 4, True, float_type=float_type, op0_transpose=test_qparam_sharing + ) + onnx.save_model(float_model, float_model_path) + + data_reader = self.input_feeds(3, {"input_0": (1, 2, 3)}, np_float_type) + + mixed_prec_overrides = { + "op_0_out": [ + { + "quant_type": QuantType.QUInt8, + "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op_3", "op_4"}}, + } + ], + "output_2": [{"quant_type": QuantType.QUInt16}], + "output_3": [{"quant_type": QuantType.QUInt16}], + } + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + op_types_to_quantize=[node.op_type for node in float_model.graph.node], + extra_options={ + "TensorQuantOverrides": mixed_prec_overrides, + "ForceQuantizeNoInputCheck": test_qparam_sharing, # To ensure Transpose is wrapped in DQ/Q + }, + ) + + # Expect the following QDQ model: + # input_0 --> Q --> DQ --> op_0 --> Q_8 -+-> DQ_8 -+-> op_1 --> Q --> DQ --> output_0 + # | | + # | +-> op_2 --> Q --> DQ --> output_1 + # | | + # | +-> op_0_out (is graph output) + # | + # +-> DQ_8' --> Q_16 --> DQ_16 -+-> op_3 --> Q --> DQ --> output_2 + # | + # +-> op_4 --> Q --> DQ --> output_3 + qdq_node_counts = {"QuantizeLinear": 7, "DequantizeLinear": 8} + check_op_type_count(self, qdq_model_path, **qdq_node_counts) + + qdq_model = onnx.load_model(qdq_model_path) + onnx.checker.check_model(qdq_model, True) + + initializers = {init.name: init for init in qdq_model.graph.initializer} + graph_outputs = {g_output.name: g_output for g_output in qdq_model.graph.output} + + # Check zero-point data types + orig_zp_init = None + if test_qparam_sharing: + # op_0_out_zero_point should not be in the model because the Transpose output is sharing + # qparams from the Transpose input. + self.assertNotIn("op_0_out_zero_point", initializers) + orig_zp_init = initializers["input_0_zero_point"] + else: + orig_zp_init = initializers["op_0_out_zero_point"] + + self.assertEqual(orig_zp_init.data_type, onnx.TensorProto.UINT8) + convert_zp_init = initializers["op_0_out_zero_point_convert"] + self.assertEqual(convert_zp_init.data_type, onnx.TensorProto.UINT16) + output_0_zp_init = initializers["output_0_zero_point"] + self.assertEqual(output_0_zp_init.data_type, onnx.TensorProto.UINT8) + output_1_zp_init = initializers["output_1_zero_point"] + self.assertEqual(output_1_zp_init.data_type, onnx.TensorProto.UINT8) + output_2_zp_init = initializers["output_2_zero_point"] + self.assertEqual(output_2_zp_init.data_type, onnx.TensorProto.UINT16) + output_3_zp_init = initializers["output_3_zero_point"] + self.assertEqual(output_3_zp_init.data_type, onnx.TensorProto.UINT16) + + # Check scale data types + orig_scale_init = None + if test_qparam_sharing: + self.assertNotIn("op_0_out_scale", initializers) + orig_scale_init = initializers["input_0_scale"] + else: + orig_scale_init = initializers["op_0_out_scale"] + + self.assertEqual(orig_scale_init.data_type, float_type) + convert_scale_init = initializers["op_0_out_scale_convert"] + self.assertEqual(convert_scale_init.data_type, float_type) + output_0_scale_init = initializers["output_0_scale"] + self.assertEqual(output_0_scale_init.data_type, float_type) + output_1_scale_init = initializers["output_1_scale"] + self.assertEqual(output_1_scale_init.data_type, float_type) + output_2_scale_init = initializers["output_2_scale"] + self.assertEqual(output_2_scale_init.data_type, float_type) + output_3_scale_init = initializers["output_3_scale"] + self.assertEqual(output_3_scale_init.data_type, float_type) + + self.assertIn("op_0_out", graph_outputs) + + def build_test_model_1(self, shape): + """ + Returns the following float32 model. + + input_0 --> op1 --> op3 --> op5 --> op6 --> output_0 + ^ + | + input_1 --> op2 -+-> op4 ----+ + | + +-> op7 --> output_1 + | + +-> op8 --> output_2 + """ + input_0 = onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, shape) + input_1 = onnx.helper.make_tensor_value_info("input_1", onnx.TensorProto.FLOAT, shape) + output_0 = onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, shape) + output_1 = onnx.helper.make_tensor_value_info("output_1", onnx.TensorProto.FLOAT, shape) + output_2 = onnx.helper.make_tensor_value_info("output_2", onnx.TensorProto.FLOAT, shape) + + op1_node = onnx.helper.make_node("Sigmoid", ["input_0"], ["op1_out"], name="op1") + op2_node = onnx.helper.make_node("Cos", ["input_1"], ["op2_out"], name="op2") + op3_node = onnx.helper.make_node("Sin", ["op1_out"], ["op3_out"], name="op3") + op4_node = onnx.helper.make_node("Tanh", ["op2_out"], ["op4_out"], name="op4") + op5_node = onnx.helper.make_node("Mul", ["op3_out", "op4_out"], ["op5_out"], name="op5") + op6_node = onnx.helper.make_node("Relu", ["op5_out"], ["output_0"], name="op6") + op7_node = onnx.helper.make_node("Cos", ["op2_out"], ["output_1"], name="op7") + op8_node = onnx.helper.make_node("Sigmoid", ["op2_out"], ["output_2"], name="op8") + + graph = onnx.helper.make_graph( + [ + op1_node, + op2_node, + op3_node, + op4_node, + op5_node, + op6_node, + op7_node, + op8_node, + ], + "mixed_prec_test", + [input_0, input_1], + [output_0, output_1, output_2], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return onnx.shape_inference.infer_shapes(model) + + def test_16bit_subgraph(self): + """ + Test correctness of a qdq model that uses a default 8-bit quantization type and contains + a subgraph that uses 16-bit activations. + """ + shape = (1, 2, 3) + f32_model_path = os.path.join(self._tmp_dir_path, "model.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, "model.qdq.onnx") + qdq_mixed_model_path = os.path.join(self._tmp_dir_path, "model.mixed.qdq.onnx") + f32_model = self.build_test_model_1(shape) + onnx.save_model(f32_model, f32_model_path) + + data_reader = self.input_feeds(3, {"input_0": shape, "input_1": shape}) + + # Create pure 8-bit qdq model + quantize_static( + f32_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + op_types_to_quantize=[node.op_type for node in f32_model.graph.node], + ) + + # Create mixed precision 8-bit/16-bit qdq model + mixed_prec_overrides = { + "op2_out": [ + {"quant_type": QuantType.QUInt8, "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op4"}}} + ], + "op3_out": [ + {"quant_type": QuantType.QUInt8, "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"op5"}}} + ], + "op4_out": [{"quant_type": QuantType.QUInt16}], + "op5_out": [{"quant_type": QuantType.QUInt16}], + "output_0": [{"quant_type": QuantType.QUInt16}], + } + data_reader.rewind() + quantize_static( + f32_model_path, + qdq_mixed_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + op_types_to_quantize=[node.op_type for node in f32_model.graph.node], + extra_options={"TensorQuantOverrides": mixed_prec_overrides}, + ) + + qop_nodes = {"Relu": 0, "QuantizeLinear": 11, "DequantizeLinear": 12} + check_op_type_count(self, qdq_mixed_model_path, **qop_nodes) + data_reader.rewind() + check_model_correctness(self, f32_model_path, qdq_mixed_model_path, data_reader.get_next()) + data_reader.rewind() + check_model_correctness(self, f32_model_path, qdq_model_path, data_reader.get_next()) + + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py index 9ea4719f3c595..77f20b3caed96 100644 --- a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py +++ b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py @@ -11,12 +11,12 @@ import numpy as np import onnx -from onnxruntime import quantization +from onnxruntime.quantization import CalibrationDataReader, QuantFormat, QuantType, quantize_static from onnxruntime.quantization.execution_providers.qnn import get_qnn_qdq_config from onnxruntime.quantization.quant_utils import compute_scale_zp, get_qmin_qmax_for_qType, ms_domain -class DummyDataReader(quantization.CalibrationDataReader): +class DummyDataReader(CalibrationDataReader): def __init__(self, activations): self.iterator = ({"INP": act} for act in activations) @@ -81,11 +81,11 @@ def perform_qdq_quantization(self, output_model_name, extra_options=None, per_ch if activation_type is None: activation_type = self.default_act_qtype - quantization.quantize_static( + quantize_static( model_input="model.onnx", model_output=output_model_name, calibration_data_reader=DummyDataReader(self.activations), - quant_format=quantization.QuantFormat.QDQ, + quant_format=QuantFormat.QDQ, activation_type=activation_type, weight_type=self.default_wgt_qtype, per_channel=per_channel, @@ -223,8 +223,8 @@ def test_qdq_overrides1(self): "SIG_OUT": [ {"scale": np.array(1.0, dtype=np.float32), "zero_point": np.array(127, dtype=np.uint8)} ], - "WGT": [{"quant_type": quantization.QuantType.QInt8, "symmetric": True, "reduce_range": True}], - "BIAS": [{"quant_type": quantization.QuantType.QInt8, "symmetric": True, "reduce_range": True}], + "WGT": [{"quant_type": QuantType.QInt8, "symmetric": True, "reduce_range": True}], + "BIAS": [{"quant_type": QuantType.QInt8, "symmetric": True, "reduce_range": True}], } }, ) @@ -240,7 +240,7 @@ def test_qdq_overrides1(self): self.assertEqual(sig_out_sc.float_data[0], np.float32(1.0)) # Weight should have different type, zero_point, and scale - self.assertEqual(wgt_zp.data_type, quantization.QuantType.QInt8.tensor_type) + self.assertEqual(wgt_zp.data_type, QuantType.QInt8.tensor_type) wgt_qmin, wgt_qmax = get_qmin_qmax_for_qType(wgt_zp.data_type, reduce_range=True, symmetric=True) wgt_rmin, wgt_rmax = np.min(self.weight), np.max(self.weight) @@ -249,7 +249,7 @@ def test_qdq_overrides1(self): self.assertEqual(wgt_sc.float_data[0], np.float32(new_wgt_sc)) # Bias should now be treated as a weight and should have different type, zero_point, and scale - self.assertEqual(bias_zp.data_type, quantization.QuantType.QInt8.tensor_type) + self.assertEqual(bias_zp.data_type, QuantType.QInt8.tensor_type) bias_qmin, bias_qmax = get_qmin_qmax_for_qType(bias_zp.data_type, reduce_range=True, symmetric=True) bias_rmin, bias_rmax = np.min(self.bias), np.max(self.bias) @@ -375,7 +375,7 @@ def test_qdq_overrides_per_channel2(self): """ rmin_vals = [0.0, 0.2] rmax_vals = [1.0, 0.8] - quant_type = quantization.QuantType.QUInt8 + quant_type = QuantType.QUInt8 reduce_ranges = [True, False] ( _, @@ -434,8 +434,8 @@ def test_16bit_overrides_set_ms_domain(self): activation_type=onnx.TensorProto.UINT8, # Default to 8bit activations extra_options={ "TensorQuantOverrides": { - "INP": [{"quant_type": quantization.QuantType.QUInt16}], - "SIG_OUT": [{"quant_type": quantization.QuantType.QUInt16}], + "INP": [{"quant_type": QuantType.QUInt16}], + "SIG_OUT": [{"quant_type": QuantType.QUInt16}], } }, ) @@ -559,31 +559,446 @@ def test_override_validation_bad_combination(self): self.assertIn("option 'reduce_range' is invalid with 'scale' and 'zero_point'", str(context.exception)) - def test_get_qnn_qdq_config(self): + def test_get_qnn_qdq_config_sigmoid(self): """ - Test that the QNN-specific configs override the scale and zero-point of Sigmoid. + Test that the QNN-specific configs override the scale and zero-point of 16-bit Sigmoid. + """ + # Create float model with a Abs --> Sigmoid + graph = onnx.helper.make_graph( + [ + onnx.helper.make_node("Abs", ["input_0"], ["abs_out"], name="Abs_0"), + onnx.helper.make_node("Sigmoid", ["abs_out"], ["output_0"], name="Sigmoid_0"), + ], + "sigmoid_graph", + [onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, (1, 2, 3))], + [onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, (1, 2, 3))], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + float_model_path = "model.onnx" + onnx.save_model(model, float_model_path) + + other_override_0 = {"abs_out": [{"symmetric": True}]} + other_override_1 = { + "abs_out": [ + { + "quant_type": QuantType.QUInt8, + "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"Sigmoid_0"}}, + } + ] + } + other_override_2 = { + "abs_out": [ + { + "quant_type": QuantType.QInt8, + "convert": {"quant_type": QuantType.QInt16, "recv_nodes": {"Sigmoid_0"}}, + } + ] + } + + # Enumerate subtests (default_act_qtype, sigmoid_out_qtype, other_override) + subtest_configs = [ + (QuantType.QUInt16, None, {}), # Sigmoid gets new scale/zp + (QuantType.QUInt16, None, other_override_0), # Sigmoid gets new scale/zp + (QuantType.QInt16, None, {}), # Sigmoid gets new scale/zp + (QuantType.QInt16, None, other_override_0), # Sigmoid gets new scale/zp + (QuantType.QUInt8, QuantType.QUInt16, other_override_1), # Sigmoid gets new scale/zp + (QuantType.QInt8, QuantType.QInt16, other_override_2), # Sigmoid gets new scale/zp + (QuantType.QUInt8, None, other_override_0), # Sigmoid DOES NOT gets new scale/zp + (QuantType.QInt8, None, {}), # Sigmoid DOES NOT gets new scale/zp + (QuantType.QInt8, QuantType.QInt8, {}), # Sigmoid DOES NOT gets new scale/zp + ] + + # Test that Sigmoid's output scale and zp should be overridden for 16-bit Sigmoid. + for default_act_qtype, sigmoid_out_qtype, abs_override in subtest_configs: + with self.subTest( + default_act_qtype=default_act_qtype, sigmoid_out_qtype=sigmoid_out_qtype, abs_override=abs_override + ): + init_overrides = {} + init_overrides.update(abs_override) + + if sigmoid_out_qtype is not None: + init_overrides["output_0"] = [{"quant_type": sigmoid_out_qtype}] + + qnn_config = get_qnn_qdq_config( + float_model_path, + DummyDataReader([]), + activation_type=default_act_qtype, + init_overrides=(init_overrides if init_overrides else None), + add_qtype_converts=False, + ) + + self.assertEqual(set(qnn_config.op_types_to_quantize), {"Abs", "Sigmoid"}) + + if default_act_qtype == QuantType.QUInt16 or sigmoid_out_qtype == QuantType.QUInt16: + self.assertIn("TensorQuantOverrides", qnn_config.extra_options) + self.assertIn("output_0", qnn_config.extra_options["TensorQuantOverrides"]) + self.assertEqual( + qnn_config.extra_options["TensorQuantOverrides"]["output_0"], + [ + { + "quant_type": QuantType.QUInt16, + "scale": np.array(1.0 / 65536.0, dtype=np.float32), + "zero_point": np.array(0, dtype=np.uint16), + } + ], + ) + elif default_act_qtype == QuantType.QInt16 or sigmoid_out_qtype == QuantType.QInt16: + self.assertIn("TensorQuantOverrides", qnn_config.extra_options) + self.assertIn("output_0", qnn_config.extra_options["TensorQuantOverrides"]) + self.assertEqual( + qnn_config.extra_options["TensorQuantOverrides"]["output_0"], + [ + { + "quant_type": QuantType.QInt16, + "scale": np.array(1.0 / 32768.0, dtype=np.float32), + "zero_point": np.array(0, dtype=np.int16), + } + ], + ) + + def test_get_qnn_qdq_config_tanh(self): + """ + Test that the QNN-specific configs override the scale and zero-point of 16-bit Tanh. """ - self.build_float32_model() - qnn_config = get_qnn_qdq_config( - "model.onnx", DummyDataReader(self.activations), activation_type=quantization.QuantType.QUInt16 + # Create float model with a Abs --> Tanh + graph = onnx.helper.make_graph( + [ + onnx.helper.make_node("Abs", ["input_0"], ["abs_out"], name="Abs_0"), + onnx.helper.make_node("Tanh", ["abs_out"], ["output_0"], name="Tanh_0"), + ], + "tanh_graph", + [onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, (1, 2, 3))], + [onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, (1, 2, 3))], ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + float_model_path = "model.onnx" + onnx.save_model(model, float_model_path) + + other_override_0 = {"abs_out": [{"symmetric": True}]} + other_override_1 = { + "abs_out": [ + {"quant_type": QuantType.QUInt8, "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"Tanh_0"}}} + ] + } + other_override_2 = { + "abs_out": [ + {"quant_type": QuantType.QInt8, "convert": {"quant_type": QuantType.QInt16, "recv_nodes": {"Tanh_0"}}} + ] + } - self.assertEqual(qnn_config.extra_options["MinimumRealRange"], 0.0001) + # Enumerate subtests (default_act_qtype, tanh_out_qtype, other_override) + subtest_configs = [ + (QuantType.QUInt16, None, {}), # Tanh gets new scale/zp + (QuantType.QUInt16, None, other_override_0), # Tanh gets new scale/zp + (QuantType.QInt16, None, {}), # Tanh gets new scale/zp + (QuantType.QInt16, None, other_override_0), # Tanh gets new scale/zp + (QuantType.QUInt8, QuantType.QUInt16, other_override_1), # Tanh gets new scale/zp + (QuantType.QInt8, QuantType.QInt16, other_override_2), # Tanh gets new scale/zp + (QuantType.QUInt8, None, other_override_0), # Tanh DOES NOT gets new scale/zp + (QuantType.QInt8, None, {}), # Tanh DOES NOT gets new scale/zp + (QuantType.QInt8, QuantType.QInt8, {}), # Tanh DOES NOT gets new scale/zp + ] - inp_zp, inp_sc, sig_out_zp, sig_out_sc, _, _, _, _, _, _ = self.perform_qdq_quantization( - "model_qnn_quant_overrides.onnx", - extra_options=qnn_config.extra_options, - activation_type=quantization.QuantType.QUInt16, + # Test that Tanh's output scale and zp should be overridden for 16-bit Tanh. + for default_act_qtype, tanh_out_qtype, abs_override in subtest_configs: + with self.subTest( + default_act_qtype=default_act_qtype, tanh_out_qtype=tanh_out_qtype, abs_override=abs_override + ): + init_overrides = {} + init_overrides.update(abs_override) + + if tanh_out_qtype is not None: + init_overrides["output_0"] = [{"quant_type": tanh_out_qtype}] + + qnn_config = get_qnn_qdq_config( + float_model_path, + DummyDataReader([]), + activation_type=default_act_qtype, + init_overrides=(init_overrides if init_overrides else None), + add_qtype_converts=False, + ) + + self.assertEqual(set(qnn_config.op_types_to_quantize), {"Abs", "Tanh"}) + + if default_act_qtype == QuantType.QUInt16 or tanh_out_qtype == QuantType.QUInt16: + self.assertIn("TensorQuantOverrides", qnn_config.extra_options) + self.assertIn("output_0", qnn_config.extra_options["TensorQuantOverrides"]) + self.assertEqual( + qnn_config.extra_options["TensorQuantOverrides"]["output_0"], + [ + { + "quant_type": QuantType.QUInt16, + "scale": np.array(1.0 / 32768.0, dtype=np.float32), + "zero_point": np.array(32768, dtype=np.uint16), + } + ], + ) + elif default_act_qtype == QuantType.QInt16 or tanh_out_qtype == QuantType.QInt16: + self.assertIn("TensorQuantOverrides", qnn_config.extra_options) + self.assertIn("output_0", qnn_config.extra_options["TensorQuantOverrides"]) + self.assertEqual( + qnn_config.extra_options["TensorQuantOverrides"]["output_0"], + [ + { + "quant_type": QuantType.QInt16, + "scale": np.array(1.0 / 32768.0, dtype=np.float32), + "zero_point": np.array(0, dtype=np.int16), + } + ], + ) + + def test_get_qnn_qdq_config_matmul(self): + """ + Test that the QNN-specific configs override MatMul's initializer input type to 8-bit if + the other input is 16-bit and the default weight type is 8-bit. + """ + # Create float model with a Abs --> MatMul + graph = onnx.helper.make_graph( + [ + onnx.helper.make_node("Abs", ["input_0"], ["abs_0_out"], name="Abs_0"), + onnx.helper.make_node("MatMul", ["abs_0_out", "weight"], ["matmul_0_out"], name="MatMul_0"), + onnx.helper.make_node("Abs", ["matmul_0_out"], ["output_0"], name="Abs_1"), + ], + "matmul_graph", + [onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, (2, 3))], + [onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, (2, 2))], + initializer=[onnx.numpy_helper.from_array(np.random.random((3, 2)).astype(np.float32), "weight")], ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + float_model_path = "model.onnx" + onnx.save_model(model, float_model_path) + + q16_qtypes = {QuantType.QUInt16, QuantType.QInt16} + q8_qtypes = {QuantType.QUInt8, QuantType.QInt8} + symmetric_wgt_qtypes = {QuantType.QInt8, QuantType.QInt16} + + other_override_0 = {"output_0": [{"symmetric": True}]} + other_override_1 = { + "matmul_0_out": [ + { + "quant_type": QuantType.QUInt16, + "convert": {"quant_type": QuantType.QUInt8, "recv_nodes": {"Abs_1"}}, + } + ] + } + other_override_2 = { + "matmul_0_out": [ + { + "quant_type": QuantType.QInt16, + "convert": {"quant_type": QuantType.QInt8, "recv_nodes": {"Abs_1"}}, + } + ] + } + convert_matmul_input = { + "abs_0_out": [ + { + "quant_type": QuantType.QUInt8, + "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"MatMul_0"}}, + } + ] + } - # Input should have uint16 quant type - self.assertEqual(inp_zp.data_type, onnx.TensorProto.UINT16) + # Enumerate subtests (default_act_qtype, default_wgt_qtype, matmul_in_qtype, other_override) + subtest_configs = [ + (QuantType.QUInt8, QuantType.QUInt8, None, {}), + (QuantType.QUInt8, QuantType.QUInt8, QuantType.QUInt16, {}), + (QuantType.QUInt8, QuantType.QUInt8, QuantType.QUInt16, other_override_0), + (QuantType.QUInt8, QuantType.QUInt8, QuantType.QUInt16, other_override_1), + (QuantType.QInt8, QuantType.QInt8, QuantType.QInt16, other_override_2), + (QuantType.QUInt16, QuantType.QUInt8, None, other_override_0), + (QuantType.QInt16, QuantType.QInt8, None, {}), + (QuantType.QUInt16, QuantType.QUInt16, None, other_override_0), + (QuantType.QInt16, QuantType.QInt16, None, {}), + (QuantType.QUInt8, QuantType.QUInt8, None, {}), + (QuantType.QUInt8, QuantType.QUInt8, None, convert_matmul_input), + ] - # Sigmoid output should have overridden scale/zp - self.assertEqual(sig_out_zp.int32_data[0], 0) - self.assertEqual(sig_out_zp.data_type, onnx.TensorProto.UINT16) - self.assertEqual(sig_out_sc.float_data[0], np.float32(1.0 / 65536.0)) + # Test if MatMul's weight input is overridden. + for default_act_qtype, default_wgt_qtype, matmul_input_qtype, other_override in subtest_configs: + with self.subTest( + default_act_qtype=default_act_qtype, + default_wgt_qtype=default_wgt_qtype, + matmul_input_qtype=matmul_input_qtype, + other_override=other_override, + ): + init_overrides = {} + init_overrides.update(other_override) + + if matmul_input_qtype is not None: + init_overrides["abs_0_out"] = [{"quant_type": matmul_input_qtype}] + + qnn_config = get_qnn_qdq_config( + float_model_path, + DummyDataReader([]), + activation_type=default_act_qtype, + weight_type=default_wgt_qtype, + init_overrides=(init_overrides if init_overrides else None), + add_qtype_converts=False, + ) + + self.assertEqual(set(qnn_config.op_types_to_quantize), {"Abs", "MatMul"}) + input_is_16bit = ( + (default_act_qtype in q16_qtypes) + or (matmul_input_qtype in q16_qtypes) + or (other_override == convert_matmul_input) + ) + weight_is_symmetric = default_wgt_qtype in symmetric_wgt_qtypes + + if input_is_16bit and default_wgt_qtype in q8_qtypes: + self.assertIn("TensorQuantOverrides", qnn_config.extra_options) + self.assertIn("weight", qnn_config.extra_options["TensorQuantOverrides"]) + self.assertEqual( + qnn_config.extra_options["TensorQuantOverrides"]["weight"], + [ + { + "quant_type": default_wgt_qtype, + "symmetric": weight_is_symmetric, + } + ], + ) + elif init_overrides: + self.assertIn("TensorQuantOverrides", qnn_config.extra_options) + self.assertNotIn("weight", qnn_config.extra_options["TensorQuantOverrides"]) + + self.assertEqual(weight_is_symmetric, qnn_config.extra_options["WeightSymmetric"]) + + def test_get_qnn_qdq_config_layernorm(self): + """ + Test that the QNN-specific configs override LayerNorm's initializer input type to 8-bit if + the other input is 16-bit and the default weight type is 8-bit. + """ + # Create float model with a Abs --> LayerNormalization + graph = onnx.helper.make_graph( + [ + onnx.helper.make_node("Abs", ["input_0"], ["abs_0_out"], name="Abs_0"), + onnx.helper.make_node( + "LayerNormalization", ["abs_0_out", "weight", "bias"], ["layernorm_0_out"], name="LayerNorm_0" + ), + onnx.helper.make_node("Abs", ["layernorm_0_out"], ["output_0"], name="Abs_1"), + ], + "layernorm_graph", + [onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, (2, 3))], + [onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, (2, 3))], + initializer=[ + onnx.numpy_helper.from_array(np.random.random((2, 3)).astype(np.float32), "weight"), + onnx.numpy_helper.from_array(np.random.random((2, 3)).astype(np.float32), "bias"), + ], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + float_model_path = "model.onnx" + onnx.save_model(model, float_model_path) + + q16_qtypes = {QuantType.QUInt16, QuantType.QInt16} + q8_qtypes = {QuantType.QUInt8, QuantType.QInt8} + symmetric_wgt_qtypes = {QuantType.QInt8, QuantType.QInt16} + + other_override_0 = {"output_0": [{"symmetric": True}]} + other_override_1 = { + "layernorm_0_out": [ + { + "quant_type": QuantType.QUInt16, + "convert": {"quant_type": QuantType.QUInt8, "recv_nodes": {"Abs_1"}}, + } + ] + } + other_override_2 = { + "layernorm_0_out": [ + { + "quant_type": QuantType.QInt16, + "convert": {"quant_type": QuantType.QInt8, "recv_nodes": {"Abs_1"}}, + } + ] + } + convert_layernorm_input = { + "abs_0_out": [ + { + "quant_type": QuantType.QUInt8, + "convert": {"quant_type": QuantType.QUInt16, "recv_nodes": {"LayerNorm_0"}}, + } + ] + } + + # Enumerate subtests (default_act_qtype, default_wgt_qtype, layernorm_in_qtype, other_override) + subtest_configs = [ + (QuantType.QUInt8, QuantType.QUInt8, None, {}), + (QuantType.QUInt8, QuantType.QUInt8, QuantType.QUInt16, {}), + (QuantType.QUInt8, QuantType.QUInt8, QuantType.QUInt16, other_override_0), + (QuantType.QUInt8, QuantType.QUInt8, QuantType.QUInt16, other_override_1), + (QuantType.QInt8, QuantType.QInt8, QuantType.QInt16, other_override_2), + (QuantType.QUInt16, QuantType.QUInt8, None, other_override_0), + (QuantType.QInt16, QuantType.QInt8, None, {}), + (QuantType.QUInt16, QuantType.QUInt16, None, other_override_0), + (QuantType.QInt16, QuantType.QInt16, None, {}), + (QuantType.QUInt8, QuantType.QUInt8, None, {}), + (QuantType.QUInt8, QuantType.QUInt8, None, convert_layernorm_input), + ] + + # Test if LayerNorm's weight input is overridden. + for default_act_qtype, default_wgt_qtype, layernorm_input_qtype, other_override in subtest_configs: + with self.subTest( + default_act_qtype=default_act_qtype, + default_wgt_qtype=default_wgt_qtype, + layernorm_input_qtype=layernorm_input_qtype, + other_override=other_override, + ): + init_overrides = {} + init_overrides.update(other_override) + + if layernorm_input_qtype is not None: + init_overrides["abs_0_out"] = [{"quant_type": layernorm_input_qtype}] + + qnn_config = get_qnn_qdq_config( + float_model_path, + DummyDataReader([]), + activation_type=default_act_qtype, + weight_type=default_wgt_qtype, + init_overrides=(init_overrides if init_overrides else None), + add_qtype_converts=False, + ) + + self.assertEqual(set(qnn_config.op_types_to_quantize), {"Abs", "LayerNormalization"}) + input_is_16bit = ( + (default_act_qtype in q16_qtypes) + or (layernorm_input_qtype in q16_qtypes) + or (other_override == convert_layernorm_input) + ) + weight_is_symmetric = default_wgt_qtype in symmetric_wgt_qtypes + + if input_is_16bit and default_wgt_qtype in q8_qtypes: + self.assertIn("TensorQuantOverrides", qnn_config.extra_options) + self.assertIn("weight", qnn_config.extra_options["TensorQuantOverrides"]) + self.assertEqual( + qnn_config.extra_options["TensorQuantOverrides"]["weight"], + [ + { + "quant_type": default_wgt_qtype, + "symmetric": weight_is_symmetric, + } + ], + ) + elif init_overrides: + self.assertIn("TensorQuantOverrides", qnn_config.extra_options) + self.assertNotIn("weight", qnn_config.extra_options["TensorQuantOverrides"]) + + self.assertEqual(weight_is_symmetric, qnn_config.extra_options["WeightSymmetric"]) + self.assertNotIn("bias", qnn_config.extra_options["TensorQuantOverrides"]) def test_get_qnn_qdq_config_ext_data(self): """ @@ -613,6 +1028,7 @@ def test_get_qnn_qdq_config_ext_data(self): ) qnn_config = get_qnn_qdq_config("add_ext_data.onnx", DummyDataReader(self.activations)) + self.assertEqual(set(qnn_config.op_types_to_quantize), {"Add"}) self.assertTrue(qnn_config.use_external_data_format) diff --git a/onnxruntime/test/python/transformers/test_flash_attn.py b/onnxruntime/test/python/transformers/test_flash_attn.py index b784c83329c76..183d6218567a7 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn.py +++ b/onnxruntime/test/python/transformers/test_flash_attn.py @@ -1216,8 +1216,6 @@ def parity_check_gqa_prompt( dtype=torch.float16, requires_grad=False, ) - # print(k.shape) - # print(new_k.shape) window_size = (-1, -1) left_window_size = -1 @@ -1328,10 +1326,6 @@ def parity_check_gqa_prompt( out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() - # print(cache_seqlens[0]) - # print((present_k - k_cache_ref.detach().cpu().numpy())[0, 0, :, 0]) - # print((out - out_ref)[0, :, 0, 0]) - # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1724,9 +1718,6 @@ def parity_check_gqa_past( out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() - # print(cache_seqlens[0]) - # print((present_k - k_cache_ref.detach().cpu().numpy())[0, 0, cache_seqlens[0], :]) - # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1939,18 +1930,6 @@ def parity_check_gqa_past_no_buff( out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() - # print(cache_seqlens[0]) - # print((out - out_ref)[0]) - # print((present_k - k_cache_ref.detach().cpu().numpy())[0, 0, :, 0]) - - # Make sure past-present buffer updating correctly - # assert numpy.allclose( - # present_k[:, :, :-1, :], k_cache_ref.detach().cpu().numpy()[:, :, :-1, :], rtol=rtol, atol=atol, equal_nan=True - # ) - # assert numpy.allclose( - # present_v[:, :, :-1, :], v_cache_ref.detach().cpu().numpy()[:, :, :-1, :], rtol=rtol, atol=atol, equal_nan=True - # ) - # Compare results print( "NO buff", @@ -2078,10 +2057,27 @@ def test_gqa_no_past(self): for sq, skv in seqs: for n, n2 in num_h: for h in h_sizes: - for past_kv_format in [Formats.BNSH]: - config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) - parity_check_gqa_prompt(config, past_format=past_kv_format) - parity_check_gqa_prompt_no_buff(config, past_format=past_kv_format) + for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: + for packed in [False, True]: + config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + parity_check_gqa_prompt( + config, + rtol=2e-3, + atol=2e-3, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + parity_check_gqa_prompt_no_buff( + config, + rtol=2e-3, + atol=2e-3, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) if major < 8 or platform.system() != "Linux": return print("------- FLASH ATTENTION (PROMPT CASE) --------") @@ -2092,12 +2088,12 @@ def test_gqa_no_past(self): for h in h_sizes: for local in [False, True]: for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: - for past_kv_format, packed in [(Formats.BNSH, False), (Formats.BNSH, True)]: + for packed in [False, True]: config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) parity_check_gqa_prompt( config, local=local, - past_format=past_kv_format, + past_format=Formats.BNSH, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, @@ -2105,7 +2101,7 @@ def test_gqa_no_past(self): parity_check_gqa_prompt_no_buff( config, local=local, - past_format=past_kv_format, + past_format=Formats.BNSH, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, @@ -2145,21 +2141,28 @@ def test_gqa_past(self): for s, s2 in seqs: for n, n2 in num_h: for h in h_sizes: - for past_kv_format in [Formats.BNSH]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - parity_check_gqa_past( - config, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) - parity_check_gqa_past_no_buff( - config, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) + for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: + for packed in [False, True]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past( + config, + past_format=Formats.BNSH, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + parity_check_gqa_past_no_buff( + config, + past_format=Formats.BNSH, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) if major < 8 or platform.system() != "Linux": return print("------- FLASH ATTENTION (TOKEN GEN) -------") @@ -2170,13 +2173,13 @@ def test_gqa_past(self): for h in h_sizes: for local in [False, True]: for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: - for past_kv_format, packed in [(Formats.BNSH, False), (Formats.BNSH, True)]: + for packed in [False, True]: sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 config = Config(b, s, s2, sp, n, n2, h) parity_check_gqa_past( config, local=local, - past_format=past_kv_format, + past_format=Formats.BNSH, rtol=1e-3, atol=1e-3, rotary=rotary, @@ -2186,7 +2189,7 @@ def test_gqa_past(self): parity_check_gqa_past_no_buff( config, local=local, - past_format=past_kv_format, + past_format=Formats.BNSH, rtol=1e-3, atol=1e-3, rotary=rotary, diff --git a/onnxruntime/test/python/transformers/test_onnx_utils.py b/onnxruntime/test/python/transformers/test_onnx_utils.py new file mode 100644 index 0000000000000..974991359795e --- /dev/null +++ b/onnxruntime/test/python/transformers/test_onnx_utils.py @@ -0,0 +1,38 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import unittest + +import numpy +from onnx import ModelProto, TensorProto, helper +from onnx.external_data_helper import set_external_data + +from onnxruntime.transformers.onnx_utils import extract_raw_data_from_model, has_external_data + + +class TestOnnxUtils(unittest.TestCase): + def test_extract_raw_data_from_model(self): + model = self._get_model_proto_with_raw_data(False) + external_names, external_values = extract_raw_data_from_model(model) + self.assertEqual(list(external_names), ["inputs"]) + self.assertEqual(len(external_values), 1) + self.assertEqual(external_values[0].numpy(), [0.0]) + + def test_has_external_data(self): + model = self._get_model_proto_with_raw_data() + self.assertTrue(has_external_data(model)) + + def test_has_external_data_with_no_external_data(self): + model = self._get_model_proto_with_raw_data(False) + self.assertFalse(has_external_data(model)) + + def _get_model_proto_with_raw_data(self, has_external_data: bool = True) -> ModelProto: + input = helper.make_tensor_value_info("inputs", TensorProto.FLOAT, [None]) + output = helper.make_tensor_value_info("outputs", TensorProto.FLOAT, [None]) + raw_data = numpy.array([0.0], dtype=numpy.float32).tobytes() + tensor = helper.make_tensor("inputs", TensorProto.FLOAT, [1], raw_data, True) + if has_external_data: + set_external_data(tensor, location="foo.bin") + node = helper.make_node("Identity", inputs=["inputs"], outputs=["outputs"]) + return helper.make_model(helper.make_graph([node], "graph", [input], [output], initializer=[tensor])) diff --git a/onnxruntime/test/testdata/conv_qdq_external_ini.bin b/onnxruntime/test/testdata/conv_qdq_external_ini.bin new file mode 100644 index 0000000000000..e749ab5af29c5 Binary files /dev/null and b/onnxruntime/test/testdata/conv_qdq_external_ini.bin differ diff --git a/onnxruntime/test/testdata/conv_qdq_external_ini.onnx b/onnxruntime/test/testdata/conv_qdq_external_ini.onnx new file mode 100644 index 0000000000000..fad6074aea133 Binary files /dev/null and b/onnxruntime/test/testdata/conv_qdq_external_ini.onnx differ diff --git a/orttraining/orttraining/python/training/ort_triton/_cache.py b/orttraining/orttraining/python/training/ort_triton/_cache.py index ede9cd86a9da5..b70064377abfc 100644 --- a/orttraining/orttraining/python/training/ort_triton/_cache.py +++ b/orttraining/orttraining/python/training/ort_triton/_cache.py @@ -9,6 +9,7 @@ import getpass import hashlib import os +import sys import tempfile from types import ModuleType from typing import Tuple @@ -61,6 +62,7 @@ def load(cls, source_code) -> ModuleType: mod.__file__ = path mod.key = key exec(code, mod.__dict__, mod.__dict__) + sys.modules[mod.__name__] = mod # another thread might set this first cls.cache.setdefault(key, mod) return cls.cache[key] diff --git a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py index e104ea13c59a3..14bc2779aa05b 100644 --- a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py +++ b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py @@ -6,11 +6,13 @@ import functools import json import os +import re import sys from types import ModuleType from typing import List, Tuple, Union import onnx +from onnx import ModelProto from torch._C import _from_dlpack from torch.utils.dlpack import to_dlpack @@ -41,18 +43,39 @@ class _ShapeCache: """ cache = dict() # noqa: RUF012 + symbolic_shape_hint = None + min_symbolic_shape = 0 clear = staticmethod(cache.clear) @classmethod - def get_shape(cls, onnx_key: int, shapes: List[List[int]]) -> List[List[Union[int, str]]]: + def set_symbolic_shape_hint(cls, symbolic_shape_hint_config): + for k, v in symbolic_shape_hint_config.items(): + if k == "*": + cls.min_symbolic_shape = v + else: + if cls.symbolic_shape_hint is None: + cls.symbolic_shape_hint = dict() + cls.symbolic_shape_hint[k] = v + + @classmethod + def get_shape(cls, onnx_key: int, model: ModelProto, shapes: List[List[int]]) -> List[List[Union[int, str]]]: if onnx_key not in cls.cache: + if cls.symbolic_shape_hint is not None: + for i, input in enumerate(model.graph.input): + if input.type.tensor_type.HasField("shape"): + for j, dim in enumerate(input.type.tensor_type.shape.dim): + if dim.dim_param: + for k, v in cls.symbolic_shape_hint.items(): + if re.fullmatch(k, dim.dim_param): + shapes[i][j] = f"i{i}_dim{j}_{v}" + break cls.cache[onnx_key] = shapes else: changed = False for i, shape in enumerate(shapes): for j, dim in enumerate(shape): - if dim != cls.cache[onnx_key][i][j] and isinstance(cls.cache[onnx_key][i][j], int): - max_dim = max(dim, cls.cache[onnx_key][i][j]) + if isinstance(cls.cache[onnx_key][i][j], int) and dim != cls.cache[onnx_key][i][j]: + max_dim = max(dim, cls.cache[onnx_key][i][j], cls.min_symbolic_shape) shape[j] = f"i{i}_dim{j}_{next_power_of_2(max_dim)}" changed = True elif isinstance(cls.cache[onnx_key][i][j], str): @@ -67,13 +90,12 @@ def get_shape(cls, onnx_key: int, shapes: List[List[int]]) -> List[List[Union[in return cls.cache[onnx_key] -def _gen_key(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> int: +def _gen_key(onnx_key: int, model: ModelProto, shapes: List[List[Union[int, str]]]) -> int: # pylint: disable=unused-argument return hash(f"{onnx_key}|{str(shapes).replace(' ', '')}") -def _gen_module(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> Tuple[str, ModuleType]: - model = onnx.load_model_from_string(onnx_str) +def _gen_module(onnx_key: int, model: ModelProto, shapes: List[List[Union[int, str]]]) -> Tuple[str, ModuleType]: sorted_graph = SortedGraph(model, [parse_shape(shape) for shape in shapes]) if _DEBUG_MODE: os.makedirs(os.path.dirname("triton_debug/"), exist_ok=True) @@ -96,14 +118,28 @@ def get_config() -> str: "scalar": only related scalar initializers will be added to subgraphs. "all": all related initializers will be added to subgraphs. The min_nodes is used to control the minimum number of non-no-op nodes in a subgraph. + User can also specify symbolic_shape_hint in the config, which is a dict to control the symbolic shape hint. + Each entry is a regex pattern to match the dim_param in ONNX model and the value is the power of 2 for the symbolic + shape. Each dim_param will be replaced by i{input_index}_dim{dim_index}_{power_of_2} in the symbolic shape. """ + config = dict() config_file = os.getenv("ORTMODULE_TRITON_CONFIG_FILE", "") if config_file and os.path.exists(config_file): with open(config_file, encoding="UTF-8") as f: - return f.read() + config = json.load(f) + + if "ops" not in config: + config["ops"] = get_supported_ops() + if "initializer" not in config: + config["initializer"] = "scalar" + if "min_nodes" not in config: + config["min_nodes"] = 2 + + if "symbolic_shape_hint" in config and len(config["symbolic_shape_hint"]) > 0: + _ShapeCache.set_symbolic_shape_hint(config["symbolic_shape_hint"]) + del config["symbolic_shape_hint"] - config = {"ops": get_supported_ops(), "initializer": "scalar", "min_nodes": 2} return json.dumps(config) @@ -136,8 +172,9 @@ def call_triton_by_onnx(onnx_key: int, onnx_str: bytes, *tensors): assert all(tensor is not None for tensor in tensors) torch_tensors = [_from_dlpack(tensor) for tensor in tensors] concrete_shapes = [list(tensor.size()) for tensor in torch_tensors] - shapes = _ShapeCache.get_shape(onnx_key, concrete_shapes) - func_name, mod = ModuleCache.load(_gen_key, _gen_module, onnx_key, onnx_str, shapes) + model = onnx.load_model_from_string(onnx_str) + shapes = _ShapeCache.get_shape(onnx_key, model, concrete_shapes) + func_name, mod = ModuleCache.load(_gen_key, _gen_module, onnx_key, model, shapes) func = getattr(mod, func_name) output = func(*torch_tensors) if isinstance(output, tuple):