diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 1b1cbbf8dead9..35c342b28b9ac 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -5,6 +5,7 @@ #include #include #include +#include #include "js_execution_provider.h" @@ -13,9 +14,11 @@ #endif #include "core/graph/function_utils.h" +#include "core/graph/indexed_sub_graph.h" #include "core/framework/compute_capability.h" #include "core/framework/data_transfer_manager.h" #include "core/framework/kernel_registry.h" +#include "core/framework/fallback_cpu_capability.h" #include "core/providers/shared/node_unit/node_unit.h" #include "allocator.h" #include "data_transfer.h" @@ -645,7 +648,45 @@ std::vector JsExecutionProvider::CreatePreferredAllocators() { std::vector> JsExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup) const { - return IExecutionProvider::GetCapability(graph, kernel_lookup); + InlinedVector candidates; + // `tenative_candidates` is a subset of `candidates`. + InlinedVector tenative_candidates; + for (auto& node_index : graph.GetNodesInTopologicalOrder()) { + const auto* p_node = graph.GetNode(node_index); + if (p_node == nullptr) + continue; + + const auto& node = *p_node; + if (!node.GetExecutionProviderType().empty()) { + // If the node was added by layout transformer, do not move it to CPU + if (node.GetExecutionProviderType() == kJsExecutionProvider) { + candidates.push_back(node.Index()); + } + continue; + } + + const KernelCreateInfo* webgpu_kernel_def = kernel_lookup.LookUpKernel(node); + // none of the provided registries has a webgpu kernel for this node + if (webgpu_kernel_def == nullptr) { + LOGS(*GetLogger(), INFO) << "webgpu kernel not found in registries for Op type: " + << node.OpType() << " node name: " << node.Name(); + continue; + } + candidates.push_back(node.Index()); + tenative_candidates.push_back(node.Index()); + } + auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates); + std::vector> result; + for (auto& node_index : candidates) { + if (cpu_nodes.count(node_index) > 0) { + continue; + } + + auto sub_graph = std::make_unique(); + sub_graph->nodes.push_back(node_index); + result.emplace_back(std::make_unique(std::move(sub_graph))); + } + return result; } std::shared_ptr JsExecutionProvider::GetKernelRegistry() const {