Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into promote_inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
egalli committed Dec 13, 2024
2 parents 5e3295f + 62e7e24 commit 21edcaf
Show file tree
Hide file tree
Showing 41 changed files with 282 additions and 103 deletions.
1 change: 1 addition & 0 deletions cmake/external/eigen.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ else ()
eigen
URL ${DEP_URL_eigen}
URL_HASH SHA1=${DEP_SHA1_eigen}
PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/eigen/eigen-edge.patch
)
endif()

Expand Down
4 changes: 2 additions & 2 deletions cmake/onnxruntime_providers_openvino.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

# Header paths
find_package(OpenVINO REQUIRED COMPONENTS Runtime ONNX)
if(OpenVINO_VERSION VERSION_LESS 2024.3)
message(FATAL_ERROR "OpenVINO 2024.3 and newer are supported. Please, use latest OpenVINO release")
if(OpenVINO_VERSION VERSION_LESS 2024.4)
message(FATAL_ERROR "OpenVINO 2024.4 and newer are supported. Please, use latest OpenVINO release")
endif()

if(OpenVINO_VERSION VERSION_GREATER_EQUAL 2024.4)
Expand Down
13 changes: 13 additions & 0 deletions cmake/patches/eigen/eigen-edge.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
diff --git a/Eigen/src/Core/util/IndexedViewHelper.h b/Eigen/src/Core/util/IndexedViewHelper.h
index f85de305f..3dc2bb5e7 100644
--- a/Eigen/src/Core/util/IndexedViewHelper.h
+++ b/Eigen/src/Core/util/IndexedViewHelper.h
@@ -178,7 +178,7 @@ namespace placeholders {

EIGEN_DEPRECATED static const all_t all = Eigen::all; // PLEASE use Eigen::all instead of Eigen::placeholders::all
EIGEN_DEPRECATED static const last_t last = Eigen::last; // PLEASE use Eigen::last instead of Eigen::placeholders::last
- EIGEN_DEPRECATED static const end_t end = Eigen::lastp1; // PLEASE use Eigen::lastp1 instead of Eigen::placeholders::end
+ // EIGEN_DEPRECATED static const end_t end = Eigen::lastp1; // PLEASE use Eigen::lastp1 instead of Eigen::placeholders::end
}

} // end namespace Eigen
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,8 @@ static const char* const kOrtSessionOptionEpContextEnable = "ep.context_enable";
static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_path";

// Flag to specify whether to dump the EP context into the Onnx model.
// "0": dump the EP context into separate file, keep the file name in the Onnx model.
// "1": dump the EP context into the Onnx model. (default).
// "0": dump the EP context into separate file, keep the file name in the Onnx model. (default).
// "1": dump the EP context into the Onnx model.
static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode";

// Specify the EPContext node name prefix to make it unique
Expand Down
33 changes: 25 additions & 8 deletions js/common/lib/env.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,19 @@ export declare namespace Env {
*
* This setting is available only when WebAssembly SIMD feature is available in current context.
*
* @defaultValue `true`
*
* @deprecated This property is deprecated. Since SIMD is supported by all major JavaScript engines, non-SIMD
* build is no longer provided. This property will be removed in future release.
* @defaultValue `true`
*/
simd?: boolean;

/**
* set or get a boolean value indicating whether to enable trace.
*
* @deprecated Use `env.trace` instead. If `env.trace` is set, this property will be ignored.
* @defaultValue `false`
*
* @deprecated Use `env.trace` instead. If `env.trace` is set, this property will be ignored.
*/
trace?: boolean;

Expand Down Expand Up @@ -153,7 +155,7 @@ export declare namespace Env {
/**
* Set or get the profiling configuration.
*/
profiling?: {
profiling: {
/**
* Set or get the profiling mode.
*
Expand All @@ -176,6 +178,9 @@ export declare namespace Env {
* See {@link https://gpuweb.github.io/gpuweb/#dictdef-gpurequestadapteroptions} for more details.
*
* @defaultValue `undefined`
*
* @deprecated Create your own GPUAdapter, use it to create a GPUDevice instance and set {@link device} property if
* you want to use a specific power preference.
*/
powerPreference?: 'low-power' | 'high-performance';
/**
Expand All @@ -187,6 +192,9 @@ export declare namespace Env {
* See {@link https://gpuweb.github.io/gpuweb/#dictdef-gpurequestadapteroptions} for more details.
*
* @defaultValue `undefined`
*
* @deprecated Create your own GPUAdapter, use it to create a GPUDevice instance and set {@link device} property if
* you want to use a specific fallback option.
*/
forceFallbackAdapter?: boolean;
/**
Expand All @@ -199,16 +207,25 @@ export declare namespace Env {
* value will be the GPU adapter that created by the underlying WebGPU backend.
*
* When use with TypeScript, the type of this property is `GPUAdapter` defined in "@webgpu/types".
*
* @deprecated It is no longer recommended to use this property. The latest WebGPU spec adds `GPUDevice.adapterInfo`
* (https://www.w3.org/TR/webgpu/#dom-gpudevice-adapterinfo), which allows to get the adapter information from the
* device. When it's available, there is no need to set/get the {@link adapter} property.
*/
adapter: TryGetGlobalType<'GPUAdapter'>;
/**
* Get the device for WebGPU.
*
* This property is only available after the first WebGPU inference session is created.
* Set or get the GPU device for WebGPU.
*
* When use with TypeScript, the type of this property is `GPUDevice` defined in "@webgpu/types".
* There are 3 valid scenarios of accessing this property:
* - Set a value before the first WebGPU inference session is created. The value will be used by the WebGPU backend
* to perform calculations. If the value is not a `GPUDevice` object, an error will be thrown.
* - Get the value before the first WebGPU inference session is created. This will try to create a new GPUDevice
* instance. Returns a `Promise` that resolves to a `GPUDevice` object.
* - Get the value after the first WebGPU inference session is created. Returns a resolved `Promise` to the
* `GPUDevice` object used by the WebGPU backend.
*/
readonly device: TryGetGlobalType<'GPUDevice'>;
get device(): Promise<TryGetGlobalType<'GPUDevice'>>;
set device(value: TryGetGlobalType<'GPUDevice'>);
/**
* Set or get whether validate input content.
*
Expand Down
16 changes: 15 additions & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,21 @@ const conv2dCommonSnippet = (
}
return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`;

const sampleW = `${getWSnippet(innerElementSizeW)}`;
const sampleW = isChannelsLast
? fitInner && fitBOuter
? getWSnippet(innerElementSizeW)
: `
let col = colIn * ${innerElementSizeW};
if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) {
${getWSnippet(innerElementSizeW)}
}
return ${typeSnippet(innerElementSizeW, dataType)}(0.0);`
: `
let col = colIn * ${innerElementSizeW};
if (row < uniforms.dim_inner && col < uniforms.dim_a_outer) {
${getWSnippet(innerElementSizeW)}
}
return ${typeSnippet(innerElementSizeW, dataType)}(0.0);`;

const resType = typeSnippet(innerElementSize, dataType);
const aType = isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType);
Expand Down
12 changes: 6 additions & 6 deletions js/web/test/test-runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -586,11 +586,11 @@ export class TensorResultValidator {
}
}

function createGpuTensorForInput(cpuTensor: ort.Tensor): ort.Tensor {
async function createGpuTensorForInput(cpuTensor: ort.Tensor): Promise<ort.Tensor> {
if (!isGpuBufferSupportedType(cpuTensor.type) || Array.isArray(cpuTensor.data)) {
throw new Error(`createGpuTensorForInput can not work with ${cpuTensor.type} tensor`);
}
const device = ort.env.webgpu.device as GPUDevice;
const device = await ort.env.webgpu.device;
const gpuBuffer = device.createBuffer({
// eslint-disable-next-line no-bitwise
usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE,
Expand All @@ -612,14 +612,14 @@ function createGpuTensorForInput(cpuTensor: ort.Tensor): ort.Tensor {
});
}

function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[]) {
async function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[]) {
if (!isGpuBufferSupportedType(type)) {
throw new Error(`createGpuTensorForOutput can not work with ${type} tensor`);
}

const size = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(type), dims)!;

const device = ort.env.webgpu.device as GPUDevice;
const device = await ort.env.webgpu.device;
const gpuBuffer = device.createBuffer({
// eslint-disable-next-line no-bitwise
usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE,
Expand Down Expand Up @@ -725,7 +725,7 @@ export async function sessionRun(options: {
if (options.ioBinding === 'ml-location' || options.ioBinding === 'ml-tensor') {
feeds[name] = await createMLTensorForInput(options.mlContext!, feeds[name]);
} else {
feeds[name] = createGpuTensorForInput(feeds[name]);
feeds[name] = await createGpuTensorForInput(feeds[name]);
}
}
}
Expand All @@ -742,7 +742,7 @@ export async function sessionRun(options: {
if (options.ioBinding === 'ml-tensor') {
fetches[name] = await createMLTensorForOutput(options.mlContext!, type, dims);
} else {
fetches[name] = createGpuTensorForOutput(type, dims);
fetches[name] = await createGpuTensorForOutput(type, dims);
}
}
}
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/providers/openvino/backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ BackendManager::BackendManager(const GlobalContext& global_context,
i++;
}
subgraph_context_.subgraph_name = fused_node.Name();
auto model_proto = GetModelProtoFromFusedNode(fused_node, subgraph, logger);
std::unique_ptr<onnx::ModelProto> model_proto;
if (!ep_ctx_handle_.IsValidOVEPCtxGraph()) {
model_proto = GetModelProtoFromFusedNode(fused_node, subgraph, logger);
}
std::string device_type = openvino_ep::BackendManager::GetGlobalContext().device_type;

if (ModelHasSymbolicInputDims(subgraph)) {
Expand Down
16 changes: 8 additions & 8 deletions onnxruntime/core/providers/openvino/backend_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,21 @@ struct static_cast_int64 {
int64_t operator()(const T1& x) const { return static_cast<int64_t>(x); }
};

std::shared_ptr<OVNetwork>
std::shared_ptr<const OVNetwork>
CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext& global_context,
std::map<std::string, std::shared_ptr<ov::Node>>& const_outputs_map) {
if (IsCILogEnabled()) {
std::cout << "CreateNgraphFunc" << std::endl;
}
const std::string model = model_proto.SerializeAsString();
try {
auto cnn_network = global_context.ie_core.ReadModel(model, global_context.onnx_model_path_name);
auto ov_model = global_context.ie_core.ReadModel(model, global_context.onnx_model_path_name);

// Check for Constant Folding
if (!global_context.is_wholly_supported_graph) {
if ((global_context.device_type != "NPU") && !global_context.is_wholly_supported_graph) {
ov::pass::ConstantFolding pass_const_obj;
pass_const_obj.run_on_model(cnn_network);
auto& results = const_cast<ov::ResultVector&>(cnn_network.get()->get_results());
pass_const_obj.run_on_model(ov_model);
auto& results = const_cast<ov::ResultVector&>(ov_model.get()->get_results());
size_t index = results.size() - 1;

for (auto it = results.rbegin(); it != results.rend(); ++it) {
Expand All @@ -67,12 +67,12 @@ CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext
}
#ifndef NDEBUG
if (IsDebugEnabled()) {
std::string name = cnn_network->get_friendly_name();
std::string name = ov_model->get_friendly_name();
ov::pass::Serialize serializer(name + ".xml", name + ".bin");
serializer.run_on_model(cnn_network);
serializer.run_on_model(ov_model);
}
#endif
return cnn_network;
return ov_model;
} catch (std::string const& msg) {
ORT_THROW(msg);
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/openvino/backend_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ void FillInputBlob(OVTensorPtr inputBlob, size_t batch_slice_idx,
void FillOutputBlob(OVTensorPtr outputBlob, Ort::UnownedValue& output_tensor,
size_t batch_slice_idx);

std::shared_ptr<OVNetwork>
std::shared_ptr<const OVNetwork>
CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto,
const GlobalContext& global_context,
std::map<std::string, std::shared_ptr<ov::Node>>& const_outputs_map);
Expand Down
57 changes: 48 additions & 9 deletions onnxruntime/core/providers/openvino/backends/basic_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
// Set the inference_num_threads property of the CPU
SetNumThreads(device_config);

auto npuw_status =
std::any_of(device_config.begin(), device_config.end(), [&](const std::pair<std::string, ov::Any>& pair) {
return (pair.first.find("NPU_USE_NPUW") != std::string::npos) && (pair.second.is<std::string>()) &&
(pair.second.as<std::string>() == "YES");
});

if (npuw_status) {
LOGS_DEFAULT(INFO) << log_tag << "NPUW Enabled during compilation";
}

try {
std::string dev_prec = global_context.device_type + "_" + global_context_.precision_str;

Expand Down Expand Up @@ -81,7 +91,6 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
device_config,
global_context_.ep_context_embed_mode,
subgraph_context_.subgraph_name);
ie_cnn_network_ = exe_network_.Get().get_runtime_model();
} else if (global_context_.export_ep_ctx_blob &&
hw_target.find("NPU") != std::string::npos &&
!global_context_.has_external_weights) {
Expand All @@ -106,15 +115,15 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
device_config,
subgraph_context_.subgraph_name);
} else { // For all other types use ov::Model Type
ie_cnn_network_ = CreateOVModel(*model_proto, global_context_, const_outputs_map_);
auto ov_model = CreateOVModel(*model_proto, global_context_, const_outputs_map_);
exe_network_ = global_context_.ie_core.CompileModel(
ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name);
ov_model, hw_target, device_config, subgraph_context_.subgraph_name);
}
#endif
} else { // Full graph is not supported
ie_cnn_network_ = CreateOVModel(*model_proto, global_context_, const_outputs_map_);
auto ov_model = CreateOVModel(*model_proto, global_context_, const_outputs_map_);
exe_network_ = global_context_.ie_core.CompileModel(
ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name);
ov_model, hw_target, device_config, subgraph_context_.subgraph_name);
}
LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin";
} catch (const char* msg) {
Expand Down Expand Up @@ -145,8 +154,8 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
device_config.emplace(ov::hint::inference_precision("f32"));
}
if (global_context_.precision_str.find("ACCURACY") != std::string::npos &&
global_context_.device_type == "GPU") {
if (global_context_.OpenVINO_Version.at(0) >= 2024 && global_context_.OpenVINO_Version.at(1) >= 1) {
global_context_.device_type.find("GPU") != std::string::npos) {
if (global_context_.OpenVINO_Version.at(0) >= 2024) {
device_config.emplace(ov::hint::inference_precision(ov::element::undefined));
device_config.emplace(ov::hint::execution_mode(ov::hint::ExecutionMode::ACCURACY));
} else {
Expand Down Expand Up @@ -174,7 +183,7 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
device_property = std::make_pair("NPU_COMPILER_TYPE", env_npu_compiler_type);
}
device_config.emplace(ov::device::properties("NPU", device_property));
#if (OPENVINO_VERSION_MAJOR >= 2024) && (OPENVINO_VERSION_MINOR > 3)
#if (((OPENVINO_VERSION_MAJOR == 2024) && (OPENVINO_VERSION_MINOR > 3)) || (OPENVINO_VERSION_MAJOR > 2024))
if (global_context_.export_ep_ctx_blob) {
global_context_.ie_core.Get().set_property("NPU", ov::intel_npu::bypass_umd_caching(true));
}
Expand All @@ -184,6 +193,33 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
if (!global_context_.load_config.empty()) {
const std::map<std::string, ov::AnyMap>& target_config = global_context_.load_config;

if (global_context_.device_type.find("NPU") != std::string::npos) {
auto npuw_config = target_config.at("NPU");

// Check if "NPU_USE_NPUW" exists and is set to "YES"
auto npu_use_npuw_it = npuw_config.find("NPU_USE_NPUW");
if (npu_use_npuw_it != npuw_config.end() &&
npu_use_npuw_it->second.is<std::string>() &&
npu_use_npuw_it->second.as<std::string>() == "YES") {
// Only add NPUW-related keys if NPU_USE_NPUW is "YES"
for (const auto& [key, value] : npuw_config) {
if (key.find("NPUW") != std::string::npos) {
if (!value.is<std::string>()) {
LOGS_DEFAULT(ERROR) << "Invalid value type for key: " << key;
continue;
}
device_config[key] = value;
}
}
} else {
// Check if there are any "NPUW" keys and log a warning
if (std::any_of(npuw_config.begin(), npuw_config.end(),
[&](const auto& pair) { return pair.first.find("NPUW") != std::string::npos; })) {
LOGS_DEFAULT(WARNING) << "Skipping NPUW-related configurations as NPU_USE_NPUW is not set to 'YES'.";
}
}
}

// Parse device types like "AUTO:CPU,GPU" and extract individual devices
auto parse_individual_devices = [&](const std::string& device_type) -> std::vector<std::string> {
std::vector<std::string> devices;
Expand Down Expand Up @@ -213,6 +249,9 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
auto set_target_properties = [&](const std::string& device, const ov::AnyMap& config_options,
const std::vector<ov::PropertyName>& supported_properties) {
for (const auto& [key, value] : config_options) {
if (key.find("NPUW") != std::string::npos) {
continue;
}
if (is_supported_and_mutable(key, supported_properties)) {
global_context_.ie_core.Get().set_property(device, ov::AnyMap{{key, value}});
} else {
Expand Down Expand Up @@ -378,7 +417,7 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque
if ((it == ort_ov_tensor_map.end()) ||
(it != ort_ov_tensor_map.end() && (it->second.ort_ptr != tensor.GetTensorRawData()))) {
ov_tensor_data_t ov_tensor_data;
auto input = graph_input_info.at(input_idx);
const auto& input = graph_input_info.at(input_idx);
ov_tensor_data.tensor_ptr = std::make_shared<ov::Tensor>(input.get_element_type(), input.get_shape(),
const_cast<void*>(tensor.GetTensorRawData()));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ class BasicBackend : public IBackend {
GlobalContext& global_context_;
SubGraphContext subgraph_context_;
mutable std::mutex compute_lock_;
std::shared_ptr<const OVNetwork> ie_cnn_network_;
OVExeNetwork exe_network_;
std::map<std::string, std::shared_ptr<ov::Node>> const_outputs_map_;
std::unique_ptr<InferRequestsQueue> inferRequestsQueue_;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/openvino/contexts.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ struct GlobalContext {
bool is_wholly_supported_graph = false;
bool enable_opencl_throttling = false;
bool disable_dynamic_shapes = false;
bool ep_context_embed_mode = true;
bool ep_context_embed_mode = false;
bool export_ep_ctx_blob = false;
bool enable_qdq_optimizer = false;
bool disable_cpu_fallback = false;
Expand Down
Loading

0 comments on commit 21edcaf

Please sign in to comment.