From c3a38f7c535b4bed0b39a1be41c6bfeada453eeb Mon Sep 17 00:00:00 2001 From: Ted Themistokleous <107195283+TedThemistokleous@users.noreply.github.com> Date: Tue, 17 Sep 2024 13:27:22 -0400 Subject: [PATCH] Add support for softmaxcrossentropy loss to MIGraphX EP (#64) --- .../providers/migraphx/migraphx_execution_provider.cc | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index e41cd577b0b21..d3baaf0c3db45 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -915,6 +915,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "SkipSimplifiedLayerNormalization", "Slice", "Softmax", + "SoftmaxCrossEntropyLoss", "Softplus", "Softsign", "SpaceToDepth", @@ -1026,15 +1027,6 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v return result; } - // migraphx cannot handle Loop, If, and SoftmaxCrossEntropyLoss for now, - // so if a model contain any of these operators, fall back to CPU - std::unordered_set vec_ops = {"SoftmaxCrossEntropyLoss"}; - if (std::any_of(unsupported_nodes.begin(), unsupported_nodes.end(), [&](auto i) { - return (vec_ops.count(graph_viewer.GetNode(i)->OpType()) > 0); - })) { - return result; - } - auto mgx_clusters = GetPartitionedSubgraphs(graph_viewer.GetNodesInTopologicalOrder(), unsupported_nodes); // check whether a subgrap should fallback to CPU