Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding CUDNN Frontend and use for CUDA NN Convolution #19470

Merged
merged 8 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions cgmanifests/generated/cgmanifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,16 @@
},
"comments": "directx_headers"
}
},
{
"component": {
"type": "git",
"git": {
"commitHash": "98ca4e1941fe3263f128f74f10063a3ea35c7019",
"repositoryUrl": "https://github.com/NVIDIA/cudnn-frontend.git"
},
"comments": "cudnn_frontend"
}
}
]
}
6 changes: 0 additions & 6 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -729,9 +729,6 @@ set(ORT_PROVIDER_FLAGS)
set(ORT_PROVIDER_CMAKE_FLAGS)

if (onnxruntime_USE_CUDA)
if (onnxruntime_USE_CUDA_NHWC_OPS)
add_compile_definitions(ENABLE_CUDA_NHWC_OPS)
endif()
enable_language(CUDA)
message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}")

Expand Down Expand Up @@ -1445,9 +1442,6 @@ if (onnxruntime_USE_CUDA)
file(TO_CMAKE_PATH CUDAToolkit_ROOT ${onnxruntime_CUDA_HOME})
endif()
find_package(CUDAToolkit REQUIRED)
if(onnxruntime_CUDNN_HOME)
file(TO_CMAKE_PATH ${onnxruntime_CUDNN_HOME} onnxruntime_CUDNN_HOME)
endif()
if (NOT CMAKE_CUDA_ARCHITECTURES)
if (CMAKE_LIBRARY_ARCHITECTURE STREQUAL "aarch64-linux-gnu")
# Support for Jetson/Tegra ARM devices
Expand Down
1 change: 1 addition & 0 deletions cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,4 @@ utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240
extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c
composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/204da9c522cebec5220bba52cd3542ebcaf99e7a.zip;1827348efd47831c13074245274d41b7cae8a557
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.5.2.zip;11071a47594b20f00af09aad83e0d5203ccf6029
111 changes: 111 additions & 0 deletions cmake/external/cuDNN.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
add_library(CUDNN::cudnn_all INTERFACE IMPORTED)

find_path(
CUDNN_INCLUDE_DIR cudnn.h
HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${Python_SITEARCH}/nvidia/cudnn ${CUDAToolkit_INCLUDE_DIRS}
PATH_SUFFIXES include
REQUIRED
)

file(READ "${CUDNN_INCLUDE_DIR}/cudnn_version.h" cudnn_version_header)
string(REGEX MATCH "#define CUDNN_MAJOR [1-9]+" macrodef "${cudnn_version_header}")
string(REGEX MATCH "[1-9]+" CUDNN_MAJOR_VERSION "${macrodef}")

function(find_cudnn_library NAME)
find_library(
${NAME}_LIBRARY ${NAME} "lib${NAME}.so.${CUDNN_MAJOR_VERSION}"
HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${Python_SITEARCH}/nvidia/cudnn ${CUDAToolkit_LIBRARY_DIR}
PATH_SUFFIXES lib64 lib/x64 lib
REQUIRED
)

if(${NAME}_LIBRARY)
add_library(CUDNN::${NAME} UNKNOWN IMPORTED)
set_target_properties(
CUDNN::${NAME} PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES ${CUDNN_INCLUDE_DIR}
IMPORTED_LOCATION ${${NAME}_LIBRARY}
)
message(STATUS "${NAME} found at ${${NAME}_LIBRARY}.")
else()
message(STATUS "${NAME} not found.")
endif()


endfunction()

find_cudnn_library(cudnn)

include (FindPackageHandleStandardArgs)
find_package_handle_standard_args(
LIBRARY REQUIRED_VARS
CUDNN_INCLUDE_DIR cudnn_LIBRARY
)

if(CUDNN_INCLUDE_DIR AND cudnn_LIBRARY)

message(STATUS "cuDNN: ${cudnn_LIBRARY}")
message(STATUS "cuDNN: ${CUDNN_INCLUDE_DIR}")

set(CUDNN_FOUND ON CACHE INTERNAL "cuDNN Library Found")

else()

set(CUDNN_FOUND OFF CACHE INTERNAL "cuDNN Library Not Found")

endif()

target_include_directories(
CUDNN::cudnn_all
INTERFACE
$<INSTALL_INTERFACE:include>
$<BUILD_INTERFACE:${CUDNN_INCLUDE_DIR}>
)

target_link_libraries(
CUDNN::cudnn_all
INTERFACE
CUDNN::cudnn
)

if(CUDNN_MAJOR_VERSION EQUAL 8)
find_cudnn_library(cudnn_adv_infer)
find_cudnn_library(cudnn_adv_train)
find_cudnn_library(cudnn_cnn_infer)
find_cudnn_library(cudnn_cnn_train)
find_cudnn_library(cudnn_ops_infer)
find_cudnn_library(cudnn_ops_train)

target_link_libraries(
CUDNN::cudnn_all
INTERFACE
CUDNN::cudnn_adv_train
CUDNN::cudnn_ops_train
CUDNN::cudnn_cnn_train
CUDNN::cudnn_adv_infer
CUDNN::cudnn_cnn_infer
CUDNN::cudnn_ops_infer
)
elseif(CUDNN_MAJOR_VERSION EQUAL 9)
find_cudnn_library(cudnn_cnn)
find_cudnn_library(cudnn_adv)
find_cudnn_library(cudnn_graph)
find_cudnn_library(cudnn_ops)
find_cudnn_library(cudnn_engines_runtime_compiled)
find_cudnn_library(cudnn_engines_precompiled)
find_cudnn_library(cudnn_heuristic)

target_link_libraries(
CUDNN::cudnn_all
INTERFACE
CUDNN::cudnn_adv
CUDNN::cudnn_ops
CUDNN::cudnn_cnn
CUDNN::cudnn_graph
CUDNN::cudnn_engines_runtime_compiled
CUDNN::cudnn_engines_precompiled
CUDNN::cudnn_heuristic
)
endif()

mark_as_advanced(CUDNN_INCLUDE_DIR)
12 changes: 12 additions & 0 deletions cmake/external/cudnn_frontend.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
include(FetchContent)
FetchContent_Declare(
cudnn_frontend
URL ${DEP_URL_cudnn_frontend}
URL_HASH SHA1=${DEP_SHA1_cudnn_frontend}
)

set(CUDNN_FRONTEND_BUILD_SAMPLES OFF)
set(CUDNN_FRONTEND_BUILD_UNIT_TESTS OFF)
set(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS OFF)
set(CUDNN_PATH ${onnxruntime_CUDNN_HOME})
FetchContent_MakeAvailable(cudnn_frontend)
16 changes: 6 additions & 10 deletions cmake/external/onnxruntime_external_deps.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -587,20 +587,16 @@ endif()

message("Finished fetching external dependencies")


set(onnxruntime_LINK_DIRS )

if (onnxruntime_USE_CUDA)
#TODO: combine onnxruntime_CUDNN_HOME and onnxruntime_CUDA_HOME, assume they are the same
find_package(CUDAToolkit REQUIRED)
if (WIN32)
if(onnxruntime_CUDNN_HOME)
list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDNN_HOME}/lib ${onnxruntime_CUDNN_HOME}/lib/x64)
endif()
else()
if(onnxruntime_CUDNN_HOME)
list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDNN_HOME}/lib ${onnxruntime_CUDNN_HOME}/lib64)
endif()

if(onnxruntime_CUDNN_HOME)
file(TO_CMAKE_PATH ${onnxruntime_CUDNN_HOME} onnxruntime_CUDNN_HOME)
set(CUDNN_PATH ${onnxruntime_CUDNN_HOME})
endif()
include(cuDNN)
endif()

if(onnxruntime_USE_SNPE)
Expand Down
2 changes: 1 addition & 1 deletion cmake/onnxruntime_framework.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ endif()
if(onnxruntime_USE_TENSORRT OR onnxruntime_USE_NCCL)
# TODO: for now, core framework depends on CUDA. It should be moved to TensorRT EP
# TODO: provider_bridge_ort.cc should not include nccl.h
target_include_directories(onnxruntime_framework PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} ${onnxruntime_CUDNN_HOME}/include PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(onnxruntime_framework PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
else()
target_include_directories(onnxruntime_framework PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
endif()
Expand Down
14 changes: 9 additions & 5 deletions cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,16 @@
target_compile_definitions(${target} PRIVATE USE_CUDA_MINIMAL)
target_link_libraries(${target} PRIVATE ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface CUDA::cudart)
else()
target_link_libraries(${target} PRIVATE CUDA::cublasLt CUDA::cublas cudnn CUDA::curand CUDA::cufft CUDA::cudart
${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface)
if(onnxruntime_CUDNN_HOME)
target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include)
target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib)
include(cudnn_frontend) # also defines CUDNN::*
if (onnxruntime_USE_CUDA_NHWC_OPS)
if(CUDNN_MAJOR_VERSION GREATER 8)
add_compile_definitions(ENABLE_CUDA_NHWC_OPS)
else()
message( WARNING "To compile with NHWC ops enabled please compile against cuDNN 9 or newer." )
endif()
endif()
target_link_libraries(${target} PRIVATE CUDA::cublasLt CUDA::cublas CUDNN::cudnn_all cudnn_frontend CUDA::curand CUDA::cufft CUDA::cudart
${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface)
endif()

if (onnxruntime_USE_TRITON_KERNEL)
Expand Down
5 changes: 1 addition & 4 deletions cmake/onnxruntime_providers_tensorrt.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@
if(onnxruntime_CUDA_MINIMAL)
set(trt_link_libs ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY})
else()
set(trt_link_libs cudnn cublas ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY})
set(trt_link_libs CUDNN::cudnn_all cublas ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY})
endif()
file(GLOB_RECURSE onnxruntime_providers_tensorrt_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/tensorrt/*.h"
Expand All @@ -183,9 +183,6 @@
endif()
target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS}
PUBLIC ${CUDAToolkit_INCLUDE_DIRS})
if(onnxruntime_CUDNN_HOME)
target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${onnxruntime_CUDNN_HOME}/include)
endif()

# ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found
set_target_properties(onnxruntime_providers_tensorrt PROPERTIES LINKER_LANGUAGE CUDA)
Expand Down
6 changes: 1 addition & 5 deletions cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,7 @@ endif()
onnxruntime_add_include_to_target(onnxruntime_pybind11_state Python::Module Python::NumPy)
target_include_directories(onnxruntime_pybind11_state PRIVATE ${ONNXRUNTIME_ROOT} ${pybind11_INCLUDE_DIRS})
if(onnxruntime_USE_CUDA)
target_include_directories(onnxruntime_pybind11_state PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
# cudnn_home is optional for Window when cuda and cudnn are installed in the same directory.
if(onnxruntime_CUDNN_HOME)
target_include_directories(onnxruntime_pybind11_state PRIVATE ${onnxruntime_CUDNN_HOME}/include)
endif()
target_include_directories(onnxruntime_pybind11_state PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${CUDNN_INCLUDE_DIR})
endif()
if(onnxruntime_USE_CANN)
target_include_directories(onnxruntime_pybind11_state PRIVATE ${onnxruntime_CANN_HOME}/include)
Expand Down
2 changes: 2 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ set(provider_excluded_files
"rnn/rnn_impl.cu"
"rnn/rnn_impl.h"
"shared_inc/cuda_call.h"
"shared_inc/cudnn_fe_call.h"
"shared_inc/fpgeneric.h"
"cuda_allocator.cc"
"cuda_allocator.h"
Expand All @@ -171,6 +172,7 @@ set(provider_excluded_files
"cuda_utils.cu"
"cudnn_common.cc"
"cudnn_common.h"
"cudnn_fe_call.cc"
"cupti_manager.cc"
"cupti_manager.h"
"fpgeneric.cu"
Expand Down
4 changes: 1 addition & 3 deletions cmake/onnxruntime_session.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ if (onnxruntime_USE_EXTENSIONS)
endif()
add_dependencies(onnxruntime_session ${onnxruntime_EXTERNAL_DEPENDENCIES})
set_target_properties(onnxruntime_session PROPERTIES FOLDER "ONNXRuntime")
if (onnxruntime_USE_CUDA)
target_include_directories(onnxruntime_session PRIVATE ${onnxruntime_CUDNN_HOME}/include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
endif()

if (onnxruntime_USE_ROCM)
target_compile_options(onnxruntime_session PRIVATE -Wno-sign-compare -D__HIP_PLATFORM_AMD__=1 -D__HIP_PLATFORM_HCC__=1)
target_include_directories(onnxruntime_session PRIVATE ${onnxruntime_ROCM_HOME}/hipfft/include ${onnxruntime_ROCM_HOME}/include ${onnxruntime_ROCM_HOME}/hipcub/include ${onnxruntime_ROCM_HOME}/hiprand/include ${onnxruntime_ROCM_HOME}/rocrand/include)
Expand Down
7 changes: 0 additions & 7 deletions cmake/onnxruntime_training.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ endif()

target_include_directories(onnxruntime_training PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT} ${ORTTRAINING_ROOT} ${eigen_INCLUDE_DIRS} PUBLIC ${onnxruntime_graph_header} ${MPI_CXX_INCLUDE_DIRS})

if (onnxruntime_USE_CUDA)
target_include_directories(onnxruntime_training PRIVATE ${onnxruntime_CUDNN_HOME}/include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
endif()

if (onnxruntime_USE_NCCL)
target_include_directories(onnxruntime_training PRIVATE ${NCCL_INCLUDE_DIRS})
endif()
Expand Down Expand Up @@ -81,9 +77,6 @@ if (onnxruntime_BUILD_UNIT_TESTS)

target_include_directories(onnxruntime_training_runner PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT} ${ORTTRAINING_ROOT} ${eigen_INCLUDE_DIRS} PUBLIC ${onnxruntime_graph_header})
target_link_libraries(onnxruntime_training_runner PRIVATE nlohmann_json::nlohmann_json)
if (onnxruntime_USE_CUDA)
target_include_directories(onnxruntime_training_runner PUBLIC ${onnxruntime_CUDNN_HOME}/include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
endif()

if (onnxruntime_USE_NCCL)
target_include_directories(onnxruntime_training_runner PRIVATE ${NCCL_INCLUDE_DIRS})
Expand Down
8 changes: 4 additions & 4 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,15 @@ function(AddTest)
if(onnxruntime_USE_CUDA)
#XXX: we should not need to do this. onnxruntime_test_all.exe should not have direct dependency on CUDA DLLs,
# otherwise it will impact when CUDA DLLs can be unloaded.
target_link_libraries(${_UT_TARGET} PRIVATE CUDA::cudart)
target_link_libraries(${_UT_TARGET} PRIVATE CUDA::cudart cudnn_frontend)
endif()
target_link_libraries(${_UT_TARGET} PRIVATE ${_UT_LIBS} GTest::gtest GTest::gmock ${onnxruntime_EXTERNAL_LIBRARIES})
endif()

onnxruntime_add_include_to_target(${_UT_TARGET} date::date flatbuffers::flatbuffers)
target_include_directories(${_UT_TARGET} PRIVATE ${TEST_INC_DIR})
if (onnxruntime_USE_CUDA)
target_include_directories(${_UT_TARGET} PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${onnxruntime_CUDNN_HOME}/include)
target_include_directories(${_UT_TARGET} PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${CUDNN_INCLUDE_DIR})
if (onnxruntime_USE_NCCL)
target_include_directories(${_UT_TARGET} PRIVATE ${NCCL_INCLUDE_DIRS})
endif()
Expand Down Expand Up @@ -392,7 +392,7 @@ if (onnxruntime_USE_CUDA AND NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_R
)
list(APPEND onnxruntime_test_providers_src ${onnxruntime_test_providers_cuda_src})

if (onnxruntime_USE_CUDA_NHWC_OPS)
if (onnxruntime_USE_CUDA_NHWC_OPS AND CUDNN_MAJOR_VERSION GREATER 8)
file(GLOB onnxruntime_test_providers_cuda_nhwc_src CONFIGURE_DEPENDS
"${TEST_SRC_DIR}/providers/cuda/nhwc/*.cc"
)
Expand Down Expand Up @@ -1498,7 +1498,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
list(APPEND custom_op_src_patterns
"${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu"
"${TEST_SRC_DIR}/testdata/custom_op_library/cuda/cuda_ops.*")
list(APPEND custom_op_lib_include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${onnxruntime_CUDNN_HOME}/include)
list(APPEND custom_op_lib_include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${CUDNN_INCLUDE_DIR})
if (HAS_QSPECTRE)
list(APPEND custom_op_lib_option "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /Qspectre>")
endif()
Expand Down
2 changes: 2 additions & 0 deletions include/onnxruntime/core/providers/cuda/cuda_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ struct CudaContext : public CustomOpContext {
bool enable_skip_layer_norm_strict_mode = false;
bool prefer_nhwc = false;
bool use_tf32 = true;
bool fuse_conv_bias = true;

void Init(const OrtKernelContext& kernel_ctx) {
cuda_stream = FetchResource<cudaStream_t>(kernel_ctx, CudaResource::cuda_stream_t);
Expand All @@ -57,6 +58,7 @@ struct CudaContext : public CustomOpContext {
kernel_ctx, CudaResource::enable_skip_layer_norm_strict_mode_t);
prefer_nhwc = FetchResource<bool>(kernel_ctx, CudaResource::prefer_nhwc_t);
use_tf32 = FetchResource<bool>(kernel_ctx, CudaResource::use_tf32_t);
fuse_conv_bias = FetchResource<bool>(kernel_ctx, CudaResource::fuse_conv_bias_t);
}

template <typename T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,6 @@
int prefer_nhwc = 0; // make the CUDA EP NHWC preferred
int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not
int use_tf32 = 1; // use TF32
int fuse_conv_bias = 0; // Enable CUDNN Frontend kernel fusing, results in JIT compiles

Check warning on line 41 in include/onnxruntime/core/providers/cuda/cuda_provider_options.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: include/onnxruntime/core/providers/cuda/cuda_provider_options.h:41: Lines should be <= 120 characters long [whitespace/line_length] [2]
int sdpa_kernel = 0; // Scaled Dot Product Attention kernel option
};
1 change: 1 addition & 0 deletions include/onnxruntime/core/providers/cuda/cuda_resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ enum CudaResource : int {
enable_skip_layer_norm_strict_mode_t,
prefer_nhwc_t,
use_tf32_t,
fuse_conv_bias_t
};
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/diffusion/nhwc_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace cuda {
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Conv<T, true>);
onnxruntime::cuda::Conv<T, true>);

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
Expand Down
Loading
Loading