Skip to content

Commit

Permalink
Remove overealous fp32->fp16 CPU casts
Browse files Browse the repository at this point in the history
  • Loading branch information
PatriceVignola committed Oct 22, 2024
1 parent 44d20cd commit e430232
Showing 1 changed file with 28 additions and 39 deletions.
67 changes: 28 additions & 39 deletions onnxruntime/core/optimizer/cast_graph_io_to_fp16_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,49 +67,38 @@ Status CastGraphIOToFp16Transformer::ApplyImpl(

std::map<const onnxruntime::NodeArg*, onnxruntime::NodeArg*> replacement_defs;

bool has_graph_input = false;

for (auto& input_def : node->MutableInputDefs()) {
if (graph.IsInputsIncludingInitializers(input_def)) {
has_graph_input = true;
break;
}
}

// We only need to convert the inputs that not non-overridable initializers currently, since it means that
// it cannot be changed by DML without a cast. On the other hand, initializer inputs can be modified
// within DML. We can make this pass more complete later-on and convert every single node to fp16, but
// it is way more complex than simply handling the inputs.
if (has_graph_input) {
for (auto& input_def : node->MutableInputDefs()) {
if (!IsMLFloat32Tensor(*input_def)) {
continue;
}

// TODO (pavignol): Convert scale/zeropoints in-place
if (input_def_updates.count(input_def)) {
replacement_defs[input_def] = input_def_updates[input_def];
} else {
// Add an fp32->fp16 node running on the CPU
auto cpu_cast_output_arg = AddCastNode(graph,
input_def,
&float_16_tensor_proto,
false,
static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16),
onnxruntime::kCpuExecutionProvider);

// Add an fp16->fp32 node running on the EP
auto ep_cast_output_arg = AddCastNode(graph,
cpu_cast_output_arg,
&float_32_tensor_proto,
false,
static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT),
node->GetExecutionProviderType());

replacement_defs[input_def] = ep_cast_output_arg;
input_def_updates[input_def] = ep_cast_output_arg;
modified = true;
}
for (auto& input_def : node->MutableInputDefs()) {
if (!IsMLFloat32Tensor(*input_def) || !graph.IsInputsIncludingInitializers(input_def)) {
continue;
}

// TODO (pavignol): Convert scale/zeropoints in-place
if (input_def_updates.count(input_def)) {
replacement_defs[input_def] = input_def_updates[input_def];
} else {
// Add an fp32->fp16 node running on the CPU
auto cpu_cast_output_arg = AddCastNode(graph,
input_def,
&float_16_tensor_proto,
false,
static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16),
onnxruntime::kCpuExecutionProvider);

// Add an fp16->fp32 node running on the EP
auto ep_cast_output_arg = AddCastNode(graph,
cpu_cast_output_arg,
&float_32_tensor_proto,
false,
static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT),
node->GetExecutionProviderType());

replacement_defs[input_def] = ep_cast_output_arg;
input_def_updates[input_def] = ep_cast_output_arg;
modified = true;
}
}

Expand Down

0 comments on commit e430232

Please sign in to comment.