Skip to content

Commit

Permalink
Merge branch 'main' into yifanl/trt_10_2_oss
Browse files Browse the repository at this point in the history
  • Loading branch information
yf711 committed Jul 26, 2024
2 parents 3b61f68 + 0f1f3b7 commit 392cc74
Show file tree
Hide file tree
Showing 294 changed files with 3,954 additions and 1,165 deletions.
2 changes: 1 addition & 1 deletion .gitattributes
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# This sets the default behaviour, overriding core.autocrlf
# This sets the default behavior, overriding core.autocrlf
* text=auto

# All source files should have unix line-endings in the repository,
Expand Down
2 changes: 1 addition & 1 deletion .pipelines/nuget_config/x64/packages.config
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<packages>
<package id="python" version="3.9.7" targetFramework="native" />
<package id="Microsoft.AI.DirectML" version="1.14.1" targetFramework="native" />
<package id="Microsoft.AI.DirectML" version="1.15.0" targetFramework="native" />
<package id="Microsoft.Windows.CppWinRT" version="2.0.201201.7" targetFramework="native" />
</packages>
2 changes: 1 addition & 1 deletion .pipelines/nuget_config/x86/packages.config
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<packages>
<package id="pythonx86" version="3.9.7" targetFramework="native" />
<package id="Microsoft.AI.DirectML" version="1.14.1" targetFramework="native" />
<package id="Microsoft.AI.DirectML" version="1.15.0" targetFramework="native" />
<package id="Microsoft.Windows.CppWinRT" version="2.0.201201.7" targetFramework="native" />
</packages>
2 changes: 1 addition & 1 deletion ThirdPartyNotices.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4820,7 +4820,7 @@ SOFTWARE.

----------------------------------------------------------------------------

This is the MIT/Expat Licence. For more information see:
This is the MIT/Expat License. For more information see:

1. http://www.opensource.org/licenses/mit-license.php

Expand Down
2 changes: 1 addition & 1 deletion cmake/external/dml.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML)
set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config)
set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config)
get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE)
set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.14.1)
set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.15.0)

# Restore nuget packages, which will pull down the DirectML redist package.
add_custom_command(
Expand Down
2 changes: 1 addition & 1 deletion cmake/onnxruntime.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ endif()

if(CMAKE_SYSTEM_NAME STREQUAL "Android" AND onnxruntime_MINIMAL_BUILD)
# target onnxruntime is a shared library, the dummy __cxa_demangle is only attach to it to avoid
# affecting downstream ort library users with the behaviour of dummy __cxa_demangle. So the dummy
# affecting downstream ort library users with the behavior of dummy __cxa_demangle. So the dummy
# __cxa_demangle must not expose to libonnxruntime_common.a. It works as when the linker is
# creating the DSO, our dummy __cxa_demangle always comes before libc++abi.a so the
# __cxa_demangle in libc++abi.a is discarded, thus, huge binary size reduction.
Expand Down
2 changes: 2 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ set(contrib_ops_excluded_files
"bert/attention_softmax.h"
"bert/attention_softmax.cu"
"bert/attention_prepare_qkv.cu"
"bert/attention_kernel_options.h"
"bert/attention_kernel_options.cc"
"bert/decoder_attention_impl.h"
"bert/decoder_attention_impl.cu"
"bert/decoder_masked_multihead_attention.h"
Expand Down
3 changes: 2 additions & 1 deletion cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -786,8 +786,9 @@ if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS)
onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_ut ${onnxruntime_test_providers_cuda_ut_src} $<TARGET_OBJECTS:onnxruntime_providers_cuda_obj>)
config_cuda_provider_shared_module(onnxruntime_providers_cuda_ut)
onnxruntime_add_include_to_target(onnxruntime_providers_cuda_ut GTest::gtest GTest::gmock)
add_dependencies(onnxruntime_providers_cuda_ut onnxruntime_test_utils onnxruntime_common)
target_include_directories(onnxruntime_providers_cuda_ut PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey)
target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common)
target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_test_utils onnxruntime_common)
if (MSVC)
# Cutlass code has an issue with the following:
# warning C4100: 'magic': unreferenced formal parameter
Expand Down
2 changes: 1 addition & 1 deletion cmake/patches/composable_kernel/Fix_Clang_Build.patch
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ index c23746e7f..bc326c8b5 100644
find_package(HIP REQUIRED)
# Override HIP version in config.h, if necessary.
@@ -269,12 +248,6 @@ if( DEFINED CK_OVERRIDE_HIP_VERSION_PATCH )
message(STATUS "CK_HIP_VERSION_PATCH overriden with ${CK_OVERRIDE_HIP_VERSION_PATCH}")
message(STATUS "CK_HIP_VERSION_PATCH overridden with ${CK_OVERRIDE_HIP_VERSION_PATCH}")
endif()
message(STATUS "Build with HIP ${HIP_VERSION}")
-link_libraries(hip::device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Event {{name.0.value}}
Operator {{name.0.value}}
{{/inOperator}}
{{#inEii}}
Explict Interface Implementation {{name.0.value}}
Explicit Interface Implementation {{name.0.value}}
{{/inEii}}
{{#inVariable}}
Variable {{name.0.value}}
Expand Down
4 changes: 2 additions & 2 deletions dockerfiles/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
docker run -it onnxruntime-source
```

The docker file supports both x86_64 and ARM64(aarch64). You may use docker's "--platform" parameter to explictly specify which CPU architecture you want to build. For example:
The docker file supports both x86_64 and ARM64(aarch64). You may use docker's "--platform" parameter to explicitly specify which CPU architecture you want to build. For example:

```bash
docker build --platform linux/arm64/v8 -f Dockerfile.source
Expand Down Expand Up @@ -274,7 +274,7 @@ Note: You may add --use_tensorrt and --tensorrt_home options if you wish to use
Note: Resulting Docker image will have ONNX Runtime installed in /usr, and ONNX Runtime wheel copied to /onnxruntime directory.
Nothing else from ONNX Runtime source tree will be copied/installed to the image.

Note: When running the container you built in Docker, please either use 'nvidia-docker' command instead of 'docker', or use Docker command-line options to make sure NVIDIA runtime will be used and appropiate files mounted from host. Otherwise, CUDA libraries won't be found. You can also [set NVIDIA runtime as default in Docker](https://github.com/dusty-nv/jetson-containers#docker-default-runtime).
Note: When running the container you built in Docker, please either use 'nvidia-docker' command instead of 'docker', or use Docker command-line options to make sure NVIDIA runtime will be used and appropriate files mounted from host. Otherwise, CUDA libraries won't be found. You can also [set NVIDIA runtime as default in Docker](https://github.com/dusty-nv/jetson-containers#docker-default-runtime).

## MIGraphX
**Ubuntu 20.04, ROCm6.0, MIGraphX**
Expand Down
10 changes: 10 additions & 0 deletions docs/ORTModule_Training_Guidelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,16 @@ A classical usage of disabling the deep copy: when the deep copy before module e
export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=0 # Disable
```
#### ORTMODULE_ATEN_SDPA_FALLBACK
- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, this is disabled. This env var can be used for enabling pre-export attention fall back to PyTorch's [_scaled_dot_product_efficient_attention](https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14778) ATen kernel for execution when calling torch.nn.functional.scaled_dot_product_attention. NOTE: only use this feature if user model leverages memory efficient attention WITHOUT masking (ie. attn_mask=None). Utilize GPU profiling looks like NVIDIA Nsight Systems to identify if user model leverages memory efficient attention.

```bash
export ORTMODULE_ATEN_SDPA_FALLBACK=1 # ENABLE
unset ORTMODULE_ATEN_SDPA_FALLBACK # DISABLE
```

### 2.2 Memory Optimization

Q: *Want to run a bigger batch size?*
Expand Down
8 changes: 6 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,7 @@ Do not modify directly.*
|||12+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||11+|**T** = tensor(float), tensor(float16)|
|||6+|**T** = tensor(float), tensor(float16)|
|Col2Im|*in* input:**T**<br> *in* image_shape:**tensor(int64)**<br> *in* block_shape:**tensor(int64)**<br> *out* output:**T**|18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Concat|*in* inputs:**T**<br> *out* concat_result:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||4+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down Expand Up @@ -1131,7 +1132,8 @@ Do not modify directly.*
|PRelu|*in* X:**T**<br> *in* slope:**T**<br> *out* Y:**T**|16+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)|
|||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)|
|||7+|**T** = tensor(float), tensor(float16)|
|Pad|*in* data:**T**<br> *in* pads:**tensor(int64)**<br> *in* constant_value:**T**<br> *in* axes:**Tind**<br> *out* output:**T**<br><br>or<br><br>*in* data:**T**<br> *in* pads:**tensor(int64)**<br> *in* constant_value:**T**<br> *out* output:**T**<br><br>or<br><br>*in* data:**T**<br> *out* output:**T**|18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Pad|*in* data:**T**<br> *in* pads:**tensor(int64)**<br> *in* constant_value:**T**<br> *in* axes:**Tind**<br> *out* output:**T**<br><br>or<br><br>*in* data:**T**<br> *in* pads:**tensor(int64)**<br> *in* constant_value:**T**<br> *out* output:**T**<br><br>or<br><br>*in* data:**T**<br> *out* output:**T**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||2+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down Expand Up @@ -1199,7 +1201,9 @@ Do not modify directly.*
|||14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||5+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Resize|*in* X:**T**<br> *in* scales:**tensor(float)**<br> *out* Y:**T**<br><br>or<br><br>*in* X:**T1**<br> *in* roi:**T2**<br> *in* scales:**tensor(float)**<br> *in* sizes:**tensor(int64)**<br> *out* Y:**T1**|13+|**T1** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|Resize|*in* X:**T**<br> *in* scales:**tensor(float)**<br> *out* Y:**T**<br><br>or<br><br>*in* X:**T1**<br> *in* roi:**T2**<br> *in* scales:**tensor(float)**<br> *in* sizes:**tensor(int64)**<br> *out* Y:**T1**|19+|**T1** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|||18+|**T1** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|||13+|**T1** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|||11+|**T1** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|||10+|**T** = tensor(float), tensor(float16)|
|ReverseSequence|*in* input:**T**<br> *in* sequence_lens:**tensor(int64)**<br> *out* Y:**T**|10+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down
4 changes: 2 additions & 2 deletions docs/python/notebooks/onnx-inference-byoc-gpu-cpu-aks.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
"If you are using an Azure Machine Learning Notebook VM, you are all set. Otherwise, please follow the [Azure ML configuration notebook](https://github.com/Azure/MachineLearningNotebooks/blob/master/configuration.ipynb) to set up your environment.\n",
"\n",
"### Install additional packages needed for this Notebook\n",
"You need to install the popular plotting library matplotlib, the image manipulation library opencv, and the onnx library in the conda environment where Azure Maching Learning SDK is installed.\n",
"You need to install the popular plotting library matplotlib, the image manipulation library opencv, and the onnx library in the conda environment where Azure Machine Learning SDK is installed.\n",
"\n",
"```\n",
"(myenv) $ pip install matplotlib onnx opencv-python\n",
Expand All @@ -79,7 +79,7 @@
"source": [
"## 1. Obtain a model from the ONNX Model Zoo\n",
"\n",
"For more information on the Facial Emotion Recognition (FER+) model, you can explore the notebook explaning how to deploy [FER+ with ONNX Runtime on an ACI Instance](onnx-inference-facial-expression-recognition-deploy.ipynb)."
"For more information on the Facial Emotion Recognition (FER+) model, you can explore the notebook explaining how to deploy [FER+ with ONNX Runtime on an ACI Instance](onnx-inference-facial-expression-recognition-deploy.ipynb)."
]
},
{
Expand Down
7 changes: 5 additions & 2 deletions include/onnxruntime/core/optimizer/graph_transformer_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "core/common/inlined_containers.h"
#include "core/framework/session_options.h"
#include "core/optimizer/graph_transformer.h"
#include "core/platform/threadpool.h"

#if !defined(ORT_MINIMAL_BUILD)
#include "core/optimizer/rule_based_graph_transformer.h"
Expand Down Expand Up @@ -49,7 +50,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
TransformerLevel level,
const SessionOptions& session_options,
const IExecutionProvider& execution_provider /*required by constant folding*/,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {});
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {},
concurrency::ThreadPool* intra_op_thread_pool = nullptr);

#endif // !defined(ORT_MINIMAL_BUILD)

Expand Down Expand Up @@ -78,7 +80,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
const SessionOptions& session_options,
const SatApplyContextVariant& apply_context,
const IExecutionProvider& cpu_execution_provider,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {});
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {},
concurrency::ThreadPool* intra_op_thread_pool = nullptr);

#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1129,7 +1129,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
//
// Ensure that the ThreadPoolParallelSection has sufficient workers to
// execute a loop with degree of parallelism n. We track the number
// of workers already avaiable to the parallel section, prior to
// of workers already available to the parallel section, prior to
// submitting tasks to the work queues to make up the total.
//
// Each worker will call in to worker_fn(idx) with a per-worker thread
Expand Down
12 changes: 8 additions & 4 deletions include/onnxruntime/core/providers/cuda/cuda_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,25 @@ struct CudaContext : public CustomOpContext {
cudnn_conv_use_max_workspace = FetchResource<bool>(kernel_ctx, CudaResource::cudnn_conv_use_max_workspace_t);

cudnn_conv1d_pad_to_nc1d = FetchResource<bool>(kernel_ctx, CudaResource::cudnn_conv1d_pad_to_nc1d_t);
enable_skip_layer_norm_strict_mode = FetchResource<bool>(kernel_ctx, CudaResource::enable_skip_layer_norm_strict_mode_t);
enable_skip_layer_norm_strict_mode = FetchResource<bool>(
kernel_ctx, CudaResource::enable_skip_layer_norm_strict_mode_t);
prefer_nhwc = FetchResource<bool>(kernel_ctx, CudaResource::prefer_nhwc_t);
use_tf32 = FetchResource<bool>(kernel_ctx, CudaResource::use_tf32_t);
}

template <typename T>
T FetchResource(const OrtKernelContext& kernel_ctx, CudaResource resource_type) {
if constexpr (sizeof(T) > sizeof(void*)) {
ORT_CXX_API_THROW("void* is not large enough to hold resource type: " + std::to_string(resource_type), OrtErrorCode::ORT_INVALID_ARGUMENT);
ORT_CXX_API_THROW("void* is not large enough to hold resource type: " + std::to_string(resource_type),
OrtErrorCode::ORT_INVALID_ARGUMENT);
}
const auto& ort_api = Ort::GetApi();
void* resource = {};
OrtStatus* status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, resource_type, &resource);
OrtStatus* status = ort_api.KernelContext_GetResource(
&kernel_ctx, ORT_CUDA_RESOURCE_VERSION, resource_type, &resource);
if (status) {
ORT_CXX_API_THROW("Failed to fetch cuda ep resource, resouce type: " + std::to_string(resource_type), OrtErrorCode::ORT_RUNTIME_EXCEPTION);
ORT_CXX_API_THROW("Failed to fetch cuda ep resource, resource type: " + std::to_string(resource_type),
OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
T t = {};
memcpy(&t, &resource, sizeof(T));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@ struct OrtCUDAProviderOptionsV2 {
int prefer_nhwc = 0; // make the CUDA EP NHWC preferred
int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not
int use_tf32 = 1; // use TF32
int sdpa_kernel = 0; // Scaled Dot Product Attention kernel option
};
2 changes: 1 addition & 1 deletion include/onnxruntime/core/providers/cuda/cuda_resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#include "core/providers/resource.h"

#define ORT_CUDA_RESOUCE_VERSION 3
#define ORT_CUDA_RESOURCE_VERSION 3

enum CudaResource : int {
cuda_stream_t = cuda_resource_offset, // 10000
Expand Down
9 changes: 6 additions & 3 deletions include/onnxruntime/core/providers/rocm/rocm_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,24 @@ struct RocmContext : public CustomOpContext {
void* resource = {};
OrtStatus* status = nullptr;

status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_ROCM_RESOUCE_VERSION, RocmResource::hip_stream_t, &resource);
status = ort_api.KernelContext_GetResource(
&kernel_ctx, ORT_ROCM_RESOURCE_VERSION, RocmResource::hip_stream_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch hip stream", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
hip_stream = reinterpret_cast<hipStream_t>(resource);

resource = {};
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_ROCM_RESOUCE_VERSION, RocmResource::miopen_handle_t, &resource);
status = ort_api.KernelContext_GetResource(
&kernel_ctx, ORT_ROCM_RESOURCE_VERSION, RocmResource::miopen_handle_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch miopen handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
miopen_handle = reinterpret_cast<miopenHandle_t>(resource);

resource = {};
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_ROCM_RESOUCE_VERSION, RocmResource::rocblas_handle_t, &resource);
status = ort_api.KernelContext_GetResource(
&kernel_ctx, ORT_ROCM_RESOURCE_VERSION, RocmResource::rocblas_handle_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch rocblas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
Expand Down
Loading

0 comments on commit 392cc74

Please sign in to comment.