Skip to content

Commit

Permalink
[VitisAI] Align TensorProto_DataType with onnx1.16 (#21067)
Browse files Browse the repository at this point in the history
### Description
Vitis AI EP synchronously supports the TensorProto data types supported
by ONNX 1.16.
Add error message show when graph resolve fail for troubleshooting.


### Motivation and Context
ONNX 1.15 & 1.16 add support some new TensorProto DataType , such as 
- FLOAT8E4M3FN
- FLOAT8E4M3FNUZ
- FLOAT8E5M2
- FLOAT8E5M2FNUZ
- UINT4
- INT4

---------

Co-authored-by: liumingyue <[email protected]>
  • Loading branch information
mingyueliuh and liumingyue authored Jun 29, 2024
1 parent 6baaaf5 commit 7e93cd7
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 2 deletions.
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/vitisai/imp/global_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,9 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
graph.SetGraphResolveNeeded();
}
auto status = graph.Resolve();
if (!status.IsOK()) {
std::cerr << "graph resolve error:" << status.ErrorMessage() << std::endl;
}
return status.Code();
};
the_global_api.graph_get_consumer_nodes_unsafe = [](const Graph& graph, const std::string& node_arg_name) -> auto {
Expand Down
8 changes: 7 additions & 1 deletion onnxruntime/core/providers/vitisai/include/vaip/my_ort.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@ enum TensorProto_DataType : int {
TensorProto_DataType_UINT64 = 13,
TensorProto_DataType_COMPLEX64 = 14,
TensorProto_DataType_COMPLEX128 = 15,
TensorProto_DataType_BFLOAT16 = 16
TensorProto_DataType_BFLOAT16 = 16,
TensorProto_DataType_FLOAT8E4M3FN = 17,
TensorProto_DataType_FLOAT8E4M3FNUZ = 18,
TensorProto_DataType_FLOAT8E5M2 = 19,
TensorProto_DataType_FLOAT8E5M2FNUZ = 20,
TensorProto_DataType_UINT4 = 21,
TensorProto_DataType_INT4 = 22
};
enum AttributeProto_AttributeType : int {
AttributeProto_AttributeType_UNDEFINED = 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ struct OrtApi;
namespace vaip_core {

#define VAIP_ORT_API_MAJOR (3u)
#define VAIP_ORT_API_MINOR (0u)
#define VAIP_ORT_API_MINOR (1u)
#define VAIP_ORT_API_PATCH (0u)
struct OrtApiForVaip {
uint32_t magic; // 'VAIP' or something else to make sure the following field
Expand Down
24 changes: 24 additions & 0 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -613,8 +613,12 @@ struct ProviderHostImpl : ProviderHost {
elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT8;
} else if (data_type->s() == "int32") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_INT32;
} else if (data_type->s() == "uint32") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT32;
} else if (data_type->s() == "int64") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_INT64;
} else if (data_type->s() == "uint64") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT64;
} else if (data_type->s() == "int1") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_BOOL;
} else if (data_type->s() == "bfloat16") {
Expand All @@ -625,6 +629,26 @@ struct ProviderHostImpl : ProviderHost {
elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT16;
} else if (data_type->s() == "int16") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_INT16;
} else if (data_type->s() == "double") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
} else if (data_type->s() == "string") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_STRING;
} else if (data_type->s() == "complex64") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64;
} else if (data_type->s() == "complex128") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128;
} else if (data_type->s() == "float8e4m3fn") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN;
} else if (data_type->s() == "float8e4m3fnuz") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ;
} else if (data_type->s() == "float8e5m2") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2;
} else if (data_type->s() == "float8e5m2funz") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ;
} else if (data_type->s() == "uint4") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT4;
} else if (data_type->s() == "int4") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_INT4;
} else {
return;
}
Expand Down

0 comments on commit 7e93cd7

Please sign in to comment.