Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
sumitsays committed Feb 23, 2024
2 parents 520e3e8 + aec2389 commit 4b7df02
Show file tree
Hide file tree
Showing 164 changed files with 4,510 additions and 1,833 deletions.
3 changes: 1 addition & 2 deletions CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ title: ONNX Runtime
message: "Please use this information to cite ONNX Runtime in
research or other publications."
authors:
- affiliation: Microsoft Corporation
given-names: ONNX Runtime developers
- name: ONNX Runtime developers
date-released: 2018-11-29
url: "https://onnxruntime.ai"
repository-code: "https://github.com/microsoft/onnxruntime"
Expand Down
2 changes: 1 addition & 1 deletion cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1600,7 +1600,7 @@ if (UNIX AND onnxruntime_USE_NCCL)
else()
set(onnxruntime_USE_NCCL OFF)
set(onnxruntime_USE_MPI OFF)
message( WARNING "MPI and NCCL disabled on Win build." )
message( WARNING "MPI and NCCL are disabled because build is on Windows or USE_NCCL is set to OFF." )
endif()

if (onnxruntime_USE_MPI)
Expand Down
9 changes: 2 additions & 7 deletions cmake/adjust_global_compile_flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,8 @@ if (onnxruntime_MINIMAL_BUILD)
endif()
endif()

# Enable stream for all the non-minimal build, except for DML. There's currently a bug
# in the allocation planner when reusing buffers and more than one streams are used that
# make it possible (although rarely) to reach a reference count of 0 for a buffer that is
# still being used. Since DML doesn't benefit from multiple streams, disabling it is the
# safest option for now.
# https://github.com/microsoft/onnxruntime/issues/19480
if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_USE_DML)
# Enable stream for all the non-minimal build
if (NOT onnxruntime_MINIMAL_BUILD)
add_compile_definitions(ORT_ENABLE_STREAM)
endif()

Expand Down
9 changes: 0 additions & 9 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@ set(contrib_ops_excluded_files
"bert/fastertransformer_decoder_attention/*"
"bert/multihead_attention.cc"
"bert/multihead_attention.h"
"bert/fast_gelu_impl.cu"
"bert/fast_gelu_impl.h"
"bert/fast_gelu.cc"
"bert/fast_gelu.h"
"bert/relative_attn_bias.cc"
"bert/relative_attn_bias.h"
"bert/relative_attn_bias_impl.cu"
Expand All @@ -44,12 +40,7 @@ set(contrib_ops_excluded_files
"bert/packed_multihead_attention.cc"
"bert/packed_multihead_attention_impl.h"
"bert/packed_multihead_attention_impl.cu"
"diffusion/group_norm.cc"
"diffusion/group_norm_impl.cu"
"diffusion/group_norm_impl.h"
"diffusion/group_norm_impl_kernel.cuh"
"diffusion/group_norm_common_base.h"
"diffusion/group_norm_common_base.cc"
"diffusion/nhwc_conv.cc"
"math/gemm_float8.cc"
"math/gemm_float8.cu"
Expand Down
2 changes: 2 additions & 0 deletions cmake/winml.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,7 @@ if (winml_is_inbox)
get_target_property(compile_options ${target} COMPILE_OPTIONS)
get_target_property(include_directories ${target} INCLUDE_DIRECTORIES)
get_target_property(link_libraries ${target} LINK_LIBRARIES)
get_target_property(link_flags ${target} LINK_FLAGS)
get_target_property(link_options ${target} LINK_OPTIONS)

add_library(${new_target} SHARED ${sources})
Expand All @@ -835,6 +836,7 @@ if (winml_is_inbox)
target_compile_options(${new_target} PRIVATE ${compile_options})
target_include_directories(${new_target} PRIVATE ${include_directories})
target_link_libraries(${new_target} PRIVATE ${link_libraries})
set_property(TARGET ${new_target} PROPERTY LINK_FLAGS "${link_flags}")
target_link_options(${new_target} PRIVATE ${link_options})
endfunction()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ private void TestCUDAProviderOptions()
private void CanRunInferenceOnAModelWithTensorRT()
{
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx");

int deviceId = 0;
string deviceIdStr = System.Environment.GetEnvironmentVariable("ONNXRUNTIME_TEST_GPU_DEVICE_ID");
if (!string.IsNullOrEmpty(deviceIdStr) && int.TryParse(deviceIdStr, out int parsedValue) && parsedValue >= 0)
Expand Down
3 changes: 3 additions & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ Do not modify directly.*
|GatherND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **indices** = tensor(int64)|
|||12|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **indices** = tensor(int64)|
|||11|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **indices** = tensor(int64)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|20+|**T** = tensor(float)|
|Gemm|*in* A:**T**<br> *in* B:**T**<br> *in* C:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float)|
|||[11, 12]|**T** = tensor(double), tensor(float)|
|||[9, 10]|**T** = tensor(double), tensor(float)|
Expand Down Expand Up @@ -606,6 +607,7 @@ Do not modify directly.*
|GatherND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)<br/> **indices** = tensor(int64)|
|||12|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)<br/> **indices** = tensor(int64)|
|||11|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)<br/> **indices** = tensor(int64)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|20+|**T** = tensor(double), tensor(float), tensor(float16)|
|Gemm|*in* A:**T**<br> *in* B:**T**<br> *in* C:**T**<br> *out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|||[9, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
Expand All @@ -617,6 +619,7 @@ Do not modify directly.*
|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
|GreaterOrEqual|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)<br/> **T1** = tensor(bool)|
|||[12, 15]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)<br/> **T1** = tensor(bool)|
|GridSample|*in* X:**T1**<br> *in* grid:**T2**<br> *out* Y:**T1**|16+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|HardSigmoid|*in* X:**T**<br> *out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)|
|Identity|*in* input:**T**<br> *out* output:**T**<br><br>or<br><br>*in* input:**V**<br> *out* output:**V**|19+|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[14, 18]|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down
10 changes: 8 additions & 2 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class Node;
#include "core/framework/stream_handles.h"
#include "core/framework/tuning_context.h"

struct OrtRunOptions;

namespace onnxruntime {

/**
Expand All @@ -51,6 +53,8 @@ struct NodeComputeInfo {
DestroyFunctionStateFunc release_state_func;
};

using RunOptions = OrtRunOptions;

enum class DataLayout {
NCHW,
NHWC,
Expand Down Expand Up @@ -184,15 +188,17 @@ class IExecutionProvider {
Run may not be finished on device This function should be regarded as the
point after which a new Run would start to submit commands from CPU
*/
virtual common::Status OnRunStart() { return Status::OK(); }
virtual common::Status OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { return Status::OK(); }

/**
Called when InferenceSession::Run ended
NOTE that due to async execution in provider, the actual work of this Run
may not be finished on device This function should be regarded as the point
that all commands of current Run has been submmited by CPU
*/
virtual common::Status OnRunEnd(bool /*sync_stream*/) { return Status::OK(); }
virtual common::Status OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) {
return Status::OK();
}

/**
Indicate whether the graph capturing mode (e.g., cuda graph) is enabled for
Expand Down
2 changes: 1 addition & 1 deletion include/onnxruntime/core/providers/cuda/cuda_resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ enum CudaResource : int {
enable_skip_layer_norm_strict_mode_t,
prefer_nhwc_t,
use_tf32_t,
};
};
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,15 @@ static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memor
// Per default it will be set to '0'
// Taking CUDA EP as an example, it omit triggering cudaStreamSynchronize on the compute stream.
static const char* const kOrtRunOptionsConfigDisableSynchronizeExecutionProviders = "disable_synchronize_execution_providers";

// Set HTP performance mode for QNN HTP backend before session run.
// options for HTP performance mode: "burst", "balanced", "default", "high_performance",
// "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver",
// "sustained_high_performance". Default to "default".
static const char* const kOrtRunOptionsConfigQnnPerfMode = "qnn.htp_perf_mode";

// Set HTP performance mode for QNN HTP backend post session run.
static const char* const kOrtRunOptionsConfigQnnPerfModePostRun = "qnn.htp_perf_mode_post_run";

// Set RPC control latency for QNN HTP backend
static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_control_latency";
15 changes: 13 additions & 2 deletions java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime.providers;
Expand All @@ -14,7 +14,18 @@ public enum CoreMLFlags implements OrtFlags {
/** Enables CoreML on subgraphs. */
ENABLE_ON_SUBGRAPH(2), // COREML_FLAG_ENABLE_ON_SUBGRAPH(0x002)
/** Only enable usage of CoreML if the device has an Apple Neural Engine. */
ONLY_ENABLE_DEVICE_WITH_ANE(4); // COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE(0x004),
ONLY_ENABLE_DEVICE_WITH_ANE(4), // COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE(0x004)
/**
* Only allow CoreML EP to take nodes with inputs with static shapes. By default it will also
* allow inputs with dynamic shapes. However, the performance may be negatively impacted if inputs
* have dynamic shapes.
*/
ONLY_ALLOW_STATIC_INPUT_SHAPES(8), // COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES(0x008)
/**
* Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or
* later.
*/
CREATE_MLPROGRAM(16); // COREML_FLAG_CREATE_MLPROGRAM(0x010)

/** The native value of the enum. */
public final int value;
Expand Down
34 changes: 23 additions & 11 deletions js/common/lib/tensor-impl-type-mapping.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ export const NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP = new Map<string, SupportedTy
['uint8', Uint8Array],
['int8', Int8Array],
['uint16', Uint16Array],
['float16', Uint16Array],
['int16', Int16Array],
['int32', Int32Array],
['bool', Uint8Array],
Expand All @@ -34,16 +33,22 @@ export const NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP = new Map<SupportedTypedArray
[Uint32Array, 'uint32'],
]);

// the following code allows delaying execution of BigInt checking. This allows lazy initialization for
// NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP and NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, which allows BigInt polyfill
// if available.
let isBigIntChecked = false;
export const checkBigInt = () => {
if (!isBigIntChecked) {
isBigIntChecked = true;
const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && typeof BigInt64Array.from === 'function';
const isBigUint64ArrayAvailable =
typeof BigUint64Array !== 'undefined' && typeof BigUint64Array.from === 'function';
// a dummy type declaration for Float16Array in case any polyfill is available.
declare global {
// eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any
const Float16Array: any;
}

// the following code allows delaying execution of BigInt/Float16Array checking. This allows lazy initialization for
// NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP and NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, which allows BigInt/Float16Array
// polyfill if available.
let isTypedArrayChecked = false;
export const checkTypedArray = () => {
if (!isTypedArrayChecked) {
isTypedArrayChecked = true;
const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && BigInt64Array.from;
const isBigUint64ArrayAvailable = typeof BigUint64Array !== 'undefined' && BigUint64Array.from;
const isFloat16ArrayAvailable = typeof Float16Array !== 'undefined' && Float16Array.from;

if (isBigInt64ArrayAvailable) {
NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('int64', BigInt64Array);
Expand All @@ -53,5 +58,12 @@ export const checkBigInt = () => {
NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('uint64', BigUint64Array);
NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(BigUint64Array, 'uint64');
}
if (isFloat16ArrayAvailable) {
NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Float16Array);
NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(Float16Array, 'float16');
} else {
// if Float16Array is not available, use 'Uint16Array' to store the data.
NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Uint16Array);
}
}
};
12 changes: 7 additions & 5 deletions js/common/lib/tensor-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {tensorToDataURL, tensorToImageData} from './tensor-conversion-impl.js';
import {TensorToDataUrlOptions, TensorToImageDataOptions} from './tensor-conversion.js';
import {tensorFromGpuBuffer, tensorFromImage, tensorFromPinnedBuffer, tensorFromTexture} from './tensor-factory-impl.js';
import {CpuPinnedConstructorParameters, GpuBufferConstructorParameters, TensorFromGpuBufferOptions, TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, TensorFromTextureOptions, TensorFromUrlOptions, TextureConstructorParameters} from './tensor-factory.js';
import {checkBigInt, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js';
import {checkTypedArray, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js';
import {calculateSize, tensorReshape} from './tensor-utils-impl.js';
import {Tensor as TensorInterface} from './tensor.js';

Expand Down Expand Up @@ -67,8 +67,8 @@ export class Tensor implements TensorInterface {
arg0: TensorType|TensorDataType|readonly string[]|readonly boolean[]|CpuPinnedConstructorParameters|
TextureConstructorParameters|GpuBufferConstructorParameters,
arg1?: TensorDataType|readonly number[]|readonly string[]|readonly boolean[], arg2?: readonly number[]) {
// perform one-time check for BigInt support
checkBigInt();
// perform one-time check for BigInt/Float16Array support
checkTypedArray();

let type: TensorType;
let dims: readonly number[];
Expand Down Expand Up @@ -103,7 +103,7 @@ export class Tensor implements TensorInterface {
}
case 'gpu-buffer': {
if ((type !== 'float32' && type !== 'float16' && type !== 'int32' && type !== 'int64' && type !== 'uint32' &&
type !== 'bool')) {
type !== 'uint8' && type !== 'bool')) {
throw new TypeError(`unsupported type "${type}" to create tensor from gpu buffer`);
}
this.gpuBufferData = arg0.gpuBuffer;
Expand Down Expand Up @@ -142,7 +142,9 @@ export class Tensor implements TensorInterface {
throw new TypeError(`Unsupported tensor type: ${arg0}.`);
}
if (Array.isArray(arg1)) {
if (arg0 === 'float16') {
if (arg0 === 'float16' && typedArrayConstructor === Uint16Array) {
// When no Float16Array polyfill is used, we cannot create 'float16' tensor from number array.
//
// Throw error here because when user try to use number array as data,
// e.g. new Tensor('float16', [1, 2, 3, 4], dims)), it will actually call
// Uint16Array.from(arg1) which generates wrong data.
Expand Down
2 changes: 1 addition & 1 deletion js/common/lib/tensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ export declare namespace Tensor {
/**
* supported data types for constructing a tensor from a WebGPU buffer
*/
export type GpuBufferDataTypes = 'float32'|'float16'|'int32'|'int64'|'uint32'|'bool';
export type GpuBufferDataTypes = 'float32'|'float16'|'int32'|'int64'|'uint32'|'uint8'|'bool';

/**
* represent where the tensor data is stored
Expand Down
Loading

0 comments on commit 4b7df02

Please sign in to comment.