Skip to content

Commit

Permalink
Add fp8 quantization to the compile stage of the MIGraphX EP
Browse files Browse the repository at this point in the history
Mirror the same calibration code we use for int8 and just change which quantize we call through the MIGraphx API
  • Loading branch information
TedThemistokleous committed Dec 19, 2024
1 parent 11ff644 commit ac77aac
Showing 1 changed file with 37 additions and 17 deletions.
54 changes: 37 additions & 17 deletions onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1190,9 +1190,8 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options);

// Read in the calibration data and map it to an migraphx paramater map for the calibration ops
if ((int8_enable_ || fp8_enable_) && int8_calibration_cache_available_) {
if ((int8_enable_ xor fp8_enable_) && int8_calibration_cache_available_) {

Check warning on line 1193 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc#L1193

Use operator ^ instead of xor [readability/alt_tokens] [2]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:1193:  Use operator ^ instead of xor  [readability/alt_tokens] [2]
LOGS_DEFAULT(INFO) << "Quantizing input program to int8" << std::endl;
migraphx::quantize_int8_options quant_opts;
migraphx::program_parameters quant_params;

auto param_shapes = prog.get_parameter_shapes();
Expand All @@ -1202,15 +1201,26 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
auto cal_val_shape = migraphx::shape(migraphx_shape_float_type);
quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast<void*>(std::move(&cal_val))));
}
quant_opts.add_calibration_data(quant_params);

// specify thing we want to int8 quantize
quant_opts.add_op_name("convolution");
quant_opts.add_op_name("dot");

// perform static quantization on the programs
migraphx::quantize_int8(prog, t_, quant_opts);
LOGS_DEFAULT(INFO) << "Quantizing input program to int8: Complete" << std::endl;
if(int8_enable_)

Check warning on line 1206 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc#L1206

Missing space before ( in if( [whitespace/parens] [5]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:1206:  Missing space before ( in if(  [whitespace/parens] [5]
{

Check warning on line 1207 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc#L1207

{ should almost always be at the end of the previous line [whitespace/braces] [4]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:1207:  { should almost always be at the end of the previous line  [whitespace/braces] [4]
migraphx::quantize_int8_options quant_opts;
quant_opts.add_calibration_data(quant_params);
// specify thing we want to int8 quantize
quant_opts.add_op_name("convolution");
quant_opts.add_op_name("dot");
migraphx::quantize_int8(prog, t_, quant_opts);
LOGS_DEFAULT(INFO) << "Quantizing input program to int8: Complete" << std::endl;
}
else if(fp8_enable_)

Check warning on line 1216 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc#L1216

An else should appear on the same line as the preceding } [whitespace/newline] [4]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:1216:  An else should appear on the same line as the preceding }  [whitespace/newline] [4]

Check warning on line 1216 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc#L1216

Missing space before ( in if( [whitespace/parens] [5]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:1216:  Missing space before ( in if(  [whitespace/parens] [5]
{

Check warning on line 1217 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc#L1217

{ should almost always be at the end of the previous line [whitespace/braces] [4]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:1217:  { should almost always be at the end of the previous line  [whitespace/braces] [4]
migraphx::quantize_fp8_options quant_opts;
quant_opts.add_calibration_data(quant_params);
migraphx::quantize_fp8(prog, t_, quant_opts);
LOGS_DEFAULT(INFO) << "Quantizing input program to fp8: Complete" << std::endl;
}

Check warning on line 1223 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc#L1223

Redundant blank line at the end of a code block should be deleted. [whitespace/blank_line] [3]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:1223:  Redundant blank line at the end of a code block should be deleted.  [whitespace/blank_line] [3]
}

if (fp16_enable_) {
Expand Down Expand Up @@ -1333,9 +1343,8 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options);

// Read in the calibration data and map it to an migraphx paramater map for the calibration ops
if (int8_enable && int8_calibration_cache_available) {
if ((int8_enable xor fp8_enable) && int8_calibration_cache_available) {

Check warning on line 1346 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc#L1346

Use operator ^ instead of xor [readability/alt_tokens] [2]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:1346:  Use operator ^ instead of xor  [readability/alt_tokens] [2]
LOGS_DEFAULT(INFO) << "Quantize Int8: Begin" << std::endl;
migraphx::quantize_int8_options quant_opts;
migraphx::program_parameters quant_params;

auto param_shapes = prog.get_parameter_shapes();
Expand Down Expand Up @@ -1364,14 +1373,25 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
auto cal_val_shape = migraphx::shape(migraphx_shape_float_type);
quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast<void*>(std::move(&cal_val))));
}
quant_opts.add_calibration_data(quant_params);
// specify thing we want to int8 quantize
quant_opts.add_op_name("convolution");
quant_opts.add_op_name("dot");

// perform static quantization on the programs
migraphx::quantize_int8(prog, t, quant_opts);
LOGS_DEFAULT(INFO) << "Quantize Int8: Completed" << std::endl;
if(int8_enable)

Check warning on line 1378 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc#L1378

Missing space before ( in if( [whitespace/parens] [5]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:1378:  Missing space before ( in if(  [whitespace/parens] [5]
{

Check warning on line 1379 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc#L1379

{ should almost always be at the end of the previous line [whitespace/braces] [4]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:1379:  { should almost always be at the end of the previous line  [whitespace/braces] [4]
migraphx::quantize_int8_options quant_opts;
quant_opts.add_calibration_data(quant_params);
// specify thing we want to int8 quantize
quant_opts.add_op_name("convolution");
quant_opts.add_op_name("dot");
migraphx::quantize_int8(prog, t, quant_opts);
LOGS_DEFAULT(INFO) << "Quantizing input program to fp8: Complete" << std::endl;
}
else if(fp8_enable)

Check warning on line 1388 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc#L1388

An else should appear on the same line as the preceding } [whitespace/newline] [4]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:1388:  An else should appear on the same line as the preceding }  [whitespace/newline] [4]

Check warning on line 1388 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc#L1388

Missing space before ( in if( [whitespace/parens] [5]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:1388:  Missing space before ( in if(  [whitespace/parens] [5]
{

Check warning on line 1389 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc#L1389

{ should almost always be at the end of the previous line [whitespace/braces] [4]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:1389:  { should almost always be at the end of the previous line  [whitespace/braces] [4]
migraphx::quantize_fp8_options quant_opts;
quant_opts.add_calibration_data(quant_params);
migraphx::quantize_fp8(prog, t, quant_opts);
LOGS_DEFAULT(INFO) << "Quantizing input program to fp8: Complete" << std::endl;
}
}

if (fp16_enable) {
Expand Down

0 comments on commit ac77aac

Please sign in to comment.