From ac4e72604605be524f70d1e92da81b70c5db984a Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Mon, 2 Oct 2023 12:25:28 +1000 Subject: [PATCH 01/10] Add bytes model loading test to react native e2e (#17749) ### Description Update E2E test to also check InferenceSession.create with bytes. ### Motivation and Context Add tests to validate #17739 --- js/react_native/e2e/package.json | 3 ++- js/react_native/e2e/src/App.tsx | 16 ++++++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/js/react_native/e2e/package.json b/js/react_native/e2e/package.json index 969c70c110123..cd97ec1d099e4 100644 --- a/js/react_native/e2e/package.json +++ b/js/react_native/e2e/package.json @@ -10,7 +10,8 @@ }, "dependencies": { "react": "^18.1.0", - "react-native": "^0.69.1" + "react-native": "^0.69.1", + "react-native-fs": "^2.20.0" }, "devDependencies": { "@babel/core": "^7.17.0", diff --git a/js/react_native/e2e/src/App.tsx b/js/react_native/e2e/src/App.tsx index f3e415f0c5a55..8a76edabc613e 100644 --- a/js/react_native/e2e/src/App.tsx +++ b/js/react_native/e2e/src/App.tsx @@ -8,6 +8,7 @@ import { Image, Text, TextInput, View } from 'react-native'; import { InferenceSession, Tensor } from 'onnxruntime-react-native'; import MNIST, { MNISTInput, MNISTOutput, MNISTResult, } from './mnist-data-handler'; import { Buffer } from 'buffer'; +import { readFile } from 'react-native-fs'; interface State { session: @@ -39,10 +40,21 @@ export default class App extends React.PureComponent<{}, State> { this.setState({ imagePath }); const modelPath = await MNIST.getLocalModelPath(); - const session: InferenceSession = await InferenceSession.create(modelPath); + + // test creating session with path + console.log('Creating with path'); + const pathSession: InferenceSession = await InferenceSession.create(modelPath); + pathSession.release(); + + // and with bytes + console.log('Creating with bytes'); + const base64Str = await readFile(modelPath, 'base64'); + const bytes = Buffer.from(base64Str, 'base64'); + const session: InferenceSession = await InferenceSession.create(bytes); this.setState({ session }); - void this.infer(); + console.log('Test session created'); + void await this.infer(); } catch (err) { console.log(err.message); } From f158f394d695a196ae2e06b350523a7dc741321d Mon Sep 17 00:00:00 2001 From: zesongw Date: Tue, 3 Oct 2023 04:01:04 +0800 Subject: [PATCH 02/10] [WebNN EP] Support Softmax since version 13 (#17714) ### Description WebNN only supports 2-D input tensor along axis 1. For now, we use Reshape and Transpose wraparound to get the compatible input. ### Motivation and Context Enable more models to run on WebNN. --- .../webnn/builders/impl/softmax_op_builder.cc | 97 +++++++++++++------ 1 file changed, 69 insertions(+), 28 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc index b207b804416aa..6a86ca7aca6e9 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc @@ -35,30 +35,79 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::vector input_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); const auto input_size = input_shape.size(); - // WebNN Softmax only support 2d input shape, reshape input to 2d. - if (input_size != 2) { - NodeAttrHelper helper(node); + NodeAttrHelper helper(node); + if (node.SinceVersion() < 13) { int32_t axis = helper.Get("axis", 1); - if (node.SinceVersion() >= 13) - // Opset 13 has default value -1. - axis = helper.Get("axis", -1); + axis = static_cast(HandleNegativeAxis(axis, input_size)); // Coerce the input into a 2-dimensional tensor with dimensions [a_0 * ... * a_{k-1}, a_k * ... * a_{n-1}]. + if (input_size != 2) { + int32_t first_dim = static_cast(std::reduce(input_shape.begin(), input_shape.begin() + axis, + 1, std::multiplies())); + int32_t second_dim = static_cast(std::reduce(input_shape.begin() + axis, input_shape.end(), + 1, std::multiplies())); + emscripten::val new_shape = emscripten::val::array(std::vector{first_dim, second_dim}); + input = model_builder.GetBuilder().call("reshape", input, new_shape); + } + + output = model_builder.GetBuilder().call("softmax", input); + + // Reshape output to the same shape of input. + if (input_size != 2) { + emscripten::val new_shape = emscripten::val::array(); + for (size_t i = 0; i < input_size; i++) { + new_shape.call("push", static_cast(input_shape[i])); + } + output = model_builder.GetBuilder().call("reshape", output, new_shape); + } + } else { + int32_t axis = helper.Get("axis", -1); axis = static_cast(HandleNegativeAxis(axis, input_size)); - int32_t first_dim = static_cast(std::reduce(input_shape.begin(), input_shape.begin() + axis, - 1, std::multiplies())); - int32_t second_dim = static_cast(std::reduce(input_shape.begin() + axis, input_shape.end(), - 1, std::multiplies())); - emscripten::val new_shape = emscripten::val::array(std::vector{first_dim, second_dim}); - input = model_builder.GetBuilder().call("reshape", input, new_shape); - } - output = model_builder.GetBuilder().call("softmax", input); - // Reshape output to the same shape of input. - if (input_size != 2) { - emscripten::val new_shape = emscripten::val::array(); - for (size_t i = 0; i < input_size; i++) { - new_shape.call("push", static_cast(input_shape[i])); + // Wraparound for transpose the target axis to the last. + // WebNN compute the softmax values of the 2-D input tensor along axis 1. + // https://www.w3.org/TR/webnn/#api-mlgraphbuilder-softmax-method + if (axis != static_cast(input_shape.size() - 1)) { + emscripten::val options = emscripten::val::object(); + std::vector permutation(input_shape.size()); + std::iota(permutation.begin(), permutation.end(), 0); + permutation.erase(permutation.begin() + axis); + permutation.push_back(axis); + options.set("permutation", emscripten::val::array(permutation)); + input = model_builder.GetBuilder().call("transpose", input, options); + } + // Wraparound for reshape input tensor to 2-D. + if (input_shape.size() != 2) { + uint32_t first_dim = static_cast(std::reduce(input_shape.begin(), input_shape.begin() + axis, + 1, std::multiplies())); + first_dim *= static_cast(std::reduce(input_shape.begin() + axis + 1, input_shape.end(), + 1, std::multiplies())); + uint32_t second_dim = static_cast(input_shape[axis]); + emscripten::val new_shape = emscripten::val::array(std::vector{first_dim, second_dim}); + input = model_builder.GetBuilder().call("reshape", input, new_shape); + } + + output = model_builder.GetBuilder().call("softmax", input); + + // Transpose back to the axis. + if (input_shape.size() != 2) { + std::vector new_shape; + std::transform(input_shape.begin(), input_shape.begin() + axis, std::back_inserter(new_shape), + [](int64_t dim) -> uint32_t { return static_cast(dim); }); + std::transform(input_shape.begin() + axis + 1, input_shape.end(), std::back_inserter(new_shape), + [](int64_t dim) -> uint32_t { return static_cast(dim); }); + new_shape.push_back(static_cast(input_shape[axis])); + output = model_builder.GetBuilder().call("reshape", + output, emscripten::val::array(new_shape)); + } + // Reshape to the original shape. + if (axis != static_cast(input_shape.size() - 1)) { + emscripten::val options = emscripten::val::object(); + std::vector permutation(input_shape.size()); + std::iota(permutation.begin(), permutation.end(), 0); + permutation.pop_back(); + permutation.insert(permutation.begin() + axis, input_shape.size() - 1); + options.set("permutation", emscripten::val::array(permutation)); + output = model_builder.GetBuilder().call("transpose", output, options); } - output = model_builder.GetBuilder().call("reshape", output, new_shape); } model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); @@ -80,14 +129,6 @@ bool SoftmaxOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali << input_size << "d shape"; return false; } - NodeAttrHelper helper(node); - const int64_t axis = helper.Get("axis", 1); - // WebNN softmax only support reshape for the last axis or version before 13. - // TODO: support opset 13 by composing into: Exp(input) / ReduceSum(Exp(input), axis=axis, keepdims=1). - if (axis != -1 && axis != input_shape.size() - 1 && node.SinceVersion() >= 13) { - LOGS(logger, VERBOSE) << "SoftMax only support axis 1 or -1, input axis: " << axis; - return false; - } return true; } From 451c02543a3a696d5ebe5ad5ef38f9298e6c1bb0 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 2 Oct 2023 21:25:12 -0700 Subject: [PATCH 03/10] [js/webgpu] allow specify preferredLayout (#17756) ### Description Allow WebGPU backend to specify `preferredLayout`. Default is NHWC. ```js const options = {executionProviders: [{name:'webgpu', preferredLayout: 'NCHW'}]}; sess1 = await ort.InferenceSession.create('./mobilenetv2-12.onnx', options); ``` ### Motivation and Context - implement @qjia7's requirement for an easier way to do performance comparison between NCHW vs NHWC. - It's possible that NCHW does better on some models and NHWC on others. So offer user the capability to switch. --- js/common/lib/inference-session.ts | 5 +++++ js/web/lib/wasm/session-options.ts | 15 +++++++++++++++ .../core/providers/js/js_execution_provider.cc | 3 ++- .../core/providers/js/js_execution_provider.h | 18 ++++++++++++++---- .../core/session/provider_registration.cc | 4 ++++ 5 files changed, 40 insertions(+), 5 deletions(-) diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index 71a5912df2464..8c1e69a68ca7e 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -192,6 +192,7 @@ export declare namespace InferenceSession { wasm: WebAssemblyExecutionProviderOption; webgl: WebGLExecutionProviderOption; xnnpack: XnnpackExecutionProviderOption; + webgpu: WebGpuExecutionProviderOption; webnn: WebNNExecutionProviderOption; nnapi: NnapiExecutionProviderOption; } @@ -233,6 +234,10 @@ export declare namespace InferenceSession { export interface XnnpackExecutionProviderOption extends ExecutionProviderOption { readonly name: 'xnnpack'; } + export interface WebGpuExecutionProviderOption extends ExecutionProviderOption { + readonly name: 'webgpu'; + preferredLayout?: 'NCHW'|'NHWC'; + } export interface WebNNExecutionProviderOption extends ExecutionProviderOption { readonly name: 'webnn'; deviceType?: 'cpu'|'gpu'; diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 2659b471733f5..02ff229cc4954 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -88,6 +88,21 @@ const setExecutionProviders = break; case 'webgpu': epName = 'JS'; + if (typeof ep !== 'string') { + const webgpuOptions = ep as InferenceSession.WebGpuExecutionProviderOption; + if (webgpuOptions?.preferredLayout) { + if (webgpuOptions.preferredLayout !== 'NCHW' && webgpuOptions.preferredLayout !== 'NHWC') { + throw new Error(`preferredLayout must be either 'NCHW' or 'NHWC': ${webgpuOptions.preferredLayout}`); + } + const keyDataOffset = allocWasmString('preferredLayout', allocs); + const valueDataOffset = allocWasmString(webgpuOptions.preferredLayout, allocs); + if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== + 0) { + checkLastError( + `Can't set a session config entry: 'preferredLayout' - ${webgpuOptions.preferredLayout}.`); + } + } + } break; case 'wasm': case 'cpu': diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index ae33fb752fe00..444f50958eb7e 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -624,7 +624,8 @@ std::unique_ptr RegisterKernels() { using namespace js; JsExecutionProvider::JsExecutionProvider(const JsExecutionProviderInfo& info) - : IExecutionProvider{kJsExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), true} { + : IExecutionProvider{kJsExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), true}, + preferred_data_layout_{info.data_layout} { } std::vector JsExecutionProvider::CreatePreferredAllocators() { diff --git a/onnxruntime/core/providers/js/js_execution_provider.h b/onnxruntime/core/providers/js/js_execution_provider.h index 091aa2904604a..39d43498c0717 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.h +++ b/onnxruntime/core/providers/js/js_execution_provider.h @@ -19,12 +19,21 @@ KernelCreateInfo BuildKernelCreateInfo(); } // namespace js -// placeholder for future use. no options currently struct JsExecutionProviderInfo { - JsExecutionProviderInfo() = default; - JsExecutionProviderInfo(const ProviderOptions& po) { + auto it = po.find("preferred_layout"); + if (it != po.end()) { + auto& value = it->second; + if (value == "NCHW") { + data_layout = DataLayout::NCHW; + } else if (value == "NHWC") { + data_layout = DataLayout::NHWC; + } + } } + + // JSEP default preferred layout is NHWC + DataLayout data_layout = DataLayout::NHWC; }; class JsExecutionProvider : public IExecutionProvider { @@ -39,7 +48,7 @@ class JsExecutionProvider : public IExecutionProvider { std::shared_ptr GetKernelRegistry() const override; std::unique_ptr GetDataTransfer() const override; - DataLayout GetPreferredLayout() const override { return DataLayout::NHWC; } + DataLayout GetPreferredLayout() const override { return preferred_data_layout_; } FusionStyle GetFusionStyle() const override { return FusionStyle::FilteredGraphViewer; } @@ -48,6 +57,7 @@ class JsExecutionProvider : public IExecutionProvider { bool ConcurrentRunSupported() const override { return false; } std::vector CreatePreferredAllocators() override; + DataLayout preferred_data_layout_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 9326c6eaff240..50c9f6681a0c8 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -109,6 +109,10 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, #endif } else if (strcmp(provider_name, "JS") == 0) { #if defined(USE_JSEP) + std::string preferred_layout; + if (options->value.config_options.TryGetConfigEntry("preferredLayout", preferred_layout)) { + provider_options["preferred_layout"] = preferred_layout; + } options->provider_factories.push_back(JsProviderFactoryCreator::Create(provider_options)); #else status = create_not_supported_status(); From d11e053412960fc677d45215a0b3f8d9e23fc53c Mon Sep 17 00:00:00 2001 From: Kaz Nishimura Date: Tue, 3 Oct 2023 15:53:09 +0900 Subject: [PATCH 04/10] Add option to specify the EP to use, enabling DML EP and others (#17490) ### Description Add DML EP to the acceptable provider list in the optimizer. ### Motivation and Context With DML EP, graph optimization was not performed in onnxruntime. --- .../stable_diffusion/optimize_pipeline.py | 10 ++++ .../python/tools/transformers/optimizer.py | 52 ++++++++++++++++--- 2 files changed, 54 insertions(+), 8 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index 4512c971ac27c..aef60a534608a 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -144,6 +144,7 @@ def _optimize_sd_pipeline( opt_level=0, optimization_options=fusion_options, use_gpu=True, + provider=args.provider, ) if float16: @@ -168,6 +169,7 @@ def _optimize_sd_pipeline( optimize_by_onnxruntime( str(tmp_model_path), use_gpu=True, + provider=args.provider, optimized_model_path=str(ort_optimized_model_path), save_as_external_data=use_external_data_format, ) @@ -324,6 +326,14 @@ def parse_arguments(argv: Optional[List[str]] = None): ) parser.set_defaults(use_external_data_format=None) + parser.add_argument( + "--provider", + required=False, + type=str, + default=None, + help="Execution provider to use.", + ) + FusionOptions.add_arguments(parser) args = parser.parse_args(argv) diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index 3f274eb6c835a..5ded027b36f74 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -69,6 +69,8 @@ def optimize_by_onnxruntime( save_as_external_data: bool = False, external_data_filename: str = "", external_data_file_threshold: int = 1024, + *, + provider: Optional[str] = None, ) -> str: """ Use onnxruntime to optimize model. @@ -82,6 +84,7 @@ def optimize_by_onnxruntime( save_as_external_data (bool): whether to save external data outside of ONNX model external_data_filename (str): name of external data file. If not provided, name is automatically created from ONNX model. external_data_file_threshold (int): threshold to decide whether to save tensor in ONNX model or in external data file + provider (str or None): execution provider to use if use_gpu Returns: optimized_model_path (str): the path of optimized model """ @@ -90,8 +93,12 @@ def optimize_by_onnxruntime( import onnxruntime - if use_gpu and set(onnxruntime.get_available_providers()).isdisjoint( - ["CUDAExecutionProvider", "ROCMExecutionProvider", "MIGraphXExecutionProvider"] + if ( + use_gpu + and provider is None + and set(onnxruntime.get_available_providers()).isdisjoint( + ["CUDAExecutionProvider", "ROCMExecutionProvider", "MIGraphXExecutionProvider"] + ) ): logger.error("There is no gpu for onnxruntime to do optimization.") return onnx_model_path @@ -138,17 +145,32 @@ def optimize_by_onnxruntime( kwargs["disabled_optimizers"] = disabled_optimizers if not use_gpu: - onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=["CPUExecutionProvider"], **kwargs) + providers = ["CPUExecutionProvider"] + elif provider is not None: + if provider == "dml": + providers = ["DmlExecutionProvider"] + elif provider == "rocm": + providers = ["ROCMExecutionProvider"] + elif provider == "migraphx": + providers = ["MIGraphXExecutionProvider", "ROCMExecutionProvider"] + elif provider == "cuda": + providers = ["CUDAExecutionProvider"] + elif provider == "tensorrt": + providers = ["TensorrtExecutionProvider", "CUDAExecutionProvider"] + else: + providers = ["CUDAExecutionProvider"] + + providers.append("CPUExecutionProvider") else: - gpu_ep = [] + providers = [] if torch_version.hip: - gpu_ep.append("MIGraphXExecutionProvider") - gpu_ep.append("ROCMExecutionProvider") + providers.append("MIGraphXExecutionProvider") + providers.append("ROCMExecutionProvider") else: - gpu_ep.append("CUDAExecutionProvider") + providers.append("CUDAExecutionProvider") - onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=gpu_ep, **kwargs) + onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=providers, **kwargs) assert os.path.exists(optimized_model_path) and os.path.isfile(optimized_model_path) logger.debug("Save optimized model by onnxruntime to %s", optimized_model_path) @@ -220,6 +242,8 @@ def optimize_model( use_gpu: bool = False, only_onnxruntime: bool = False, verbose: bool = False, + *, + provider: Optional[str] = None, ): """Optimize Model by OnnxRuntime and/or python fusion logic. @@ -257,6 +281,7 @@ def optimize_model( use_gpu (bool, optional): use gpu or not for onnxruntime. Defaults to False. only_onnxruntime (bool, optional): only use onnxruntime to optimize model, and no python fusion. Defaults to False. + provider (str, optional): execution provider to use if use_gpu. Defaults to None. Returns: object of an optimizer class. @@ -302,6 +327,7 @@ def optimize_model( temp_model_path = optimize_by_onnxruntime( input, use_gpu=use_gpu, + provider=provider, optimized_model_path=optimized_model_path, opt_level=opt_level, disabled_optimizers=disabled_optimizers, @@ -316,6 +342,7 @@ def optimize_model( temp_model_path = optimize_by_onnxruntime( input, use_gpu=use_gpu, + provider=provider, optimized_model_path=optimized_model_path, opt_level=1, disabled_optimizers=disabled_optimizers, @@ -423,6 +450,14 @@ def _parse_arguments(): ) parser.set_defaults(use_gpu=False) + parser.add_argument( + "--provider", + required=False, + type=str, + default=None, + help="Execution provider to use if use_gpu", + ) + parser.add_argument( "--only_onnxruntime", required=False, @@ -501,6 +536,7 @@ def main(): opt_level=args.opt_level, optimization_options=optimization_options, use_gpu=args.use_gpu, + provider=args.provider, only_onnxruntime=args.only_onnxruntime, ) From d0519a760326fc7d87afdcbc1ea4c110ecbe3ce8 Mon Sep 17 00:00:00 2001 From: Arthur Islamov Date: Tue, 3 Oct 2023 23:20:20 +0400 Subject: [PATCH 05/10] [js/web] BiasSplitGelu and BiasAdd kernels (#17161) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Two contrib kernels that supposed to speed-up StableDiffusion according to this doc https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md However, there is no noticable effect in speed or memory consumption. So i guess the only way to make it faster is to implement MultiHeadAttention but i'm not capable of doing that right now. So i'll focus on existing PRs and finding the JSEP kernel that produces incorrect results. It should be one of the old ones (i suspect Conv or ConvTranspose), as SD was not generating images correctly on webgpu since i started working on it. I hoped someone else would fix that by the time i finish with kernels/optimizations 😅 --------- Co-authored-by: Guenther Schmuelling Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- js/web/docs/webgpu-operators.md | 2 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 4 + js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts | 69 + .../wasm/jsep/webgpu/ops/bias-split-gelu.ts | 76 + js/web/test/data/ops/bias-add.jsonc | 874 +++++++++++ js/web/test/data/ops/bias-split-gelu.jsonc | 1332 +++++++++++++++++ js/web/test/suite-test-list.jsonc | 2 + onnxruntime/contrib_ops/js/bias_add.cc | 23 + onnxruntime/contrib_ops/js/bias_add.h | 17 + onnxruntime/contrib_ops/js/bias_split_gelu.cc | 23 + onnxruntime/contrib_ops/js/bias_split_gelu.h | 17 + .../contrib_ops/js/js_contrib_kernels.cc | 4 + 12 files changed, 2443 insertions(+) create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts create mode 100644 js/web/test/data/ops/bias-add.jsonc create mode 100644 js/web/test/data/ops/bias-split-gelu.jsonc create mode 100644 onnxruntime/contrib_ops/js/bias_add.cc create mode 100644 onnxruntime/contrib_ops/js/bias_add.h create mode 100644 onnxruntime/contrib_ops/js/bias_split_gelu.cc create mode 100644 onnxruntime/contrib_ops/js/bias_split_gelu.h diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index f8ac29e5f82ca..4e33368a7aa65 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -21,6 +21,8 @@ Do not modify directly.* | Atan | ai.onnx(7+) | | | Atanh | ai.onnx(9+) | | | AveragePool | ai.onnx(7-9,10,11+); com.ms.internal.nhwc(11+) | need perf optimization; need implementing activation | +| BiasAdd | com.microsoft(1+) | | +| BiasSplitGelu | com.microsoft(1+) | | | Cast | ai.onnx(6-8,9-12,13-18,19+) | | | Ceil | ai.onnx(6-12,13+) | | | Clip | ai.onnx(6-10,11,12,13+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index cbe845b882468..2fba39c939a16 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -2,6 +2,8 @@ // Licensed under the MIT License. import {argMax, argMin, parseArgMinMaxAttributes} from './ops/argminmax'; +import {biasAdd} from './ops/bias-add'; +import {biasSplitGelu} from './ops/bias-split-gelu'; import * as binaryOps from './ops/binary-op'; import {concat, parseConcatAttributes} from './ops/concat'; import {conv, parseConvAttributes} from './ops/conv'; @@ -45,6 +47,8 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Atanh', [unaryOps.atanh]], // TODO: support new attributes for AveragePool-10 ['AveragePool', [pool.averagePool, pool.parseAveragePoolAttributes]], + ['BiasAdd', [biasAdd]], + ['BiasSplitGelu', [biasSplitGelu]], ['Cast', [unaryOps.cast, unaryOps.parseCastAttributes]], ['Ceil', [unaryOps.ceil]], ['ClipV10', [unaryOps.clipV10]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts b/js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts new file mode 100644 index 0000000000000..688bedc619ce6 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; + +import {inputVariable, outputVariable, ShaderHelper} from './common'; + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (inputs[0].dims.length !== 3) { + throw new Error('input should have 3 dimensions'); + } + + if (![320, 640, 1280].includes(inputs[0].dims[2])) { + throw new Error('number of channels should be 320, 640 or 1280'); + } + + if (inputs[1].dims.length !== 1) { + throw new Error('bias is expected to have 1 dimensions'); + } + + if (inputs[0].dims[2] !== inputs[1].dims[0]) { + throw new Error('last dimension of input and bias are not the same'); + } +}; + +const createBiasAddProgramInfo = (metadata: ProgramMetadata, inputs: readonly TensorView[]): ProgramInfo => { + const outputShape = inputs[0].dims; + + const channels = inputs[0].dims[2]; + // since channel number can be only 320/640/1280, it's always divisable by 4 + const outputSize = ShapeUtil.size(outputShape) / 4; + + const dataType = inputs[0].dataType; + const input = inputVariable('input', dataType, outputShape, 4); + const bias = inputVariable('bias', dataType, [channels], 4); + const residual = inputVariable('residual', dataType, outputShape, 4); + const output = outputVariable('output', dataType, outputShape, 4); + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const channels = ${channels}u / 4; + ${shaderHelper.declareVariables(input, bias, residual, output)} + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + let value = ${input.getByOffset('global_idx')} + + ${bias.getByOffset('global_idx % channels')} + ${residual.getByOffset('global_idx')}; + ${output.setByOffset('global_idx', 'value')} + }`; + + return { + ...metadata, + outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + getShaderSource, + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) + }; +}; + +export const biasAdd = (context: ComputeContext): void => { + validateInputs(context.inputs); + const inputTypes = Array(context.inputs.length).fill(GpuDataType.default); + const metadata = { + name: 'BiasAdd', + inputTypes, + }; + + context.compute(createBiasAddProgramInfo(metadata, context.inputs)); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts b/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts new file mode 100644 index 0000000000000..8ec4ff5c870d3 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; + +import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {erfImpl} from './unary-op'; + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (inputs[0].dims.length !== 3) { + throw new Error('input should have 3 dimensions'); + } + + if (![2560, 5120, 10240].includes(inputs[0].dims[2])) { + throw new Error('hidden state should be 2560, 5120 or 10240'); + } + + if (inputs[1].dims.length !== 1) { + throw new Error('bias is expected to have 1 dimensions'); + } + + if (inputs[0].dims[2] !== inputs[1].dims[0]) { + throw new Error('last dimension of input and bias are not the same'); + } +}; + +const createBiasSplitGeluProgramInfo = (metadata: ProgramMetadata, inputs: readonly TensorView[]): ProgramInfo => { + const outputShape = inputs[0].dims.slice(); + outputShape[2] = outputShape[2] / 2; + + const input = inputVariable('input', inputs[0].dataType, inputs[0].dims, 4); + const bias = inputVariable('bias', inputs[0].dataType, [inputs[0].dims[2]], 4); + const output = outputVariable('output', inputs[0].dataType, outputShape, 4); + + const outputSize = ShapeUtil.size(outputShape) / 4; + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const M_SQRT2 = sqrt(2.0); + const halfChannels = ${inputs[0].dims[2] / 4 / 2}u; + + ${shaderHelper.declareVariables(input, bias, output)} + + ${erfImpl('vec4f')} + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + let biasIdx = global_idx % halfChannels; + let batchIndex = global_idx / halfChannels; + let inputOffset = biasIdx + batchIndex * halfChannels * 2; + let valueLeft = input[inputOffset] + bias[biasIdx]; + let valueRight = input[inputOffset + halfChannels] + bias[biasIdx + halfChannels]; + let geluRight = valueRight * 0.5 * (erf_vf32(valueRight / M_SQRT2) + 1); + + ${output.setByOffset('global_idx', 'valueLeft * geluRight')} + }`; + + return { + ...metadata, + outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + getShaderSource, + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) + }; +}; + +export const biasSplitGelu = (context: ComputeContext): void => { + validateInputs(context.inputs); + + const metadata = { + name: 'BiasSplitGelu', + inputTypes: [GpuDataType.default, GpuDataType.default], + }; + + context.compute(createBiasSplitGeluProgramInfo(metadata, context.inputs)); +}; diff --git a/js/web/test/data/ops/bias-add.jsonc b/js/web/test/data/ops/bias-add.jsonc new file mode 100644 index 0000000000000..e89c5dd81cc23 --- /dev/null +++ b/js/web/test/data/ops/bias-add.jsonc @@ -0,0 +1,874 @@ +[ + { + "name": "BiasAdd", + "operator": "BiasAdd", + "attributes": [], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "bias add [2,2,320]x[320]x[2,2,320]", + "inputs": [ + { + "data": [ + -0.43078827160569055, -1.3343044914862014, -1.3875186857333706, 0.7550491056943578, -0.015677493769832296, + 1.2761569600658005, -0.033221806310058, 1.3590872185896297, 1.706516873231478, 1.429932535224122, + -0.39598259309643424, -0.6735725816667406, 0.5551536414690581, 1.9642482005713608, 0.1481069760277398, + -0.6706041659325672, 1.7626582023316804, 1.8587314247747715, 0.6761988056588422, -1.2865903673587722, + 0.17451384248240753, -1.468579522325724, -1.4426371855482047, -1.101469412150644, 0.32065561169114787, + -1.7280217475756476, 0.8472414982170147, 1.089925472742105, -1.9188568763699578, -1.6897769809368715, + -0.22443847491362767, -1.592027916773227, -1.403545711986184, -0.992398507333391, 1.501364958055385, + 1.1969883087000177, -1.742362862568375, -1.3122082038099396, -1.602285316113413, 1.7122406231186638, + -1.7331729420495074, -1.1660018974894326, 1.8466385944978407, 0.8059189037532883, 0.42819875266849383, + 1.2622036788694384, 1.09253289978553, 1.7169775828443141, 0.8965924110901806, 1.731515754612638, + -0.710260858323644, 0.6949611010518701, 1.9392089898271783, -0.06855145845621369, 1.831602787077924, + 1.919258334566825, -0.1750228495679611, -1.385684174878624, 0.1864456393673155, -0.7627497248975219, + 0.5460459705476177, 1.0151263124631544, 1.0476468224745439, -1.4662684764200389, 0.8737230293064417, + -1.5018396956497346, 1.7171541562949724, 0.33089656971471637, -0.4497250989135697, 0.03155539207838842, + 0.051811401515028166, -0.6417182241860351, 0.2892848388089577, 1.7313934166665597, 0.07283972370973935, + -1.1648958342544793, 0.31235891981372976, 0.8667946196377363, -1.8955234844794324, -0.7935013829177411, + 1.7461436471789726, -0.7863083195781497, -0.24062975690975286, 0.5319240675275907, -1.093540712132092, + -0.08133035383567844, 0.24368896952864638, -0.7817379372747375, 0.09213174790435374, 0.7008130043325052, + 1.2549161788773802, -1.3617356636970648, 0.9203525826936687, -0.17473559280037954, 0.8615190853991157, + 1.7227407352985065, -1.6408918805405301, -1.2530895034442668, -1.561750342643168, 0.9122283638044753, + 1.52061857140668, -1.757751307093299, 1.9174275778315568, 0.8751522548565935, -0.40373986100628567, + -0.9081360776169474, -1.0200280593711373, 1.8684056043061243, -0.8656213251095854, 1.2016289009297738, + -0.6420074354987753, 0.8277984540094971, -0.7794691695697935, -0.3682767999575489, -0.8431108361591697, + -1.8337997630639924, -1.497840701527875, 0.3582228355620778, -1.7648364421707035, 1.2295209603744386, + -1.4630737828552771, -1.0901242069115051, 0.16944674169381635, -1.1367614122844483, 1.1108472254071025, + -1.823264607399035, -1.9393711989501217, 1.5655717488598917, -1.2791502509294475, 0.5264599738022175, + 1.2793776364048464, 0.6643881105536087, -1.8021867427977876, 1.2507197673418382, 1.5096918274472868, + -0.35787327298378635, -0.8142338946526149, 0.05554038558981045, -0.8292295885917778, -1.9802571383548457, + 0.15711486642700123, 1.0606485766817002, -0.75648306593627, 0.7821766109776371, 1.5576015081398582, + 1.274065256431233, -0.11137614589209388, 1.5523940512159609, -0.9225168834642474, -1.52372850136818, + -0.09285110141468689, 0.951570914266842, 1.4257064421442056, -1.1151817450836203, -1.0405979771695408, + -0.3512294660068509, 0.5544676624698717, -1.1641268738929513, -1.1024667332826468, -0.8470834386045532, + 1.0200780368213875, -0.4111438148101385, 1.5060564363688043, -0.9264136525377458, -1.1144625610145829, + -0.8955596783917263, 1.1298063812792858, 1.1669470152018224, -0.9624856700584665, -0.5105634488042536, + 0.3350150739180062, 1.8252865056498973, 1.8318837072890988, 1.247181799708187, 0.953253999223505, + 0.33583064317486144, 0.8060935063640473, 0.025351359439778953, -1.9020713785704064, -1.4395100399524523, + -1.4434709436802136, 1.1887525549807645, -0.7404500263025655, 0.43567229698745447, 0.697905733863994, + 1.4955234099092705, 1.8798334809531436, 0.28521714680739496, -1.8247616864327485, -0.8899636892784697, + -1.9902986052685518, -0.8612351129594646, 1.4473654540376284, -0.027164144754272534, -1.3409554794285556, + 1.8216393608213144, 1.6287547173387606, 0.34546830263617245, -1.8554217975954366, -0.1599936883574884, + -1.1517327823070804, -1.1043403244478194, 0.9959357690900248, 0.47246889924079394, 1.3040046448975406, + 1.3219564442035914, 1.5522794215260332, -1.3132538270303629, -0.6568357397541558, -0.3166961264888011, + 1.155856117058299, -0.2856093347506521, 0.37282953339474023, 1.3624325643069906, -0.5961360461760625, + 0.15430022597664106, 1.459554220213934, -0.4578169766295437, -1.5520231063597194, -0.9692019226485984, + -0.7868623417517249, -0.47648838559948903, 0.521518493577898, -0.43202750995667394, -0.2800695092149823, + -0.062073964072573595, -1.280928976710845, 1.653227202229826, -1.3266797422187144, -1.8026162116413413, + -1.6342686696555697, -1.8535208355884771, -0.6740602051435252, -1.1112082500330684, -0.2784013709506814, + -0.38477554625820254, -1.8030977138882127, 1.0792754568917307, 1.7171188079371378, 1.0025920437036717, + 0.4909327973461517, 1.7583392556519017, -0.9513053055982352, 0.9836640607205354, 0.676230592061307, + -1.532719596602865, 1.7094356339043584, 1.275239052465298, -0.7097409099122736, -0.07447816392320394, + 0.6054981041837371, -1.368606109962121, 0.5491138352978151, 0.9607716616110995, -0.8472145355857039, + -1.5962071764683028, -1.3945278279448647, 1.6804737734010837, -0.2799307641611124, -1.5199944467651827, + -0.6284155111805374, -1.387427601467464, -1.3689101475786654, 0.5479648937942034, -0.35058122336125397, + -0.01913895405776156, 1.9421209900010803, 1.521339125582216, 0.44317460300853284, 1.4673321860144632, + 1.0844686121268872, 1.5454854860881193, -0.10676223292802778, 1.304830685737219, 1.5144316149285784, + -0.9206202207515908, 1.800878209478248, -1.4022748170800003, 1.8475728691446207, -0.6049511260966813, + 1.3794138268485217, -0.5504308384418284, -1.308589620881369, -1.5124914949500843, -0.9843890898875198, + -0.47517812559034134, 0.019190610059508728, 1.3102818494016093, 0.02046481460734384, 0.8966955765481393, + -1.9544981815442322, 1.4818794231229235, 1.701164813487166, -0.1853726987447688, -0.9284129805631327, + -0.6665475855012435, 1.7174739766102922, -0.6866660477705144, 1.3741442823600236, 1.3144332615465366, + -1.6821345841035624, 1.4369774381581397, -0.40113703543299106, 1.2155358501458204, 0.3746014766056245, + -1.2797835093659478, -1.2733146708696594, 1.7592472733050606, 0.7411515442284751, -0.8417009900350942, + 0.43950532643238294, -0.4849697519720575, 0.4561330711561924, -0.9993716519294349, 1.2993450426953688, + 0.6077681589510648, -1.1884025219126162, 0.7423783440021925, 0.2164111466551839, -1.8404401979886122, + -1.6417250174486764, 0.33169566624656355, 1.4670345153781224, 0.05770292022609169, -0.7574653293187339, + -1.8332930679907182, 0.30545703812352265, 1.2645746651782588, 0.06851662509789058, -1.5007737660843672, + -0.7895770838504896, -0.16973917251686643, -1.1309776011957018, -0.5245832860094533, 1.2865060486729964, + -0.6877713804065921, 0.8603087489217494, -1.6760835939805823, -0.3421307621860663, -0.08163355960848584, + 0.8960283575910557, -0.07414831058651483, 0.06552900030788233, -0.28343554832982676, 0.8580233436387052, + 0.9756624200728901, -0.31454210863680565, -1.3961591961928228, 1.4357660237613095, 0.995011454280724, + 1.7963312095595976, 0.44298210926821735, -0.4481873251636479, 0.04661154364950004, -1.0205093044441567, + -1.416501713829339, -1.5609202237237856, -1.9678529588325686, 0.054492766820211536, 1.9234783545287293, + -1.730337423937696, 1.0002270533854967, 0.912626942921996, 1.6082280213507438, -0.0007848079092749316, + -0.16260961343095381, -0.4951675344800499, -1.358161813033476, 0.0952066467959174, 1.1451260564923738, + 0.5658192251004577, 1.638639222136547, 1.4561159943128157, 1.394720518360975, -1.6647338550020887, + -0.35227744791908844, -1.4997218325299349, 1.7269852905741923, -1.6265531305359326, 0.08177780426753678, + -0.2548418092940219, -0.5491200865930868, -1.3448390466679427, -1.4669459626844805, 1.6556888857778747, + 0.9797534441152536, 0.011435434178378223, -1.0444316805293496, -0.21951923449225497, 1.2775877064028922, + 1.3912532203430734, -0.41034439679528845, 1.6871573166908078, -0.2949753341249375, -0.19366027881779058, + -0.8063505536355873, -1.9683493413513462, 1.1343233715358245, 0.5863358659604163, -1.7161271586916902, + -1.1511237584065572, -1.026334237004506, 0.051935455424261256, -0.47999388673650234, 1.3653965992268642, + -1.6246601676055645, 1.0720439180368135, 1.7513891096415195, -1.9557074487732296, -0.6497504028268937, + 1.5657013366841293, 1.480374539039186, 0.6254160749725743, -0.1386074990036681, 1.4879074098885399, + 0.8608723264201785, 1.3707693807951804, -1.9599278528285744, -0.08178821222314703, -0.8801306486657889, + 0.2672127121047323, -1.0905213718834403, 1.2845923771997692, 0.7483674434994203, 1.692846569262815, + -1.4081461907911006, 0.4847990697153808, -1.384710972149036, -0.48477323237460457, 0.5075573472017414, + -0.469975647682892, -1.7330175330396296, 0.4900846176343876, -0.5504384887286493, -1.3105683633488505, + 1.7832520063588246, -0.008133111014310579, -0.4117632340667363, -1.364428903076842, 1.1502921666750634, + -1.297368143184646, 1.2657228492344146, -0.8339630526727912, 1.6036841813714835, 1.7388392748246462, + 1.0764044765977916, 0.8635247292187316, 0.8915437658945047, 0.048593859303998954, 0.6397243841879581, + -0.32613694518247716, -0.5352551370184315, 1.0102009591194312, 1.2995349433229801, 0.873915978492783, + 0.28688858188963984, 1.0299796543068158, -0.21558135465200134, -0.8175153872775942, -0.2228988257546698, + 1.3651108393401516, 1.7419897496979102, -0.4287431897946634, -1.7210342396761868, -1.1328295428945232, + 0.45940815613566865, 1.225728222018235, -1.8492110470394127, 0.11914217556997819, -0.5298005043189082, + 0.2548881953681379, 1.4159379393782983, 1.4882705620355736, -1.778746106600524, -0.1322461839301079, + 0.959787258509837, -1.2087764883807948, -0.656679223823172, -0.8052024037155077, 0.28857972466710535, + -1.9964605007110707, -1.618829946697912, 0.0003216720189040956, 1.3814498186817072, 1.267459171727996, + 1.7808899740756745, -0.026284995100580133, 1.8415787116931153, 1.588128771971907, -1.379239932131604, + 1.4954639033867787, -0.5365530743272542, -1.689396586297355, -1.5263365207258657, -0.15914036263121556, + -0.2236484993504826, -0.233339567785662, -0.6479528728470942, 1.7496861376407091, 1.9658328381866337, + 0.7465900143344948, 0.9538445157519684, -1.233501367636971, 1.7832767928842772, -0.05854857072903297, + -1.4741131369232807, -0.8878273465278168, 0.7857197910561684, -1.3138933804749309, -1.6218392819603826, + -1.5976871654455351, 0.7862199318904839, 0.7930452731942381, -0.3851655998300325, 1.454630867585033, + 0.07699835053088844, 1.9635265550996897, -0.9018052833921555, -1.3260562107679466, 1.591219679914813, + -1.9325066606486736, 1.6231422543543284, 0.06461353520395186, 1.938690006387282, 1.18237528878611, + 0.08913168691813134, -0.24910533784113476, -0.24270271174439095, -0.058892972285330636, + 0.3017252069250542, -1.6443343279597213, -0.35086444759952506, 1.2568748876859326, 1.2564694463685155, + 0.5001323821158774, 1.6753738435417835, 0.12536895878538168, 0.11445707906508318, -1.4674734226308148, + -1.664480357904826, 1.2455086575046002, 0.016219310045000768, 0.6915540154848818, 0.8863957336976913, + -1.902221644800517, 1.1943261373616156, 1.5343838485344943, 0.26951920872184143, 1.1650565743686334, + 0.8796643687219943, -0.9914368645961655, -0.4140197685853506, -0.31070218781242964, 0.18576121794650557, + -1.8127899542033798, 0.7364855692315055, 0.61436653795095, -1.340047295554938, -0.9787870418028497, + 1.467065674277963, -1.3989058830195544, 0.8760353392917564, 1.103715585163978, 1.5993888842634876, + -1.4487285544801765, 1.0735445052923094, -1.7221746626643117, 1.3406229033114787, -0.31808486639617684, + 0.18778714988957468, 0.3360116223452154, 0.9957693335536861, -0.7082173962327145, 0.13563127391102814, + -0.6103565638388746, -0.27366984218234336, 1.670182285266585, 0.7450288088790513, -0.7376055316855803, + -1.9155349843040552, -0.5647372929979877, 0.38985718472222164, -1.0859956865362559, -1.2774336101522819, + 1.6716179824020179, -1.8248740137755952, -0.2462417745015948, -1.6330831414998679, 1.2633365236447727, + 1.8405595764836615, -1.5041331151923663, 0.7795364178340263, 0.30228912276856246, -1.4199604845773672, + -1.9675070264789785, -1.5847567364669262, 0.8417404542568088, 1.411592224813667, -0.7798761695294161, + 0.3836369752883826, -1.8656778968593093, -0.5928302815127546, -1.9170044216421092, -0.05015509660959072, + -0.2949046704017606, 0.7706683291518104, -1.0318399043402815, 1.673370077717097, -0.014510791777835763, + -1.7993886007705928, -1.6310599365379153, 0.47517801288072725, 1.7067820519317198, 0.2398133204231856, + 0.8136128605181492, 1.7661699247824538, -1.7400434804598, 1.6440515677939507, 1.0630270197569356, + 1.604276559525923, -0.6932776253265054, -0.5527120052345769, -0.15050627035753017, 1.9862178971553677, + -0.015247839923115514, -1.973561045412188, -0.7527150187762626, -1.8069828707308844, -1.0877483004326844, + 1.5632037263398146, -0.7907173278699853, 0.020757273269920162, -0.9162761989634776, -0.871619440202827, + 0.9274209555443278, 0.2788106849606935, -1.139267389509799, 1.8554359980704245, -1.2900715232363025, + 1.5740910757851019, 1.2022687452616108, -1.3263358341556861, 1.705785771419051, -0.2966767936771886, + 1.0184489507711243, 0.9331775150786781, -0.09496306724418613, 1.6169644094277178, -0.43071270921050253, + 0.00617695499471882, 1.9046927055754255, -0.3800897035987143, -1.2677566100047573, -1.6177276055783159, + 0.41451247814580316, 1.9136674265154827, -1.4827469552090689, -1.7754603228705612, -0.6490018494468908, + 1.0832837177200183, 1.5236866225466352, -1.3663768537968632, 0.3345867923992154, -1.3264457751116838, + 1.1302163417527789, 0.7987367218594539, -0.16727587952795364, -0.7323977510251147, -1.3812007830618693, + 0.7210535269663234, 1.1937695850586403, 1.9659603230065574, -0.12903581925605057, 0.22599793243849042, + -1.055270780364089, 1.0726029698666473, -1.8921222720165147, 1.232307457573829, -1.6501289040499376, + -1.2203689042981933, -1.5782693671316723, 0.84090024006949, 1.1443588250664956, -0.8862909159261179, + -0.34912733758720993, 0.15172846490974035, 0.6801601864117464, -1.3948321257969702, 1.9080269171533608, + -1.851455291444804, -0.8525139187927957, -0.9629354426957466, -0.3970802170728076, 0.2714456125847784, + 1.9355349765888947, 1.643295056500194, 0.461049347109487, -0.819228054522112, -0.7773196244370615, + 0.27305821451894285, -0.5007383808686932, -0.5307070901422906, -1.6087255013287924, 1.708746273758031, + -0.18903643771546896, 0.4237658727537639, 0.7912222637914885, -0.06713232018529425, -0.9303076017910259, + -0.9715788400465986, 1.9239773722802864, -0.13605413264730704, 1.63369388301102, -1.4098364226383078, + -1.8177000155999794, -1.0786108587058756, 1.1467184875003493, 1.6942868827942474, 0.3743937965110735, + -0.5205992959118086, 1.9008470494268277, -0.33613881942363744, 1.408798355509151, 1.4374267676314432, + -1.6266431395044751, -1.2795792184706105, -1.6942133113310183, 0.44536668255477885, -0.5438457389607523, + -0.10839952389615792, -0.741753360770752, 1.9252204052044393, -1.298683868110981, -0.6674482773778925, + 1.7250903096066814, -0.31851968039415723, 0.09917784413898989, -1.2300134366181839, -1.6661221342983064, + -1.7392500850180967, 1.1202143474944348, -0.008191271742616912, -1.9527622539466778, -1.112511983718873, + 0.8023412387190376, -0.8667541685114779, 1.5324025634620666, -0.885176692529595, 0.28103075310068526, + 0.5555445946177473, 0.0746344154810279, 1.8279995093264354, 1.589298526844452, -1.1728977455798502, + -1.2266193029402643, -0.08324504041088154, -1.5997622864209005, -1.9929554997781622, 1.649795503298476, + -1.825621659313522, 1.2037940913995477, -0.1478578137836779, -1.791209833936434, -1.5875938503209248, + -1.54554126750838, 0.22767426425123372, -0.9730859655005233, 1.01498892879457, -0.14002274071109522, + 1.5196250773427282, -0.262915623574389, -0.8395076912698629, -1.6509617804598316, 0.4853743370864656, + -1.2426649457879062, -0.34746037671673236, -0.2668537063224701, -1.5727266735586252, -1.117657993368196, + 0.38564526206938243, 1.8759791120426446, -0.9537042932874478, 1.8291799127445367, -0.005069359588546263, + 0.9881621403906173, -1.0839433778149212, -0.4270854569893858, 1.6796287101836995, 1.4763471217127018, + -1.0142364314111942, -1.9083290608311998, 1.8813585092075078, -1.7461290703297596, -0.21674842672674544, + 0.5502685329995662, -0.4411322123830974, -1.8416200674737109, -1.0844672555196375, -1.4812521607836615, + 1.5083908944091222, -1.038918979055473, 0.7142655565242073, -1.1043893519874795, -0.9397073551320325, + 1.335714667963627, -0.39552262419583606, 0.941189318776745, -1.973265716530558, 0.29314288363790375, + 1.1149136203907473, -1.3280789878295156, 0.4344131789940944, -1.4976865930360805, -0.18683967016091696, + 1.6841429470215354, 0.9990587834251263, -0.9896683309850935, 1.7344823188063838, 1.2809894053746431, + 0.11253057753370044, 0.6202153889005659, -1.1393527545232924, -0.13267950871413792, 1.9929188473410813, + -0.5751171946115745, -1.5738159418363038, -0.722979035597497, -0.3845867260905491, 0.6726496775295967, + 1.5922485171567802, -1.2068498206133187, 0.412853384927236, -1.1136351543496117, 0.4046306514337976, + 1.2348205714395197, -0.6837477630512465, 0.4396481646202046, -1.689933367495156, 1.2059642129302333, + -0.9663178383992985, 1.1606541969900555, 1.6008140707565817, -0.49901900408756017, -0.3295636539441862, + -1.6961254597784983, -1.2594474125668986, -1.8290342261999246, -1.3195471501043112, -1.3360729012617636, + 0.4819319828437685, -0.6044650413524977, 1.5401325916637765, 1.3108212503564856, 0.6641610189431937, + -1.167107917049922, 1.3717452437078013, 1.8234968598766077, -0.059000791459610014, 0.5078759939367101, + 1.7012186950957178, -0.8153543329038886, -0.8116555600682265, 1.1042603281155614, -0.5230370384584662, + 1.5907666644047485, 1.3585126059484214, -1.1604013321546773, 0.2832653250904853, 0.6146831527317715, + 1.710136815942171, -0.4339250659200289, 0.5404568535663827, 0.9731252061576328, 1.667624932562064, + -0.8395294570873553, 0.9655900545408684, -0.8459497721445768, -0.3303605936459242, -1.0228351996527785, + 1.64618479653826, -0.17369144560780647, -0.762588585662475, -1.420812072676508, 0.04977508086731497, + 1.1411840603073538, 1.2658579855151144, -0.7695492083057207, -1.5802557368206935, 1.8063629418173504, + -1.3314629726856042, 0.926332179655863, -0.29403591774091886, -1.4368324532624133, 1.794172160269225, + -1.45190145484798, -0.6970484207532923, -0.7393596578156938, -1.3777134665955009, 1.698548607127142, + -1.5277458031231408, 0.8121596602029895, -0.4871450419173451, -0.8973481250486168, 0.999939728035903, + -1.622149524743996, -1.8097564354752569, -1.9184903777986992, 1.6699213455758075, -1.8966494026411178, + -1.5537605568683883, 1.0114197035541261, -1.7582102654477056, -1.881605728865896, 1.9424710063889181, + -1.387404030261334, 1.3858022761207254, 0.6698691050652563, -1.7882425787208467, 0.05463416162273482, + 1.151588572666963, 0.9448022669750591, 1.2079032520591921, 1.3271868932748552, 0.7459734492068639, + 1.4448931947680101, 1.741554075288172, -0.49778064496799157, -0.05357363118867564, 0.17060355537450267, + 1.8858952797885928, 1.1828850394664219, 1.4398264570692447, -1.0269803221490008, 0.12795611159499387, + 1.4338295119095292, 1.7680805378740985, 1.8889001943775678, -0.023646131688371597, -0.055364618226636075, + 1.3107732868213482, -1.3726761197824935, -0.48421640176631975, 1.8520978683112554, 0.14900451528494418, + 0.7553309487914097, 0.995210897988966, -0.2653148753757497, -1.0047335940870337, -0.15140716923573905, + 1.0342357378533045, 0.9590192011054128, 1.3276618340182669, 1.7076552004070518, 1.3639368693762144, + 1.0626034699464553, 1.0985888634186862, 1.0871213821052033, 0.3518298069849042, 0.6905794127769393, + -0.5700252629850588, -0.10814050178161683, 1.3965639143955952, 0.8292785089561896, 0.25327348015151596, + -1.334944218927732, 1.6209990328336517, -1.26244979705121, 1.771347153639546, -1.4436659851102362, + -0.5033590550326617, 1.484309499478445, -0.3804774758165417, -0.6854434358446646, 0.11814627802495625, + -1.8940220672649586, 1.9948327521339193, -0.31419774418955715, -1.708608376028823, -1.8717143806284637, + 1.4405268554284918, 1.8275766986420505, 1.613878636732296, -1.6842307903910925, 0.40437384436799473, + -1.7974817731693786, 1.1222020737272933, 0.8616912290576968, 0.7450260507858868, -1.262819663341098, + -1.3890825964448705, -0.24733521021608595, -1.9566763230316564, 1.240599031645397, 1.0005078570110628, + -1.2429059760645886, -0.8343788480444481, 1.8180187022655172, -1.2406155795725864, 1.206905860440954, + 1.0657535671563094, 1.0688089382594308, -0.1623146723818225, -0.9394369384107097, 0.1644095126054408, + -1.5836754669396766, -0.45886757503464093, 0.02309988717518241, -0.11609000844996142, -1.8627934732063016, + 1.9415986762986073, -1.1741977923279308, 1.2850766678166368, -1.6650362245910895, 0.8711235235579853, + 1.703790813181027, -0.20031635110543, 0.7709122587840396, 0.4676763407673441, 0.7333591027965438, + 0.010661459873729129, -0.6248209657856156, 1.3499622620431584, -1.010555890938674, -1.9094767924575482, + -0.4118954261411796, -1.9764407569645153, -0.11597198863706204, -1.2058671493925148, -0.9480128119239577, + -0.5293044156931037, -1.0802020289569132, -0.17428061346628443, 1.79479586490669, 0.4507608914027035, + -1.5890677193516893, -0.4180241158854212, 1.1247910122551152, -0.10769135533882057, 0.2413054436244062, + 1.3070197809453399, -0.7234463442247714, -0.9044481440724681, 0.808060474881219, -1.9681916392611978, + 0.6794353030118225, -1.8140413117066592, -0.21172484209703857, -0.3970612901969721, -0.22610168646442563, + 1.8446444889972504, -0.14161684848047962, -1.0612317380319158, 1.6805704263182024, -1.5680342684533937, + 1.6367583739045974, 1.9603810572547848, -1.9461850695662726, -1.669279137293902, -0.4582040515383534, + -0.4903387593204007, 1.700009791169296, -0.1006260106974528, 0.46356012096704813, -1.148937310426322, + 1.1291686959534264, -0.43326682883595513, -0.7294554760312604, -1.3404464141642505, 1.4428709283346048, + -0.31527585219642784, 0.5232849548965959, 0.840594052127317, 0.8605144687711537, -1.6471991161039679, + -0.32119284017874694, -1.199582906673192, 0.3080262169024959, 0.14151294810583348, -1.7354321287231, + 0.4873457081904098, 0.30837874931057474, 1.3003825539901728, 1.9636934942685267, -1.017158000827136, + -1.7596484196087827, -0.6817110071876389, 1.8933995404689652, -0.27989446483984093, -1.067564992804905, + 0.48291319348514605, -1.6740080386493696, 0.2807422201361458, -1.9110968101706796, 1.6831495761856532, + 0.11512823651793447, -0.01736208420134222, 0.4405414565435697, -1.8085718479527353, -0.9564467319845411, + -1.58340676850203, -0.647955866711154, 1.6543925512752926, -0.7724666857844449, 0.5113949921548908, + -0.39267895944443065, 0.40761631065773063, 1.7690968773142668, -0.21764226027462374, -0.5169606846714005, + 1.4412302687210286, -1.8896763215831802, 0.07979887381656514, -0.11436797170178448, 1.0634323241712869, + 0.26858767414261475, 1.582753510128553, -0.010528543666477042, -0.3613643892495162, 1.391514209117238, + -1.4700872595733046, -1.5362821122874086, -1.818586245442571, 1.9678697208900475, -0.9362374796105595, + -1.4709960962767532, -1.182374325374778, 1.6385607669867177, -0.46775448579892487, 1.9576437315276696, + 0.915531228777362, -1.0860235734385926, 1.0655104012081509, -1.7770877181643954, 0.1443659128293806, + -0.23955298993629803, 0.41725367891443366, -0.8558589757408512, 1.024674449305122, -0.8538581096220099, + -0.17121172366938264, -1.0495343198650096, -0.8461809157835463, 1.956660524400533, -0.10451516941234473, + -0.9119888509755709, 0.9341633453090434, 0.5765821236303488, -1.8017153374435075, -1.6959921212218267, + 0.3565838506194048, 1.986423720658717, -1.1810787364750697, -1.3554314442606277, 1.2292087344595828, + 1.9389467760629646, -0.06060251846881748, 0.6471281822482204, 1.028562584237319, -0.9889764039700069, + -0.30300382154607064, 1.8809742113734886, -1.7911374091327446, 0.4093234223382991, -0.692170253260544, + -0.5766217362114325, -1.234711065294488, 1.4845455677723791, 1.1142993730640143, 1.0351547495051978, + -0.2304804542756207, 0.7896680860540854, 0.7368394967498872, 0.6117647784314304, 0.9649509001647774, + -1.4794529756304886, -1.5330264541276408, -1.1347331780500776, -1.957296773370273, 1.0497217009949296, + -1.876577007676679, -1.8707142834400772, 1.0355676671507679, -1.3024864669068572, -0.8110172097035955, + -0.7956308468122133, 1.5651626086889294, -0.5950947090287055, -0.15363512205018193, -1.7408026469236138, + -1.8514840915078024, -0.40821529034130855, 1.2085979600022174, -0.953165571253348, -0.03303673441403454, + -1.6012777563202603, -1.1080907034689993, 0.4974859939502432, -1.2774517368694216, -0.20163785863448247, + 0.8261780324851822, -1.5127015843190126, 0.14100033563999403, -1.052646319316592, 1.6782929279009817, + -1.5464154280508744, 1.7486715792103427, 1.7780537663957405, 1.7209562957702067, -1.5054888719925499, + 1.5292916602749358, 1.179965119787405, 1.5170349093126205, 0.7433092687415837, 0.9522185073327041, + -1.2380413345480523, 1.9354870169011447, 0.625636243747028, -0.09000816987169546, -0.5335972012344188, + 0.9674628266238745, -0.03967494279717343, -1.8591428816092792, 0.2309446236016406, -0.9041639531030761, + -1.9026103934874206, 1.5541589920102101, -0.45696090245009824, 1.8466298423672463, -0.36055327204706167, + -1.2458073226198056, -1.4410345639464017, 0.4731557626110279, 1.2307498218360111, -0.30913913858563724, + 1.8557259220973865, 1.1724797822135766, -1.516241961681641, 1.147047572924638, 0.009295148811337306, + 0.8291735590935811, 0.9825251314963639, -1.6702374836146134, -1.3070146895265724, -0.35977729833032246, + 0.10882986028094521, -0.1545812635060546, 0.5966946312401102, -0.12463585998219351, 0.764026848253426, + -0.0653501987613172, 0.5337159207310522, -1.3783865008607394, 0.3440914524635428, -1.6128660868012537, + 1.0505520366072156, 1.4508195056160966, -0.3811605116562866, -1.0184337989154448, 0.3472432185953034, + 0.7934690008453043, -1.1871814411996455, 1.7160465328415073, -0.8932034391682153, 0.22684342695521842, + 0.006601468173127678, -0.8703158970162406, 0.7854001417512366, -1.6096149032006064, 0.5734105371918883, + 1.8323642183966413, 0.8494195484926621, 1.52159530384528, -0.43666400911265324, 0.6949749758230679, + 0.09014001060463439, 1.1181673671725294, 0.23797216532766896, 0.9091467606318648, -0.051242214293259813, + -0.3492957666583303 + ], + "dims": [2, 2, 320], + "type": "float32" + }, + { + "data": [ + 0.0018182522274203805, 0.06756509596322857, 1.889723294866065, -0.10289095754140298, 1.5711519216894745, + 0.027529292075774592, 0.9603256438495507, -1.497309631471758, 1.9251601219617065, 0.8851491878732389, + 0.05078780805071137, 0.40903741455911735, -1.6644840015459215, -1.348225759557871, 1.615832737926227, + 1.042719864089511, -1.9289326046242312, 1.3417535199012995, -1.710655801290117, -1.130165128147044, + -0.3755000776719024, 0.6155781582426902, -0.5883485771887473, -1.7159986811406176, -0.16333854572017525, + -0.06079239446971929, 1.6926064002585495, -1.8776332892248098, 1.4601970742576578, -1.3202352800423185, + -0.12899708506012164, -0.6003093613879029, -0.1726349092091164, 1.2394146364350664, -1.769629141089184, + 1.4197330981295524, -0.9504267735259635, -1.240675610662361, 1.4018548317486923, 0.5332018345413356, + -0.16073415033536875, 0.15303724703170385, -1.8037963193841238, -1.311714810716846, -0.5740602095553404, + -1.0372165240096223, -1.370949121899355, 0.29661966702940035, 0.07816374250571023, -0.41396787300651905, + -0.3694698645575212, 0.6759765867037197, 0.2952400780995772, -0.06275069272676337, -0.9130274419561628, + -1.8944701092982958, 0.33465806810173593, -0.404939193749847, 1.4043718178232805, -0.5590711165263631, + 1.2184926422968934, -0.7087036307842709, -1.6055109382118182, 1.968257767003597, 0.529695028652811, + -1.9967381817454308, 1.595078125176956, 0.9871155490120058, 1.6566751957870993, -1.609626148231829, + -1.1397801527001823, 0.02238544560446254, 1.4873497063814245, -0.4755743745599572, -1.6926423664304844, + -1.4161828320433028, 0.346372427157398, -0.18203459832580027, 1.4635583159542183, 1.5944148599650028, + 1.3186726267824955, 1.5675687012032968, -1.9754809408365706, 1.44557963549327, 0.9397875688795354, + 1.4424046221061442, -1.4458135310352649, 0.9975520389856136, -1.9027511578082317, 0.9144382540308484, + 1.052124261689804, -1.4732678674195494, 0.29024955712503164, -1.2231252144665383, 0.34787712508784985, + 0.3556934319800238, -1.7419738471239645, 0.8630538908485903, 0.5386452782458866, 1.7600516786463105, + 1.8905437777505014, -1.5744952794523028, -0.7530004157782235, 1.5678919268380707, 0.034533101389558674, + 0.7325333516090975, 0.9775064333478163, -1.6408433791748216, -0.7414398323785214, 1.6725433719876586, + 1.0072882099919305, 1.4341931058179327, 0.7139948421146176, 0.40545031341822124, -0.11478362063979386, + -0.9345270890825441, 1.4281286745225614, -1.39970554180245, -1.1485396410325386, 1.1495990036520798, + -1.2916127094423402, 1.4211660589871826, -1.0749317173140405, 1.4370776307284663, 0.7880288709576773, + 0.46732661965227873, 1.56798877542517, -0.9531716195760707, -1.3739051298849194, 0.9766290318098436, + -1.307661662111708, 1.8574559170417002, 1.35797073743995, -1.6940130226054606, 0.28491131826133387, + -0.36419491260352554, -1.3047662545854015, -0.9266815176320033, -1.2358711507932467, 0.9127887752631247, + -1.6466848327495578, 1.1607458121027339, 0.46297657760513733, -1.5495718508374514, 0.3292413137438217, + 0.7675934897387728, 1.4008121445440214, 1.4898570624591958, 1.6030744917802648, 1.2925872420362232, + -0.8421561750911684, -0.3407292616133608, 0.38924919209979336, 1.6793513775487527, -1.0373013949726966, + -1.5337353736283532, 1.5143316995909872, -1.6320472160478126, -1.3996482770156646, 0.6337864872715988, + 0.5406528347636357, 1.2967809902878562, -1.5182702863754916, -0.7399098341126589, 0.31978027899894723, + -0.4320909370805026, -0.057815767103424065, -0.42656779912656795, -0.7191461156604344, 1.732444695508872, + -0.16793165663622744, -1.029044319841585, -0.7379183254565955, 0.6335667491493258, -0.7407757474651113, + 0.737814588729532, -0.41713542698826167, -0.862992043249343, -1.8537968903371889, -0.480058608858549, + 0.04028745468513595, -0.11696118988455151, 1.2159286584219329, -1.4551651039165874, 1.8518920420484895, + 0.8324620148383071, -1.9503205997190287, 1.3118092522348013, -0.41781057862944326, 0.47025354333711356, + -0.08599400306878557, -1.398138636933056, 1.799030968066016, -0.9016154689967486, -0.44642885397970034, + -1.6161407274817075, 0.6108393015698415, -0.9652371448534662, 1.472448459030451, -0.12097411552763226, + -1.7427779621544364, 0.6772588555443013, 1.239525535102806, 1.4978793781566582, 0.9794171716198061, + 0.37480400234555056, 1.7099069435864092, 0.5339030487857208, -1.7368267186422761, 0.3401246801395308, + -1.495349576003802, 1.1154539341471592, -0.5739747352480027, 1.7719108709631328, 0.7087464471791378, + -1.407094251765498, -0.5711994993106657, 1.6197007171162792, -1.665245693725593, -1.4093290138388097, + 0.8150971020478908, 1.1565262598728276, 0.007036682898540647, -1.067724969488646, 1.1760370444772006, + -0.4660822995530971, 0.18663889657333232, 0.8600384570962394, 0.07639203983671461, 1.7055162765205303, + -1.7134292208088802, -1.3413558800873675, -1.338677372528159, 1.4246968540400653, -1.1823984287999973, + 1.4751654585472211, 0.5262834049380078, 1.5117343050060867, -1.409416488118043, -0.39544742603356386, + -0.6577586706586658, -0.5919201797053688, 0.6013534842506445, -1.1862135968111707, -1.229417973714626, + -1.803412419156234, -0.7655790098575235, 0.9128632433156794, 0.9036623476529559, -1.9831271121679324, + 0.8324308647368319, -1.759507307385337, -1.7725931616787687, -1.7039303423725647, -1.4439967872268928, + -1.7432455401143834, -0.02216033991501387, 1.2819676717165, 0.16659457648361364, -1.167642388668959, + -1.7143084152722228, -0.7289345444538382, -0.02925241516287791, 1.9566358667801342, -1.2857581699546135, + 1.0915031830445114, 0.05084200795390714, 1.083568818422366, -1.1315486700234478, -0.8881346175534794, + -0.63619987674788, 0.3799832019858531, 0.2477670922101094, -0.6132210208290614, -1.8451948781462812, + -0.22847217268867048, 0.0025467735349682386, 0.1315834394384794, 0.1776926575489597, -0.8691295174311664, + 1.6637912565242994, 0.448901769947029, 1.233816013145204, 1.7971799993597228, 1.8719614934816882, + 1.655937636621596, -0.27359273976124054, 0.08461142131684696, -0.2757947346097396, -0.9521228519499276, + 1.766034536643284, 0.8831916052200137, 0.9813027219865562, -0.322591101625501, -0.20675723380495992, + 1.0866641329284041, -0.6397672290782843, -1.9715973970816654, -0.36395252045986304, -1.4160336028155198, + -0.7487477697571272, 1.4091533113140509, 1.2152244001439598, 1.0139512253023701, -0.5841820989850488, + 0.36171343432432046, 1.1810326691265303, -0.044977125366693294, -1.5719763377131732, 1.636814383280785, + -0.8254090686593019, -0.2739258751225844, -1.5838736296117837, 0.057544692367468286, -1.6536791042504957, + 0.8676152862870037, -0.6012988236535559, 1.0789190140651197, -0.21655562768188386, 1.5865699400089461 + ], + "dims": [320], + "type": "float32" + }, + { + "data": [ + 1.535672932186043, -0.3469466691127403, 0.7594896463626952, -0.05450122463129414, 1.4639377922956394, + -0.6333990278356465, -0.8242789470237648, 0.5117653543833605, 1.6078505759993273, -1.410275750604895, + -1.6792951377646883, -1.783057576321041, 1.1956662347204423, -1.3979831191193002, 1.7644067312268517, + 0.4240762243207543, 1.986096182518743, 0.36545941859180964, -0.8774236745093011, -0.8647372274160503, + 0.8720148666725347, -1.022106286236455, 0.5503111675120635, 1.0204841436521281, -0.9254965061314904, + -1.3449432022823808, 0.006824458535456657, 0.07690008991648423, -0.8426817383905396, 0.9996621329373534, + -0.23056243949407484, -1.0440039859718286, 1.9168909615768683, 1.5600000104620682, -1.9890822883775865, + -0.3604004168107382, -0.028511959235538065, 0.8476098198214288, 0.22053970034789216, -0.42929632097288817, + 0.6599479925924321, 0.6647860485919495, 0.10175396167639938, 0.22650892002231515, 0.4701540897019987, + 0.624214514356682, -0.6652257805050041, 0.8518349008799753, -0.9562813618340789, -1.657496508881473, + -0.3312572279583206, -1.5494034812904562, 0.18877981986543801, -1.2351800795813066, -0.07918559380797063, + -0.09391536586009241, -0.2856357420391582, 1.9393958604954182, -0.7529216437305211, 1.525084648903749, + -0.07883509109638975, 1.376637107607113, 0.5783362536287875, 0.961847664027677, -1.6855455725917468, + -0.5830772019897683, 0.4271291901307981, -1.8373857521152086, -1.7965394924729141, 1.7115697467771378, + 0.2565457539488545, -1.3360260284983019, 0.4353676471582455, 1.7248708601658969, -0.9750598890096729, + 0.05312641822767361, 1.5787554531472985, 0.9667162219022503, -1.364971428290251, -0.2814850946411962, + 0.9013643208289279, -1.4725055888862961, 0.6001425944665559, 1.2723681158746203, 1.7714493392964075, + 1.4044899825398272, -1.5787548082153382, -1.1589036159974757, -0.4012478414167475, -0.06868641055197777, + -0.7534521481998526, -0.8700101449208493, -1.1662115104665567, 1.7611310737805477, -0.2501517942331226, + 0.12866215308587936, -1.4699875001512854, 1.0395395370450604, -0.5782390952646876, 0.63115653417037, + 0.10138581116634082, -0.07007439881121424, -0.4276277546360472, 0.418589841403306, 0.9267207479900215, + -1.0293664343515356, -0.1495871781602336, 1.4452889339030666, -1.7189823464809564, 1.8323799237149645, + 1.8914008693919682, 0.3829486757403364, -0.8735369861149813, 1.602486711188683, -0.39959917784662924, + -0.8673792916868024, -1.2627215362178648, 1.8597348040684398, -0.8688156300975107, 0.15713415561611388, + -0.13148226217512082, 1.883732805180382, 0.11420203807616502, -1.6552288945493094, 1.0335466032430753, + 1.9806710089769703, 1.988269693866676, 0.29427412741632075, 1.4966799360753749, 0.6937827119996989, + -1.298620046493725, -1.752952308784005, 0.46645438478103873, -0.898908219432915, 0.7139098459371658, + -0.16215199540773462, 0.07954281050960432, 0.795652990025399, 1.5967568490712063, -1.2445652996859247, + 1.9127555713254205, -0.4996844935898572, -1.1156627480592256, -1.2948343944985163, 0.18276720875230268, + -0.748683470251498, -1.5079466014120557, -0.2494558107532141, -0.9231537960141623, -1.4121241243829443, + 1.2059834829573104, 1.905725511300579, -0.39337905860681044, 1.8425190053973166, -1.6566221588247219, + -0.242919176072947, 1.2425502129492436, -1.4417507121400348, 0.015600407032383856, -0.2242098694907284, + 0.18796276556529357, 0.08107732342066765, 0.7149451467441841, -0.20769007081368773, 0.4421202004832834, + 1.5233025839787455, -0.6642431462292846, 1.5564028464468986, -0.1586815058735116, 0.6099306071219655, + 0.8180887224937807, 1.9911546339103818, -0.005984685083011421, 0.6777759409892354, 1.7289968623869099, + 0.5264262640237458, -0.511038272902959, -1.7235775305068346, -1.138944875679032, -0.9623892814614488, + -0.6380738572168294, 1.8832106250881075, 0.028541651530706424, 0.7394956760616829, -1.5455450050824036, + 0.598697699776686, 1.44227094769795, -1.842926293477114, -1.9786511228960357, -1.774125089606943, + 0.4273755931309067, -1.1833540770674968, -0.29742688579612864, -1.7932368057978882, -1.3999703979662605, + -0.5494229951060436, -0.7692231154827809, -1.0112160791506497, 1.633993910846237, 1.3945699010195831, + -0.8649776103569309, 1.921348771224042, 0.832322610715301, 1.3754060709990767, 0.3497018723561567, + -0.7191957838389857, -1.9794221990125722, -1.84384806203993, 1.2324522851211803, 1.7698494016317143, + 0.006624102198243165, 1.5911519918365267, -0.762455861009844, 1.4479210196035108, 0.7818151145500849, + -1.9876926272814606, 1.8202062970885162, -1.6446357331454369, 0.8692666690506741, -0.7358532212979823, + -0.8444806659707744, -1.6015224446062994, 0.7479281419258141, -1.6523782603794155, 0.6710725185977369, + -1.1710932073100304, 1.4784513737588512, -1.212966513263102, -1.3741809040280142, 0.25437428444308896, + 0.8440351752665407, -1.1722116121570672, 1.104161389783421, -1.645735790976162, -0.4286533806712738, + -0.37044520217626875, -1.574330285391767, 1.4899314272896893, -0.8495642882336822, -1.714377156019676, + -0.4893435327563349, -0.7616337581393848, -0.5339391487933929, -0.3003289730553087, 1.3489307896261735, + -0.14680109166432054, 0.1026969670558735, 0.32430953678969043, 0.1795871726769951, 0.9696062238740311, + -1.5296687271207166, -0.2770372362376037, -1.0409472934130868, -0.17306368093190905, -1.1960408781183967, + 0.984219061951209, -0.4077661181651919, 1.7423047847942446, 0.5608878908901787, 1.4329493489434109, + 1.8986413512869937, 0.19154199669760352, -1.3315756935180012, 1.8870822754754517, 0.5674631415439482, + 1.1017148980678542, 0.7256621357674105, -1.8682573426264009, -1.2687446906641284, 0.5430939279068951, + 1.8279931413962558, 0.15890686973919443, 1.394841983124743, -0.8330211159668224, -1.2412716683059033, + -1.1755274256803165, 0.3146422214936937, 0.5127310756940888, -0.6223681329826247, -1.3728009148038876, + -0.5073994704733549, -1.1727465329222264, -0.07518002339175833, -1.6218851358655701, -1.3314808424730247, + 0.2696107099425271, -1.353815758928219, 1.6070801592460056, -0.7018653814032136, -1.594649470877921, + 1.8662880614030657, 0.009632792539534307, 0.885433106263176, 0.7081454198732997, 0.12480572241808119, + -1.9002637028711113, -0.8823815470565757, -0.12794198437065507, 0.3682196882451354, -1.1962414622570767, + -1.101920787984521, 0.1703217046774217, 1.2755057257388405, 1.2757461273763866, -1.7253598839195732, + -0.3935586680170111, -1.6790297555951925, 1.1726640337873802, -0.7187759606615787, -1.5997974808572053, + 1.9512036824878933, 0.8991982283799391, -1.516998597379371, -1.0918962406357053, 0.12845929863120276, + 0.387447437135819, -1.6766371647631972, 0.4172435231617522, -0.8587881195399367, 0.8973805509978297, + 0.5384910477202398, 0.22290981983700497, -1.9824980848037859, -0.19789410371539873, 1.0396641208977249, + 1.9498654847750698, 1.752979186273122, -0.10251547854421528, -1.7031576116596918, -0.6422947693243835, + -0.5947775282776488, 0.25094777162345583, 0.5519773563378578, 1.1845669608153342, 0.07011886849473115, + -1.5689347607142432, -0.9068208446502926, 0.4518736648271817, -1.1908598340431444, 0.9123278060019366, + 0.3808045721687314, -0.5161116183400685, 1.4633312728276353, -0.24955275031843804, -1.9270793627181808, + 0.5510310380033525, 0.002103402836195478, -1.9722027133603266, 0.8207770388309132, -1.2709862666051333, + 0.03660015849392373, -0.08721025552259398, 0.1480447971653538, 1.3975878551198289, -1.8688681862560603, + -0.2735983144132472, -0.29150197793885635, -1.349553505848272, -0.14289894302424067, -1.0632608448362548, + 0.9197316019995538, 1.6766092374653363, 1.4333994578157911, 1.8497508886723608, -1.8365902161760914, + 0.3329875047259945, 0.28711035354851955, 0.018743287980965917, -0.47550704561352664, 0.026002587809994537, + -0.9815518239812109, -0.30422655490353545, 1.1236748290508274, -1.996970334350796, -1.663190926732148, + -1.4930228184840004, 1.2293686779591093, -0.11228295031816327, -1.399262159949875, -1.2745774075202778, + 1.0404471355251506, -0.9042932188930193, -0.483855773240883, -0.051899262666108115, -1.1517591694487734, + 1.631117268451015, -0.4341760983538707, -1.5093199848977354, -1.524695207871412, -1.179033179719653, + 1.203939869858461, 1.2278371883112191, -0.7764972190751465, 0.12469436067847539, 1.1254668267275294, + 0.1253270059252225, 1.02529025377972, 0.37477534712132243, -0.816607896481754, 0.7652933238577306, + 0.0816252203587613, -1.6877073529228523, 0.282188424454314, -0.48899417877023144, -1.10579595806544, + 0.4180711569457314, 1.239608967084651, -0.8553327976952234, 0.7601553028351749, 1.017191993054694, + -1.561711107008871, 0.18166558203866234, 1.4575039351725838, -0.7919992885427041, -0.05528739747934974, + 0.5393789182198327, 0.9208003955213648, 0.8037584630910892, 0.2508199691349171, 0.5025718274381168, + 0.40223725437742086, 0.43401128486340124, -1.918673978558985, 0.38895512761013773, 0.9647875436316022, + 0.356426573504554, 0.7676218046110401, 0.15946706730485438, 1.2727737024033576, -1.1428215846133938, + 0.36778995418490545, -0.41392909578544224, -1.550642999283478, -1.7016418383565188, -0.3516276355010701, + -1.6424434547903983, 1.2296355686757101, -1.3262048004001983, 1.9866748350391505, 1.9039145370701833, + -0.4605978047947623, 0.37289955561548194, 1.2909351136100344, -1.4775326769813537, 1.8608708474080071, + 0.6440656172393684, 1.4358923542702868, 1.6635530454398575, -1.7844300247360296, -1.470415868795488, + 0.21864396672047892, 1.5488664195436606, -1.0864322992770177, -1.550780881959068, -0.8331945313037004, + -0.9367699280324953, -1.9013228249406309, 0.7807264098375688, 0.06677827961955263, -0.4865949947067687, + -1.9079733463147095, -0.8445233464370387, 0.4065074139836655, -0.3310839064029283, -0.2904445573034993, + -1.1753367420636245, 1.3721435340052208, 0.36660883645931097, 1.27723053302687, -0.9359216637576937, + -1.732231846976478, 1.0644600709999477, -1.6378422934868384, -0.8826400850725795, 1.950622879844948, + -1.9911319096225792, -0.6598073662934398, 1.8955996856482145, -1.4071132961709223, -0.8795225115767629, + -0.5029228970810946, -0.7734268477225967, -1.716542237524993, 0.04010671043366898, -0.7937284158037281, + 1.030026939297609, 0.9801808342123648, 1.2953427689382302, 0.20610803631475605, -1.672761300291775, + -0.690673451769495, 1.6609033000524338, -1.897131087105456, 1.2029533984904228, -0.5681454803874688, + -1.3646956682920965, 0.22071074912276334, -1.4735886916157908, -0.9695144027680014, 1.626222864433485, + -1.8694559899308487, -1.8879003983306655, -1.0176033048635613, 1.7915586444709328, 1.8810192124623084, + 0.5319984718680058, 0.0113238596202212, -0.09805090157632446, -0.5444299501215024, 1.1135935258682927, + 1.17684427133796, -1.85426568001437, -1.7530946944132086, 0.3038089938756876, -1.5870230070820002, + 1.9106333020042747, 1.3937407603560725, 0.8591788216145968, -1.612956779000272, -1.7151209190289016, + -0.13707423626294535, -1.4389728179178984, -0.05236093609874359, 0.9751452825232896, 1.744306648935904, + 0.7254929535860901, -0.09824503868926815, -1.4925208247531838, -1.985227418605298, -1.405540454178178, + 1.692915817031472, 1.2230668144021735, 0.04262811065188643, -1.1756894733009666, 1.0222275091190456, + 0.4934708666464802, 0.08979456736565261, 0.10059671914562518, 0.7155249975927536, 0.04082949674837977, + -1.715826873724553, -0.979189481763262, -1.1065843508804214, 1.481429410565739, 0.5278383608268999, + 1.4941027771635946, 1.4151786058577498, 0.07974076288029774, 0.3167509060420519, -1.269619345887964, + 1.8667680276727765, 0.527367815431, -1.874110045435497, -1.9373013120064702, 1.5330729150450173, + -1.7833509822122444, -1.7182607692067773, 0.7561591559894678, -0.1056962696530368, 0.13014948563496898, + 1.804101947913626, 0.7276195691909635, 0.021465712639121115, -0.5163553036182069, 1.1855106734783103, + -0.532372100695512, -0.5871635445412506, 0.161643292508721, 0.61018489160484, -0.9869821416193743, + -0.5318766940780302, 0.9532631042147388, 0.4597709134353236, 1.2142228742259942, -0.8224515402258339, + -0.879922983657166, -1.4710925151016916, -0.29851124917883975, -1.6631372706933156, -1.417993373545026, + 0.6364481896978704, -0.6013255938328603, -0.046835161333119935, -1.7247175181005758, 0.9825982199711403, + -0.7264776248635592, -1.463988875635824, -0.5013956201257255, 1.1933878395314643, 1.3056455851087287, + -1.4398688148432273, 1.7038722585040453, 0.46568790654958114, 0.481485333420693, -1.2873930064513237, + -1.1475617778051763, -1.9673375617555031, 0.39874490557435127, 1.0960170584095357, -1.8987243885981488, + 0.36983554057526735, 1.9718844590478293, 0.7894176749822801, 0.06983603687412288, -0.9000466156841869, + 0.6428129286371904, 1.704798225993037, 0.5950045030048496, -1.1678955586471442, 1.237662010594594, + -1.6482921001228146, 0.7270614937877813, -1.0006186813130835, -0.5305400798817805, -1.5252716548293819, + 0.18855276048488978, 1.5437352372976703, -0.9397215004565727, -0.4258934153954881, -1.0950445559616853, + 1.1844079915434298, 0.024990774215178035, -0.16149461270780652, -1.4078837300903269, 0.09499589792836627, + 0.516842370641422, 0.4800833347119191, -0.539291197739594, -1.5117979701954605, 1.354396143788092, + 1.28278689745333, 1.8488206619245648, 1.3022599053953776, 1.6609548614809775, 0.3269713781203789, + -0.38485903666664, -1.464958277436181, -0.3992461504929734, 1.0699961189397085, -1.0135843210651023, + -1.458604697589653, 1.490121083969428, -0.2595359932483561, 0.20854389182544342, -1.7482190390121701, + 1.3007127316326326, -1.9884730509986825, -1.4952841032131454, -1.3179011133758536, -0.388478076479009, + -1.1589100485370416, 0.8387145536985532, -1.3384696651494759, 1.4683529176022008, 1.303145953986827, + -1.3041819891109316, -0.03449749547681513, -0.6608734038387656, 1.1683787754166381, 0.5655509145236746, + 0.37738607602963814, 1.2152762148898635, -0.29353655718583926, -1.0509092280694636, 0.7139081884804019, + 1.5196106141395527, -0.530586207320952, -1.558831387258346, 0.01131046295330318, 1.9344117181061735, + -0.6850503993030497, 0.9331665418290909, 1.1688357095654807, -0.42466684124295995, -0.49121961262440816, + -0.11897540791552874, 0.5942162255141863, 1.7548838522451646, 0.4438013028171106, -0.28183936745813476, + -0.07495854303862437, 0.9303587326961971, -1.3198776631733748, -0.591718773961956, 1.1127108159764676, + 1.3939520197540327, 0.6360105102962654, 1.3722503898910716, 0.1757960098808633, -1.8297470389548955, + -1.9205472381959057, 1.4666198830651629, -0.7830326162911714, -1.4564248566278515, -1.0967977812614587, + 1.164770039819981, -0.5760771475874042, -0.7667709006388028, 1.371522788043694, 1.7326600398634024, + -1.8193902025531763, -1.6090197630011929, 0.09836987546776577, 1.0677415637460363, -1.5307232030478781, + -1.5599516580470505, -0.2609675276531007, 1.0598276568453162, 1.8794380814113483, 0.5316994667209949, + -0.7552146503779023, -0.4617287817040179, -0.3819745586646004, -0.42575961119349426, -1.9237942552613312, + -0.8825058198571423, -1.8728798417790404, 1.7802885739921077, 1.8333435291969842, -0.3098252256281784, + -0.029956863413143964, -1.0772837825116914, -0.08180463340649524, -0.6945910113459792, 1.1668128443816146, + -0.02738480437430635, 0.39293059281590104, -1.6704359314356383, 1.6869956995774205, -0.2294375604199823, + -0.32757809443951746, 1.9764189201357567, 1.5201484938081151, 1.087504317388186, -1.6272710803209698, + 1.0397868469069298, -0.22176854941092294, -1.8820468396186323, -0.5303897107068192, 0.594170569473933, + 1.6960001373937432, 1.9644545152850057, -1.7960342834770175, 1.7873883299813267, 1.3489957623885935, + -1.6820391707003042, -1.5713129762775537, -1.3637851932919034, 1.5936068708950781, -1.2089638711610604, + 1.322028643928432, 0.8929678781012855, 1.8401053408016272, 1.5452683829326075, -0.9171427145163484, + -0.06745535875434427, 1.9379586035273615, 0.5855503730756357, 0.03549855059545948, 0.6527698319031092, + 1.6754602207349976, 0.7728323704391817, -0.9665588877441182, 0.6173545510334506, 1.3120695172827377, + 1.181821226786317, 0.1841309435168954, 0.32318631702986167, -0.790159398034489, -1.385019609574396, + -0.7118643666238835, -1.0439971536099275, 0.017584768122861583, 1.7536303032255187, -1.1965922071808155, + 1.4548082915973595, 1.562828560283652, 0.0920828524560271, 1.892000960124009, 0.3648061373597171, + -1.9613669287159263, 0.841563763070833, -0.9118328355251464, 1.9226565574363006, 0.3988462224271192, + -1.970188432590363, -0.8264337415665439, -1.9090851263430704, 1.8428915650288547, 0.28596991752644385, + -1.6708643684685667, -1.0762549708362332, 1.9492472760488564, -0.17802109704659852, 1.9236671687550047, + 1.9548632849049623, 1.9566450030001414, 1.3303550873049677, -0.5915124672929295, -0.0037832544010933944, + -1.6026781861800838, 1.7578833516354813, -0.2956678774270767, 1.4060455643195402, -0.7370157759032727, + 1.8789198126787916, 1.1123902493105078, -0.8769185681462304, -1.2618214191177506, 1.8674610245111278, + 0.5103075356648485, -0.020685118611023512, -1.407221324818173, -0.7491588381751608, -1.3743460812306214, + -1.7710712130536228, -0.19455369352318552, 1.3434990212660862, -0.5544338320325721, -1.324247058015053, + 0.8874849369101687, 0.6838871095643375, -1.5617313105262172, 1.204432716341258, -1.4981479923604955, + -0.06499977622096687, -0.8060264147106553, -0.36092597365795775, -1.307326777195418, 1.6399424900785, + 0.429157912433868, -0.9915570262096942, 1.5128426032058089, 1.6375586318255548, -0.1737010535017669, + -1.21285453753765, -1.8844155037723906, -0.2590630754224348, 1.7328565249414716, -1.260633142919116, + 0.3637043664955444, -0.48087365705468965, -1.7001295586898113, -1.0775016378447075, 0.2620695698901221, + 0.5015363913767086, -0.42080100290276246, 0.5338170065286052, -0.43568151602634764, -0.744286733837793, + 1.57647267103789, -1.2491109283310529, 0.10032144655805375, -0.46919353377702855, 1.9415827315644636, + 1.2111393515469855, 1.744725164783687, -0.6871612254817352, 1.406736078990102, -1.063724178982385, + -0.904699966390976, 1.5681407930006221, 0.79849818604837, -1.7759907970834616, -1.6325947440964974, + -1.0309733602826086, 0.29563414198237936, -1.7157737037653629, 0.2876568188935451, 0.21411659926835913, + -0.1601632043965786, -0.02605079418095091, 1.2041639219664182, -0.6351323647136597, 1.1149646585336592, + 0.6657515663650084, -0.4672646227384094, -0.5117766415018226, -1.4643244157794149, 0.39081520672097003, + -1.502649477455031, -1.637368884151761, 0.34542161036123176, 0.060151105688381, -0.5045040651104555, + -1.761988723037204, -1.9197872473179176, -1.3665270161331975, 1.3928026939637972, 0.39218445250695577, + -0.8470063024385848, 0.009038121027233892, 0.46871439485211575, 1.459827780771806, 1.4853766551455694, + -1.2321752545416356, 0.3748806345040103, 0.20582729258619814, -1.4266279966077402, 1.2950255786963805, + 1.7125611822808544, -1.407545517068188, 0.5169179018491512, 1.8595592751857541, 0.9487671455033482, + 1.9467423989905699, -1.5919149626150748, -0.4630901723451881, -0.698284068975914, 0.6197574561950008, + -0.8199869405915381, -1.7196626702920055, 0.6036024034626806, -0.8348164600145518, -1.7650166093756132, + 1.5829990521620996, -1.588645487863901, 0.7633248861408699, 0.5800948434762754, -1.7159447523887836, + -0.3836837699904496, -0.9746560067630572, 1.4480442893861705, -0.24403527878135645, -0.6397662241706819, + 0.956203271386264, -1.4601856308265049, -1.5649468816584298, 1.731664582215319, 0.9679933953420496, + -1.9722379093946607, 0.24076423675934056, -0.19244242272389211, -1.3854799949067935, 1.3744990882455346, + -1.121046645776083, -0.4342567706309435, -1.7159646482293107, -0.9317859666979054, -1.698219647396134, + -0.8288368620433939, -1.3875410583085985, -0.16399331338641066, 1.4667160798353667, -0.020345764369364083, + -1.9585520591695529, 0.9886716666217517, -1.0701744437434098, -1.3248249591382057, -1.3272246201915312, + 0.906046259715148, -0.9554587301398687, 0.16744253332193715, 0.40874734944503466, -0.7237514235199383, + -0.8028952942996463, 0.9478548199038599, 1.6268191787625108, -0.7376232063503751, -0.6874490141085632, + 0.20469380737641973, -1.940886393624119, -0.9715176541080677, -0.11409081023343237, 1.8208884259904847, + -0.05753377002269655, -1.2533113228725696, 0.18235199190840046, 1.3670427559403047, 0.7183594427747524, + -1.9834311091439476, 0.09488256814231644, 0.07406140049599319, 1.0427950622016802, -0.7928805141629418, + 1.7221214208634228, -0.06548459693275177, 1.6984102601031559, 1.8777510809050026, -1.735259524674964, + -1.1416240368033357, -1.7612022583614682, 1.721880360655705, 0.14372177475853665, 1.9269311955654835, + 0.19978809107127216, 1.0299566806165856, -1.7617419918814026, -1.2737765895488096, 0.7789099525859564, + 1.9816257384474012, 0.4482897887627919, -0.9051913536142644, -1.152506387584042, -1.8817136441487783, + 1.1054935295772461, 1.577999662025542, 1.600449927735128, -1.4919075331081064, -1.9550057574515671, + 0.1306184124670624, 1.4754764229533928, -0.808023880270273, -0.21695285993080393, -0.539628797891055, + -0.7836468765498132, -0.8815668388678288, 1.8917264703112755, 0.028934119940069003, 0.06879472114883711, + 0.4407647131615784, -1.702696302284755, 1.7815067716148931, 1.7950168026349171, 1.1405438335719111, + -1.1434283018085534, -1.720238715793207, -0.7729623733152229, 0.17672006075090962, -0.14942755614865622, + -1.6229777838891115, 0.3793725781830055, -1.0113407389657345, -1.9280392460441265, 0.7422498462017861, + 0.8331559663193939, 1.3063596659922263, 1.8113167679814106, -0.1401093760534291, -0.24674083884906395, + 0.15509679692376732, 1.8667087827355466, 1.1906398286118094, -1.673307806924564, 0.41063702592861695, + 0.2436862014560477, 0.24067383021132027, 0.22686603382511628, -1.7295093225806442, 0.6565075922933001, + -0.5514412373381097, -0.5236684516031653, 1.8733248509057603, 1.082970345098504, 0.3340204503283841, + 0.5450315229688343, 1.0954041212853163, 1.565649272477267, -0.5469992182522905, 0.7193108029242588, + -0.9050070254533322, -0.5121370718971949, 0.962566205706815, -0.24631520617082092, 1.3340325816997325, + -0.8820024080231894, 0.22736077826137535, 0.2252389330707789, -1.947448723525529, -0.9518843625899782, + 1.7502182429516546, 1.558646352665332, -0.838440251378624, 1.541757246903681, 0.44677553405529213, + 0.9918545507928869, 1.060951650228274, -1.3653319701374311, 0.2635328688559797, -1.6894618652561055, + -1.9316398959917604, -1.6545844047461316, -0.8374565390669062, 0.5467667551875302, -1.1703334497283162, + 0.79898936445238, -0.48742537394255603, 0.05126348262407365, -1.0630031367885744, -1.6755563384807575, + -1.7470496911251123, -1.4045037572548411, 1.697678496203098, 0.541058257415223, 1.9355948975325852, + -0.8470115353500489, -1.2030885197848056, 0.8919323754916997, 0.0702516207867685, 0.5155253592422371, + -0.3579514965668338, 1.7112737380442171, 1.9947965056065957, 1.2741397687110538, 0.09885151767767297, + -0.9770807797341039, -0.11682236263324342, 0.7272198637411007, -1.987824039940028, -1.1358258258310752, + -0.11090836034305251, -1.9915135816887366, 0.39056056844969866, 1.2932859858303178, 1.7109978939050503, + 1.1846928384025448, -1.7330449982026206, -1.1525984164585106, 1.104166927781134, -0.28565377527521196, + -0.9685059019914002, 1.7093828969134002, 1.9709107005494806, 0.049031526597832276, 0.4472417265612556, + 1.0921859039999235, 0.8763632205063123, 0.8161138478535914, -1.0720275802414108, 0.7266994153226873, + 1.233185460886041, -1.127435043988318, 1.0918127239321773, -1.8540096367958645, 1.9681192361925266, + -1.176325917090126, -0.30265014266672097, -0.44524467812690727, -0.9978154618024702, -0.667478816738317, + 0.15065079333379305, -1.0302715841959227, 0.829863553229278, 0.8134089689909292, -0.6474889076993566, + -1.079618527738825, -1.783292588379826, 0.748112963221554, -0.6286053844150628, 0.48331041409284303, + 1.663305348437456, 0.18479680885937455, 0.6293791717008288, -0.6005275360880811, 1.5747695362530774, + 1.5708476785905807, 0.5755861487097542, 1.2041008720516082, 1.6685888824542738, -0.10278064261508835, + 0.9057150675313927, 0.6510335974298398, -0.10744692672758216, -1.7129461062136837, -1.4064873182457918, + 0.4781316094642234, -0.37189635012197275, 0.16614793992804522, 0.03645962620683285, -1.6224894209420242, + -1.8138940006820983, 0.5069783696842336, 0.6849365239989318, 0.8037589654894051, 1.979213666270276, + 1.6861127134381242, 1.2233661916798626, -0.3986509966310168, 1.274497735591801, 1.605857523477285, + -1.2118797206236485, -1.0066619307873124, -1.358968189183389, -1.9510798049888383, -0.9808314235618916, + 1.742926920936518, -1.1022645984613817, 1.0929594394621382, 0.48488158650621127, -0.32877770055973077, + -0.47650834081572935, -0.5160849885006016, 1.3738126318494883, 0.8827072110361662, 0.48644110690758247, + 1.0382179714335322, 1.6721919595070132, 1.341715329629717, 1.7295025892939409, -1.522344995293861, + -1.5965490751871654, -1.7983927509857223, -1.0759710407128011, -1.3793282201703079, -1.443902375079908, + -0.9426382639949571, 0.5602210832754357, -1.0965977429851064, -0.19857124750589872, -0.7431182359930233, + -1.2699260459939286, 1.4876549876992726, 1.6274319173214051, -0.3309529465344534, -1.9454352826883534, + 0.12935585140981676, -1.0093456723551508, 1.7243377444859647, 0.10127369924344443, 0.697537788222963, + -1.521134755613331, -0.442714777461525, -1.6896188579102178, -1.9330985764980921, 1.9140786772267155, + -1.2925077312416482, -0.9509978830442094, -1.1889787216631778, 0.795835379830006, 1.4837581515063887, + -0.8597344665233413, -0.7448025823499504, 1.7455825639820093, 0.33723505300912304, -1.8208678041990423, + -0.12753920031860666, -1.757360720716986, 0.8256076807737855, -1.9972549760931448, 0.844750409785961, + -0.9594803067513551, 0.7862268813645183, 1.7393046013815212, -1.161126895447727, -1.6347790700422697, + -1.3348870119154936, -1.1621632421015011, 1.2696646718252413, 0.4845759791644788, 1.0668299384975475, + -0.6942334010657198, 1.4734240949292259, 0.4282074397978146, 1.6699946816827183, -0.6802086123370898, + 1.9208442056043609, -1.8532082289660545, -0.1592261674427098, -1.2431462763761214, -0.7286614982674164, + -1.522868872353036, -0.3825873577199159, 1.431979005569458, 0.43719966684470446, 1.6478260330278633, + -0.06620691473965401, -0.36945868484144917, -0.3809516652838498, 1.6855172591399752, 0.31969027259376137, + 0.09157179754578149, -1.3138870107882425, 1.4208147318276607, -0.03157398665509881, 0.03702456744844529, + 1.4819698957492982, -1.6015809663105944, 1.8331399913105164, -0.6094891041007129, 0.9393020005799118, + 0.6313754553821562, -0.3128111370670492, -1.324295564232262, 1.7609361120800635, 1.5935407847968914, + -1.280640014867119, 1.4668684416985176, 1.460389700948717, 1.0299991397017587, 1.2266139129378075 + ], + "dims": [2, 2, 320], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 1.1067028045654297, -1.6136860847473145, 1.261694312095642, 0.5976569056510925, 3.0194122791290283, + 0.6702871322631836, 0.10282492637634277, 0.3735429048538208, 5.239527702331543, 0.9048061370849609, + -2.0244898796081543, -2.0475926399230957, 0.08633577823638916, -0.7819606065750122, 3.528346300125122, + 0.796191930770874, 1.819821834564209, 3.565944194793701, -1.9118807315826416, -3.2814927101135254, + 0.6710286140441895, -1.8751076459884644, -1.4806747436523438, -1.796984076499939, -0.7681794166564941, + -3.1337573528289795, 2.5466723442077637, -0.710807740688324, -1.3013415336608887, -2.010349988937378, + -0.5839980244636536, -3.2363412380218506, 0.34071028232574463, 1.8070162534713745, -2.2573466300964355, + 2.2563209533691406, -2.721301555633545, -1.7052738666534424, 0.020109310746192932, 1.8161461353302002, + -1.2339590787887573, -0.3481786847114563, 0.14459623396396637, -0.2792869210243225, 0.3242926299571991, + 0.8492016792297363, -0.9436420798301697, 2.8654322624206543, 0.018474817276000977, -0.33994853496551514, + -1.4109879732131958, -0.17846572399139404, 2.4232289791107178, -1.366482138633728, 0.8393897414207458, + -0.06912710517644882, -0.1260005384683609, 0.14877259731292725, 0.8378958106040955, 0.2032637596130371, + 1.6857033967971802, 1.683059811592102, 0.020472168922424316, 1.4638370275497437, -0.2821274995803833, + -4.081655025482178, 3.739361524581909, -0.5193736553192139, -0.5895893573760986, 0.13349902629852295, + -0.8314229249954224, -1.955358862876892, 2.2120022773742676, 2.9806900024414062, -2.594862461090088, + -2.5279524326324463, 2.2374868392944336, 1.651476263999939, -1.7969365119934082, 0.5194283723831177, + 3.9661808013916016, -0.6912452578544617, -1.615968108177185, 3.2498717308044434, 1.6176962852478027, + 2.765564441680908, -2.780879497528076, -0.943089485168457, -2.211867332458496, 1.5465649366378784, + 1.5535883903503418, -3.7050137519836426, 0.04439067840576172, 0.36327028274536133, 0.9592444896697998, + 2.2070963382720947, -4.852853298187256, 0.6495039463043213, -1.601344108581543, 3.303436756134033, + 3.5125482082366943, -3.4023211002349854, 0.7367992997169495, 2.8616340160369873, 0.557513952255249, + -1.2049691677093506, -0.19210883975028992, 1.6728510856628418, -3.3260436058044434, 4.706552028656006, + 2.2566816806793213, 2.644940137863159, -0.9390113353729248, 1.6396602392196655, -1.3574936389923096, + -3.6357059478759766, -1.3324334621429443, 0.8182520866394043, -3.782191753387451, 2.5362539291381836, + -2.8861687183380127, 2.2147746086120605, -0.7912830114364624, -1.3549126386642456, 2.932422637939453, + 0.6247330904006958, 1.6168872117996216, 0.9066742658615112, -1.156375527381897, 2.196871757507324, + -1.3269041776657104, 0.7688918113708496, 0.02223837375640869, -1.3422014713287354, 2.5085129737854004, + -0.8842201828956604, -2.039457321166992, -0.0754881501197815, -0.4683438539505005, -2.3120336532592773, + 0.4231855869293213, 1.7217100858688354, -1.4091691970825195, -2.062229633331299, 2.0696098804473877, + 1.2929754257202148, -0.21851062774658203, 2.792795181274414, -0.24259614944458008, -1.6432653665542603, + 0.2709762454032898, 2.5165672302246094, 1.4215764999389648, 2.406688690185547, -3.7345216274261475, + -2.1278839111328125, 3.311349868774414, -4.237924575805664, -2.4865145683288574, -0.4375068247318268, + 1.7486937046051025, 0.9667145013809204, 0.7027313113212585, -1.8740135431289673, -0.3525621294975281, + 0.19565200805664062, 0.40774744749069214, 2.2967820167541504, -1.8403133153915405, 1.831811785697937, + 0.9851721525192261, 2.7873969078063965, 1.0879806280136108, 2.5585243701934814, 1.9414751529693604, + 1.6000714302062988, -0.12208014726638794, -2.56121826171875, -4.894813060760498, -2.881957769393921, + -2.041257381439209, 2.9550018310546875, 0.5040202736854553, -0.27999716997146606, 1.0042527914047241, + 2.926683187484741, 1.3717838525772095, -0.24589979648590088, -4.2212233543396, -2.1938352584838867, + -1.6489169597625732, -3.442727565765381, 2.948969602584839, -2.7220163345336914, -3.187354803085327, + -0.34392428398132324, 1.470370888710022, -1.630984902381897, 1.2510205507278442, 1.1136020421981812, + -3.759488344192505, 1.4942673444747925, 3.067783832550049, 3.345754384994507, 2.6331236362457275, + 0.9775646328926086, 1.2827643156051636, -2.623198986053467, -1.1612101793289185, 1.7932779788970947, + -0.332869291305542, 2.42099666595459, -0.9636011123657227, 4.5822649002075195, 0.8944255113601685, + -3.2404866218566895, 2.7085609436035156, -0.4827519655227661, -2.3480019569396973, -3.114384174346924, + -0.8162459135055542, -0.9214845895767212, 1.2764832973480225, -3.152130603790283, 1.567040205001831, + -1.699249505996704, 0.3841613531112671, 1.300299048423767, -2.6244685649871826, 0.1572742760181427, + -2.503662586212158, -4.367088317871094, -0.9085763692855835, -1.3322471380233765, -1.8894531726837158, + 0.7199447751045227, -2.851144790649414, 4.080941200256348, -0.541861891746521, -1.1072325706481934, + -0.6561694145202637, 0.40478527545928955, -0.8838909864425659, -0.5028785467147827, 0.7957435250282288, + -3.4829330444335938, 1.046553611755371, 2.5124118328094482, 0.3735085725784302, -1.0879991054534912, + -0.09173977375030518, -3.4051504135131836, -2.2644267082214355, -0.9162223935127258, -3.4872522354125977, + -2.355233669281006, -1.8244541883468628, 4.704746246337891, 0.4475516676902771, -1.2546875476837158, + -0.44408249855041504, -1.924820065498352, -2.729738235473633, 4.391683101654053, -1.0688762664794922, + 2.174078941345215, 2.718625068664551, 0.7366507053375244, -1.9571187496185303, 1.1222915649414062, + 2.276261806488037, 2.0843756198883057, 1.5358469486236572, -0.14141148328781128, -1.5720349550247192, + -2.324619770050049, 2.1180672645568848, -0.757960319519043, 1.402897596359253, -2.846881628036499, + 2.5358057022094727, -1.274275541305542, -0.14995357394218445, -1.3371965885162354, -0.4439084529876709, + 1.4503703117370605, -1.6082179546356201, 3.0019733905792236, -0.9571952819824219, -1.6500767469406128, + 1.6778243780136108, 2.374703884124756, 3.5679006576538086, 0.20018166303634644, -1.0103645324707031, + -1.480147123336792, 0.19532525539398193, -2.786205530166626, 1.3784115314483643, -1.2978419065475464, + -3.5328032970428467, 3.0164525508880615, 2.0895931720733643, 3.5052330493927, -1.9349405765533447, + -1.311628818511963, -1.7713117599487305, 2.886934280395508, -1.5496007204055786, -0.8046841025352478, + 1.5652999877929688, 0.1403025984764099, -2.6447391510009766, -2.0337233543395996, -0.22587481141090393, + 1.8628309965133667, -3.466338634490967, 2.2385408878326416, -0.858932614326477, 0.6435102820396423, + -1.1014156341552734, 0.6221705675125122, 1.3742595911026, -0.24308213591575623, 1.8533508777618408, + 0.14410161972045898, 3.0187618732452393, -0.33525052666664124, 0.290519118309021, -1.2579193115234375, + -1.3335667848587036, 0.4902459979057312, -2.2434842586517334, -0.6882419586181641, 2.9724576473236084, + -1.2139863967895508, -1.9754445552825928, 0.11754357814788818, -3.2436463832855225, -0.29947084188461304, + 0.9013328552246094, 0.025318264961242676, 0.9405116438865662, -2.2489869594573975, -1.2323944568634033, + 1.4659011363983154, 1.380167841911316, -5.245995044708252, 3.716740131378174, -1.5962101221084595, + 1.7039341926574707, -0.24453751742839813, -0.47277745604515076, 2.6836142539978027, -4.659006595611572, + -0.2703670263290405, -2.802849054336548, -4.558082103729248, 1.3134486675262451, 1.3934195041656494, + -0.9713399410247803, 2.829873561859131, 0.5422300696372986, 2.14626407623291, -2.411435127258301, + -0.8668385744094849, -1.579006314277649, -1.0427988767623901, -0.3021366596221924, 0.7571608424186707, + -0.7852025032043457, 2.0103890895843506, 2.875030994415283, -0.6650004386901855, -4.240952491760254, + -3.7397704124450684, 0.06430482864379883, 1.2097631692886353, -1.621443510055542, -1.7518706321716309, + 2.0040979385375977, -2.1621170043945312, -3.4342057704925537, 0.4494125247001648, 1.0336246490478516, + 0.6141325235366821, 1.1723374128341675, -1.566636085510254, -0.0875391960144043, -1.5110716819763184, + 1.4554129838943481, 0.839878261089325, 2.398009777069092, -0.6458553671836853, -0.7608357667922974, + -2.0972063541412354, -0.596686601638794, 1.327064037322998, 1.2332861423492432, 0.643580973148346, + 0.2491741180419922, -1.1464729309082031, -1.6413570642471313, 0.4765915870666504, 1.1993881464004517, + 0.2358156442642212, 0.8658393621444702, 1.8936083316802979, -3.0983033180236816, 1.2818799018859863, + 1.0561144351959229, 0.18877224624156952, 2.373169422149658, -2.1537320613861084, 1.7804971933364868, + 1.7559447288513184, 0.5495958924293518, -0.29311543703079224, 0.7076770067214966, 1.3824928998947144, + 2.5599937438964844, -2.2310054302215576, -1.3870820999145508, 2.705214500427246, 2.692167282104492, + -0.3191862404346466, 2.2299273014068604, -2.8660874366760254, 0.04656076431274414, 1.0372791290283203, + 0.9051024913787842, -0.7127535343170166, -0.346563458442688, -1.8466299772262573, -1.776979684829712, + -0.7937185168266296, 2.6496312618255615, -3.1376733779907227, -0.5262937545776367, 4.203805446624756, + -3.0495786666870117, 3.059788465499878, -0.6179596185684204, 1.5632293224334717, 4.387739181518555, + 2.1877965927124023, 3.867405891418457, 1.6019251346588135, -3.1097412109375, 0.14593756198883057, + -1.4151546955108643, 2.8710670471191406, 1.281739354133606, -1.9452589750289917, 0.3256327509880066, + -1.0140762329101562, -2.1761093139648438, -0.36153650283813477, -1.9866083860397339, 0.20329490303993225, + -2.189547300338745, 2.0582122802734375, 0.44074079394340515, -3.6016898155212402, -1.0940327644348145, + 0.05166494846343994, 3.9986839294433594, 0.007254809141159058, 2.9994473457336426, -0.17313486337661743, + -2.319499969482422, 2.1396687030792236, 0.23967742919921875, -0.9820348620414734, 0.7810753583908081, + -2.565080165863037, -0.3542521595954895, -0.39312660694122314, -3.611963987350464, 0.042843759059906006, + -1.9587305784225464, -1.0954759120941162, -3.2344908714294434, 0.6816467046737671, 0.7935110926628113, + 2.3788259029388428, 0.8960800766944885, 2.7103538513183594, 1.0750906467437744, -1.3195565938949585, + 0.6368587017059326, 0.09530603885650635, -4.324446201324463, 0.31018364429473877, -1.4680615663528442, + -0.8505295515060425, -0.4297642111778259, -2.9845335483551025, -1.073625087738037, 3.111997127532959, + -1.082578420639038, -1.0510170459747314, -1.0351759195327759, 2.1196703910827637, 3.6743626594543457, + -0.10965263843536377, -2.8268239498138428, 1.9994782209396362, -2.2761340141296387, -0.037992119789123535, + -0.5068368911743164, -2.466184377670288, 0.8389816284179688, -0.9829720854759216, -0.578821063041687, + 0.3714909553527832, 3.968106746673584, -1.0078635215759277, -1.4665645360946655, -0.24487531185150146, + -3.812358856201172, 0.8614283800125122, 1.251778244972229, 4.411714553833008, 3.906099319458008, + 1.1894285678863525, 1.3625565767288208, -1.2013204097747803, -3.780947208404541, -0.7636905908584595, + -1.4467679262161255, 1.9876563549041748, 0.7255282998085022, 1.8526909351348877, 2.2311062812805176, + 0.7617504596710205, -0.3560359477996826, 1.834754467010498, -2.417194128036499, -3.032979965209961, + 0.3447788953781128, 0.193556010723114, -0.4079936146736145, 1.300100326538086, -0.19834625720977783, + 2.222346782684326, 3.1362013816833496, 1.2092983722686768, 1.5581995248794556, 1.3155611753463745, + -0.8380979299545288, -1.2280077934265137, -3.5234897136688232, -0.32684326171875, -1.4621152877807617, + 0.428300142288208, -0.5776108503341675, 0.9278461337089539, -2.4938998222351074, 1.2017678022384644, + -0.2525625228881836, 1.0117347240447998, 1.7265347242355347, -0.10318005084991455, -1.492635726928711, + -1.2622400522232056, -3.0749173164367676, 2.4151294231414795, 1.1957623958587646, -2.7823221683502197, + 0.6365658044815063, 0.18952512741088867, -2.0210397243499756, -0.3540761470794678, -2.876804828643799, + -2.8968381881713867, 0.17692947387695312, 1.728485107421875, -2.2341482639312744, -4.501170635223389, + -1.6425974369049072, -0.9404029250144958, -1.1620832681655884, -1.0455152988433838, 1.3684580326080322, + -1.4598485231399536, -1.6593886613845825, -1.0509099960327148, 1.3251757621765137, 2.258070468902588, + -3.5802016258239746, 2.863391876220703, 1.0157440900802612, -1.5516963005065918, -5.100094795227051, + -2.9607906341552734, -1.1230504512786865, 1.9419206380844116, 0.4938334822654724, -2.3842170238494873, + 0.1679488718509674, 1.827955961227417, 0.10622924566268921, 1.8168610334396362, 0.677010178565979, + 3.0694189071655273, 0.3993656635284424, 2.3529860973358154, -1.4582010507583618, -1.5138496160507202, + -1.5133174657821655, 2.0854310989379883, 1.6874661445617676, -0.6133178472518921, -0.9184160232543945, + 3.041386842727661, -0.8360755443572998, -1.2672674655914307, 0.27318108081817627, -0.906801700592041, + -0.2576174736022949, 0.8814321160316467, 0.9032235145568848, 1.5922852754592896, -0.5044339895248413, + -1.0950052738189697, 0.9084010124206543, -2.3912510871887207, -4.171522617340088, 4.554413795471191, + -0.3333394527435303, 1.5956521034240723, -1.197889804840088, 0.8468800783157349, -0.39928677678108215, + 0.7615669369697571, -3.205524444580078, 2.535108804702759, -0.4366309642791748, 2.1470766067504883, + -0.25451767444610596, 0.23135042190551758, 3.335973024368286, -0.19102385640144348, 0.8413820266723633, + 2.2614195346832275, -1.1231105327606201, -1.3756293058395386, 0.17654633522033691, 0.5028480291366577, + 0.7965704202651978, 0.867662250995636, -4.270709991455078, -1.4976004362106323, 3.333491325378418, + 1.6522053480148315, -3.4461770057678223, -1.0945802927017212, -1.1912789344787598, 0.5186694860458374, + 1.525572657585144, 0.4644775390625, -0.5472983121871948, -4.093353748321533, 1.6807860136032104, + 2.2575550079345703, 0.9947443604469299, -4.168862342834473, 0.09030676633119583, 1.3352301120758057, + 0.37972205877304077, 2.2988173961639404, 0.8671650290489197, 1.040745735168457, -3.316119432449341, + 2.3733606338500977, -2.248332977294922, 1.7465157508850098, 0.19552722573280334, -0.9690064191818237, + -1.8139621019363403, 1.9242961406707764, -1.9793150424957275, -2.789724349975586, 0.18952327966690063, + 0.5084639191627502, -0.054778456687927246, 0.2740379571914673, 2.1619865894317627, -4.095170497894287, + -3.142530918121338, 1.1796610355377197, -0.8848727345466614, -1.2477298974990845, -0.07429039478302002, + 0.9135949611663818, 0.21963024139404297, -1.9909381866455078, 1.99857497215271, 1.4466471672058105, + -1.1016359329223633, -2.8484303951263428, -3.1158666610717773, 4.74474573135376, -1.1900646686553955, + -3.1329240798950195, 2.125332832336426, 1.9798109531402588, 2.6058056354522705, -2.0495054721832275, + 0.028982579708099365, -0.5753974914550781, 2.7390692234039307, -2.3111703395843506, -5.434136390686035, + -3.3772997856140137, -0.37978899478912354, 3.2925407886505127, 3.671295642852783, 0.7639904022216797, + 3.1895627975463867, 0.15414607524871826, -0.6484872102737427, 2.18841552734375, 0.4799572825431824, + 0.1354406625032425, -2.747096300125122, -0.22751712799072266, -0.7596011161804199, 0.5766011476516724, + -0.017207175493240356, 2.4283714294433594, 0.5117142200469971, -0.8030692338943481, 0.44569623470306396, + 1.076960563659668, -1.8645645380020142, -2.2490062713623047, -1.6578664779663086, 0.6149722337722778, + 4.706758499145508, 0.38176798820495605, -4.501796722412109, 2.2427682876586914, 2.1858701705932617, + -1.8162599802017212, 0.9385958909988403, -3.889805316925049, 1.1331977844238281, 1.0191240310668945, + 2.4039511680603027, 4.155160427093506, 4.143398284912109, 0.7778210639953613, -2.2585456371307373, + -1.085227608680725, 1.7663249969482422, -2.8071107864379883, 0.5367544293403625, -0.02325284481048584, + 1.5876415967941284, 2.046140670776367, -3.832700490951538, 0.46683841943740845, 0.5545571446418762, + 1.8768221139907837, 0.7790337800979614, 0.38500359654426575, -2.3040874004364014, 1.1112343072891235, + -2.2824416160583496, -0.026048898696899414, -0.27540627121925354, 0.5449916124343872, -2.154345989227295, + 0.7431529760360718, -0.008791446685791016, -2.407325506210327, -0.46152830123901367, 1.6632401943206787, + -1.7320727109909058, 1.0486053228378296, 1.3803236484527588, 0.3680152893066406, 1.716249704360962, + -2.2865381240844727, 0.14729297161102295, 1.260400652885437, 4.922313213348389, 0.5643207430839539, + -4.42134952545166, 0.464374303817749, 0.59236741065979, 1.2845817804336548, 1.4366343021392822, + -0.0200042724609375, 1.6293566226959229, -1.3861595392227173, -3.4724128246307373, 2.1383941173553467, + -2.1009442806243896, 3.7689297199249268, -2.918327569961548, -0.27357161045074463, 0.9184791445732117, + 1.0513062477111816, 1.9957637786865234, -3.276752233505249, -1.6878246068954468, 4.714818954467773, + -0.9857031106948853, -0.6153162121772766, -3.6428263187408447, -0.30243179202079773, -0.4309789538383484, + -0.03419780731201172, -1.6013574600219727, 2.214989185333252, -1.1272412538528442, -1.6917750835418701, + 1.547987699508667, -0.5724269151687622, -0.47848212718963623, 1.742186427116394, -0.2213730812072754, + -0.8063536882400513, -3.479326009750366, 0.5662966370582581, -1.0524877309799194, 3.702444553375244, + -0.8636859059333801, -1.9768422842025757, 2.1982383728027344, 1.1405737400054932, 0.6146906614303589, + -3.5127429962158203, -0.8339279890060425, -2.914233446121216, 4.411269187927246, -2.3479251861572266, + -0.2184194028377533, 1.7971992492675781, -0.9596229791641235, 0.09081411361694336, -0.4546387791633606, + -0.38310706615448, -0.5399283170700073, -0.2518271207809448, -3.5085813999176025, 0.077769935131073, + -0.5233420133590698, 1.4064757823944092, 0.8371680378913879, 1.9668782949447632, 1.483221173286438, + 1.1757903099060059, 2.9970226287841797, 0.8735387325286865, 0.24936652183532715, -0.7718344926834106, + -0.9049572348594666, 1.9130113124847412, 1.9097952842712402, -3.3667526245117188, 1.1342090368270874, + -0.1385430097579956, -0.6781283020973206, -0.57246994972229, 0.9787319898605347, 3.6297695636749268, + -2.3075175285339355, -0.8269498944282532, 0.8386117219924927, 2.4571895599365234, -0.9069632291793823, + 3.1065073013305664, -0.786931037902832, 0.6695969700813293, -3.896576166152954, 1.6415526866912842, + -2.334099531173706, -2.991877555847168, -0.4740370512008667, -1.0762873888015747, -0.5927379131317139, + -2.2995433807373047, -3.4549155235290527, -2.033919334411621, 4.102828025817871, -2.922405481338501, + 0.9117567539215088, -2.0445048809051514, -2.740710973739624, 1.5500695705413818, -1.4105217456817627, + -3.672469139099121, -0.38663938641548157, 0.1100814938545227, 0.43851518630981445, -1.4003627300262451, + 0.8104124069213867, -2.6236252784729004, -0.40968263149261475, 4.816134929656982, -1.9591403007507324, + 1.2284891605377197, -3.4595632553100586, 2.2904000282287598, -3.7264821529388428, -1.8221375942230225, + -0.44476717710494995, -3.0978899002075195, -1.0302362442016602, 0.49443352222442627, -4.997615814208984, + 2.7403292655944824, -0.9162295460700989, -0.8933342695236206, 0.8124216198921204, -1.433485746383667, + 2.224909782409668, 0.6821490526199341, 4.009047508239746, 2.2991182804107666, 2.677088499069214, + 4.353694915771484, -2.2315590381622314, -1.5339090824127197, 1.626473307609558, 1.9017658233642578, + 0.9766815900802612, 2.563782215118408, -0.2381199598312378, -1.5801150798797607, 2.601571559906006, + 1.7336980104446411, 0.8148760795593262, -3.7112083435058594, -1.3511030673980713, -1.8034800291061401, + -2.950260877609253, -0.4626041054725647, 2.9033288955688477, 2.629671812057495, 0.1508030742406845, + -0.6016277074813843, 1.9043893814086914, -2.119884967803955, -3.048208236694336, 1.3438254594802856, + 1.039656400680542, 0.0982772707939148, 0.29122409224510193, 1.8302289247512817, -1.3148270845413208, + 1.16330885887146, 1.4336774349212646, 3.057568073272705, -0.2635994255542755, 0.3290955424308777, + 0.09837156534194946, -0.4767574071884155, 1.7474840879440308, 0.036291711032390594, 2.057096004486084, + 1.5909944772720337, -1.5554354190826416, 0.45638948678970337, 1.8485369682312012, 1.1001496315002441, + -0.448333740234375, 0.12344249337911606, -2.2758660316467285, -0.18728435039520264, -1.0710699558258057, + 4.759674072265625, -2.308614730834961, 1.3315553665161133, -1.7046191692352295, -1.4248977899551392, + 0.31045258045196533, 0.4682546854019165, -0.5506991147994995, -1.167902946472168, -0.033889174461364746, + 1.2611976861953735, 3.584254264831543, -2.8943490982055664, -1.0763990879058838, -1.9304077625274658, + 1.6052935123443604, -2.1086959838867188, 0.16277271509170532, 1.087416172027588, -4.894248962402344, + 1.6908477544784546, 2.445591688156128, -0.8808413743972778, 1.1168533563613892, -0.35605037212371826, + 1.0386931896209717, 1.4661989212036133, -3.5512571334838867, -1.364258050918579, -2.697364568710327, + -2.279731035232544, -2.2294161319732666, 2.072256088256836, -1.7556955814361572, 1.5964255332946777, + -1.102902889251709, 0.25835680961608887, 0.41171061992645264, 2.6033897399902344, 1.931307077407837, + -3.2382147312164307, -0.6146683692932129, -0.7102252244949341, 2.314451217651367, -0.697837233543396, + -1.0293060541152954, 1.020631194114685, -3.6274075508117676, -1.869258165359497, 0.8600494861602783, + -3.1400227546691895, 2.785465717315674, 1.5925650596618652, 0.5685530304908752, -2.385671377182007, + -2.064885377883911, 1.7148135900497437, 4.472785472869873, -1.6981213092803955, -2.8710732460021973, + -1.5905208587646484, 1.7118372917175293, -0.0628599226474762, -0.024645566940307617, 3.5579423904418945, + 0.043785035610198975, 0.1394520401954651, 0.705904483795166, 0.5603584051132202, 1.9532432556152344, + 0.17339491844177246, -0.5621342658996582, 2.166140079498291, -2.675852060317993, 3.4783935546875, + 0.005500435829162598, -3.0466365814208984, 2.9333863258361816, -3.0374748706817627, 3.3186678886413574, + 1.4340720176696777, -3.4607982635498047, -0.5809862613677979, -1.8670074939727783, 0.31782853603363037, + 5.3407721519470215, -0.11647498607635498, -1.127880573272705, 1.9607118368148804, 1.6104774475097656, + 1.291121006011963, 1.3090026378631592, -4.346621990203857, 0.9649640321731567, -0.3321942090988159, + -0.40106678009033203, 0.6202027797698975, 0.737052857875824, -0.6949820518493652, -1.6063098907470703, + -1.335120677947998, 1.2487294673919678, -1.206929087638855, -3.946974754333496, -0.038611650466918945, + -2.730283737182617, 1.3170448541641235, 2.586440086364746, 0.9609778523445129, 0.9639753699302673, + -1.0613958835601807, 2.2582998275756836, -0.341133713722229, -2.37121844291687, 1.9750676155090332, + -3.339621067047119, 3.8494720458984375, 1.4416704177856445, 1.2632763385772705, 0.49889105558395386, + -1.358637809753418, -0.9810472130775452, -2.008033514022827, -4.180141925811768, -1.8064439296722412, + -2.4055490493774414, 0.8236247301101685, 0.08107048273086548, 0.2551090717315674, 1.6475603580474854, + 2.3599026203155518, 1.4368640184402466, 0.11961638927459717, 2.1902809143066406, 2.4481637477874756, + -3.700338363647461, 1.4484524726867676, 2.2457919120788574, 2.7918150424957275, -0.3214719891548157, + 1.1412039995193481, 2.3801662921905518, -1.1772977113723755, -1.080161690711975, -0.10960137844085693, + 0.23755615949630737, -1.6492403745651245, 1.5414122343063354, -3.5301568508148193, 0.8169034719467163, + -2.7907910346984863, 1.27809476852417, -1.339566946029663, 0.24068212509155273, 1.980497121810913, + -1.103304386138916, -0.9938054084777832, -1.6851425170898438, 1.5913416147232056, -1.6278176307678223, + 0.07544970512390137, -3.0562868118286133, 0.909795343875885, -3.3362603187561035, -0.16795992851257324, + 0.654058575630188, -0.7783452868461609, 3.801968574523926, -2.160207748413086, 2.5146727561950684, + 3.337472915649414, -0.7981523871421814, 0.7141947746276855, -0.44521379470825195, 0.7240567803382874, + 2.8061447143554688, -1.9281837940216064, 0.33615267276763916, -1.3853528499603271, 0.08603048324584961, + -1.1986116170883179, 0.8860710859298706, 0.22947335243225098, 0.5199316740036011, -2.0464673042297363, + -1.6756978034973145, -0.6069002151489258, 2.3337855339050293, 1.6094681024551392, 2.3820090293884277, + 0.8262057304382324, 4.417818546295166, 1.2495514154434204, 0.5728256702423096, 1.7155016660690308, + -1.9175611734390259, 0.8456315994262695, -1.3211780786514282, 0.7857818603515625, -1.7515380382537842, + -1.1971937417984009, -2.808197259902954, 0.7553633451461792, -0.1306423544883728, -3.6146838665008545, + -1.5321255922317505, 1.676008939743042, 0.07836294174194336, -0.9960349798202515, 0.8668472766876221, + 2.137298345565796, 1.2637362480163574, 2.0481185913085938, 0.06509196758270264, -1.6683127880096436, + -3.718193531036377, -1.9311506748199463, -3.367814064025879, 0.012331843376159668, -4.227578639984131, + -0.5755635499954224, 1.583990454673767, -0.86231529712677, -1.809625506401062, -0.3123876452445984, + -3.4403862953186035, 1.0367351770401, 1.1761391162872314, 0.16112631559371948, -4.721268653869629, + 0.07461494207382202, 0.003129124641418457, 4.358157157897949, -0.5005528330802917, 0.24370229244232178, + 0.4912611246109009, -0.6851872205734253, -2.718902587890625, -2.6848104000091553, -0.7679593563079834, + -1.1002662181854248, -0.3475220203399658, 0.002980828285217285, 0.4288327097892761, 1.283578634262085, + -2.613717794418335, -3.3328800201416016, 1.9472748041152954, 1.3897069692611694, -5.0092363357543945, + 2.518123149871826, -2.1634795665740967, 3.7558064460754395, -3.4893569946289062, -1.289191484451294, + -3.036714792251587, 1.6393659114837646, 3.2178215980529785, -2.083487033843994, -1.6242481470108032, + -0.39087945222854614, -2.675858497619629, 2.548295497894287, 0.6715638041496277, 1.0268739461898804, + 1.9520831108093262, 0.2520883083343506, 0.3550087511539459, 3.1073975563049316, 1.3005826473236084, + 3.4222006797790527, -1.5301063060760498, -0.19925060868263245, -0.7549142241477966, -1.7461345195770264, + 0.7768814563751221, -0.8777822256088257, 2.757373332977295, -1.4982573986053467, 2.4916207790374756, + 2.4712767601013184, -1.3903863430023193, -3.3709828853607178, 1.6688079833984375, -0.3028743267059326, + -1.8443574905395508, 1.8113127946853638, 1.7428357601165771, 1.2092206478118896, -0.5405560731887817, + 0.97336745262146, 0.36485183238983154, 0.17854809761047363, -1.6080548763275146, 4.408480644226074, + 0.6553859114646912, 0.9348583221435547, -3.3448331356048584, 2.513455867767334, 0.03000164031982422, + 0.705142617225647, 1.1035417318344116, 3.448455572128296, 0.7622013092041016, 2.463888168334961 + ], + "dims": [2, 2, 320], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/bias-split-gelu.jsonc b/js/web/test/data/ops/bias-split-gelu.jsonc new file mode 100644 index 0000000000000..23fcb488ca4d9 --- /dev/null +++ b/js/web/test/data/ops/bias-split-gelu.jsonc @@ -0,0 +1,1332 @@ +[ + { + "name": "BiasSplitGelu", + "operator": "BiasSplitGelu", + "attributes": [], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "bias split gelu [1,1,2560]x[2560]", + "inputs": [ + { + "data": [ + -0.2565546032426438, -0.4308542731494329, 0.7196725122004919, 0.049034255098233004, -0.6348569496555116, + 1.5952359184580631, 0.8451805251092992, -0.31838310590966934, 0.8187686985927041, 0.5208222841347814, + -1.36437690702164, 1.9883897131851596, 1.5564695381513953, 1.6179857166855847, -1.6925162826813818, + 1.7367654350206285, 0.7210294356326798, -0.16665830399749915, 0.7502629374213177, 0.49174373254887005, + 0.05228599187298144, -1.8492674426703974, -0.13963175206123601, 0.3689713052652124, -0.539726857153676, + 0.4047910328979869, -1.5364223249042084, 1.5401613243237753, -1.5533776810895263, -1.503725108181306, + 1.4980778036916167, 0.3954579111685037, 0.6720461504101429, 1.2082071287930374, -1.9053414457057452, + 0.5161807133471896, 1.876123128617479, 1.4026319996913186, -1.0578832074721918, -0.5667177472093234, + 0.7427152553125103, 1.3309336325346122, -1.250604470605933, -1.5581827636425984, 1.6094888834151977, + 0.8346945311402969, 1.964734881798945, -1.1480686695741893, -1.4692802823913134, 1.443324467502948, + -1.5059857923356113, -0.4324256636772654, 0.47787327572253613, 0.010830153648134555, -0.552568788591528, + 0.976425831078207, -0.3964598095089391, -0.8550178676766054, 1.7829908691025604, -1.4386851327204164, + 1.2586393384114123, 0.6282746108828974, -1.9228284428211788, -1.723501329370512, 0.07875604943128245, + -0.6805248059926141, -0.09444157676724618, -1.2578296605984862, 1.4748594927305314, -1.1326833840447499, + 1.2899356994911972, 0.36076376234182295, 1.0687687434117796, 0.016260539139403285, 0.8790330834375988, + 1.958147414219165, -0.6744379666162423, -0.11384066716133123, -0.8573301336636696, -1.6026243400276634, + 0.947103941481469, -1.9714924616237095, 1.4248854638955484, 0.20743958585530908, 1.1632144973105554, + -1.755170686870751, -0.7194149118428639, 0.14148466161285445, -0.2721811711092048, 1.5603318294054445, + 0.4281402047349676, 0.9777768965070042, -0.4340528216185948, -1.4027853075880978, -1.786451668170261, + -1.0399446061083921, 0.5016682459975055, 0.47912646038882833, -1.4994325065634833, -0.5813174761669684, + -1.7675659877499355, -0.5295448113440351, -0.185677936488422, 0.6133721122115032, 0.8342835422965917, + -1.8851704197237282, 1.8145028466620214, -0.6693817855471478, -1.2825600860901352, -1.9614837268208936, + 1.7498654134131364, -0.9137222020716864, -0.3419963114086366, 0.4718774290086678, -1.087807747894451, + 1.7131835407851748, -1.3262901364654693, -1.3116975227056313, -1.3834647528821646, -1.6786541344436587, + 1.7754973215348864, -0.2608287138885421, 0.7649932766417926, -1.620043859042676, 1.6639963441569705, + 0.8487054193287991, -0.0440434794759188, -1.7095242256175753, 1.285898353046436, -0.8541456852270608, + -1.0975953883519463, 1.6456192019739255, -0.9829899938198103, 1.5115291047537394, -0.1356101800325593, + 0.3163471293050524, -0.7350393764769096, -0.2635909692970628, -0.8961556079942969, 1.35594272476665, + 1.5346085770720466, -1.561332626472022, -0.7482744786530882, 0.7415275042470979, -1.3455806550694929, + -0.17403414417677876, 0.13174239791099396, 1.617942860808114, -0.21372789574955942, 0.1252618392038798, + 0.8356601282506393, -1.4645614042274238, -1.2750059490075998, 0.30114108212769697, 0.718759572394414, + -1.7476419557101002, 1.3774845116886816, -1.5548787928223922, 1.2843809545104072, 0.40166155990403407, + 0.7182073117046643, -0.9178049843649907, 0.23022861413351325, -0.16782915827518163, 0.809538498880066, + 1.5368543357363587, -1.9144180623212463, 0.22370787233170475, 1.754935494591769, -0.300812686337701, + 1.6324016835327004, 0.42942625235999277, -1.821577324695582, 0.7065235323323327, 1.6359230558031488, + -1.6698325562461855, -1.0322767487026727, 0.230813812551105, 0.10476742241238135, 1.8969745358419585, + 1.6692111883958614, -0.13495950982519744, -1.3187891429974368, 0.33479728114140084, -0.3567915710179923, + -0.757581852807693, 1.0471229062934597, -0.15293599163291116, 0.8603181880710826, 1.5915256233625152, + -1.146248153475943, 1.4333844348196614, -0.4549776123350959, -0.24793256123671004, -1.1103266883866416, + -1.8671773666157563, -1.2229905095092333, 0.06266819913029753, -1.0062234212958696, 0.8035035300131232, + -0.7305732410705756, -0.18203039831679124, -0.31014428798975313, 0.3622048186467879, 0.45985097885235504, + 1.4809283400531248, -0.43833106689297807, 0.04401934126909346, 1.5587830597086887, -1.6615450592134344, + -1.0549727490794032, 0.10619243811841983, 1.6457168234317425, 0.8542761080095067, 1.0005810856142707, + -1.5553877748221447, -0.9212057116120924, -0.8894180684757362, 1.2348207479698639, -1.867513894231462, + -0.3703762605562, 0.6409515313866843, 0.07519639585361748, -0.6450469579811191, -1.1954961398506443, + -1.1204107450288525, -1.7282576070578317, 1.3286514210595293, 0.9221546638897049, 0.3164348496087497, + 1.1124119320099153, -0.1657137161424238, 1.8878952460812632, 0.1293077579568429, 0.23358535468938246, + 1.6758942267934263, -1.4201635109571313, -0.15778983434253657, 0.7416701308494114, 0.5883762821820557, + 1.4927448431941999, 0.6935037789859804, -0.7735384747712253, -0.003653232296928266, 1.016177625769398, + 1.663989224847544, -1.860463076082154, 0.586684836283415, 0.35116588158246387, 1.1544623264859073, + 1.900422405885342, -1.4569177282686354, 0.04845063476308198, 1.3176042838894286, 0.7208418723320795, + -1.2204473994940885, -1.4000622968397813, 0.21650376984443565, 1.060631946423273, 1.5077365306547108, + -1.0212630681047514, -1.9198532775330452, 0.7154901442715236, -0.2676631096599271, 1.3808441670529703, + -0.09885367904367648, 1.4777610290353689, 0.056817122949397, -0.21052160911698614, 1.163420787671865, + -0.9278068609037406, 1.6649534188139832, 0.651092543699634, -1.9558728896680586, -0.17393499677517266, + 1.4142540331221616, -0.6414265768865439, -1.796476122987272, 1.3019592923170675, 1.8063975233357983, + -0.09846336861938809, -0.44669537007711746, -1.8509683249130777, -1.2535113000292872, + -0.26272617023889033, 1.618948771545468, -0.24333144562387243, 1.684567206060013, -0.0684671993602386, + -1.2513800864625626, -0.42209759253339296, -0.5204624726492675, -1.2354597847242017, -1.8741410257015954, + -0.4694216976292269, -0.6222814585320586, 0.21189389441688178, -0.07825818775306104, 1.7163253327182595, + -0.5828872666159741, 1.2859942412787504, -1.6213500187165488, 1.3292413342594758, -1.8919626255232016, + -0.5376250474966104, 1.8689278008038874, -1.7144714737277686, -1.0258522713372873, 1.7337707335280257, + 0.7736026666583085, 0.5294771168151202, 0.018158706442022776, 0.06181604275240726, 0.39921503387415935, + 1.1876014889699817, -0.3054069392890115, 0.16369159901033914, -1.7827759163094807, -1.972081887839714, + 0.6481860379313247, -0.35352594703309315, -1.3849570316603321, 0.623606163848514, -0.05008698162477465, + -0.7604090786231117, 0.018133973172791862, 1.4085369489144304, 0.0006579664080073044, -0.945901823508092, + 1.9768973551218298, 0.8031916889887203, 1.787407754969756, 1.4724476919448328, 0.4025600955058959, + -0.8188218839399735, 1.9450326825091242, -0.7081203957970317, 0.2756639708632225, 0.5758621242189532, + -1.3400477780804723, 0.9643250467598792, 1.6902983107628202, -0.9456991797498446, -1.214924995943468, + -1.7096508074160557, 0.6280377071410248, 0.043386564374169545, 1.4018487692877626, 1.9387558240379494, + 0.6404760132931342, 1.9907639837009397, -0.9629506658962566, 0.44160372290058625, 0.5311677453788013, + -1.98443782839251, -0.8098531120098231, 0.6492275695305256, -0.8908778962502977, 1.3356991895363226, + 0.22938352294803988, -0.7322473322214602, 1.5046263929054788, 1.3645550724874669, 0.9961702988823644, + 0.5139435687003742, 1.4501909090527283, -0.7800082224423388, 0.35291863289434566, -1.146090063932382, + 0.30350456267852977, 0.6439700368534709, -1.9972080095325557, -1.6099893116712707, 0.8052080484401971, + -1.0954234253610906, -0.7337556567039494, -0.31964730584596346, -0.0775965423638807, 1.9583401341749038, + -1.2551023178822023, 0.8562333066218226, 0.7147503883958679, -0.7448967487912448, 1.935551654646078, + -0.8599046782377142, 0.2970146461500134, 0.8930246286483605, 1.1182036059156406, -1.197201583834416, + 0.07161352111517694, 1.9408265571293892, -1.0932315538260982, -1.599500716435764, -1.3924326798020594, + 0.11983524432330928, -0.7679553338056699, 1.8370796383852186, -1.2704307623609719, -1.1997771708610943, + -1.888947914144012, 0.55632142669847, -0.34454450486308286, 1.4998341008691707, 1.5090724650893934, + -0.39294559939271156, -0.5834694562549148, -0.5369747247041268, 0.09250855982787431, 1.2568497810992714, + 0.3053435200201413, 1.9380606411981267, -1.7300384067590375, 1.8430090411386875, -0.09705704918122926, + -0.9748664840480341, -0.13310751475233928, 1.8695019507218804, 0.6192212840529407, -1.8594506513740239, + -0.5527467955831558, 0.42368164563035027, -0.3623000694050589, 1.4671722318441134, -1.9926015337743834, + 0.6390178007629022, 0.3930378886525858, -1.999173071586763, -0.10690790413254625, -1.8526590162958945, + -1.7626087675511721, -0.37087181562117433, 0.8387521856704474, -0.19069686876676517, -0.8188898303583647, + 0.12674602912986455, 1.1094188774808789, -1.3557441716004606, -0.6605764711080546, 1.5447276204253857, + -0.8769307424824397, -0.2545000118362921, -1.0291738091907403, 1.7938016119427465, -1.9984894749397677, + -1.2827802315328531, 0.9376743174843778, 1.6630184567373671, 1.3959796712419958, 0.9312602680557571, + 1.1022358675981447, -1.8339474355859497, -0.4903489710565081, -0.8895348660922506, 0.6526243952924284, + -1.3544145682501627, 1.7693021030400526, 0.16080394788547636, -0.012568942703699904, -1.15491020041351, + 1.2537510862653143, -1.0969062449531704, -1.2759047604132565, 1.5734543561168284, -0.10691180117407129, + -0.26934435340525376, -0.9933427134070101, 1.7043012450253494, 0.02549752010787465, -0.08701034102069993, + -0.2674176298500974, -0.08499649536913978, -1.9065384537204872, -1.8548124018210599, 1.4114174077197168, + -0.1003529232852527, 1.4168803569119728, 1.2546078925285231, 0.2836297019652072, -1.741854513232557, + 0.4335127344631058, -1.7722508801895138, 0.8906617987453824, -1.1907389082962334, -0.4570346317067928, + -1.65487466866219, -1.6799510105998428, -0.7345507991250928, 1.0450995087483994, 1.6718914380761127, + 1.9492487109001955, -1.603110762963735, -1.1386896023097748, -1.2314080609240623, 0.9155801473532383, + -1.7678615945378713, -1.2302067186131085, -0.008026479019882515, 1.7076934044219962, 1.2976204179783615, + 0.23979601472188072, -0.31452550079775676, 1.237656674724093, -0.5073967065142844, -1.4684831151454167, + -0.31751104037333455, -0.18632018989018295, 1.9021272175218185, -1.605133256102862, 0.7057429809617588, + 1.131130449558933, -0.37696933268230914, -1.9999062126002407, 1.2947486718643644, -1.5494680428486927, + -0.619492250822522, -0.7119383385868714, -1.2267826883342403, -0.05700963867832698, -1.6002949090802074, + -1.235668500595783, 1.0038351588940309, 0.9591804998296345, 0.3500270456646062, 0.9985378178389439, + 0.37800440518144107, -1.2954921521047265, -0.06652125976806822, -0.5529207993209635, 0.799277104478481, + -0.3195922296721463, 1.278513926080179, 1.7667762251394539, 1.6741322500537104, -0.7688735332563814, + 0.3718086223078023, -1.8559862073263451, 0.6460500799921309, 0.9538174223011966, -0.9599765113376018, + 0.2363354873707415, 1.1507344321172663, -1.5062754878549311, -0.6521456232305631, 0.26796953563726245, + -1.6029382384422037, -0.13330793525728346, -0.4700350464240364, 1.2426614036874923, 1.8410794059261688, + -0.29943068600730616, 0.23080019611542912, 1.1289988352417586, -0.9307260459011948, 1.5075366130642047, + -0.75430635423512, 0.7786360347486365, 0.43281144997343457, -1.9253410424489141, -1.7129308932126097, + -0.7115090993554389, 0.09728563961266712, 0.9612524095722907, 1.2677697737234288, 0.17580981533895557, + -0.07099813019347945, 0.6420914403628677, 1.929710466274961, 0.8224262711223203, 1.197715527419664, + -1.5306699759366271, -0.4379817520480751, -0.6193479789854166, -1.4632019512465932, -1.259260972161262, + 0.8218577982330961, -0.5872036188804506, 0.9619144554878618, 1.9136356497081524, -0.0054634952730427955, + 0.918959735070624, -0.5711442865104566, -1.7438335683612998, -0.09911336983085484, 0.9296115985185578, + 1.4616523373203654, 0.46274661162826547, 1.0360815568791644, 1.2212946749865532, 0.15313173395745583, + 0.822676634908933, 0.4284785502131063, -1.003861397431101, -0.1736541682765118, 0.6790160543077883, + -0.14098755750254632, 0.3123446972161563, -1.1989246192108354, -1.8129870085835993, 0.32605203312965525, + -1.1125265808505294, 1.7561842733362214, 1.7739812007749807, -1.4043301791912688, -0.8943531190371674, + -0.6297367970706924, -0.7981479808539875, 0.8794315514521189, -0.3994743473829594, -1.401573672109106, + -0.8912997118364743, 0.6495541504851907, 1.9230113410471903, -0.6785063630402766, 0.5089535993984642, + -1.8974271070828808, -0.2620702810213116, 0.1991815065696141, -1.207848636028186, -0.8986511371196979, + -0.7061642892592088, 1.9332451504308255, -0.9830761395756227, -1.244182804715014, -0.017993897992697683, + 0.8786296671983029, 0.540637680299664, 0.8969343374913903, 0.7955401708949532, -1.529657800511421, + 1.8639189521351138, 1.067177379712641, 0.7401683858489143, 0.7294620382453241, 0.5904153067983531, + -0.49843243603051146, -1.6804352670210738, -1.0693615378859294, -1.9090165306219946, -1.4036168596502474, + -1.2392401335977272, -0.7753989839018987, 1.8934648648263472, 0.3583307892541949, 0.1338886754962525, + 1.8690213745085558, 1.455690685933316, 1.1699346906107797, 1.6059636790463596, -1.2445702023047094, + -0.26540435557319864, 0.7367777535407951, -0.06848824820724975, -0.35847294758367365, -1.8422274132326661, + -1.0257628610133205, -1.6669678072109697, 0.37716994278958094, 0.238235954755047, -0.2999780255928437, + -1.8840499434617515, -0.2146515363892023, 1.683440503823257, 0.11413961404491424, -1.6059835147695, + 1.1951777078503847, -0.5174290916747344, 1.3440252628311207, 0.21500397686099326, -0.9421891710579917, + 0.0593410986318057, 0.34006850000048683, 1.644097326727004, 1.74676874874047, -1.6742906231992345, + -0.33329412020743643, -1.6254771174048148, -1.6581767039373991, 0.791705460659057, 0.9383035214148672, + 1.7805390345327456, 0.5776366760158806, 1.587860436382865, -1.5903762069130911, 0.16034052878776794, + -0.2414627652627388, 1.2751236768892227, 0.3209997221960421, 0.31176177950076234, 0.6234148156263783, + 1.144504541840126, 1.1535423529138784, 0.11665655599341473, 1.9697764827003628, 0.14558312336598078, + 0.36578124791517297, -1.769415346962682, -0.8303724165129278, 0.44703666963932154, 0.35095942056362794, + -1.197063815711486, -1.5390788457973201, -0.8313129454097989, 1.952907404456571, 0.30612523411761394, + -0.9380264922530621, 0.9822286847681259, -0.12269281399330456, 0.6557752769532215, 0.48870196679108435, + 1.347011507573625, 1.6519563808835915, -0.7385795429014586, -1.048048723619047, -1.0859902684402032, + 1.7187556784188924, -1.8663335499394762, 1.3448325108921972, 1.5973182779732955, -1.9246562924691855, + -0.896157435511153, 0.5381287932952938, -1.7180528810790427, 1.6998135385965343, -1.99001646343652, + 1.7545133161159834, 1.9333853932807186, 0.509746393965961, -0.0675591816104566, -0.5487596842885525, + 1.1654829640998834, 1.2418276538086106, 1.1046551528849635, 1.1541227103848648, 1.3688916432757363, + 1.8682574813085164, -0.9855654818965549, 0.5606398966348412, 1.9470096222814695, -1.9903658034181957, + -0.48344153717960925, 0.6353768282033556, -0.0513621772163555, 1.097950015862886, 0.24567995461843584, + 0.8059244051805585, 1.3521742423136072, 0.47230481196519936, 0.6133818327341913, 1.9135550317308576, + -1.8499763818701833, 0.5894967733509713, -1.920326871370679, 1.6405640996262383, -1.1695280806136248, + 1.4241423736983672, -1.67349053378969, -0.8405344375321118, 1.4551541878625045, -0.9884334361969964, + -1.274246357510064, -0.5067514453663833, 0.5492300144792424, -1.063358889380532, 0.3303561470509724, + -0.9579434760073298, 0.057620205260851876, -1.4742255794760677, -1.6308108759776871, 0.050309644830520917, + -0.4069039662995735, -0.046184249642958086, 0.07949429360802096, 0.12024378717844808, -0.7878588154579758, + 1.472302474285109, -0.29877233638272127, 0.3469939919860652, 1.0030914450113793, -0.033399023382078674, + 1.337733463292298, 1.322404140400704, -0.9344843531853249, -0.7431360466648114, 1.6844959250870701, + 1.361728870488129, 0.9343674315961268, -0.027258361373763584, 1.838512510349391, 1.6503912849225708, + -0.4099533673620952, 1.0611278108159041, -1.6127864143155026, 0.24393478918735934, -0.44450380494664365, + -1.6597118492450527, -0.6490315790577066, -1.7612694260424497, -1.4276602313620197, 0.940659328987655, + -1.263148442418843, 0.5147963537125868, -1.3438202624356137, -1.2843267233008957, -1.2655953556300048, + 1.8008689611198623, -1.0529595001338974, -0.8423440348538298, 1.4068621033910276, 0.007761032105414678, + 1.0935924640537378, -0.09622562614772612, 1.5915223843314257, -0.6702176163118629, 0.773624899360053, + -1.0994916609949348, 1.8469641173001987, 0.8777243936892871, -0.7039703784245424, -0.6127726098737645, + -0.8456557917329288, -0.21546583787614892, -0.9360415939958875, 0.5705354560747677, 1.4330326228154542, + 1.5996006008539831, -0.07367645458791117, -1.5628497471538445, 0.6808136908868017, -0.6639347562269995, + 0.9713336233567063, 0.914773615335891, -1.4236420984162779, -0.02242621653471044, -1.898261801965826, + 1.8572128094883267, 1.5674319640876826, -0.825255227018233, -1.2984530858686814, -1.8947198959796596, + 0.587966233462053, 0.7276082291256891, -0.7118337876069631, -1.9491744048376294, -1.9749840187715382, + 0.6636341374951193, -0.8911042565498537, 1.1064802004273853, 0.5698018751524465, 0.31161591659845556, + -1.4542202550786563, -0.4973178781500307, -1.261090074552551, 1.202428109165787, -0.6502124259985935, + 0.875011031676804, -1.0459657545700463, -0.42208724851625856, -0.056326319293169114, 1.033820082809723, + -1.7520891696720033, 0.8087980713817391, 1.4090272899721707, 1.7934272795205217, 1.625604511965781, + -1.1093024540507015, -1.1836309931833142, 1.2726582815916636, -1.348432635892622, -1.3318038843301023, + 1.8025664988508767, -0.9447266569517438, -0.5044961285117484, -1.6608791173499942, -0.5404247801346038, + -1.8417295161873213, 1.3703119773103376, -1.762654389000578, 0.654516908110395, 1.7444654879377754, + 0.06700313474848674, -0.11772438211787417, -1.4346934108510236, -1.3471639350541027, -1.8846130174975517, + 1.4341206352848488, -0.1781484331915868, 1.8660723422398098, -0.14067413462914313, 0.8690053039391579, + 0.14439977774786517, 0.39859516757572777, 1.338326821281048, -0.3540839451368738, -0.3799467769042604, + -1.6627115692361185, -0.27677435345149703, -0.9817434354820271, -1.7561599198792992, -0.09529627982909172, + 1.0892976295039958, 0.8013191504371955, -1.6378286446982058, -0.2422910554223776, 1.192900022577974, + 1.552603422632739, -1.9305123473184267, -0.20948858983751073, 1.9616644595376727, 1.6353585267617419, + 0.8302939534322658, -1.3824308029410997, 1.4671017215695894, 1.075525400633838, -0.8516882354421424, + -1.764852757300301, -0.6508588489877889, -0.5864767969763536, 1.1318668208548859, -1.2774769138035325, + 0.28054164298575834, -0.7225628233832841, 1.0930686104208345, 1.3757645119448334, -0.8436086923316299, + 0.15317111418019813, 1.0696425581898357, 1.2977798557148477, -0.9586985895916751, -0.9572006108424791, + -0.5173013825178812, 0.9558842805311762, 1.4588133070282732, 1.4300025626225263, 0.9236625391445488, + 0.6245988744867041, -1.4555968079681971, -1.9081528171774265, 1.4969305861374398, 1.8013611610843867, + -1.215264739954538, -0.6460349115684512, -1.7385666591205897, 0.4952215434575882, 1.7813317720717237, + 1.4779804705753685, -1.6448372710981527, -1.7914377349318746, -0.8701351514587072, -0.6086582613844049, + -1.7386736832736416, -1.6630595398426538, 0.06404471619092522, 0.12472538361116214, 0.25173024087120677, + -1.4925192750724054, 1.1326320024353294, -0.6723638830002354, 0.9276409891081574, 0.5160132920113458, + -1.7619800169226787, 0.7729306209435105, 1.4852229135771582, -1.4043797098218525, -0.26979894999654697, + -1.5676449182307888, -0.8285263778416416, 1.7968643376162596, 1.1503154963149846, 1.3682632091264955, + 1.9014644171021402, 1.2035803257561772, 1.5807589620662768, 0.6681530530278019, 1.7430055744872055, + 1.4516895295108698, -0.6636088362253298, -1.3265544726243066, 0.7260245399794885, 1.0068518018712478, + 0.2208570840730273, 0.8459119656002851, -1.8915833254450725, 1.2162158433713186, 0.9752766886721753, + -1.2607916054285697, 1.9003684087234687, -1.6824694825149118, -1.3545700227621689, 0.5912583336167279, + -0.7008114462062913, -0.022208913414461406, 1.4871167887266532, -0.47220337297808346, + 0.001402495408155957, -1.6337795432062068, 1.6557142707874517, -0.21911880468117495, 1.994215681564901, + -1.4675327906481472, 1.9744129850181764, -0.10991781070844464, -1.5582267736493964, 0.9509729601966104, + 0.27383527366630744, -1.0109967848840293, -0.2652951445752292, -0.31773126890169845, 1.9347214689284513, + -0.48900940865557896, 1.0348946328564068, 1.6101647718098393, 1.2224869337691553, -0.24528963586677577, + -0.8282995437227312, -0.74214677104667, 1.9022077987920571, 1.065511429772795, 0.5557978714258756, + -0.8552846035431614, 0.14131421568194025, 1.6415849500887356, -1.0979455229862056, -0.1899406250116744, + 1.882935380340533, -0.47086245203046495, 1.1173180765349162, -0.38373005169056196, -1.9204322585926832, + -1.9947555620271595, 0.25610856180969677, 1.672078838942273, -1.9323275104124873, -0.7955457526088345, + 0.9709167319971037, -0.6356155194456985, 1.6590519014057472, 0.2576466445092809, 0.06826734181142502, + -0.8077944062086111, 1.6116753503934271, -0.360670679232701, 0.665216992360695, -1.3443131827622485, + 1.004588252277319, 1.4090953833875757, -1.6465558763926547, 1.5390983488230878, -1.3071804786368446, + 0.6990024492525402, 1.0093027361688671, 1.097146827869869, -1.1912001568366906, 1.630219037886489, + 0.5744645582929406, 0.6113886454941007, 1.9913520749447864, 0.44093025553359233, 0.08768839917963245, + 0.8058107685571905, -1.033178784078995, -0.8225351475861098, -1.033391812065922, -1.9518439963887113, + -1.4013652343418483, -0.3575891026292508, 0.5391177792156094, -0.589608397581415, -0.8183723923536403, + 1.2663801923646707, 0.2780137273448853, -0.6282934713194672, -0.3515950748704757, 0.32366450506041744, + 1.0378037684314076, -1.4606869437722425, -1.1777315953586012, 0.20077305098986375, 1.1132819638088014, + -1.3403580753455646, -0.8471624166708764, -0.78670588443378, 0.411075409745707, -1.1567621787661597, + -1.644706858279382, -1.6923798602467333, 0.7295344690619672, -1.6826221173466598, -0.5883451091973848, + 0.401937953067649, -1.9630301043637237, 0.9577847247199625, 1.2919304030896734, 1.225292029861409, + 1.0377797459888836, -1.4163731758310805, -1.1848812939548496, -0.6833661697787141, 1.8924034115938815, + 1.0815655745578274, 1.6514910861689147, 0.13638193195368675, -0.7978052236681465, 1.534221720841157, + -0.9153275868493269, 1.8162414196916217, -0.18899261449512395, -0.41163974536752157, -1.7273888453908217, + -0.2259883780746561, -1.5648477168095223, 1.483239033708478, -1.3942275569974942, -0.47490997296002035, + -1.7533267576457146, 1.885960859958221, -0.2403666825901727, 0.5586086137789854, -1.0921597903975728, + 0.31058170316775424, 1.3277885858021783, 0.7148623072607876, 1.557240774068096, -0.37204592237067047, + 1.296024489070251, 1.0544578912529046, 1.6854705704653057, 1.511922732706835, -1.7363032773317322, + -0.2092964598440208, 0.4481058099152593, -1.6737966633530679, -1.090792800318261, 0.6450021408332125, + 0.20592552116749374, 0.8156159228622641, -1.568521435345816, -1.9264895113126421, 1.3354691459199177, + 0.5736596388344699, -1.1328195208816618, -0.9748407260828982, 0.12745284809908775, 0.12755406203831754, + -0.402313518792381, -1.2084251106645967, -0.766591500306526, 0.5831135977679764, -1.9502950151238094, + 0.02368696030780093, 0.9474278651806163, 1.200170490724969, -0.23626565831359603, -0.00155175398078633, + -1.4962274409033034, 1.4172670942494623, 1.6336646415640272, 0.9015591359514898, 0.0028804858578190817, + 1.8517288251219313, -0.5845464836204135, 1.4764716372203344, -1.3353282541360834, 1.2177029719595032, + 1.990053698845105, -1.5060495488639427, 0.3106662987296689, 0.2870946491789992, -1.0746935708007666, + 0.895601302779335, -1.5268871259283276, 0.2678390558787376, 1.864513144299627, -1.509732497733654, + -1.5904836915336782, 1.677740063904718, -1.5200319590355935, -0.7931901348349131, -0.057131201288623146, + -0.5299918958977132, 1.5844048394762984, -0.9074859645216646, -0.6250324408575869, -1.7483344039680153, + -1.648406830815511, 1.790525414138699, 1.1087046319299025, 0.42706999037150517, -1.1252987787944102, + -1.7926103436944185, -1.661929519642099, 1.6441197132107996, -0.7994467312378033, 0.4924994603048294, + -0.08241669477843683, 1.3633290277505736, -1.5418640660707128, 0.05946840692833444, -1.6001798975438088, + 1.1048132720023895, 1.0870814820606292, 0.5831092120375949, -0.4133915802678638, -0.9142815961432058, + 1.9472407631770015, -1.4219510012962004, -1.1714172709742634, 1.4079274160525728, -0.34506772652017137, + -1.0158502812481522, -0.8168949547547397, -1.1457275079452227, 0.12910364073916902, -0.4662867248454887, + -1.9437241965128846, 0.07261938805201762, 1.600763502162561, 1.0777174066413018, -1.2723083195923186, + 1.3113857387077417, -0.5228664205198017, 0.4450488409424249, -0.5683762553586282, 1.8256298065201282, + -0.5555324792306022, -0.5028443495682451, 1.550965251170056, -0.47857481650681066, -1.008285169693293, + 1.9029801635553145, 0.7739617661198315, 1.8099201835531415, -0.3994817059435505, -0.8127756747385817, + 1.4033307810724054, -1.359844813448376, -1.3846355261466767, 1.8540201060398234, 1.0970430821179962, + 1.3778953118217157, -0.6311210216871839, -1.8270928773238353, -0.29073753592093343, 1.063407723752193, + 0.769348666705115, 1.069807859635052, -0.13297318999054397, -0.7495627942438086, 0.4278305696495597, + -1.7534013899377605, 0.20503621122624516, 0.8877917416885026, 1.9219368972107151, -0.7795858126195832, + 1.8045722365205155, 0.01848994995789255, 1.9822395081707462, 0.5682615557282436, 1.096590333951183, + -1.1060317246730396, 1.0869871276155, -1.8681569307257382, 1.9498301468214843, -0.5725242199723457, + -0.754441782550737, -0.0400922717249097, 0.010590885689596874, -0.6969977940409491, 0.6620666861327669, + -1.5969982725685883, 0.17340909047153819, 0.47755863566996126, -1.6291589696000264, 0.5780359168220688, + -0.5173306807336635, -0.9514848124225095, 1.6169705288679577, -0.42893373795490586, -1.1528283547930025, + -1.0400955977716713, 1.9827399061918518, -1.327376858845513, 1.6043081593845727, -0.11039938533269034, + -0.24997046912904874, 0.47014693724974954, 0.729145170631436, 0.7014015249399357, -0.210704593378912, + -0.8898579600723409, 1.127679820439794, 1.379686041589359, -0.781681363239616, 0.2858562618428895, + -0.7131287792008063, 1.7878252016142655, -1.0588662101910593, 0.4893786031875278, -0.8406649057821527, + -0.1534439859760095, 0.2331374640695536, 0.469582749149386, 0.4137313563589631, -0.9769266749700227, + -0.37870419628070984, -0.778681560279427, 0.9076631441596534, 0.9624606550623103, -0.10250544734763523, + -0.2518054637298146, -1.8389925418951307, -1.7523906632013846, -1.8251290048632933, -0.49859600328817244, + -1.5645964425382966, -0.46692843392327266, 0.3025335867203003, -0.8785897670006184, -1.7537151065566459, + 1.9360855593038684, 0.03479121265661611, -1.9402430169146303, -0.8981491188900641, -0.9720525655542991, + -1.2872361339345169, 1.9657361835316278, -1.9654227152525223, 1.4590349841798576, 1.7417951527725704, + -0.7636287264036836, 1.6938802231015364, 0.3969017158868704, -0.9308527980280088, 0.44396078845267084, + 0.8114974124677037, 0.22323733905690712, -1.157000049324795, -1.2116172131012615, -0.9832275983234169, + 1.773233656033006, -1.6481062641009663, -1.9471951041445985, 1.1654342998679619, -1.8679076405021187, + 1.9134708504745168, 1.9270958489182615, -0.4877809076980144, -1.4674512268342745, 0.006115322418878577, + 1.5523105881305073, 1.008791751555858, 1.7292932521498168, -1.2446660428848375, -1.32058622408507, + -1.6942157582592943, 0.9218514004458749, -0.823621328629307, -1.0203195530541063, 0.07206341884947509, + 1.214451058931207, -0.40454188129729296, 1.0066091638178039, 0.0801907243894604, -1.3250420419558173, + 1.6740542746900093, 1.1525840942242223, -1.7538751715267296, 1.2289357346449874, 0.44273632243705485, + -1.480515264351281, 0.7203216915034076, 1.736757457268701, -1.6126702540429125, 1.3353291202473017, + -0.3386246414186953, -0.15824184756675486, -0.9082701908024067, 0.19739090770367174, -0.6056382353292964, + -1.5021205949442713, -0.13004055376809376, -1.9680841369920756, -1.0085004366482355, -1.4660753620146698, + 0.5310372051600734, 0.6252656799282139, 0.28834705715856845, 1.9582690133559009, 1.0284365097891248, + 1.5791162728965862, 1.5890315798131152, -1.5740074592032895, -1.2346249557276874, 0.09869464843015763, + -1.9888659899819219, 1.6245510003031853, -1.8240356816949115, -1.1160775047609688, -0.9717920085031029, + 0.026054846056297265, -1.5990410562524549, 0.6191498034442082, 1.4318181978005144, 1.1449640789498945, + 1.435002701727269, -1.8991365043582489, 0.9679929619919578, -1.9806014397574234, 1.4536549482052994, + 1.2369898149063783, 0.0942097559544548, -0.44988290575276135, 0.6393419762034132, -0.6790983093222227, + -0.33932133359722183, -1.3323692893954657, 0.051614649124500644, -1.4850113224086279, -0.6288795685974646, + 1.6283138375987471, -1.7789341475324125, 0.9641353407036455, -0.8859758584446853, -1.3905521072955613, + 1.248121111974715, -0.148429098769566, -0.4602995590886545, 1.1003026112521521, -0.6827850627285645, + 1.9771039126131766, 1.227966495460028, -1.9054616874551513, -0.17250407406937818, -1.5976498513957544, + 0.04439303257666172, 1.51163887169748, 0.646404644877455, -1.721204116533527, -1.491970732730513, + 0.2989738629420451, -0.9550137794523614, 1.2025118418310514, 0.17278672215112767, -1.7073672613962216, + -0.6869493466963412, 1.4221175269968072, 0.8136078027936655, -0.09457917718456343, -0.6500956499585442, + -1.0403599385456666, -1.4197371934035345, 1.2190245805380533, 1.4761477510781154, 0.8005307377799946, + 1.5249741316932477, 1.0855648961952538, 1.9224829076060752, 0.2869108546051615, -1.5091598215762332, + 1.4501191712612789, 1.8870984229831107, -1.7302116354922958, -0.3545753940936196, 1.442437478066842, + -0.7823996675200249, -1.131407428223576, -0.5013183847734197, -1.4979613863674297, -0.24415013272007524, + -0.6627159008601602, -1.0181273908903332, 1.0328362758216585, 1.260250651113478, 1.0138156752044916, + 0.5371615864333288, 1.8553073783399325, 1.951357678034606, 0.7194607934648083, -1.0589630923826236, + -0.9620500497996112, 0.67763129815489, -1.6212754527879456, 0.8824272362591019, -1.8034693359581802, + 1.7422096047831568, -1.010660238079712, -1.8120346724245788, 0.7326343612924848, 0.8912492318672749, + 0.669928902828123, 1.5191540472733163, 1.9408210365763843, -0.2675149665538239, -1.5368478010395314, + 1.1378158248590227, -0.14340274845268297, -1.0420396266499052, 1.4238975359786146, 1.0186548966434374, + -0.37302282044451296, -1.6665521929489486, -1.9414179347758749, 1.9845099037831098, 0.25190468995929116, + -1.3565033558826896, -1.430084412385555, -1.9049229412063653, 1.221645795300951, -0.5219823891712627, + 1.9368378730562315, 0.7035479701902823, -0.3754047433995442, -0.14526093695990294, 0.17885379911634125, + -1.7453384743240203, 0.8465052490601339, -1.6823293227525946, -0.8161516604165566, 1.8312443096922442, + -1.9192510920287678, 0.37723942145518574, 1.725107407940751, 0.21381615826643507, -1.1716001801855649, + -1.0611345679669162, -1.9732355910502637, -0.4461777828323239, -0.23052068805350423, 1.218942575723136, + -0.22572812681057552, 1.9428383183668947, -0.06997110252961747, 0.2238383470247678, -1.178747770654164, + 1.289511874981887, -1.9756722906104285, -0.42650188553983703, -0.5494388087356263, 0.6619450518994654, + -0.8233262341940337, -0.551580552700945, 0.5817278377322888, 1.8685269618613036, -1.3000953227319512, + -1.4283838275880294, 0.5999358561474102, 1.0645958312240245, 0.014353697508937557, 1.2413161019277963, + -0.6897291610817131, 1.12297456609934, -1.8432752527139797, 0.9084575027035138, -0.1243916597867818, + 0.33130042087381195, 0.66895259903393, 1.6557830983974302, 0.08287276029065538, -1.5357530396168295, + 1.2343785017416087, 0.12902705924095592, -1.3495271524839367, -0.8728580889008626, -0.5244139433889465, + -0.3879006179068245, 0.9188768930207276, 0.7793805226934829, -1.4393071368258976, 1.613705376668844, + -1.0843958887438374, 1.7124014800074736, -0.9964777990871019, -0.8670020603088275, -1.779081809150778, + 0.12646904349631516, 1.6571805108463433, 0.12600417813102993, 0.49996900162664826, 1.9525284377831484, + -1.9652455426976498, 0.9494015016667543, 0.38443667884960586, -1.7724098362330505, -0.4082684743558227, + -1.7879879942659969, 0.08275123875162116, 1.7475525868209036, 0.6792010104118127, 0.039437186277630154, + 1.2836761397306624, 0.43674745284168726, -0.758347022092078, -1.7870493658991657, -1.7978809807072968, + 1.0584586913661846, -1.1157899056211305, 1.2701741054323072, 0.4374863843249699, -0.8214325870445185, + -0.7728421127264511, -1.5577282427526624, -0.23594830217496732, -0.5391955305244664, -0.5624792843906574, + -0.1938002353617021, 0.5223396103139288, 1.538375944848121, -1.9694136326451632, 0.8633772521280756, + 0.7208433609240901, 1.7666620769537023, -0.43532694305019426, 0.411879568237544, -1.2098790166305973, + -0.36423075107954883, 1.8891146830740784, 1.748054742284471, -0.24142532135673545, 0.5927554266793296, + -0.9152877032152276, -1.8063969967782656, 0.48207466271184707, 1.213706670479203, 1.6604107491663145, + 1.7558340275933135, -1.7932219074543312, 1.9856023833866967, 0.4913618217526574, 0.6789406744471576, + -1.2508048490603496, -1.3830358311750315, 1.217800508245622, -1.363187904834711, -1.9078698738816984, + -0.28177773798420525, 0.5582837083561785, -1.3049002248929513, 0.08913900413869147, 1.6980987887729269, + -1.1912978387544912, -0.8931915280846656, -0.18855808865450108, -0.45803053755017054, -0.8712638141863449, + 1.7094200997790034, -1.663338331843227, 1.2015315071432422, 1.7919441084478605, 0.31409629120692095, + -0.2597694212249051, -0.9168044364289045, 1.216621830723616, -1.6662413518480443, -0.8080880254762635, + 1.4073145445825483, 0.545639136243703, 1.4013255263406696, -0.5546097014600111, 1.0974522886316729, + 1.5316488607435108, 1.834526451907207, -1.9144935270991548, -0.7645822781961176, 0.06654825465795611, + 0.33190077251492234, 1.5544507881908975, -0.3538304760036741, -0.9572664339018262, 0.9528358413200575, + -1.6745715540252952, 0.507318989134915, -0.20165453900363595, -0.785857858404909, 0.8550126703052818, + -0.0012624172438071568, -0.9598028440876272, 1.2244820355277088, 0.2255077411064299, 1.8056720264962882, + 0.037756919076751494, -1.343540974005494, -0.5094599890534655, -1.9469287581463544, -1.9209187756131971, + -1.9807700974543714, 1.6311835876110612, 1.3993893579005068, -0.43399651753379054, 0.5299120477626555, + -1.9670191857782173, -0.5248122259552108, 1.7175353497229144, 1.7289010829074938, 0.5729177740933844, + 1.5352457320381498, 0.4553403180302915, 0.29793068106143483, -1.5797876923914638, -0.40357306624290423, + 1.265178047524694, 1.0072829549062359, -0.1769350094041604, 0.6540198866379798, -0.568023907344485, + -0.9385372669098198, -0.9547551373309231, 0.9690083247728642, -0.49448270214331114, -0.23248401423341658, + 0.8559950134661438, 1.5269704580997994, -0.6731018952946073, 1.4821498012885321, -0.086653738716163, + 1.5513273624777053, 0.4908490299053412, 0.608900001002695, 0.9347951454225933, -1.8764423825502883, + -0.5383639900183281, -0.9057735867914714, -0.4471390378994089, -0.9345110033957944, -1.2485918565222498, + 1.3231857428964897, 0.14498731120726926, 1.2862055594556612, 1.3345878394286323, -1.2756222907563037, + -1.3083824064683531, 1.5045780173000916, -0.3988401099234551, -1.8345796826462557, 0.9521512093444242, + 1.37476939261499, 1.7326371780327987, -1.1085753439494548, -1.9216876869475268, 0.32444794625794593, + -0.35462190108024316, -1.6069584401996142, -0.6840247431355166, -0.6774334428214628, -0.6167820243082769, + -0.7218673863781015, 0.4950955340603276, -1.9104643839392264, -1.8658682955390713, 1.6439842658028825, + -1.9665747219603489, 0.877772486257161, -1.447763157358751, -1.602252241120052, 0.44016487530366266, + 1.010332237856809, 1.1627057025546774, 1.3589381505180222, 1.6636605093405308, -0.30457004904529494, + 1.2269612708937165, 0.6454558582632632, -0.3281883947244264, 0.8131649340163314, -1.0598316735137594, + -0.42299178838843243, 1.2938482948248247, -1.9502364601006974, 0.33584798243720115, 1.960552892008372, + -1.5357606602986342, 0.7325579192102154, -1.7068807453175427, -0.27946385324092926, -1.8639767385643475, + 0.026428176272640158, 1.8041446105579393, -1.9753854510739615, -1.2061334845302873, 0.7372916274629793, + 1.2179795802972748, -1.3143781251446196, 0.682430888983494, 1.765010566861033, 0.7860745956875865, + 1.5788062358182353, -1.575242377037637, -1.6479554611466636, 1.2726425469891218, 1.3214340571366705, + -1.1176914021551418, -0.28104090150176475, -0.8644022366981297, -1.0077418811855443, 0.015316356685453059, + -1.8951979061169633, -0.5196653139834044, 0.8125953302477926, -1.823298987094664, -1.772714760583118, + -0.42385130914514946, 0.8535556140476261, -0.21002977910714105, -1.0489096773038966, 1.693320576211045, + -0.7552501574944035, -1.70051398713119, 1.7993076735058722, -1.5431556753936784, -0.44776937868310007, + -0.0034131973262736537, -0.08787851222156462, 0.1325414438300383, -1.9039783696720445, 0.9518867483514475, + -0.9603800917394327, -0.0648448360480609, 0.5218985531710549, -1.6635189965029777, 1.9924108002877654, + -0.26339506136937363, 1.1928641143061727, 0.3810845521550812, 1.1123779174325765, 1.2834151041594684, + 1.7494397068262924, 0.7406099190962179, -1.0134737497144135, 1.055889762638893, -0.8394462038526926, + -0.28477578596133046, -1.3393678408508727, -0.012496365797505682, -0.03748743063321758, + -0.09254003926345167, -0.19080760280617248, 0.34473162850229677, 1.3043786112928455, 0.33958283219176444, + -1.0099885582662163, 1.5669995819775036, -0.2714415980354179, 1.3346740594279858, 0.602475697679103, + -0.27309507620431717, -0.4346043241348738, 1.12017383631626, -1.6915194091713959, -1.1923394655016413, + -1.8515166462293298, -0.591963112025546, 0.5023793945680568, -0.4921032884548957, -0.4819170803548376, + -1.4966363280675248, -1.421677402946293, -1.6850285211377178, -1.9944566759402251, 0.23292198758882154, + 0.8553269921007791, -1.948229679796313, -0.13143026366451505, -1.6538966244541689, -0.9216045649630527, + 1.7264765451863706, 1.8222181199058705, 1.9306597140261852, -0.42071370405153363, -0.994060577839833, + 1.3461298706581069, 0.6155835113186514, -0.4233898713404436, 1.0628290316641253, -1.882099316617622, + 1.0589656835123407, -1.1563854484115064, 0.06709310867074869, 0.6384730812451886, 0.10574897066907951, + -1.702659493082189, 1.532461866391916, -1.4099555819463054, 0.9274818832867693, 1.5037867997803227, + 0.7264189517035895, 1.739697756396633, -0.6088742760292636, -1.6856099921156718, -1.772602909046058, + 0.3867733521889116, 0.43414404525556716, -1.4477323883785695, -0.5550413636141851, -1.0684298188676138, + 0.49963854275697717, -0.6428666361617692, -1.2461600431717166, 0.139396634056558, -0.8713484091078527, + 0.9279324928307542, -0.860886224082325, 0.6052214303020085, -0.07407732844211701, 0.8385159528403303, + -1.2883611258333367, -1.594397303034099, -1.288952302189137, -1.5793499078181599, -0.6082222788183742, + 0.35912648776686495, -0.607260423160894, -1.253926812856471, 0.5102013763540754, -1.6947762434136582, + -1.675889179914285, -1.7314605301275563, 1.5448045270744153, -0.6686717741699395, -0.5356827185210964, + 0.12358215343481938, -0.6340124361185859, -0.2817760753415781, 1.164433875881067, -1.0118173150641265, + 1.8268866563720367, 1.7413521495755342, -1.3276318599091654, -1.7238317321272323, -1.7921370418547173, + 0.3558056543449375, 1.6918022479955095, 0.9222053317838856, -0.052028097382029515, 0.5122787494307435, + -0.32626022494247575, 0.15032399126484997, 0.8080660425425252, 0.8796007677651145, 1.6881141214849356, + -1.9923518519205325, 1.1594510791129853, 1.8780756936993601, -1.7055973395367428, 0.7022611016052407, + -1.2390916075946476, 1.489326601979406, 0.09129782431581734, -1.0503781321438632, -1.830361985060848, + -1.6234239674810462, -1.9400691667730579, -1.775525052900898, 0.49354389704319335, -0.294835573216341, + 1.6593428114830902, 1.1178351651841174, 1.392043569467992, -0.23893431171828805, -1.4623844403945752, + -1.5343016105239773, 1.1865854046143252, -0.030035182509464242, -0.23319636854325143, 0.16118837392623053, + -1.4249606736799842, -0.10348851980420637, -0.512221808349385, 1.3638877591460936, 1.6510345373891697, + 0.5403817230171999, -1.7568167360776457, 1.582586423119313, 1.8459420743954364, -0.0937741201677813, + -0.07681514921378785, -1.6485771749332665, 1.405045845579937, -1.380639120572158, 1.3920918243796292, + 1.8133826455276356, -1.6212352715972855, -1.6619339109060558, 0.5860174449496753, 0.8058352539157623, + 1.2357143629814047, 0.22041032204430522, 0.7808171688721632, -1.8606664108725086, 1.8543721730560465, + -1.2809372837528121, 1.8923388155485306, 0.7298982025806318, 0.08659163513354162, -1.7687987574607682, + -0.4763155379116144, -1.5473458996242258, 1.5866611468929568, -1.5386305359062193, 0.2081267635605304, + -1.381273405047125, 0.29199706105659384, 0.5721917482420489, -1.6925391452117218, 0.7916231259738984, + -1.4231755623012, 0.3405385083200416, 0.12077341714245371, -1.203092512782769, 0.9840212059218176, + 0.8659573692212614, -1.7404045592057091, -1.200177869243232, 1.9929041194760506, 1.6176296085808595, + 1.15438166620096, 0.41001019693662766, 1.545005651638757, 1.4880131034252333, 0.6483595755736316, + 1.9934284869908785, -0.4425968131492626, 1.1431733762266791, -1.5542365498656396, -1.4707761026218984, + 1.6864275163379014, 1.8816548331946308, -1.896526297348939, -1.2767368243378723, 1.7876805565314937, + -1.9714294114028945, -0.4825746666154789, -1.7490480229694754, -1.6790035262680618, -1.89004754358142, + 0.37907912536042243, 1.2624866059593396, 0.10773840252143074, -0.6872687299448277, 1.9925905119168483, + 1.6613584791100244, 1.3497294256582544, 0.14770827679848963, -0.386793508812298, 1.6398868686034165, + -0.41758827391224607, 1.7067458043950365, 1.0278221550032498, 1.7159317753776575, 0.5096805329124043, + 0.9834561119760092, 0.09498575572277979, 1.9951803301511655, -0.378982487863734, -0.8263817670300764, + -1.0229038853801748, 0.3913502332282688, -0.6996774883339203, 1.8060537458340287, -1.5307723308680314, + 1.8274255271921316, 0.3374273766186473, 0.504303338314112, -1.6080869793956403, -1.158919549492441, + 1.5813373512787257, 0.9432537197048445, 0.7842306706160116, 1.3546056018625912, -0.9933122900788947, + 1.0300108626811424, -1.8830535281855107, 1.2388444130694714, 1.5820432556020423, 1.7375339367053177, + 1.6102618826367072, -0.8665078865132916, -0.6732392396345555, -0.6529773281138276, 0.10547233393450206, + -0.6941730912707627, 0.910541741327461, 0.9098240508320066, -1.6363195940333455, 0.5463776877988993, + -0.33425194180314666, 1.0461836058570837, -0.9752798490633703, 0.8152462082002794, -1.5875859561094057, + -1.9077898642385165, -0.6579383417965277, -1.3678615233634694, 1.285752623992095, -1.258025248987475, + -1.734869037268262, 1.0429220147507374, 0.7326509890514723, 1.9986880192758498, 0.6509156089882868, + 1.3547313230945077, -0.5885137816765393, 0.17851497147549722, -0.48206201174941743, 0.7757155223228231, + 1.5797039076642765, -1.8365491055202483, -0.5557915214323543, 0.8098649490633258, -0.9482838270042748, + 1.7240038491867988, -1.727136737786945, 1.504375516257281, 1.8508968488624733, -1.781048252815335, + 0.29406576456043076, 1.3929052384310125, -0.5672101671346796, 1.1612683698036195, 1.0365664134132775, + 0.52678187508665, 0.3071827595777501, 0.7041562182045773, -1.1798720159212701, -0.3576278976339484, + -1.816947298419736, 1.1348446198592024, -1.5847397567297339, 0.3604508734069167, 0.46593859852538255, + -1.7777395155730495, 1.6287619990007203, 0.05390138452125459, -1.6910054822794427, -1.1955797005332123, + 0.1527150449913588, 0.7396002223852989, -0.18108018677808335, 1.5581070045362102, -1.8797418765958547, + 0.8650557162994215, -1.611948653627663, 1.1857525788641103, -1.0890741767763679, -0.7623497273737962, + -0.5452689372225361, 1.4851221090925062, -1.5985252342075071, 0.020384445673001572, -0.9046335836897486, + 0.553216643980825, -1.9081035245013105, 1.1735365048310928, 0.1195026810678641, -0.13729942676911122, + -1.5259805046535062, 0.18920225926555112, 0.6705947977532496, 1.0030469424172317, 0.7610916391668443, + -1.2463291870058537, -1.0249840040963818, -0.19173741736380734, -0.6665958901111466, 0.4839557261866174, + -0.3485528787244352, -1.5242944187930236, -0.14435052683644933, 1.4529331704829067, 0.6458683077621643, + -0.20459459850901585, 0.631062246518785, -0.8791486672647375, -1.2386308323664368, -0.7619627858914138, + 1.2107017989790583, -0.33473019147897354, -1.9130923107708755, -1.7056520026340571, 0.8880970113501672, + 1.3775061802643274, 1.6842143744397662, 0.8956663807740703, -0.14207737810877585, 1.15151215723502, + 1.9554721163330937, -0.7192433376047749, 1.9055716193191357, -1.7324310669463001, -0.9426609664206191, + 1.4304692916943642, -1.4208060554924176, -1.6703510665846304, -0.11789403764850537, 1.1035841327225944, + 1.19193146120024, -1.735646676252749, 1.9517684584286812, 1.800446091627931, 1.2643875465217338, + -1.2016658470474093, 1.096389364153712, 1.5201049232112753, 0.30449137858893494, -1.1869295719311896, + -0.8315768582782672, 0.4346498942421677, 1.8028550790017723, 1.8857086005150672, 1.1520127103731062, + -0.1519769293665103, -1.9810749314865586, 0.6722940736723881, -0.8875375147759952, -0.7891504144449426, + -1.6946353432210914, -0.4359020089062122, -0.6980599672246557, -0.07982954815920884, 0.19084722828248868, + 0.6845828835105898, 1.361502619024912, -0.6198314937755827, 0.43206457200540616, -0.4275542957416594, + -0.5942951907386576, -1.2680930727916406, 0.8768595741043272, -1.621734829642608, -0.5341763533991353, + 0.5480403433778003, 0.04432939753449716, 0.6820566847717622, 1.9586624282689744, 0.32087743739982866, + -0.6031550453761696, -1.9650369588045296, -0.01282494430097092, 0.39651817911017506, -1.3672271702505698, + 0.5253796884817046, -0.9414239928629815, -0.8381706914260612, -1.9383642756327601, -1.035705422287875, + -1.206785655560016, 1.2381881481085397, 1.2976332585562043, 1.6406249791463123, 0.10308205600753784, + -0.5475238811744738, 0.29899086302238675, -0.8482038855976217, -0.5137345687600776, -0.9065955517315878, + 1.3104099207076656, 0.5025101714926583, -0.19511985691542133, -1.7979503287642702, -1.902372134245554, + -0.8024648307653406, -0.8296523085064926, -0.9765719040493881, -0.32443266986702035, 0.6539733476893286, + 0.5973369569721942, 0.9620679364522129, -1.1326092851006697, -0.5203786703966227, -1.9445684352154542, + 1.120817472266788, -0.7832944208518358, 1.7728569415904172, 0.4746637992232019, -0.6331056497466134, + -0.9560458440421451, 0.774229832283905, 0.9430691283749084, 1.7415084291941643, 1.7617290689369067, + -1.241891773067345, -0.4745318258222131, 0.170886449126332, 0.280332317095926, 1.797798588989031, + 1.4955462166754243, -1.9714454758512092, -1.0474834990256507, -1.4515647324997412, 0.5231202814417619, + -0.8247619693573318, 0.842919528017017, 0.3442132040114725, -0.7990310505752118, -0.3629900475529535, + 0.3295033891840484, 1.7075620006843595, -1.7886707971887938, -0.4229652804748518, -1.2955282842854743, + -1.178898100821427, 1.0539302135104265, 0.07818911036924803, -1.572166269506143, -0.9256373147747796, + -0.4523743452766471, 0.870552332090627, -0.14850807893624118, 0.4287510203485754, 0.7234083474785322, + -1.8786486629917283, -1.3259155700503378, -0.347806801338562, 1.3631756709064557, -1.8448244492904076, + 1.0352554853447788, 1.2734756957434143, -1.8995503975543802, 1.3929138458956896, 0.4683038046394792, + -0.2679051570033124, 0.8144995104717001, 1.4932193857012948, -1.6953990593425603, 1.402756253554589, + 0.3808795564720828, 0.8990244719837941, 0.40911995577962923, -1.1361078826263524, 0.17385823702487802, + -0.8186745793721117, 1.7076078317166345, -1.236408907822983, 1.6848698673082323, 1.1057305891182168, + 0.684172881489129, 0.6339043483376958, -0.5196970144327304, 0.3521577727883862, 1.650116506377742, + 0.2682799961840745, -0.4735884200746394, -0.37060613769970896, 1.1147917782653982, 1.7887865852421685, + 1.6978458641851324, 0.27152326646567104, -0.7479948142297932, -1.2016639963842959, 0.49444674806749767, + -1.3781311297932053, 1.6512147519128018, 1.7416186949120336, 0.045136221165612334, 0.5209559439848652, + 1.259816622507529, -1.4608419771090535, 0.044609757110229076, 0.9224148305655282, 1.001777160557932, + 0.6923994259395219, -1.7332027007479764, 0.1740919172843176, -1.6691024022127667, 0.020765839823362775, + -0.5411621161873468, -1.9512100832757255, 1.2205713100210618, -1.1974515945322022, -1.8219003860967407, + -0.20156369520917128, 0.8953615730514599, -0.6851191939445931, -0.9784599914417944, 0.6357787809646211, + 1.2543221592701617, 1.154804559643825, -1.3175939582140597, -0.6021071761808416, -1.3594268073464137, + 1.445454919786025, 1.6373224793127816, 0.4773095450916047, -0.5842274784088719, -0.7097055492598594, + 1.221870791596726, 0.5841175201744218, -0.6088866249685312, -1.0239540299091434, -1.5776221817975893, + -1.0528530592443435, -1.7599610522525682, -0.009359693820525372, -1.5849216381687343, 1.1123817614821956, + 0.48090947307860965, 0.5755896514757204, -0.49340183569225626, 0.38999288587119985, 1.9076108980193105, + 1.7293269863481244, -1.3443206598919062, 0.001652373471082491, -0.3612982015066102, -1.666179661360145, + -1.5652606963428637, 1.2847799597537284, -1.8746367656966019, -0.5051948920106453, 1.480902450916811, + -0.6574579095355038, 0.41299775479925227, 0.0038296277371880905, -0.49511696025555096, + 0.33076498894339057, 0.07531759452269782, 0.3925650285564686, 0.04553443036132965, -1.2835239354793773, + 1.022597041467466, -0.14452142564653414, -0.6834698246070756, -0.26296934920727555, -0.1553100097881268, + -0.7486942835272403, 1.6436947866485108, -1.453108500403955, -1.142814165110103, 0.8561776338025142, + 0.9114113939243937, -1.6289171667242108, -0.8765736035046707, -1.3197920400322358, -0.3734912789205094, + 1.5370462570492807, 1.0649239894263536, 0.8535930082816217, -1.6961129271355224, 1.4179270256764838, + 1.4041314424292972, -0.0700614584636794, -0.5005397839840926, 0.23183520631057597, -0.10990588024738823, + -1.147089599229715, 1.8964022235002318, 0.03743472262799674, 0.15742108441832148, -0.7185738786020632, + -1.1856642140796065, -1.044955084073341, 0.6700862122512099, -1.6025569167524276, 0.5441522921772082, + 1.756702564317389, 0.1676615923964384, 1.9361822628230643, 1.7064197592361863, 1.6880082988048626, + 1.4936942432907419, -0.5799520435760606, 0.990660775656627, -1.9229488942439295, 0.3155933315458741, + 0.27140426916735727, -0.5632551338077167, -1.4590940918768025, 1.5983827694310806, -1.4330645192158764, + 0.8003275119389475, -1.9841806470403869, -1.8395944973457397, -1.5552105471701392, 1.6842460594625583, + -1.7017713102134175, -1.3427034266215978, 0.6775219782398878, -0.3351302524035349, -0.022448660326078063, + -1.3500598715209335, -0.5202469644373027, 1.921619202644619, 1.3872706825409837, 1.396529188979435, + 0.07083965303226236, 0.6914877624468483, -0.29224764274698156, -1.751452537465684, 1.0064318828131888, + 1.5123878406285325, 0.4033025380645654, -0.9820341018265148, 1.7861648401668955, -1.0181815370408254, + -1.903647522119213, -0.6561248619205413, -0.8699254533593299, 1.7173299560789754, -1.9229485367712744, + -0.690953419069066, 1.2737422710672446, 1.1066566790733132, 0.8606336717720096, 0.8621209988717888, + -0.10083318929855434, 1.3453155024821513, 0.10215147045998396, 1.3751499812523518, -1.2192061255665605, + -1.0236206808587962, -1.533123935767816, 0.15060837984637665, 0.20593075721578913, -0.1420114293640644, + 0.17633888228692118, -0.5175320445499185, 1.3087124943884705, 1.8324200965520019, -0.4738758668846961, + 0.5584879952099433, 1.9477363800804062, -1.8628690990218209, 1.42344054481267, 1.9288039965215793, + -1.8566049760751664, -1.2884135845361993, -1.19038243691399, 0.7683237483024614, 0.4802985717653341, + 1.4329557197969898, 0.9128985393650337, 0.9461407015680665, -1.65078779921165, 0.4192615449543542, + 0.32096438114763437, -0.5129450234494479, 0.631644434863107, -0.47335834996781756, 1.3100891597589213, + -1.4790022015467565, 1.725986057822876, 0.8701026300053698, -0.6819596038865807, 0.4664857353751577, + -0.7112013772427535, 0.12572492597099938, -1.7769733290901675, -1.6440093209585278, 1.125533492349331, + 0.19218031047016204, 0.02403105386135085, -1.7090825450471296, -1.312565935939367, 0.03858791868454148, + -1.2641036387182139, 0.904155292993349, 1.3561478854255968, -1.539635445223043, 1.9235783529452943, + -1.874310049415433, -1.7773733191470518, 0.11118483759793829, -1.0886413092825036 + ], + "dims": [1, 1, 2560], + "type": "float32" + }, + { + "data": [ + -1.0123639551600512, -0.1262791332695521, -0.5528788189121379, -0.9891722578914903, -0.8541177111117033, + -1.842327110585285, -0.8753504664996115, 1.645494290881846, 0.12876899266900654, 0.5739158499727086, + 1.1027206966256946, -1.3155458981065662, -1.1433211051679475, 1.4367855916029315, 1.4674402192278242, + -1.3373231059554618, 0.8170172046647917, 1.1074697240531268, -1.6004007249577086, 1.4644646696571568, + -1.827927680383385, 1.4548611857965401, -0.8614990138298273, 1.6706016048131325, 1.5096827794979042, + -0.4651782953949448, -0.8577219028996828, -1.299490913831483, -1.339145989117756, -1.0017632367283333, + -1.3586772742922255, -1.8799261724983776, 0.059417093938870735, -0.6646734157727456, 0.5388638764003799, + -0.6378909942726629, -1.1647516486356562, -0.058721485739401835, -1.814477796499844, -1.189167849669282, + -0.6380012350722168, 0.5285662507696332, -0.534701982091482, 0.5570990437739303, -1.755585696977759, + 1.27238726136709, -1.9571057028071568, 0.04195657651183726, -1.9047024137637942, -0.5116506294039525, + 0.9926189908254566, 1.7759871807627992, -1.301591689492625, -0.9524123108659142, 0.524043944088068, + 0.5401946307447156, 1.2036398911372181, 0.3219194319137575, 1.9433711281899884, 0.33648919584354076, + 0.9772519308773946, -0.5575502080369272, -0.7345410843307336, 0.5778449333097511, -1.9408006240005902, + 1.9819202164371932, 0.8468700855540847, -1.3899691404202503, 0.15850835128301544, 0.49059781858603113, + 1.7764286491001, 1.7946165130578535, -1.16168298050607, 1.032789240397304, -0.7706982026329863, + -1.1038032957670127, 1.1838096309519006, -1.6323318982074788, 1.1064057844588042, 0.391546384023683, + 0.8726810963884386, -0.28563916045025906, -1.960002018275822, -1.3867181381833626, 1.921900210546779, + 0.23545042298506935, 1.84976268271061, 1.0525315891685612, -0.7472942377034979, 0.7577710804753881, + 0.3083793892222504, 1.327301485825914, -1.5659874146221533, -0.8978311834083046, -0.62907789098844, + 0.4870403076019034, 0.630783738162723, -1.9216326727288857, 0.9012236985202584, -0.7852645198565309, + -0.6937624671663629, -0.17653350051115613, -0.5561473457869717, 0.6698782481609369, 0.23362849702887267, + -1.2054404784680655, 0.7698259603739315, -1.1714183267177214, 1.1184512448333477, -1.6690588449303805, + 0.9147154841361029, -0.8955722282828678, 0.5231382201829353, -1.0773616491852902, -1.8410207920577744, + -1.1337885739036437, 0.5219438601865445, -0.5731516829203676, 1.5757208446602124, -1.8470713081756802, + -0.14985389360149082, -1.9692981274100303, -0.9532002408436595, -0.9539842568916299, 0.17378022347200517, + -0.41623035688267596, 1.6481595173848254, -1.026339675363599, -0.9699532421892103, 1.9461597162595536, + 1.3640648912196438, 1.1391500889364954, -0.5595329932725566, 0.43069270118926184, 0.9216291855227485, + -1.440051073691266, -1.47987236009598, 0.9087155921703598, 1.2787682017568338, 1.2363303394454128, + -0.49711585260182556, -0.9387884495809562, -0.19011148798368716, -0.08611760612433539, 0.8656095244085549, + -1.4316244544821588, -1.558915825119862, -0.6738541489070196, 0.7358531504667214, -1.5716157542957179, + -0.3210549969144969, -0.20072745672949832, -0.2416365577548918, 1.5081023632879136, -1.547604775100064, + -0.6334244403609635, 0.14810360033380032, -0.7978288773497635, 0.6204344672116999, 0.4773642826492761, + 1.2087249539596163, -0.6643075818626354, -0.4170560884596144, -1.6192321024457321, 0.7844847722786517, + -0.629690651133866, -0.7380758723814482, 0.303414620658204, 1.5479875220490822, -1.198103302774089, + 0.9760982659095188, -0.7574001500859886, -1.2724614749813545, 1.0658176639069543, -0.8843666652730136, + -0.6427064732600343, -1.4416669867869603, 1.473657450100351, -0.8344994942004691, -1.4224435385472942, + -1.0338023533751777, -1.933568422908838, -1.0802998481520287, -0.38091309180010224, 1.6199945117010506, + -1.702101910236685, -1.4725385504086255, -1.413591341417039, 0.540278745507603, 0.3517718642795238, + -1.590795907883174, -1.765499823368284, 1.7366923492439614, 0.4582221558773192, -0.10581682008337268, + -0.18516227544796227, -0.21097779387988158, -0.1428735544745896, -1.97510241493318, 0.32449001731988947, + -0.1832218746003349, 0.26181337286546746, -1.1227369967165552, 0.35351574098454375, -0.21205956319428498, + -1.3866497212089195, 0.5946688412485415, -0.3417425538750871, -0.33633058083047906, -1.7852940300594273, + 1.9919312461265548, -0.7629882135863388, 1.4620310920385196, 1.1115061446942711, -0.9057539166302142, + 1.775862903430335, -1.1324751374031425, -1.5851970376790732, 0.9843604936500894, 1.5734177841900427, + 0.9515914445205862, 0.034323622285483246, 0.8075573695504703, 0.5332420240003275, 0.7767308358623826, + 1.0329131214994085, -0.9838298807872725, -1.7429868813963063, 0.03740922197745089, -1.3794671490283932, + 0.9772124799843054, 1.6546060756751624, -1.345806362676182, -1.620585515255308, 1.3498272448019941, + -0.25283974040314394, 1.8309785362540882, 0.8336568766196351, -1.407378961144727, -0.8870392599067882, + 1.5455801491463914, 1.1404840595611354, 1.9778865957841072, 1.9026326233043243, -1.2919286899267508, + 1.5194536255103763, -0.40024201189426734, -0.38767200629106124, 0.37883119550528654, -0.8971148848399899, + 1.1472506966060552, 1.6769048658537242, 0.39834963390946676, -0.8584979863189526, 1.4851684858856853, + -1.9898922021489387, 1.7271323520062838, 1.8848497191146922, 1.7439889790675318, 0.5311134881425099, + 0.2302960173495876, -1.6217864910988453, 0.28492260856667784, 1.3550969896689358, 1.6762026245924515, + 0.45464402605973575, 1.8447468286497628, 1.7489125819896838, -0.7745650567526248, 0.6473255323813127, + 1.88270574713889, -1.4231865592800421, 1.406236181817195, -0.05820536366515672, -1.9830176146937077, + -0.927096735728453, -1.37521200952383, 1.6293084827507869, 0.18916714867483186, -0.3559388864834574, + -0.0626685044384443, 1.4510888124049117, 0.00665671994549033, 0.5852250937009087, -1.964947150735517, + 1.086994355276114, -1.5545621604146378, -0.6702039017668291, 0.15273130009205538, -1.354848404243989, + 0.8081822753111396, -0.2990136329330131, -0.1334268300545549, 1.2936295445817017, -0.5276761138383153, + 0.06209853112125252, -0.35227980331045927, -0.6683541541821878, 1.5365781152706175, -0.5227637702649135, + -0.43751245897261537, 0.39166051967309556, 0.6145502882685348, 0.6764920150128493, -0.46478346293163764, + 0.40093484640123567, -1.4385605602950564, 1.5318810200296449, -0.7902920012169599, -0.22815329205907098, + 1.5159148518766017, 1.7440445423086697, -0.7705868478778743, -1.0446035338845894, 1.4407728607631372, + -1.7690868678646723, -1.9956594357087072, 0.9165504950260104, 1.1647979922386025, 1.7626373785022524, + -0.3262003585763962, -0.7291462643423232, 1.2691368673965409, 0.9833027096614515, 0.7052758987187504, + -1.4080008451270958, 0.2004861907693547, 1.92413536100345, 1.8633978379666134, -1.5597901041000588, + -1.232525418601906, -1.9326471509575835, -0.23851047841947803, -1.745957663852197, -1.027455630245683, + 1.7842373183009093, 0.7098705198166604, -1.3523419086313861, 0.2493915779920206, 0.5836072040016118, + 0.8452857075528275, -0.6044200471234227, 1.335947146234287, 1.9535634253874816, 1.8737477649440653, + 1.6787256628480751, 1.3475059469256392, -0.9023420902836907, -1.815324493360138, -1.7487338231501415, + -0.08107787176718784, 1.178869071718574, 0.5869021791922719, 0.1289991861916615, 0.9871714466975163, + -0.7828180891664971, -0.9162265218952319, 0.7883323334301799, -0.7738825207321494, 1.0578051800827781, + -0.48483804389576335, -1.7003938250158095, -1.7474401518911815, -0.6024807198720463, 1.470072074418848, + -0.809852698248462, -1.8087803758512981, 1.1275613461510172, 0.9110052554791794, -1.4827836388852713, + -1.3641845213240744, -1.9188108209559402, 0.8859208024949954, 1.7438050845669144, 0.14476912919518536, + 0.6121128834023981, -1.7692670213619586, 0.023661688752081744, 1.1007625036098432, 1.1758330104763122, + 0.4546664062325476, 0.022499008403786824, 0.8120850018523171, -0.7886059301759083, 0.8107426171843777, + 0.015751759753425354, 1.9515003227015306, -0.3285612629290764, 1.758730602588753, -1.9178063288185045, + -1.3319225925070368, 0.5970900239608552, 1.8634221473873263, -0.7483844730402502, 0.0851383845623852, + -0.10037389959678844, 1.8601880295663413, -0.5358906627108242, 1.5311027975069011, -0.7567148434480719, + 1.7810484758849983, 1.5004941791198378, 1.238866744077014, 0.019238796977725237, 0.7314924609545477, + -0.6404106749076366, -0.30544348502988683, -0.7754562102568752, -1.1903829239480253, -0.7557972926946839, + 1.418804956107497, 1.6841275666684883, -0.35403092145419013, 0.3072276436064163, 0.4941160183076647, + -0.010460638654985033, -0.7496577784263767, -0.05957826320949966, -0.5349743628709929, + 0.44780861823397355, -1.9548584156880642, -0.6834407857845042, -0.6574778495500677, -0.2568872307434864, + -0.8179424074332058, -1.9399599284052886, -1.7438777236599172, -1.9046697213047699, 0.33576417528481173, + 1.0390831565369494, -1.2867357835981865, -0.9105779330773034, -1.2600968940701254, 1.7546113033912878, + 0.8638193166816803, 1.915934439034265, -0.18936860893703056, 1.6490561383957179, -1.7404200826424407, + 0.15942118157817387, 1.174512061322961, 1.087287672904493, 1.182852158431765, -0.12430741231511089, + 1.6711861366379157, 0.7124940145742862, -0.3946246470773911, -0.7754542725640272, 0.9539784330716907, + -0.5716889107776746, 1.4262896723570924, 1.4675456840569163, -1.7077488525524833, -1.6888666810589683, + 1.2108429896458865, -0.30524840414522547, -0.18167408305726607, -0.2569749511019337, 1.2912167486614727, + 0.6208472747047127, 0.9472464500515958, -1.1302136544927563, -1.478282134349313, 0.4848945322578242, + 0.8298435424742152, 1.6932133553283863, -1.4458048451455756, 1.4088139925156833, 0.505348371415975, + -0.21105001864112882, -0.04858175142791943, 1.3570555900503694, 1.2673714070205957, 1.1469844853077413, + 1.2000011591622064, -1.0780533577358637, 0.37698814259115565, 0.6997609434842227, -0.7604196150995675, + 1.8410835681246196, -1.6836663805912915, -1.6352482015283725, -0.4811456756273813, 1.3045848106454878, + -0.823583139102726, -0.646527859067727, 1.5092372843244393, -1.0424659042584983, -1.9448695809676995, + 0.2678845821239566, -0.6194354091338141, -0.3172475643478627, 0.16481577119936563, -1.1026554846901258, + -0.8352503899270465, -1.8149755146432849, 0.4839677208020152, -1.9367901959501284, 0.8680859459275245, + -1.2035834761537227, -0.8748603808576707, -1.8417628093555134, -1.0429294120821577, -0.2520578638761588, + 1.7833216800296539, 1.4367696159460968, 1.8669976567111535, -1.405562858069989, 1.0576377264778563, + -0.4569713929987014, 1.3255011842556819, 1.6171166029195225, 1.0403552739195874, 1.2264321768656314, + -0.47396544132443275, 0.7118492170263346, 0.6260191876547241, -1.2179712214091793, 1.5120789908822676, + -1.657525645319189, -1.2991286032659461, 0.22202239748400387, -0.5389051448124622, -1.594992260705033, + -0.0487688918807363, 1.1512759563478916, 1.4679486318272383, 0.8813613284468369, 1.4328674044139706, + 1.9999268579039367, -0.47950159568339323, -1.2281571927849866, -0.4554451947054856, -0.15429012922622043, + 0.19786052464284776, -0.3680312279906497, -0.18825645901610866, 0.13700608028054084, 0.11417316734012051, + -0.6463349589003959, -0.3770634118502656, -1.9950465240002844, 1.4676192894281632, -1.6060448800367215, + -0.6182395877160713, 1.2695963682598732, 1.4459649727588744, -1.88317964468765, 1.4240536704934144, + -1.5317465035623874, 1.9497915396745666, -1.991995390985421, -1.4828801801030478, -0.03471214257257316, + 1.0775554630217314, 0.38086611278178495, -0.1958126129950628, 0.3711869657718534, 1.7307000355063105, + 1.3370240564962907, -0.6892941432270163, -0.6176252997554714, -1.2761391889511922, -1.9463694295411074, + -0.820841598715079, -0.42139807880407787, -0.7976620746519121, 1.934915457164431, -0.4497028214466532, + 0.5258289450636102, -0.3002339930414486, 1.4317770429007526, -0.10670432773391081, 1.477187387671167, + 1.6422292268536607, 0.30393726544465505, -0.16649028812518285, -1.1690831963968895, -0.7043985973685203, + 0.47350023974648625, 1.657836032561474, 0.16219091081621606, 1.2307861721904754, -0.7270831655242516, + -1.0574142264570137, -1.9134290652692378, 0.1585901752075669, -1.922458865955397, -0.5216475421535529, + 1.4438431375673586, -0.2874803852531551, 0.3370681487492808, 0.9173203757850725, -1.3751125170831138, + -1.014212305918492, -1.5475568694685897, -0.5834419852252983, -0.022709263779811195, 1.4145035718404255, + -1.9267536984073965, -1.871498186038706, 0.525620783057489, -0.48663912480579086, -1.0308451661848164, + -0.1369351560707246, -0.7876105221422698, 0.6955249722293555, 0.29453585260072757, -0.06514154829755281, + -1.6429080966047698, 0.15520901396599296, -1.518429991046033, 0.6839405241853216, -1.8300625086431346, + -0.15898426442765246, 1.9278290352285792, -0.7150445644634695, 0.34034454186145346, -0.2506887167667946, + 0.2912251513885442, 0.10434269155791664, -0.5637887420304368, 0.031008416285043694, -1.0816174134360272, + 0.05203114530680164, 0.03172694813978172, -0.7646387549793285, 0.36213414786228526, 0.0060869909269349876, + -1.0367311632092022, -1.0684702277942222, 1.1874407786461294, -1.9290032593242623, -1.8550268296137276, + -1.161269082907582, -0.18656240236501098, 1.1070044180055767, 1.6261946461581385, -0.5698373516978554, + 0.9631347920513802, -1.0985201941795912, -0.45509125721838917, 1.6535643092193615, 1.9696290288271951, + 1.0266341388473261, -0.23790773845234447, 0.2828088466454748, -1.6537561622154024, -0.2286353308074256, + 0.8766588558049788, -0.8195788808423616, 0.7718354518479451, 0.46796484124644344, -1.7212327769837446, + 0.0658435971926874, -0.6407624425160288, -0.5647885630487526, -0.5284202936353299, 1.8818650199438771, + 0.29252062160862025, 0.09136912052125101, 0.4321630239196912, -1.1485094277982304, 1.6036235307678686, + 0.3318334927588493, 1.7219946936827109, -0.09166860362313312, -0.9321185046623404, -0.5842230824766759, + 0.5762857089716649, 0.6237761258836967, 1.3257989135149089, 1.65675048758645, -1.0060167288419342, + -0.08448091333478214, -1.2793076427969634, 0.7514972175750367, -0.4193024725154899, 1.427794959994305, + 0.9558973375817734, -0.00039143542951691757, 1.7030425931606343, 1.8219801925309609, 1.7260980421968792, + -1.9249357614979115, 1.6285038041870772, -1.6118493301059527, -0.4294907666236245, -0.1659993953929053, + -0.5722726383494532, 1.2935829105972072, -0.06859448172655114, 0.6177602091273879, -1.3370886026529494, + -0.8003712871381898, 1.4776462171750593, 1.4184671800982604, 0.5433276418773598, -1.1103872044287346, + -0.7572146251109908, 1.1710857107940438, -1.8705799333769377, -0.31024900334903194, 0.34866139709491595, + 1.4866168061361806, -1.9774625782466435, -1.9891386648785518, -1.7735018436923857, -1.5751766748778406, + 0.7218521749338684, -1.9390531947989702, 0.10502871018805493, -0.6908737000286438, -0.583334654840761, + 1.465181746808006, 0.9232784443998208, -0.37400862804876933, 0.6661364913985333, 1.8688403166211724, + -0.8717922772171374, 1.40123243258902, -1.9913513342271694, 0.7369262986601814, 1.2562396521830914, + 0.7638029152143444, -1.8164465814226718, 1.7184240047901733, -0.911895923498772, -0.43161449539234464, + -1.09011215721989, -1.865570383600069, -1.2232212962752618, -1.943030725162366, -1.3198980808588407, + 0.0564583162685901, -1.1298601037432707, 0.6392469941959655, -1.8442933136946733, -1.0692331296192306, + -0.29525834417416963, 1.1184299311108798, -1.6129180448223925, -1.2727580333965411, -0.9415967651718447, + -0.2597646669604643, -1.511150740860189, -1.769101860129168, -0.9185489030191736, 1.6841872338621604, + 1.7266136417112579, -0.6047332956995355, -0.4784036452377798, -0.6987121488961696, -0.3950169895430573, + 0.29099877073820757, -1.5250611710167732, -0.5876293953125105, -0.5938486494168753, -1.0021656820999798, + -1.9666708037201044, -1.272140943592933, 0.7880251982149247, 0.11964755378902137, -0.3901422866566291, + 0.7163616669643105, -0.04395212244207691, -1.0791955402155438, -1.2675936648298745, 1.0655012879795382, + -1.2960016160353804, 1.1268487593724137, -1.4561611267402474, 1.7853064994708516, 0.9817883639607627, + -1.509195648143982, -0.5791158763214721, 0.2952226835939813, 0.978029471962218, -1.7020877610480865, + 1.2949154364706787, -1.8669978207007674, 1.9804087372291983, -0.1920592681769353, -1.3129464854527964, + -0.13084211958976688, -0.25279655392730405, -1.8414897577067046, -0.7363208735547797, -0.6909260581968182, + -1.8811392695885178, -1.5901068180742568, 1.758878672856656, 1.5387787983193055, 1.6713805051822828, + 0.28500585759464503, 1.3914306792968247, -1.112480424362695, 0.43326162263712487, -1.8142315585546145, + 0.04023793859339708, -0.29805331377366073, -1.5940199056342657, 0.598129666067309, 0.625004812581027, + -0.911960460437454, -1.973405398299871, -0.7574758313972065, 1.6261060948310595, -1.0639316504874738, + -0.8549167612511983, -1.6781924250988283, -1.2461164334164385, -1.4396767476893544, 0.588676315376695, + -1.3923513282535058, 1.960640665522404, -1.1216598084556173, 0.29702774865635373, -1.1441990771482464, + -0.9733601129567919, -0.13533496827900127, 1.2875809157665516, -1.004348467034836, 1.501216437295625, + -1.7257128690349832, 1.3038540536955745, -0.23514094567048183, -1.0545846838443325, -1.5628126421353628, + 1.758225843292891, -1.752217343717482, 0.8827182187480176, 1.9633500079396518, 1.4124055174643644, + -1.6009057792139894, -0.124257691420528, -1.6361563376854855, -0.7163857270237415, 0.5991086423774714, + 1.7584781739562239, -1.625063774845441, -0.37572359945414213, -1.9506995916793395, -1.951072499257542, + -0.8315895595505731, -0.7002195813028456, 0.31048848147933406, 0.19118037223773499, -0.836819966187166, + 0.8992259497849764, 0.2769262236848036, -0.8190887502725346, 0.5908744335005158, 0.8309308801063224, + 1.828667115346997, 0.050270920754735826, -1.4675310474376078, 1.6474157726281966, 1.4412270025481906, + -0.071799132580602, 1.2723657902431542, 0.9847345744230269, -1.7967618044169429, -0.38502605748464447, + -0.9911154903546722, -0.9911306363398822, -0.9822148063420846, 1.404660627884402, -1.7223894428279198, + -1.906376932077218, 1.1267093944315762, 0.005733157947431344, 1.5499657009743384, -1.2918427917389055, + 1.8878866260898972, -1.450986879628605, -1.0670858020560647, 1.012435839252528, -1.0904895043171203, + -0.7636238525274122, 1.8215658692720762, 0.802604215057829, 1.6666057955071523, -0.6857256630224935, + -0.5501356674470292, 0.810089459752044, -1.6169276394201413, -1.7364078843810304, -1.1867030927097977, + -0.5172860730134063, -0.8556026745046195, 1.7395171980402528, 0.8977518661224195, -0.715248100272647, + 0.42642471199620147, -0.1359154018671509, -1.017818228497351, -0.00905895348889274, -1.6541703500137865, + 0.5001119548133026, 1.7346626988667904, -1.1674654589916598, -0.735219697011062, -0.1962670855393469, + 0.21602710767932987, 1.634800475118543, -1.614549402431435, -1.86469031751599, -1.0722793725135675, + -1.6751255258166085, -0.34784828468612705, -1.333517864681233, 0.2286247270719235, -0.9686327464308748, + 1.8030391927221219, -1.260653677947687, -1.3291707209191737, -1.7464317874557151, -0.6022055677476486, + 0.4234332187817236, -1.3942184957445614, -0.49460322127068856, 0.846493215661404, -1.3779621390020038, + 1.6170651737934367, -1.5949106936017516, -1.993666650794867, 0.517274246668932, 0.636335391123283, + 0.09728792236496986, 0.21871120388837983, 1.9033150412817141, 0.3761639225094582, -1.7448084623773052, + 1.001122609393847, -0.44673552515686765, -0.6566748795770616, 1.0022029703662332, 0.49152185517814306, + -0.19632140501675632, 0.6593309755584418, -1.42607069406814, -0.2499327686935615, -1.4645970035455589, + -1.8214929827258137, 1.7849263457214155, -0.46930999932929396, 0.930852011498847, 1.4657054090327062, + -0.8598379219960437, 0.21923117934684644, -0.2719980917718665, 1.1382814204088554, -1.4437234293121408, + -0.08437654030814734, 0.9230522551879243, 0.17552818859532504, -1.3982719952892557, 0.7727609217240659, + -1.3512364654797944, 0.4217546307725639, 1.8959084151076748, -1.9009721763514387, -0.2084686501701638, + 1.34507209606525, 1.4812500373319821, -0.25452451515050356, -0.7547650359872655, -1.5901912558006952, + -0.617303822265475, -1.568626713241147, 1.3511951372228292, 0.46680417658438866, -0.9843992974612537, + 0.21141544095185427, 1.9555502838890186, -0.5622924926922526, -0.7074175111692584, 1.6856741408764497, + 1.4329492504371748, -0.6904233032298688, -0.16570616327044707, -0.5819404191250754, 1.5298308400435117, + 1.2873904282242874, 1.75253332340609, -0.3969229369696805, 0.9712496560090953, -0.984449102903409, + 1.845837132921134, -1.23000834955623, 0.25823037305241403, 0.6562595586377551, 1.426434937488695, + 1.2050365141327637, 0.35112023386450986, -1.423781157416867, -1.22442697877245, 0.8857806584751993, + -0.27941851495344316, 0.383573200806417, -1.6546531309712131, -1.0620419037179136, -1.6487673588042684, + -1.1583303816085477, 0.11883432462925647, 1.8623629910270258, 0.7814730455397738, 1.3892839510915138, + -1.2109955091247775, -1.1531820955625154, 1.4249824872214445, 1.878872910977651, 1.96640914460413, + -1.7574195668520565, -0.17080472539130565, -1.2334517024508829, -0.042203796430729135, + -0.5119900340796173, 1.8245076363940829, -1.9504677488809792, -0.15838646478579133, 1.8653261691127963, + 0.9818831615417336, -1.1498049891062543, 1.8177635453973222, 1.0307183336158117, -0.9459693602670747, + -0.6967235469867861, -0.6765683802163993, -0.25117711682611343, 1.7085642032810782, 1.2354232326963057, + 0.16304953424899615, 0.8288006493055686, 0.7426908331423769, -1.8567733398412107, 1.9363614866712426, + -1.1491819532508236, 0.4456745704746172, -1.5200940589370164, -1.2921806881240192, -1.8425344342113679, + 1.385258197718982, -1.4468713069952912, -1.7194028306161009, 0.11112342320496005, -0.46993304971516636, + -0.3532497592673005, -0.20705867891006324, 1.1503266436423507, -0.8615322336734987, 1.4243814245202273, + 0.44605877897824087, 1.4291899741133527, -0.5503339260133986, 1.6845285529242942, -0.11203488733585854, + 0.7838364265704332, -0.3478757678382811, 1.4617786521240204, 0.6797786349237454, -1.4070556983478548, + 1.5368835556999283, 0.7270943692880465, -0.3943204095728534, -1.0159540349760245, 1.3420827647018054, + 1.4944969205796248, -1.3846348875424548, -1.1070204938214987, -1.5623431163391235, -0.7285340046827864, + -1.6146739312377756, 0.7914876850628412, -1.4285275663878165, -0.2102551967847095, -0.6036343290031185, + -0.7863519667198808, -0.5232027574551195, -1.5248951664568793, 0.6403374226115135, -1.3332654904167054, + -1.6013714748847017, 1.5264343369773945, -1.4659567584783701, -0.1255742527269854, -0.05787364846985188, + 1.0630262325298858, -1.0251739144160261, 0.6341878529485214, -0.37708094500223766, 1.5820017651752964, + -0.9754158194342759, 0.021470140570154506, -0.270780715342573, 0.8513662098629382, 1.812913501299759, + -0.5507306480213909, 1.6252304329937193, 1.6166807780171624, -1.0015815783816109, -1.6525598491008084, + -1.7640141719015405, 1.4567526655551655, 1.4314017477260261, 1.1147464550993922, -0.8609069210724725, + -0.8835445478716384, -0.8807052221352922, -1.054638273386189, -0.9307483447707048, 0.6532758793412583, + 0.2251469563444708, -1.4983410691150256, -1.355149970924261, 1.825265403016327, 0.3882557252462826, + -1.2005370275411478, 0.5167632305655108, 1.258505468633845, 0.09615276706317388, -0.11253109876707157, + 1.301050271249565, -0.5926127701907813, 1.6730534476647492, 1.2040207312610214, -1.5220497985479415, + 1.7064305519329883, -0.7793546956972763, 1.0964366887160475, -0.8642894018251441, -1.9411871186407836, + 0.074716954760353, -0.14369204309870298, 1.8393941069756865, -1.3769308839327685, -0.30862200557660646, + -1.993793823331563, 0.3709006233167482, -1.6557247604820402, 0.32053951400820324, 1.418554947267494, + -0.14801920801346657, 0.25882446183466357, -0.30227778350472967, -1.4281993644549162, -1.2907922764091362, + 0.9110864171884971, 0.613974184551024, 0.6697305289032087, 0.8489421527088039, 1.498148243683703, + -1.4269397350154973, 0.6132189565263042, 1.8741137765083877, -0.05705777446194382, -0.8855796810622429, + -0.8656995527854097, -0.5082467483431357, 0.4332387677470342, -0.7541429782381526, 1.305940642158534, + 1.3554774725998202, -1.2111195490457929, -1.4381676776657422, -0.5207599119467634, -1.8228914923142483, + -0.1877702958456453, 0.36230894851909135, 0.17851993376959374, 1.7927379289246463, 1.1368406732884377, + -0.9950802434664396, 1.6043789056058433, 0.6484661390901874, 0.933455445947418, 1.0965484754420745, + -0.6939481831648528, 1.2328397545284568, 0.9378872541025238, 0.34445900641106775, 0.1278365294249939, + -1.0258774152176224, 1.0892898808290514, -1.067964672512037, 0.44689053769430487, 1.9413539674195102, + 0.9528679183933839, 1.7587111895733223, 0.5576643641512362, -1.9897364013123235, 0.47036230007388813, + -1.7897636522061902, 1.5010421043239726, -0.6901446605239645, 0.45393761765945406, 0.2829324542693943, + 0.7192447802434581, -1.1331672455879191, 0.17337802502719768, 0.49963004068736083, -0.7815511951172738, + 0.9636628323801011, -0.8323427533238403, 1.2095900671544157, 1.934951536397083, 0.5200093238971917, + 1.158643992086569, -1.5891421117437572, 1.1934443042649656, 0.8276249819736909, 1.4212639972350827, + -0.8419641775597322, 1.2661381994885206, 0.06158022749540493, 0.46246628371777465, 0.7573951522202043, + 0.7619102543908491, -1.9354772481952196, -0.08541307148121735, 0.32286401465719994, 1.170684883004844, + 0.7749781307966064, 0.07936620410936168, 0.2315148010116328, 1.263046672689227, 1.031153500041193, + -1.4759863901212418, 0.9873606549190814, 0.7777062346878534, -1.7228292836209809, -1.5782061917261343, + -1.9224252408205036, 1.6523756827872802, 0.2993123495389698, -1.0626357773248643, 1.0197875082835361, + 1.9808878982391658, -1.7335070696659765, -1.3148099108168987, -1.9494880550804297, 0.17268926636423831, + 1.242098464223531, 1.1180045679110648, -0.6736825181756254, 1.0154763824746373, -0.19957976202073535, + 1.1193883202782464, -1.1478784045364039, -1.1330343367225462, -0.8689855158445736, 1.0116971724295913, + 1.2506562934116516, 0.9967792918784815, -1.0610583651118048, -0.532297674201299, -0.7150119770938783, + -1.2232688695213954, -0.4743547399060102, -0.46434726873258025, 1.0115464987548544, -1.0177528333698103, + 1.733606846210808, -0.8155729251191701, 0.5232276995816267, 1.7914758177269183, 1.0157926433716051, + -1.8351739772356694, 1.2012667794467289, 0.8545667843161366, 1.294851140223634, -1.3948286719024399, + -1.1466399375859053, 1.3041194739119701, 0.7377935849631942, -0.7995103884432107, 1.6159710967964172, + 1.1620542442509896, 0.0055755030499886615, -1.0334307489311154, -1.5841651529973335, 1.4111075346503084, + -1.2009959445559417, 1.9720845657548844, -0.8035695524416591, -0.14331494930025102, -1.4264826197967428, + 1.7586100751820988, -0.739427319576853, 1.0666567765029669, -1.5916344843881065, -1.4350385771936356, + 0.42668754843352463, -1.6484548280031257, -0.37728075139212613, -0.7687371980269919, -0.8179781982796062, + 0.465506361112749, 1.3840790073759015, 1.0880219954714985, -0.7036856119504717, -0.6963478247947235, + 1.0979632285661367, -1.3810178288903288, 0.8156134298411253, 0.10032276394202011, -1.6008032936016479, + 0.18932441617959217, 1.3551150619453578, 1.3534267176398442, -1.1276635081675348, 0.6608005967002919, + 0.793182461268664, 1.398769261405878, -1.2123058565244103, -0.08803245110073288, -0.2893447181014084, + -0.9972961711861705, 1.5618332897004663, 1.1591927593779072, 0.511047619279922, -1.970138349713964, + -1.3628804504805752, 0.2782295809685751, 0.30358322411230354, 1.9514398542744873, 0.25960063763317454, + 0.4976234205537926, 0.012047558099780531, 1.79534431915478, -1.7723391117592922, -1.9992623394682916, + -0.4524505481322034, 1.3804610881366495, 1.1587664810582172, 0.5111716739430667, -1.6928217537440542, + -1.0278751605013383, 0.20893412968684988, 0.5871739815665329, -0.8412950581167742, -0.4077765748738731, + 1.7498754266646204, 0.8583271186920243, 0.5762482317954367, -1.8599099610537024, -0.19242912582490845, + 1.2512291284228754, -0.8441763329152305, 0.26980735485206075, 1.4044456507515894, -0.8516268695811835, + 1.4493090144656193, -1.3915783403234894, 0.35557624127716814, 0.17226516309619733, -0.5021504124493701, + -0.766188811190383, 1.1332244159180078, 0.011135590774230764, 1.8851307343362258, 0.9148262788018782, + -0.8299956158151707, -1.6057691996197043, 0.5678238711359924, 1.8008767630518667, -0.586304639586193, + 0.47029839294118503, -1.463460016599707, 0.20856103503853962, 1.2545845494965118, -1.0729668619560213, + -0.14947337388785709, -0.20035875199434283, -0.07202935566940027, 0.05533721254453372, 1.3677731442776313, + 1.7011893855187177, -1.7202195328563636, 1.9488792451860384, 1.3096386167232117, -0.5132153326702822, + 0.5616165083457831, 0.4157359121447879, 1.8006481124839855, 0.230442477572935, 1.1686013774607265, + -1.7879670674147912, -0.842370723742838, -0.7927388944332199, -0.9586442598316518, -1.54708954114047, + -1.2956507442445577, -1.4031204732951874, -1.6120562181795481, 0.6283505387959369, 1.9223686678649798, + 0.12298814371626143, 1.6278360280654836, -1.6223557147160461, 0.43054457669015456, -1.7908842288361821, + 0.5775385836169233, -0.15097004219870414, -1.2290692851647318, -0.620782793926316, 1.1062043604891931, + -1.9746433547898716, 0.6382174626600765, 1.749692571795931, -1.8339775549081967, -1.464038954875173, + -0.9639795224425223, -0.990228592162139, -0.9403728223487793, -1.6685943188697578, 0.07041255085387288, + 1.5308882897823413, -0.47253846241489494, 0.8106189739961147, -0.2270261582976314, 1.8799983165866934, + -1.2472611795739992, -0.15798247303785384, 1.1021702350596438, 1.554345756888078, -0.555551779439396, + -0.2519441527853594, 1.6402787480890648, 1.0543147407284197, 1.6335920152443828, 1.6300453015691483, + -1.0481818985064661, -0.7279789339473091, 0.4888000242214119, 0.48732772667936697, 0.8584068837243475, + -0.2507103566018847, -0.4408909179300613, 1.1233796101976283, -0.4632288097116861, 0.2476082637987398, + 1.0440154957174563, -0.8869375992382329, -1.4091245527531013, 1.1244796013617524, 0.6892639663640718, + -0.022508581368523295, -0.1947662213695871, -0.03308275174489772, -1.5224685354120124, -1.749461912732114, + -0.8082369514527556, -0.25696601764524907, -0.29739460489203395, -0.7761972897539025, 0.11904461166545932, + 0.12301125764257481, -0.42555313462267197, 1.193050472577613, 0.41286096527825666, 0.19893639265378305, + 0.8311360551038502, 0.07310473388978878, 0.6685202630904961, 1.3372821732384246, -0.12066792119153114, + 0.05609759368011602, 1.0451437108149522, 1.662035784060775, -1.1922702285472315, 0.6416834232307735, + 0.6055133359661493, 0.7942484561725927, 1.7276358502409623, -1.2864330076965746, -1.0918326505522646, + 1.1998891150473536, -0.28628797182602295, -0.5716447940943157, -0.5689380516868878, -0.12795378416769676, + -0.6791193619117468, 0.10038501986448622, 1.7063771482852008, 1.194191767284221, 0.6647694331145066, + -1.7348358815181681, 0.2817755499501846, -1.8525340342841066, 0.9851618492112006, -1.1059877587552673, + -1.5295669583116638, -1.0505227522752554, -0.9487207502273058, 0.809796237150997, -1.8498482833968302, + 0.48019511430536, 1.1226579267420975, 0.742760641350614, 0.29074715598422696, 1.3377685110252564, + 1.0867564725571857, 0.18478374884500237, -0.3164295444608616, 1.622419371685548, -1.1918941739540632, + -1.85979992788012, -0.12624583269268985, -0.7280639736725645, 1.7648123921702625, -0.2906781194024406, + -0.09262137204418774, 0.4138200526064937, 0.04327935477419942, 0.9797051661471672, 1.6815036980821159, + 1.1825812431197686, -0.4448889424945657, -1.7444130322963716, -1.2294019413222204, 0.050868336661046065, + -0.28346687593229625, 1.76715189128082, 1.611534349398629, -1.1862000782338198, -0.18760651204489776, + -0.6355147050302099, -1.0693954841485978, -1.3343928935517813, -0.027843745181175272, 1.8949354550001445, + 1.1947313752564401, -1.705187723165774, 0.7263276294399761, -1.7205254544798727, 1.3224746522778261, + 1.1708741480141782, -0.521723626523988, 0.14495711998638772, 0.1539508663177509, 1.3474086945015884, + 1.629428728782094, -0.4727414254013107, 0.6353417064465656, -0.03451981234038648, -0.5836884681507799, + -0.20106395276242583, -1.7734283674131142, 0.32803960400104515, 1.9133097530382273, -0.24540943135655446, + 0.5453897283838289, -0.5427539051784143, -0.6173702928258038, 0.9933437009783477, -0.3068382522087365, + -0.9721176288546554, -0.2527049737425271, -1.273295878249181, 1.0233037474978097, -0.3711521481099638, + 0.5595202642561681, 0.8557760663482039, -1.8239035041840888, 0.9147544038362989, 1.7978278449861724, + -1.4194290997912544, 1.963345131841125, 0.23090855457281645, -1.3144424697360408, -0.1798082266955383, + -1.2650654602972473, -0.18679991460173095, -0.17629662209988428, -0.18349530363472422, + -0.5475428774726536, -1.04729696306776, -0.6411208082297168, -0.7308839648617003, -1.3455838296857312, + 1.3597877917216312, -1.8340993142240434, -0.6466367039451653, -1.8498009040527492, 1.50126492806735, + -0.8473870152631546, -0.566754366425382, -0.11291722444101016, -1.204932180016593, -1.3014629360740422, + 0.8120263473720843, -1.8930799484123453, 0.5895700763514027, 1.7580882681692787, 1.477127335430259, + 1.2047061381469248, 0.16485767274321272, 0.23788298792881424, 1.2386497309565767, -0.05281444305869076, + 0.17334853252953497, -0.9480401485516037, -0.12016072866190708, -1.5973543471706693, 0.9977273313654775, + -1.859593065673394, -1.22159503592711, 1.1674399144996856, 1.2941842856062413, -1.1135548933000354, + -0.9788839012477277, 1.5718286586532253, -0.1759982975417227, 1.1031262106075763, -0.8778538891016368, + -1.0912009563680698, -1.4342020269338933, 0.7224191131816049, -1.2061468497546022, 1.5502197738160985, + -0.4251181088474105, -0.3206510906592923, -0.20288873333289725, -0.3560455119064194, -1.7508196425056557, + 1.5986301909131457, -1.4438601542872123, -1.7892166215086185, 0.7616375057554672, 1.6373087559633968, + 0.384206177448279, 0.4567136418012572, 1.2972172920819673, -1.1093690558377691, 1.6750979445730616, + 0.90908880550879, 1.2770950805973325, -0.30072973792624325, -0.30858954557714835, -1.9794583286836787, + 1.4463537661478734, 1.0183415951718624, 0.19309738632155593, -1.9449394663020856, 0.7108312298345272, + -1.131148529447878, 1.5259401710667637, -1.2736225271446795, 1.449852979043179, -0.5704566653927854, + -0.3074713127723667, 1.3993673001014262, 1.7718101827800963, -1.4492416161385497, 1.204820691151994, + 0.537241600964693, -1.5810854626566924, -1.5966679426112087, -0.6946214128685, 0.7401275132249641, + -1.7160959560539784, 0.4050021853660404, -1.2105835684442203, 1.7944565918560151, -0.6818406209243566, + -1.2359644949792958, 0.49907683448423157, -0.8405824207686488, 0.4689270476935219, -1.4699087331797918, + 1.43264169315706, 0.10499119180317251, -1.9830520821430992, -1.5927469176409472, -0.7632048947095695, + 1.303303968850459, 0.5773554905161333, 0.6761130632322967, 0.9023788989770569, 1.8960847479504084, + 0.8846144527507596, 0.9891128512774987, -0.7137249307539442, 1.267181508288493, -1.5113665535004523, + -0.17564834166799148, 0.9032525164747707, -0.25795858643541525, -0.2153155122989885, 0.14650372443070392, + -0.9984704344431465, 0.19261182317116177, -1.568804815745514, -1.8496400745041237, 0.7219284702138937, + 0.47816525621959105, -0.3273800096758981, -1.3390312793543542, -1.9838474983553418, 1.4066433632074071, + 1.9258390074704312, 0.4520281509654822, -0.30119846185025256, -1.8086624212334463, 0.851460380433025, + -0.4149432793401253, 1.4970655678791776, -0.9139304175330043, -1.1721517169571731, -1.9882366923130537, + 0.20630155701555886, 0.0891351853539204, -0.18485046053672338, -0.5253430902979694, 1.1136150007281822, + -1.072256739674419, 0.5677711226994742, -1.6986682182236068, -1.373143853609375, 0.5391705517446521, + 1.615483488858379, -0.18222418110590155, 1.5270125615115786, 1.4186450525284275, 0.6856859039097802, + 0.5948037341597869, -1.0097732940745248, -0.7260016082299225, -0.1705798585617213, -1.4460592122059417, + 0.2804912469966476, -0.0574570618149588, -1.4226509159038505, -1.3490817825559507, 0.7561887451342573, + -1.0315310280500372, 0.7865868802852711, -1.7955739447423928, -0.20476732094967787, 0.6532859024525468, + -1.398626809307176, 0.31416850473475755, -1.0474173751356446, 0.049027524534579925, 1.335442264825483, + 0.5839485880852768, 0.8416818491436544, 0.7729008830376998, 1.7957935152184445, -0.20047560204525272, + 1.9653799460331678, -0.756998178675067, -0.12357101901807699, -1.4272827743751613, 0.7149414745051672, + 1.4783565252719182, -1.2368177109511205, -1.4571248051607144, -0.7948678149157731, -0.6295946982419727, + -0.022851757488315805, -0.07947620035768654, -1.3106359681202076, 0.1591438592300909, -1.4970586188027868, + 1.3181273904865316, -1.508591213967403, 0.5722257787143228, 0.774539967054146, -0.5579675263215638, + -0.801690277809052, -0.8966439545169163, -0.2168181087774288, 1.8549965661558616, 0.7870136331314779, + 1.0426166176054243, 1.2052992540989846, -0.6116512580549873, -1.7800528483131748, 0.6162047118916432, + -1.1406795391578877, -1.3126212462178328, -0.1255252753148266, -0.048214851156274996, -1.7513823416941525, + -1.9966724157135571, 1.468282137353885, 1.1596808879879097, 1.848952713577705, 1.9276331797246486, + -0.6082295997412146, 1.9590194651252002, -0.6705403599782791, 0.8982591946264264, -1.8582005994721253, + -0.6224103416017206, 1.3118474535601639, 1.927285880838153, -1.3435831019941835, -0.02035775798119932, + -0.258091815197548, 1.5685276792778557, -1.7504336743073416, -1.3270808448193447, 1.9609655615175043, + -0.5002114597187894, -0.8302889305621663, 0.6662682285835677, -1.3588868202703237, -1.4263374077454936, + -1.117653746556062, 1.6959423725848142, -0.3368698386266633, 0.6329184444264122, 1.4360518922995382, + 0.2209792086889042, -1.7826312330601093, -1.9055378329489479, 1.8363537758423742, 1.8612237061845747, + -1.163857834211714, 0.38823573714522475, 0.9933133475252713, -1.769852560129741, 0.6303163049709841, + 1.6352278260339865, -1.4220707937174062, 0.4996182092181929, 0.1748538915264719, -1.389807604688972, + -0.5041547053983226, -0.7755917479953034, 0.33822942573796055, -1.3957767536429841, -0.16066323457963172, + 1.8426173458683097, 1.0912529333551886, 0.04454407634104118, 1.4585397734066836, 1.314915917164475, + 1.0930141444320949, -0.9720567164640972, -0.5831452038265033, 0.8082335756515109, 0.4358913655339238, + 0.8310387682873994, -0.8242800720840835, -0.47497624245619896, 0.000058968841639917, -1.6746583184349388, + 0.2586283765233146, -0.03952361428650608, 1.9572062803747263, -1.364317103129661, 0.16484595584710782, + 0.6889848970954304, 0.33625779127527444, 0.28142293472509294, -1.5510992496482494, 1.6785313595707674, + 0.4921495479711657, -0.42294403727168906, -0.10192465238332815, 1.583070264702826, 0.6464143795128816, + 0.7706704090619576, -0.45316577898360944, -0.6156337052461307, 0.2949317256431403, -1.1153946167003506, + 0.23143632095143918, -1.187495465719234, -0.754948635807529, -1.090644714217727, 0.8562387289761135, + 1.4209567719285578, -1.867698005779011, 1.3320884849513037, 0.5619380450950349, 1.8886416226851166, + -1.7314027359692306, -1.0362482885730966, 0.9807231768105664, -1.3689591083054822, 1.5694772951886131, + 0.4400722090716478, 0.9539178709143741, 0.4832872148319014, 0.23471769113792984, 1.9643745055943542, + 1.1325801513292664, -0.43752654225713705, 0.4538975778222154, -1.6157155513403065, -0.961125955159364, + -1.1751535270699955, 1.1277536127856669, 0.11594556933087752, -1.9276503102738447, -1.5774089828974898, + -1.0029301039964427, 0.4455245589428616, -0.4739643281334569, -0.8513671370845639, -1.0336436816615233, + 0.6626347865920605, 0.7885413873550009, -0.013439463013608766, 0.6488139123172507, 1.4291110296253855, + -1.9761431450732667, -0.5012957954679527, -1.1585910698227027, -1.1021436093929422, 0.036919383239228054, + 1.8089329170710071, -0.005231354013648826, -1.2082234644042886, -0.3456887578591781, -0.8017405353429492, + -0.5345375675659492, 0.8534420279659507, 1.7447905633469585, 0.43817127727920724, -1.8499205219957702, + 1.4797731845522186, -1.5443888715914138, 1.0131225647292235, 0.34701885989022063, -0.41455116200553377, + 0.40209313291363724, 0.11900781713274, -0.9935386808363758, -1.8322340002578068, 0.9811839836103111, + 1.502154507354498, 0.8891949169357574, 0.899071159318308, 1.752337905147142, -0.04599799842734953, + -1.6347681983052045, -0.5522741247690393, 1.505215487771519, 0.8504281241898015, 1.8693941265525265, + -1.1512863792577441, -1.8748160415118837, -1.879939448107617, 0.9353149913960506, -1.077101932112896, + 1.2322050595012843, 1.2672982902122678, 0.9384368132472858, 1.7274119921052788, -0.9601726232935137, + 0.19420343716687505, -0.7830049935581602, -1.9099470296794694, -1.213386784368356, 1.6800660417837605, + -0.9282638481321719, 0.5088239004955142, -0.5528513330962577, -0.4235136044745138, 1.6316021980530238, + -0.3087654696690505, -0.10527992793999896, -1.4364007982935343, -0.4455364976497762, -1.3433044303003099, + 0.6517505064656408, 1.6050250028051813, 1.6490276577492855, 1.9140119353414144, -0.7684496098140174, + -1.01738188731548, -1.1250647161193914, -1.6586222112755102, 1.1599068196677091, 0.795751774794466, + 0.5733174614685748, -1.3655937932875277, -1.507254849973065, -0.2831083653801638, 1.3241227396573514, + 1.574957068221127, -0.31194765030973937, 0.4008126582755933, 0.43635579619776443, -1.3214048572867325, + -0.8447194221435215, -1.1526249262582748, 0.7073544609451421, 1.17078844004073, 0.2425026449956018, + -1.7518561882120753, 1.6591407848437605, 0.06616038448738504, -1.928680221520497, -1.0504809684365677, + -1.0974712342176778, 1.6344494477175475, -0.4129201382527832, -1.7111789594333953, -1.6070549808753904, + -1.2456702084965565, -0.012663680475193395, -1.1305840149083926, 0.734392120651302, 0.18651679771884844, + -0.22974141381305735, 1.9415149817194726, 1.9280078232850126, -1.2072428658632042, -0.14782869942839927, + -1.6523593328098034, 0.4844141001145905, 1.0492278525622805, 0.5924539450553175, 0.848097235977705, + 1.8881210898619676, 0.20004070245023797, 1.4305799893712425, -0.9082660328332564, -0.14268688754147085, + -1.1201061991671262, 1.1399839045134712, -1.5579448101377515, 0.5516078322124933, 0.3365679579810825, + -0.636402425334972, 0.7364990374614768, -1.2657328423109204, 1.4084870144147636, -0.5538274490613713, + 0.43684201943536305, -0.706532199493215, -1.7678543182116737, 0.5086879667154935, -0.8888826793267235, + -1.0510640830474856, -1.775013227468511, -0.7345226367397419, 0.9474796694127203, -0.649964939391042, + -1.5189099534245498, -1.9260526549789567, -0.457330394781688, -0.9340352374682741, -1.3868164983748459, + 0.553888202560878, 0.36818698767921365, -0.9382183717778192, -1.0829596839250488, 1.3646658325042367, + -1.3240476722940633, 1.9816923192707012, 0.5300123477141927, 0.0790085088366057, -0.4455760575475125, + 0.48463653297167486, -0.7788483158746828, 0.8253771773416885, -0.6823431576948558, 1.7776737534704772, + 0.5497713214586923, 1.4452464852137838, 0.23037431004796094, 0.31188142524786766, -1.7543267850797761, + 0.6063452856820941, 1.207122989395999, -1.926907332363795, 0.45038239145265724, 1.0988284911574286, + -1.6007436457047142, -0.4728890687538678, 0.3195037474199047, 0.8855124961325762, -0.2555993730577626, + 1.8813620496087493, 1.8900177166377103, 0.09592367474164032, 1.8974568987778628, -1.1058972953708501, + -1.1512017435907103, 0.40201549011430693, -0.1831060132280804, 0.22245091899613723, -1.1866541831479385, + -0.5451040730392975, 0.9199451579519691, -0.42060255461704177, -1.3791747236441925, -0.3024448490768936, + -1.6611455107283604, -0.3106541240888907, -0.9498356682876157, 1.769660410309836, 1.598216213741022, + 0.4623205859503434, 0.03664778458072249, -0.6655252973523123, -0.7325818423601653, -1.591871681771024, + 0.9451427301297981, -0.8203468674560934, -1.5069504221011005, 0.7243170638862324, 1.1839749702725175, + -0.7128348341329511, -1.5076965090949397, -1.59865172895221, 0.08910680490749368, 1.5717880586471278, + 1.8951504684652152, 0.42550207471805646, -0.128409822054123, 0.9896766313315162, 0.34808462644009275, + -1.7082990487472571, -1.0459982270685435, -1.1132292311691874, -0.3022325459842996, -1.7274216318536348, + -0.11775921716410043, 0.7403577685290532, 1.188824227090608, 1.387282721223393, 1.0688709331799577, + -0.6395615121564866, 0.8142138261269114, -1.4467483751545576, 0.8996177593321626, -1.8193866881462766, + 0.08924208518874632, -1.405297919708996, 0.31754790231458685, 0.9823851818369507, -0.49590144528424585, + 1.4194064220328588, 0.9729299634967079, -0.46170090347918613, -1.634203532024186, -1.3139980454214522, + -1.8469876250843802, 0.710926322864931, -0.4029599381569682, 1.7246833539931403, -1.4088169680807, + -1.9165388068372708, 0.21804317359714798, 0.3898186987610348, -1.6118063668363405, 0.28583673086194583, + 1.7015683175211391, -1.2836168642070582, 0.8463494611619371, -1.1625839799245696, 0.2640138032690018, + 0.5041551687310717, -1.755824925370514, -1.177748867346489, -0.3829444120449246, 0.8360805034202388, + 0.05022254918868896, 0.4276469609032256, -0.8235567451730139, -1.94062145827791, 0.35097020666890355, + -0.6636358495150212, -1.2587452298002964, -1.6575362996910616, 1.060467707494971, 0.9661831305087887, + -1.77149852066976, 0.18844425542251564, 0.0897431507375952, -0.8281105706620897, 1.5632959207508623, + 0.07141359825195703, -1.0844701043735414, -1.905443802421968, -0.27588161311845294, -0.40342423607415956, + -0.34332304727825136, 0.0176022295550462, 1.94359831375926, 0.09702777017089215, -0.11695098349068722, + -1.2810374187149858, 0.37597456306160204, 1.7631374725308877, -0.7830266259108773, -0.5605784036815882, + -0.4409773606270875, -0.49636250754717803, -1.549108447216227, 0.6261185797820117, -1.260881611110821, + 1.691411217905281, 0.899655093658585, -1.0875528162122174, -0.7120948701980732, 1.8214705523154269, + 1.3010380076854968, -0.4492643980075144, 0.9914465230608682, 1.7590027691290615, 0.8514661670055963, + 1.5263431803492642, -1.7260779024351258, -0.14589666108296218, 0.18011804793376918, 0.7175880982696512, + -1.0399388762140145, -1.0480376846250712, 1.5656146512942648, 0.4435540525930799, -0.4175857955829816, + -1.8218436496980575, -0.9346408060646185, 1.40089015453285, 1.6926667426168764, 0.4187248632147291, + -0.6755275086264145, -0.7011229448771363, -0.9528087286614646, 0.0730922604589237, 0.6252467328216804, + -0.5573518555770702, 0.9864121888624755, 0.6646486706800783, -0.8405364163020792, -1.0505815213688878, + 0.8989991238262265, -0.5022516947851985, -1.784806766373471, 1.2637708002659025, 0.5065772818030325, + 0.9973024415787677, 0.08348064671549338, -1.757630249437522, -0.2016005631449005, -1.0360120513086803, + 1.786128872822113, -1.2720919225213843, 1.3430514506545927, -1.0762325516117865, -1.0578995596104255, + -0.47242526972271204, 0.05539697038437996, -1.08558757453563, -1.404337586710036, -0.059247790489052043, + 1.3998034069978171, 1.6067367856721155, -0.5185826391883994, -0.051896682542695416, 0.11112005542023429, + -1.2231398633939348, 0.2886372299242277, 0.6564469519248641, 1.67404118804063, -1.538487261793886, + 0.18551945213331145, -1.1342837192256061, -0.3318725405647953, -1.4531152273595227, -0.6934713826285721, + -0.24436235286417052, 0.6776292484438171, 0.8871814678850702, 0.41826798275898014, 1.0161513742931785, + -0.13947907673300097, -0.7736759327375049, -0.43981678279829683, 1.635191807530191, 1.3044401854805878, + -0.3097446711021723, 1.8125726195847056, -0.26127912212234694, 0.8564630403854094, -1.519521818793156, + -0.5727391479884938, 1.7015469847109976, 0.663240965083145, 0.31064120951508656, 1.4030451184981052, + 0.3325065959732836, -0.7178902057747756, 0.6090652378284442, 0.8426138183122633, -0.580146652112278, + -0.6076938097212707, -1.6599273271373782, 0.29960912457791444, -1.6741835731853065, 1.5428301790607195, + 0.8970548194971704, 1.6066845600081736, 0.5404165757730146, -1.9537941867764292, -1.5234595572340748, + 0.5293735217702951, -0.64620260665742, 1.8818640992235771, -1.7237764606754276, -0.8040024538741264, + 0.0642546885214017, 1.4395299343641659, -1.462587128675942, 0.011882540823848764, 0.12033421748154716, + 0.5458210215408705, -1.5141295301316422, 1.8809343680577308, -1.8801856621666753, -1.901376259472575, + -1.4374202976060095, 0.8473513507453765, -0.896351895119154, 0.457751001832321, 1.876657552919962, + 1.267733433184599, -0.30894648094866195, -1.8178016120669414, -0.7711776919446018, 0.29038564786361576, + 1.6396189720781438, 0.9597929848181161, -0.34227788522140834, -1.4450087753233527, -1.7068508353679626, + 0.8426935759536303, 0.7173810205823674, -0.6580236891322881, 0.8663322021812405, -0.38112472550089915, + -0.3331447946260786, 1.8551673806318556, -0.6731525492100126, -0.009001319785657103, 0.41039833755685784, + 0.025091358839535616, -0.49823213412251555, -0.9827448714264726, 1.1077851800046377, 0.5740585983905078, + -0.7235926614954762, 0.5059901875180826, -0.850177898505664, -0.05453987892121592, 0.8633840127545733, + -0.3153969644106205, -1.3028092681229868, 0.7523083527030074, 1.413775575813558, 1.0697458650110754, + -1.1839403319780022, 1.5167022074836893, -0.36486781099211996, -1.7835462010879564, -0.6061285803342944, + 1.9969466536022722, 1.9531672204642883, 0.7967381388222403, 1.0934589095880973, 1.6405590176012312, + 0.3501113568054244, -0.8786338692108497, -0.0545508019996932, 1.4464849975584952, -1.8853956921596513, + 0.25983013847132774, 0.8440414107184964, -1.8818826057620326, -0.22674619971532906, -0.35951513414106007, + 1.6757192364875237, -0.15819503713874195, 1.6691357866915144, 1.8534771980207108, -1.8297709996602967, + -1.6514392305036534, -0.2343385561012088, -1.820925987823773, -1.2556074451315578, 1.0621490016055715, + 1.2109203955756476, -1.1500152481919903, 1.0466452723330733, -0.5833431814034746, -1.0817348313127493, + -0.40295742758029984, 1.0986483593098368, -0.42704879342020696, 0.7065668399686658, -1.7278295290305223, + -0.7183051438456021, 0.15944089245801774, -1.0276995188812146, -1.6398474518024653, -1.2318071869917695, + 0.23333723988807886, 0.4626060172442088, 1.2255286520425894, 0.9652309086097803, -1.5254810192280601, + -0.6683416541099767, -1.9000628944332894, -0.7244291249780632, 1.0347731523086239, -0.9009629081875952, + -0.11734057204667625, -0.6698286923748489, -0.9472592207823913, -1.8286232413501864, 0.2898215560382935, + -1.422921306299517, 0.6696091233175077, -1.014229444534946, 0.7139087775492996, -0.23241859018004174, + 0.9620272535910068, 1.1497473812544134, -0.8640235723632799, -1.3121563547519584, -1.0396316763878488, + 1.7769188425035614, -1.014039070903868, 0.4489833453895331, 1.0763456360970807, 0.6229672422660295, + -1.7464692834364435, -0.3663893352922596, -0.8769410861076103, 1.0608706111705635, -0.9262810932490861, + -0.5468726859924029, 1.7966956706569919, -1.0663859467239112, 0.7378193848221777, 1.068192632208282, + 1.3312476842417755, -1.5902814653913628, -1.0131078061920382, -0.6235296781383814, 0.37504339068020176, + 0.9126508132330242, 0.3999532385546267, 0.6552059838941, -1.6053942866342332, 1.7900258342291853, + 0.08171912833062756, 1.6137883979635745, 1.3466843147948442, 1.8505801094553158, 1.3728813966930913, + -0.4473279660140852, -0.20009909626620814, 1.4067472413245437, 0.36658966226851764, 1.4566800897303072, + -0.11633958899045194, -1.9458410018060368, 0.5651869174802018, -0.9885077925334622, 0.24385043055374833, + -0.4407908079663816, -1.7015252126482139, 1.5396273477916198, -0.9801103159055833, 0.9331708410017399, + 0.058036076446482454, 0.29277070369481883, 1.6896333641554682, -0.2872886303585469, 0.2981100430160728, + 0.1670720357805502, -1.6828245857476496, -1.4681960401028125, -0.9933436100210251, -1.827639383468739, + 0.08433714147463611, -1.1318904274562795, 0.9840669856671846, 0.8204547128989219, 0.5959008566248984, + -0.22424536381303728, 1.765380932910376, 1.050492887173749, 0.8249285352430338, 1.5823516671950122, + -1.4695844512182843, 1.4009128159343485, 1.0886951647082785, -0.4963319371911856, -1.8633848779197413, + 0.660465126445569, -1.2319891082878298, 1.6547000157065659, -1.6403428022350113, -1.2308283749125177, + 0.9142339764828238, 0.18691349086990705, -1.148271069003111, -1.266859733272054, -1.4482873768560758, + 1.6888579757850914, 1.5392518897570104, 0.41499451567073464, -1.0517290742419663, -0.9856143466540894, + 0.704611691207357, -0.27871441123648655, 0.445828139270918, -0.8125969294930622, 1.521716695437079, + 0.5657668386735519, -1.813374372841099, 1.0076529676525672, -0.5864288471977783, 0.5855480422270194, + -1.8330974772064481, 0.9782157266479414, -1.6230556142249775, 0.5265126362718373, -1.6878701852107563, + -0.3955226747487526, -1.3888929741627605, 0.2905034183357449, 1.0489208524387843, -0.3118857187498678, + -0.6289506096761981, -0.05735383950307149, 1.8668941791416147, 0.8898345005884769, 1.7147482078759548, + -0.12387314928310289, 0.2298818139402634, 1.9294076224252024, -0.43580099597679656, 1.7512542893273144, + -1.258214124547644, 0.9779750741630782, -0.2566261319632144, -1.9813300069235993, -1.3498734101224414, + 0.7506344777083953, 1.8867470646651894, -1.918953273635191, 1.7429571494233906, 0.7638060343526085, + -0.44782770384121484, -1.1300950570142518, -1.4753506380821149 + ], + "dims": [2560], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.027313262224197388, -0.005701353773474693, 0.1959753781557083, 0.10011828690767288, -3.6098804473876953, + 0.00864929985255003, -0.011981655843555927, -0.11036527156829834, 0.6647213101387024, 0.276733934879303, + -0.6354819536209106, -0.014075735583901405, -0.059462033212184906, -0.3388662040233612, + -0.017422985285520554, -0.043299876153469086, -0.16756349802017212, -0.07582926005125046, + -0.16514767706394196, 2.9962074756622314, -2.600733757019043, 0.04413439333438873, 0.07896167039871216, + 1.1207873821258545, -0.032255738973617554, -0.09964963793754578, -2.1782073974609375, + -0.01814177632331848, 0.08586198836565018, 0.380964457988739, -0.01918521337211132, 0.006902141962200403, + 0.0669674500823021, -0.09234043955802917, -1.0496017932891846, 0.020094068720936775, -0.11474193632602692, + -0.056350305676460266, -2.2275612354278564, -2.648808240890503, -0.017779357731342316, + -0.2514607608318329, 0.008559616282582283, 0.010673644952476025, -0.32376542687416077, + -0.16903237998485565, 0.026010606437921524, -2.163571357727051, 0.35461699962615967, 1.5194188356399536, + -1.1094666719436646, -0.012471643276512623, 0.0767873078584671, -0.21644049882888794, + -0.043257202953100204, 0.001341399853117764, -0.1367240697145462, 0.005313852336257696, 2.144134759902954, + 0.11904949694871902, -0.26428619027137756, 0.014375614002346992, 0.06913577765226364, -4.196413516998291, + -5.172718524932861, 0.06162356957793236, -0.0010976337362080812, 0.21020400524139404, 4.567638397216797, + -0.059758733958005905, 5.990215301513672, 0.19405193626880646, 0.003011247143149376, -0.1036064475774765, + -0.016247211024165154, -0.12790939211845398, -0.08561908453702927, 0.25051021575927734, + 0.07514326274394989, -0.6767844557762146, 1.5661166906356812, -4.326471328735352, 0.07481537014245987, + -0.7969828248023987, -0.45468205213546753, 0.21233250200748444, 3.420551061630249, -0.0759267508983612, + 0.16086462140083313, -0.3939729928970337, 1.3957020044326782, -0.2972649931907654, -0.31666669249534607, + 0.35118427872657776, 0.3117898404598236, 0.088602215051651, 0.17165301740169525, -0.8542330265045166, + -0.06893759965896606, 0.08126193284988403, 0.02327258512377739, -0.1314769983291626, 0.035079699009656906, + 1.2096712589263916, 0.9461245536804199, -6.337772846221924, 5.575413703918457, -3.9876515865325928, + 0.01430205162614584, -2.093717098236084, -0.056584782898426056, 0.05612698942422867, -0.01935030147433281, + 0.0010159225203096867, -0.38109132647514343, -0.000587565591558814, 0.12273997068405151, + -1.4854758977890015, 0.016024703159928322, -0.05192752555012703, -0.257480651140213, -4.023406982421875, + 0.03150588274002075, -5.065948486328125, -0.07601942121982574, -0.04482676833868027, -0.01937261037528515, + -6.9667582511901855, -0.05368780344724655, 0.6142992377281189, 0.3128206431865692, -0.3862888216972351, + 0.053061991930007935, -0.24360240995883942, -0.018439287319779396, 0.1868235021829605, + 0.005632609128952026, -0.10385553538799286, 1.2077943086624146, -0.07107469439506531, 1.771382212638855, + 0.3696843981742859, -0.31587034463882446, 0.0002820117224473506, 0.055834002792835236, + -0.7621694803237915, -3.773604393005371, -0.12602387368679047, 0.8626934289932251, -4.139935493469238, + -0.08643748611211777, -2.25795841217041, -0.025201046839356422, -0.28647178411483765, -0.5088312029838562, + -2.3566224575042725, -0.20447342097759247, 0.3922976553440094, 0.047735944390296936, -0.09984598308801651, + 4.436963081359863, 0.17725177109241486, 0.01968466490507126, -0.4080508351325989, 0.600350558757782, + -0.1489681750535965, -2.3178586959838867, 0.010645782575011253, -0.5052445530891418, 0.12876634299755096, + 2.72904109954834, 0.007315368857234716, 0.503023624420166, 0.9695355892181396, 0.4959081709384918, + -3.562389612197876, 0.3780525028705597, -0.194877028465271, -1.0815603733062744, 0.6436595320701599, + -0.10088582336902618, -0.06308454275131226, -3.7394943237304688, -0.0011674398556351662, + -0.19378826022148132, -0.2329375147819519, 0.029814809560775757, 0.20438098907470703, + -0.23114298284053802, 0.026816120371222496, -7.350013256072998, -0.011900502257049084, 2.0180928707122803, + 0.20987474918365479, -3.209254503250122, -7.5496602058410645, 0.232008695602417, 0.0027162893675267696, + -0.4211888611316681, 0.287914901971817, 0.028367964550852776, 0.015583046711981297, 0.07393462210893631, + 0.6514078974723816, -0.04090245068073273, -0.004561522509902716, 0.2931022346019745, -0.4355356991291046, + -0.1867547184228897, -5.984931945800781, 0.044270407408475876, -0.35987964272499084, + -0.033762961626052856, -0.2677021622657776, -0.013161826878786087, -0.010206296108663082, + -4.528798580169678, 0.4174078106880188, 0.12906667590141296, 0.04690857604146004, -0.08034832775592804, + 1.5188398361206055, 1.3247699737548828, 0.011872933246195316, 0.055544108152389526, 0.0025585023686289787, + -7.696174621582031, 0.030730921775102615, 0.039231084287166595, -0.4407111704349518, -0.3110845386981964, + 2.284346342086792, -0.027610689401626587, 0.09054349362850189, 1.7885178327560425, -0.11802572757005692, + 0.03795969486236572, 2.373623847961426, 0.11311819404363632, 0.009557336568832397, -0.02887658029794693, + -0.28853726387023926, -0.17708882689476013, -0.22821268439292908, 0.0237746462225914, 3.257477283477783, + 0.2507217526435852, 0.17421714961528778, -0.12231585383415222, 0.18179824948310852, -0.3428541123867035, + 0.024907970800995827, 0.2441745400428772, 1.13312828540802, -0.0009440237190574408, -0.594701886177063, + -0.008615869097411633, 4.071537017822266, 0.6198470592498779, -0.3097928464412689, -0.4404515027999878, + -7.008431911468506, 0.024559520184993744, 0.1267288327217102, 0.2140975296497345, 0.5778637528419495, + -0.03296203166246414, 0.8842242360115051, 0.16367295384407043, -0.3035202920436859, -0.09384048730134964, + 0.6805808544158936, -0.2706672251224518, -1.429656982421875, -0.1497703641653061, 0.4302230775356293, + -1.864505648612976, 0.01007054653018713, 0.23598365485668182, -0.08086620271205902, 0.001842734171077609, + 0.08458849042654037, 0.3059651553630829, -0.06515960395336151, -3.803208589553833, -0.41865429282188416, + 0.2828770875930786, 3.459416151046753, -0.00129605398979038, -9.578699111938477, -0.06560757756233215, + 0.026055261492729187, 0.0672057718038559, 0.08423102647066116, -1.3624160289764404, 0.013521464541554451, + -0.027731282636523247, -0.9650477766990662, -0.012694457545876503, -0.2116907835006714, + -0.10714730620384216, 0.0034909709356725216, 1.5338910818099976, 0.0006434338865801692, + 0.1618947833776474, -0.10659407079219818, -6.774624347686768, -0.08567759394645691, 0.5162889361381531, + 0.11074300855398178, -0.09961605817079544, -0.005474632140249014, 0.1132681593298912, 0.10878968983888626, + -0.4140564203262329, -6.274385452270508, -3.410104274749756, -0.2155490219593048, 0.13330507278442383, + -0.2973288297653198, -0.5738739371299744, 0.3465871810913086, -0.2567448318004608, -0.13507360219955444, + -0.014550707302987576, 0.039058394730091095, -0.25891509652137756, -0.30598220229148865, + -0.14163219928741455, 1.2217881679534912, -0.2967555820941925, 0.024605438113212585, -0.03864026814699173, + -1.4379907846450806, -3.0257911682128906, 10.609665870666504, -0.0002576113329268992, 0.1658751666545868, + -0.01822504773736, 0.1141287237405777, -0.3072766363620758, 2.9927172660827637, 0.42983293533325195, + 0.9799204468727112, -0.007520963903516531, 3.565046787261963, -0.18206597864627838, 1.1247198581695557, + -0.0011717785382643342, 0.0026591955684125423, 3.689824104309082, 0.03598639369010925, + -0.09997520595788956, 0.06576227396726608, -1.7916548252105713, 0.030312752351164818, -0.4527510106563568, + -0.26613515615463257, 0.025749003514647484, -0.17866003513336182, 0.18729515373706818, + 0.003528681118041277, -0.1579633355140686, 1.070467472076416, -0.20637144148349762, -0.10882926732301712, + -6.439236640930176, -0.25033196806907654, -0.26708030700683594, 0.036800775676965714, -1.9130735397338867, + 0.11082696169614792, 0.10686857253313065, 7.136363506317139, -2.1805343627929688, 0.002802944276481867, + -1.0081117153167725, -0.08366546779870987, -0.07263432443141937, -0.011199882254004478, + -0.015524221584200859, -0.008838756941258907, -0.005488056223839521, 0.6502953767776489, + -0.010726823471486568, 0.41685575246810913, -0.23590049147605896, -0.0868658497929573, + -0.07914192229509354, 0.22732190787792206, 0.0985199362039566, 0.013477811589837074, 0.5970719456672668, + -0.12020514905452728, -0.0009808604372665286, 0.1139480322599411, -0.5872443914413452, -5.610537528991699, + 0.14893069863319397, 0.44541916251182556, 0.599539041519165, -0.028194887563586235, -0.42580458521842957, + -0.24352392554283142, 0.25486475229263306, -3.251058578491211, 0.042388759553432465, -0.15446388721466064, + -0.01016155257821083, 0.07647261768579483, 4.707305431365967, -0.0834866315126419, -0.21240641176700592, + 0.34789028763771057, 0.0710633248090744, 0.013448074460029602, -0.18779638409614563, 0.022113602608442307, + -1.8543815612792969, 0.012882213108241558, 0.0508059561252594, -2.1125378608703613, 1.00347900390625, + 0.34287792444229126, 0.023498279973864555, 0.2604916989803314, -0.854418158531189, 0.3368889391422272, + -3.5361156463623047, -0.3238249719142914, 0.09940877556800842, 0.011137581430375576, -0.09505806118249893, + 1.4575674533843994, 0.18798890709877014, 0.13481135666370392, -3.1009016036987305, -0.0046508763916790485, + -0.002944883657619357, 0.008598391897976398, -0.05753857269883156, -0.007956058718264103, + 0.7023902535438538, -2.114570140838623, 0.6187217235565186, 4.448208808898926, 2.5069539546966553, + -0.10476846992969513, -0.04466601461172104, 0.32297447323799133, 0.06604880094528198, + -0.0016604098491370678, -2.8530216217041016, -1.2369211912155151, -0.02766953594982624, + -0.025159431621432304, 0.0029653196688741446, 0.04569535329937935, 0.03927958756685257, + -0.0021295847836881876, 0.024881726130843163, 0.028491219505667686, -0.0042065782472491264, + -0.05266435816884041, -0.08988969027996063, -0.04083021357655525, -0.040847159922122955, + 3.154191732406616, -0.06132543459534645, -0.7507759928703308, -0.029571423307061195, 0.03537856787443161, + -1.4017058610916138, 0.3888748586177826, 0.9719987511634827, -0.010947618633508682, 3.847195863723755, + 1.015498161315918, 0.012234801426529884, -0.3849196434020996, 0.5072981119155884, -0.07829593122005463, + 0.2524659037590027, -0.13102610409259796, 0.020525088533759117, 0.15267324447631836, -0.11044808477163315, + 0.008630136027932167, 0.0009689829312264919, 2.615210771560669, 0.3638320863246918, 0.2452821582555771, + 0.01092306338250637, 0.03127167001366615, -3.899691104888916, 0.16573800146579742, -0.06733611971139908, + -0.39246127009391785, 5.207749843597412, 0.05021298676729202, 0.17778877913951874, -1.4260956048965454, + 0.19870443642139435, -0.27705708146095276, -0.11092191934585571, -0.09528861939907074, + -0.5703115463256836, 0.5077508687973022, 0.3938415050506592, 0.47991737723350525, 0.7821948528289795, + -0.2891596853733063, -0.3829837143421173, -0.010832893662154675, 0.15224608778953552, + -0.00014581253344658762, 0.00025647180154919624, 0.02536843903362751, -0.06366542726755142, + -8.023703575134277, 0.027589797973632812, -0.1799485832452774, -0.2505863904953003, -0.3841714859008789, + -0.00031740960548631847, -0.04642002657055855, 0.38759565353393555, -0.05341910943388939, + -0.37632811069488525, 0.6983012557029724, -0.10781889408826828, -0.0007781427120789886, + -0.0877101942896843, -0.5221861600875854, -0.07871037721633911, -2.2496471405029297, + -0.042697690427303314, 2.38197922706604, 0.035262834280729294, 0.0695495679974556, 1.6927565336227417, + -10.396214485168457, 0.05338706448674202, -2.813828468322754, -3.691652536392212, -0.008508461527526379, + 1.570121169090271, -0.4011033773422241, -0.24479898810386658, 0.30835238099098206, 0.1998486965894699, + 0.0945337787270546, 0.39656326174736023, -0.23758645355701447, 0.1661674976348877, -0.04912934452295303, + 0.024212513118982315, -1.0319569110870361, 0.04704924300312996, -0.058226123452186584, + -1.6492913961410522, 0.6406868100166321, -0.005447663366794586, 0.19865849614143372, -0.3373563289642334, + -0.03675329312682152, -0.19241032004356384, -0.43262550234794617, -0.08300381153821945, + -0.014068910852074623, 0.11309102177619934, -0.02719084918498993, 0.2096254676580429, + -0.02292095310986042, 0.4072689712047577, 0.0003724964044522494, 0.20711149275302887, 1.0793871879577637, + 0.06120060756802559, 0.11688049882650375, -0.0023522432893514633, -0.9283630847930908, 2.477475881576538, + 0.26047614216804504, 0.7143173813819885, -1.4795730113983154, -0.15119962394237518, -1.4587875604629517, + -0.03378799930214882, -0.3518248498439789, 0.1747346669435501, 0.002720446093007922, 0.865147590637207, + 0.015568590722978115, 0.1952929049730301, 0.1818414330482483, -1.4265116453170776, 0.2012042999267578, + -0.2151491343975067, 0.11098571866750717, -0.16003955900669098, 7.798532962799072, 0.299221396446228, + -1.0280503034591675, -0.1838797777891159, 0.005458994302898645, -0.13420982658863068, + -0.06905319541692734, -1.8678100109100342, 0.40917104482650757, -0.09650467336177826, 0.2953720986843109, + 0.008414564654231071, 0.1998010128736496, 0.34882158041000366, -0.17196929454803467, 0.031611330807209015, + 0.08629407733678818, -0.1856321394443512, -0.22879824042320251, -0.09241079539060593, 0.2628664970397949, + 0.03050280548632145, 0.15829861164093018, -0.06391621381044388, -0.01048242673277855, + -0.010927671566605568, 1.0013593435287476, -0.15796290338039398, -0.10746872425079346, + -0.0013137627393007278, -0.2024063915014267, 0.0005700114998035133, 0.03609214723110199, + -0.4168614447116852, 0.12660957872867584, -0.005800928454846144, -0.5319929718971252, 0.32967525720596313, + 0.028021158650517464, -1.217489242553711, -0.09096843004226685, 2.344956636428833, 5.432365894317627, + 0.7993219494819641, -0.15543963015079498, -0.0007157247746363282, -0.08481337875127792, + -0.12131065130233765, 1.1516586542129517, -0.01504613272845745, 0.03704383224248886, 0.004402496851980686, + -0.581766664981842, 0.07592090964317322, 0.3745843768119812, -0.4989067614078522, 0.04438084363937378, + 3.175107479095459, -4.4077911376953125, -0.002988559426739812, 1.0332038402557373, 0.006027049385011196, + -0.0018258332274854183, 2.3033533096313477, -0.10134122520685196, 0.02520211599767208, + 0.005497010890394449, 0.0003968894889112562, -0.00029831190477125347, 1.2718113660812378, + -0.34650272130966187, 9.8225736618042, -6.33831787109375, -0.9639870524406433, -4.028343677520752, + 0.016925739124417305, -0.0449683852493763, 0.6271276473999023, -0.13903772830963135, -0.06179821118712425, + 5.860689640045166, -0.005071749445050955, 0.5026626586914062, 0.3309268057346344, 2.2567200660705566, + -4.23521089553833, -0.01613122597336769, -0.02665529027581215, 0.04668727517127991, 3.150425434112549, + -0.0052042026072740555, 1.067151427268982, 0.025044972077012062, -1.325771689414978, -0.1094195619225502, + 0.26904425024986267, 0.6204038262367249, 0.006285298615694046, 0.002915250603109598, -0.6165238618850708, + -4.090943813323975, 2.8669519424438477, -0.09453117847442627, -0.09316729754209518, 0.034191157668828964, + -6.707476615905762, -0.20231620967388153, -1.6191682815551758, 2.0373117923736572, -0.10501966625452042, + 0.0019581259693950415, 0.21420015394687653, 0.0156276635825634, 7.224427223205566, 0.1236666664481163, + 0.294806569814682, -0.0061331382021307945, -0.10612531006336212, -0.8333144187927246, + -0.001029952079989016, 0.38204053044319153, -0.03597458079457283, -1.41422438621521, -0.2833155691623688, + -0.0006075138808228076, -0.3701440095901489, 0.1309424191713333, 0.06839437037706375, + -0.0017361472127959132, -1.69569993019104, -0.20629459619522095, -0.5999218225479126, 0.114132359623909, + 6.6828436851501465, -0.36263933777809143, 0.41539111733436584, 0.022192703559994698, -0.06610587984323502, + -1.683022141456604, -0.2835130989551544, 0.27643388509750366, -0.6247501373291016, -1.421617865562439, + -0.08159351348876953, 0.005017416086047888, -2.026592493057251, 0.0009393739746883512, 1.760980486869812, + 0.00019237841479480267, 0.0022294363006949425, 0.22415778040885925, -0.09657209366559982, + -3.056180953979492, -0.24515365064144135, -1.7638490200042725, -1.900456428527832, 1.7747641801834106, + -0.9473960399627686, 0.27619242668151855, -0.11893711239099503, 0.7769895792007446, 0.09835439175367355, + 0.0019296495011076331, -0.043601375073194504, 0.03626292571425438, 0.1591210663318634, + 0.45964139699935913, -0.06853707879781723, -2.4563350677490234, -0.13421472907066345, + 0.040955424308776855, -0.2855738699436188, 0.11433675140142441, 0.00306497560814023, -0.48573875427246094, + -0.046301428228616714, 0.6907474994659424, -3.983771800994873, 2.3131954669952393, 0.05256381258368492, + -0.0911293551325798, -1.8945766687393188, 0.03453084081411362, 1.8747694492340088, 0.3433213233947754, + -1.1485600471496582, 1.6418366432189941, -0.13894057273864746, -3.1275031566619873, -0.9726752638816833, + -0.0012102212058380246, 3.898921251296997, 0.2646528482437134, 0.01665414497256279, -0.06312943994998932, + -3.0655016899108887, 0.024803519248962402, -0.25584203004837036, -1.3387784957885742, + -0.03684002161026001, -0.524848461151123, -0.9969499707221985, -1.8777778148651123, -0.14820723235607147, + -8.182234764099121, 0.015234949067234993, -0.010302969254553318, 0.10785042494535446, + -0.02237902209162712, -2.500221014022827, -0.006121153011918068, 0.054380882531404495, 2.607618808746338, + -0.48403942584991455, 1.7271841764450073, -0.054084569215774536, 0.04733904451131821, 0.29113033413887024, + 0.3090323507785797, -0.4069989323616028, 0.827186644077301, -0.8676308393478394, -0.18980173766613007, + 0.017093969509005547, 0.05046425387263298, 0.025303032249212265, -0.9938563704490662, 0.0307749193161726, + -0.003506980137899518, -2.145794153213501, 0.08889109641313553, -0.2744760513305664, 0.02533753775060177, + -0.008416163735091686, 1.6139867305755615, -6.39102840423584, -4.842134952545166, -0.7291613817214966, + 0.9694556593894958, 0.07247202843427658, 0.005149913020431995, 0.0029090321622788906, -0.1867554932832718, + 0.0015806574374437332, -0.04847263917326927, -0.18512502312660217, -0.04184968024492264, + -1.2331782579421997, -0.6159178018569946, 0.025481248274445534, 0.11850030720233917, -0.2734290063381195, + 0.26392263174057007, -0.2278929203748703, -2.4300358295440674, 0.0007563972030766308, 1.2603007555007935, + -0.009525062516331673, -2.5598459243774414, -0.1015859916806221, -0.3136966824531555, + -0.0023580349516123533, 3.0281076431274414, 0.0851983055472374, 0.18700700998306274, + -0.008541906252503395, -0.007119827438145876, 0.42274990677833557, -0.06235692277550697, + 0.3246764540672302, 0.047069381922483444, 0.00011004792031599209, -0.49105241894721985, + 0.041874051094055176, 0.013326031155884266, -2.525364875793457, 0.5126351714134216, -0.01582668349146843, + -0.3391125500202179, -0.1049877479672432, -0.36534854769706726, 0.027926098555326462, + 0.004374756012111902, -0.10876958072185516, 0.579942524433136, 0.34367814660072327, 0.12710949778556824, + -0.28762391209602356, 0.028134111315011978, 0.001072783605195582, -0.430772066116333, 0.21052536368370056, + 0.09690351784229279, 0.000786184798926115, 0.06906910240650177, -3.5896573066711426, 0.24118882417678833, + -3.176041841506958, 1.3121603727340698, -0.40836477279663086, -7.590582370758057, -1.9390276670455933, + -0.06406442821025848, 0.00011302023631287739, 0.013246525079011917, 0.21886053681373596, + 0.090825155377388, -0.06342892348766327, -0.14027893543243408, 0.017751706764101982, 0.11045858263969421, + -0.05397825688123703, 0.2152465432882309, 0.14184458553791046, -1.6443814039230347, -1.023624300956726, + 0.050706081092357635, -0.8185511231422424, -0.009972692467272282, -1.6231411695480347, + -0.010527506470680237, 1.5382870435714722, -2.6943516731262207, 0.965884804725647, -0.5423170924186707, + -2.0661613941192627, -0.4436858892440796, 0.0058816042728722095, -0.665194034576416, 0.8273401260375977, + 0.10996203124523163, -0.1316700130701065, 0.027179520577192307, -0.2735114097595215, -0.10301132500171661, + -1.906333565711975, -0.32074108719825745, 0.4478001892566681, -1.1052520275115967, 0.009423047304153442, + 0.5322814583778381, -0.004648196045309305, -0.009632693603634834, -0.7735386490821838, + 0.005249344743788242, 0.11850841343402863, -0.0034776863176375628, -0.1439099758863449, + 0.2767007648944855, -2.8716399669647217, -0.16290035843849182, -0.1801692247390747, 0.19117145240306854, + -0.7634338736534119, -0.29985561966896057, 0.009378351271152496, -0.6186265349388123, + -0.13845475018024445, 0.03558935597538948, -0.20145508646965027, -0.5337783694267273, 0.28876203298568726, + -0.5732369422912598, 0.03304499760270119, 0.8687714338302612, -0.2524224817752838, 0.4371426999568939, + 0.03568745777010918, 0.4382450580596924, 0.03245728462934494, -0.14247629046440125, 0.7598915696144104, + 0.30114904046058655, -0.21331092715263367, -0.0028205476701259613, 0.09227168560028076, + 0.008056613616645336, 1.635034203529358, -0.166751429438591, -0.020675446838140488, -1.2244166135787964, + -0.10547340661287308, 2.802537441253662, -0.004014655947685242, 3.690307140350342, 0.0017954192589968443, + -0.45281466841697693, 0.020796259865164757, 1.5265557765960693, 0.20084713399410248, -0.21376214921474457, + -0.025406286120414734, 1.9211277961730957, -0.7583361268043518, 4.267587661743164, -4.551294803619385, + -0.08887865394353867, 0.07695532590150833, -0.17959536612033844, 0.5096666216850281, -9.957548141479492, + -0.11618410050868988, 0.09543278813362122, 0.270590603351593, 0.024046115577220917, -0.245524600148201, + -0.307966023683548, 2.2781827449798584, -0.14958485960960388, -0.06977607309818268, 2.3428077697753906, + 0.8067795038223267, 0.9448233842849731, 0.35110360383987427, 0.4814533293247223, 0.00956026278436184, + -0.03395213186740875, 0.2255835384130478, 0.4806722402572632, 0.0005861452082172036, -0.5671629309654236, + 0.5004423260688782, 7.04985237121582, -0.41439759731292725, -0.2847898304462433, -0.10965365916490555, + 0.3427604138851166, 0.12897160649299622, -0.6046913266181946, -0.1840457171201706, 0.002393739065155387, + 0.41798320412635803, 3.0662004947662354, -0.0002512158825993538, 3.1039047241210938, -0.6795744895935059, + 0.3395420014858246, 0.33144497871398926, 6.939244270324707, 0.0011752373538911343, 0.09660591185092926, + 0.4894343912601471, 4.00507116317749, 0.005761822685599327, 0.5680277943611145, 3.315598726272583, + -0.5373169779777527, 0.44444069266319275, -0.0027646978851407766, 1.6829670667648315, -2.6327450275421143, + 7.026787757873535, 1.5361366271972656, -0.3418486714363098, -0.21611177921295166, -0.05756537616252899, + -0.023443035781383514, 0.23109422624111176, 0.21568432450294495, 1.236294150352478, -0.4581376612186432, + -4.188883304595947, 0.28338590264320374, -0.04076884314417839, -0.0198111142963171, -1.1872543096542358, + -0.062372151762247086, 2.7870607376098633, -0.6517040729522705, -1.7516529560089111, -1.8091800212860107, + -0.15286022424697876, 0.09354468435049057, 0.11334053426980972, -1.8259668350219727, + -0.017136069014668465, 0.12429703027009964, 0.00773972412571311, -1.0446529388427734, 0.2356342375278473, + 0.9886437654495239, 0.18150633573532104, 0.4682118594646454, -0.11415664851665497, 0.003153775818645954, + 0.03332524746656418, -1.1180988550186157, 4.163827896118164, -0.18159173429012299, 0.0999181717634201, + 2.058738946914673, 4.2595014572143555, 0.010485258884727955, 0.16270849108695984, -0.9842506051063538, + -0.0003203578235115856, 6.040156364440918, 1.308574914932251, -0.36853301525115967, -0.29669252038002014, + 0.2741331160068512, -0.3422505855560303, -0.7587988972663879, 2.9686927795410156, 0.8209773898124695, + -0.0007857486489228904, 0.02395622618496418, 0.05102493613958359, 0.15041151642799377, + -0.002741719363257289, -4.980742454528809, 0.2880830764770508, 0.24828404188156128, 0.15224777162075043, + 0.13059845566749573, 1.5662918090820312, -0.8474074006080627, 0.08810389041900635, -0.2630780041217804, + 0.43268874287605286, -0.001932788989506662, 0.012391841039061546, 3.4719245433807373, + -0.0024825239088386297, -0.11508434265851974, 0.40480512380599976, 0.4639693796634674, 1.0610097646713257, + -0.3626452386379242, -0.18480324745178223, 0.2711804509162903, 0.21260038018226624, -0.02785545215010643, + -0.11340377479791641, -0.027071641758084297, -0.22035427391529083, -0.10667331516742706, + 0.16908612847328186, 0.10428506135940552, 0.1233300194144249, -0.06643304973840714, 0.2004912942647934, + 0.09342359751462936, 0.03175133094191551, -1.5805491209030151, -0.4311752915382385, 0.5455954074859619, + 0.15516425669193268, -0.04940091818571091, -0.1447768211364746, 0.21044854819774628, 4.243553161621094, + 0.2046128362417221, -0.2688649892807007, 8.694140434265137, 1.6753796339035034, 0.05555065721273422, + -0.16497156023979187, 0.1828891634941101, 1.505862832069397, 0.08261094242334366, 3.104039430618286, + 5.931700706481934, 0.4487259089946747, -0.011016261763870716, 0.012441087514162064, -0.5082470774650574, + -0.11641799658536911, 0.01356368139386177, -0.01659458689391613, 7.667582988739014, -0.346441388130188, + -5.981383323669434, 7.6806206703186035, -0.0383722260594368, 0.4603561460971832, -0.16010521352291107, + -3.173022985458374, -0.4581749737262726, 0.06485684961080551, -0.4382486939430237, 0.25564998388290405, + 0.21537242829799652, 8.642731666564941, -0.00621763477101922, 1.9488575458526611, -0.030582817271351814, + 0.00024394349020440131, -0.015434199944138527, 1.1591683626174927, 0.29453498125076294, + -0.06397286057472229, 0.2931598126888275, 3.056126594543457, 0.35364240407943726, -0.07242457568645477, + 0.19602720439434052, 3.9850471019744873, -0.141653373837471, 0.7269659042358398, 0.020487932488322258, + -1.1716344356536865, -0.13872867822647095, 0.004658155608922243, 0.0002524183946661651, + -0.3965473771095276, 0.058615222573280334, -0.005210256204009056, -6.64661169052124, + -0.003978024236857891, -0.11946024745702744, -0.15917553007602692, -2.0323069095611572, 3.694988250732422, + -0.2484176903963089, 1.117063283920288, 0.04259220138192177, 5.167333126068115, -4.669308185577393, + 0.0845528095960617, 0.011936561204493046, 5.875088691711426, -0.032051537185907364, -1.2815053462982178, + -0.010812301188707352, -0.021371034905314445, -0.0037929099053144455, 0.09110360592603683, + -0.2871546745300293, -0.2895969748497009, 0.2745248079299927, -2.462489366531372, -1.4751100540161133, + 0.6493473052978516, 0.105231374502182, -0.11311662197113037, 5.5921311378479, -0.11590596288442612, + -4.284017562866211, -3.0435032844543457, 0.2767241597175598, -0.2098703682422638, -0.008056361228227615, + 3.738365888595581, 0.03918733820319176, 0.5250627994537354, -0.03754068538546562, -1.1869362592697144, + 0.0016376300482079387, -0.19201913475990295, -0.12353025376796722, -0.038338642567396164, + -0.5192811489105225, -0.07935589551925659, 2.1531660556793213, -0.002360888756811619, 0.3615436553955078, + -1.499021291732788, 0.6402538418769836, -2.886809825897217, 2.502922296524048, -0.0014745928347110748, + -0.09228605031967163, -0.14953434467315674, 0.2779182493686676, 2.071781635284424, -0.2248198240995407, + 0.5830495357513428, -0.1257641464471817, -0.06734845042228699, 0.003910396713763475, -1.285413146018982, + -5.392889976501465, -0.0003311980399303138, 8.632763862609863, 0.1819709688425064, -0.013432486914098263, + -0.019152071326971054, -0.026376325637102127 + ], + "dims": [1, 1, 1280], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 96ced2bdf9216..0b05d6c672257 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1336,6 +1336,8 @@ "add_int32.jsonc", //"and.jsonc", "asin.jsonc", + "bias-add.jsonc", + "bias-split-gelu.jsonc", "ceil.jsonc", "concat.jsonc", "concat_int32.jsonc", diff --git a/onnxruntime/contrib_ops/js/bias_add.cc b/onnxruntime/contrib_ops/js/bias_add.cc new file mode 100644 index 0000000000000..9e70dead6a5da --- /dev/null +++ b/onnxruntime/contrib_ops/js/bias_add.cc @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "bias_add.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + BiasAdd, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()), + BiasAdd); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bias_add.h b/onnxruntime/contrib_ops/js/bias_add.h new file mode 100644 index 0000000000000..62a4df9bcdf34 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bias_add.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsKernel; +JSEP_KERNEL_IMPL(BiasAdd, BiasAdd); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bias_split_gelu.cc b/onnxruntime/contrib_ops/js/bias_split_gelu.cc new file mode 100644 index 0000000000000..e16aa4367d1c7 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bias_split_gelu.cc @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "bias_split_gelu.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + BiasSplitGelu, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()), + BiasSplitGelu); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bias_split_gelu.h b/onnxruntime/contrib_ops/js/bias_split_gelu.h new file mode 100644 index 0000000000000..3b3b41c0ca1f3 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bias_split_gelu.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsKernel; +JSEP_KERNEL_IMPL(BiasSplitGelu, BiasSplitGelu); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index 0bf6a4a168e68..4641b006a7785 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -8,6 +8,8 @@ namespace contrib { namespace js { class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization); template <> @@ -19,6 +21,8 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo}; for (auto& function_table_entry : function_table) { From f8a8452a6b4d494c9d60eda6708309bcbcdea1fa Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Tue, 3 Oct 2023 13:39:50 -0700 Subject: [PATCH 06/10] [js/webgpu] fix pad operator (#17775) fix pad operator --- js/web/lib/wasm/jsep/webgpu/ops/pad.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts index c2f89fd2845df..ccaf9f16ea770 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts @@ -209,7 +209,7 @@ const createPadProgramInfo = const createPadAttributesFromInputs = (inputs: readonly TensorView[], attributes: PadAttributes): PadAttributes => { if (inputs.length > 1) { const bigInt64Pads = inputs[1].getBigInt64Array(); - const value = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : 0.0; + const value = (inputs.length >= 3 && inputs[2].data) ? inputs[2].getFloat32Array()[0] : 0.0; const inputRank = inputs[0].dims.length; const updatePads = new Int32Array(2 * inputRank).fill(0); @@ -220,7 +220,7 @@ const createPadAttributesFromInputs = (inputs: readonly TensorView[], attributes updatePads[Number(axes[i]) + inputRank] = Number(bigInt64Pads[i + axes.length]); } } else { - bigInt64Pads.forEach((i, v) => updatePads[Number(i)] = (Number(v))); + bigInt64Pads.forEach((v, i) => updatePads[Number(i)] = (Number(v))); } const pads: number[] = []; From 992f3e460945070f7239d60ba4d5eff0b1ab5897 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Wed, 4 Oct 2023 05:28:21 +0800 Subject: [PATCH 07/10] [js/webgpu] Support where (#17544) Supported type: float. int32_t, uint32_t, bool. Case where_broadcast.jsonc is not enabled due to https://github.com/microsoft/onnxruntime/issues/17405. ### Description ### Motivation and Context --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- js/web/docs/webgpu-operators.md | 1 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 2 + js/web/lib/wasm/jsep/webgpu/ops/common.ts | 35 ++++ js/web/lib/wasm/jsep/webgpu/ops/where.ts | 110 +++++++++++ js/web/test/data/ops/where.jsonc | 172 ++++++++++++++++++ js/web/test/data/ops/where_broadcast.jsonc | 84 +++++++++ js/web/test/suite-test-list.jsonc | 5 +- .../providers/js/js_execution_provider.cc | 6 + .../core/providers/js/operators/where.cc | 41 +++++ 9 files changed, 455 insertions(+), 1 deletion(-) create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/where.ts create mode 100644 js/web/test/data/ops/where.jsonc create mode 100644 js/web/test/data/ops/where_broadcast.jsonc create mode 100644 onnxruntime/core/providers/js/operators/where.cc diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 4e33368a7aa65..44003021293b0 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -96,3 +96,4 @@ Do not modify directly.* | Tile | ai.onnx(6-12,13+) | | | Transpose | ai.onnx(1-12,13+) | need perf optimization | | Unsqueeze | ai.onnx(1-10,11-12,13+) | | +| Where | ai.onnx(9-15,16+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 2fba39c939a16..40309c1849bcc 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -28,6 +28,7 @@ import {parseSplitAttributes, split} from './ops/split'; import {tile} from './ops/tile'; import {parseTransposeAttributes, transpose} from './ops/transpose'; import * as unaryOps from './ops/unary-op'; +import {where} from './ops/where'; import {ComputeContext} from './types'; export type RunFunction = (context: ComputeContext, attribute?: unknown) => void; @@ -116,4 +117,5 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['ThresholdedRelu', [unaryOps.thresholdedRelu, unaryOps.parseAlphaAttributes]], ['Tile', [tile]], ['Transpose', [transpose, parseTransposeAttributes]], + ['Where', [where]], ]); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 0ab777bfbdee9..fb800d66b59a2 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -102,6 +102,16 @@ export interface IndicesHelper { */ readonly indicesToOffset: (varIndices: string) => string; + /** + * WGSL code of an `u32` expression for getting original offset from broadcasted indices. + * + * @param varIndices - a `type.indices` expression representing the output indices. + * @param output - output IndicesHelper. + * + * @returns an `u32` expression + */ + readonly broadcastedIndicesToOffset: (varIndices: string, output: IndicesHelper) => string; + /** * WGSL code of generating an indices literal * @@ -262,6 +272,7 @@ const createIndicesHelper = const implementationUsed = { offsetToIndices: false, indicesToOffset: false, + broadcastedIndicesToOffset: false, set: false, setByIndices: false, get: false, @@ -310,6 +321,26 @@ const createIndicesHelper = return rank < 2 ? varIndices : `i2o_${name}(${varIndices})`; }; + const broadcastedIndicesToOffsetImplementation: {[key: string]: string} = {}; + const broadcastedIndicesToOffset = (varIndices: string, output: IndicesHelper) => { + implementationUsed.broadcastedIndicesToOffset = true; + const implKey = `${output.name}broadcastedIndicesTo${name}Offset`; + if (implKey in broadcastedIndicesToOffsetImplementation) { + return `${implKey}(${varIndices})`; + } + const offsets = []; + for (let i = shape.length - 1; i >= 0; i--) { + const idx = output.indicesGet('outputIndices', i + output.shape.length - shape.length); + offsets.push(`${strides[i]}u * (${idx} % ${shape[i]}u)`); + } + broadcastedIndicesToOffsetImplementation[implKey] = + `fn ${implKey}(outputIndices: ${output.type.indices}) -> u32 { + return ${offsets.length > 0 ? offsets.join('+') : '0u'}; + }`; + + return `${implKey}(${varIndices})`; + }; + const indices = (...init: ReadonlyArray) => rank === 0 ? '0u' : `${type.indices}(${init.map(normalizeDim).join(',')})`; @@ -462,6 +493,9 @@ const createIndicesHelper = if (implementationUsed.indicesToOffset) { impls.push(indicesToOffsetImplementation); } + if (implementationUsed.broadcastedIndicesToOffset) { + Object.values(broadcastedIndicesToOffsetImplementation).forEach(impl => impls.push(impl)); + } if (implementationUsed.set) { impls.push(setImplementation); } @@ -482,6 +516,7 @@ const createIndicesHelper = type, offsetToIndices, indicesToOffset, + broadcastedIndicesToOffset, indices, indicesGet, indicesSet, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts new file mode 100644 index 0000000000000..4c595bb90b4bc --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor-view'; +import {BroadcastUtil, ShapeUtil} from '../../util'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; + +import {inputVariable, outputVariable, ShaderHelper} from './common'; + +const createWhereOpProgramShader = + (shaderHelper: ShaderHelper, inputs: readonly TensorView[], dimsOutput: readonly number[], isBroadcast: boolean, + typeOutput: number) => { + const outputSize = ShapeUtil.size(dimsOutput); + const vecSize = Math.ceil(outputSize / 4); + + const output = outputVariable('outputData', typeOutput, dimsOutput, 4); + const a = inputVariable('aData', inputs[1].dataType, inputs[1].dims, 4); + const b = inputVariable('bData', inputs[2].dataType, inputs[2].dims, 4); + const c = inputVariable('cData', inputs[0].dataType, inputs[0].dims, 4); + + let assignment: string; + const expression = (a: string, b: string, c: string) => `select(${b}, ${a}, ${c})`; + if (!isBroadcast) { + assignment = output.setByOffset( + 'global_idx', + expression(a.getByOffset('global_idx'), b.getByOffset('global_idx'), c.getByOffset('global_idx'))); + } else { + const singleAssignment = (resStr: string, x: number, typeCast = '') => { + const expressionA = `aData[indexA${x}][componentA${x}]`; + const expressionB = `bData[indexB${x}][componentB${x}]`; + // eslint-disable-next-line no-bitwise + const expressionC = `bool(cData[indexC${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`; + return ` + let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; + let offsetA${x} = ${a.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; + let offsetB${x} = ${b.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; + let offsetC${x} = ${c.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; + let indexA${x} = offsetA${x} / 4u; + let indexB${x} = offsetB${x} / 4u; + let indexC${x} = offsetC${x} / 4u; + let componentA${x} = offsetA${x} % 4u; + let componentB${x} = offsetB${x} % 4u; + ${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)}); + `; + }; + if (typeOutput === DataType.bool) { + assignment = ` + var data = vec4(0); + ${singleAssignment('data', 0, 'u32')} + ${singleAssignment('data', 1, 'u32')} + ${singleAssignment('data', 2, 'u32')} + ${singleAssignment('data', 3, 'u32')} + outputData[global_idx] = dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(data));`; + } else { + assignment = ` + ${singleAssignment('outputData[global_idx]', 0)} + ${singleAssignment('outputData[global_idx]', 1)} + ${singleAssignment('outputData[global_idx]', 2)} + ${singleAssignment('outputData[global_idx]', 3)} + `; + } + } + + return ` + ${shaderHelper.declareVariables(c, a, b, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)} + ${assignment} + }`; + }; + +const createWhereOpProgramInfo = (metadata: ProgramMetadata, inputs: readonly TensorView[]): ProgramInfo => { + const dimsA = inputs[1].dims; + const dimsB = inputs[2].dims; + const dimsC = inputs[0].dims; + const outputDataType = inputs[1].dataType; + + const isBroadcast = !(ShapeUtil.areEqual(dimsA, dimsB) && ShapeUtil.areEqual(dimsB, dimsC)); + let outputShape = dimsA; + let outputSize = ShapeUtil.size(dimsA); + // TODO: deal with zero-sized tensors (eg. dims=[1,0]) + + if (isBroadcast) { + const calculatedShape = BroadcastUtil.calcShape(BroadcastUtil.calcShape(dimsA, dimsB, false)!, dimsC, false); + if (!calculatedShape) { + throw new Error('Can\'t perform where op on the given tensors'); + } + outputShape = calculatedShape; + outputSize = ShapeUtil.size(outputShape); + } + + return { + ...metadata, + getShaderSource: (shaderHelper) => + createWhereOpProgramShader(shaderHelper, inputs, outputShape, isBroadcast, outputDataType), + outputs: [{dims: outputShape, dataType: outputDataType, gpuDataType: GpuDataType.default}], + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)}) + }; +}; + +const createWhereOpProgramInfoLoader = (inputs: readonly TensorView[], name: string): ProgramInfoLoader => { + const inputTypes = [GpuDataType.default, GpuDataType.default, GpuDataType.default]; + const metadata: ProgramMetadata = {name, inputTypes}; + return {...metadata, get: () => createWhereOpProgramInfo(metadata, inputs)}; +}; + +export const where = (context: ComputeContext): void => { + context.compute(createWhereOpProgramInfoLoader(context.inputs, 'Where')); +}; diff --git a/js/web/test/data/ops/where.jsonc b/js/web/test/data/ops/where.jsonc new file mode 100644 index 0000000000000..047fd6fd7511b --- /dev/null +++ b/js/web/test/data/ops/where.jsonc @@ -0,0 +1,172 @@ +[ + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3] T[3] T[3] float32 T[3] ", + "inputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "bool" + }, + { + "data": [4.0, 8.0, 7.0, 2.0, 4.0, 8.0, 7.0, 1.0], + "dims": [8], + "type": "float32" + }, + { + "data": [1.0, 3.0, 9.0, 6.0, 1.0, 3.0, 9.0, 2.0], + "dims": [8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [4.0, 3.0, 7.0, 6.0, 4.0, 3.0, 7.0, 2.0], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3] T[3] T[3] int32 T[3] ", + "inputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "bool" + }, + { + "data": [4, 8, 7, 2, 4, 8, 7, 1], + "dims": [8], + "type": "int32" + }, + { + "data": [1, 3, 9, 6, 1, 3, 9, 2], + "dims": [8], + "type": "int32" + } + ], + "outputs": [ + { + "data": [4, 3, 7, 6, 4, 3, 7, 2], + "dims": [8], + "type": "int32" + } + ] + } + ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3] T[3] T[3] uint32 T[3] ", + "inputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "bool" + }, + { + "data": [4, 8, 7, 2, 4, 8, 7, 1], + "dims": [8], + "type": "uint32" + }, + { + "data": [1, 4294967295, 9, 6, 1, 3, 9, 2], + "dims": [8], + "type": "uint32" + } + ], + "outputs": [ + { + "data": [4, 4294967295, 7, 6, 4, 3, 7, 2], + "dims": [8], + "type": "uint32" + } + ] + } + ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3] T[3] T[3] bool T[3] ", + "inputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "bool" + }, + { + "data": [true, true, true, true, true, true, true, true], + "dims": [8], + "type": "float32" + }, + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3 3] T[3 3] T[1] float32 broadcast", + "inputs": [ + { + "data": [true, true, true, true, true, false, false, false, false], + "dims": [3, 3], + "type": "bool" + }, + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8], + "dims": [3, 3], + "type": "float32" + }, + { + "data": [-1.0], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 1, 2, 3, 4, -1, -1, -1, -1], + "dims": [3, 3], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/where_broadcast.jsonc b/js/web/test/data/ops/where_broadcast.jsonc new file mode 100644 index 0000000000000..ad97177bb101b --- /dev/null +++ b/js/web/test/data/ops/where_broadcast.jsonc @@ -0,0 +1,84 @@ +[ + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + // This failed due to: https://github.com/microsoft/onnxruntime/issues/17405. + "name": "T[3 6] T[3 6] T[1] float32 broadcast", + "inputs": [ + { + "data": [ + true, + true, + true, + true, + true, + false, + false, + false, + false, + false, + false, + true, + true, + true, + true, + true, + true, + true + ], + "dims": [3, 6], + "type": "bool" + }, + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], + "dims": [3, 6], + "type": "float32" + }, + { + "data": [-1.0], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 1, 2, 3, 4, -1, -1, -1, -1, -1, -1, 11, 12, 13, 14, 15, 16, 17], + "dims": [3, 6], + "type": "float32" + } + ] + }, + { + // This failed due to: https://github.com/microsoft/onnxruntime/issues/17405. + "name": "T[3 1] T[3 6] T[1] float32 broadcast", + "inputs": [ + { + "data": [true, false, true], + "dims": [3, 1], + "type": "bool" + }, + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], + "dims": [3, 6], + "type": "float32" + }, + { + "data": [-1.0], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 1, 2, 3, 4, 5, -1, -1, -1, -1, -1, -1, 12, 13, 14, 15, 16, 17], + "dims": [3, 6], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 0b05d6c672257..c80f0b04a9abc 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1388,7 +1388,10 @@ "tan.jsonc", "tile.jsonc", "transpose.jsonc", - "transpose_int32_uint32.jsonc" + "transpose_int32_uint32.jsonc", + "where.jsonc" + // Turn on this when https://github.com/microsoft/onnxruntime/issues/17405 is fixed. + //"where_broadcast.jsonc", //"xor.jsonc" ] }, diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 444f50958eb7e..1b1cbbf8dead9 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -229,6 +229,9 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Unsqueeze); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Unsqueeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 15, Where); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, Where); + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Transpose); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Transpose); @@ -496,6 +499,9 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, + KERNEL_CREATE_INFO_VERSIONED(9, 15, Where), + KERNEL_CREATE_INFO(16, Where), + BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/operators/where.cc b/onnxruntime/core/providers/js/operators/where.cc new file mode 100644 index 0000000000000..2f8f5e275aa98 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/where.cc @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION, \ + kJsExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", \ + {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType()}), \ + KERNEL_CLASS); + +#define REG_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION_FROM, VERSION_TO, \ + kJsExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", \ + {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType()}), \ + KERNEL_CLASS); + +JSEP_KERNEL_IMPL(Where, Where) +REG_ELEMENTWISE_VERSIONED_KERNEL(Where, 9, 15, Where); +REG_ELEMENTWISE_KERNEL(Where, 16, Where); +} // namespace js +} // namespace onnxruntime From 8e6019af2ebcca0bbb1da425bc882e74a3843015 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Tue, 3 Oct 2023 16:24:33 -0700 Subject: [PATCH 08/10] [QNN EP] Enable QNN Saver for debugging issues (#17747) ### Description - Enables option to use the QNN Saver backend for dumping QNN API calls to file. - Adds logic to read environment variable `ORT_UNIT_TEST_ENABLE_QNN_SAVER` from QNN EP unit tests. If enabled, unit tests will use the QNN Saver backend and dump files to `./saver_output/`. ### Motivation and Context QNN Saver makes it easier to debug issues when unit tests fail. The output files generated by QNN Saver can be used to replay the exact QNN API calls that lead to a specific error condition. QNN Saver dumps QNN API calls (and weights) to disk. - saver_output/saver_output.c: C file containing all QNN API calls. - saver_output/params.bin: binary file containing all input/output/parameter tensor data provided during tensor creation, op config validation, and graph execution. Enabling the QNN Saver backend has 2 note-worthy effects: 1. All QNN API calls will succeed. 2. Inference output returns dummy data. Because the output files from QNN Saver are always overwritten, it is recommended to run individual unit tests via the `--gtest_filter` command-line option. Example (linux): ```shell $ ORT_UNIT_TEST_ENABLE_QNN_SAVER=1 ./onnxruntime_test_all --gtest_filter=QnnHTPBackendTests.Resize_DownSample_Linear_AlignCorners ``` --- .../core/session/onnxruntime_c_api.h | 3 + .../qnn/builder/qnn_backend_manager.cc | 200 ++++++++++++------ .../qnn/builder/qnn_backend_manager.h | 21 +- .../providers/qnn/qnn_execution_provider.cc | 24 ++- .../providers/qnn/qnn_execution_provider.h | 1 - onnxruntime/test/onnx/main.cc | 3 + .../test/providers/qnn/qnn_basic_test.cc | 45 +++- .../test/providers/qnn/qnn_test_utils.cc | 19 +- .../test/providers/qnn/qnn_test_utils.h | 24 ++- 9 files changed, 259 insertions(+), 81 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 4b911e3482e90..486e2ff2b90a2 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -3597,6 +3597,9 @@ struct OrtApi { * "rpc_control_latency": QNN RPC control latency. * "htp_performance_mode": QNN performance mode, options: "burst", "balanced", "default", "high_performance", * "high_power_saver", "low_balanced", "low_power_saver", "power_saver", "sustained_high_performance". Default to "default". + * "qnn_saver_path": File path to the QNN Saver backend library. If specified, QNN Saver will be enabled and will + * dump QNN API calls to disk for replay/debugging. QNN Saver produces incorrect model inference results and + * may alter model/EP partitioning. Use only for debugging. * * SNPE supported keys: * "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16", diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 8e31124ce4c85..e2083371acca4 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -30,12 +30,20 @@ typedef Qnn_ErrorHandle_t (*QnnSystemInterfaceGetProvidersFn_t)(const QnnSystemI constexpr const char* QNN_PROVIDER = "ORTQNNEP"; +static Qnn_Version_t GetQnnInterfaceApiVersion(const QnnInterface_t* qnn_interface) { + return qnn_interface->apiVersion.coreApiVersion; +} + +static Qnn_Version_t GetQnnInterfaceApiVersion(const QnnSystemInterface_t* qnn_interface) { + return qnn_interface->systemApiVersion; +} + template -Status QnnBackendManager::GetQnnInterfaceProviders(const char* lib_path, - const char* interface_provider_name, - void** backend_lib_handle, - T*** interface_providers, - uint32_t& num_providers) { +Status QnnBackendManager::GetQnnInterfaceProvider(const char* lib_path, + const char* interface_provider_name, + void** backend_lib_handle, + Qnn_Version_t req_version, + T** interface_provider) { std::string error_msg; *backend_lib_handle = LoadLib(lib_path, static_cast(DlOpenFlag::DL_NOW) | static_cast(DlOpenFlag::DL_LOCAL), @@ -47,10 +55,36 @@ Status QnnBackendManager::GetQnnInterfaceProviders(const char* lib_path, GetInterfaceProviders = ResolveSymbol(*backend_lib_handle, interface_provider_name, *logger_); ORT_RETURN_IF(nullptr == GetInterfaceProviders, "Failed to get QNN providers!"); - auto result = GetInterfaceProviders((const T***)interface_providers, &num_providers); + T** interface_providers{nullptr}; + uint32_t num_providers{0}; + + auto result = GetInterfaceProviders((const T***)&interface_providers, &num_providers); ORT_RETURN_IF((QNN_SUCCESS != result || nullptr == *interface_providers || 0 == num_providers), "Failed to get QNN providers."); + bool found_valid_interface{false}; + for (size_t pIdx = 0; pIdx < num_providers; pIdx++) { + Qnn_Version_t interface_version = GetQnnInterfaceApiVersion(interface_providers[pIdx]); + + LOGS_DEFAULT(VERBOSE) << lib_path << " interface version: " << interface_version.major << "." + << interface_version.minor << "." << interface_version.patch; + + // Check the interface's API version against the required version. + // Major versions must match. The interface's minor version must be greater OR equal with a suitable patch version. + if (interface_version.major == req_version.major) { + bool minor_and_patch_version_ok = (interface_version.minor > req_version.minor) || + (interface_version.minor == req_version.minor && + interface_version.patch >= req_version.patch); + if (minor_and_patch_version_ok) { + found_valid_interface = true; + *interface_provider = interface_providers[pIdx]; + break; + } + } + } + + ORT_RETURN_IF_NOT(found_valid_interface, "Unable to find a valid interface for ", lib_path); + return Status::OK(); } @@ -76,38 +110,89 @@ void QnnBackendManager::SetQnnBackendType(uint32_t backend_id) { } Status QnnBackendManager::LoadBackend() { - QnnInterface_t** interface_providers{nullptr}; - uint32_t num_providers{0}; - auto rt = GetQnnInterfaceProviders(backend_path_.c_str(), - "QnnInterface_getProviders", - &backend_lib_handle_, - &interface_providers, - num_providers); + QnnInterface_t* backend_interface_provider{nullptr}; + auto rt = GetQnnInterfaceProvider(backend_path_.c_str(), + "QnnInterface_getProviders", + &backend_lib_handle_, + {QNN_API_VERSION_MAJOR, + QNN_API_VERSION_MINOR, + QNN_API_VERSION_PATCH}, + &backend_interface_provider); ORT_RETURN_IF_ERROR(rt); + qnn_interface_ = backend_interface_provider->QNN_INTERFACE_VER_NAME; + auto backend_id = backend_interface_provider->backendId; + SetQnnBackendType(backend_id); - bool found_valid_interface{false}; - LOGS_DEFAULT(VERBOSE) << "QNN_API_VERSION_MAJOR: " << QNN_API_VERSION_MAJOR - << " QNN_API_VERSION_MINOR: " << QNN_API_VERSION_MINOR; - for (size_t pIdx = 0; pIdx < num_providers; pIdx++) { - LOGS_DEFAULT(VERBOSE) << "interface_providers major: " << interface_providers[pIdx]->apiVersion.coreApiVersion.major - << " interface_providers minor: " << interface_providers[pIdx]->apiVersion.coreApiVersion.minor; - if (QNN_API_VERSION_MAJOR == interface_providers[pIdx]->apiVersion.coreApiVersion.major && - QNN_API_VERSION_MINOR <= interface_providers[pIdx]->apiVersion.coreApiVersion.minor) { - found_valid_interface = true; - qnn_interface_ = interface_providers[pIdx]->QNN_INTERFACE_VER_NAME; - auto backend_id = interface_providers[pIdx]->backendId; - SetQnnBackendType(backend_id); - - LOGS_DEFAULT(INFO) << "Found valid interface, version: " << QNN_API_VERSION_MAJOR - << "." << QNN_API_VERSION_MINOR - << " backend provider name: " << interface_providers[pIdx]->providerName - << " backend id: " << backend_id; - break; + Qnn_Version_t backend_interface_version = GetQnnInterfaceApiVersion(backend_interface_provider); + LOGS_DEFAULT(INFO) << "Found valid interface, version: " << backend_interface_version.major + << "." << backend_interface_version.minor << "." << backend_interface_version.patch + << " backend provider name: " << backend_interface_provider->providerName + << " backend id: " << backend_id; + + return Status::OK(); +} + +// Loads the intended backend (e.g., HTP, CPU, etc) to get its type, and then +// sets QNN Saver as the active backend. QNN op builders will still see the intended backend (e.g., HTP) +// as the backend type to ensure they emit the expected QNN API calls. +// +// QNN Saver is a "debugging" backend that serializes all QNN API calls (and weights) into local files. +// This information can be used to debug issues by replaying QNN API calls with another backend. +Status QnnBackendManager::LoadQnnSaverBackend() { + void* backend_lib_handle = nullptr; + + // Helper that unloads the intended backend library handle when the `unload_backend_lib` variable + // goes out of scope. Similar to `defer` in other languages. + auto unload_backend_lib = gsl::finally([&] { + if (backend_lib_handle != nullptr) { + auto result = UnloadLib(backend_lib_handle); + if (Status::OK() != result) { + ORT_THROW("Failed to unload backend library."); + } } - } + }); + + // Load the intended backend (e.g., HTP, CPU) to ensure it is valid and to get its type. + QnnInterface_t* backend_interface_provider{nullptr}; + auto rt = GetQnnInterfaceProvider(backend_path_.c_str(), + "QnnInterface_getProviders", + &backend_lib_handle, + {QNN_API_VERSION_MAJOR, + QNN_API_VERSION_MINOR, + QNN_API_VERSION_PATCH}, + &backend_interface_provider); + ORT_RETURN_IF_ERROR(rt); - ORT_RETURN_IF_NOT(found_valid_interface, "Unable to find a valid interface."); + // Set the "intended" backend type so that QNN builders still make the expected QNN API calls. + auto backend_id = backend_interface_provider->backendId; + SetQnnBackendType(backend_id); + + // Load the QNN Saver backend and set it as the activate backend. + QnnInterface_t* saver_interface_provider{nullptr}; + auto saver_rt = GetQnnInterfaceProvider(qnn_saver_path_.c_str(), + "QnnInterface_getProviders", + &backend_lib_handle_, // NOTE: QNN Saver library handle is set + {QNN_API_VERSION_MAJOR, + QNN_API_VERSION_MINOR, + QNN_API_VERSION_PATCH}, + &saver_interface_provider); + ORT_RETURN_IF_ERROR(saver_rt); + qnn_interface_ = saver_interface_provider->QNN_INTERFACE_VER_NAME; // NOTE: QNN Saver will provide the interfaces + + Qnn_Version_t backend_interface_version = GetQnnInterfaceApiVersion(backend_interface_provider); + Qnn_Version_t saver_interface_version = GetQnnInterfaceApiVersion(saver_interface_provider); + + LOGS_DEFAULT(INFO) << "Using QNN Saver version: " << saver_interface_version.major << "." + << saver_interface_version.minor << "." << saver_interface_version.patch + << " provider name : " << saver_interface_provider->providerName; + + LOGS_DEFAULT(INFO) << "Intended backend provider name: " << backend_interface_provider->providerName + << " backend id: " << backend_id + << " interface version: " << backend_interface_version.major + << "." << backend_interface_version.minor << "." << backend_interface_version.patch; return Status::OK(); } @@ -120,34 +205,22 @@ Status QnnBackendManager::LoadQnnSystemLib() { #endif // #ifdef _WIN32 std::filesystem::path lib_file_path(backend_path_.c_str()); std::string sys_file_path(lib_file_path.remove_filename().string() + system_lib_file); - QnnSystemInterface_t** system_interface_providers{nullptr}; - uint32_t num_providers = 0; - auto rt = GetQnnInterfaceProviders(sys_file_path.c_str(), - "QnnSystemInterface_getProviders", - &system_lib_handle_, - &system_interface_providers, - num_providers); + QnnSystemInterface_t* system_interface_provider{nullptr}; + auto rt = GetQnnInterfaceProvider(sys_file_path.c_str(), + "QnnSystemInterface_getProviders", + &system_lib_handle_, + {QNN_SYSTEM_API_VERSION_MAJOR, + QNN_SYSTEM_API_VERSION_MINOR, + QNN_SYSTEM_API_VERSION_PATCH}, + &system_interface_provider); ORT_RETURN_IF_ERROR(rt); + Qnn_Version_t system_interface_version = GetQnnInterfaceApiVersion(system_interface_provider); + qnn_sys_interface_ = system_interface_provider->QNN_SYSTEM_INTERFACE_VER_NAME; - bool found_valid_interface{false}; - for (size_t pIdx = 0; pIdx < num_providers; pIdx++) { - LOGS_DEFAULT(VERBOSE) << "system_interface_providers major: " << system_interface_providers[pIdx]->systemApiVersion.major - << " system_interface_providers minor: " << system_interface_providers[pIdx]->systemApiVersion.minor; - int64_t systems_version_major = static_cast(system_interface_providers[pIdx]->systemApiVersion.major); - int64_t systems_version_minor = static_cast(system_interface_providers[pIdx]->systemApiVersion.minor); - if (systems_version_major == QNN_SYSTEM_API_VERSION_MAJOR && - systems_version_minor >= QNN_SYSTEM_API_VERSION_MINOR) { - found_valid_interface = true; - qnn_sys_interface_ = system_interface_providers[pIdx]->QNN_SYSTEM_INTERFACE_VER_NAME; - LOGS_DEFAULT(INFO) << "Found valid system interface, version: " << QNN_API_VERSION_MAJOR - << "." << QNN_API_VERSION_MINOR - << " backend provider name: " << system_interface_providers[pIdx]->providerName; - break; - } - } - - ORT_RETURN_IF_NOT(found_valid_interface, "Unable to find a valid system interface."); + LOGS_DEFAULT(INFO) << "Found valid system interface, version: " << system_interface_version.major + << "." << system_interface_version.minor + << " backend provider name: " << system_interface_provider->providerName; return Status::OK(); } @@ -643,7 +716,12 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_ return Status::OK(); } - ORT_RETURN_IF_ERROR(LoadBackend()); + if (qnn_saver_path_.empty()) { + ORT_RETURN_IF_ERROR(LoadBackend()); + } else { + ORT_RETURN_IF_ERROR(LoadQnnSaverBackend()); + } + LOGS(logger, VERBOSE) << "LoadBackend succeed."; if (load_from_cached_context) { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 4ca63a042c103..402f842c7a4bf 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -25,14 +25,16 @@ class QnnModel; class QnnBackendManager { public: - QnnBackendManager(std::string backend_path, + QnnBackendManager(std::string&& backend_path, ProfilingLevel profiling_level, uint32_t rpc_control_latency, - HtpPerformanceMode htp_performance_mode) + HtpPerformanceMode htp_performance_mode, + std::string&& qnn_saver_path) : backend_path_(backend_path), profiling_level_(profiling_level), rpc_control_latency_(rpc_control_latency), - htp_performance_mode_(htp_performance_mode) { + htp_performance_mode_(htp_performance_mode), + qnn_saver_path_(qnn_saver_path) { } ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnBackendManager); @@ -140,6 +142,8 @@ class QnnBackendManager { Status LoadQnnSystemLib(); + Status LoadQnnSaverBackend(); + Status UnloadLib(void* handle); void* LibFunction(void* handle, const char* symbol, std::string& error_msg); @@ -155,11 +159,11 @@ class QnnBackendManager { } template - Status GetQnnInterfaceProviders(const char* lib_path, - const char* interface_provider_name, - void** backend_lib_handle, - T*** interface_providers, - uint32_t& num_providers); + Status GetQnnInterfaceProvider(const char* lib_path, + const char* interface_provider_name, + void** backend_lib_handle, + Qnn_Version_t req_version, + T** interface_provider); bool IsDevicePropertySupported(); @@ -210,6 +214,7 @@ class QnnBackendManager { #ifdef _WIN32 std::set mod_handles_; #endif + const std::string qnn_saver_path_; }; } // namespace qnn diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 7bbfe807da0f2..ec5316eb13ce1 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -104,9 +104,10 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio static const std::string BACKEND_PATH = "backend_path"; auto backend_path_pos = runtime_options_.find(BACKEND_PATH); + std::string backend_path; if (backend_path_pos != runtime_options_.end()) { - backend_path_ = backend_path_pos->second; - LOGS_DEFAULT(VERBOSE) << "Backend path: " << backend_path_; + backend_path = backend_path_pos->second; + LOGS_DEFAULT(VERBOSE) << "Backend path: " << backend_path; } else { LOGS_DEFAULT(ERROR) << "No backend path provided."; } @@ -131,10 +132,21 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio ParseHtpPerformanceMode(htp_performance_mode_pos->second); } - qnn_backend_manager_ = std::make_unique(backend_path_, - profiling_level_, - rpc_control_latency_, - htp_performance_mode_); + // Enable use of QNN Saver if the user provides a path the QNN Saver backend library. + static const std::string QNN_SAVER_PATH_KEY = "qnn_saver_path"; + std::string qnn_saver_path; + auto qnn_saver_path_pos = runtime_options_.find(QNN_SAVER_PATH_KEY); + if (qnn_saver_path_pos != runtime_options_.end()) { + qnn_saver_path = qnn_saver_path_pos->second; + LOGS_DEFAULT(VERBOSE) << "User specified QNN Saver path: " << qnn_saver_path; + } + + qnn_backend_manager_ = std::make_unique( + std::move(backend_path), + profiling_level_, + rpc_control_latency_, + htp_performance_mode_, + std::move(qnn_saver_path)); } bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 2fe507b70a6ab..3827e2044e2b1 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -58,7 +58,6 @@ class QNNExecutionProvider : public IExecutionProvider { private: ProviderOptions runtime_options_; - std::string backend_path_; qnn::ProfilingLevel profiling_level_ = qnn::ProfilingLevel::OFF; qnn::HtpPerformanceMode htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault; std::unique_ptr qnn_backend_manager_; diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 062ca4ece86bf..287d657a2ce28 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -56,6 +56,7 @@ void usage() { "\t [QNN only] [rpc_control_latency]: QNN rpc control latency. default to 10.\n" "\t [QNN only] [htp_performance_mode]: QNN performance mode, options: 'burst', 'balanced', 'default', 'high_performance', \n" "\t 'high_power_saver', 'low_balanced', 'low_power_saver', 'power_saver', 'sustained_high_performance'. Default to 'default'. \n" + "\t [QNN only] [qnn_saver_path]: QNN Saver backend path. e.g '/folderpath/libQnnSaver.so'.\n" "\t [Usage]: -e -i '| |' \n\n" "\t [Example] [For QNN EP] -e qnn -i \"profiling_level|detailed backend_path|/folderpath/libQnnCpu.so\" \n\n" "\t [SNPE only] [runtime]: SNPE runtime, options: 'CPU', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n" @@ -477,6 +478,8 @@ int real_main(int argc, char* argv[], Ort::Env& env) { std::string str = str_stream.str(); ORT_THROW("Wrong value for htp_performance_mode. select from: " + str); } + } else if (key == "qnn_saver_path") { + // no validation } else { ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'qnn_context_cache_enable', 'qnn_context_cache_path', 'profiling_level', 'rpc_control_latency', 'htp_performance_mode'])"); diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index a441e828c0cc6..5f63813d8d84e 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include #include "core/session/onnxruntime_cxx_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" @@ -172,7 +173,7 @@ TEST(QnnEP, TestDisableCPUFallback_ConflictingConfig) { // The models passed to this function are subgraphs extracted from a larger model that exhibited // shape inferencing issues on QNN. Thus, the models are expected to have a specific input/output // types and shapes. -static void RunNHWCResizeModel(const ORTCHAR_T* ort_model_path, bool use_htp) { +static void RunNHWCResizeModel(const ORTCHAR_T* ort_model_path, bool use_htp, bool enable_qnn_saver = false) { Ort::SessionOptions so; // Ensure all type/shape inference warnings result in errors! @@ -183,8 +184,14 @@ static void RunNHWCResizeModel(const ORTCHAR_T* ort_model_path, bool use_htp) { #if defined(_WIN32) options["backend_path"] = use_htp ? "QnnHtp.dll" : "QnnCpu.dll"; + if (enable_qnn_saver) { + options["qnn_saver_path"] = "QnnSaver.dll"; + } #else options["backend_path"] = use_htp ? "libQnnHtp.so" : "libQnnCpu.so"; + if (enable_qnn_saver) { + options["qnn_saver_path"] = "libQnnSaver.so"; + } #endif so.AppendExecutionProvider("QNN", options); @@ -226,7 +233,7 @@ static void RunNHWCResizeModel(const ORTCHAR_T* ort_model_path, bool use_htp) { auto typeshape = ort_output.GetTensorTypeAndShapeInfo(); std::vector output_shape = typeshape.GetShape(); - ASSERT_THAT(output_shape, ::testing::ElementsAre(1, 6, 7, 10)); + EXPECT_THAT(output_shape, ::testing::ElementsAre(1, 6, 7, 10)); } // Test shape inference of NHWC Resize operator (opset 11) that uses @@ -253,6 +260,23 @@ TEST_F(QnnCPUBackendTests, TestNHWCResizeShapeInference_sizes_opset18) { RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.onnx", false); } +// Test that QNN Saver generates the expected files for a model meant to run on the QNN CPU backend. +TEST_F(QnnCPUBackendTests, QnnSaver_OutputFiles) { + const std::filesystem::path qnn_saver_output_dir = "saver_output"; + + // Remove pre-existing QNN Saver output files. Note that fs::remove_all() can handle non-existing paths. + std::filesystem::remove_all(qnn_saver_output_dir); + ASSERT_FALSE(std::filesystem::exists(qnn_saver_output_dir)); + + RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.onnx", + false, // use_htp + true); // enable_qnn_saver + + // Check that QNN Saver output files exist. + EXPECT_TRUE(std::filesystem::exists(qnn_saver_output_dir / "saver_output.c")); + EXPECT_TRUE(std::filesystem::exists(qnn_saver_output_dir / "params.bin")); +} + #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) // Test shape inference of QDQ NHWC Resize operator (opset 18) that uses @@ -261,6 +285,23 @@ TEST_F(QnnHTPBackendTests, TestNHWCResizeShapeInference_qdq_sizes_opset18) { RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx", true); } +// Test that QNN Saver generates the expected files for a model meant to run on the QNN HTP backend. +TEST_F(QnnHTPBackendTests, QnnSaver_OutputFiles) { + const std::filesystem::path qnn_saver_output_dir = "saver_output"; + + // Remove pre-existing QNN Saver output files. Note that fs::remove_all() can handle non-existing paths. + std::filesystem::remove_all(qnn_saver_output_dir); + ASSERT_FALSE(std::filesystem::exists(qnn_saver_output_dir)); + + RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.onnx", + true, // use_htp + true); // enable_qnn_saver + + // Check that QNN Saver output files exist. + EXPECT_TRUE(std::filesystem::exists(qnn_saver_output_dir / "saver_output.c")); + EXPECT_TRUE(std::filesystem::exists(qnn_saver_output_dir / "params.bin")); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index 51df93f8853ec..a067c9c53e57a 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -9,6 +9,7 @@ #include "test/util/include/default_providers.h" #include "test/util/include/test/test_environment.h" +#include "core/platform/env_var_utils.h" #include "core/common/span_utils.h" #include "core/framework/compute_capability.h" #include "core/graph/graph.h" @@ -41,7 +42,22 @@ std::vector GetFloatDataInRange(float min_val, float max_val, size_t num_ return data; } -void RunQnnModelTest(const GetTestModelFn& build_test_case, const ProviderOptions& provider_options, +void TryEnableQNNSaver(ProviderOptions& qnn_options) { + // Allow dumping QNN API calls to file by setting an environment variable that enables the QNN Saver backend. + constexpr auto kEnableQNNSaverEnvironmentVariableName = "ORT_UNIT_TEST_ENABLE_QNN_SAVER"; + static std::optional enable_qnn_saver = onnxruntime::ParseEnvironmentVariable( + kEnableQNNSaverEnvironmentVariableName); + + if (enable_qnn_saver.has_value() && *enable_qnn_saver != 0) { +#if defined(_WIN32) + qnn_options["qnn_saver_path"] = "QnnSaver.dll"; +#else + qnn_options["qnn_saver_path"] = "libQnnSaver.so"; +#endif // defined(_WIN32) + } +} + +void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions provider_options, int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, float fp32_abs_err, logging::Severity log_severity) { EPVerificationParams verification_params; @@ -65,6 +81,7 @@ void RunQnnModelTest(const GetTestModelFn& build_test_case, const ProviderOption // Serialize the model to a string. std::string model_data; model.ToProto().SerializeToString(&model_data); + TryEnableQNNSaver(provider_options); RunAndVerifyOutputsWithEP(AsByteSpan(model_data.data(), model_data.size()), "QNN_EP_TestLogID", QnnExecutionProviderWithOptions(provider_options), helper.feeds_, verification_params); diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index 14c62f98f6a3e..b4c84d893c828 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -220,6 +220,25 @@ void InferenceModel(const std::string& model_data, const char* log_id, ExpectedEPNodeAssignment expected_ep_assignment, const NameMLValMap& feeds, std::vector& output_vals); +/** + * If the ORT_UNIT_TEST_ENABLE_QNN_SAVER environment variable is enabled (set to 1), this function modifies + * the QNN EP provider options to enable the QNN Saver backend, which dumps QNN API calls (and weights) to disk. + * + * - saver_output/saver_output.c: C file containing all QNN API calls. + * - saver_output/params.bin: binary file containing all input/output/parameter tensor data provided during tensor + * creation, op config validation, and graph execution. + * + * Enabling the QNN Saver backend has 2 note-worthy effects: + * 1. All QNN API calls will succeed. + * 2. Inference output returns dummy data. + * + * Because output files from QNN Saver are always overwritten, it is recommended to run individual unit tests via the + * --gtest_filter command-line option. Ex: --gtest_filter=QnnHTPBackendTests.Resize_DownSample_Linear_AlignCorners + * + * \param qnn_options QNN EP provider options that may be modified to enable QNN Saver. + */ +void TryEnableQNNSaver(ProviderOptions& qnn_options); + /** * Tests the accuracy of a QDQ model on QNN EP by runnning 3 inferences: * @@ -240,7 +259,7 @@ void InferenceModel(const std::string& model_data, const char* log_id, */ template inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTestQDQModelFn& qdq_model_fn, - const ProviderOptions& qnn_options, int opset_version, + ProviderOptions qnn_options, int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, float fp32_abs_err = 1e-4f, logging::Severity log_severity = logging::Severity::kERROR) { // Add kMSDomain to cover contrib op like Gelu @@ -300,6 +319,7 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe qdq_model.ToProto().SerializeToString(&qdq_model_data); // Run QDQ model on QNN EP and collect outputs. + TryEnableQNNSaver(qnn_options); std::vector qnn_qdq_outputs; InferenceModel(qdq_model_data, "qdq_model_logger", QnnExecutionProviderWithOptions(qnn_options), expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs); @@ -538,7 +558,7 @@ inline GetTestQDQModelFn BuildQDQOpTestCase(const std::string& op_typ * \param fp32_abs_err The acceptable error between CPU EP and QNN EP. * \param log_severity The logger's minimum severity level. */ -void RunQnnModelTest(const GetTestModelFn& build_test_case, const ProviderOptions& provider_options, +void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions provider_options, int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, float fp32_abs_err = 1e-5f, logging::Severity log_severity = logging::Severity::kERROR); From a05580ed5be3f6311c6184136f953eeb14f48dcd Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 4 Oct 2023 08:01:39 -0700 Subject: [PATCH 09/10] StableDiffusion XL with TensorRT EP (#17748) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Accelerate StableDiffusion XL with TensorRT EP. It is modified from TensorRT demo diffusion, and we updated the design to make the pipeline works with different backend engines. The following result is from A100 80GB with 30 steps of Base, or 30 steps Base & 30 Steps Refiner to generate 1024x1024 images. The engine is built with static input shape, and cuda graph is enabled.   | Batch Size | TRT Latency (ms) | ORT_TRT Latency (ms) | Diff -- | -- | -- | -- | -- Base | 1 | 2714 | 2679 | -1.3% Base & Refiner | 1 | 3593 | 3530 | -1.8% The test environment: onnxruntime-gpu is built from source, and the following packages or libraries are used in this test: * tensorrt==8.6.1.post1 * torch==2.2.0.dev20230920+cu121 * transformers==4.31.0 * diffusers==0.19.3 * onnx==1.14.1 * onnx-graphsurgeon==0.3.27 * polygraphy==0.47.1 * protobuf==3.20.2 * onnxruntime-gpu==1.17.0 (built from source of main branch) * CUDA 12.2.2 * cuDNN 8.9.5.29 * python 3.10.13 --- .../tools/transformers/benchmark_helper.py | 20 +- .../tools/transformers/io_binding_helper.py | 1 + .../models/stable_diffusion/README.md | 16 +- .../models/stable_diffusion/benchmark.py | 1035 +++++++++++++---- .../models/stable_diffusion/demo_txt2img.py | 94 ++ .../stable_diffusion/demo_txt2img_xl.py | 129 ++ .../models/stable_diffusion/demo_utils.py | 255 ++++ .../stable_diffusion/diffusion_models.py | 858 ++++++++++++++ .../stable_diffusion/diffusion_schedulers.py | 721 ++++++++++++ .../models/stable_diffusion/engine_builder.py | 181 +++ .../engine_builder_ort_trt.py | 263 +++++ .../engine_builder_tensorrt.py | 507 ++++++++ .../models/stable_diffusion/models.py | 368 ------ .../onnxruntime_cuda_txt2img.py | 187 +-- .../onnxruntime_tensorrt_txt2img.py | 388 +----- .../stable_diffusion/optimize_pipeline.py | 9 +- .../models/stable_diffusion/ort_optimizer.py | 34 +- .../models/stable_diffusion/ort_utils.py | 250 ++-- .../stable_diffusion/pipeline_img2img_xl.py | 232 ++++ .../pipeline_stable_diffusion.py | 429 +++++++ .../stable_diffusion/pipeline_txt2img.py | 155 +++ .../stable_diffusion/pipeline_txt2img_xl.py | 198 ++++ .../stable_diffusion/requirements-cuda.txt | 16 +- .../requirements-tensorrt.txt | 18 +- .../models/stable_diffusion/trt_utilities.py | 12 + .../python/tools/transformers/onnx_model.py | 6 +- 26 files changed, 5090 insertions(+), 1292 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py create mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py create mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py create mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py create mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py create mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py create mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py create mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_tensorrt.py delete mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/models.py create mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py create mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py create mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py create mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py create mode 100644 onnxruntime/python/tools/transformers/models/stable_diffusion/trt_utilities.py diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index 67d3c95922a87..4f898245d01bd 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -542,7 +542,7 @@ def measure_gpu_usage(self): while True: for i in range(device_count): max_gpu_usage[i] = max(max_gpu_usage[i], self.get_used_memory(i)) - time.sleep(0.005) # 2ms + time.sleep(0.005) # 5ms if not self.keep_measuring: break return [ @@ -555,7 +555,7 @@ def measure_gpu_usage(self): ] -def measure_memory(is_gpu, func, monitor_type="cuda"): +def measure_memory(is_gpu, func, monitor_type="cuda", start_memory=None): memory_monitor_type = None if monitor_type == "rocm": memory_monitor_type = RocmMemoryMonitor @@ -565,10 +565,16 @@ def measure_memory(is_gpu, func, monitor_type="cuda"): monitor = memory_monitor_type(False) if is_gpu: - memory_before_test = monitor.measure_gpu_usage() + if start_memory is not None: + memory_before_test = start_memory + else: + memory_before_test = monitor.measure_gpu_usage() if memory_before_test is None: return None + if func is None: + return memory_before_test + with ThreadPoolExecutor() as executor: monitor = memory_monitor_type() mem_thread = executor.submit(monitor.measure_gpu_usage) @@ -595,7 +601,13 @@ def measure_memory(is_gpu, func, monitor_type="cuda"): return None # CPU memory - memory_before_test = monitor.measure_cpu_usage() + if start_memory is not None: + memory_before_test = start_memory + else: + memory_before_test = monitor.measure_cpu_usage() + + if func is None: + return memory_before_test with ThreadPoolExecutor() as executor: monitor = MemoryMonitor() diff --git a/onnxruntime/python/tools/transformers/io_binding_helper.py b/onnxruntime/python/tools/transformers/io_binding_helper.py index 71c1a21d8f768..de17f195c99cc 100644 --- a/onnxruntime/python/tools/transformers/io_binding_helper.py +++ b/onnxruntime/python/tools/transformers/io_binding_helper.py @@ -283,6 +283,7 @@ def infer(self, feed_dict: Dict[str, torch.Tensor]): if name in self.input_names: if self.enable_cuda_graph: assert self.input_tensors[name].nelement() == tensor.nelement() + assert self.input_tensors[name].dtype == tensor.dtype assert tensor.device.type == "cuda" # Please install cuda-python package with a version corresponding to CUDA in your machine. from cuda import cudart diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index 7ffefdd05f215..1fbd5092a719a 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -74,15 +74,16 @@ Below is an example to optimize Stable Diffusion 1.5 in Linux. For Windows OS, p ### Setup Environment (CUDA) -It is recommended to create a Conda environment with Python 3.8, 3.9 or 3.10, and run the model with [CUDA 11.7](https://developer.nvidia.com/cuda-11-7-0-download-archive) or 11.8. +It is recommended to create a Conda environment with Python 3.8, 3.9 or 3.10, and run the model with CUDA 11.8. +If you use CUDA 12.*, you will need build onnxruntime-gpu from source. ``` conda create -n py38 python=3.8 conda activate py38 -pip install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117 +pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118 +pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com pip install -r requirements-cuda.txt ``` - -ONNX Runtime requires CUDA and [cuDNN](https://developer.nvidia.com/rdp/cudnn-download) for GPU inference. CUDA 11.7 and cuDNN 8.5 are used in our tests. +ONNX Runtime requires CUDA and [cuDNN](https://developer.nvidia.com/rdp/cudnn-download) for GPU inference. CUDA 11.8 and cuDNN 8.5 or above are recommended. #### Install Nightly (Optional) @@ -233,18 +234,21 @@ Sometime, it complains ptxas not found when there are multiple CUDA versions ins Note that torch.compile is not supported in Windows: we encountered error `Windows not yet supported for torch.compile`. So it is excluded from RTX 3060 results of Windows. -### Run Benchmark with TensorRT and TensorRT execution provider +### Run Benchmark with TensorRT or TensorRT execution provider For TensorRT installation, follow https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html. ``` pip install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117 -pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com +pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com pip install -r requirements-tensorrt.txt export CUDA_MODULE_LOADING=LAZY python benchmark.py -e tensorrt -b 1 -v 1.5 python benchmark.py -e onnxruntime -r tensorrt -b 1 -v 1.5 python benchmark.py -e onnxruntime -r tensorrt -b 1 -v 1.5 --enable_cuda_graph + +python benchmark.py -e tensorrt --height 1024 --width 1024 -s 30 -b 1 -v xl-1.0 --enable_cuda_graph +python benchmark.py -e onnxruntime -r tensorrt --height 1024 --width 1024 -s 30 -b 1 -v xl-1.0 --enable_cuda_graph ``` ### Example Benchmark output diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py index 13126f648d290..f8fda13a35b93 100755 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -10,15 +10,18 @@ import sys import time +import __init__ # noqa: F401. Walk-around to run this script directly import coloredlogs # import torch before onnxruntime so that onnxruntime uses the cuDNN in the torch package. import torch +from benchmark_helper import measure_memory SD_MODELS = { "1.5": "runwayml/stable-diffusion-v1-5", "2.0": "stabilityai/stable-diffusion-2", "2.1": "stabilityai/stable-diffusion-2-1", + "xl-1.0": "stabilityai/stable-diffusion-xl-refiner-1.0", } PROVIDERS = { @@ -43,139 +46,13 @@ def example_prompts(): "delicate elvish moonstone necklace on a velvet background, symmetrical intricate motifs, leaves, flowers, 8k", ] - return prompts + negative_prompt = "bad composition, ugly, abnormal, malformed" - -class CudaMemoryMonitor: - def __init__(self, keep_measuring=True): - self.keep_measuring = keep_measuring - - def measure_gpu_usage(self): - from py3nvml.py3nvml import ( - NVMLError, - nvmlDeviceGetCount, - nvmlDeviceGetHandleByIndex, - nvmlDeviceGetMemoryInfo, - nvmlDeviceGetName, - nvmlInit, - nvmlShutdown, - ) - - max_gpu_usage = [] - gpu_name = [] - try: - nvmlInit() - device_count = nvmlDeviceGetCount() - if not isinstance(device_count, int): - print(f"nvmlDeviceGetCount result is not integer: {device_count}") - return None - - max_gpu_usage = [0 for i in range(device_count)] - gpu_name = [nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)) for i in range(device_count)] - while True: - for i in range(device_count): - info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i)) - if isinstance(info, str): - print(f"nvmlDeviceGetMemoryInfo returns str: {info}") - return None - max_gpu_usage[i] = max(max_gpu_usage[i], info.used / 1024**2) - time.sleep(0.002) # 2ms - if not self.keep_measuring: - break - nvmlShutdown() - return [ - { - "device_id": i, - "name": gpu_name[i], - "max_used_MB": max_gpu_usage[i], - } - for i in range(device_count) - ] - except NVMLError as error: - print("Error fetching GPU information using nvml: %s", error) - return None - - -class RocmMemoryMonitor: - def __init__(self, keep_measuring=True): - self.keep_measuring = keep_measuring - rocm_smi_path = "/opt/rocm/libexec/rocm_smi" - if os.path.exists(rocm_smi_path): - if rocm_smi_path not in sys.path: - sys.path.append(rocm_smi_path) - try: - import rocm_smi - - self.rocm_smi = rocm_smi - self.rocm_smi.initializeRsmi() - except ImportError: - self.rocm_smi = None - - def get_used_memory(self, dev): - if self.rocm_smi is None: - return -1 - return self.rocm_smi.getMemInfo(dev, "VRAM")[0] / 1024 / 1024 - - def measure_gpu_usage(self): - device_count = len(self.rocm_smi.listDevices()) if self.rocm_smi is not None else 0 - max_gpu_usage = [0 for i in range(device_count)] - gpu_name = [f"GPU{i}" for i in range(device_count)] - while True: - for i in range(device_count): - max_gpu_usage[i] = max(max_gpu_usage[i], self.get_used_memory(i)) - time.sleep(0.002) # 2ms - if not self.keep_measuring: - break - return [ - { - "device_id": i, - "name": gpu_name[i], - "max_used_MB": max_gpu_usage[i], - } - for i in range(device_count) - ] + return prompts, negative_prompt def measure_gpu_memory(monitor_type, func, start_memory=None): - if monitor_type is None: - return None - - monitor = monitor_type(False) - memory_before_test = monitor.measure_gpu_usage() - - if start_memory is None: - start_memory = memory_before_test - if start_memory is None: - return None - if func is None: - return start_memory - - from concurrent.futures import ThreadPoolExecutor - - with ThreadPoolExecutor() as executor: - monitor = monitor_type() - mem_thread = executor.submit(monitor.measure_gpu_usage) - try: - fn_thread = executor.submit(func) - _ = fn_thread.result() - finally: - monitor.keep_measuring = False - max_usage = mem_thread.result() - - if max_usage is None: - return None - - print(f"GPU memory usage: before={memory_before_test} peak={max_usage}") - if len(start_memory) >= 1 and len(max_usage) >= 1 and len(start_memory) == len(max_usage): - # When there are multiple GPUs, we will check the one with maximum usage. - max_used = 0 - for i, memory_before in enumerate(start_memory): - before = memory_before["max_used_MB"] - after = max_usage[i]["max_used_MB"] - used = after - before - max_used = max(max_used, used) - return max_used - return None + return measure_memory(is_gpu=True, func=func, monitor_type=monitor_type, start_memory=start_memory) def get_ort_pipeline(model_name: str, directory: str, provider, disable_safety_checker: bool): @@ -256,7 +133,7 @@ def run_ort_pipeline( assert isinstance(pipe, OnnxStableDiffusionPipeline) - prompts = example_prompts() + prompts, negative_prompt = example_prompts() def warmup(): pipe("warm up", height, width, num_inference_steps=steps, num_images_per_prompt=batch_size) @@ -275,13 +152,12 @@ def warmup(): for j in range(batch_count): inference_start = time.time() images = pipe( - prompt, + [prompt] * batch_size, height, width, num_inference_steps=steps, - negative_prompt=None, + negative_prompt=[negative_prompt] * batch_size, guidance_scale=7.5, - num_images_per_prompt=batch_size, ).images inference_end = time.time() latency = inference_end - inference_start @@ -320,7 +196,7 @@ def run_torch_pipeline( start_memory, memory_monitor_type, ): - prompts = example_prompts() + prompts, negative_prompt = example_prompts() # total 2 runs of warm up, and measure GPU memory for CUDA EP def warmup(): @@ -342,13 +218,12 @@ def warmup(): for j in range(batch_count): inference_start = time.time() images = pipe( - prompt=prompt, + prompt=[prompt] * batch_size, height=height, width=width, num_inference_steps=steps, guidance_scale=7.5, - negative_prompt=None, - num_images_per_prompt=batch_size, + negative_prompt=[negative_prompt] * batch_size, generator=None, # torch.Generator ).images @@ -427,7 +302,7 @@ def run_ort( def export_and_run_ort( - model_name: str, + version: str, provider: str, batch_size: int, disable_safety_checker: bool, @@ -443,15 +318,19 @@ def export_and_run_ort( assert provider == "CUDAExecutionProvider" from diffusers import DDIMScheduler + from diffusion_models import PipelineInfo from onnxruntime_cuda_txt2img import OnnxruntimeCudaStableDiffusionPipeline - scheduler = DDIMScheduler.from_pretrained(model_name, subfolder="scheduler") + pipeline_info = PipelineInfo(version) + model_name = pipeline_info.name() + scheduler = DDIMScheduler.from_pretrained(model_name, subfolder="scheduler") pipe = OnnxruntimeCudaStableDiffusionPipeline.from_pretrained( model_name, scheduler=scheduler, requires_safety_checker=not disable_safety_checker, enable_cuda_graph=enable_cuda_graph, + pipeline_info=pipeline_info, ) # re-use cached folder to save ONNX models @@ -473,7 +352,7 @@ def warmup(): image_filename_prefix = get_image_filename_prefix("ort_cuda", model_name, batch_size, disable_safety_checker) latency_list = [] - prompts = example_prompts() + prompts, negative_prompt = example_prompts() for i, prompt in enumerate(prompts): if i >= num_prompts: break @@ -481,6 +360,7 @@ def warmup(): inference_start = time.time() images = pipe( [prompt] * batch_size, + negative_prompt=[negative_prompt] * batch_size, num_inference_steps=steps, ).images inference_end = time.time() @@ -514,7 +394,7 @@ def warmup(): def run_ort_trt( - model_name: str, + version: str, batch_size: int, disable_safety_checker: bool, height: int, @@ -528,8 +408,12 @@ def run_ort_trt( enable_cuda_graph: bool, ): from diffusers import DDIMScheduler + from diffusion_models import PipelineInfo from onnxruntime_tensorrt_txt2img import OnnxruntimeTensorRTStableDiffusionPipeline + pipeline_info = PipelineInfo(version) + model_name = pipeline_info.name() + assert batch_size <= max_batch_size scheduler = DDIMScheduler.from_pretrained(model_name, subfolder="scheduler") @@ -544,6 +428,7 @@ def run_ort_trt( max_batch_size=max_batch_size, onnx_opset=17, enable_cuda_graph=enable_cuda_graph, + pipeline_info=pipeline_info, ) # re-use cached folder to save ONNX models and TensorRT Engines @@ -552,7 +437,7 @@ def run_ort_trt( pipe = pipe.to("cuda") def warmup(): - pipe(["warm up"] * batch_size, num_inference_steps=steps) + pipe(["warm up"] * batch_size, negative_prompt=["negative"] * batch_size, num_inference_steps=steps) # Run warm up, and measure GPU memory of two runs # The first run has algo search so it might need more memory @@ -564,7 +449,7 @@ def warmup(): image_filename_prefix = get_image_filename_prefix("ort_trt", model_name, batch_size, disable_safety_checker) latency_list = [] - prompts = example_prompts() + prompts, negative_prompt = example_prompts() for i, prompt in enumerate(prompts): if i >= num_prompts: break @@ -572,6 +457,7 @@ def warmup(): inference_start = time.time() images = pipe( [prompt] * batch_size, + negative_prompt=[negative_prompt] * batch_size, num_inference_steps=steps, ).images inference_end = time.time() @@ -589,7 +475,7 @@ def warmup(): "model_name": model_name, "engine": "onnxruntime", "version": ort_version, - "provider": f"tensorrt{trt_version})", + "provider": f"tensorrt({trt_version})", "directory": pipe.engine_dir, "height": height, "width": width, @@ -606,7 +492,148 @@ def warmup(): } -def run_tensorrt( +def run_ort_trt_static( + work_dir: str, + version: str, + batch_size: int, + disable_safety_checker: bool, + height: int, + width: int, + steps: int, + num_prompts: int, + batch_count: int, + start_memory, + memory_monitor_type, + max_batch_size: int, + nvtx_profile: bool = False, + use_cuda_graph: bool = True, +): + print("[I] Initializing ORT TensorRT EP accelerated StableDiffusionXL txt2img pipeline (static input shape)") + + # Register TensorRT plugins + from trt_utilities import init_trt_plugins + + init_trt_plugins() + + assert batch_size <= max_batch_size + + from diffusion_models import PipelineInfo + + pipeline_info = PipelineInfo(version) + short_name = pipeline_info.short_name() + + from engine_builder import EngineType, get_engine_paths + from pipeline_txt2img import Txt2ImgPipeline + + engine_type = EngineType.ORT_TRT + onnx_dir, engine_dir, output_dir, framework_model_dir, _ = get_engine_paths(work_dir, pipeline_info, engine_type) + + # Initialize pipeline + pipeline = Txt2ImgPipeline( + pipeline_info, + scheduler="DDIM", + output_dir=output_dir, + hf_token=None, + verbose=False, + nvtx_profile=nvtx_profile, + max_batch_size=max_batch_size, + use_cuda_graph=use_cuda_graph, + framework_model_dir=framework_model_dir, + engine_type=engine_type, + ) + + # Load TensorRT engines and pytorch modules + pipeline.backend.build_engines( + engine_dir, + framework_model_dir, + onnx_dir, + 17, + opt_image_height=height, + opt_image_width=width, + opt_batch_size=batch_size, + force_engine_rebuild=False, + static_batch=True, + static_image_shape=True, + max_workspace_size=0, + device_id=torch.cuda.current_device(), + ) + + # Here we use static batch and image size, so the resource allocation only need done once. + # For dynamic batch and image size, some cost (like memory allocation) shall be included in latency. + pipeline.load_resources(height, width, batch_size) + + def warmup(): + pipeline.run( + ["warm up"] * batch_size, ["negative"] * batch_size, height, width, denoising_steps=steps, warmup=True + ) + + # Run warm up, and measure GPU memory of two runs + # The first run has algo search so it might need more memory + first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + + warmup() + + image_filename_prefix = get_image_filename_prefix("ort_trt", short_name, batch_size, disable_safety_checker) + + latency_list = [] + prompts, negative_prompt = example_prompts() + for i, prompt in enumerate(prompts): + if i >= num_prompts: + break + for j in range(batch_count): + inference_start = time.time() + # Use warmup mode here since non-warmup mode will save image to disk. + images, pipeline_time = pipeline.run( + [prompt] * batch_size, + [negative_prompt] * batch_size, + height, + width, + denoising_steps=steps, + guidance=7.5, + seed=123, + warmup=True, + ) + images = pipeline.to_pil_image( + images + ) # include image conversion time to pil image for apple-to-apple compare + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time:.1f} ms") + for k, image in enumerate(images): + image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg") + + pipeline.teardown() + + from tensorrt import __version__ as trt_version + + from onnxruntime import __version__ as ort_version + + return { + "model_name": pipeline_info.name(), + "engine": "onnxruntime", + "version": ort_version, + "provider": f"tensorrt({trt_version})", + "directory": engine_dir, + "height": height, + "width": width, + "steps": steps, + "batch_size": batch_size, + "batch_count": batch_count, + "num_prompts": num_prompts, + "average_latency": sum(latency_list) / len(latency_list), + "median_latency": statistics.median(latency_list), + "first_run_memory_MB": first_run_memory, + "second_run_memory_MB": second_run_memory, + "disable_safety_checker": disable_safety_checker, + "enable_cuda_graph": use_cuda_graph, + } + + +def run_tensorrt_static( + work_dir: str, + version: str, model_name: str, batch_size: int, disable_safety_checker: bool, @@ -618,32 +645,79 @@ def run_tensorrt( start_memory, memory_monitor_type, max_batch_size: int, + nvtx_profile: bool = False, + use_cuda_graph: bool = True, ): - from diffusers import DDIMScheduler - from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline + print("[I] Initializing TensorRT accelerated StableDiffusionXL txt2img pipeline (static input shape)") + + from cuda import cudart + + # Register TensorRT plugins + from trt_utilities import init_trt_plugins + + init_trt_plugins() assert batch_size <= max_batch_size - scheduler = DDIMScheduler.from_pretrained(model_name, subfolder="scheduler") - pipe = StableDiffusionPipeline.from_pretrained( - model_name, - custom_pipeline="stable_diffusion_tensorrt_txt2img", - revision="fp16", - torch_dtype=torch.float16, - scheduler=scheduler, - requires_safety_checker=not disable_safety_checker, - image_height=height, - image_width=width, + from diffusion_models import PipelineInfo + + pipeline_info = PipelineInfo(version) + + from engine_builder import EngineType, get_engine_paths + from pipeline_txt2img import Txt2ImgPipeline + + engine_type = EngineType.TRT + onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths( + work_dir, pipeline_info, engine_type + ) + + # Initialize pipeline + pipeline = Txt2ImgPipeline( + pipeline_info, + scheduler="DDIM", + output_dir=output_dir, + hf_token=None, + verbose=False, + nvtx_profile=nvtx_profile, max_batch_size=max_batch_size, + use_cuda_graph=True, + engine_type=engine_type, ) - # re-use cached folder to save ONNX models and TensorRT Engines - pipe.set_cached_folder(model_name, revision="fp16") + # Load TensorRT engines and pytorch modules + pipeline.backend.load_engines( + engine_dir=engine_dir, + framework_model_dir=framework_model_dir, + onnx_dir=onnx_dir, + onnx_opset=17, + opt_batch_size=batch_size, + opt_image_height=height, + opt_image_width=width, + force_export=False, + force_optimize=False, + force_build=False, + static_batch=True, + static_shape=True, + enable_refit=False, + enable_preview=False, + enable_all_tactics=False, + timing_cache=timing_cache, + onnx_refit_dir=None, + ) - pipe = pipe.to("cuda") + # activate engines + max_device_memory = max(pipeline.backend.max_device_memory(), pipeline.backend.max_device_memory()) + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + pipeline.backend.activate_engines(shared_device_memory) + + # Here we use static batch and image size, so the resource allocation only need done once. + # For dynamic batch and image size, some cost (like memory allocation) shall be included in latency. + pipeline.load_resources(height, width, batch_size) def warmup(): - pipe(["warm up"] * batch_size, num_inference_steps=steps) + pipeline.run( + ["warm up"] * batch_size, ["negative"] * batch_size, height, width, denoising_steps=steps, warmup=True + ) # Run warm up, and measure GPU memory of two runs # The first run has algo search so it might need more memory @@ -655,28 +729,225 @@ def warmup(): image_filename_prefix = get_image_filename_prefix("trt", model_name, batch_size, disable_safety_checker) latency_list = [] - prompts = example_prompts() + prompts, negative_prompt = example_prompts() for i, prompt in enumerate(prompts): if i >= num_prompts: break for j in range(batch_count): inference_start = time.time() - images = pipe( + # Use warmup mode here since non-warmup mode will save image to disk. + images, pipeline_time = pipeline.run( [prompt] * batch_size, - num_inference_steps=steps, - ).images + [negative_prompt] * batch_size, + height, + width, + denoising_steps=steps, + guidance=7.5, + seed=123, + warmup=True, + ) + images = pipeline.to_pil_image( + images + ) # include image conversion time to pil image for apple-to-apple compare inference_end = time.time() latency = inference_end - inference_start latency_list.append(latency) - print(f"Inference took {latency:.3f} seconds") + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time:.1f} ms") for k, image in enumerate(images): image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg") - from tensorrt import __version__ as trt_version + pipeline.teardown() + + import tensorrt as trt + + return { + "engine": "tensorrt", + "version": trt.__version__, + "provider": "default", + "height": height, + "width": width, + "steps": steps, + "batch_size": batch_size, + "batch_count": batch_count, + "num_prompts": num_prompts, + "average_latency": sum(latency_list) / len(latency_list), + "median_latency": statistics.median(latency_list), + "first_run_memory_MB": first_run_memory, + "second_run_memory_MB": second_run_memory, + "enable_cuda_graph": use_cuda_graph, + } + + +def run_tensorrt_static_xl( + work_dir: str, + version: str, + batch_size: int, + disable_safety_checker: bool, + height: int, + width: int, + steps: int, + num_prompts: int, + batch_count: int, + start_memory, + memory_monitor_type, + max_batch_size: int, + nvtx_profile: bool = False, + use_cuda_graph=True, +): + print("[I] Initializing TensorRT accelerated StableDiffusionXL txt2img pipeline (static input shape)") + + import tensorrt as trt + from cuda import cudart + from trt_utilities import init_trt_plugins + + # Validate image dimensions + image_height = height + image_width = width + if image_height % 8 != 0 or image_width % 8 != 0: + raise ValueError( + f"Image height and width have to be divisible by 8 but specified as: {image_height} and {image_width}." + ) + + # Register TensorRT plugins + init_trt_plugins() + + assert batch_size <= max_batch_size + + from diffusion_models import PipelineInfo + from engine_builder import EngineType, get_engine_paths + + def init_pipeline(pipeline_class, pipeline_info): + engine_type = EngineType.TRT + + onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths( + work_dir, pipeline_info, engine_type + ) + + # Initialize pipeline + pipeline = pipeline_class( + pipeline_info, + scheduler="DDIM", + output_dir=output_dir, + hf_token=None, + verbose=False, + nvtx_profile=nvtx_profile, + max_batch_size=max_batch_size, + use_cuda_graph=use_cuda_graph, + framework_model_dir=framework_model_dir, + engine_type=engine_type, + ) + + pipeline.backend.load_engines( + engine_dir=engine_dir, + framework_model_dir=framework_model_dir, + onnx_dir=onnx_dir, + onnx_opset=17, + opt_batch_size=batch_size, + opt_image_height=height, + opt_image_width=width, + force_export=False, + force_optimize=False, + force_build=False, + static_batch=True, + static_shape=True, + enable_refit=False, + enable_preview=False, + enable_all_tactics=False, + timing_cache=timing_cache, + onnx_refit_dir=None, + ) + return pipeline + + from pipeline_img2img_xl import Img2ImgXLPipeline + from pipeline_txt2img_xl import Txt2ImgXLPipeline + + base_pipeline_info = PipelineInfo(version) + demo_base = init_pipeline(Txt2ImgXLPipeline, base_pipeline_info) + + refiner_pipeline_info = PipelineInfo(version, is_sd_xl_refiner=True) + demo_refiner = init_pipeline(Img2ImgXLPipeline, refiner_pipeline_info) + + max_device_memory = max(demo_base.backend.max_device_memory(), demo_refiner.backend.max_device_memory()) + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + demo_base.backend.activate_engines(shared_device_memory) + demo_refiner.backend.activate_engines(shared_device_memory) + + # Here we use static batch and image size, so the resource allocation only need done once. + # For dynamic batch and image size, some cost (like memory allocation) shall be included in latency. + demo_base.load_resources(image_height, image_width, batch_size) + demo_refiner.load_resources(image_height, image_width, batch_size) + + def run_sd_xl_inference(prompt, negative_prompt, seed=None, warmup=False): + images, time_base = demo_base.run( + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=steps, + guidance=5.0, + warmup=warmup, + seed=seed, + return_type="latents", + ) + + images, time_refiner = demo_refiner.run( + prompt, + negative_prompt, + images, + image_height, + image_width, + denoising_steps=steps, + guidance=5.0, + warmup=warmup, + seed=seed, + ) + return images, time_base + time_refiner + + def warmup(): + run_sd_xl_inference(["warm up"] * batch_size, ["negative"] * batch_size, warmup=True) + + # Run warm up, and measure GPU memory of two runs + # The first run has algo search so it might need more memory + first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + + warmup() + + model_name = refiner_pipeline_info.name() + image_filename_prefix = get_image_filename_prefix("trt", model_name, batch_size, disable_safety_checker) + + latency_list = [] + prompts, negative_prompt = example_prompts() + for i, prompt in enumerate(prompts): + if i >= num_prompts: + break + for j in range(batch_count): + inference_start = time.time() + # Use warmup mode here since non-warmup mode will save image to disk. + if nvtx_profile: + cudart.cudaProfilerStart() + images, pipeline_time = run_sd_xl_inference( + [prompt] * batch_size, [negative_prompt] * batch_size, seed=123, warmup=True + ) + if nvtx_profile: + cudart.cudaProfilerStop() + images = demo_refiner.to_pil_image( + images + ) # include image conversion time to pil image for apple-to-apple compare + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time:.1f} ms") + for k, image in enumerate(images): + image.save(f"{image_filename_prefix}_{i}_{j}_{k}.png") + + demo_base.teardown() + demo_refiner.teardown() return { + "model_name": model_name, "engine": "tensorrt", - "version": trt_version, + "version": trt.__version__, "provider": "default", "height": height, "width": width, @@ -688,7 +959,178 @@ def warmup(): "median_latency": statistics.median(latency_list), "first_run_memory_MB": first_run_memory, "second_run_memory_MB": second_run_memory, - "enable_cuda_graph": False, + "enable_cuda_graph": use_cuda_graph, + } + + +def run_ort_trt_xl( + work_dir: str, + version: str, + batch_size: int, + disable_safety_checker: bool, + height: int, + width: int, + steps: int, + num_prompts: int, + batch_count: int, + start_memory, + memory_monitor_type, + max_batch_size: int, + nvtx_profile: bool = False, + use_cuda_graph=True, +): + from cuda import cudart + + # Validate image dimensions + image_height = height + image_width = width + if image_height % 8 != 0 or image_width % 8 != 0: + raise ValueError( + f"Image height and width have to be divisible by 8 but specified as: {image_height} and {image_width}." + ) + + assert batch_size <= max_batch_size + + from engine_builder import EngineType, get_engine_paths + + def init_pipeline(pipeline_class, pipeline_info): + engine_type = EngineType.ORT_TRT + + onnx_dir, engine_dir, output_dir, framework_model_dir, _ = get_engine_paths( + work_dir, pipeline_info, engine_type + ) + + # Initialize pipeline + pipeline = pipeline_class( + pipeline_info, + scheduler="DDIM", + output_dir=output_dir, + hf_token=None, + verbose=False, + nvtx_profile=nvtx_profile, + max_batch_size=max_batch_size, + use_cuda_graph=use_cuda_graph, + framework_model_dir=framework_model_dir, + engine_type=engine_type, + ) + + pipeline.backend.build_engines( + engine_dir, + framework_model_dir, + onnx_dir, + 17, + opt_image_height=height, + opt_image_width=width, + opt_batch_size=batch_size, + force_engine_rebuild=False, + static_batch=True, + static_image_shape=True, + max_workspace_size=0, + device_id=torch.cuda.current_device(), # TODO: might not work with CUDA_VISIBLE_DEVICES + ) + return pipeline + + from diffusion_models import PipelineInfo + from pipeline_img2img_xl import Img2ImgXLPipeline + from pipeline_txt2img_xl import Txt2ImgXLPipeline + + base_pipeline_info = PipelineInfo(version) + demo_base = init_pipeline(Txt2ImgXLPipeline, base_pipeline_info) + + refiner_pipeline_info = PipelineInfo(version, is_sd_xl_refiner=True) + demo_refiner = init_pipeline(Img2ImgXLPipeline, refiner_pipeline_info) + + demo_base.load_resources(image_height, image_width, batch_size) + demo_refiner.load_resources(image_height, image_width, batch_size) + + def run_sd_xl_inference(prompt, negative_prompt, seed=None, warmup=False): + images, time_base = demo_base.run( + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=steps, + guidance=5.0, + warmup=warmup, + seed=seed, + return_type="latents", + ) + images, time_refiner = demo_refiner.run( + prompt, + negative_prompt, + images, + image_height, + image_width, + denoising_steps=steps, + guidance=5.0, + warmup=warmup, + seed=seed, + ) + return images, time_base + time_refiner + + def warmup(): + run_sd_xl_inference(["warm up"] * batch_size, ["negative"] * batch_size, warmup=True) + + # Run warm up, and measure GPU memory of two runs + # The first run has algo search so it might need more memory + first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + + warmup() + + model_name = refiner_pipeline_info.name() + image_filename_prefix = get_image_filename_prefix("ort_trt", model_name, batch_size, disable_safety_checker) + + latency_list = [] + prompts, negative_prompt = example_prompts() + for i, prompt in enumerate(prompts): + if i >= num_prompts: + break + for j in range(batch_count): + inference_start = time.time() + # Use warmup mode here since non-warmup mode will save image to disk. + if nvtx_profile: + cudart.cudaProfilerStart() + images, pipeline_time = run_sd_xl_inference( + [prompt] * batch_size, [negative_prompt] * batch_size, seed=123, warmup=True + ) + if nvtx_profile: + cudart.cudaProfilerStop() + images = demo_refiner.to_pil_image( + images + ) # include image conversion time to pil image for apple-to-apple compare + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time:.1f} ms") + for k, image in enumerate(images): + filename = f"{image_filename_prefix}_{i}_{j}_{k}.png" + image.save(filename) + print("Image saved to", filename) + + demo_base.teardown() + demo_refiner.teardown() + + from tensorrt import __version__ as trt_version + + from onnxruntime import __version__ as ort_version + + return { + "model_name": model_name, + "engine": "onnxruntime", + "version": ort_version, + "provider": f"tensorrt{trt_version})", + "height": height, + "width": width, + "steps": steps, + "batch_size": batch_size, + "batch_count": batch_count, + "num_prompts": num_prompts, + "average_latency": sum(latency_list) / len(latency_list), + "median_latency": statistics.median(latency_list), + "first_run_memory_MB": first_run_memory, + "second_run_memory_MB": second_run_memory, + "enable_cuda_graph": use_cuda_graph, } @@ -808,6 +1250,15 @@ def parse_arguments(): help="Directory of saved onnx pipeline. It could be the output directory of optimize_pipeline.py.", ) + parser.add_argument( + "-w", + "--work_dir", + required=False, + type=str, + default=".", + help="Root directory to save exported onnx models, built engines etc.", + ) + parser.add_argument( "--enable_safety_checker", required=False, @@ -922,28 +1373,31 @@ def main(): args = parse_arguments() print(args) - if args.enable_cuda_graph: - if not (args.engine == "onnxruntime" and args.provider in ["cuda", "tensorrt"] and args.pipeline is None): - raise ValueError("The stable diffusion pipeline does not support CUDA graph.") + if args.engine == "onnxruntime": + if args.version in ["2.1"]: + # Set a flag to avoid overflow in attention, which causes black image output in SD 2.1 model. + # The environment variables shall be set before the first run of Attention or MultiHeadAttention operator. + os.environ["ORT_DISABLE_TRT_FLASH_ATTENTION"] = "1" from packaging import version from onnxruntime import __version__ as ort_version - if version.parse(ort_version) < version.parse("1.16"): - raise ValueError( - "CUDA graph requires ONNX Runtime 1.16. You can install nightly like the following:\n" - " pip uninstall onnxruntime-gpu\n" - " pip install ort-nightly-gpu -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/" - ) + if version.parse(ort_version) == version.parse("1.16.0"): + # ORT 1.16 has a bug that might trigger Attention RuntimeError when latest fusion script is applied on clip model. + # The walkaround is to enable fused causal attention, or disable Attention fusion for clip model. + os.environ["ORT_ENABLE_FUSED_CAUSAL_ATTENTION"] = "1" + + if args.enable_cuda_graph: + if not (args.engine == "onnxruntime" and args.provider in ["cuda", "tensorrt"] and args.pipeline is None): + raise ValueError("The stable diffusion pipeline does not support CUDA graph.") + + if version.parse(ort_version) < version.parse("1.16"): + raise ValueError("CUDA graph requires ONNX Runtime 1.16 or later") coloredlogs.install(fmt="%(funcName)20s: %(message)s") - memory_monitor_type = None - if args.provider in ["cuda", "tensorrt"]: - memory_monitor_type = CudaMemoryMonitor - elif args.provider == "rocm": - memory_monitor_type = RocmMemoryMonitor + memory_monitor_type = "rocm" if args.provider == "rocm" else "cuda" start_memory = measure_gpu_memory(memory_monitor_type, None) print("GPU memory used before loading models:", start_memory) @@ -951,89 +1405,157 @@ def main(): sd_model = SD_MODELS[args.version] provider = PROVIDERS[args.provider] if args.engine == "onnxruntime" and args.provider == "tensorrt": - result = run_ort_trt( - sd_model, - args.batch_size, - not args.enable_safety_checker, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, - args.max_trt_batch_size, - args.enable_cuda_graph, - ) + if "xl" in args.version: + print("Testing Txt2ImgXLPipeline with static input shape. Backend is ORT TensorRT EP.") + result = run_ort_trt_xl( + work_dir=args.work_dir, + version=args.version, + batch_size=args.batch_size, + disable_safety_checker=True, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + max_batch_size=args.max_trt_batch_size, + nvtx_profile=False, + use_cuda_graph=args.enable_cuda_graph, + ) + elif args.tuning: + print( + "Testing OnnxruntimeTensorRTStableDiffusionPipeline with {}.".format( + "static input shape" if args.enable_cuda_graph else "dynamic batch size" + ) + ) + result = run_ort_trt( + version=args.version, + batch_size=args.batch_size, + disable_safety_checker=not args.enable_safety_checker, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + max_batch_size=args.max_trt_batch_size, + enable_cuda_graph=args.enable_cuda_graph, + ) + else: + print("Testing Txt2ImgPipeline with static input shape. Backend is ORT TensorRT EP.") + result = run_ort_trt_static( + work_dir=args.work_dir, + version=args.version, + batch_size=args.batch_size, + disable_safety_checker=not args.enable_safety_checker, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + max_batch_size=args.max_trt_batch_size, + nvtx_profile=False, + use_cuda_graph=args.enable_cuda_graph, + ) + elif args.engine == "onnxruntime" and provider == "CUDAExecutionProvider" and args.pipeline is None: - print("Pipeline is not specified. Trying export and optimize onnx models...") + print( + "Testing OnnxruntimeCudaStableDiffusionPipeline with {} input shape. Backend is ORT CUDA EP.".format( + "static" if args.enable_cuda_graph else "dynamic" + ) + ) result = export_and_run_ort( - sd_model, - provider, - args.batch_size, - not args.enable_safety_checker, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, - args.enable_cuda_graph, + version=args.version, + provider=provider, + batch_size=args.batch_size, + disable_safety_checker=not args.enable_safety_checker, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + enable_cuda_graph=args.enable_cuda_graph, ) elif args.engine == "onnxruntime": assert args.pipeline and os.path.isdir( args.pipeline ), "--pipeline should be specified for the directory of ONNX models" - - if args.version in ["2.1"]: - # Set a flag to avoid overflow in attention, which causes black image output in SD 2.1 model - # This shall be done before the first inference run. - os.environ["ORT_DISABLE_TRT_FLASH_ATTENTION"] = "1" - + print(f"Testing diffusers StableDiffusionPipeline with {provider} provider and tuning={args.tuning}") result = run_ort( - sd_model, - args.pipeline, - provider, - args.batch_size, - not args.enable_safety_checker, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, - args.tuning, + model_name=sd_model, + directory=args.pipeline, + provider=provider, + batch_size=args.batch_size, + disable_safety_checker=not args.enable_safety_checker, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + tuning=args.tuning, + ) + elif args.engine == "tensorrt" and "xl" in args.version: + print("Testing Txt2ImgXLPipeline with static input shape. Backend is TensorRT.") + result = run_tensorrt_static_xl( + work_dir=args.work_dir, + version=args.version, + batch_size=args.batch_size, + disable_safety_checker=True, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + max_batch_size=args.max_trt_batch_size, + nvtx_profile=False, + use_cuda_graph=args.enable_cuda_graph, ) elif args.engine == "tensorrt": - result = run_tensorrt( - sd_model, - args.batch_size, - not args.enable_safety_checker, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, - args.max_trt_batch_size, + print("Testing Txt2ImgPipeline with static input shape. Backend is TensorRT.") + result = run_tensorrt_static( + work_dir=args.work_dir, + version=args.version, + model_name=sd_model, + batch_size=args.batch_size, + disable_safety_checker=True, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + max_batch_size=args.max_trt_batch_size, + nvtx_profile=False, + use_cuda_graph=args.enable_cuda_graph, ) else: + print( + f"Testing Txt2ImgPipeline with dynamic input shape. Backend is PyTorch: compile={args.enable_torch_compile}, xformers={args.use_xformers}." + ) result = run_torch( - sd_model, - args.batch_size, - not args.enable_safety_checker, - args.enable_torch_compile, - args.use_xformers, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, + model_name=sd_model, + batch_size=args.batch_size, + disable_safety_checker=not args.enable_safety_checker, + enable_torch_compile=args.enable_torch_compile, + use_xformers=args.use_xformers, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, ) print(result) @@ -1068,8 +1590,9 @@ def main(): if __name__ == "__main__": + import traceback + try: main() - except Exception as e: - tb = sys.exc_info() - print(e.with_traceback(tb[2])) + except Exception: + traceback.print_exception(*sys.exc_info()) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py new file mode 100644 index 0000000000000..f6e00063a6391 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py @@ -0,0 +1,94 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import coloredlogs +from cuda import cudart +from demo_utils import init_pipeline, parse_arguments, repeat_prompt +from diffusion_models import PipelineInfo +from engine_builder import EngineType, get_engine_type +from pipeline_txt2img import Txt2ImgPipeline + +if __name__ == "__main__": + coloredlogs.install(fmt="%(funcName)20s: %(message)s") + + args = parse_arguments(is_xl=False, description="Options for Stable Diffusion Demo") + prompt, negative_prompt = repeat_prompt(args) + + image_height = args.height + image_width = args.width + + # Register TensorRT plugins + engine_type = get_engine_type(args.engine) + if engine_type == EngineType.TRT: + from trt_utilities import init_trt_plugins + + init_trt_plugins() + + max_batch_size = 16 + if engine_type != EngineType.ORT_CUDA and (args.build_dynamic_shape or image_height > 512 or image_width > 512): + max_batch_size = 4 + + batch_size = len(prompt) + if batch_size > max_batch_size: + raise ValueError( + f"Batch size {len(prompt)} is larger than allowed {max_batch_size}. If dynamic shape is used, then maximum batch size is 4" + ) + + pipeline_info = PipelineInfo(args.version) + pipeline = init_pipeline(Txt2ImgPipeline, pipeline_info, engine_type, args, max_batch_size, batch_size) + + if engine_type == EngineType.TRT: + max_device_memory = max(pipeline.backend.max_device_memory(), pipeline.backend.max_device_memory()) + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + pipeline.backend.activate_engines(shared_device_memory) + + pipeline.load_resources(image_height, image_width, batch_size) + + def run_inference(warmup=False): + return pipeline.run( + prompt, + negative_prompt, + image_height, + image_width, + warmup=warmup, + denoising_steps=args.denoising_steps, + guidance=args.guidance, + seed=args.seed, + return_type="images", + ) + + if not args.disable_cuda_graph: + # inference once to get cuda graph + _image, _latency = run_inference(warmup=True) + + print("[I] Warming up ..") + for _ in range(args.num_warmup_runs): + _image, _latency = run_inference(warmup=True) + + print("[I] Running StableDiffusion pipeline") + if args.nvtx_profile: + cudart.cudaProfilerStart() + _image, _latency = run_inference(warmup=False) + if args.nvtx_profile: + cudart.cudaProfilerStop() + + pipeline.teardown() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py new file mode 100644 index 0000000000000..c3a2e4e293cc8 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py @@ -0,0 +1,129 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import coloredlogs +from cuda import cudart +from demo_utils import init_pipeline, parse_arguments, repeat_prompt +from diffusion_models import PipelineInfo +from engine_builder import EngineType, get_engine_type +from pipeline_img2img_xl import Img2ImgXLPipeline +from pipeline_txt2img_xl import Txt2ImgXLPipeline + +if __name__ == "__main__": + coloredlogs.install(fmt="%(funcName)20s: %(message)s") + args = parse_arguments(is_xl=True, description="Options for Stable Diffusion XL Demo") + prompt, negative_prompt = repeat_prompt(args) + + image_height = args.height + image_width = args.width + + # Register TensorRT plugins + engine_type = get_engine_type(args.engine) + if engine_type == EngineType.TRT: + from trt_utilities import init_trt_plugins + + init_trt_plugins() + + max_batch_size = 16 + if args.build_dynamic_shape or image_height > 512 or image_width > 512: + max_batch_size = 4 + + batch_size = len(prompt) + if batch_size > max_batch_size: + raise ValueError( + f"Batch size {len(prompt)} is larger than allowed {max_batch_size}. If dynamic shape is used, then maximum batch size is 4" + ) + + base_info = PipelineInfo(args.version, use_vae_in_xl_base=not args.enable_refiner) + base = init_pipeline(Txt2ImgXLPipeline, base_info, engine_type, args, max_batch_size, batch_size) + + if args.enable_refiner: + refiner_info = PipelineInfo(args.version, is_sd_xl_refiner=True) + refiner = init_pipeline(Img2ImgXLPipeline, refiner_info, engine_type, args, max_batch_size, batch_size) + + if engine_type == EngineType.TRT: + max_device_memory = max(base.backend.max_device_memory(), refiner.backend.max_device_memory()) + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + base.backend.activate_engines(shared_device_memory) + refiner.backend.activate_engines(shared_device_memory) + + base.load_resources(image_height, image_width, batch_size) + refiner.load_resources(image_height, image_width, batch_size) + else: + if engine_type == EngineType.TRT: + max_device_memory = max(base.backend.max_device_memory(), base.backend.max_device_memory()) + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + base.backend.activate_engines(shared_device_memory) + + base.load_resources(image_height, image_width, batch_size) + + def run_sd_xl_inference(enable_refiner: bool, warmup=False): + images, time_base = base.run( + prompt, + negative_prompt, + image_height, + image_width, + warmup=warmup, + denoising_steps=args.denoising_steps, + guidance=args.guidance, + seed=args.seed, + return_type="latents" if enable_refiner else "images", + ) + + if enable_refiner: + images, time_refiner = refiner.run( + prompt, + negative_prompt, + images, + image_height, + image_width, + warmup=warmup, + denoising_steps=args.denoising_steps, + guidance=args.guidance, + seed=args.seed, + ) + return images, time_base + time_refiner + else: + return images, time_base + + if not args.disable_cuda_graph: + # inference once to get cuda graph + images, _ = run_sd_xl_inference(args.enable_refiner, warmup=True) + + print("[I] Warming up ..") + for _ in range(args.num_warmup_runs): + images, _ = run_sd_xl_inference(args.enable_refiner, warmup=True) + + print("[I] Running StableDiffusion XL pipeline") + if args.nvtx_profile: + cudart.cudaProfilerStart() + images, pipeline_time = run_sd_xl_inference(args.enable_refiner, warmup=False) + if args.nvtx_profile: + cudart.cudaProfilerStop() + + base.teardown() + + if args.enable_refiner: + print("|------------|--------------|") + print("| {:^10} | {:>9.2f} ms |".format("e2e", pipeline_time)) + print("|------------|--------------|") + refiner.teardown() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py new file mode 100644 index 0000000000000..5fdafc463f4e2 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py @@ -0,0 +1,255 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import argparse + +import torch +from diffusion_models import PipelineInfo +from engine_builder import EngineType, get_engine_paths + + +class RawTextArgumentDefaultsHelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelpFormatter): + pass + + +def parse_arguments(is_xl: bool, description: str): + parser = argparse.ArgumentParser(description=description, formatter_class=RawTextArgumentDefaultsHelpFormatter) + + parser.add_argument( + "--engine", + type=str, + default="ORT_TRT", + choices=["ORT_TRT", "TRT"], + help="Backend engine. Default is OnnxRuntime CUDA execution provider.", + ) + + supported_versions = PipelineInfo.supported_versions(is_xl) + parser.add_argument( + "--version", + type=str, + default=supported_versions[-1] if is_xl else "1.5", + choices=supported_versions, + help="Version of Stable Diffusion" + (" XL." if is_xl else "."), + ) + + parser.add_argument( + "--height", + type=int, + default=1024 if is_xl else 512, + help="Height of image to generate (must be multiple of 8).", + ) + parser.add_argument( + "--width", type=int, default=1024 if is_xl else 512, help="Height of image to generate (must be multiple of 8)." + ) + + parser.add_argument( + "--scheduler", + type=str, + default="DDIM", + choices=["DDIM", "EulerA", "UniPC"], + help="Scheduler for diffusion process", + ) + + parser.add_argument( + "--work-dir", + default=".", + help="Root Directory to store torch or ONNX models, built engines and output images etc.", + ) + + parser.add_argument("prompt", nargs="+", help="Text prompt(s) to guide image generation.") + + parser.add_argument( + "--negative-prompt", nargs="*", default=[""], help="Optional negative prompt(s) to guide the image generation." + ) + parser.add_argument( + "--repeat-prompt", + type=int, + default=1, + choices=[1, 2, 4, 8, 16], + help="Number of times to repeat the prompt (batch size multiplier).", + ) + + parser.add_argument( + "--denoising-steps", + type=int, + default=30 if is_xl else 50, + help="Number of denoising steps" + (" in each of base and refiner." if is_xl else "."), + ) + + parser.add_argument( + "--guidance", + type=float, + default=5.0 if is_xl else 7.5, + help="Higher guidance scale encourages to generate images that are closely linked to the text prompt.", + ) + + # ONNX export + parser.add_argument( + "--onnx-opset", + type=int, + default=17, + choices=range(14, 18), + help="Select ONNX opset version to target for exported models.", + ) + parser.add_argument( + "--force-onnx-export", action="store_true", help="Force ONNX export of CLIP, UNET, and VAE models." + ) + parser.add_argument( + "--force-onnx-optimize", action="store_true", help="Force ONNX optimizations for CLIP, UNET, and VAE models." + ) + + # Framework model ckpt + parser.add_argument( + "--framework-model-dir", + default="pytorch_model", + help="Directory for HF saved models. Default is pytorch_model.", + ) + parser.add_argument("--hf-token", type=str, help="HuggingFace API access token for downloading model checkpoints.") + + # Engine build options. + parser.add_argument("--force-engine-build", action="store_true", help="Force rebuilding the TensorRT engine.") + parser.add_argument( + "--build-dynamic-batch", action="store_true", help="Build TensorRT engines to support dynamic batch size." + ) + parser.add_argument( + "--build-dynamic-shape", action="store_true", help="Build TensorRT engines to support dynamic image sizes." + ) + + # Inference related options + parser.add_argument( + "--num-warmup-runs", type=int, default=5, help="Number of warmup runs before benchmarking performance." + ) + parser.add_argument("--nvtx-profile", action="store_true", help="Enable NVTX markers for performance profiling.") + parser.add_argument("--seed", type=int, default=None, help="Seed for random generator to get consistent results.") + parser.add_argument("--disable-cuda-graph", action="store_true", help="Disable cuda graph.") + + # TensorRT only options + group = parser.add_argument_group("Options for TensorRT (--engine=TRT) only") + group.add_argument("--onnx-refit-dir", help="ONNX models to load the weights from.") + group.add_argument( + "--build-enable-refit", action="store_true", help="Enable Refit option in TensorRT engines during build." + ) + group.add_argument( + "--build-preview-features", action="store_true", help="Build TensorRT engines with preview features." + ) + group.add_argument( + "--build-all-tactics", action="store_true", help="Build TensorRT engines using all tactic sources." + ) + + # Pipeline options + if is_xl: + parser.add_argument( + "--enable-refiner", action="store_true", help="Enable refiner and run both base and refiner pipelines." + ) + + args = parser.parse_args() + + # Validate image dimensions + if args.height % 8 != 0 or args.width % 8 != 0: + raise ValueError( + f"Image height and width have to be divisible by 8 but specified as: {args.height} and {args.width}." + ) + + if (args.build_dynamic_batch or args.build_dynamic_shape) and not args.disable_cuda_graph: + print("[I] CUDA Graph is disabled since dynamic input shape is configured.") + args.disable_cuda_graph = True + + print(args) + + return args + + +def repeat_prompt(args): + if not isinstance(args.prompt, list): + raise ValueError(f"`prompt` must be of type `str` or `str` list, but is {type(args.prompt)}") + prompt = args.prompt * args.repeat_prompt + + if not isinstance(args.negative_prompt, list): + raise ValueError( + f"`--negative-prompt` must be of type `str` or `str` list, but is {type(args.negative_prompt)}" + ) + if len(args.negative_prompt) == 1: + negative_prompt = args.negative_prompt * len(prompt) + else: + negative_prompt = args.negative_prompt + + return prompt, negative_prompt + + +def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_size, batch_size): + onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths( + args.work_dir, pipeline_info, engine_type + ) + + # Initialize demo + pipeline = pipeline_class( + pipeline_info, + scheduler=args.scheduler, + output_dir=output_dir, + hf_token=args.hf_token, + verbose=False, + nvtx_profile=args.nvtx_profile, + max_batch_size=max_batch_size, + use_cuda_graph=not args.disable_cuda_graph, + framework_model_dir=framework_model_dir, + engine_type=engine_type, + ) + + if engine_type == EngineType.ORT_TRT: + # Build TensorRT EP engines and load pytorch modules + pipeline.backend.build_engines( + engine_dir, + framework_model_dir, + onnx_dir, + args.onnx_opset, + opt_image_height=args.height, + opt_image_width=args.height, + opt_batch_size=batch_size, + force_engine_rebuild=args.force_engine_build, + static_batch=not args.build_dynamic_batch, + static_image_shape=not args.build_dynamic_shape, + max_workspace_size=0, + device_id=torch.cuda.current_device(), + ) + elif engine_type == EngineType.TRT: + # Load TensorRT engines and pytorch modules + pipeline.backend.load_engines( + engine_dir, + framework_model_dir, + onnx_dir, + args.onnx_opset, + opt_batch_size=batch_size, + opt_image_height=args.height, + opt_image_width=args.height, + force_export=args.force_onnx_export, + force_optimize=args.force_onnx_optimize, + force_build=args.force_engine_build, + static_batch=not args.build_dynamic_batch, + static_shape=not args.build_dynamic_shape, + enable_refit=args.build_enable_refit, + enable_preview=args.build_preview_features, + enable_all_tactics=args.build_all_tactics, + timing_cache=timing_cache, + onnx_refit_dir=args.onnx_refit_dir, + ) + + return pipeline diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py new file mode 100644 index 0000000000000..951cd66005f4c --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py @@ -0,0 +1,858 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from stable_diffusion_tensorrt_txt2img.py in diffusers and TensorRT demo diffusion, +# which has the following license: +# +# Copyright 2023 The HuggingFace Inc. team. +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import tempfile +from typing import List, Optional + +import onnx +import onnx_graphsurgeon as gs +import torch +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from onnx import GraphProto, ModelProto, shape_inference +from ort_optimizer import OrtStableDiffusionOptimizer +from polygraphy.backend.onnx.loader import fold_constants +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from onnxruntime.transformers.onnx_model import OnnxModel + +logger = logging.getLogger(__name__) + + +class TrtOptimizer: + def __init__(self, onnx_graph): + self.graph = gs.import_onnx(onnx_graph) + + def cleanup(self): + self.graph.cleanup().toposort() + + def get_optimized_onnx_graph(self): + return gs.export_onnx(self.graph) + + def select_outputs(self, keep, names=None): + self.graph.outputs = [self.graph.outputs[o] for o in keep] + if names: + for i, name in enumerate(names): + self.graph.outputs[i].name = name + + def fold_constants(self): + onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True) + self.graph = gs.import_onnx(onnx_graph) + + def infer_shapes(self): + onnx_graph = gs.export_onnx(self.graph) + if onnx_graph.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF: + with tempfile.TemporaryDirectory() as temp_dir: + input_onnx_path = os.path.join(temp_dir, "model.onnx") + onnx.save_model( + onnx_graph, + input_onnx_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + convert_attribute=False, + ) + output_onnx_path = os.path.join(temp_dir, "model_with_shape.onnx") + onnx.shape_inference.infer_shapes_path(input_onnx_path, output_onnx_path) + onnx_graph = onnx.load(output_onnx_path) + else: + onnx_graph = shape_inference.infer_shapes(onnx_graph) + + self.graph = gs.import_onnx(onnx_graph) + + +class PipelineInfo: + def __init__( + self, version: str, is_inpaint: bool = False, is_sd_xl_refiner: bool = False, use_vae_in_xl_base=False + ): + self.version = version + self._is_inpaint = is_inpaint + self._is_sd_xl_refiner = is_sd_xl_refiner + self._use_vae_in_xl_base = use_vae_in_xl_base + + if is_sd_xl_refiner: + assert self.is_sd_xl() + + def is_inpaint(self) -> bool: + return self._is_inpaint + + def is_sd_xl(self) -> bool: + return "xl" in self.version + + def is_sd_xl_base(self) -> bool: + return self.is_sd_xl() and not self._is_sd_xl_refiner + + def is_sd_xl_refiner(self) -> bool: + return self.is_sd_xl() and self._is_sd_xl_refiner + + def use_safetensors(self) -> bool: + return self.is_sd_xl() + + def stages(self) -> List[str]: + if self.is_sd_xl_base(): + return ["clip", "clip2", "unetxl"] + (["vae"] if self._use_vae_in_xl_base else []) + + if self.is_sd_xl_refiner(): + return ["clip2", "unetxl", "vae"] + + return ["clip", "unet", "vae"] + + def vae_scaling_factor(self) -> float: + return 0.13025 if self.is_sd_xl() else 0.18215 + + @staticmethod + def supported_versions(is_xl: bool): + return ["xl-1.0"] if is_xl else ["1.4", "1.5", "2.0-base", "2.0", "2.1", "2.1-base"] + + def name(self) -> str: + if self.version == "1.4": + if self.is_inpaint(): + return "runwayml/stable-diffusion-inpainting" + else: + return "CompVis/stable-diffusion-v1-4" + elif self.version == "1.5": + if self.is_inpaint(): + return "runwayml/stable-diffusion-inpainting" + else: + return "runwayml/stable-diffusion-v1-5" + elif self.version == "2.0-base": + if self.is_inpaint(): + return "stabilityai/stable-diffusion-2-inpainting" + else: + return "stabilityai/stable-diffusion-2-base" + elif self.version == "2.0": + if self.is_inpaint(): + return "stabilityai/stable-diffusion-2-inpainting" + else: + return "stabilityai/stable-diffusion-2" + elif self.version == "2.1": + return "stabilityai/stable-diffusion-2-1" + elif self.version == "2.1-base": + return "stabilityai/stable-diffusion-2-1-base" + elif self.version == "xl-1.0": + if self.is_sd_xl_refiner(): + return "stabilityai/stable-diffusion-xl-refiner-1.0" + else: + return "stabilityai/stable-diffusion-xl-base-1.0" + + raise ValueError(f"Incorrect version {self.version}") + + def short_name(self) -> str: + return self.name().split("/")[-1].replace("stable-diffusion", "sd") + + def clip_embedding_dim(self): + # TODO: can we read from config instead + if self.version in ("1.4", "1.5"): + return 768 + elif self.version in ("2.0", "2.0-base", "2.1", "2.1-base"): + return 1024 + elif self.version in ("xl-1.0") and self.is_sd_xl_base(): + return 768 + else: + raise ValueError(f"Invalid version {self.version}") + + def clipwithproj_embedding_dim(self): + if self.version in ("xl-1.0"): + return 1280 + else: + raise ValueError(f"Invalid version {self.version}") + + def unet_embedding_dim(self): + if self.version in ("1.4", "1.5"): + return 768 + elif self.version in ("2.0", "2.0-base", "2.1", "2.1-base"): + return 1024 + elif self.version in ("xl-1.0") and self.is_sd_xl_base(): + return 2048 + elif self.version in ("xl-1.0") and self.is_sd_xl_refiner(): + return 1280 + else: + raise ValueError(f"Invalid version {self.version}") + + +class BaseModel: + def __init__( + self, + pipeline_info: PipelineInfo, + model, + device, + fp16: bool = False, + max_batch_size: int = 16, + embedding_dim: int = 768, + text_maxlen: int = 77, + ): + self.name = self.__class__.__name__ + + self.pipeline_info = pipeline_info + + self.model = model + self.fp16 = fp16 + self.device = device + + self.min_batch = 1 + self.max_batch = max_batch_size + self.min_image_shape = 256 # min image resolution: 256x256 + self.max_image_shape = 1024 # max image resolution: 1024x1024 + self.min_latent_shape = self.min_image_shape // 8 + self.max_latent_shape = self.max_image_shape // 8 + + self.embedding_dim = embedding_dim + self.text_maxlen = text_maxlen + + def get_ort_optimizer(self): + model_name_to_model_type = { + "CLIP": "clip", + "UNet": "unet", + "VAE": "vae", + "UNetXL": "unet", + "CLIPWithProj": "clip", + } + model_type = model_name_to_model_type[self.name] + return OrtStableDiffusionOptimizer(model_type) + + def get_model(self): + return self.model + + def from_pretrained(self, model_class, framework_model_dir, hf_token, subfolder, **kwargs): + model_dir = os.path.join(framework_model_dir, self.pipeline_info.name(), subfolder) + + if not os.path.exists(model_dir): + model = model_class.from_pretrained( + self.pipeline_info.name(), + subfolder=subfolder, + use_safetensors=self.pipeline_info.use_safetensors(), + use_auth_token=hf_token, + **kwargs, + ).to(self.device) + model.save_pretrained(model_dir) + else: + print(f"Load {self.name} pytorch model from: {model_dir}") + + model = model_class.from_pretrained(model_dir).to(self.device) + return model + + def load_model(self, framework_model_dir: str, hf_token: str, subfolder: str): + pass + + def get_input_names(self): + pass + + def get_output_names(self): + pass + + def get_dynamic_axes(self): + return None + + def get_sample_input(self, batch_size, image_height, image_width): + pass + + def get_profile_id(self, batch_size, image_height, image_width, static_batch, static_image_shape): + """For TensorRT EP""" + ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + _, + _, + _, + _, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + + profile_id = f"_b_{batch_size}" if static_batch else f"_b_{min_batch}_{max_batch}" + + if self.name != "CLIP": + if static_image_shape: + profile_id += f"_h_{image_height}_w_{image_width}" + else: + profile_id += f"_h_{min_image_height}_{max_image_height}_w_{min_image_width}_{max_image_width}" + + return profile_id + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + """For TensorRT""" + return None + + def get_shape_dict(self, batch_size, image_height, image_width): + return None + + def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True): + optimizer = self.get_ort_optimizer() + optimizer.optimize(input_onnx_path, optimized_onnx_path, to_fp16) + + def optimize_trt(self, input_onnx_path, optimized_onnx_path): + onnx_graph = onnx.load(input_onnx_path) + opt = TrtOptimizer(onnx_graph) + opt.cleanup() + opt.fold_constants() + opt.infer_shapes() + opt.cleanup() + onnx_opt_graph = opt.get_optimized_onnx_graph() + + if onnx_opt_graph.ByteSize() > onnx.checker.MAXIMUM_PROTOBUF: + onnx.save_model( + onnx_opt_graph, + optimized_onnx_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + convert_attribute=False, + ) + else: + onnx.save(onnx_opt_graph, optimized_onnx_path) + + def check_dims(self, batch_size, image_height, image_width): + assert batch_size >= self.min_batch and batch_size <= self.max_batch + assert image_height % 8 == 0 or image_width % 8 == 0 + latent_height = image_height // 8 + latent_width = image_width // 8 + assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape + assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape + return (latent_height, latent_width) + + def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_image_shape): + min_batch = batch_size if static_batch else self.min_batch + max_batch = batch_size if static_batch else self.max_batch + latent_height = image_height // 8 + latent_width = image_width // 8 + min_image_height = image_height if static_image_shape else self.min_image_shape + max_image_height = image_height if static_image_shape else self.max_image_shape + min_image_width = image_width if static_image_shape else self.min_image_shape + max_image_width = image_width if static_image_shape else self.max_image_shape + min_latent_height = latent_height if static_image_shape else self.min_latent_shape + max_latent_height = latent_height if static_image_shape else self.max_latent_shape + min_latent_width = latent_width if static_image_shape else self.min_latent_shape + max_latent_width = latent_width if static_image_shape else self.max_latent_shape + return ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) + + +class CLIP(BaseModel): + def __init__( + self, + pipeline_info: PipelineInfo, + model, + device, + max_batch_size, + embedding_dim: int = 0, + clip_skip=0, + ): + super().__init__( + pipeline_info, + model=model, + device=device, + max_batch_size=max_batch_size, + embedding_dim=embedding_dim if embedding_dim > 0 else pipeline_info.clip_embedding_dim(), + ) + self.output_hidden_state = pipeline_info.is_sd_xl() + + # see https://github.com/huggingface/diffusers/pull/5057 for more information of clip_skip. + # Clip_skip=1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + self.clip_skip = clip_skip + + def get_input_names(self): + return ["input_ids"] + + def get_output_names(self): + # The exported onnx model has no hidden_state. For SD-XL, We will add hidden_state to optimized onnx model. + return ["text_embeddings"] + + def get_dynamic_axes(self): + return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + self.check_dims(batch_size, image_height, image_width) + min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_image_shape + ) + return { + "input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + output = { + "input_ids": (batch_size, self.text_maxlen), + "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim), + } + + if self.output_hidden_state: + output["hidden_states"] = (batch_size, self.text_maxlen, self.embedding_dim) + + return output + + def get_sample_input(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return (torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device),) + + def add_hidden_states_graph_output(self, model: ModelProto, optimized_onnx_path): + graph: GraphProto = model.graph + hidden_layers = -1 + for i in range(len(graph.node)): + for j in range(len(graph.node[i].output)): + name = graph.node[i].output[j] + if "layers" in name: + hidden_layers = max(int(name.split(".")[1].split("/")[0]), hidden_layers) + + assert self.clip_skip >= 0 and self.clip_skip < hidden_layers + + node_output_name = "/text_model/encoder/layers.{}/Add_1_output_0".format(hidden_layers - 1 - self.clip_skip) + + # search the name in outputs of all node + found = False + for i in range(len(graph.node)): + for j in range(len(graph.node[i].output)): + if graph.node[i].output[j] == node_output_name: + found = True + break + if found: + break + if not found: + raise RuntimeError("Failed to find hidden_states graph output in clip") + + # Insert a Cast (fp32 -> fp16) node so that hidden_states has same data type as the first graph output. + graph_output_name = "hidden_states" + cast_node = onnx.helper.make_node("Cast", inputs=[node_output_name], outputs=[graph_output_name]) + cast_node.attribute.extend([onnx.helper.make_attribute("to", graph.output[0].type.tensor_type.elem_type)]) + + hidden_state = graph.output.add() + hidden_state.CopyFrom( + onnx.helper.make_tensor_value_info( + graph_output_name, + graph.output[0].type.tensor_type.elem_type, + ["B", self.text_maxlen, self.embedding_dim], + ) + ) + + onnx_model = OnnxModel(model) + onnx_model.add_node(cast_node) + onnx_model.save_model_to_file(optimized_onnx_path) + + def optimize_trt(self, input_onnx_path, optimized_onnx_path): + onnx_graph = onnx.load(input_onnx_path) + opt = TrtOptimizer(onnx_graph) + opt.select_outputs([0]) # delete graph output#1 + opt.cleanup() + opt.fold_constants() + opt.infer_shapes() + opt.select_outputs([0], names=["text_embeddings"]) # rename network output + opt.cleanup() + onnx_opt_graph = opt.get_optimized_onnx_graph() + if self.output_hidden_state: + self.add_hidden_states_graph_output(onnx_opt_graph, optimized_onnx_path) + else: + onnx.save(onnx_opt_graph, optimized_onnx_path) + + def load_model(self, framework_model_dir, hf_token, subfolder="text_encoder"): + return self.from_pretrained(CLIPTextModel, framework_model_dir, hf_token, subfolder) + + +class CLIPWithProj(CLIP): + def __init__( + self, + pipeline_info: PipelineInfo, + model, + device, + max_batch_size=16, + clip_skip=0, + ): + super().__init__( + pipeline_info, + model, + device=device, + max_batch_size=max_batch_size, + embedding_dim=pipeline_info.clipwithproj_embedding_dim(), + clip_skip=clip_skip, + ) + + def load_model(self, framework_model_dir, hf_token, subfolder="text_encoder_2"): + return self.from_pretrained(CLIPTextModelWithProjection, framework_model_dir, hf_token, subfolder) + + def get_shape_dict(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + output = { + "input_ids": (batch_size, self.text_maxlen), + "text_embeddings": (batch_size, self.embedding_dim), + } + + if self.output_hidden_state: + output["hidden_states"] = (batch_size, self.text_maxlen, self.embedding_dim) + + return output + + +class UNet(BaseModel): + def __init__( + self, + pipeline_info: PipelineInfo, + model, + device, + fp16=False, # used by TRT + max_batch_size=16, + text_maxlen=77, + unet_dim=4, + ): + super().__init__( + pipeline_info, + model=model, + device=device, + fp16=fp16, + max_batch_size=max_batch_size, + embedding_dim=pipeline_info.unet_embedding_dim(), + text_maxlen=text_maxlen, + ) + self.unet_dim = unet_dim + + def load_model(self, framework_model_dir, hf_token, subfolder="unet"): + options = {"variant": "fp16", "torch_dtype": torch.float16} if self.fp16 else {} + return self.from_pretrained(UNet2DConditionModel, framework_model_dir, hf_token, subfolder, **options) + + def get_input_names(self): + return ["sample", "timestep", "encoder_hidden_states"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + return { + "sample": {0: "2B", 2: "H", 3: "W"}, + "encoder_hidden_states": {0: "2B"}, + "latent": {0: "2B", 2: "H", 3: "W"}, + } + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + return { + "sample": [ + (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), + (2 * batch_size, self.unet_dim, latent_height, latent_width), + (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), + ], + "encoder_hidden_states": [ + (2 * min_batch, self.text_maxlen, self.embedding_dim), + (2 * batch_size, self.text_maxlen, self.embedding_dim), + (2 * max_batch, self.text_maxlen, self.embedding_dim), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), + "timestep": [1], + "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), + "latent": (2 * batch_size, 4, latent_height, latent_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + dtype = torch.float16 if self.fp16 else torch.float32 + return ( + torch.randn( + 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device + ), + torch.tensor([1.0], dtype=torch.float32, device=self.device), + torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), + ) + + +class UNetXL(BaseModel): + def __init__( + self, + pipeline_info: PipelineInfo, + model, + device, + fp16=False, # used by TRT + max_batch_size=16, + text_maxlen=77, + unet_dim=4, + time_dim=6, + ): + super().__init__( + pipeline_info, + model, + device=device, + fp16=fp16, + max_batch_size=max_batch_size, + embedding_dim=pipeline_info.unet_embedding_dim(), + text_maxlen=text_maxlen, + ) + self.unet_dim = unet_dim + self.time_dim = time_dim + + def load_model(self, framework_model_dir, hf_token, subfolder="unet"): + options = {"variant": "fp16", "torch_dtype": torch.float16} if self.fp16 else {} + return self.from_pretrained(UNet2DConditionModel, framework_model_dir, hf_token, subfolder, **options) + + def get_input_names(self): + return ["sample", "timestep", "encoder_hidden_states", "text_embeds", "time_ids"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + return { + "sample": {0: "2B", 2: "H", 3: "W"}, + "encoder_hidden_states": {0: "2B"}, + "latent": {0: "2B", 2: "H", 3: "W"}, + "text_embeds": {0: "2B"}, + "time_ids": {0: "2B"}, + } + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + return { + "sample": [ + (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), + (2 * batch_size, self.unet_dim, latent_height, latent_width), + (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), + ], + "encoder_hidden_states": [ + (2 * min_batch, self.text_maxlen, self.embedding_dim), + (2 * batch_size, self.text_maxlen, self.embedding_dim), + (2 * max_batch, self.text_maxlen, self.embedding_dim), + ], + "text_embeds": [(2 * min_batch, 1280), (2 * batch_size, 1280), (2 * max_batch, 1280)], + "time_ids": [ + (2 * min_batch, self.time_dim), + (2 * batch_size, self.time_dim), + (2 * max_batch, self.time_dim), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), + "timestep": (1,), + "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), + "latent": (2 * batch_size, 4, latent_height, latent_width), + "text_embeds": (2 * batch_size, 1280), + "time_ids": (2 * batch_size, self.time_dim), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + dtype = torch.float16 if self.fp16 else torch.float32 + return ( + torch.randn( + 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device + ), + torch.tensor([1.0], dtype=torch.float32, device=self.device), + torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), + { + "added_cond_kwargs": { + "text_embeds": torch.randn(2 * batch_size, 1280, dtype=dtype, device=self.device), + "time_ids": torch.randn(2 * batch_size, self.time_dim, dtype=dtype, device=self.device), + } + }, + ) + + +# VAE Decoder +class VAE(BaseModel): + def __init__(self, pipeline_info: PipelineInfo, model, device, max_batch_size): + super().__init__( + pipeline_info, + model=model, + device=device, + max_batch_size=max_batch_size, + ) + + def load_model(self, framework_model_dir, hf_token: Optional[str] = None, subfolder: str = "vae_decoder"): + model_dir = os.path.join(framework_model_dir, self.pipeline_info.name(), subfolder) + if not os.path.exists(model_dir): + vae = AutoencoderKL.from_pretrained( + self.pipeline_info.name(), + subfolder="vae", + use_safetensors=self.pipeline_info.use_safetensors(), + use_auth_token=hf_token, + ).to(self.device) + vae.save_pretrained(model_dir) + else: + print(f"Load {self.name} pytorch model from: {model_dir}") + vae = AutoencoderKL.from_pretrained(model_dir).to(self.device) + + vae.forward = vae.decode + return vae + + def get_input_names(self): + return ["latent"] + + def get_output_names(self): + return ["images"] + + def get_dynamic_axes(self): + return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + return { + "latent": [ + (min_batch, 4, min_latent_height, min_latent_width), + (batch_size, 4, latent_height, latent_width), + (max_batch, 4, max_latent_height, max_latent_width), + ] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "latent": (batch_size, 4, latent_height, latent_width), + "images": (batch_size, 3, image_height, image_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return (torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device),) + + +def get_tokenizer(pipeline_info: PipelineInfo, framework_model_dir, hf_token, subfolder="tokenizer"): + tokenizer_dir = os.path.join(framework_model_dir, pipeline_info.name(), subfolder) + + if not os.path.exists(tokenizer_dir): + model = CLIPTokenizer.from_pretrained( + pipeline_info.name(), + subfolder=subfolder, + use_safetensors=pipeline_info.is_sd_xl(), + use_auth_token=hf_token, + ) + model.save_pretrained(tokenizer_dir) + else: + print(f"[I] Load tokenizer pytorch model from: {tokenizer_dir}") + model = CLIPTokenizer.from_pretrained(tokenizer_dir) + return model + + +class TorchVAEEncoder(torch.nn.Module): + def __init__(self, vae_encoder): + super().__init__() + self.vae_encoder = vae_encoder + + def forward(self, x): + return self.vae_encoder.encode(x).latent_dist.sample() + + +class VAEEncoder(BaseModel): + def __init__(self, pipeline_info: PipelineInfo, model, device, max_batch_size): + super().__init__( + pipeline_info, + model=model, + device=device, + max_batch_size=max_batch_size, + ) + + def load_model(self, framework_model_dir, hf_token, subfolder="vae_encoder"): + vae = self.from_pretrained(AutoencoderKL, framework_model_dir, hf_token, subfolder) + return TorchVAEEncoder(vae) + + def get_input_names(self): + return ["images"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + return {"images": {0: "B", 2: "8H", 3: "8W"}, "latent": {0: "B", 2: "H", 3: "W"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + self.check_dims(batch_size, image_height, image_width) + + ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + _, + _, + _, + _, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + + return { + "images": [ + (min_batch, 3, min_image_height, min_image_width), + (batch_size, 3, image_height, image_width), + (max_batch, 3, max_image_height, max_image_width), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "images": (batch_size, 3, image_height, image_width), + "latent": (batch_size, 4, latent_height, latent_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return torch.randn(batch_size, 3, image_height, image_width, dtype=torch.float32, device=self.device) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py new file mode 100644 index 0000000000000..13c450a517eba --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py @@ -0,0 +1,721 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from utilities.py of TensorRT demo diffusion, which has the following license: +# +# Copyright 2022 The HuggingFace Inc. team. +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +from typing import List, Optional + +import numpy as np +import torch + + +class DDIMScheduler: + def __init__( + self, + device="cuda", + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + clip_sample: bool = False, + set_alpha_to_one: bool = False, + steps_offset: int = 1, + prediction_type: str = "epsilon", + ): + # this schedule is very specific to the latent diffusion model. + betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + + alphas = 1.0 - betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + self.steps_offset = steps_offset + self.num_train_timesteps = num_train_timesteps + self.clip_sample = clip_sample + self.prediction_type = prediction_type + self.device = device + + def configure(self): + variance = np.zeros(self.num_inference_steps, dtype=np.float32) + for idx, timestep in enumerate(self.timesteps): + prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps + variance[idx] = self._get_variance(timestep, prev_timestep) + self.variance = torch.from_numpy(variance).to(self.device) + + timesteps = self.timesteps.long().cpu() + self.alphas_cumprod = self.alphas_cumprod[timesteps].to(self.device) + self.final_alpha_cumprod = self.final_alpha_cumprod.to(self.device) + + def scale_model_input(self, sample: torch.FloatTensor, idx, *args, **kwargs) -> torch.FloatTensor: + return sample + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def set_timesteps(self, num_inference_steps: int): + self.num_inference_steps = num_inference_steps + step_ratio = self.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(self.device) + self.timesteps += self.steps_offset + + def step( + self, + model_output, + sample, + idx, + timestep, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: torch.FloatTensor = None, + ): + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + prev_idx = idx + 1 + alpha_prod_t = self.alphas_cumprod[idx] + alpha_prod_t_prev = ( + self.alphas_cumprod[prev_idx] if prev_idx < self.num_inference_steps else self.final_alpha_cumprod + ) + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.prediction_type == "sample": + pred_original_sample = model_output + elif self.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + # predict V + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip "predicted x_0" + if self.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # o_t = sqrt((1 - a_t-1)/(1 - a_t)) * sqrt(1 - a_t/a_t-1) + variance = self.variance[idx] + std_dev_t = eta * variance ** (0.5) + + if use_clipped_model_output: + # the model_output is always re-derived from the clipped x_0 in Glide + model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if eta > 0: + # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 + device = model_output.device + if variance_noise is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" + " `variance_noise` stays `None`." + ) + + if variance_noise is None: + variance_noise = torch.randn( + model_output.shape, generator=generator, device=device, dtype=model_output.dtype + ) + variance = variance ** (0.5) * eta * variance_noise + + prev_sample = prev_sample + variance + + return prev_sample + + def add_noise(self, init_latents, noise, idx, latent_timestep): + sqrt_alpha_prod = self.alphas_cumprod[idx] ** 0.5 + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[idx]) ** 0.5 + noisy_latents = sqrt_alpha_prod * init_latents + sqrt_one_minus_alpha_prod * noise + + return noisy_latents + + +class EulerAncestralDiscreteScheduler: + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + device="cuda", + steps_offset=0, + prediction_type="epsilon", + ): + # this schedule is very specific to the latent diffusion model. + betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + + alphas = 1.0 - betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = self.sigmas.max() + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.is_scale_input_called = False + self.device = device + self.num_train_timesteps = num_train_timesteps + self.steps_offset = steps_offset + self.prediction_type = prediction_type + + def scale_model_input(self, sample: torch.FloatTensor, idx, timestep, *args, **kwargs) -> torch.FloatTensor: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + self.is_scale_input_called = True + return sample + + def set_timesteps(self, num_inference_steps: int): + self.num_inference_steps = num_inference_steps + + timesteps = np.linspace(0, self.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy() + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas).to(device=self.device) + self.timesteps = torch.from_numpy(timesteps).to(device=self.device) + + def configure(self): + dts = np.zeros(self.num_inference_steps, dtype=np.float32) + sigmas_up = np.zeros(self.num_inference_steps, dtype=np.float32) + for idx, timestep in enumerate(self.timesteps): + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + + sigma_from = self.sigmas[step_index] + sigma_to = self.sigmas[step_index + 1] + sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + dt = sigma_down - sigma + dts[idx] = dt + sigmas_up[idx] = sigma_up + + self.dts = torch.from_numpy(dts).to(self.device) + self.sigmas_up = torch.from_numpy(sigmas_up).to(self.device) + + def step( + self, + model_output, + sample, + idx, + timestep, + generator=None, + ): + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + else: + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + sigma_up = self.sigmas_up[idx] + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + + dt = self.dts[idx] + + prev_sample = sample + derivative * dt + + device = model_output.device + noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(device) + + prev_sample = prev_sample + noise * sigma_up + + return prev_sample + + def add_noise(self, original_samples, noise, idx, timestep=None): + step_index = (self.timesteps == timestep).nonzero().item() + noisy_samples = original_samples + noise * self.sigmas[step_index] + return noisy_samples + + +class UniPCMultistepScheduler: + def __init__( + self, + device="cuda", + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, + beta_end: float = 0.012, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: Optional[List[int]] = None, + ): + self.device = device + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + self.predict_x0 = predict_x0 + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector if disable_corrector else [] + self.last_sample = None + self.num_train_timesteps = num_train_timesteps + self.solver_order = solver_order + self.prediction_type = prediction_type + self.thresholding = thresholding + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.sample_max_value = sample_max_value + self.solver_type = solver_type + self.lower_order_final = lower_order_final + + def set_timesteps(self, num_inference_steps: int): + timesteps = ( + np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + + self.timesteps = torch.from_numpy(timesteps).to(self.device) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.solver_order + self.lower_order_nums = 0 + self.last_sample = None + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample + + def convert_model_output( + self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor + ) -> torch.FloatTensor: + if self.predict_x0: + if self.prediction_type == "epsilon": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.prediction_type == "sample": + x0_pred = model_output + elif self.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the UniPCMultistepScheduler." + ) + + if self.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.prediction_type == "epsilon": + return model_output + elif self.prediction_type == "sample": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = (sample - alpha_t * model_output) / sigma_t + return epsilon + elif self.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = alpha_t * model_output + sigma_t * sample + return epsilon + else: + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the UniPCMultistepScheduler." + ) + + def multistep_uni_p_bh_update( + self, + model_output: torch.FloatTensor, + prev_timestep: int, + sample: torch.FloatTensor, + order: int, + ) -> torch.FloatTensor: + timestep_list = self.timestep_list + model_output_list = self.model_outputs + + s0, t = self.timestep_list[-1], prev_timestep + m0 = model_output_list[-1] + x = sample + + lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + + h = lambda_t - lambda_s0 + + rks = [] + d1s = [] + for i in range(1, order): + si = timestep_list[-(i + 1)] + mi = model_output_list[-(i + 1)] + lambda_si = self.lambda_t[si] + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + d1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=self.device) + + r = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.solver_type == "bh1": + b_h = hh + elif self.solver_type == "bh2": + b_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + r.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / b_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + r = torch.stack(r) + b = torch.tensor(b, device=self.device) + + if len(d1s) > 0: + d1s = torch.stack(d1s, dim=1) # (B, K) + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=self.device) + else: + rhos_p = torch.linalg.solve(r[:-1, :-1], b[:-1]) + else: + d1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if d1s is not None: + pred_res = torch.einsum("k,bkchw->bchw", rhos_p, d1s) + else: + pred_res = 0 + x_t = x_t_ - alpha_t * b_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if d1s is not None: + pred_res = torch.einsum("k,bkchw->bchw", rhos_p, d1s) + else: + pred_res = 0 + x_t = x_t_ - sigma_t * b_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.FloatTensor, + this_timestep: int, + last_sample: torch.FloatTensor, + # this_sample: torch.FloatTensor, + order: int, + ) -> torch.FloatTensor: + timestep_list = self.timestep_list + model_output_list = self.model_outputs + + s0, t = timestep_list[-1], this_timestep + m0 = model_output_list[-1] + x = last_sample + # x_t = this_sample + model_t = this_model_output + + lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + + h = lambda_t - lambda_s0 + + rks = [] + d1s = [] + for i in range(1, order): + si = timestep_list[-(i + 1)] + mi = model_output_list[-(i + 1)] + lambda_si = self.lambda_t[si] + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + d1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=self.device) + + r = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.solver_type == "bh1": + b_h = hh + elif self.solver_type == "bh2": + b_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + r.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / b_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + r = torch.stack(r) + b = torch.tensor(b, device=self.device) + + if len(d1s) > 0: + d1s = torch.stack(d1s, dim=1) + else: + d1s = None + + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=self.device) + else: + rhos_c = torch.linalg.solve(r, b) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if d1s is not None: + corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], d1s) + else: + corr_res = 0 + d1_t = model_t - m0 + x_t = x_t_ - alpha_t * b_h * (corr_res + rhos_c[-1] * d1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if d1s is not None: + corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], d1s) + else: + corr_res = 0 + d1_t = model_t - m0 + x_t = x_t_ - sigma_t * b_h * (corr_res + rhos_c[-1] * d1_t) + x_t = x_t.to(x.dtype) + return x_t + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + return_dict: bool = True, + ): + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.device) + step_index = (self.timesteps == timestep).nonzero() + if len(step_index) == 0: + step_index = len(self.timesteps) - 1 + else: + step_index = step_index.item() + + use_corrector = step_index > 0 and step_index - 1 not in self.disable_corrector and self.last_sample is not None + + model_output_convert = self.convert_model_output(model_output, timestep, sample) + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + this_timestep=timestep, + last_sample=self.last_sample, + # this_sample=sample, + order=self.this_order, + ) + + # now prepare to run the predictor + prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] + + for i in range(self.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep + + if self.lower_order_final: + this_order = min(self.solver_order, len(self.timesteps) - step_index) + else: + this_order = self.solver_order + + self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, # pass the original non-converted model output, in case solver-p is used + prev_timestep=prev_timestep, + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.solver_order: + self.lower_order_nums += 1 + + if not return_dict: + return (prev_sample,) + + return prev_sample + + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + return sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=self.device, dtype=original_samples.dtype) + timesteps = timesteps.to(self.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def configure(self): + pass + + def __len__(self): + return self.num_train_timesteps diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py new file mode 100644 index 0000000000000..64c3c5bc80ecb --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py @@ -0,0 +1,181 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import os +from enum import Enum + +import torch +from diffusion_models import CLIP, VAE, CLIPWithProj, PipelineInfo, UNet, UNetXL + + +class EngineType(Enum): + ORT_CUDA = 0 # ONNX Runtime CUDA Execution Provider + ORT_TRT = 1 # ONNX Runtime TensorRT Execution Provider + TRT = 2 # TensorRT + TORCH = 3 # PyTorch + + +def get_engine_type(name: str) -> EngineType: + name_to_type = { + "ORT_CUDA": EngineType.ORT_CUDA, + "ORT_TRT": EngineType.ORT_TRT, + "TRT": EngineType.TRT, + "TORCH": EngineType.TORCH, + } + return name_to_type[name] + + +class EngineBuilder: + def __init__( + self, + engine_type: EngineType, + pipeline_info: PipelineInfo, + device="cuda", + max_batch_size=16, + hf_token=None, + use_cuda_graph=False, + ): + """ + Initializes the Engine Builder. + + Args: + pipeline_info (PipelineInfo): + Version and Type of pipeline. + device (str | torch.device): + device to run engine + max_batch_size (int): + Maximum batch size for dynamic batch engine. + hf_token (str): + HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. + use_cuda_graph (bool): + Use CUDA graph to capture engine execution and then launch inference + """ + self.engine_type = engine_type + self.pipeline_info = pipeline_info + self.max_batch_size = max_batch_size + self.hf_token = hf_token + self.use_cuda_graph = use_cuda_graph + self.device = torch.device(device) + self.torch_device = torch.device(device, torch.cuda.current_device()) + self.stages = pipeline_info.stages() + self.vae_torch_fallback = self.pipeline_info.is_sd_xl() + + self.models = {} + self.engines = {} + self.torch_models = {} + + def teardown(self): + for engine in self.engines.values(): + del engine + self.engines = {} + + def get_cached_model_name(self, model_name): + if self.pipeline_info.is_inpaint(): + model_name += "_inpaint" + return model_name + + def get_onnx_path(self, model_name, onnx_dir, opt=True): + engine_name = self.engine_type.name.lower() + onnx_model_dir = os.path.join( + onnx_dir, self.get_cached_model_name(model_name) + (f".{engine_name}" if opt else "") + ) + os.makedirs(onnx_model_dir, exist_ok=True) + return os.path.join(onnx_model_dir, "model.onnx") + + def get_engine_path(self, engine_dir, model_name, profile_id): + return os.path.join(engine_dir, self.get_cached_model_name(model_name) + profile_id) + + def load_models(self, framework_model_dir: str): + # Disable torch SDPA since torch 2.0.* cannot export it to ONNX + if hasattr(torch.nn.functional, "scaled_dot_product_attention"): + delattr(torch.nn.functional, "scaled_dot_product_attention") + + # For TRT or ORT_TRT, we will export fp16 torch model for UNet. + # For ORT_CUDA, we export fp32 model first, then optimize to fp16. + export_fp16_unet = self.engine_type in [EngineType.ORT_TRT, EngineType.TRT] + + if "clip" in self.stages: + self.models["clip"] = CLIP( + self.pipeline_info, + None, # not loaded yet + device=self.torch_device, + max_batch_size=self.max_batch_size, + clip_skip=0, + ) + + if "clip2" in self.stages: + self.models["clip2"] = CLIPWithProj( + self.pipeline_info, + None, # not loaded yet + device=self.torch_device, + max_batch_size=self.max_batch_size, + clip_skip=0, + ) + + if "unet" in self.stages: + self.models["unet"] = UNet( + self.pipeline_info, + None, # not loaded yet + device=self.torch_device, + fp16=export_fp16_unet, + max_batch_size=self.max_batch_size, + unet_dim=(9 if self.pipeline_info.is_inpaint() else 4), + ) + + if "unetxl" in self.stages: + self.models["unetxl"] = UNetXL( + self.pipeline_info, + None, # not loaded yet + device=self.torch_device, + fp16=export_fp16_unet, + max_batch_size=self.max_batch_size, + unet_dim=4, + time_dim=(5 if self.pipeline_info.is_sd_xl_refiner() else 6), + ) + + # VAE Decoder + if "vae" in self.stages: + self.models["vae"] = VAE( + self.pipeline_info, + None, # not loaded yet + device=self.torch_device, + max_batch_size=self.max_batch_size, + ) + + if self.vae_torch_fallback: + self.torch_models["vae"] = self.models["vae"].load_model(framework_model_dir, self.hf_token) + + def load_resources(self, image_height, image_width, batch_size): + # Allocate buffers for I/O bindings + for model_name, obj in self.models.items(): + if model_name == "vae" and self.vae_torch_fallback: + continue + self.engines[model_name].allocate_buffers( + shape_dict=obj.get_shape_dict(batch_size, image_height, image_width), device=self.torch_device + ) + + def vae_decode(self, latents): + if self.vae_torch_fallback: + latents = latents.to(dtype=torch.float32) + self.torch_models["vae"] = self.torch_models["vae"].to(dtype=torch.float32) + images = self.torch_models["vae"](latents)["sample"] + else: + images = self.run_engine("vae", {"latent": latents})["images"] + + return images + + +def get_engine_paths(work_dir: str, pipeline_info: PipelineInfo, engine_type: EngineType): + root_dir = work_dir or "." + short_name = pipeline_info.short_name() + + # When both ORT_CUDA and ORT_TRT/TRT is used, we shall make sub directory for each engine since + # ORT_CUDA need fp32 torch model, while ORT_TRT/TRT use fp16 torch model. + onnx_dir = os.path.join(root_dir, engine_type.name, short_name, "onnx") + engine_dir = os.path.join(root_dir, engine_type.name, short_name, "engine") + output_dir = os.path.join(root_dir, engine_type.name, short_name, "output") + timing_cache = os.path.join(root_dir, engine_type.name, "timing_cache") + framework_model_dir = os.path.join(root_dir, engine_type.name, "torch_model") + + return onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py new file mode 100644 index 0000000000000..253cdcc45bf2e --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py @@ -0,0 +1,263 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import gc +import logging +import os +import shutil + +import torch +from cuda import cudart +from diffusion_models import PipelineInfo +from engine_builder import EngineBuilder, EngineType + +import onnxruntime as ort +from onnxruntime.transformers.io_binding_helper import CudaSession + +logger = logging.getLogger(__name__) + + +class OrtTensorrtEngine(CudaSession): + def __init__(self, engine_path, device_id, onnx_path, fp16, input_profile, workspace_size, enable_cuda_graph): + self.engine_path = engine_path + self.ort_trt_provider_options = self.get_tensorrt_provider_options( + input_profile, + workspace_size, + fp16, + device_id, + enable_cuda_graph, + ) + + session_options = ort.SessionOptions() + session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL + print("creating TRT EP session for ", onnx_path) + ort_session = ort.InferenceSession( + onnx_path, + session_options, + providers=[ + ("TensorrtExecutionProvider", self.ort_trt_provider_options), + ], + ) + print("created TRT EP session for ", onnx_path) + + device = torch.device("cuda", device_id) + super().__init__(ort_session, device, enable_cuda_graph) + + def get_tensorrt_provider_options(self, input_profile, workspace_size, fp16, device_id, enable_cuda_graph): + trt_ep_options = { + "device_id": device_id, + "trt_fp16_enable": fp16, + "trt_engine_cache_enable": True, + "trt_timing_cache_enable": True, + "trt_detailed_build_log": True, + "trt_engine_cache_path": self.engine_path, + } + + if enable_cuda_graph: + trt_ep_options["trt_cuda_graph_enable"] = True + + if workspace_size > 0: + trt_ep_options["trt_max_workspace_size"] = workspace_size + + if input_profile: + min_shapes = [] + max_shapes = [] + opt_shapes = [] + for name, profile in input_profile.items(): + assert isinstance(profile, list) and len(profile) == 3 + min_shape = profile[0] + opt_shape = profile[1] + max_shape = profile[2] + assert len(min_shape) == len(opt_shape) and len(opt_shape) == len(max_shape) + + min_shapes.append(f"{name}:" + "x".join([str(x) for x in min_shape])) + opt_shapes.append(f"{name}:" + "x".join([str(x) for x in opt_shape])) + max_shapes.append(f"{name}:" + "x".join([str(x) for x in max_shape])) + + trt_ep_options["trt_profile_min_shapes"] = ",".join(min_shapes) + trt_ep_options["trt_profile_max_shapes"] = ",".join(max_shapes) + trt_ep_options["trt_profile_opt_shapes"] = ",".join(opt_shapes) + + logger.info("trt_ep_options=%s", trt_ep_options) + + return trt_ep_options + + def allocate_buffers(self, shape_dict, device): + super().allocate_buffers(shape_dict) + + +class OrtTensorrtEngineBuilder(EngineBuilder): + def __init__( + self, + pipeline_info: PipelineInfo, + max_batch_size=16, + hf_token=None, + device="cuda", + use_cuda_graph=False, + ): + """ + Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder. + + Args: + pipeline_info (PipelineInfo): + Version and Type of pipeline. + max_batch_size (int): + Maximum batch size for dynamic batch engine. + hf_token (str): + HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. + device (str): + device to run. + use_cuda_graph (bool): + Use CUDA graph to capture engine execution and then launch inference + """ + super().__init__( + EngineType.ORT_TRT, + pipeline_info, + max_batch_size=max_batch_size, + hf_token=hf_token, + device=device, + use_cuda_graph=use_cuda_graph, + ) + + def has_engine_file(self, engine_path): + if os.path.isdir(engine_path): + children = os.scandir(engine_path) + for entry in children: + if entry.is_file() and entry.name.endswith(".engine"): + return True + return False + + def get_work_space_size(self, model_name, max_workspace_size): + gibibyte = 2**30 + workspace_size = 4 * gibibyte if model_name == "clip" else max_workspace_size + if workspace_size == 0: + _, free_mem, _ = cudart.cudaMemGetInfo() + # The following logic are adopted from TensorRT demo diffusion. + if free_mem > 6 * gibibyte: + workspace_size = free_mem - 4 * gibibyte + return workspace_size + + def build_engines( + self, + engine_dir, + framework_model_dir, + onnx_dir, + onnx_opset, + opt_image_height, + opt_image_width, + opt_batch_size=1, + force_engine_rebuild=False, + static_batch=False, + static_image_shape=True, + max_workspace_size=0, + device_id=0, + ): + self.torch_device = torch.device("cuda", device_id) + self.load_models(framework_model_dir) + + if force_engine_rebuild: + if os.path.isdir(onnx_dir): + logger.info("Remove existing directory %s since force_engine_rebuild is enabled", onnx_dir) + shutil.rmtree(onnx_dir) + if os.path.isdir(engine_dir): + logger.info("Remove existing directory %s since force_engine_rebuild is enabled", engine_dir) + shutil.rmtree(engine_dir) + + if not os.path.isdir(engine_dir): + os.makedirs(engine_dir) + + if not os.path.isdir(onnx_dir): + os.makedirs(onnx_dir) + + # Export models to ONNX + for model_name, model_obj in self.models.items(): + if model_name == "vae" and self.vae_torch_fallback: + continue + + profile_id = model_obj.get_profile_id( + opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape + ) + engine_path = self.get_engine_path(engine_dir, model_name, profile_id) + if not self.has_engine_file(engine_path): + onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False) + onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True) + if not os.path.exists(onnx_opt_path): + if not os.path.exists(onnx_path): + logger.info(f"Exporting model: {onnx_path}") + model = model_obj.load_model(framework_model_dir, self.hf_token) + with torch.inference_mode(), torch.autocast("cuda"): + inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) + torch.onnx.export( + model, + inputs, + onnx_path, + export_params=True, + opset_version=onnx_opset, + do_constant_folding=True, + input_names=model_obj.get_input_names(), + output_names=model_obj.get_output_names(), + dynamic_axes=model_obj.get_dynamic_axes(), + ) + del model + torch.cuda.empty_cache() + gc.collect() + else: + logger.info("Found cached model: %s", onnx_path) + + # Optimize onnx + if not os.path.exists(onnx_opt_path): + logger.info("Generating optimizing model: %s", onnx_opt_path) + model_obj.optimize_trt(onnx_path, onnx_opt_path) + else: + logger.info("Found cached optimized model: %s", onnx_opt_path) + + built_engines = {} + for model_name, model_obj in self.models.items(): + if model_name == "vae" and self.vae_torch_fallback: + continue + + profile_id = model_obj.get_profile_id( + opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape + ) + + engine_path = self.get_engine_path(engine_dir, model_name, profile_id) + onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True) + + if not self.has_engine_file(engine_path): + logger.info( + "Building TensorRT engine for %s from %s to %s. It can take a while to complete...", + model_name, + onnx_opt_path, + engine_path, + ) + else: + logger.info("Reuse cached TensorRT engine in directory %s", engine_path) + + input_profile = model_obj.get_input_profile( + opt_batch_size, + opt_image_height, + opt_image_width, + static_batch=static_batch, + static_image_shape=static_image_shape, + ) + + engine = OrtTensorrtEngine( + engine_path, + device_id, + onnx_opt_path, + fp16=True, + input_profile=input_profile, + workspace_size=self.get_work_space_size(model_name, max_workspace_size), + enable_cuda_graph=self.use_cuda_graph, + ) + + built_engines[model_name] = engine + + self.engines = built_engines + + return built_engines + + def run_engine(self, model_name, feed_dict): + return self.engines[model_name].infer(feed_dict) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_tensorrt.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_tensorrt.py new file mode 100644 index 0000000000000..4a924abfb8600 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_tensorrt.py @@ -0,0 +1,507 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import gc +import os +import pathlib +from collections import OrderedDict + +import numpy as np +import onnx +import onnx_graphsurgeon as gs +import tensorrt as trt +import torch +from cuda import cudart +from diffusion_models import PipelineInfo +from engine_builder import EngineBuilder, EngineType +from polygraphy.backend.common import bytes_from_path +from polygraphy.backend.trt import ( + CreateConfig, + ModifyNetworkOutputs, + Profile, + engine_from_bytes, + engine_from_network, + network_from_onnx_path, + save_engine, +) +from trt_utilities import TRT_LOGGER + +# Map of numpy dtype -> torch dtype +numpy_to_torch_dtype_dict = { + np.int32: torch.int32, + np.int64: torch.int64, + np.float16: torch.float16, + np.float32: torch.float32, +} + + +def _cuda_assert(cuda_ret): + err = cuda_ret[0] + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError( + f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t" + ) + if len(cuda_ret) > 1: + return cuda_ret[1] + return None + + +class TensorrtEngine: + def __init__( + self, + engine_path, + ): + self.engine_path = engine_path + self.engine = None + self.context = None + self.buffers = OrderedDict() + self.tensors = OrderedDict() + self.cuda_graph_instance = None + + def __del__(self): + del self.engine + del self.context + del self.buffers + del self.tensors + + def refit(self, onnx_path, onnx_refit_path): + def convert_int64(arr): + if len(arr.shape) == 0: + return np.int32(arr) + return arr + + def add_to_map(refit_dict, name, values): + if name in refit_dict: + assert refit_dict[name] is None + if values.dtype == np.int64: + values = convert_int64(values) + refit_dict[name] = values + + print(f"Refitting TensorRT engine with {onnx_refit_path} weights") + refit_nodes = gs.import_onnx(onnx.load(onnx_refit_path)).toposort().nodes + + # Construct mapping from weight names in refit model -> original model + name_map = {} + for n, node in enumerate(gs.import_onnx(onnx.load(onnx_path)).toposort().nodes): + refit_node = refit_nodes[n] + assert node.op == refit_node.op + # Constant nodes in ONNX do not have inputs but have a constant output + if node.op == "Constant": + name_map[refit_node.outputs[0].name] = node.outputs[0].name + # Handle scale and bias weights + elif node.op == "Conv": + if node.inputs[1].__class__ == gs.Constant: + name_map[refit_node.name + "_TRTKERNEL"] = node.name + "_TRTKERNEL" + if node.inputs[2].__class__ == gs.Constant: + name_map[refit_node.name + "_TRTBIAS"] = node.name + "_TRTBIAS" + # For all other nodes: find node inputs that are initializers (gs.Constant) + else: + for i, inp in enumerate(node.inputs): + if inp.__class__ == gs.Constant: + name_map[refit_node.inputs[i].name] = inp.name + + def map_name(name): + if name in name_map: + return name_map[name] + return name + + # Construct refit dictionary + refit_dict = {} + refitter = trt.Refitter(self.engine, TRT_LOGGER) + all_weights = refitter.get_all() + for layer_name, role in zip(all_weights[0], all_weights[1]): + # for specialized roles, use a unique name in the map: + if role == trt.WeightsRole.KERNEL: + name = layer_name + "_TRTKERNEL" + elif role == trt.WeightsRole.BIAS: + name = layer_name + "_TRTBIAS" + else: + name = layer_name + + assert name not in refit_dict, "Found duplicate layer: " + name + refit_dict[name] = None + + for n in refit_nodes: + # Constant nodes in ONNX do not have inputs but have a constant output + if n.op == "Constant": + name = map_name(n.outputs[0].name) + print(f"Add Constant {name}\n") + add_to_map(refit_dict, name, n.outputs[0].values) + + # Handle scale and bias weights + elif n.op == "Conv": + if n.inputs[1].__class__ == gs.Constant: + name = map_name(n.name + "_TRTKERNEL") + add_to_map(refit_dict, name, n.inputs[1].values) + + if n.inputs[2].__class__ == gs.Constant: + name = map_name(n.name + "_TRTBIAS") + add_to_map(refit_dict, name, n.inputs[2].values) + + # For all other nodes: find node inputs that are initializers (AKA gs.Constant) + else: + for inp in n.inputs: + name = map_name(inp.name) + if inp.__class__ == gs.Constant: + add_to_map(refit_dict, name, inp.values) + + for layer_name, weights_role in zip(all_weights[0], all_weights[1]): + if weights_role == trt.WeightsRole.KERNEL: + custom_name = layer_name + "_TRTKERNEL" + elif weights_role == trt.WeightsRole.BIAS: + custom_name = layer_name + "_TRTBIAS" + else: + custom_name = layer_name + + # Skip refitting Trilu for now; scalar weights of type int64 value 1 - for clip model + if layer_name.startswith("onnx::Trilu"): + continue + + if refit_dict[custom_name] is not None: + refitter.set_weights(layer_name, weights_role, refit_dict[custom_name]) + else: + print(f"[W] No refit weights for layer: {layer_name}") + + if not refitter.refit_cuda_engine(): + print("Failed to refit!") + exit(0) + + def build( + self, + onnx_path, + fp16, + input_profile=None, + enable_refit=False, + enable_preview=False, + enable_all_tactics=False, + timing_cache=None, + update_output_names=None, + ): + print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") + p = Profile() + if input_profile: + for name, dims in input_profile.items(): + assert len(dims) == 3 + p.add(name, min=dims[0], opt=dims[1], max=dims[2]) + + config_kwargs = {} + if not enable_all_tactics: + config_kwargs["tactic_sources"] = [] + + network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) + if update_output_names: + print(f"Updating network outputs to {update_output_names}") + network = ModifyNetworkOutputs(network, update_output_names) + engine = engine_from_network( + network, + config=CreateConfig( + fp16=fp16, refittable=enable_refit, profiles=[p], load_timing_cache=timing_cache, **config_kwargs + ), + save_timing_cache=timing_cache, + ) + save_engine(engine, path=self.engine_path) + + def load(self): + print(f"Loading TensorRT engine: {self.engine_path}") + self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) + + def activate(self, reuse_device_memory=None): + if reuse_device_memory: + self.context = self.engine.create_execution_context_without_device_memory() + self.context.device_memory = reuse_device_memory + else: + self.context = self.engine.create_execution_context() + + def allocate_buffers(self, shape_dict=None, device="cuda"): + for idx in range(self.engine.num_io_tensors): + binding = self.engine[idx] + if shape_dict and binding in shape_dict: + shape = shape_dict[binding] + else: + shape = self.engine.get_binding_shape(binding) + dtype = trt.nptype(self.engine.get_binding_dtype(binding)) + if self.engine.binding_is_input(binding): + self.context.set_binding_shape(idx, shape) + tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device) + self.tensors[binding] = tensor + + def infer(self, feed_dict, stream, use_cuda_graph=False): + for name, buf in feed_dict.items(): + self.tensors[name].copy_(buf) + + for name, tensor in self.tensors.items(): + self.context.set_tensor_address(name, tensor.data_ptr()) + + if use_cuda_graph: + if self.cuda_graph_instance is not None: + _cuda_assert(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream)) + _cuda_assert(cudart.cudaStreamSynchronize(stream)) + else: + # do inference before CUDA graph capture + noerror = self.context.execute_async_v3(stream) + if not noerror: + raise ValueError("ERROR: inference failed.") + # capture cuda graph + _cuda_assert( + cudart.cudaStreamBeginCapture(stream, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal) + ) + self.context.execute_async_v3(stream) + self.graph = _cuda_assert(cudart.cudaStreamEndCapture(stream)) + + from cuda import nvrtc + + result, major, minor = nvrtc.nvrtcVersion() + assert result == nvrtc.nvrtcResult(0) + if major < 12: + self.cuda_graph_instance = _cuda_assert( + cudart.cudaGraphInstantiate(self.graph, b"", 0) + ) # cuda < 12 + else: + self.cuda_graph_instance = _cuda_assert(cudart.cudaGraphInstantiate(self.graph, 0)) # cuda >= 12 + else: + noerror = self.context.execute_async_v3(stream) + if not noerror: + raise ValueError("ERROR: inference failed.") + + return self.tensors + + +class TensorrtEngineBuilder(EngineBuilder): + """ + Helper class to hide the detail of TensorRT Engine from pipeline. + """ + + def __init__( + self, + pipeline_info: PipelineInfo, + max_batch_size=16, + hf_token=None, + device="cuda", + use_cuda_graph=False, + ): + """ + Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder. + + Args: + pipeline_info (PipelineInfo): + Version and Type of pipeline. + max_batch_size (int): + Maximum batch size for dynamic batch engine. + hf_token (str): + HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. + device (str): + device to run. + use_cuda_graph (bool): + Use CUDA graph to capture engine execution and then launch inference + """ + super().__init__( + EngineType.TRT, + pipeline_info, + max_batch_size=max_batch_size, + hf_token=hf_token, + device=device, + use_cuda_graph=use_cuda_graph, + ) + + self.stream = None + self.shared_device_memory = None + + def load_resources(self, image_height, image_width, batch_size): + super().load_resources(image_height, image_width, batch_size) + + self.stream = _cuda_assert(cudart.cudaStreamCreate()) + + def teardown(self): + super().teardown() + + if self.shared_device_memory: + cudart.cudaFree(self.shared_device_memory) + + cudart.cudaStreamDestroy(self.stream) + del self.stream + + def load_engines( + self, + engine_dir, + framework_model_dir, + onnx_dir, + onnx_opset, + opt_batch_size, + opt_image_height, + opt_image_width, + force_export=False, + force_optimize=False, + force_build=False, + static_batch=False, + static_shape=True, + enable_refit=False, + enable_preview=False, + enable_all_tactics=False, + timing_cache=None, + onnx_refit_dir=None, + ): + """ + Build and load engines for TensorRT accelerated inference. + Export ONNX models first, if applicable. + + Args: + engine_dir (str): + Directory to write the TensorRT engines. + framework_model_dir (str): + Directory to write the framework model ckpt. + onnx_dir (str): + Directory to write the ONNX models. + onnx_opset (int): + ONNX opset version to export the models. + opt_batch_size (int): + Batch size to optimize for during engine building. + opt_image_height (int): + Image height to optimize for during engine building. Must be a multiple of 8. + opt_image_width (int): + Image width to optimize for during engine building. Must be a multiple of 8. + force_export (bool): + Force re-exporting the ONNX models. + force_optimize (bool): + Force re-optimizing the ONNX models. + force_build (bool): + Force re-building the TensorRT engine. + static_batch (bool): + Build engine only for specified opt_batch_size. + static_shape (bool): + Build engine only for specified opt_image_height & opt_image_width. Default = True. + enable_refit (bool): + Build engines with refit option enabled. + enable_preview (bool): + Enable TensorRT preview features. + enable_all_tactics (bool): + Enable all tactic sources during TensorRT engine builds. + timing_cache (str): + Path to the timing cache to accelerate build or None + onnx_refit_dir (str): + Directory containing refit ONNX models. + """ + # Create directory + for directory in [engine_dir, onnx_dir]: + if not os.path.exists(directory): + print(f"[I] Create directory: {directory}") + pathlib.Path(directory).mkdir(parents=True) + + self.load_models(framework_model_dir) + + # Export models to ONNX + for model_name, obj in self.models.items(): + if model_name == "vae" and self.vae_torch_fallback: + continue + profile_id = obj.get_profile_id( + opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape + ) + engine_path = self.get_engine_path(engine_dir, model_name, profile_id) + if force_export or force_build or not os.path.exists(engine_path): + onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False) + onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True) + if force_export or not os.path.exists(onnx_opt_path): + if force_export or not os.path.exists(onnx_path): + print(f"Exporting model: {onnx_path}") + model = obj.load_model(framework_model_dir, self.hf_token) + with torch.inference_mode(), torch.autocast("cuda"): + inputs = obj.get_sample_input(1, opt_image_height, opt_image_width) + torch.onnx.export( + model, + inputs, + onnx_path, + export_params=True, + opset_version=onnx_opset, + do_constant_folding=True, + input_names=obj.get_input_names(), + output_names=obj.get_output_names(), + dynamic_axes=obj.get_dynamic_axes(), + ) + del model + torch.cuda.empty_cache() + gc.collect() + else: + print(f"Found cached model: {onnx_path}") + + # Optimize onnx + if force_optimize or not os.path.exists(onnx_opt_path): + print(f"Generating optimizing model: {onnx_opt_path}") + obj.optimize_trt(onnx_path, onnx_opt_path) + else: + print(f"Found cached optimized model: {onnx_opt_path} ") + + # Build TensorRT engines + for model_name, obj in self.models.items(): + if model_name == "vae" and self.vae_torch_fallback: + continue + profile_id = obj.get_profile_id( + opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape + ) + engine_path = self.get_engine_path(engine_dir, model_name, profile_id) + engine = TensorrtEngine(engine_path) + onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True) + + if force_build or not os.path.exists(engine.engine_path): + engine.build( + onnx_opt_path, + fp16=True, + input_profile=obj.get_input_profile( + opt_batch_size, + opt_image_height, + opt_image_width, + static_batch, + static_shape, + ), + enable_refit=enable_refit, + enable_preview=enable_preview, + enable_all_tactics=enable_all_tactics, + timing_cache=timing_cache, + update_output_names=None, + ) + self.engines[model_name] = engine + + # Load TensorRT engines + for model_name in self.models: + if model_name == "vae" and self.vae_torch_fallback: + continue + self.engines[model_name].load() + if onnx_refit_dir: + onnx_refit_path = self.get_onnx_path(model_name, onnx_refit_dir, opt=True) + if os.path.exists(onnx_refit_path): + self.engines[model_name].refit(onnx_opt_path, onnx_refit_path) + + def max_device_memory(self): + max_device_memory = 0 + for _model_name, engine in self.engines.items(): + max_device_memory = max(max_device_memory, engine.engine.device_memory_size) + return max_device_memory + + def activate_engines(self, shared_device_memory=None): + if shared_device_memory is None: + max_device_memory = self.max_device_memory() + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + self.shared_device_memory = shared_device_memory + # Load and activate TensorRT engines + for engine in self.engines.values(): + engine.activate(reuse_device_memory=self.shared_device_memory) + + def run_engine(self, model_name, feed_dict): + return self.engines[model_name].infer(feed_dict, self.stream, use_cuda_graph=self.use_cuda_graph) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/models.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/models.py deleted file mode 100644 index 0f7688a3df9f6..0000000000000 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/models.py +++ /dev/null @@ -1,368 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -# -# Copyright 2023 The HuggingFace Inc. team. -# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Models used in Stable diffusion. -""" -import logging - -import onnx -import onnx_graphsurgeon as gs -import torch -from onnx import shape_inference -from ort_optimizer import OrtStableDiffusionOptimizer -from polygraphy.backend.onnx.loader import fold_constants - -logger = logging.getLogger(__name__) - - -class TrtOptimizer: - def __init__(self, onnx_graph): - self.graph = gs.import_onnx(onnx_graph) - - def cleanup(self): - self.graph.cleanup().toposort() - - def get_optimized_onnx_graph(self): - return gs.export_onnx(self.graph) - - def select_outputs(self, keep, names=None): - self.graph.outputs = [self.graph.outputs[o] for o in keep] - if names: - for i, name in enumerate(names): - self.graph.outputs[i].name = name - - def fold_constants(self): - onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True) - self.graph = gs.import_onnx(onnx_graph) - - def infer_shapes(self): - onnx_graph = gs.export_onnx(self.graph) - if onnx_graph.ByteSize() > 2147483648: - raise TypeError("ERROR: model size exceeds supported 2GB limit") - else: - onnx_graph = shape_inference.infer_shapes(onnx_graph) - - self.graph = gs.import_onnx(onnx_graph) - - -class BaseModel: - def __init__(self, model, name, device="cuda", fp16=False, max_batch_size=16, embedding_dim=768, text_maxlen=77): - self.model = model - self.name = name - self.fp16 = fp16 - self.device = device - - self.min_batch = 1 - self.max_batch = max_batch_size - self.min_image_shape = 256 # min image resolution: 256x256 - self.max_image_shape = 1024 # max image resolution: 1024x1024 - self.min_latent_shape = self.min_image_shape // 8 - self.max_latent_shape = self.max_image_shape // 8 - - self.embedding_dim = embedding_dim - self.text_maxlen = text_maxlen - - self.model_type = name.lower() if name in ["CLIP", "UNet"] else "vae" - self.ort_optimizer = OrtStableDiffusionOptimizer(self.model_type) - - def get_model(self): - return self.model - - def get_input_names(self): - pass - - def get_output_names(self): - pass - - def get_dynamic_axes(self): - return None - - def get_sample_input(self, batch_size, image_height, image_width): - pass - - def get_profile_id(self, batch_size, image_height, image_width, static_batch, static_image_shape): - """For TensorRT EP""" - ( - min_batch, - max_batch, - min_image_height, - max_image_height, - min_image_width, - max_image_width, - _, - _, - _, - _, - ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) - - profile_id = f"_b_{batch_size}" if static_batch else f"_b_{min_batch}_{max_batch}" - - if self.name != "CLIP": - if static_image_shape: - profile_id += f"_h_{image_height}_w_{image_width}" - else: - profile_id += f"_h_{min_image_height}_{max_image_height}_w_{min_image_width}_{max_image_width}" - - return profile_id - - def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): - """For TensorRT""" - return None - - def get_shape_dict(self, batch_size, image_height, image_width): - return None - - def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True): - self.ort_optimizer.optimize(input_onnx_path, optimized_onnx_path, to_fp16) - - def optimize_trt(self, input_onnx_path, optimized_onnx_path): - onnx_graph = onnx.load(input_onnx_path) - opt = TrtOptimizer(onnx_graph) - opt.cleanup() - opt.fold_constants() - opt.infer_shapes() - opt.cleanup() - onnx_opt_graph = opt.get_optimized_onnx_graph() - onnx.save(onnx_opt_graph, optimized_onnx_path) - - def check_dims(self, batch_size, image_height, image_width): - assert batch_size >= self.min_batch and batch_size <= self.max_batch - assert image_height % 8 == 0 or image_width % 8 == 0 - latent_height = image_height // 8 - latent_width = image_width // 8 - assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape - assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape - return (latent_height, latent_width) - - def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_image_shape): - min_batch = batch_size if static_batch else self.min_batch - max_batch = batch_size if static_batch else self.max_batch - latent_height = image_height // 8 - latent_width = image_width // 8 - min_image_height = image_height if static_image_shape else self.min_image_shape - max_image_height = image_height if static_image_shape else self.max_image_shape - min_image_width = image_width if static_image_shape else self.min_image_shape - max_image_width = image_width if static_image_shape else self.max_image_shape - min_latent_height = latent_height if static_image_shape else self.min_latent_shape - max_latent_height = latent_height if static_image_shape else self.max_latent_shape - min_latent_width = latent_width if static_image_shape else self.min_latent_shape - max_latent_width = latent_width if static_image_shape else self.max_latent_shape - return ( - min_batch, - max_batch, - min_image_height, - max_image_height, - min_image_width, - max_image_width, - min_latent_height, - max_latent_height, - min_latent_width, - max_latent_width, - ) - - -class CLIP(BaseModel): - def __init__(self, model, device, max_batch_size, embedding_dim): - super().__init__( - model=model, - name="CLIP", - device=device, - max_batch_size=max_batch_size, - embedding_dim=embedding_dim, - ) - - def get_input_names(self): - return ["input_ids"] - - def get_output_names(self): - return ["text_embeddings"] - - def get_dynamic_axes(self): - return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}} - - def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): - self.check_dims(batch_size, image_height, image_width) - min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims( - batch_size, image_height, image_width, static_batch, static_image_shape - ) - return { - "input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)] - } - - def get_shape_dict(self, batch_size, image_height, image_width): - self.check_dims(batch_size, image_height, image_width) - return { - "input_ids": (batch_size, self.text_maxlen), - "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim), - } - - def get_sample_input(self, batch_size, image_height, image_width): - self.check_dims(batch_size, image_height, image_width) - return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device) - - def optimize_trt(self, input_onnx_path, optimized_onnx_path): - onnx_graph = onnx.load(input_onnx_path) - opt = TrtOptimizer(onnx_graph) - opt.select_outputs([0]) # delete graph output#1 - opt.cleanup() - opt.fold_constants() - opt.infer_shapes() - opt.select_outputs([0], names=["text_embeddings"]) # rename network output - opt.cleanup() - onnx_opt_graph = opt.get_optimized_onnx_graph() - onnx.save(onnx_opt_graph, optimized_onnx_path) - - -class UNet(BaseModel): - def __init__( - self, - model, - device="cuda", - fp16=False, # used by TRT - max_batch_size=16, - embedding_dim=768, - text_maxlen=77, - unet_dim=4, - ): - super().__init__( - model=model, - name="UNet", - device=device, - fp16=fp16, - max_batch_size=max_batch_size, - embedding_dim=embedding_dim, - text_maxlen=text_maxlen, - ) - self.unet_dim = unet_dim - - def get_input_names(self): - return ["sample", "timestep", "encoder_hidden_states"] - - def get_output_names(self): - return ["latent"] - - def get_dynamic_axes(self): - return { - "sample": {0: "2B", 2: "H", 3: "W"}, - "encoder_hidden_states": {0: "2B"}, - "latent": {0: "2B", 2: "H", 3: "W"}, - } - - def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - ( - min_batch, - max_batch, - _, - _, - _, - _, - min_latent_height, - max_latent_height, - min_latent_width, - max_latent_width, - ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) - return { - "sample": [ - (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), - (2 * batch_size, self.unet_dim, latent_height, latent_width), - (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), - ], - "encoder_hidden_states": [ - (2 * min_batch, self.text_maxlen, self.embedding_dim), - (2 * batch_size, self.text_maxlen, self.embedding_dim), - (2 * max_batch, self.text_maxlen, self.embedding_dim), - ], - } - - def get_shape_dict(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - return { - "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), - "timestep": [1], - "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), - "latent": (2 * batch_size, 4, latent_height, latent_width), - } - - def get_sample_input(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - dtype = torch.float16 if self.fp16 else torch.float32 - return ( - torch.randn( - 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device - ), - torch.tensor([1.0], dtype=torch.float32, device=self.device), - torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), - ) - - -class VAE(BaseModel): - def __init__(self, model, device, max_batch_size, embedding_dim): - super().__init__( - model=model, - name="VAE Decoder", - device=device, - max_batch_size=max_batch_size, - embedding_dim=embedding_dim, - ) - - def get_input_names(self): - return ["latent"] - - def get_output_names(self): - return ["images"] - - def get_dynamic_axes(self): - return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}} - - def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - ( - min_batch, - max_batch, - _, - _, - _, - _, - min_latent_height, - max_latent_height, - min_latent_width, - max_latent_width, - ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) - return { - "latent": [ - (min_batch, 4, min_latent_height, min_latent_width), - (batch_size, 4, latent_height, latent_width), - (max_batch, 4, max_latent_height, max_latent_width), - ] - } - - def get_shape_dict(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - return { - "latent": (batch_size, 4, latent_height, latent_width), - "images": (batch_size, 3, image_height, image_width), - } - - def get_sample_input(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_cuda_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_cuda_txt2img.py index 6134fa7bddcf4..37785869a355b 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_cuda_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_cuda_txt2img.py @@ -43,16 +43,14 @@ StableDiffusionSafetyChecker, ) from diffusers.schedulers import DDIMScheduler -from diffusers.utils import DIFFUSERS_CACHE -from huggingface_hub import snapshot_download -from models import CLIP, VAE, UNet -from ort_utils import Engines +from diffusion_models import CLIP, VAE, PipelineInfo, UNet +from ort_utils import Engines, StableDiffusionPipelineMixin from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer logger = logging.getLogger(__name__) -class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipeline): +class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipelineMixin, StableDiffusionPipeline): r""" Pipeline for text-to-image generation using CUDA provider in ONNX Runtime. This pipeline inherits from [`StableDiffusionPipeline`]. Check the documentation in super class for most parameters. @@ -70,11 +68,12 @@ def __init__( requires_safety_checker: bool = True, # ONNX export parameters onnx_opset: int = 14, - onnx_dir: str = "raw_onnx", + onnx_dir: str = "onnx_ort", # Onnxruntime execution provider parameters - engine_dir: str = "onnxruntime_optimized_onnx", + engine_dir: str = "ORT_CUDA", force_engine_rebuild: bool = False, enable_cuda_graph: bool = False, + pipeline_info: PipelineInfo = None, ): super().__init__( vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker @@ -96,51 +95,38 @@ def __init__( self.fp16 = False - def __load_models(self): - self.embedding_dim = self.text_encoder.config.hidden_size + self.pipeline_info = pipeline_info - self.models["clip"] = CLIP( - self.text_encoder, - device=self.torch_device, - max_batch_size=self.max_batch_size, - embedding_dim=self.embedding_dim, - ) + def load_models(self): + assert self.pipeline_info.clip_embedding_dim() == self.text_encoder.config.hidden_size - self.models["unet"] = UNet( - self.unet, - device=self.torch_device, - fp16=self.fp16, - max_batch_size=self.max_batch_size, - embedding_dim=self.embedding_dim, - unet_dim=(9 if self.inpaint else 4), - ) + stages = self.pipeline_info.stages() + if "clip" in stages: + self.models["clip"] = CLIP( + self.pipeline_info, + self.text_encoder, + device=self.torch_device, + max_batch_size=self.max_batch_size, + clip_skip=0, + ) - self.models["vae"] = VAE( - self.vae, device=self.torch_device, max_batch_size=self.max_batch_size, embedding_dim=self.embedding_dim - ) + if "unet" in stages: + self.models["unet"] = UNet( + self.pipeline_info, + self.unet, + device=self.torch_device, + fp16=False, + max_batch_size=self.max_batch_size, + unet_dim=(9 if self.pipeline_info.is_inpaint() else 4), + ) - @classmethod - def set_cached_folder(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs): - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", False) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - - cls.cached_folder = ( - pretrained_model_name_or_path - if os.path.isdir(pretrained_model_name_or_path) - else snapshot_download( - pretrained_model_name_or_path, - cache_dir=cache_dir, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, + if "vae" in stages: + self.models["vae"] = VAE( + self.pipeline_info, + self.vae, + device=self.torch_device, + max_batch_size=self.max_batch_size, ) - ) def to( self, @@ -156,7 +142,7 @@ def to( # load models self.fp16 = torch_dtype == torch.float16 - self.__load_models() + self.load_models() # build engines self.engines.build( @@ -180,88 +166,6 @@ def to( return self - def __encode_prompt(self, prompt, negative_prompt): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - """ - # Tokenize prompt - text_input_ids = ( - self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - .input_ids.type(torch.int32) - .to(self.torch_device) - ) - - # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt - text_embeddings = ( - self.engines.get_engine("clip").infer({"input_ids": text_input_ids})["text_embeddings"].clone() - ) - - # Tokenize negative prompt - uncond_input_ids = ( - self.tokenizer( - negative_prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - .input_ids.type(torch.int32) - .to(self.torch_device) - ) - - uncond_embeddings = self.engines.get_engine("clip").infer({"input_ids": uncond_input_ids})["text_embeddings"] - - # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) - - return text_embeddings - - def __denoise_latent(self, latents, text_embeddings, timesteps=None, mask=None, masked_image_latents=None): - if not isinstance(timesteps, torch.Tensor): - timesteps = self.scheduler.timesteps - - for _step_index, timestep in enumerate(timesteps): - # Expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) - if isinstance(mask, torch.Tensor): - latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) - - timestep_float = timestep.to(torch.float16) if self.fp16 else timestep.to(torch.float32) - - # Predict the noise residual - noise_pred = self.engines.get_engine("unet").infer( - {"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings}, - )["latent"] - - # Perform guidance - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample - - latents = 1.0 / 0.18215 * latents - return latents - - def __decode_latent(self, latents): - images = self.engines.get_engine("vae").infer({"latent": latents})["images"] - images = (images / 2 + 0.5).clamp(0, 1) - return images.cpu().permute(0, 2, 3, 1).float().numpy() - def __allocate_buffers(self, image_height, image_width, batch_size): # Allocate output tensors for I/O bindings for model_name, obj in self.models.items(): @@ -337,7 +241,7 @@ def __call__( with torch.inference_mode(), torch.autocast("cuda"): # CLIP text encoder - text_embeddings = self.__encode_prompt(prompt, negative_prompt) + text_embeddings = self.encode_prompt(self.engines.get_engine("clip"), prompt, negative_prompt) # Pre-initialize latents num_channels_latents = self.unet_in_channels @@ -352,30 +256,37 @@ def __call__( ) # UNet denoiser - latents = self.__denoise_latent(latents, text_embeddings) + latents = self.denoise_latent( + self.engines.get_engine("unet"), latents, text_embeddings, timestep_fp16=self.fp16 + ) # VAE decode latent - images = self.__decode_latent(latents) + images = self.decode_latent(self.engines.get_engine("vae"), latents) images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype) images = self.numpy_to_pil(images) return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) -if __name__ == "__main__": - model_name_or_path = "runwayml/stable-diffusion-v1-5" +def example(): + pipeline_info = PipelineInfo("1.5") + model_name_or_path = pipeline_info.name() scheduler = DDIMScheduler.from_pretrained(model_name_or_path, subfolder="scheduler") - pipe = OnnxruntimeCudaStableDiffusionPipeline.from_pretrained( model_name_or_path, scheduler=scheduler, + pipeline_info=pipeline_info, ) # re-use cached folder to save ONNX models - pipe.set_cached_folder(model_name_or_path) + pipe.set_cached_folder(model_name_or_path, resume_download=True, local_files_only=True) pipe = pipe.to("cuda", torch_dtype=torch.float16) prompt = "photorealistic new zealand hills" image = pipe(prompt).images[0] - image.save("ort_trt_txt2img_new_zealand_hills.png") + image.save("ort_cuda_txt2img_new_zealand_hills.png") + + +if __name__ == "__main__": + example() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_tensorrt_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_tensorrt_txt2img.py index 6f3c215f36318..c663e37c7ea7d 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_tensorrt_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_tensorrt_txt2img.py @@ -32,13 +32,11 @@ pip install onnxruntime-gpu """ -import gc +import logging import os -import shutil from typing import List, Optional, Union import torch -from cuda import cudart from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import ( StableDiffusionPipeline, @@ -46,224 +44,15 @@ StableDiffusionSafetyChecker, ) from diffusers.schedulers import DDIMScheduler -from diffusers.utils import DIFFUSERS_CACHE, logging -from huggingface_hub import snapshot_download -from models import CLIP, VAE, UNet -from ort_utils import OrtCudaSession +from diffusion_models import PipelineInfo +from engine_builder_ort_trt import OrtTensorrtEngineBuilder +from ort_utils import StableDiffusionPipelineMixin from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer -import onnxruntime as ort +logger = logging.getLogger(__name__) -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -class Engine(OrtCudaSession): - def __init__(self, engine_path, device_id, onnx_path, fp16, input_profile, workspace_size, enable_cuda_graph): - self.engine_path = engine_path - self.ort_trt_provider_options = self.get_tensorrt_provider_options( - input_profile, - workspace_size, - fp16, - device_id, - enable_cuda_graph, - ) - - sess_options = ort.SessionOptions() - sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL - ort_session = ort.InferenceSession( - onnx_path, - sess_options, - providers=[ - ("TensorrtExecutionProvider", self.ort_trt_provider_options), - ], - ) - - device = torch.device("cuda", device_id) - super().__init__(ort_session, device, enable_cuda_graph) - - def get_tensorrt_provider_options(self, input_profile, workspace_size, fp16, device_id, enable_cuda_graph): - trt_ep_options = { - "device_id": device_id, - "trt_fp16_enable": fp16, - "trt_engine_cache_enable": True, - "trt_timing_cache_enable": True, - "trt_detailed_build_log": True, - "trt_engine_cache_path": self.engine_path, - } - - if enable_cuda_graph: - trt_ep_options["trt_cuda_graph_enable"] = True - - if workspace_size > 0: - trt_ep_options["trt_max_workspace_size"] = workspace_size - - if input_profile: - min_shapes = [] - max_shapes = [] - opt_shapes = [] - for name, profile in input_profile.items(): - assert isinstance(profile, list) and len(profile) == 3 - min_shape = profile[0] - opt_shape = profile[1] - max_shape = profile[2] - assert len(min_shape) == len(opt_shape) and len(opt_shape) == len(max_shape) - - min_shapes.append(f"{name}:" + "x".join([str(x) for x in min_shape])) - opt_shapes.append(f"{name}:" + "x".join([str(x) for x in opt_shape])) - max_shapes.append(f"{name}:" + "x".join([str(x) for x in max_shape])) - - trt_ep_options["trt_profile_min_shapes"] = ",".join(min_shapes) - trt_ep_options["trt_profile_max_shapes"] = ",".join(max_shapes) - trt_ep_options["trt_profile_opt_shapes"] = ",".join(opt_shapes) - - logger.info("trt_ep_options=%s", trt_ep_options) - - return trt_ep_options - - -def get_onnx_path(model_name, onnx_dir, opt=True): - return os.path.join(onnx_dir, model_name + (".opt" if opt else "") + ".onnx") - - -def get_engine_path(engine_dir, model_name, profile_id): - return os.path.join(engine_dir, model_name + profile_id) - - -def has_engine_file(engine_path): - if os.path.isdir(engine_path): - children = os.scandir(engine_path) - for entry in children: - if entry.is_file() and entry.name.endswith(".engine"): - return True - return False - - -def get_work_space_size(model_name, max_workspace_size): - gibibyte = 2**30 - workspace_size = 4 * gibibyte if model_name == "clip" else max_workspace_size - if workspace_size == 0: - _, free_mem, _ = cudart.cudaMemGetInfo() - # The following logic are adopted from TensorRT demo diffusion. - if free_mem > 6 * gibibyte: - workspace_size = free_mem - 4 * gibibyte - return workspace_size - - -def build_engines( - models, - engine_dir, - onnx_dir, - onnx_opset, - opt_image_height, - opt_image_width, - opt_batch_size=1, - force_engine_rebuild=False, - static_batch=False, - static_image_shape=True, - max_workspace_size=0, - device_id=0, - enable_cuda_graph=False, -): - if force_engine_rebuild: - if os.path.isdir(onnx_dir): - logger.info("Remove existing directory %s since force_engine_rebuild is enabled", onnx_dir) - shutil.rmtree(onnx_dir) - if os.path.isdir(engine_dir): - logger.info("Remove existing directory %s since force_engine_rebuild is enabled", engine_dir) - shutil.rmtree(engine_dir) - - if not os.path.isdir(engine_dir): - os.makedirs(engine_dir) - - if not os.path.isdir(onnx_dir): - os.makedirs(onnx_dir) - - # Export models to ONNX - for model_name, model_obj in models.items(): - profile_id = model_obj.get_profile_id( - opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape - ) - engine_path = get_engine_path(engine_dir, model_name, profile_id) - if not has_engine_file(engine_path): - onnx_path = get_onnx_path(model_name, onnx_dir, opt=False) - onnx_opt_path = get_onnx_path(model_name, onnx_dir) - if not os.path.exists(onnx_opt_path): - if not os.path.exists(onnx_path): - logger.info(f"Exporting model: {onnx_path}") - model = model_obj.get_model() - with torch.inference_mode(), torch.autocast("cuda"): - inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) - torch.onnx.export( - model, - inputs, - onnx_path, - export_params=True, - opset_version=onnx_opset, - do_constant_folding=True, - input_names=model_obj.get_input_names(), - output_names=model_obj.get_output_names(), - dynamic_axes=model_obj.get_dynamic_axes(), - ) - del model - torch.cuda.empty_cache() - gc.collect() - else: - logger.info("Found cached model: %s", onnx_path) - - # Optimize onnx - if not os.path.exists(onnx_opt_path): - logger.info("Generating optimizing model: %s", onnx_opt_path) - model_obj.optimize_trt(onnx_path, onnx_opt_path) - else: - logger.info("Found cached optimized model: %s", onnx_opt_path) - - built_engines = {} - for model_name, model_obj in models.items(): - profile_id = model_obj.get_profile_id( - opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape - ) - - engine_path = get_engine_path(engine_dir, model_name, profile_id) - onnx_opt_path = get_onnx_path(model_name, onnx_dir) - - if not has_engine_file(engine_path): - logger.info( - "Building TensorRT engine for %s from %s to %s. It can take a while to complete...", - model_name, - onnx_opt_path, - engine_path, - ) - else: - logger.info("Reuse cached TensorRT engine in directory %s", engine_path) - - input_profile = model_obj.get_input_profile( - opt_batch_size, - opt_image_height, - opt_image_width, - static_batch=static_batch, - static_image_shape=static_image_shape, - ) - - engine = Engine( - engine_path, - device_id, - onnx_opt_path, - fp16=True, - input_profile=input_profile, - workspace_size=get_work_space_size(model_name, max_workspace_size), - enable_cuda_graph=enable_cuda_graph, - ) - - built_engines[model_name] = engine - - return built_engines - - -def run_engine(engine, feed_dict): - return engine.infer(feed_dict) - - -class OnnxruntimeTensorRTStableDiffusionPipeline(StableDiffusionPipeline): +class OnnxruntimeTensorRTStableDiffusionPipeline(StableDiffusionPipelineMixin, StableDiffusionPipeline): r""" Pipeline for text-to-image generation using TensorRT execution provider in ONNX Runtime. @@ -285,11 +74,12 @@ def __init__( max_batch_size: int = 16, # ONNX export parameters onnx_opset: int = 17, - onnx_dir: str = "onnx", + onnx_dir: str = "onnx_trt", # TensorRT engine build parameters - engine_dir: str = "onnxruntime_tensorrt_engine", + engine_dir: str = "ORT_TRT", # use short name here to avoid path exceeds 260 chars in Windows. force_engine_rebuild: bool = False, enable_cuda_graph: bool = False, + pipeline_info: Optional[PipelineInfo] = None, ): super().__init__( vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker @@ -299,16 +89,14 @@ def __init__( self.image_height = image_height self.image_width = image_width - self.inpaint = False self.onnx_opset = onnx_opset self.onnx_dir = onnx_dir self.engine_dir = engine_dir self.force_engine_rebuild = force_engine_rebuild - self.enable_cuda_graph = enable_cuda_graph - # Although cuda graph requires static input shape, engine built with dyamic batch gets better performance in T4. + # Although cuda graph requires static input shape, engine built with dynamic batch gets better performance in T4. # Use static batch could reduce GPU memory footprint. - self.build_static_batch = False + self.build_static_batch = enable_cuda_graph # TODO: support dynamic image shape. self.build_dynamic_shape = False @@ -318,54 +106,13 @@ def __init__( if self.build_dynamic_shape or self.image_height > 512 or self.image_width > 512: self.max_batch_size = 4 - self.models = {} # loaded in __load_models() self.engines = {} # loaded in build_engines() - - def __load_models(self): - self.embedding_dim = self.text_encoder.config.hidden_size - - self.models["clip"] = CLIP( - self.text_encoder, - device=self.torch_device, - max_batch_size=self.max_batch_size, - embedding_dim=self.embedding_dim, - ) - - self.models["unet"] = UNet( - self.unet, - device=self.torch_device, - fp16=True, - max_batch_size=self.max_batch_size, - embedding_dim=self.embedding_dim, - unet_dim=(9 if self.inpaint else 4), + self.engine_builder = OrtTensorrtEngineBuilder( + pipeline_info, max_batch_size=max_batch_size, use_cuda_graph=enable_cuda_graph ) - self.models["vae"] = VAE( - self.vae, device=self.torch_device, max_batch_size=self.max_batch_size, embedding_dim=self.embedding_dim - ) - - @classmethod - def set_cached_folder(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs): - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", False) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - - cls.cached_folder = ( - pretrained_model_name_or_path - if os.path.isdir(pretrained_model_name_or_path) - else snapshot_download( - pretrained_model_name_or_path, - cache_dir=cache_dir, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - ) - ) + self.pipeline_info = pipeline_info + self.stages = pipeline_info.stages() def to( self, @@ -381,11 +128,9 @@ def to( self.torch_device = self._execution_device logger.info(f"Running inference on device: {self.torch_device}") - self.__load_models() - - self.engines = build_engines( - self.models, + self.engines = self.engine_builder.build_engines( self.engine_dir, + None, self.onnx_dir, self.onnx_opset, opt_image_height=self.image_height, @@ -394,96 +139,10 @@ def to( static_batch=self.build_static_batch, static_image_shape=not self.build_dynamic_shape, device_id=self.torch_device.index, - enable_cuda_graph=self.enable_cuda_graph, ) return self - def __encode_prompt(self, prompt, negative_prompt): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - """ - # Tokenize prompt - text_input_ids = ( - self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - .input_ids.type(torch.int32) - .to(self.torch_device) - ) - - # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt - text_embeddings = run_engine(self.engines["clip"], {"input_ids": text_input_ids})["text_embeddings"].clone() - - # Tokenize negative prompt - uncond_input_ids = ( - self.tokenizer( - negative_prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - .input_ids.type(torch.int32) - .to(self.torch_device) - ) - - uncond_embeddings = run_engine(self.engines["clip"], {"input_ids": uncond_input_ids})["text_embeddings"] - - # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) - - return text_embeddings - - def __denoise_latent(self, latents, text_embeddings, timesteps=None, mask=None, masked_image_latents=None): - if not isinstance(timesteps, torch.Tensor): - timesteps = self.scheduler.timesteps - for _step_index, timestep in enumerate(timesteps): - # Expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) - if isinstance(mask, torch.Tensor): - latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) - - # Predict the noise residual - timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep - - noise_pred = run_engine( - self.engines["unet"], - {"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings}, - )["latent"] - - # Perform guidance - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample - - latents = 1.0 / 0.18215 * latents - return latents - - def __decode_latent(self, latents): - images = run_engine(self.engines["vae"], {"latent": latents})["images"] - images = (images / 2 + 0.5).clamp(0, 1) - return images.cpu().permute(0, 2, 3, 1).float().numpy() - - def __allocate_buffers(self, image_height, image_width, batch_size): - # Allocate output tensors for I/O bindings - for model_name, obj in self.models.items(): - self.engines[model_name].allocate_buffers(obj.get_shape_dict(batch_size, image_height, image_width)) - @torch.no_grad() def __call__( self, @@ -547,11 +206,11 @@ def __call__( f"Batch size {len(prompt)} is larger than allowed {self.max_batch_size}. If dynamic shape is used, then maximum batch size is 4" ) - self.__allocate_buffers(self.image_height, self.image_width, batch_size) + self.engine_builder.load_resources(self.image_height, self.image_width, batch_size) with torch.inference_mode(), torch.autocast("cuda"): # CLIP text encoder - text_embeddings = self.__encode_prompt(prompt, negative_prompt) + text_embeddings = self.encode_prompt(self.engines["clip"], prompt, negative_prompt) # Pre-initialize latents num_channels_latents = self.unet.config.in_channels @@ -566,10 +225,10 @@ def __call__( ) # UNet denoiser - latents = self.__denoise_latent(latents, text_embeddings) + latents = self.denoise_latent(self.engines["unet"], latents, text_embeddings) # VAE decode latent - images = self.__decode_latent(latents) + images = self.decode_latent(self.engines["vae"], latents) images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype) images = self.numpy_to_pil(images) @@ -577,8 +236,8 @@ def __call__( if __name__ == "__main__": - model_name_or_path = "runwayml/stable-diffusion-v1-5" - + pipeline_info = PipelineInfo("1.5") + model_name_or_path = pipeline_info.name() scheduler = DDIMScheduler.from_pretrained(model_name_or_path, subfolder="scheduler") pipe = OnnxruntimeTensorRTStableDiffusionPipeline.from_pretrained( @@ -589,6 +248,7 @@ def __call__( image_height=512, image_width=512, max_batch_size=4, + pipeline_info=pipeline_info, ) # re-use cached folder to save ONNX models and TensorRT Engines diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index aef60a534608a..ffcfd6d9fd7e0 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -13,7 +13,7 @@ # python optimize_pipeline.py -i ./sd-v1-5 -o ./sd-v1-5-fp16 --float16 # # Note that the optimizations are carried out for CUDA Execution Provider at first, other EPs may not have the support -# for the fused opeartors. The users could disable the operator fusion manually to workaround. +# for the fused operators. The users could disable the operator fusion manually to workaround. import argparse import logging @@ -49,7 +49,6 @@ def has_external_data(onnx_model_path): def _optimize_sd_pipeline( source_dir: Path, target_dir: Path, - overwrite: bool, use_external_data_format: Optional[bool], float16: bool, force_fp32_ops: List[str], @@ -61,7 +60,6 @@ def _optimize_sd_pipeline( Args: source_dir (Path): Root of input directory of stable diffusion onnx pipeline with float32 models. target_dir (Path): Root of output directory of stable diffusion onnx pipeline with optimized models. - overwrite (bool): Overwrite files if exists. use_external_data_format (Optional[bool]): use external data format. float16 (bool): use half precision force_fp32_ops(List[str]): operators that are forced to run in float32. @@ -235,7 +233,7 @@ def optimize_stable_diffusion_pipeline( args, ): if os.path.exists(output_dir): - if args.overwrite: + if overwrite: shutil.rmtree(output_dir, ignore_errors=True) else: raise RuntimeError("output directory existed:{output_dir}. Add --overwrite to empty the directory.") @@ -249,7 +247,6 @@ def optimize_stable_diffusion_pipeline( _optimize_sd_pipeline( source_dir, target_dir, - overwrite, use_external_data_format, float16, args.force_fp32_ops, @@ -321,7 +318,7 @@ def parse_arguments(argv: Optional[List[str]] = None): required=False, action="store_true", help="Onnx model larger than 2GB need to use external data format. " - "If specifed, save each onnx model to two files: one for onnx graph, another for weights. " + "If specified, save each onnx model to two files: one for onnx graph, another for weights. " "If not specified, use same format as original model by default. ", ) parser.set_defaults(use_external_data_format=None) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py index 0824c8f07d6e2..2c4b8e8a1639e 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py @@ -12,6 +12,7 @@ from pathlib import Path import onnx +from optimize_pipeline import has_external_data from onnxruntime.transformers.fusion_options import FusionOptions from onnxruntime.transformers.onnx_model_clip import ClipOnnxModel @@ -32,21 +33,25 @@ def __init__(self, model_type: str): "clip": ClipOnnxModel, } - def optimize_by_ort(self, onnx_model): + def optimize_by_ort(self, onnx_model, use_external_data_format=False): # Use this step to see the final graph that executed by Onnx Runtime. with tempfile.TemporaryDirectory() as tmp_dir: # Save to a temporary file so that we can load it with Onnx Runtime. logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...") tmp_model_path = Path(tmp_dir) / "model.onnx" - onnx_model.save_model_to_file(str(tmp_model_path)) - ort_optimized_model_path = tmp_model_path + onnx_model.save_model_to_file(str(tmp_model_path), use_external_data_format=use_external_data_format) + ort_optimized_model_path = Path(tmp_dir) / "optimized.onnx" optimize_by_onnxruntime( - str(tmp_model_path), use_gpu=True, optimized_model_path=str(ort_optimized_model_path) + str(tmp_model_path), + use_gpu=True, + optimized_model_path=str(ort_optimized_model_path), + save_as_external_data=use_external_data_format, + external_data_filename="optimized.onnx_data", ) model = onnx.load(str(ort_optimized_model_path), load_external_data=True) return self.model_type_class_mapping[self.model_type](model) - def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True): + def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True, keep_io_types=False, keep_outputs=None): """Optimize onnx model using ONNX Runtime transformers optimizer""" logger.info(f"Optimize {input_fp32_onnx_path}...") fusion_options = FusionOptions(self.model_type) @@ -54,6 +59,8 @@ def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True): fusion_options.enable_packed_kv = False fusion_options.enable_packed_qkv = False + use_external_data_format = has_external_data(input_fp32_onnx_path) + m = optimize_model( input_fp32_onnx_path, model_type=self.model_type, @@ -64,21 +71,24 @@ def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True): use_gpu=True, ) - if self.model_type == "clip": - m.prune_graph(outputs=["text_embeddings"]) # remove the pooler_output, and only keep the first output. + if keep_outputs is None and self.model_type == "clip": + # remove the pooler_output, and only keep the first output. + keep_outputs = ["text_embeddings"] + + if keep_outputs: + m.prune_graph(outputs=keep_outputs) if float16: logger.info("Convert to float16 ...") m.convert_float_to_float16( - keep_io_types=False, - op_block_list=["RandomNormalLike"], + keep_io_types=keep_io_types, ) - # Note that ORT 1.15 could not save model larger than 2GB. This only works for float16 + # Note that ORT < 1.16 could not save model larger than 2GB. if float16 or (self.model_type != "unet"): - m = self.optimize_by_ort(m) + m = self.optimize_by_ort(m, use_external_data_format=use_external_data_format) m.get_operator_statistics() m.get_fused_operator_statistics() - m.save_model_to_file(optimized_onnx_path, use_external_data_format=(self.model_type == "unet") and not float16) + m.save_model_to_file(optimized_onnx_path, use_external_data_format=use_external_data_format) logger.info("%s is optimized: %s", self.model_type, optimized_onnx_path) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py index 7192e4ad5584f..5c2145845e757 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py @@ -7,122 +7,24 @@ import logging import os import shutil -from collections import OrderedDict -from typing import Any, Dict +from typing import Union import torch import onnxruntime as ort -from onnxruntime.transformers.io_binding_helper import TypeHelper +from onnxruntime.transformers.io_binding_helper import CudaSession logger = logging.getLogger(__name__) -class OrtCudaSession: - """Inference Session with IO Binding for ONNX Runtime CUDA or TensorRT provider""" - - def __init__(self, ort_session: ort.InferenceSession, device: torch.device, enable_cuda_graph=False): - self.ort_session = ort_session - self.input_names = [input.name for input in self.ort_session.get_inputs()] - self.output_names = [output.name for output in self.ort_session.get_outputs()] - self.io_name_to_numpy_type = TypeHelper.get_io_numpy_type_map(self.ort_session) - self.io_binding = self.ort_session.io_binding() - self.enable_cuda_graph = enable_cuda_graph - - self.input_tensors = OrderedDict() - self.output_tensors = OrderedDict() - self.device = device - - def __del__(self): - del self.input_tensors - del self.output_tensors - del self.io_binding - del self.ort_session - - def allocate_buffers(self, shape_dict: Dict[str, tuple]): - """Allocate tensors for I/O Binding""" - if self.enable_cuda_graph: - for name, shape in shape_dict.items(): - if name in self.input_names: - # Reuse allocated buffer when the shape is same - if name in self.input_tensors: - if tuple(self.input_tensors[name].shape) == tuple(shape): - continue - raise RuntimeError("Expect static input shape for cuda graph") - - numpy_dtype = self.io_name_to_numpy_type[name] - tensor = torch.empty(tuple(shape), dtype=TypeHelper.numpy_type_to_torch_type(numpy_dtype)).to( - device=self.device - ) - self.input_tensors[name] = tensor - - self.io_binding.bind_input( - name, - tensor.device.type, - tensor.device.index, - numpy_dtype, - list(tensor.size()), - tensor.data_ptr(), - ) - - for name, shape in shape_dict.items(): - if name in self.output_names: - # Reuse allocated buffer when the shape is same - if name in self.output_tensors and tuple(self.output_tensors[name].shape) == tuple(shape): - continue - - numpy_dtype = self.io_name_to_numpy_type[name] - tensor = torch.empty(tuple(shape), dtype=TypeHelper.numpy_type_to_torch_type(numpy_dtype)).to( - device=self.device - ) - self.output_tensors[name] = tensor - - self.io_binding.bind_output( - name, - tensor.device.type, - tensor.device.index, - numpy_dtype, - list(tensor.size()), - tensor.data_ptr(), - ) - - def infer(self, feed_dict): - """Bind input tensors and run inference""" - for name, tensor in feed_dict.items(): - assert isinstance(tensor, torch.Tensor) and tensor.is_contiguous() - if name in self.input_names: - if self.enable_cuda_graph: - assert self.input_tensors[name].nelement() == tensor.nelement() - assert tensor.device.type == "cuda" - # Update input tensor inplace since cuda graph requires input and output has fixed memory address. - from cuda import cudart - - cudart.cudaMemcpy( - self.input_tensors[name].data_ptr(), - tensor.data_ptr(), - tensor.element_size() * tensor.nelement(), - cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice, - ) - else: - self.io_binding.bind_input( - name, - tensor.device.type, - tensor.device.index, - TypeHelper.torch_type_to_numpy_type(tensor.dtype), - [1] if len(tensor.shape) == 0 else list(tensor.shape), - tensor.data_ptr(), - ) - - self.ort_session.run_with_iobinding(self.io_binding) - - return self.output_tensors - - -class Engine(OrtCudaSession): +# ----------------------------------------------------------------------------------------------------- +# Utilities for CUDA EP +# ----------------------------------------------------------------------------------------------------- +class Engine(CudaSession): def __init__(self, engine_path, provider: str, device_id: int = 0, enable_cuda_graph=False): self.engine_path = engine_path self.provider = provider - self.provider_options = self.get_cuda_provider_options(device_id, enable_cuda_graph) + self.provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph) device = torch.device("cuda", device_id) ort_session = ort.InferenceSession( @@ -135,13 +37,6 @@ def __init__(self, engine_path, provider: str, device_id: int = 0, enable_cuda_g super().__init__(ort_session, device, enable_cuda_graph) - def get_cuda_provider_options(self, device_id: int, enable_cuda_graph: bool) -> Dict[str, Any]: - return { - "device_id": device_id, - "arena_extend_strategy": "kSameAsRequested", - "enable_cuda_graph": enable_cuda_graph, - } - class Engines: def __init__(self, provider, onnx_opset: int = 14): @@ -197,9 +92,16 @@ def build( model = model_obj.get_model().to(model_obj.device) with torch.inference_mode(): inputs = model_obj.get_sample_input(1, 512, 512) + fp32_inputs = tuple( + [ + (tensor.to(torch.float32) if tensor.dtype == torch.float16 else tensor) + for tensor in inputs + ] + ) + torch.onnx.export( model, - inputs, + fp32_inputs, onnx_path, export_params=True, opset_version=self.onnx_opset, @@ -224,3 +126,125 @@ def build( def get_engine(self, model_name): return self.engines[model_name] + + +def run_engine(engine, feed_dict): + return engine.infer(feed_dict) + + +# ----------------------------------------------------------------------------------------------------- +# Utilities for both CUDA and TensorRT EP +# ----------------------------------------------------------------------------------------------------- + + +class StableDiffusionPipelineMixin: + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def encode_prompt(self, clip_engine, prompt, negative_prompt): + """ + Encodes the prompt into text encoder hidden states. + """ + + # Tokenize prompt + text_input_ids = ( + self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.torch_device) + ) + + # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt + text_embeddings = run_engine(clip_engine, {"input_ids": text_input_ids})["text_embeddings"].clone() + + # Tokenize negative prompt + uncond_input_ids = ( + self.tokenizer( + negative_prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.torch_device) + ) + + uncond_embeddings = run_engine(clip_engine, {"input_ids": uncond_input_ids})["text_embeddings"] + + # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) + + return text_embeddings + + def denoise_latent( + self, + unet_engine, + latents, + text_embeddings, + timesteps=None, + mask=None, + masked_image_latents=None, + timestep_fp16=False, + ): + if not isinstance(timesteps, torch.Tensor): + timesteps = self.scheduler.timesteps + + for _step_index, timestep in enumerate(timesteps): + # Expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) + if isinstance(mask, torch.Tensor): + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # Predict the noise residual + timestep_float = timestep.to(torch.float16) if timestep_fp16 else timestep.to(torch.float32) + + noise_pred = run_engine( + unet_engine, + {"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings}, + )["latent"] + + # Perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample + + latents = 1.0 / 0.18215 * latents + return latents + + def decode_latent(self, vae_engine, latents): + images = run_engine(vae_engine, {"latent": latents})["images"] + images = (images / 2 + 0.5).clamp(0, 1) + return images.cpu().permute(0, 2, 3, 1).float().numpy() + + def set_cached_folder(self, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs): + from diffusers.utils import DIFFUSERS_CACHE + from huggingface_hub import snapshot_download + + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + + self.cached_folder = ( + pretrained_model_name_or_path + if os.path.isdir(pretrained_model_name_or_path) + else snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + ) + ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py new file mode 100644 index 0000000000000..0e2aeb6174666 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py @@ -0,0 +1,232 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import time + +import torch +from diffusion_models import PipelineInfo +from pipeline_stable_diffusion import StableDiffusionPipeline + + +class Img2ImgXLPipeline(StableDiffusionPipeline): + """ + Stable Diffusion Img2Img XL pipeline using NVidia TensorRT. + """ + + def __init__(self, pipeline_info: PipelineInfo, *args, **kwargs): + """ + Initializes the Img2Img XL Diffusion pipeline. + + Args: + pipeline_info (PipelineInfo): + Version and Type of stable diffusion pipeline. + """ + assert pipeline_info.is_sd_xl_refiner() + + super().__init__(pipeline_info, *args, **kwargs) + + self.requires_aesthetics_score = True + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype + ): + if self.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,)) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0).to(device=self.device) + return add_time_ids + + def _infer( + self, + prompt, + negative_prompt, + init_image, + image_height, + image_width, + denoising_steps=30, + guidance=5.0, + seed=None, + warmup=False, + return_type="image", + ): + assert len(prompt) == len(negative_prompt) + + # TODO(tianleiwu): Need we use image_height and image_width for the target size here? + original_size = (1024, 1024) + crops_coords_top_left = (0, 0) + target_size = (1024, 1024) + strength = 0.3 + aesthetic_score = 6.0 + negative_aesthetic_score = 2.5 + + self.set_denoising_steps(denoising_steps) + self.set_random_seed(seed) + + with torch.inference_mode(), torch.autocast("cuda"): + batch_size = len(prompt) + + torch.cuda.synchronize() + e2e_tic = time.perf_counter() + + # Initialize timesteps + timesteps, t_start = self.initialize_timesteps(self.denoising_steps, strength) + latent_timestep = timesteps[:1].repeat(batch_size) + + # CLIP text encoder 2 + text_embeddings, pooled_embeddings2 = self.encode_prompt( + prompt, + negative_prompt, + encoder="clip2", + tokenizer=self.tokenizer2, + pooled_outputs=True, + output_hidden_states=True, + ) + + # Time embeddings + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + dtype=text_embeddings.dtype, + ) + + add_time_ids = add_time_ids.repeat(batch_size, 1) + + add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids} + + # Pre-process input image + init_image = self.preprocess_images(batch_size, (init_image,))[0] + + # VAE encode init image + if init_image.shape[1] == 4: + init_latents = init_image + else: + init_latents = self.encode_image(init_image) + + # Add noise to latents using timesteps + noise = torch.randn(init_latents.shape, device=self.device, dtype=torch.float32, generator=self.generator) + latents = self.scheduler.add_noise(init_latents, noise, t_start, latent_timestep) + + # UNet denoiser + latents = self.denoise_latent( + latents, + text_embeddings, + timesteps=timesteps, + step_offset=t_start, + denoiser="unetxl", + guidance=guidance, + add_kwargs=add_kwargs, + ) + + with torch.inference_mode(): + # VAE decode latent + if return_type == "latents": + images = latents * self.vae_scaling_factor + else: + images = self.decode_latent(latents) + + torch.cuda.synchronize() + e2e_toc = time.perf_counter() + + if not warmup: + print("SD-XL Refiner Pipeline") + self.print_summary(e2e_tic, e2e_toc, batch_size) + self.save_images(images, "img2img-xl", prompt) + + return images, (e2e_toc - e2e_tic) * 1000.0 + + def run( + self, + prompt, + negative_prompt, + init_image, + image_height, + image_width, + denoising_steps=30, + guidance=5.0, + seed=None, + warmup=False, + return_type="images", + ): + """ + Run the diffusion pipeline. + + Args: + prompt (str): + The text prompt to guide image generation. + negative_prompt (str): + The prompt not to guide the image generation. + init_image (tuple[torch.Tensor]): + Image from base pipeline. + image_height (int): + Height (in pixels) of the image to be generated. Must be a multiple of 8. + image_width (int): + Width (in pixels) of the image to be generated. Must be a multiple of 8. + denoising_steps (int): + Number of denoising steps. More steps usually lead to higher quality image at the expense of slower inference. + guidance (float): + Higher guidance scale encourages to generate images that are closely linked to the text prompt. + seed (int): + Seed for the random generator + warmup (bool): + Indicate if this is a warmup run. + return_type (str): + It can be "latents" or "images". + """ + + if self.is_backend_tensorrt(): + import tensorrt as trt + from trt_utilities import TRT_LOGGER + + with trt.Runtime(TRT_LOGGER): + return self._infer( + prompt, + negative_prompt, + init_image, + image_height, + image_width, + denoising_steps=denoising_steps, + guidance=guidance, + seed=seed, + warmup=warmup, + return_type=return_type, + ) + else: + return self._infer( + prompt, + negative_prompt, + init_image, + image_height, + image_width, + denoising_steps=denoising_steps, + guidance=guidance, + seed=seed, + warmup=warmup, + return_type=return_type, + ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py new file mode 100644 index 0000000000000..a053c9d5d0835 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py @@ -0,0 +1,429 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import os +import pathlib +import random + +import nvtx +import torch +from cuda import cudart +from diffusion_models import PipelineInfo, get_tokenizer +from diffusion_schedulers import DDIMScheduler, EulerAncestralDiscreteScheduler, UniPCMultistepScheduler +from engine_builder import EngineType +from engine_builder_ort_trt import OrtTensorrtEngineBuilder +from engine_builder_tensorrt import TensorrtEngineBuilder + + +class StableDiffusionPipeline: + """ + Stable Diffusion pipeline using TensorRT. + """ + + def __init__( + self, + pipeline_info: PipelineInfo, + max_batch_size=16, + scheduler="DDIM", + device="cuda", + output_dir=".", + hf_token=None, + verbose=False, + nvtx_profile=False, + use_cuda_graph=False, + framework_model_dir="pytorch_model", + engine_type: EngineType = EngineType.ORT_TRT, + ): + """ + Initializes the Diffusion pipeline. + + Args: + pipeline_info (PipelineInfo): + Version and Type of pipeline. + max_batch_size (int): + Maximum batch size for dynamic batch engine. + scheduler (str): + The scheduler to guide the denoising process. Must be one of [DDIM, EulerA, UniPC]. + device (str): + PyTorch device to run inference. Default: 'cuda' + output_dir (str): + Output directory for log files and image artifacts + hf_token (str): + HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. + verbose (bool): + Enable verbose logging. + nvtx_profile (bool): + Insert NVTX profiling markers. + use_cuda_graph (bool): + Use CUDA graph to capture engine execution and then launch inference + framework_model_dir (str): + cache directory for framework checkpoints + engine_type (EngineType) + backend engine type like ORT_TRT or TRT + """ + + self.pipeline_info = pipeline_info + self.version = pipeline_info.version + + self.vae_scaling_factor = pipeline_info.vae_scaling_factor() + + self.max_batch_size = max_batch_size + + self.framework_model_dir = framework_model_dir + self.output_dir = output_dir + for directory in [self.framework_model_dir, self.output_dir]: + if not os.path.exists(directory): + print(f"[I] Create directory: {directory}") + pathlib.Path(directory).mkdir(parents=True) + + self.hf_token = hf_token + self.device = device + self.torch_device = torch.device(device, torch.cuda.current_device()) + self.verbose = verbose + self.nvtx_profile = nvtx_profile + + # Scheduler options + sched_opts = {"num_train_timesteps": 1000, "beta_start": 0.00085, "beta_end": 0.012} + if self.version in ("2.0", "2.1"): + sched_opts["prediction_type"] = "v_prediction" + else: + sched_opts["prediction_type"] = "epsilon" + + if scheduler == "DDIM": + self.scheduler = DDIMScheduler(device=self.device, **sched_opts) + elif scheduler == "EulerA": + self.scheduler = EulerAncestralDiscreteScheduler(device=self.device, **sched_opts) + elif scheduler == "UniPC": + self.scheduler = UniPCMultistepScheduler(device=self.device) + else: + raise ValueError("Scheduler should be either DDIM, EulerA or UniPC") + + self.stages = pipeline_info.stages() + + self.vae_torch_fallback = self.pipeline_info.is_sd_xl() + + self.use_cuda_graph = use_cuda_graph + + self.tokenizer = None + self.tokenizer2 = None + + self.generator = None + self.denoising_steps = None + + # backend engine + self.engine_type = engine_type + if engine_type == EngineType.TRT: + self.backend = TensorrtEngineBuilder(pipeline_info, max_batch_size, hf_token, device, use_cuda_graph) + elif engine_type == EngineType.ORT_TRT: + self.backend = OrtTensorrtEngineBuilder(pipeline_info, max_batch_size, hf_token, device, use_cuda_graph) + else: + raise RuntimeError(f"Backend engine type {engine_type.name} is not supported") + + # Load text tokenizer + if not self.pipeline_info.is_sd_xl_refiner(): + self.tokenizer = get_tokenizer( + self.pipeline_info, self.framework_model_dir, self.hf_token, subfolder="tokenizer" + ) + + if self.pipeline_info.is_sd_xl(): + self.tokenizer2 = get_tokenizer( + self.pipeline_info, self.framework_model_dir, self.hf_token, subfolder="tokenizer_2" + ) + + # Create CUDA events + self.events = {} + for stage in ["clip", "denoise", "vae", "vae_encoder"]: + for marker in ["start", "stop"]: + self.events[stage + "-" + marker] = cudart.cudaEventCreate()[1] + + def is_backend_tensorrt(self): + return self.engine_type == EngineType.TRT + + def set_denoising_steps(self, denoising_steps: int): + if self.denoising_steps != denoising_steps: + assert self.denoising_steps is None # TODO(tianleiwu): support changing steps in different runs + # Pre-compute latent input scales and linear multistep coefficients + self.scheduler.set_timesteps(denoising_steps) + self.scheduler.configure() + self.denoising_steps = denoising_steps + + def load_resources(self, image_height, image_width, batch_size): + # If engine is built with static input shape, call this only once after engine build. + # Otherwise, it need be called before every inference run. + self.backend.load_resources(image_height, image_width, batch_size) + + def set_random_seed(self, seed): + # Initialize noise generator. Usually, it is done before a batch of inference. + self.generator = torch.Generator(device="cuda").manual_seed(seed) if isinstance(seed, int) else None + + def teardown(self): + for e in self.events.values(): + cudart.cudaEventDestroy(e) + + if self.backend: + self.backend.teardown() + + def run_engine(self, model_name, feed_dict): + return self.backend.run_engine(model_name, feed_dict) + + def initialize_latents(self, batch_size, unet_channels, latent_height, latent_width): + latents_dtype = torch.float32 # text_embeddings.dtype + latents_shape = (batch_size, unet_channels, latent_height, latent_width) + latents = torch.randn(latents_shape, device=self.device, dtype=latents_dtype, generator=self.generator) + # Scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def initialize_timesteps(self, timesteps, strength): + self.scheduler.set_timesteps(timesteps) + offset = self.scheduler.steps_offset if hasattr(self.scheduler, "steps_offset") else 0 + init_timestep = int(timesteps * strength) + offset + init_timestep = min(init_timestep, timesteps) + t_start = max(timesteps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].to(self.device) + return timesteps, t_start + + def preprocess_images(self, batch_size, images=()): + if self.nvtx_profile: + nvtx_image_preprocess = nvtx.start_range(message="image_preprocess", color="pink") + init_images = [] + for i in images: + image = i.to(self.device).float() + if image.shape[0] != batch_size: + image = image.repeat(batch_size, 1, 1, 1) + init_images.append(image) + if self.nvtx_profile: + nvtx.end_range(nvtx_image_preprocess) + return tuple(init_images) + + def encode_prompt( + self, prompt, negative_prompt, encoder="clip", tokenizer=None, pooled_outputs=False, output_hidden_states=False + ): + if tokenizer is None: + tokenizer = self.tokenizer + + if self.nvtx_profile: + nvtx_clip = nvtx.start_range(message="clip", color="green") + cudart.cudaEventRecord(self.events["clip-start"], 0) + + # Tokenize prompt + text_input_ids = ( + tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.device) + ) + + # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt + outputs = self.run_engine(encoder, {"input_ids": text_input_ids}) + text_embeddings = outputs["text_embeddings"].clone() + if output_hidden_states: + hidden_states = outputs["hidden_states"].clone() + + # Tokenize negative prompt + uncond_input_ids = ( + tokenizer( + negative_prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.device) + ) + + outputs = self.run_engine(encoder, {"input_ids": uncond_input_ids}) + uncond_embeddings = outputs["text_embeddings"] + if output_hidden_states: + uncond_hidden_states = outputs["hidden_states"] + + # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) + + if pooled_outputs: + pooled_output = text_embeddings + + if output_hidden_states: + text_embeddings = torch.cat([uncond_hidden_states, hidden_states]).to(dtype=torch.float16) + + cudart.cudaEventRecord(self.events["clip-stop"], 0) + if self.nvtx_profile: + nvtx.end_range(nvtx_clip) + + if pooled_outputs: + return text_embeddings, pooled_output + return text_embeddings + + def denoise_latent( + self, + latents, + text_embeddings, + denoiser="unet", + timesteps=None, + step_offset=0, + mask=None, + masked_image_latents=None, + guidance=7.5, + image_guidance=1.5, + add_kwargs=None, + ): + assert guidance > 1.0, "Guidance has to be > 1.0" + assert image_guidance > 1.0, "Image guidance has to be > 1.0" + + cudart.cudaEventRecord(self.events["denoise-start"], 0) + if not isinstance(timesteps, torch.Tensor): + timesteps = self.scheduler.timesteps + for step_index, timestep in enumerate(timesteps): + if self.nvtx_profile: + nvtx_latent_scale = nvtx.start_range(message="latent_scale", color="pink") + + # Expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, step_offset + step_index, timestep + ) + + if isinstance(mask, torch.Tensor): + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + if self.nvtx_profile: + nvtx.end_range(nvtx_latent_scale) + + # Predict the noise residual + if self.nvtx_profile: + nvtx_unet = nvtx.start_range(message="unet", color="blue") + + timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep + + sample_inp = latent_model_input + timestep_inp = timestep_float + embeddings_inp = text_embeddings + + params = {"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp} + if add_kwargs: + params.update(add_kwargs) + + noise_pred = self.run_engine(denoiser, params)["latent"] + + if self.nvtx_profile: + nvtx.end_range(nvtx_unet) + + if self.nvtx_profile: + nvtx_latent_step = nvtx.start_range(message="latent_step", color="pink") + + # Perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond) + + if type(self.scheduler) == UniPCMultistepScheduler: + latents = self.scheduler.step(noise_pred, timestep, latents, return_dict=False)[0] + else: + latents = self.scheduler.step(noise_pred, latents, step_offset + step_index, timestep) + + if self.nvtx_profile: + nvtx.end_range(nvtx_latent_step) + + latents = 1.0 / self.vae_scaling_factor * latents + cudart.cudaEventRecord(self.events["denoise-stop"], 0) + return latents + + def encode_image(self, init_image): + if self.nvtx_profile: + nvtx_vae = nvtx.start_range(message="vae_encoder", color="red") + cudart.cudaEventRecord(self.events["vae_encoder-start"], 0) + init_latents = self.run_engine("vae_encoder", {"images": init_image})["latent"] + cudart.cudaEventRecord(self.events["vae_encoder-stop"], 0) + if self.nvtx_profile: + nvtx.end_range(nvtx_vae) + + init_latents = self.vae_scaling_factor * init_latents + return init_latents + + def decode_latent(self, latents): + if self.nvtx_profile: + nvtx_vae = nvtx.start_range(message="vae", color="red") + cudart.cudaEventRecord(self.events["vae-start"], 0) + images = self.backend.vae_decode(latents) + cudart.cudaEventRecord(self.events["vae-stop"], 0) + if self.nvtx_profile: + nvtx.end_range(nvtx_vae) + return images + + def print_summary(self, tic, toc, batch_size, vae_enc=False): + print("|------------|--------------|") + print("| {:^10} | {:^12} |".format("Module", "Latency")) + print("|------------|--------------|") + if vae_enc: + print( + "| {:^10} | {:>9.2f} ms |".format( + "VAE-Enc", + cudart.cudaEventElapsedTime(self.events["vae_encoder-start"], self.events["vae_encoder-stop"])[1], + ) + ) + print( + "| {:^10} | {:>9.2f} ms |".format( + "CLIP", cudart.cudaEventElapsedTime(self.events["clip-start"], self.events["clip-stop"])[1] + ) + ) + print( + "| {:^10} | {:>9.2f} ms |".format( + "UNet x " + str(self.denoising_steps), + cudart.cudaEventElapsedTime(self.events["denoise-start"], self.events["denoise-stop"])[1], + ) + ) + print( + "| {:^10} | {:>9.2f} ms |".format( + "VAE-Dec", cudart.cudaEventElapsedTime(self.events["vae-start"], self.events["vae-stop"])[1] + ) + ) + print("|------------|--------------|") + print("| {:^10} | {:>9.2f} ms |".format("Pipeline", (toc - tic) * 1000.0)) + print("|------------|--------------|") + print(f"Throughput: {batch_size / (toc - tic):.2f} image/s") + + @staticmethod + def to_pil_image(images): + images = ( + ((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy() + ) + from PIL import Image + + return [Image.fromarray(images[i]) for i in range(images.shape[0])] + + def save_images(self, images, pipeline, prompt): + image_name_prefix = ( + pipeline + "".join(set(["-" + prompt[i].replace(" ", "_")[:10] for i in range(len(prompt))])) + "-" + ) + + images = self.to_pil_image(images) + random_session_id = str(random.randint(1000, 9999)) + for i, image in enumerate(images): + image_path = os.path.join( + self.output_dir, image_name_prefix + str(i + 1) + "-" + random_session_id + ".png" + ) + print(f"Saving image {i+1} / {len(images)} to: {image_path}") + image.save(image_path) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py new file mode 100644 index 0000000000000..82f73e8b3cc61 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py @@ -0,0 +1,155 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import time + +import torch +from diffusion_models import PipelineInfo +from pipeline_stable_diffusion import StableDiffusionPipeline + + +class Txt2ImgPipeline(StableDiffusionPipeline): + """ + Stable Diffusion Txt2Img pipeline using NVidia TensorRT. + """ + + def __init__(self, pipeline_info: PipelineInfo, **kwargs): + """ + Initializes the Txt2Img Diffusion pipeline. + + Args: + pipeline_info (PipelineInfo): + Version and Type of stable diffusion pipeline. + """ + super().__init__(pipeline_info, **kwargs) + + def _infer( + self, + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=50, + guidance=7.5, + seed=None, + warmup=False, + return_type="latents", + ): + assert len(prompt) == len(negative_prompt) + batch_size = len(prompt) + + self.set_denoising_steps(denoising_steps) + self.set_random_seed(seed) + + with torch.inference_mode(), torch.autocast("cuda"): + # Pre-initialize latents + latents = self.initialize_latents( + batch_size=batch_size, + unet_channels=4, + latent_height=(image_height // 8), + latent_width=(image_width // 8), + ) + + torch.cuda.synchronize() + e2e_tic = time.perf_counter() + + # CLIP text encoder + text_embeddings = self.encode_prompt(prompt, negative_prompt) + + # UNet denoiser + latents = self.denoise_latent(latents, text_embeddings, guidance=guidance) + + # VAE decode latent + images = self.decode_latent(latents) + + torch.cuda.synchronize() + e2e_toc = time.perf_counter() + + if not warmup: + self.print_summary(e2e_tic, e2e_toc, batch_size) + self.save_images(images, "txt2img", prompt) + + return images, (e2e_toc - e2e_tic) * 1000.0 + + def run( + self, + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=30, + guidance=7.5, + seed=None, + warmup=False, + return_type="images", + ): + """ + Run the diffusion pipeline. + + Args: + prompt (str): + The text prompt to guide image generation. + negative_prompt (str): + The prompt not to guide the image generation. + image_height (int): + Height (in pixels) of the image to be generated. Must be a multiple of 8. + image_width (int): + Width (in pixels) of the image to be generated. Must be a multiple of 8. + denoising_steps (int): + Number of denoising steps. More steps usually lead to higher quality image at the expense of slower inference. + guidance (float): + Higher guidance scale encourages to generate images that are closely linked to the text prompt. + seed (int): + Seed for the random generator + warmup (bool): + Indicate if this is a warmup run. + return_type (str): + type of return. The value can be "latents" or "images". + """ + if self.is_backend_tensorrt(): + import tensorrt as trt + from trt_utilities import TRT_LOGGER + + with trt.Runtime(TRT_LOGGER): + return self._infer( + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=denoising_steps, + guidance=guidance, + seed=seed, + warmup=warmup, + return_type=return_type, + ) + else: + return self._infer( + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=denoising_steps, + guidance=guidance, + seed=seed, + warmup=warmup, + return_type=return_type, + ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py new file mode 100644 index 0000000000000..d8f00ed619354 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py @@ -0,0 +1,198 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import time + +import torch +from diffusion_models import PipelineInfo +from pipeline_stable_diffusion import StableDiffusionPipeline + + +class Txt2ImgXLPipeline(StableDiffusionPipeline): + """ + Stable Diffusion Txt2Img XL pipeline. + """ + + def __init__(self, pipeline_info: PipelineInfo, *args, **kwargs): + """ + Initializes the Txt2Img XL Diffusion pipeline. + + Args: + pipeline_info (PipelineInfo): + Version and Type of stable diffusion pipeline. + """ + assert pipeline_info.is_sd_xl_base() + + super().__init__(pipeline_info, *args, **kwargs) + + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + def _infer( + self, + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=30, + guidance=5.0, + seed=None, + warmup=False, + return_type="images", + ): + assert len(prompt) == len(negative_prompt) + + # TODO(tianleiwu): Need we use image_height and image_width for the target size here? + original_size = (1024, 1024) + crops_coords_top_left = (0, 0) + target_size = (1024, 1024) + batch_size = len(prompt) + + self.set_denoising_steps(denoising_steps) + self.set_random_seed(seed) + + with torch.inference_mode(), torch.autocast("cuda"): + # Pre-initialize latents + latents = self.initialize_latents( + batch_size=batch_size, + unet_channels=4, + latent_height=(image_height // 8), + latent_width=(image_width // 8), + ) + + torch.cuda.synchronize() + e2e_tic = time.perf_counter() + + # CLIP text encoder + text_embeddings = self.encode_prompt( + prompt, negative_prompt, encoder="clip", tokenizer=self.tokenizer, output_hidden_states=True + ) + # CLIP text encoder 2 + text_embeddings2, pooled_embeddings2 = self.encode_prompt( + prompt, + negative_prompt, + encoder="clip2", + tokenizer=self.tokenizer2, + pooled_outputs=True, + output_hidden_states=True, + ) + + # Merged text embeddings + text_embeddings = torch.cat([text_embeddings, text_embeddings2], dim=-1) + + # Time embeddings + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=text_embeddings.dtype + ) + add_time_ids = add_time_ids.repeat(batch_size, 1) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0).to(self.device) + + add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids} + + # UNet denoiser + latents = self.denoise_latent( + latents, text_embeddings, denoiser="unetxl", guidance=guidance, add_kwargs=add_kwargs + ) + + # VAE decode latent + if return_type == "latents": + images = latents * self.vae_scaling_factor + else: + images = self.decode_latent(latents) + + torch.cuda.synchronize() + e2e_toc = time.perf_counter() + + if not warmup: + print("SD-XL Base Pipeline") + self.print_summary(e2e_tic, e2e_toc, batch_size) + if return_type == "images": + self.save_images(images, "txt2img-xl", prompt) + + return images, (e2e_toc - e2e_tic) * 1000.0 + + def run( + self, + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=30, + guidance=5.0, + seed=None, + warmup=False, + return_type="images", + ): + """ + Run the diffusion pipeline. + + Args: + prompt (str): + The text prompt to guide image generation. + negative_prompt (str): + The prompt not to guide the image generation. + image_height (int): + Height (in pixels) of the image to be generated. Must be a multiple of 8. + image_width (int): + Width (in pixels) of the image to be generated. Must be a multiple of 8. + denoising_steps (int): + Number of denoising steps. More steps usually lead to higher quality image at the expense of slower inference. + guidance (float): + Higher guidance scale encourages to generate images that are closely linked to the text prompt. + seed (int): + Seed for the random generator + warmup (bool): + Indicate if this is a warmup run. + return_type (str): + It can be "latents" or "images". + """ + + if self.is_backend_tensorrt(): + import tensorrt as trt + from trt_utilities import TRT_LOGGER + + with trt.Runtime(TRT_LOGGER): + return self._infer( + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=denoising_steps, + guidance=guidance, + seed=seed, + warmup=warmup, + return_type=return_type, + ) + else: + return self._infer( + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=denoising_steps, + guidance=guidance, + seed=seed, + warmup=warmup, + return_type=return_type, + ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda.txt index b942749f8dcd2..2a3caf4c2392b 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda.txt @@ -1,8 +1,14 @@ -r requirements.txt -onnxruntime-gpu>=1.14 +onnxruntime-gpu>=1.16 py3nvml>=0.2.7 + # cuda-python is needed for cuda graph. It shall be compatible with CUDA version of torch and onnxruntime-gpu. -cuda-python==11.7.0 -#To export onnx of stable diffusion, please install PyTorch 1.13.1+cu117 -#--extra-index-url https://download.pytorch.org/whl/cu117 -#torch==1.13.1+cu117 +cuda-python==11.8.0 +# For windows, cuda-python need the following +pywin32; platform_system == "Windows" + +nvtx + +# To export onnx, please install PyTorch 2.10 like +# pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118 +# pip3 install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-tensorrt.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-tensorrt.txt index 567f39c0119e6..5b59c64ab7470 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-tensorrt.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-tensorrt.txt @@ -1,18 +1,2 @@ -diffusers>=0.16.0 -transformers>=4.26.0 -numpy>=1.24.1 -accelerate -onnx>=1.13.0 -coloredlogs -packaging -protobuf -psutil -sympy +-r requirements-cuda.txt tensorrt>=8.6.1 -onnxruntime-gpu>=1.15.1 -py3nvml -# cuda-python version shall be compatible with CUDA version of torch and onnxruntime-gpu -cuda-python==11.7.0 -#To export onnx of stable diffusion, please install PyTorch 1.13.1+cu117 -#--extra-index-url https://download.pytorch.org/whl/cu117 -#torch==1.13.1+cu117 diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/trt_utilities.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/trt_utilities.py new file mode 100644 index 0000000000000..d03a9f9f55372 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/trt_utilities.py @@ -0,0 +1,12 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import tensorrt as trt + +TRT_LOGGER = trt.Logger(trt.Logger.ERROR) + + +def init_trt_plugins(): + # Register TensorRT plugins + trt.init_libnvinfer_plugins(TRT_LOGGER, "") diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 60be2d84b2bc8..e9c24ed3eb09b 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -610,7 +610,7 @@ def convert_float_to_float16(self, use_symbolic_shape_infer=True, **kwargs): When symbolic shape inference is used (even if it failed), ONNX shape inference will be disabled. - Note that onnx shape inference will fail for model larger than 2GB. For large model, you have to eanble + Note that onnx shape inference will fail for model larger than 2GB. For large model, you have to enable symbolic shape inference. If your model is not optimized, you can also use model path to call convert_float_to_float16 in float16.py (see https://github.com/microsoft/onnxruntime/pull/15067) to avoid the 2GB limit. @@ -832,7 +832,7 @@ def get_first_output(node): # Keep track of nodes to keep. The key is first output of node, and the value is the node. output_to_node = {} - # Start from graph outputs, and find parent nodes recurisvely, and add nodes to the output_to_node dictionary. + # Start from graph outputs, and find parent nodes recursively, and add nodes to the output_to_node dictionary. dq = deque() for output in keep_outputs: if output in output_name_to_node: @@ -1161,7 +1161,7 @@ def has_same_value( signature_cache1 (dict): Optional dictionary to store data signatures of tensor1 in order to speed up comparison. signature_cache2 (dict): Optional dictionary to store data signatures of tensor2 in order to speed up comparison. Returns: - bool: True when two intializers has same value. + bool: True when two initializers has same value. """ sig1 = ( signature_cache1[tensor1.name] From 1bc115719c0d07ad658a9bb458ddf8f1bf379a41 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Wed, 4 Oct 2023 08:55:08 -0700 Subject: [PATCH 10/10] Unify handling of public headers in onnxruntime.cmake. (#17779) The changes in PR #8919 overwrote the PUBLIC_HEADER property value of the `onnxruntime` target with a list that did not include EP-specific headers. We should probably be using a consistent set of header files across packages anyway. --- cmake/onnxruntime.cmake | 57 +++++++++++++++-------------------------- 1 file changed, 21 insertions(+), 36 deletions(-) diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 59ebf8eca4306..0fe9a0a4b0bfb 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -18,35 +18,21 @@ if (${CMAKE_SYSTEM_NAME} STREQUAL "iOS") set(OUTPUT_STYLE xcode) endif() -set(ONNXRUNTIME_PUBLIC_HEADERS - "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_c_api.h" - "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_api.h" - "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_float16.h" - "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_inline.h" - "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h" - "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h" -) - -if (onnxruntime_ENABLE_TRAINING_APIS) - list(APPEND ${_HEADERS} "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h") - list(APPEND ${_HEADERS} "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h") - list(APPEND ${_HEADERS} "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h") -endif() - -# This macro is to get the path of header files for mobile packaging, for iOS and Android -macro(get_mobile_api_headers _HEADERS) - # include both c and cxx api - set(${_HEADERS} +# Gets the public C/C++ API header files +function(get_c_cxx_api_headers HEADERS_VAR) + set(_headers "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_c_api.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_api.h" - "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_float16.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_inline.h" + "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_float16.h" + "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h" + "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h" ) if (onnxruntime_ENABLE_TRAINING_APIS) - list(APPEND ${_HEADERS} "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h") - list(APPEND ${_HEADERS} "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h") - list(APPEND ${_HEADERS} "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h") + list(APPEND _headers "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h") + list(APPEND _headers "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h") + list(APPEND _headers "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h") endif() # need to add header files for enabled EPs @@ -54,10 +40,13 @@ macro(get_mobile_api_headers _HEADERS) file(GLOB _provider_headers CONFIGURE_DEPENDS "${REPO_ROOT}/include/onnxruntime/core/providers/${f}/*.h" ) - list(APPEND ${_HEADERS} "${_provider_headers}") - unset(_provider_headers) + list(APPEND _headers ${_provider_headers}) endforeach() -endmacro() + + set(${HEADERS_VAR} ${_headers} PARENT_SCOPE) +endfunction() + +get_c_cxx_api_headers(ONNXRUNTIME_PUBLIC_HEADERS) #If you want to verify if there is any extra line in symbols.txt, run # nm -C -g --defined libonnxruntime.so |grep -v '\sA\s' | cut -f 3 -d ' ' | sort @@ -84,11 +73,9 @@ if(WIN32) "${ONNXRUNTIME_ROOT}/core/dll/onnxruntime.rc" ) elseif(onnxruntime_BUILD_APPLE_FRAMEWORK) - get_mobile_api_headers(APPLE_FRAMEWORK_HEADERS) - # apple framework requires the header file be part of the library onnxruntime_add_shared_library(onnxruntime - ${APPLE_FRAMEWORK_HEADERS} + ${ONNXRUNTIME_PUBLIC_HEADERS} "${CMAKE_CURRENT_BINARY_DIR}/generated_source.c" ) @@ -107,10 +94,9 @@ elseif(onnxruntime_BUILD_APPLE_FRAMEWORK) set_target_properties(onnxruntime PROPERTIES FRAMEWORK TRUE FRAMEWORK_VERSION A - PUBLIC_HEADER "${APPLE_FRAMEWORK_HEADERS}" - MACOSX_FRAMEWORK_INFO_PLIST ${CMAKE_CURRENT_BINARY_DIR}/Info.plist - VERSION ${ORT_VERSION} - SOVERSION ${ORT_VERSION} + MACOSX_FRAMEWORK_INFO_PLIST ${INFO_PLIST_PATH} + SOVERSION ${ORT_VERSION} + # Note: The PUBLIC_HEADER and VERSION properties for the 'onnxruntime' target will be set later in this file. ) else() onnxruntime_add_shared_library(onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/generated_source.c) @@ -180,11 +166,10 @@ endif() # we need to copy C/C++ API headers to be packed into Android AAR package if(CMAKE_SYSTEM_NAME STREQUAL "Android" AND onnxruntime_BUILD_JAVA) - get_mobile_api_headers(ANDROID_AAR_HEADERS) set(ANDROID_HEADERS_DIR ${CMAKE_CURRENT_BINARY_DIR}/android/headers) file(MAKE_DIRECTORY ${ANDROID_HEADERS_DIR}) # copy the header files one by one - foreach(h_ ${ANDROID_AAR_HEADERS}) + foreach(h_ ${ONNXRUNTIME_PUBLIC_HEADERS}) get_filename_component(HEADER_NAME_ ${h_} NAME) add_custom_command(TARGET onnxruntime POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different ${h_} ${ANDROID_HEADERS_DIR}/${HEADER_NAME_}) endforeach() @@ -328,7 +313,7 @@ if(onnxruntime_BUILD_APPLE_FRAMEWORK) file(MAKE_DIRECTORY ${STATIC_FRAMEWORK_HEADER_DIR}) # copy the header files one by one, and the Info.plist - foreach(h_ ${APPLE_FRAMEWORK_HEADERS}) + foreach(h_ ${ONNXRUNTIME_PUBLIC_HEADERS}) get_filename_component(HEADER_NAME_ ${h_} NAME) add_custom_command(TARGET onnxruntime POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different ${h_} ${STATIC_FRAMEWORK_HEADER_DIR}/${HEADER_NAME_}) endforeach()