From dadd0c451a60a6f908ef40fc93e7775765280c09 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous <107195283+TedThemistokleous@users.noreply.github.com> Date: Mon, 17 Jun 2024 23:18:13 -0400 Subject: [PATCH] [MIGraphX EP] Fix MIGraphX mixed precision run input parameters (#20982) See #20643 ### Description Changes order of how we perform quantization to better support mixed precision and fixes a bug found with parameters of inputs for int8 quantization not being correctly handled. We now perform int8 quantization first on a full precision input model, before then quantizing the model to fp16 for remain ops that aren't quantized. The former case was causing us to use a low precision input which could cause larger values to be inserted than intended to the model when int8 quantization is perform. The symptom of this was a failure during quantization steps. Similar to the above input parameters were being uninitialized and resulting in similar failure during int8 quantization. GPU faults were intermittent but present as using uninitialized memory created undefined behavior when we started testing more complex models during mixed precision. ### Motivation and Context In some cases we've seen random data and/or invalid values entering into compiled onnx graphs. This is due to input parameters to the MIGraphX Graph not being set correctly when mixed precision (int8 + fp16) is used and ordering of quantization steps is causes a lower precision model to be used to perform int8 quantization. In most cases the failure is silent/intermittent. In some cases we've observed gpu faults due to out of bounds values being set. This change is required as a large input parameter to the MIGraphX graph is initialized to a large random value, and the next operator is using that for indexing, we get undefined behavior and a GPU fault. --- .../core/providers/migraphx/migraphx_execution_provider.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 581376623ffe0..6ee85c3a4c047 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1172,7 +1172,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& } std::vector input_names, output_names; - no_input_shape = no_input_shape or get_input_output_names(graph_body_viewer, input_names, output_names); + no_input_shape = no_input_shape || get_input_output_names(graph_body_viewer, input_names, output_names); // by parsing the model_proto, create a program corresponding to // the input fused_node @@ -1356,7 +1356,6 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast(std::move(&cal_val)))); } quant_opts.add_calibration_data(quant_params); - // specify thing we want to int8 quantize quant_opts.add_op_name("convolution"); quant_opts.add_op_name("dot");