Skip to content

Commit

Permalink
[MIGraphX EP] Ensure we support all inputs for MatMulInteger and Conv…
Browse files Browse the repository at this point in the history
…Integer. (#21680)

… to int8 for now

Allow for models with biases/full input and only check for int8 support
in EP

### Description
<!-- Describe your changes. -->
Allows for all inputs for MatMulInteger and ConvInteger to be supported
for prequantized models


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Fixes issues when using prequantized models that contain weight biases

---------

Co-authored-by: Ted Themistokleous <[email protected]>
  • Loading branch information
TedThemistokleous and Ted Themistokleous authored Aug 21, 2024
1 parent 009209e commit ed155ad
Showing 1 changed file with 6 additions and 19 deletions.
25 changes: 6 additions & 19 deletions onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -316,22 +316,14 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
return true;
}
} else if (optype == "ConvInteger") {
if (node->InputDefs()[0]->Shape()->dim_size() != 4) {
return true;
}

// migraphx can handle only two inputs
if (node->InputDefs().size() != 2) {
return true;
}

// only support int8 type
// only support int8 and uint8 type
const auto& input_type = node->InputDefs()[0]->TypeAsProto();
if (input_type == nullptr) {
return true;
}

if (input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) {
if ((input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) and
(input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)) {
return true;
}
} else if (optype == "Expand") {
Expand Down Expand Up @@ -373,18 +365,14 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
return true;
}
} else if (optype == "MatMulInteger") {
// migraphx can handle only two inputs
if (node->InputDefs().size() != 2) {
return true;
}

// only support int8 type
// only support int8 and uint8 type
const auto& input_type = node->InputDefs()[0]->TypeAsProto();
if (input_type == nullptr) {
return true;
}

if (input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) {
if ((input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) and
(input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)) {
return true;
}
} else if (optype == "NonZero") {
Expand Down Expand Up @@ -456,7 +444,6 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
return true;
}
}

} else if (optype == "ReduceSum") {
const auto& args = node->InputDefs();
if (args.size() == 2) {
Expand Down

0 comments on commit ed155ad

Please sign in to comment.