Skip to content

Commit

Permalink
Merge branch 'main' into Cjian/cuda12
Browse files Browse the repository at this point in the history
  • Loading branch information
jchen351 committed Oct 25, 2023
2 parents 617aad4 + 706e13e commit 128696a
Show file tree
Hide file tree
Showing 161 changed files with 13,448 additions and 1,783 deletions.
2 changes: 2 additions & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ exclude_patterns = [
'cmake/external/**',
# ignore generated flatbuffers code
'onnxruntime/core/flatbuffers/ort_flatbuffers_py/**',
'orttraining/orttraining/python/training/optim/_ds_code_store.py',
]
command = [
'python',
Expand Down Expand Up @@ -76,6 +77,7 @@ exclude_patterns = [
'cmake/**',
'orttraining/*',
'onnxruntime/core/flatbuffers/**',
'orttraining/orttraining/python/training/optim/_ds_code_store.py',
]
command = [
'python',
Expand Down
32 changes: 31 additions & 1 deletion cgmanifests/generated/cgmanifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,36 @@
"$schema": "https://json.schemastore.org/component-detection-manifest.json",
"Version": 1,
"Registrations": [
{
"component": {
"type": "git",
"git": {
"commitHash": "a896e3d066448b3530dbcaa48869fafefd738f57",
"repositoryUrl": "https://github.com/emscripten-core/emsdk.git"
},
"comments": "git submodule at cmake/external/emsdk"
}
},
{
"component": {
"type": "git",
"git": {
"commitHash": "7a2ed51a6b682a83e345ff49fc4cfd7ca47550db",
"repositoryUrl": "https://github.com/google/libprotobuf-mutator.git"
},
"comments": "git submodule at cmake/external/libprotobuf-mutator"
}
},
{
"component": {
"type": "git",
"git": {
"commitHash": "0c296085f9f65f0f8ef7aec7b9eed55faf37dc40",
"repositoryUrl": "https://github.com/onnx/onnx.git"
},
"comments": "git submodule at cmake/external/onnx"
}
},
{
"component": {
"type": "git",
Expand Down Expand Up @@ -166,7 +196,7 @@
"component": {
"type": "git",
"git": {
"commitHash": "fdefbe85ed9c362b95b9b401cd19db068a76141f",
"commitHash": "6a20ba82b439ea1fd650da4d389e96b60a1dd828",
"repositoryUrl": "https://github.com/onnx/onnx.git"
},
"comments": "onnx"
Expand Down
2 changes: 1 addition & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf36
microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41
mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063
onnx;https://github.com/onnx/onnx/archive/14303de049144035dfd94ace5f7a3b44773b1aad.zip;250eab9690392b248d75b56e605fb49eca373442
onnx;https://github.com/onnx/onnx/archive/6a20ba82b439ea1fd650da4d389e96b60a1dd828.zip;179a22ad4cd67109c60031ae4b6cf2f434d8bd7e
#use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459)
onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/0462dc31ae78f48744b6141ae376df1f96d3f459.zip;5ff086361956cceb81ed17453a1fd8db2aa4328d
protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa
Expand Down
2 changes: 1 addition & 1 deletion cmake/external/onnx
Submodule onnx updated 960 files
6 changes: 6 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,9 @@ else()
${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelNeon.S
${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelNeon.S
${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUdot.S
${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S
${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSdot.S
${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S
${MLAS_SRC_DIR}/aarch64/SgemmKernelNeon.S
${MLAS_SRC_DIR}/aarch64/SgemvKernelNeon.S
${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelNeon.S
Expand All @@ -334,6 +336,8 @@ else()
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp
${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp
)
if (NOT APPLE)
set(mlas_platform_srcs
Expand All @@ -348,6 +352,8 @@ else()
set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
endif()

if(ONNXRUNTIME_MLAS_MULTI_ARCH)
Expand Down
51 changes: 50 additions & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ Do not modify directly.*
* <a href="#com.microsoft.RemovePadding">com.microsoft.RemovePadding</a>
* <a href="#com.microsoft.RestorePadding">com.microsoft.RestorePadding</a>
* <a href="#com.microsoft.Rfft">com.microsoft.Rfft</a>
* <a href="#com.microsoft.RotaryEmbedding">com.microsoft.RotaryEmbedding</a>
* <a href="#com.microsoft.SampleOp">com.microsoft.SampleOp</a>
* <a href="#com.microsoft.Sampling">com.microsoft.Sampling</a>
* <a href="#com.microsoft.SkipLayerNormalization">com.microsoft.SkipLayerNormalization</a>
Expand Down Expand Up @@ -2834,7 +2835,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>bias</tt> (optional) : T</dt>
<dd>Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection</dd>
<dt><tt>key_padding_mask</tt> (optional) : M</dt>
<dd>Key padding mask with shape (batch_size) or (3 * batch_size + 2) or (batch_size, kv_sequence_length)</dd>
<dd>Key padding mask with shape (batch_size), (3 * batch_size + 2), (batch_size, kv_sequence_length), (batch_size, total_sequence_length), or (batch_size, sequence_length, total_sequence_length)</dd>
<dt><tt>relative_position_bias</tt> (optional) : T</dt>
<dd>relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length) or (1, num_heads, sequence_length, total_sequence_length)</dd>
<dt><tt>past_key</tt> (optional) : T</dt>
Expand Down Expand Up @@ -4796,6 +4797,54 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.RotaryEmbedding"></a><a name="com.microsoft.rotaryembedding">**com.microsoft.RotaryEmbedding**</a>

RotaryEmbedding is the implementation of rotary positional embeddings (RoPE). The positions are represented as rotation matrices
that are multiplied to query and key before the inner product of query and key is taken.

#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Attributes

<dl>
<dt><tt>interleaved</tt> : int</dt>
<dd>Rotate using interleaved pattern. Default value is 0 (False).</dd>
<dt><tt>scale</tt> : float</dt>
<dd>Custom scale will be used if specified. Default value is 1.0</dd>
</dl>

#### Inputs

<dl>
<dt><tt>input</tt> : T</dt>
<dd>3D tensor with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>position_ids</tt> : M</dt>
<dd>1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)</dd>
<dt><tt>cos_cache</tt> : T</dt>
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
<dt><tt>sin_cache</tt> : T</dt>
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
</dl>

#### Outputs

<dl>
<dt><tt>output</tt> : T</dt>
<dd>3D tensor with shape (batch_size, sequence_length, hidden_size)</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain input and output types to float tensors.</dd>
<dt><tt>M</tt> : tensor(int64)</dt>
<dd>Constrain input and output types to integer tensors</dd>
</dl>


### <a name="com.microsoft.SampleOp"></a><a name="com.microsoft.sampleop">**com.microsoft.SampleOp**</a>

Sample echo operator.
Expand Down
10 changes: 8 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Do not modify directly.*
|||13|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|||[7, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|Affine|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|AffineGrid|*in* theta:**T1**<br> *in* size:**T2**<br> *out* grid:**T1**|20+|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(int64)|
|And|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|7+|**T** = tensor(bool)<br/> **T1** = tensor(bool)|
|ArgMax|*in* data:**T**<br> *out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)|
Expand Down Expand Up @@ -156,8 +157,10 @@ Do not modify directly.*
|||[1, 10]|**B** = tensor(bool)<br/> **V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|ImageScaler|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(float)|
|InstanceNormalization|*in* input:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *out* output:**T**|6+|**T** = tensor(float)|
|IsInf|*in* X:**T1**<br> *out* Y:**T2**|10+|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(bool)|
|IsNaN|*in* X:**T1**<br> *out* Y:**T2**|13+|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|IsInf|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
|||[10, 19]|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(bool)|
|IsNaN|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
|||[13, 19]|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|LRN|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(float)|
|||[1, 12]|**T** = tensor(float)|
Expand Down Expand Up @@ -477,9 +480,11 @@ Do not modify directly.*
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int16), tensor(int8), tensor(uint16), tensor(uint8)|
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**<br> *in* limit:**T**<br> *in* delta:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float)|
|SampleOp|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *in* seed:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
|SkipSimplifiedLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
|SparseToDenseMatMul|*in* A:**T**<br> *in* B:**T1**<br> *out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)<br/> **T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Tokenizer|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(string)|
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
Expand Down Expand Up @@ -866,6 +871,7 @@ Do not modify directly.*
|RemovePadding|*in* input:**T**<br> *in* sequence_token_count:**M**<br> *out* output:**T**<br> *out* token_offset:**M**<br> *out* cumulated_seq_len:**M**<br> *out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)|
|RestorePadding|*in* input:**T**<br> *in* token_offset:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Rfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float), tensor(float16)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *in* seed:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipSimplifiedLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down
5 changes: 4 additions & 1 deletion include/onnxruntime/core/framework/float8.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,10 @@ struct Float8E4M3FNUZ {
val = static_cast<uint8_t>((b & 0x80000000) >> 24); // sign
if ((b & 0x7fffffff) == 0x7f800000) { // infinity
if (saturate) {
// the highest available value
val |= 0x7F;
} else {
// infinity
// NaN
val = 0x80;
}
} else if ((b & 0x7F800000) == 0x7F800000) { // NaN
Expand Down Expand Up @@ -362,8 +363,10 @@ struct Float8E5M2 {
val = (b & 0x80000000) >> 24; // sign
if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf
if (saturate) {
// the highest available value
val |= 0x7B;
} else {
// the infinity
val |= 0x7C;
}
} else if ((b & 0x7F800000) == 0x7F800000) { // NaN
Expand Down
44 changes: 33 additions & 11 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#pragma once

#include <functional>
#include <limits>
#include <memory>
#include <string>
Expand Down Expand Up @@ -83,10 +84,10 @@ class Node {
gsl::span<NodeArg* const> output_args,
const NodeAttributes* attributes,
std::string_view domain) {
Init(std::string{name}, std::string{op_type}, std::string{description},
std::vector<NodeArg*>{input_args.begin(), input_args.end()},
std::vector<NodeArg*>{output_args.begin(), output_args.end()},
attributes, std::string{domain});
Init(name, op_type, description,
input_args,
output_args,
attributes, domain);
}
#endif

Expand Down Expand Up @@ -563,13 +564,13 @@ class Node {
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Node);

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
void Init(const std::string& name,
const std::string& op_type,
const std::string& description,
const std::vector<NodeArg*>& input_args,
const std::vector<NodeArg*>& output_args,
void Init(std::string_view name,
std::string_view op_type,
std::string_view description,
gsl::span<NodeArg* const> input_args,
gsl::span<NodeArg* const> output_args,
const NodeAttributes* attributes,
const std::string& domain);
std::string_view domain);
#endif

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
Expand Down Expand Up @@ -1141,8 +1142,22 @@ class Graph {
*/
Status InlineFunction(Node& node);

/**
Directly insert the nodes in the function proto provided into the graph.
The function converts Constant nodes into the initializers in the graph.
It then creates a node in the graph for each of the function nodes.
All of the names are expected to be specialized, and, therefore unique.
See function_utils::Specialize().

The Graph needs to be Resolve()d after this call.
@param func_to_inline
@returns Status indicating success or providing an error message.
*/

Status InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto& func_to_inline);

/** Mark a NodeArg name as coming from the outer scope when programmatically constructing a Graph that will
be used as a GraphProto attribute in another Node..
be used as a GraphProto attribute in another Node.
e.g. when creating a Graph instance that will be used as a subgraph in a control flow operator, it is necessary to
define placeholder NodeArgs for outer scope values. This prevents these values from becoming explicit graph inputs
when the Graph is resolved.
Expand Down Expand Up @@ -1391,6 +1406,13 @@ class Graph {
Node& AddNode(const ONNX_NAMESPACE::NodeProto& node_proto,
const ArgNameToTypeMap& name_to_type);

/** Helper that converts and adds constant node proto to an initializer in the graph.
@param constant_node_proto Constant node to convert
@param new_name use the new name for the initializer.
*/
Status AddConstantProtoAsInitializer(const ONNX_NAMESPACE::NodeProto& constant_node_proto,
std::optional<std::string_view> new_name);

#endif

Version IrVersion() const noexcept {
Expand Down
31 changes: 31 additions & 0 deletions include/onnxruntime/core/providers/cuda/cuda_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ struct CudaContext : public CustomOpContext {
cudaStream_t cuda_stream = {};
cudnnHandle_t cudnn_handle = {};
cublasHandle_t cublas_handle = {};
OrtAllocator* deferred_cpu_allocator = {};

void Init(const OrtKernelContext& kernel_ctx) override {
const auto& ort_api = Ort::GetApi();
Expand All @@ -44,6 +45,36 @@ struct CudaContext : public CustomOpContext {
ORT_CXX_API_THROW("failed to fetch cublas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
cublas_handle = reinterpret_cast<cublasHandle_t>(resource);

resource = {};
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::deferred_cpu_allocator_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch deferred cpu allocator", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
deferred_cpu_allocator = reinterpret_cast<OrtAllocator*>(resource);
}

void* AllocDeferredCpuMem(size_t size) const {
if (0 == size) {
return {};
}
const auto& ort_api = Ort::GetApi();
void* mem = {};
auto status = ort_api.AllocatorAlloc(deferred_cpu_allocator, size, &mem);
if (status) {
ORT_CXX_API_THROW("failed to allocate deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
return mem;
}

void FreeDeferredCpuMem(void* mem) const {
if (mem) {
const auto& ort_api = Ort::GetApi();
auto status = ort_api.AllocatorFree(deferred_cpu_allocator, mem);
if (status) {
ORT_CXX_API_THROW("failed to free deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
}
}
};

Expand Down
5 changes: 3 additions & 2 deletions include/onnxruntime/core/providers/cuda/cuda_resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

#include "core/providers/resource.h"

#define ORT_CUDA_RESOUCE_VERSION 1
#define ORT_CUDA_RESOUCE_VERSION 2

enum CudaResource : int {
cuda_stream_t = cuda_resource_offset,
cudnn_handle_t,
cublas_handle_t
cublas_handle_t,
deferred_cpu_allocator_t,
};
Loading

0 comments on commit 128696a

Please sign in to comment.