Skip to content

Commit

Permalink
MIGraphX EP: Add set_false_math to false by default (#20520)
Browse files Browse the repository at this point in the history
Patching in fast match disabled in the MIGraphX Compile stage in the
MIGraphX EP

### Description

Allow the MIGraphX API to compile the program given to the EP to turn
off fast math by default.

### Motivation and Context

Fixes accuracy issue we're seeing with GELU parity tests. Without fast
math disabled GELU will use a faster but less numerically stable version
which trades speed for accuracy.

Co-authored-by: Ted Themistokleous <[email protected]>
  • Loading branch information
TedThemistokleous and Ted Themistokleous authored May 8, 2024
1 parent 8d09baf commit 737eb48
Showing 1 changed file with 6 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1147,7 +1147,9 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
// perform static quantization on the programs
migraphx::quantize_int8(prog, t_, quant_opts);
}
prog.compile(t_);
migraphx::compile_options co;
co.set_fast_math(false);
prog.compile(t_, co);
auto prog_output_shapes = prog.get_output_shapes();
for (std::size_t i = 0; i < output_names.size(); ++i) {
auto out_len = prog_output_shapes[i].lengths();
Expand Down Expand Up @@ -1265,7 +1267,9 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
migraphx::quantize_int8(prog, t, quant_opts);
}

prog.compile(t);
migraphx::compile_options co;
co.set_fast_math(false);
prog.compile(t, co);
mgx_state->prog = prog;
param_shapes = prog.get_parameter_shapes();
no_input_shape = false;
Expand Down

0 comments on commit 737eb48

Please sign in to comment.