Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into snnn/fix_333
Browse files Browse the repository at this point in the history
  • Loading branch information
snnn committed Jan 25, 2024
2 parents e755298 + 4477f57 commit 6c676b4
Show file tree
Hide file tree
Showing 21 changed files with 387 additions and 565 deletions.
6 changes: 5 additions & 1 deletion cmake/external/xnnpack.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@ set(FP16_BUILD_BENCHMARKS OFF CACHE INTERNAL "")
set(PTHREADPOOL_BUILD_TESTS OFF CACHE INTERNAL "")
set(PTHREADPOOL_BUILD_BENCHMARKS OFF CACHE INTERNAL "")

if(CMAKE_SYSTEM_PROCESSOR MATCHES "^riscv64.*")
set(XNNPACK_USE_SYSTEM_LIBS OFF)
endif()

# BF16 instructions cause ICE in Android NDK compiler
if(CMAKE_ANDROID_ARCH_ABI STREQUAL armeabi-v7a)
set(XNNPACK_ENABLE_ARM_BF16 OFF)
ENDIF()
endif()

# fp16 depends on psimd
FetchContent_Declare(psimd URL ${DEP_URL_psimd} URL_HASH SHA1=${DEP_SHA1_psimd})
Expand Down
4 changes: 3 additions & 1 deletion cmake/onnxruntime_common.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ elseif(NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
set(ARM TRUE)
elseif(dumpmachine_output MATCHES "^aarch64.*")
set(ARM64 TRUE)
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^riscv64.*")
set(RISCV64 TRUE)
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$")
set(X86 TRUE)
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$")
Expand All @@ -198,7 +200,7 @@ elseif(NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
endif()


if (ARM64 OR ARM OR X86 OR X64 OR X86_64)
if (RISCV64 OR ARM64 OR ARM OR X86 OR X64 OR X86_64)
if((WIN32 AND NOT CMAKE_CXX_STANDARD_LIBRARIES MATCHES kernel32.lib) OR ((ARM64 OR ARM) AND MSVC))
# msvc compiler report syntax error with cpuinfo arm source files
# and cpuinfo does not have code for getting arm uarch info under windows
Expand Down
35 changes: 35 additions & 0 deletions cmake/riscv64.toolchain.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2024 SiFive, Inc. All rights reserved.
# Copyright (c) 2024, Phoebe Chen <[email protected]>
# Licensed under the MIT License.

set(CMAKE_SYSTEM_NAME Linux)
set(CMAKE_SYSTEM_PROCESSOR riscv64)

list(APPEND CMAKE_TRY_COMPILE_PLATFORM_VARIABLES RISCV_TOOLCHAIN_ROOT)

if(NOT RISCV_TOOLCHAIN_ROOT)
message(FATAL_ERROR "RISCV_TOOLCHAIN_ROOT is not defined. Please set the RISCV_TOOLCHAIN_ROOT variable.")
endif()

set(CMAKE_C_COMPILER "${RISCV_TOOLCHAIN_ROOT}/bin/riscv64-unknown-linux-gnu-gcc")
set(CMAKE_ASM_COMPILER "${RISCV_TOOLCHAIN_ROOT}/bin/riscv64-unknown-linux-gnu-gcc")
set(CMAKE_CXX_COMPILER "${RISCV_TOOLCHAIN_ROOT}/bin/riscv64-unknown-linux-gnu-g++")

set(CMAKE_FIND_ROOT_PATH ${RISCV_TOOLCHAIN_ROOT})
set(CMAKE_SYSROOT "${RISCV_TOOLCHAIN_ROOT}/sysroot")
set(CMAKE_INCLUDE_PATH "${RISCV_TOOLCHAIN_ROOT}/sysroot/usr/include/")
set(CMAKE_LIBRARY_PATH "${RISCV_TOOLCHAIN_ROOT}/sysroot/usr/lib/")
set(CMAKE_PROGRAM_PATH "${RISCV_TOOLCHAIN_ROOT}/sysroot/usr/bin/")

if(RISCV_QEMU_PATH)
message(STATUS "RISCV_QEMU_PATH=${RISCV_QEMU_PATH} is defined during compilation.")
set(CMAKE_CROSSCOMPILING_EMULATOR "${RISCV_QEMU_PATH};-L;${CMAKE_SYSROOT}")
endif()

set(CMAKE_CROSSCOMPILING TRUE)

set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY)
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)

4 changes: 0 additions & 4 deletions js/web/lib/build-def.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@ interface BuildDefinitions {
/**
* defines whether to disable the whole WebNN backend in the build.
*/
readonly DISABLE_WEBNN: boolean;
/**
* defines whether to disable the whole WebAssembly backend in the build.
*/
readonly DISABLE_WASM: boolean;
/**
* defines whether to disable proxy feature in WebAssembly backend in the build.
Expand Down
4 changes: 1 addition & 3 deletions js/web/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,10 @@ if (!BUILD_DEFS.DISABLE_WASM) {
require('./backend-wasm-training').wasmBackend;
if (!BUILD_DEFS.DISABLE_WEBGPU) {
registerBackend('webgpu', wasmBackend, 5);
registerBackend('webnn', wasmBackend, 5);
}
registerBackend('cpu', wasmBackend, 10);
registerBackend('wasm', wasmBackend, 10);
if (!BUILD_DEFS.DISABLE_WEBNN) {
registerBackend('webnn', wasmBackend, 9);
}
}

Object.defineProperty(env.versions, 'web', {value: version, enumerable: true});
2 changes: 1 addition & 1 deletion js/web/lib/wasm/binding/ort-wasm.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ export interface OrtWasmModule extends EmscriptenModule {

_OrtGetLastError(errorCodeOffset: number, errorMessageOffset: number): void;

_OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): number;
_OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): Promise<number>;
_OrtReleaseSession(sessionHandle: number): void;
_OrtGetInputOutputCount(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number;
_OrtGetInputName(sessionHandle: number, index: number): number;
Expand Down
20 changes: 10 additions & 10 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -222,16 +222,6 @@ export class WebGpuBackend {
getCommandEncoder(): GPUCommandEncoder {
if (!this.commandEncoder) {
this.commandEncoder = this.device.createCommandEncoder();

if (this.queryType !== 'none' && typeof this.querySet === 'undefined') {
this.querySet = this.device.createQuerySet({
type: 'timestamp',
count: this.maxDispatchNumber * 2,
});
this.queryResolveBuffer = this.device.createBuffer(
// eslint-disable-next-line no-bitwise
{size: this.maxDispatchNumber * 2 * 8, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE});
}
}
return this.commandEncoder;
}
Expand Down Expand Up @@ -654,6 +644,16 @@ export class WebGpuBackend {
} else if (this.device.features.has('timestamp-query')) {
this.queryType = 'at-passes';
}

if (this.queryType !== 'none' && typeof this.querySet === 'undefined') {
this.querySet = this.device.createQuerySet({
type: 'timestamp',
count: this.maxDispatchNumber * 2,
});
this.queryResolveBuffer = this.device.createBuffer(
// eslint-disable-next-line no-bitwise
{size: this.maxDispatchNumber * 2 * 8, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE});
}
}
}
onRunStart(): void {
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ export const initRuntime = async(env: Env): Promise<void> => {
* @param epName
*/
export const initEp = async(env: Env, epName: string): Promise<void> => {
if (!BUILD_DEFS.DISABLE_WEBGPU && epName === 'webgpu') {
if (!BUILD_DEFS.DISABLE_WEBGPU && (epName === 'webgpu' || epName === 'webnn')) {
// perform WebGPU availability check
if (typeof navigator === 'undefined' || !navigator.gpu) {
throw new Error('WebGPU is not supported in current environment');
Expand Down Expand Up @@ -228,7 +228,7 @@ export const createSession = async(
await Promise.all(loadingPromises);
}

sessionHandle = wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle);
sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle);
if (sessionHandle === 0) {
checkLastError('Can\'t create a session.');
}
Expand Down
7 changes: 1 addition & 6 deletions js/web/script/build.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ const SOURCE_ROOT_FOLDER = path.join(__dirname, '../..'); // <ORT_ROOT>/js/
const DEFAULT_DEFINE = {
'BUILD_DEFS.DISABLE_WEBGL': 'false',
'BUILD_DEFS.DISABLE_WEBGPU': 'false',
'BUILD_DEFS.DISABLE_WEBNN': 'false',
'BUILD_DEFS.DISABLE_WASM': 'false',
'BUILD_DEFS.DISABLE_WASM_PROXY': 'false',
'BUILD_DEFS.DISABLE_WASM_THREAD': 'false',
Expand Down Expand Up @@ -364,7 +363,6 @@ async function main() {
...DEFAULT_DEFINE,
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
'BUILD_DEFS.DISABLE_WEBGL': 'true',
'BUILD_DEFS.DISABLE_WEBNN': 'true',
'BUILD_DEFS.DISABLE_WASM_PROXY': 'true',
'BUILD_DEFS.DISABLE_WASM_THREAD': 'true',
},
Expand Down Expand Up @@ -397,7 +395,7 @@ async function main() {
// ort.webgpu[.min].js
await addAllWebBuildTasks({
outputBundleName: 'ort.webgpu',
define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true', 'BUILD_DEFS.DISABLE_WEBNN': 'true'},
define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true'},
});
// ort.wasm[.min].js
await addAllWebBuildTasks({
Expand All @@ -411,7 +409,6 @@ async function main() {
...DEFAULT_DEFINE,
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
'BUILD_DEFS.DISABLE_WASM': 'true',
'BUILD_DEFS.DISABLE_WEBNN': 'true',
},
});
// ort.wasm-core[.min].js
Expand All @@ -421,7 +418,6 @@ async function main() {
...DEFAULT_DEFINE,
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
'BUILD_DEFS.DISABLE_WEBGL': 'true',
'BUILD_DEFS.DISABLE_WEBNN': 'true',
'BUILD_DEFS.DISABLE_WASM_PROXY': 'true',
'BUILD_DEFS.DISABLE_WASM_THREAD': 'true',
},
Expand All @@ -434,7 +430,6 @@ async function main() {
'BUILD_DEFS.DISABLE_TRAINING': 'false',
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
'BUILD_DEFS.DISABLE_WEBGL': 'true',
'BUILD_DEFS.DISABLE_WEBNN': 'true',
},
});
}
Expand Down
4 changes: 0 additions & 4 deletions js/web/script/test-runner-cli-args.ts
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,6 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs

const globalEnvFlags = parseGlobalEnvFlags(args);

if (backend.includes('webnn') && !globalEnvFlags.wasm!.proxy) {
throw new Error('Backend webnn requires flag "wasm-enable-proxy" to be set to true.');
}

// Options:
// --log-verbose=<...>
// --log-info=<...>
Expand Down
9 changes: 9 additions & 0 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ std::string GetShapeString(std::vector<T>& shape) {
return shape_info.str();
}

inline std::vector<uint32_t> GetVecUint32FromVecInt64(const std::vector<int64_t>& int64_vec) {
std::vector<uint32_t> uint32_vec;
uint32_vec.reserve(int64_vec.size());
std::transform(int64_vec.begin(), int64_vec.end(),
std::back_inserter(uint32_vec),
[](int64_t val) -> uint32_t { return SafeInt<uint32_t>(val); });
return uint32_vec;
}

template <typename T>
bool ReadIntArrayFrom1DTensor(const onnx::TensorProto& tensor, std::vector<T>& array, const logging::Logger& logger) {
std::vector<uint8_t> unpacked_tensor;
Expand Down
Loading

0 comments on commit 6c676b4

Please sign in to comment.