Skip to content

Commit

Permalink
Merge branch 'main' into yuslepukhin/resize_cuda_18
Browse files Browse the repository at this point in the history
  • Loading branch information
yuslepukhin committed Feb 28, 2024
2 parents b54e7df + a93c31e commit 8487adf
Show file tree
Hide file tree
Showing 103 changed files with 7,420 additions and 834 deletions.
30 changes: 22 additions & 8 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -324,15 +324,27 @@ if (onnxruntime_USE_ROCM)
endif()

# replicate strategy used by pytorch to get ROCM_VERSION
# https://github.com/pytorch/pytorch/blob/8eb21488fdcdb8b0e6fa2e46179b5fa6c42e75af/cmake/public/LoadHIP.cmake#L153-L173
file(READ "${onnxruntime_ROCM_HOME}/.info/version-dev" ROCM_VERSION_DEV_RAW)
string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_DEV_MATCH ${ROCM_VERSION_DEV_RAW})
if (ROCM_VERSION_DEV_MATCH)
# https://github.com/pytorch/pytorch/blob/5c5b71b6eebae76d744261715231093e62f0d090/cmake/public/LoadHIP.cmake
# with modification
if (EXISTS "${onnxruntime_ROCM_HOME}/.info/version-dev")
file(READ "${onnxruntime_ROCM_HOME}/.info/version-dev" ROCM_VERSION_DEV_RAW)
string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_MATCH ${ROCM_VERSION_DEV_RAW})
elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm_version.h")
file(READ "${onnxruntime_ROCM_HOME}/include/rocm_version.h" ROCM_VERSION_H_RAW)
string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW})
elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h")
file(READ "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h" ROCM_VERSION_H_RAW)
string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW})
endif()

if (ROCM_VERSION_MATCH)
set(ROCM_VERSION_DEV_MAJOR ${CMAKE_MATCH_1})
set(ROCM_VERSION_DEV_MINOR ${CMAKE_MATCH_2})
set(ROCM_VERSION_DEV_PATCH ${CMAKE_MATCH_3})
set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}")
math(EXPR ROCM_VERSION_DEV_INT "(${ROCM_VERSION_DEV_MAJOR}*10000) + (${ROCM_VERSION_DEV_MINOR}*100) + ${ROCM_VERSION_DEV_PATCH}")
else()
message(FATAL_ERROR "Cannot determine ROCm version string")
endif()
message("\n***** ROCm version from ${onnxruntime_ROCM_HOME}/.info/version-dev ****\n")
message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}")
Expand Down Expand Up @@ -1400,6 +1412,10 @@ endif()
if (onnxruntime_USE_CUDA)
set(CMAKE_CUDA_RUNTIME_LIBRARY Shared)
set(CMAKE_CUDA_STANDARD 17)
if(onnxruntime_CUDA_HOME)
file(TO_CMAKE_PATH CUDAToolkit_ROOT ${onnxruntime_CUDA_HOME})
endif()
find_package(CUDAToolkit REQUIRED)
if(onnxruntime_CUDNN_HOME)
file(TO_CMAKE_PATH ${onnxruntime_CUDNN_HOME} onnxruntime_CUDNN_HOME)
endif()
Expand Down Expand Up @@ -1729,14 +1745,12 @@ if(onnxruntime_BUILD_KERNEL_EXPLORER)
endif()

# When GDK_PLATFORM is set then WINAPI_FAMILY is defined in gdk_toolchain.cmake (along with other relevant flags/definitions).
if (WIN32 AND NOT GDK_PLATFORM)
if (WIN32 AND NOT GDK_PLATFORM AND NOT CMAKE_CROSSCOMPILING)
if (NOT CMAKE_CXX_STANDARD_LIBRARIES MATCHES kernel32.lib)
# On onecore, link to the onecore build of the MSVC runtime
get_filename_component(msvc_path "${CMAKE_C_COMPILER}/../../../.." ABSOLUTE)
link_directories(BEFORE "${msvc_path}/lib/onecore/${onnxruntime_target_platform}")
# The .lib files in the MSVC runtime have a DEFAULITLIB entry for onecore.lib, which in turn links to reverse forwarders.
# We ignore that entry and use onecore_apiset.lib instead, since system components must not rely on reverse forwarders.
add_link_options("/NODEFAULTLIB:onecore.lib")
# The .lib files in the MSVC runtime have a DEFAULITLIB entry for onecore.lib, but it shold not cause any conflict with onecoreuap.lib
endif()
endif()

Expand Down
2 changes: 1 addition & 1 deletion cmake/adjust_global_compile_flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ endif()


macro(check_nvcc_compiler_flag _FLAG _RESULT)
execute_process(COMMAND ${onnxruntime_CUDA_HOME}/bin/nvcc "${_FLAG}" RESULT_VARIABLE NVCC_OUT ERROR_VARIABLE NVCC_ERROR)
execute_process(COMMAND ${CUDAToolkit_BIN_DIR}/nvcc "${_FLAG}" RESULT_VARIABLE NVCC_OUT ERROR_VARIABLE NVCC_ERROR)
message("NVCC_ERROR = ${NVCC_ERROR}")
message("NVCC_OUT = ${NVCC_OUT}")
if ("${NVCC_OUT}" MATCHES "0")
Expand Down
3 changes: 1 addition & 2 deletions cmake/external/onnxruntime_external_deps.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -556,16 +556,15 @@ message("Finished fetching external dependencies")
set(onnxruntime_LINK_DIRS )
if (onnxruntime_USE_CUDA)
#TODO: combine onnxruntime_CUDNN_HOME and onnxruntime_CUDA_HOME, assume they are the same
find_package(CUDAToolkit REQUIRED)
if (WIN32)
if(onnxruntime_CUDNN_HOME)
list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDNN_HOME}/lib ${onnxruntime_CUDNN_HOME}/lib/x64)
endif()
list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDA_HOME}/x64/lib64)
else()
if(onnxruntime_CUDNN_HOME)
list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDNN_HOME}/lib ${onnxruntime_CUDNN_HOME}/lib64)
endif()
list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDA_HOME}/lib64)
endif()
endif()

Expand Down
20 changes: 10 additions & 10 deletions cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,16 @@
add_dependencies(${target} onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES})
if(onnxruntime_CUDA_MINIMAL)
target_compile_definitions(${target} PRIVATE USE_CUDA_MINIMAL)
target_link_libraries(${target} PRIVATE ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface)
target_link_libraries(${target} PRIVATE ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface CUDA::cudart)
else()
target_link_libraries(${target} PRIVATE cublasLt cublas cudnn curand cufft ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface)
target_link_libraries(${target} PRIVATE CUDA::cublasLt CUDA::cublas cudnn CUDA::curand CUDA::cufft CUDA::cudart
${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface)
if(onnxruntime_CUDNN_HOME)
target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include)
target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib)
endif()
endif()

if (onnxruntime_USE_TRITON_KERNEL)
# compile triton kernel, generate .a and .h files
include(onnxruntime_compile_triton_kernel.cmake)
Expand All @@ -196,25 +197,24 @@
target_include_directories(${target} PRIVATE ${triton_kernel_header_dir})
target_link_libraries(${target} PUBLIC -Wl,--whole-archive ${triton_kernel_obj_file} -Wl,--no-whole-archive)
# lib cuda needed by cuLaunchKernel
target_link_libraries(${target} PRIVATE cuda)
target_link_libraries(${target} PRIVATE CUDA::cuda_driver)
endif()

include(cutlass)
target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples)

target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES}
PUBLIC ${CUDAToolkit_INCLUDE_DIRS})
# ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found
set_target_properties(${target} PROPERTIES LINKER_LANGUAGE CUDA)
set_target_properties(${target} PROPERTIES FOLDER "ONNXRuntime")

if (onnxruntime_ENABLE_CUDA_PROFILING) # configure cupti for cuda profiling
target_include_directories(${target} PRIVATE ${onnxruntime_CUDA_HOME}/extras/CUPTI/include)
target_link_directories(${target} PRIVATE ${onnxruntime_CUDA_HOME}/extras/CUPTI/lib64)
target_link_libraries(${target} PRIVATE cupti)
target_link_libraries(${target} PRIVATE CUDA::cupti)
endif()

if (onnxruntime_ENABLE_NVTX_PROFILE AND NOT WIN32)
target_link_libraries(${target} PRIVATE nvToolsExt)
if (onnxruntime_ENABLE_NVTX_PROFILE)
target_link_libraries(${target} PRIVATE CUDA::nvtx3)
endif()

if (onnxruntime_ENABLE_TRAINING_OPS)
Expand Down
11 changes: 6 additions & 5 deletions cmake/onnxruntime_providers_tensorrt.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
set(BUILD_LIBRARY_ONLY 1)
add_definitions("-DONNX_ML=1")
add_definitions("-DONNX_NAMESPACE=onnx")
set(CUDA_INCLUDE_DIRS ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
set(CUDA_INCLUDE_DIRS ${CUDAToolkit_INCLUDE_DIRS})
set(TENSORRT_ROOT ${onnxruntime_TENSORRT_HOME})
set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
set(PROTOBUF_LIBRARY ${PROTOBUF_LIB})
Expand Down Expand Up @@ -58,7 +58,7 @@
URL_HASH SHA1=${DEP_SHA1_onnx_tensorrt}
)
if (NOT CUDA_INCLUDE_DIR)
set(CUDA_INCLUDE_DIR ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) # onnx-tensorrt repo needs this variable to build
set(CUDA_INCLUDE_DIR ${CUDAToolkit_INCLUDE_DIRS}) # onnx-tensorrt repo needs this variable to build
endif()
# The onnx_tensorrt repo contains a test program, getSupportedAPITest, which doesn't support Windows. It uses
# unistd.h. So we must exclude it from our build. onnxruntime_fetchcontent_makeavailable is for the purpose.
Expand Down Expand Up @@ -102,11 +102,12 @@
onnxruntime_add_include_to_target(onnxruntime_providers_tensorrt onnxruntime_common onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface)
add_dependencies(onnxruntime_providers_tensorrt onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES})
if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER)
target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${trt_link_libs} cudart ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS})
target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS} PUBLIC CUDA::cudart)
else()
target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${onnxparser_link_libs} ${trt_link_libs} cudart ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS})
target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart)
endif()
target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS}
PUBLIC ${CUDAToolkit_INCLUDE_DIRS})
if(onnxruntime_CUDNN_HOME)
target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${onnxruntime_CUDNN_HOME}/include)
endif()
Expand Down
5 changes: 1 addition & 4 deletions cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,7 @@ if (WIN32)
get_filename_component(CUDNN_DLL_NAME ${CUDNN_DLL_PATH} NAME_WE)
string(REPLACE "cudnn64_" "" CUDNN_VERSION "${CUDNN_DLL_NAME}")
if(NOT onnxruntime_CUDA_VERSION)
message("Reading json file ${onnxruntime_CUDA_HOME}/version.json")
set(CUDA_SDK_JSON_FILE_PATH "${onnxruntime_CUDA_HOME}/version.json")
file(READ ${CUDA_SDK_JSON_FILE_PATH} CUDA_SDK_JSON_CONTENT)
string(JSON onnxruntime_CUDA_VERSION GET ${CUDA_SDK_JSON_CONTENT} "cuda" "version")
set(onnxruntime_CUDA_VERSION ${CUDAToolkit_VERSION})
message("onnxruntime_CUDA_VERSION=${onnxruntime_CUDA_VERSION}")
endif()
file(APPEND "${VERSION_INFO_FILE}"
Expand Down
4 changes: 2 additions & 2 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ function(AddTest)
if(onnxruntime_USE_CUDA)
#XXX: we should not need to do this. onnxruntime_test_all.exe should not have direct dependency on CUDA DLLs,
# otherwise it will impact when CUDA DLLs can be unloaded.
target_link_libraries(${_UT_TARGET} PRIVATE cudart)
target_link_libraries(${_UT_TARGET} PRIVATE CUDA::cudart)
endif()
target_link_libraries(${_UT_TARGET} PRIVATE ${_UT_LIBS} GTest::gtest GTest::gmock ${onnxruntime_EXTERNAL_LIBRARIES})
endif()
Expand Down Expand Up @@ -1268,7 +1268,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
list(APPEND onnxruntime_shared_lib_test_LIBS cpuinfo)
endif()
if (onnxruntime_USE_CUDA)
list(APPEND onnxruntime_shared_lib_test_LIBS cudart)
list(APPEND onnxruntime_shared_lib_test_LIBS CUDA::cudart)
endif()
if (onnxruntime_USE_ROCM)
list(APPEND onnxruntime_shared_lib_test_LIBS hip::host)
Expand Down
4 changes: 2 additions & 2 deletions cmake/wcos_rules_override.cmake
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
set(CMAKE_C_STANDARD_LIBRARIES_INIT onecoreuap_apiset.lib)
set(CMAKE_CXX_STANDARD_LIBRARIES_INIT onecoreuap_apiset.lib)
set(CMAKE_C_STANDARD_LIBRARIES_INIT onecoreuap.lib)
set(CMAKE_CXX_STANDARD_LIBRARIES_INIT onecoreuap.lib)
9 changes: 9 additions & 0 deletions js/common/lib/env.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export declare namespace Env {
/**
* set or get a boolean value indicating whether to enable trace.
*
* @deprecated Use `env.trace` instead. If `env.trace` is set, this property will be ignored.
* @defaultValue `false`
*/
trace?: boolean;
Expand Down Expand Up @@ -167,13 +168,21 @@ export interface Env {
* @defaultValue `'warning'`
*/
logLevel?: 'verbose'|'info'|'warning'|'error'|'fatal';

/**
* Indicate whether run in debug mode.
*
* @defaultValue `false`
*/
debug?: boolean;

/**
* set or get a boolean value indicating whether to enable trace.
*
* @defaultValue `false`
*/
trace?: boolean;

/**
* Get version of the current package.
*/
Expand Down
6 changes: 3 additions & 3 deletions js/common/lib/trace.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import {env} from './env-impl.js';

export const TRACE = (deviceType: string, label: string) => {
if (!env.wasm.trace) {
if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) {
return;
}
// eslint-disable-next-line no-console
Expand All @@ -30,14 +30,14 @@ const TRACE_FUNC = (msg: string, extraMsg?: string) => {
};

export const TRACE_FUNC_BEGIN = (extraMsg?: string) => {
if (!env.wasm.trace) {
if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) {
return;
}
TRACE_FUNC('BEGIN', extraMsg);
};

export const TRACE_FUNC_END = (extraMsg?: string) => {
if (!env.wasm.trace) {
if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) {
return;
}
TRACE_FUNC('END', extraMsg);
Expand Down
35 changes: 30 additions & 5 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -385,11 +385,16 @@ export class WebGpuBackend {
// create info for inputs
const inputDatas: GpuData[] = [];
for (let i = 0; i < inputTensorViews.length; ++i) {
const gpuData = this.gpuDataManager.get(inputTensorViews[i].data);
const data = inputTensorViews[i].data;
// if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it.
if (data === 0) {
continue;
}
const gpuData = this.gpuDataManager.get(data);
if (!gpuData) {
throw new Error(`no GPU data for input: ${inputTensorViews[i].data}`);
throw new Error(`no GPU data for input: ${data}`);
}
inputDatas[i] = gpuData;
inputDatas.push(gpuData);
}

const {outputs, dispatchGroup, programUniforms} = program.getRunData(inputTensorViews);
Expand Down Expand Up @@ -419,6 +424,11 @@ export class WebGpuBackend {
const tensorView = (isTemporary || isPersistent) ?
createIntermediateOutput(outputs[i].dataType, outputs[i].dims) :
createKernelOutput(validatedOutputIndices[i], outputs[i].dataType, outputs[i].dims);
outputTensorViews.push(tensorView);
// if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it.
if (tensorView.data === 0) {
continue;
}
const gpuData = this.gpuDataManager.get(tensorView.data);
if (!gpuData) {
throw new Error(`no GPU data for output: ${tensorView.data}`);
Expand All @@ -434,10 +444,24 @@ export class WebGpuBackend {
}
persistentData.push(gpuData);
}
outputTensorViews.push(tensorView);
outputDatas.push(gpuData);
}

// when there are any zero-sized tensor in the inputs or outputs, we should report error unless all outputs are
// zero-sized tensors.
if (inputDatas.length !== inputTensorViews.length || outputDatas.length !== outputTensorViews.length) {
// if all outputs are zero-sized tensors, there is no need to run the program.
if (outputDatas.length === 0) {
TRACE_FUNC_END(program.name);
return outputTensorViews;
}
// if some outputs are zero-sized tensors, report an error.
//
// TODO: so far we don't see any use case that outputs include both zero-sized tensors and non-zero-sized tensors.
// If we see such use case, we need to make a change here to support it.
throw new Error(
`Program ${program.name} has zero-sized tensor(s) in inputs or outputs. This is not supported now.`);
}

// load uniforms
// TODO: add cache for uniform (is it necessary?)
Expand Down Expand Up @@ -686,7 +710,8 @@ export class WebGpuBackend {
}
setQueryType(): void {
this.queryType = 'none';
if (this.env.webgpu.profiling?.mode === 'default' || this.env.wasm.trace) {
if (this.env.webgpu.profiling?.mode === 'default' ||
(typeof this.env.trace === 'undefined' ? this.env.wasm.trace : this.env.trace)) {
if (this.device.features.has('chromium-experimental-timestamp-query-inside-passes')) {
this.queryType = 'inside-passes';
} else if (this.device.features.has('timestamp-query')) {
Expand Down
3 changes: 2 additions & 1 deletion js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ class ComputeContextImpl implements ComputeContext {
throw new Error(`Unsupported data type: ${dataType}`);
}
const bufferSize = elementSize * ShapeUtil.size(dims);
return new TensorViewImpl(this.module, dataType, this.backend.gpuDataManager.create(bufferSize).id, dims);
const gpuDataId = bufferSize > 0 ? this.backend.gpuDataManager.create(bufferSize).id : 0;
return new TensorViewImpl(this.module, dataType, gpuDataId, dims);
};
return this.backend.run(program, mappedInputs, outputIndices, createKernelOutput, createTemporaryOutput);
}
Expand Down
11 changes: 10 additions & 1 deletion js/web/lib/wasm/jsep/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,16 @@ export class BroadcastUtil {
if (aLen !== bLen && aLen > 1 && bLen > 1) {
return undefined;
}
cdims[crank - i] = Math.max(aLen, bLen);
const max = Math.max(aLen, bLen);
if (aLen && bLen) {
cdims[crank - i] = Math.max(aLen, bLen);
} else {
// when either aLen or bLen is 0, the other should be either 0 or 1, otherwise it is not broadcastable.
if (max > 1) {
return undefined;
}
cdims[crank - i] = 0;
}
}

return cdims;
Expand Down
4 changes: 3 additions & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/concat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,9 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P

export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => {
validateInputs(context.inputs);
context.compute(createConcatProgramInfo(context.inputs, attributes.axis));
// 0 length tensors are valid for concat, remove them
const nonEmptyInputs = context.inputs.filter(input => ShapeUtil.size(input.dims) > 0);
context.compute(createConcatProgramInfo(nonEmptyInputs, attributes.axis), {inputs: nonEmptyInputs});
};

export const parseConcatAttributes = (attributes: Record<string, unknown>): ConcatAttributes =>
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/gather.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath
if (idx${x} < 0) {
idx${x} = idx${x} + uniforms.axisDimLimit;
}
var dataIndices${x} = ${data.type.indices}(0);
var dataIndices${x} : ${data.type.indices};
`;
for (let i = 0, j = 0; i < inputRank; i++) {
if (i === axis) {
Expand Down
Loading

0 comments on commit 8487adf

Please sign in to comment.