Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into skottmckay/CoremL_MLP…
Browse files Browse the repository at this point in the history
…rogram_MoreOps_PR
  • Loading branch information
skottmckay committed Feb 27, 2024
2 parents 99c40f8 + c20ced4 commit 8608f25
Show file tree
Hide file tree
Showing 26 changed files with 599 additions and 67 deletions.
24 changes: 20 additions & 4 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -324,15 +324,27 @@ if (onnxruntime_USE_ROCM)
endif()

# replicate strategy used by pytorch to get ROCM_VERSION
# https://github.com/pytorch/pytorch/blob/8eb21488fdcdb8b0e6fa2e46179b5fa6c42e75af/cmake/public/LoadHIP.cmake#L153-L173
file(READ "${onnxruntime_ROCM_HOME}/.info/version-dev" ROCM_VERSION_DEV_RAW)
string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_DEV_MATCH ${ROCM_VERSION_DEV_RAW})
if (ROCM_VERSION_DEV_MATCH)
# https://github.com/pytorch/pytorch/blob/5c5b71b6eebae76d744261715231093e62f0d090/cmake/public/LoadHIP.cmake
# with modification
if (EXISTS "${onnxruntime_ROCM_HOME}/.info/version-dev")
file(READ "${onnxruntime_ROCM_HOME}/.info/version-dev" ROCM_VERSION_DEV_RAW)
string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_MATCH ${ROCM_VERSION_DEV_RAW})
elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm_version.h")
file(READ "${onnxruntime_ROCM_HOME}/include/rocm_version.h" ROCM_VERSION_H_RAW)
string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW})
elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h")
file(READ "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h" ROCM_VERSION_H_RAW)
string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW})
endif()

if (ROCM_VERSION_MATCH)
set(ROCM_VERSION_DEV_MAJOR ${CMAKE_MATCH_1})
set(ROCM_VERSION_DEV_MINOR ${CMAKE_MATCH_2})
set(ROCM_VERSION_DEV_PATCH ${CMAKE_MATCH_3})
set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}")
math(EXPR ROCM_VERSION_DEV_INT "(${ROCM_VERSION_DEV_MAJOR}*10000) + (${ROCM_VERSION_DEV_MINOR}*100) + ${ROCM_VERSION_DEV_PATCH}")
else()
message(FATAL_ERROR "Cannot determine ROCm version string")
endif()
message("\n***** ROCm version from ${onnxruntime_ROCM_HOME}/.info/version-dev ****\n")
message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}")
Expand Down Expand Up @@ -1400,6 +1412,10 @@ endif()
if (onnxruntime_USE_CUDA)
set(CMAKE_CUDA_RUNTIME_LIBRARY Shared)
set(CMAKE_CUDA_STANDARD 17)
if(onnxruntime_CUDA_HOME)
file(TO_CMAKE_PATH CUDAToolkit_ROOT ${onnxruntime_CUDA_HOME})
endif()
find_package(CUDAToolkit REQUIRED)
if(onnxruntime_CUDNN_HOME)
file(TO_CMAKE_PATH ${onnxruntime_CUDNN_HOME} onnxruntime_CUDNN_HOME)
endif()
Expand Down
2 changes: 1 addition & 1 deletion cmake/adjust_global_compile_flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ endif()


macro(check_nvcc_compiler_flag _FLAG _RESULT)
execute_process(COMMAND ${onnxruntime_CUDA_HOME}/bin/nvcc "${_FLAG}" RESULT_VARIABLE NVCC_OUT ERROR_VARIABLE NVCC_ERROR)
execute_process(COMMAND ${CUDAToolkit_BIN_DIR}/nvcc "${_FLAG}" RESULT_VARIABLE NVCC_OUT ERROR_VARIABLE NVCC_ERROR)
message("NVCC_ERROR = ${NVCC_ERROR}")
message("NVCC_OUT = ${NVCC_OUT}")
if ("${NVCC_OUT}" MATCHES "0")
Expand Down
3 changes: 1 addition & 2 deletions cmake/external/onnxruntime_external_deps.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -556,16 +556,15 @@ message("Finished fetching external dependencies")
set(onnxruntime_LINK_DIRS )
if (onnxruntime_USE_CUDA)
#TODO: combine onnxruntime_CUDNN_HOME and onnxruntime_CUDA_HOME, assume they are the same
find_package(CUDAToolkit REQUIRED)
if (WIN32)
if(onnxruntime_CUDNN_HOME)
list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDNN_HOME}/lib ${onnxruntime_CUDNN_HOME}/lib/x64)
endif()
list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDA_HOME}/x64/lib64)
else()
if(onnxruntime_CUDNN_HOME)
list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDNN_HOME}/lib ${onnxruntime_CUDNN_HOME}/lib64)
endif()
list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDA_HOME}/lib64)
endif()
endif()

Expand Down
20 changes: 10 additions & 10 deletions cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,16 @@
add_dependencies(${target} onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES})
if(onnxruntime_CUDA_MINIMAL)
target_compile_definitions(${target} PRIVATE USE_CUDA_MINIMAL)
target_link_libraries(${target} PRIVATE ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface)
target_link_libraries(${target} PRIVATE ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface CUDA::cudart)
else()
target_link_libraries(${target} PRIVATE cublasLt cublas cudnn curand cufft ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface)
target_link_libraries(${target} PRIVATE CUDA::cublasLt CUDA::cublas cudnn CUDA::curand CUDA::cufft CUDA::cudart
${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface)
if(onnxruntime_CUDNN_HOME)
target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include)
target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib)
endif()
endif()

if (onnxruntime_USE_TRITON_KERNEL)
# compile triton kernel, generate .a and .h files
include(onnxruntime_compile_triton_kernel.cmake)
Expand All @@ -196,25 +197,24 @@
target_include_directories(${target} PRIVATE ${triton_kernel_header_dir})
target_link_libraries(${target} PUBLIC -Wl,--whole-archive ${triton_kernel_obj_file} -Wl,--no-whole-archive)
# lib cuda needed by cuLaunchKernel
target_link_libraries(${target} PRIVATE cuda)
target_link_libraries(${target} PRIVATE CUDA::cuda_driver)
endif()

include(cutlass)
target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples)

target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES}
PUBLIC ${CUDAToolkit_INCLUDE_DIRS})
# ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found
set_target_properties(${target} PROPERTIES LINKER_LANGUAGE CUDA)
set_target_properties(${target} PROPERTIES FOLDER "ONNXRuntime")

if (onnxruntime_ENABLE_CUDA_PROFILING) # configure cupti for cuda profiling
target_include_directories(${target} PRIVATE ${onnxruntime_CUDA_HOME}/extras/CUPTI/include)
target_link_directories(${target} PRIVATE ${onnxruntime_CUDA_HOME}/extras/CUPTI/lib64)
target_link_libraries(${target} PRIVATE cupti)
target_link_libraries(${target} PRIVATE CUDA::cupti)
endif()

if (onnxruntime_ENABLE_NVTX_PROFILE AND NOT WIN32)
target_link_libraries(${target} PRIVATE nvToolsExt)
if (onnxruntime_ENABLE_NVTX_PROFILE)
target_link_libraries(${target} PRIVATE CUDA::nvtx3)
endif()

if (onnxruntime_ENABLE_TRAINING_OPS)
Expand Down
11 changes: 6 additions & 5 deletions cmake/onnxruntime_providers_tensorrt.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
set(BUILD_LIBRARY_ONLY 1)
add_definitions("-DONNX_ML=1")
add_definitions("-DONNX_NAMESPACE=onnx")
set(CUDA_INCLUDE_DIRS ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
set(CUDA_INCLUDE_DIRS ${CUDAToolkit_INCLUDE_DIRS})
set(TENSORRT_ROOT ${onnxruntime_TENSORRT_HOME})
set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
set(PROTOBUF_LIBRARY ${PROTOBUF_LIB})
Expand Down Expand Up @@ -58,7 +58,7 @@
URL_HASH SHA1=${DEP_SHA1_onnx_tensorrt}
)
if (NOT CUDA_INCLUDE_DIR)
set(CUDA_INCLUDE_DIR ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) # onnx-tensorrt repo needs this variable to build
set(CUDA_INCLUDE_DIR ${CUDAToolkit_INCLUDE_DIRS}) # onnx-tensorrt repo needs this variable to build
endif()
# The onnx_tensorrt repo contains a test program, getSupportedAPITest, which doesn't support Windows. It uses
# unistd.h. So we must exclude it from our build. onnxruntime_fetchcontent_makeavailable is for the purpose.
Expand Down Expand Up @@ -102,11 +102,12 @@
onnxruntime_add_include_to_target(onnxruntime_providers_tensorrt onnxruntime_common onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface)
add_dependencies(onnxruntime_providers_tensorrt onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES})
if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER)
target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${trt_link_libs} cudart ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS})
target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS} PUBLIC CUDA::cudart)
else()
target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${onnxparser_link_libs} ${trt_link_libs} cudart ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS})
target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart)
endif()
target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS}
PUBLIC ${CUDAToolkit_INCLUDE_DIRS})
if(onnxruntime_CUDNN_HOME)
target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${onnxruntime_CUDNN_HOME}/include)
endif()
Expand Down
5 changes: 1 addition & 4 deletions cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,7 @@ if (WIN32)
get_filename_component(CUDNN_DLL_NAME ${CUDNN_DLL_PATH} NAME_WE)
string(REPLACE "cudnn64_" "" CUDNN_VERSION "${CUDNN_DLL_NAME}")
if(NOT onnxruntime_CUDA_VERSION)
message("Reading json file ${onnxruntime_CUDA_HOME}/version.json")
set(CUDA_SDK_JSON_FILE_PATH "${onnxruntime_CUDA_HOME}/version.json")
file(READ ${CUDA_SDK_JSON_FILE_PATH} CUDA_SDK_JSON_CONTENT)
string(JSON onnxruntime_CUDA_VERSION GET ${CUDA_SDK_JSON_CONTENT} "cuda" "version")
set(onnxruntime_CUDA_VERSION ${CUDAToolkit_VERSION})
message("onnxruntime_CUDA_VERSION=${onnxruntime_CUDA_VERSION}")
endif()
file(APPEND "${VERSION_INFO_FILE}"
Expand Down
4 changes: 2 additions & 2 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ function(AddTest)
if(onnxruntime_USE_CUDA)
#XXX: we should not need to do this. onnxruntime_test_all.exe should not have direct dependency on CUDA DLLs,
# otherwise it will impact when CUDA DLLs can be unloaded.
target_link_libraries(${_UT_TARGET} PRIVATE cudart)
target_link_libraries(${_UT_TARGET} PRIVATE CUDA::cudart)
endif()
target_link_libraries(${_UT_TARGET} PRIVATE ${_UT_LIBS} GTest::gtest GTest::gmock ${onnxruntime_EXTERNAL_LIBRARIES})
endif()
Expand Down Expand Up @@ -1268,7 +1268,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
list(APPEND onnxruntime_shared_lib_test_LIBS cpuinfo)
endif()
if (onnxruntime_USE_CUDA)
list(APPEND onnxruntime_shared_lib_test_LIBS cudart)
list(APPEND onnxruntime_shared_lib_test_LIBS CUDA::cudart)
endif()
if (onnxruntime_USE_ROCM)
list(APPEND onnxruntime_shared_lib_test_LIBS hip::host)
Expand Down
9 changes: 9 additions & 0 deletions js/common/lib/env.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export declare namespace Env {
/**
* set or get a boolean value indicating whether to enable trace.
*
* @deprecated Use `env.trace` instead. If `env.trace` is set, this property will be ignored.
* @defaultValue `false`
*/
trace?: boolean;
Expand Down Expand Up @@ -167,13 +168,21 @@ export interface Env {
* @defaultValue `'warning'`
*/
logLevel?: 'verbose'|'info'|'warning'|'error'|'fatal';

/**
* Indicate whether run in debug mode.
*
* @defaultValue `false`
*/
debug?: boolean;

/**
* set or get a boolean value indicating whether to enable trace.
*
* @defaultValue `false`
*/
trace?: boolean;

/**
* Get version of the current package.
*/
Expand Down
6 changes: 3 additions & 3 deletions js/common/lib/trace.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import {env} from './env-impl.js';

export const TRACE = (deviceType: string, label: string) => {
if (!env.wasm.trace) {
if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) {
return;
}
// eslint-disable-next-line no-console
Expand All @@ -30,14 +30,14 @@ const TRACE_FUNC = (msg: string, extraMsg?: string) => {
};

export const TRACE_FUNC_BEGIN = (extraMsg?: string) => {
if (!env.wasm.trace) {
if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) {
return;
}
TRACE_FUNC('BEGIN', extraMsg);
};

export const TRACE_FUNC_END = (extraMsg?: string) => {
if (!env.wasm.trace) {
if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) {
return;
}
TRACE_FUNC('END', extraMsg);
Expand Down
3 changes: 2 additions & 1 deletion js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,8 @@ export class WebGpuBackend {
}
setQueryType(): void {
this.queryType = 'none';
if (this.env.webgpu.profiling?.mode === 'default' || this.env.wasm.trace) {
if (this.env.webgpu.profiling?.mode === 'default' ||
(typeof this.env.trace === 'undefined' ? this.env.wasm.trace : this.env.trace)) {
if (this.device.features.has('chromium-experimental-timestamp-query-inside-passes')) {
this.queryType = 'inside-passes';
} else if (this.device.features.has('timestamp-query')) {
Expand Down
5 changes: 0 additions & 5 deletions onnxruntime/core/providers/cuda/nvtx_profile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,8 @@
#ifdef ENABLE_NVTX_PROFILE
#include "nvtx_profile.h"
#include "core/common/common.h"
#if defined(_WIN32) || defined(WIN32) || defined(__CYGWIN__) || defined(__MINGW32__) || defined(__BORLANDC__)
#include <nvtx3/nvToolsExt.h>
#include <nvtx3/nvToolsExtCuda.h>
#else
#include <nvToolsExt.h>
#include <nvToolsExtCuda.h>
#endif

namespace onnxruntime {
namespace profile {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
{"LessOrEqual", {"lesserOrEqual", false}},
{"Log", {"log", false}},
{"LpPool", {"l2Pool2d", false}},
{"MatMul", {"matmul", false}},
{"MatMul", {"matmul", true}},
{"MatMulInteger", {"matmulInteger", false}},
{"Max", {"max", true}},
{"MaxPool", {"maxPool2d", true}},
Expand Down
14 changes: 12 additions & 2 deletions onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class GemmOpBuilder : public BaseOpBuilder {

// Add operator related.
Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& /* logger */) const {
const logging::Logger& logger) const {
const auto& op_type = node.OpType();
const auto& input_defs = node.InputDefs();
const size_t a_idx = 0, b_idx = 1, c_idx = 2; // A*B+C
Expand All @@ -38,7 +38,17 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
emscripten::val b = model_builder.GetOperand(node.InputDefs()[b_idx]->Name());
emscripten::val output = emscripten::val::object();
if (op_type == "MatMul") {
output = model_builder.GetBuilder().call<emscripten::val>("matmul", a, b);
std::vector<int64_t> a_shape;
if (!GetShape(*input_defs[a_idx], a_shape, logger)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Can not get shape of A.");
}
// The inputs of MatMul must be at least 3D for WebNN CPU backend. Use GEMM for 2D case.
// TODO: Remove this workaround when it is fixed in Chromium.
if (model_builder.GetWebnnDeviceType() == WebnnDeviceType::CPU && a_shape.size() == 2) {
output = model_builder.GetBuilder().call<emscripten::val>("gemm", a, b);
} else {
output = model_builder.GetBuilder().call<emscripten::val>("matmul", a, b);
}
} else if (op_type == "MatMulInteger") {
emscripten::val a_zero_point = emscripten::val::null();
emscripten::val b_zero_point = emscripten::val::null();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ def fuse(

self.nodes_to_remove.extend(subgraph_nodes)
fused_node = onnx.helper.make_node(
self.fused_op_type, inputs=[subgraph_input], outputs=[subgraph_output], p=2, axis=-1
self.fused_op_type,
name=self.create_unique_node_name(),
inputs=[subgraph_input],
outputs=[subgraph_output],
p=2,
axis=-1,
)
self.nodes_to_add.append(fused_node)
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ def qnn_preprocess_model(model_input: Path, model_output: Path, fuse_layernorm:
if fusion_layernorm.apply():
modified = True

# Make sure all nodes have a name.
unnamed_node_prefix = "qnn_preproc_node_"
available_suffix = onnx_model.get_largest_node_name_suffix(unnamed_node_prefix) + 1
for node in onnx_model.model.graph.node:
if node.op_type != "Constant" and not node.name:
new_node_name = f"{unnamed_node_prefix}{available_suffix!s}"
available_suffix += 1
node.name = new_node_name
modified = True
logging.warning(f"Node of type {node.op_type} does not have a name. Renamed to {new_node_name}.")

if modified:
onnx_model.topological_sort()
onnx.save_model(model, model_output)
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/python/tools/quantization/fusions/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def __init__(self, model: ONNXModel, fused_op_type: str, search_op_type: str):
self.nodes_to_remove: list = []
self.nodes_to_add: list = []

self._new_node_name_prefix = self.fused_op_type + "_fused_" + self.search_op_type + "_"
self._new_node_name_suffix = None # int|None used to create unique node names for the fused ops.

def fuse(
self,
node: onnx.NodeProto,
Expand Down Expand Up @@ -57,6 +60,18 @@ def apply(self) -> bool:

return graph_updated

def create_unique_node_name(self):
prefix = self._new_node_name_prefix

if self._new_node_name_suffix is None:
largest_suffix: int = self.model.get_largest_node_name_suffix(prefix)
self._new_node_name_suffix = largest_suffix + 1

new_name = f"{prefix}{self._new_node_name_suffix!s}"
self._new_node_name_suffix += 1

return new_name

@staticmethod
def is_safe_to_fuse_nodes(
nodes_to_remove: list[onnx.NodeProto],
Expand Down
Loading

0 comments on commit 8608f25

Please sign in to comment.