Skip to content

Commit

Permalink
[webgpu]: add a simple GetCapability implementation (microsoft#17643)
Browse files Browse the repository at this point in the history
Most of the function body was copied from CUDA EP.
  • Loading branch information
snnn authored and kleiti committed Mar 22, 2024
1 parent a871a3c commit a18093b
Showing 1 changed file with 42 additions and 1 deletion.
43 changes: 42 additions & 1 deletion onnxruntime/core/providers/js/js_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

#include "js_execution_provider.h"

Expand All @@ -13,9 +14,11 @@
#endif

#include "core/graph/function_utils.h"
#include "core/graph/indexed_sub_graph.h"
#include "core/framework/compute_capability.h"
#include "core/framework/data_transfer_manager.h"
#include "core/framework/kernel_registry.h"
#include "core/framework/fallback_cpu_capability.h"
#include "core/providers/shared/node_unit/node_unit.h"
#include "allocator.h"
#include "data_transfer.h"
Expand Down Expand Up @@ -645,7 +648,45 @@ std::vector<AllocatorPtr> JsExecutionProvider::CreatePreferredAllocators() {
std::vector<std::unique_ptr<ComputeCapability>> JsExecutionProvider::GetCapability(
const onnxruntime::GraphViewer& graph,
const IKernelLookup& kernel_lookup) const {
return IExecutionProvider::GetCapability(graph, kernel_lookup);
InlinedVector<NodeIndex> candidates;
// `tenative_candidates` is a subset of `candidates`.
InlinedVector<NodeIndex> tenative_candidates;
for (auto& node_index : graph.GetNodesInTopologicalOrder()) {
const auto* p_node = graph.GetNode(node_index);
if (p_node == nullptr)
continue;

const auto& node = *p_node;
if (!node.GetExecutionProviderType().empty()) {
// If the node was added by layout transformer, do not move it to CPU
if (node.GetExecutionProviderType() == kJsExecutionProvider) {
candidates.push_back(node.Index());
}
continue;
}

const KernelCreateInfo* webgpu_kernel_def = kernel_lookup.LookUpKernel(node);
// none of the provided registries has a webgpu kernel for this node
if (webgpu_kernel_def == nullptr) {
LOGS(*GetLogger(), INFO) << "webgpu kernel not found in registries for Op type: "
<< node.OpType() << " node name: " << node.Name();
continue;
}
candidates.push_back(node.Index());
tenative_candidates.push_back(node.Index());
}
auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates);
std::vector<std::unique_ptr<ComputeCapability>> result;
for (auto& node_index : candidates) {
if (cpu_nodes.count(node_index) > 0) {
continue;
}

auto sub_graph = std::make_unique<IndexedSubGraph>();
sub_graph->nodes.push_back(node_index);
result.emplace_back(std::make_unique<ComputeCapability>(std::move(sub_graph)));
}
return result;
}

std::shared_ptr<KernelRegistry> JsExecutionProvider::GetKernelRegistry() const {
Expand Down

0 comments on commit a18093b

Please sign in to comment.