Skip to content

Commit

Permalink
Hide NPU Adapter selection behind macro
Browse files Browse the repository at this point in the history
  • Loading branch information
Sheil Kumar committed Nov 20, 2023
1 parent 1dd9bf5 commit ac870ac
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
4 changes: 4 additions & 0 deletions include/onnxruntime/core/providers/dml/dml_provider_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,13 @@ enum OrtDmlPerformancePreference {
};

enum OrtDmlDeviceFilter : uint32_t {
#ifdef ENABLE_NPU_ADAPTER_ENUMERATION
Any = 0xffffffff,
Gpu = 1 << 0,
Npu = 1 << 1,
#else
Gpu = 1 << 0,
#endif
};

inline OrtDmlDeviceFilter operator~(OrtDmlDeviceFilter a) { return (OrtDmlDeviceFilter) ~(int)a; }
Expand Down
17 changes: 12 additions & 5 deletions onnxruntime/core/providers/dml/dml_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,12 @@ static DeviceType FilterAdapterTypeQuery(IDXCoreAdapter* adapter, OrtDmlDeviceFi
return DeviceType::GPU;
}

#ifdef ENABLE_NPU_ADAPTER_ENUMERATION
auto allow_npus = (filter & OrtDmlDeviceFilter::Npu) == OrtDmlDeviceFilter::Npu;
if (IsNPU(adapter) && allow_npus) {
return DeviceType::NPU;
}
#endif

return DeviceType::BadDevice;
}
Expand Down Expand Up @@ -216,13 +218,15 @@ static void SortHeterogenousDXCoreAdapterList(
return;
}

#ifdef ENABLE_NPU_ADAPTER_ENUMERATION
// When considering both GPUs and NPUs sort them by performance preference
// of Default (Gpus first), HighPerformance (GPUs first), or LowPower (NPUs first)
auto keep_npus = (filter & OrtDmlDeviceFilter::Npu) == OrtDmlDeviceFilter::Npu;
auto only_npus = filter == OrtDmlDeviceFilter::Npu;
if (!keep_npus || only_npus) {
return;
}
#endif

struct SortingPolicy {
// default is false because GPUs are considered higher priority in
Expand Down Expand Up @@ -322,23 +326,26 @@ static std::optional<OrtDmlPerformancePreference> ParsePerformancePreference(con

static std::optional<OrtDmlDeviceFilter> ParseFilter(const ProviderOptions& provider_options) {
static const std::string Filter = "filter";
static const std::string Any = "any";
static const std::string Gpu = "gpu";
#ifdef ENABLE_NPU_ADAPTER_ENUMERATION
static const std::string Any = "any";
static const std::string Npu = "npu";
#endif

auto preference_it = provider_options.find(Filter);
if (preference_it != provider_options.end()) {
if (preference_it->second == Any) {
return OrtDmlDeviceFilter::Any;
}

if (preference_it->second == Gpu) {
return OrtDmlDeviceFilter::Gpu;
}

#ifdef ENABLE_NPU_ADAPTER_ENUMERATION
if (preference_it->second == Any) {
return OrtDmlDeviceFilter::Any;
}
if (preference_it->second == Npu) {
return OrtDmlDeviceFilter::Npu;
}
#endif

ORT_THROW("Invalid Filter provided for DirectML EP device selection.");
}
Expand Down

0 comments on commit ac870ac

Please sign in to comment.