diff --git a/.github/actions/rust-toolchain-setup/action.yml b/.github/actions/rust-toolchain-setup/action.yml deleted file mode 100644 index bf73fede16c7f..0000000000000 --- a/.github/actions/rust-toolchain-setup/action.yml +++ /dev/null @@ -1,44 +0,0 @@ -# yaml-language-server: $schema=https://json.schemastore.org/github-action.json - -name: 'Rust toolchain setup' -description: 'Common setup steps for GitHub workflows for Rust projects' - -runs: - using: composite - steps: - - uses: dtolnay/rust-toolchain@1.71.0 - with: - components: clippy, rustfmt - - uses: extractions/setup-just@v1 - with: - just-version: '1.15.0' # optional semver specification, otherwise latest - - ### - ### Linux setup - ### - - name: rustup - # We need to use the nightly rust tool change to enable registry-auth / to connect to ADO feeds. - if: ${{ (runner.os == 'Linux') }} - run: | - rustup set profile minimal - rustup install - shell: bash - # - name: Cargo login - # if: ${{ (runner.os == 'Linux') }} - # run: just cargo-login-ci - # shell: bash - - ### - ### Windows setup - ### - - name: rustup - # We need to use the nightly rust tool change to enable registry-auth / to connect to ADO feeds. - if: ${{ (runner.os == 'Windows') }} - run: | - rustup set profile minimal - rustup install - shell: pwsh - # - name: Cargo login - # if: ${{ (runner.os == 'Windows') }} - # run: just cargo-login-ci-windows - # shell: pwsh diff --git a/.github/workflows/rust-ci.yml b/.github/workflows/rust-ci.yml deleted file mode 100644 index 725c40c2ded53..0000000000000 --- a/.github/workflows/rust-ci.yml +++ /dev/null @@ -1,132 +0,0 @@ -name: Rust - -on: [pull_request] - -env: - CARGO_TERM_COLOR: always - RUST_LOG: onnxruntime=debug,onnxruntime-sys=debug - RUST_BACKTRACE: 1 - MANIFEST_PATH: ${{ github.workspace }}/rust/Cargo.toml - -jobs: - fmt: - name: Rustfmt - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: ./.github/actions/rust-toolchain-setup - - name: vendor onnxruntime source - run: just vendor - - name: fmt - run: cargo fmt --all -- --check - - download: - name: Download prebuilt ONNX Runtime archive from build.rs - runs-on: ubuntu-latest - env: - ORT_RUST_STRATEGY: download - steps: - - uses: actions/checkout@v4 - - uses: ./.github/actions/rust-toolchain-setup - - run: rustup target install x86_64-unknown-linux-gnu - - run: rustup target install x86_64-apple-darwin - - run: rustup target install i686-pc-windows-msvc - - run: rustup target install x86_64-pc-windows-msvc - # ****************************************************************** - - name: Download prebuilt archive (CPU, x86_64-unknown-linux-gnu) - run: cargo build --target x86_64-unknown-linux-gnu --manifest-path ${{ env.MANIFEST_PATH }} - - name: Verify prebuilt archive downloaded (CPU, x86_64-unknown-linux-gnu) - run: ls -lh target/x86_64-unknown-linux-gnu/debug/build/onnxruntime-sys-*/out/onnxruntime-linux-x64-1.*.tgz - # ****************************************************************** - - name: Download prebuilt archive (CPU, x86_64-apple-darwin) - run: cargo build --target x86_64-apple-darwin --manifest-path ${{ env.MANIFEST_PATH }} - - name: Verify prebuilt archive downloaded (CPU, x86_64-apple-darwin) - run: ls -lh target/x86_64-apple-darwin/debug/build/onnxruntime-sys-*/out/onnxruntime-osx-x64-1.*.tgz - # ****************************************************************** - - name: Download prebuilt archive (CPU, i686-pc-windows-msvc) - run: cargo build --target i686-pc-windows-msvc --manifest-path ${{ env.MANIFEST_PATH }} - - name: Verify prebuilt archive downloaded (CPU, i686-pc-windows-msvc) - run: ls -lh target/i686-pc-windows-msvc/debug/build/onnxruntime-sys-*/out/onnxruntime-win-x86-1.*.zip - # ****************************************************************** - - name: Download prebuilt archive (CPU, x86_64-pc-windows-msvc) - run: cargo build --target x86_64-pc-windows-msvc --manifest-path ${{ env.MANIFEST_PATH }} - - name: Verify prebuilt archive downloaded (CPU, x86_64-pc-windows-msvc) - run: ls -lh target/x86_64-pc-windows-msvc/debug/build/onnxruntime-sys-*/out/onnxruntime-win-x64-1.*.zip - # ****************************************************************** - - name: Download prebuilt archive (GPU, x86_64-unknown-linux-gnu) - env: - ORT_USE_CUDA: "yes" - run: cargo build --target x86_64-unknown-linux-gnu --manifest-path ${{ env.MANIFEST_PATH }} - - name: Verify prebuilt archive downloaded (GPU, x86_64-unknown-linux-gnu) - run: ls -lh target/x86_64-unknown-linux-gnu/debug/build/onnxruntime-sys-*/out/onnxruntime-linux-x64-gpu-1.*.tgz - # ****************************************************************** - - name: Download prebuilt archive (GPU, x86_64-pc-windows-msvc) - env: - ORT_USE_CUDA: "yes" - run: cargo build --target x86_64-pc-windows-msvc --manifest-path ${{ env.MANIFEST_PATH }} - - name: Verify prebuilt archive downloaded (GPU, x86_64-pc-windows-msvc) - run: ls -lh target/x86_64-pc-windows-msvc/debug/build/onnxruntime-sys-*/out/onnxruntime-win-gpu-x64-1.*.zip - - test: - name: Test Suite - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - target: - [ - x86_64-unknown-linux-gnu, - x86_64-apple-darwin, - x86_64-pc-windows-msvc, - i686-pc-windows-msvc, - ] - include: - - target: x86_64-unknown-linux-gnu - os: ubuntu-latest - - target: x86_64-apple-darwin - os: macos-latest - - target: x86_64-pc-windows-msvc - os: windows-latest - - target: i686-pc-windows-msvc - os: windows-latest - env: - CARGO_BUILD_TARGET: ${{ matrix.target }} - steps: - - uses: actions/checkout@v4 - - uses: ./.github/actions/rust-toolchain-setup - - name: vendor onnxruntime source - run: just vendor - - run: rustup target install ${{ matrix.target }} - - name: Install additional packages (macOS) - if: contains(matrix.target, 'x86_64-apple-darwin') - run: brew install libomp - - name: Build (cargo build) - run: cargo build --all --manifest-path ${{ env.MANIFEST_PATH }} - - name: Build tests (cargo test) - run: cargo test --no-run --manifest-path ${{ env.MANIFEST_PATH }} - - name: Build onnxruntime with 'model-fetching' feature - run: cargo build --manifest-path ${{ env.MANIFEST_PATH }} --features model-fetching - - name: Test onnxruntime-sys - run: cargo build --package onnxruntime-sys -- --test-threads=1 --nocapture - - name: Test onnxruntime - run: cargo test --manifest-path ${{ env.MANIFEST_PATH }} --features model-fetching -- --test-threads=1 --nocapture - - clippy: - name: Clippy - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: ./.github/actions/rust-toolchain-setup - - name: vendor onnxruntime source - run: just vendor - - run: clippy --all-features --manifest-path ${{ env.MANIFEST_PATH }} -- -D warnings - - package-sys: - name: Package onnxruntime-sys - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: ./.github/actions/rust-toolchain-setup - - name: vendor onnxruntime source - run: just vendor - - run: cargo package --allow-dirty --package onnxruntime-sys diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 0f57258dca706..1567da90cacfc 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -131,6 +131,7 @@ option(onnxruntime_USE_ACL_1902 "Build with ACL version 1902 support" OFF) option(onnxruntime_USE_ACL_1905 "Build with ACL version 1905 support" OFF) option(onnxruntime_USE_ACL_1908 "Build with ACL version 1908 support" OFF) option(onnxruntime_USE_ACL_2002 "Build with ACL version 2002 support" OFF) +option(onnxruntime_USE_ACL_2308 "Build with ACL version 2308 support" OFF) option(onnxruntime_USE_ARMNN "Build with ArmNN support" OFF) option(onnxruntime_ARMNN_RELU_USE_CPU "Use the CPU implementation for the Relu operator for the ArmNN EP" ON) option(onnxruntime_ARMNN_BN_USE_CPU "Use the CPU implementation for the Batch Normalization operator for the ArmNN EP" ON) @@ -1110,7 +1111,7 @@ function(onnxruntime_add_include_to_target dst_target) endfunction() # ACL -if (onnxruntime_USE_ACL OR onnxruntime_USE_ACL_1902 OR onnxruntime_USE_ACL_1905 OR onnxruntime_USE_ACL_1908 OR onnxruntime_USE_ACL_2002) +if (onnxruntime_USE_ACL OR onnxruntime_USE_ACL_1902 OR onnxruntime_USE_ACL_1905 OR onnxruntime_USE_ACL_1908 OR onnxruntime_USE_ACL_2002 OR onnxruntime_USE_ACL_2308) set(onnxruntime_USE_ACL ON) if (onnxruntime_USE_ACL_1902) add_definitions(-DACL_1902=1) @@ -1121,7 +1122,11 @@ if (onnxruntime_USE_ACL OR onnxruntime_USE_ACL_1902 OR onnxruntime_USE_ACL_1905 if (onnxruntime_USE_ACL_2002) add_definitions(-DACL_2002=1) else() - add_definitions(-DACL_1905=1) + if (onnxruntime_USE_ACL_2308) + add_definitions(-DACL_2308=1) + else() + add_definitions(-DACL_1905=1) + endif() endif() endif() endif() diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index ed174bc0982eb..90e02da986b8f 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -16,7 +16,7 @@ import {expand} from './ops/expand'; import {gather, parseGatherAttributes} from './ops/gather'; import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements'; import {gemm, parseGemmAttributes} from './ops/gemm'; -import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm'; +import {instanceNorm} from './ops/instance-norm'; import {layerNorm} from './ops/layer-norm'; import {matMul} from './ops/matmul'; import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion'; @@ -82,7 +82,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]], ['Greater', [binaryOps.greater]], ['GreaterOrEqual', [binaryOps.greaterOrEqual]], - ['InstanceNormalization', [instanceNorm, parseInstanceNormAttributes]], + ['InstanceNormalization', [instanceNorm]], ['LayerNormalization', [layerNorm]], ['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]], ['Less', [binaryOps.less]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index 3a84844544c96..056dd54d54591 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -4,58 +4,56 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common'; +import {createTensorShapeVariables, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; -export interface InstanceNormAttributes extends AttributeWithCacheKey { +export interface InstanceNormAttributes { epsilon: number; format: 'NHWC'|'NCHW'; } -const metadata = { - name: 'InstanceNormalization' -}; - const createInstanceNormProgramInfo = (inputs: readonly TensorView[], attributes: InstanceNormAttributes): ProgramInfo => { const xShape = inputs[0].dims; - const outputShape = xShape; const axis = 2; const normCount = ShapeUtil.sizeToDimension(xShape, axis); const normSize = ShapeUtil.sizeFromDimension(xShape, axis); const components = getMaxComponents(normSize); const normPackedSize = normSize / components; - const C = xShape[1]; - const x = inputVariable('x', inputs[0].dataType, [xShape[0], xShape[1], normPackedSize], components); - const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims); - const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims); - const output = outputVariable('output', inputs[0].dataType, [xShape[0], xShape[1], normPackedSize], components); - const variables = [x, scale, bias, output]; - const dataType = x.type.value; - const f32Type = components === 1 ? 'f32' : `vec${components}`; - const workgroupSize = 64; - const getShaderSource = (shaderHelper: ShaderHelper) => ` - - const C: u32 = ${C}; - const normSize: u32 = ${normSize}; - const epsilon: f32 = ${attributes.epsilon}; + const inputShape = [xShape[0], xShape[1], normPackedSize]; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type']; + const programUniforms: ProgramUniform[] = + [{type: 'uint32', data: normSize}, {type: 'uint32', data: normPackedSize}]; + programUniforms.push(...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape)); + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const x = inputVariable('x', inputs[0].dataType, inputShape.length, components); + const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims); + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims); + const output = outputVariable('output', inputs[0].dataType, inputShape.length, components); + const variables = [x, scale, bias, output]; + const dataType = x.type.value; + const f32Type = components === 1 ? 'f32' : `vec${components}`; + const workgroupSize = 64; + + const uniforms: UniformsArrayType = [{name: 'normSize', type: 'u32'}, {name: 'normPackedSize', type: 'u32'}]; + return ` var meanShared : f32; var squaredNormShared : f32; var workgroupShared : array<${f32Type}, ${workgroupSize}>; const workgroupSize = ${workgroupSize}u; - ${shaderHelper.declareVariables(...variables)} + ${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)} ${shaderHelper.mainStart(workgroupSize)} let norm = global_idx / workgroupSize; - let batch = norm / C; - let channel = norm % C; + let batch = norm / uniforms.x_shape[1]; + let channel = norm % uniforms.x_shape[1]; let localIndex = local_id.x; // initialize workgroup memory var initial = ${f32Type}(0); - for (var h = localIndex; h < ${normPackedSize}; h += workgroupSize) { + for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) { initial = initial + ${f32Type}(${x.get('batch', 'channel', 'h')}); } workgroupShared[localIndex] = initial; @@ -69,13 +67,13 @@ const createInstanceNormProgramInfo = workgroupBarrier(); } if (localIndex == 0) { - meanShared = ${sumVector('workgroupShared[0]', components)} / f32(normSize); + meanShared = ${sumVector('workgroupShared[0]', components)} / f32(uniforms.normSize); } workgroupBarrier(); // reinitialize workgroup memory. initial = ${f32Type}(0); - for (var h = localIndex; h < ${normPackedSize}; h += workgroupSize) { + for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) { let deviation = ${f32Type}(${x.get('batch', 'channel', 'h')}) - ${f32Type}(meanShared); initial = initial + deviation * deviation; } @@ -94,23 +92,26 @@ const createInstanceNormProgramInfo = } workgroupBarrier(); - let invStdDev = 1 / sqrt(squaredNormShared / f32(normSize) + epsilon); + let invStdDev = 1 / sqrt(squaredNormShared / f32(uniforms.normSize) + f32(${attributes.epsilon})); let channelScale = invStdDev * f32(${scale.getByOffset('channel')}); let channelShift = f32(${bias.getByOffset('channel')}) - meanShared * channelScale; - for (var h = localIndex; h < ${normPackedSize}; h += workgroupSize) { + for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) { let value = ${x.get('batch', 'channel', 'h')} * ${dataType}(${f32Type}(channelScale)) + ${dataType}(${ - f32Type}(channelShift)); + f32Type}(channelShift)); ${output.set('batch', 'channel', 'h', 'value')}; } }`; + }; return { - ...metadata, - shaderCache: {hint: attributes.cacheKey}, + ...{name: 'InstanceNormalization'}, + // TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon. + shaderCache: {hint: `${attributes.epsilon};${components}`, inputDependencies}, getRunData: () => ({ outputs: [ {dims: outputShape, dataType: inputs[0].dataType}, ], - dispatchGroup: {x: normCount} + dispatchGroup: {x: normCount}, + programUniforms }), getShaderSource, }; @@ -120,10 +121,6 @@ const computeMean = (context: ComputeContext, input: TensorView, scale: TensorView, bias: TensorView, n: number, h: number, c: number, epsilon: number) => { const components = getMaxComponents(c); - const inputHelper = inputVariable('input', input.dataType, input.dims, components); - const scaleHelper = inputVariable('scale', scale.dataType, scale.dims, components); - const biasHelper = inputVariable('bias', bias.dataType, bias.dims, components); - const WG = 64; // we will store channel scale and channel shift in [2, components] matrix // or in vec2 when components == 1 @@ -133,65 +130,79 @@ const computeMean = const unitsOfWork = n * c / components; const wgSize = Math.ceil(h / WG); - const getMeanShaderSource = (shaderHelper: ShaderHelper) => ` - const H: u32 = ${h}; - const C: u32 = ${c / components}; - const imageSize: u32 = ${h * c / components}; + const meanInputDependencies: ProgramInputTensorInfoDependency[] = ['type']; + const meanProgramUniforms: ProgramUniform[] = [ + {type: 'uint32', data: wgSize}, {type: 'uint32', data: h}, {type: 'uint32', data: Math.floor(c / components)}, + {type: 'uint32', data: Math.floor(h * c / components)} + ]; + const getMeanShaderSource = (shaderHelper: ShaderHelper) => { + const inputHelper = inputVariable('input', input.dataType, input.dims, components); + return ` ${shaderHelper.declareVariables(inputHelper)} @group(0) @binding(1) var output : array<${outputType}>; + struct Uniforms {wg_size:u32, H:u32, C:u32, image_size:u32}; + @group(0) @binding(2) var uniforms: Uniforms; ${shaderHelper.mainStart(WG)} - let currentImageNumber = global_idx / ${WG} / C; - let currentChannelNumber = (global_idx / ${WG}) % C; + let currentImageNumber = global_idx / ${WG} / uniforms.C; + let currentChannelNumber = (global_idx / ${WG}) % uniforms.C; let wgId = global_idx % ${WG}; - let wgOffset = wgId * ${wgSize}; - if (wgOffset >= H) { + let wgOffset = wgId * uniforms.wg_size; + if (wgOffset >= uniforms.H) { return; } - let wgMax = min(wgOffset + ${wgSize}, H); + let wgMax = min(wgOffset + uniforms.wg_size, uniforms.H); - let offset = currentImageNumber * imageSize + currentChannelNumber; + let offset = currentImageNumber * uniforms.image_size + currentChannelNumber; var sum = ${fillVector('f32', components)}; var squaredSum = ${fillVector('f32', components)}; for (var i: u32 = wgOffset; i < wgMax; i++) { - let value = ${sumCastType}(input[offset + i * C]); + let value = ${sumCastType}(input[offset + i * uniforms.C]); sum += value; squaredSum += value * value; } output[global_idx] = ${setOutputValue('sum', 'squaredSum')}; }`; + }; const meanValues = context.compute( { name: 'InstanceNormComputeMean', - shaderCache: {hint: JSON.stringify({components, n, h, c})}, + shaderCache: {hint: `${components}`, inputDependencies: meanInputDependencies}, getRunData: () => ({ outputs: [ {dims: [n, c, WG, 2], dataType: DataType.float}, ], dispatchGroup: {x: n * c / components}, + programUniforms: meanProgramUniforms }), getShaderSource: getMeanShaderSource, }, {inputs: [input], outputs: [-1]})[0]; - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const H: u32 = ${h}; - const C: u32 = ${c / components}; - const imageSize: u32 = ${WG * c / components}; - const epsilon: f32 = ${epsilon}; + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: unitsOfWork}, {type: 'uint32', data: h}, + {type: 'uint32', data: Math.floor(c / components)}, {type: 'uint32', data: Math.floor(WG * c / components)} + ]; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type', 'type']; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const scaleHelper = inputVariable('scale', scale.dataType, scale.dims, components); + const biasHelper = inputVariable('bias', bias.dataType, bias.dims, components); + return ` @group(0) @binding(0) var input : array<${outputType}>; @group(0) @binding(1) var scale : array<${scaleHelper.type.storage}>; @group(0) @binding(2) var bias : array<${biasHelper.type.storage}>; @group(0) @binding(3) var output : array<${outputType}>; + struct Uniforms {units_of_work : u32, H: u32, C : u32, image_size : u32}; + @group(0) @binding(4) var uniforms: Uniforms; ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(unitsOfWork)} - let currentImageNumber = global_idx / C; - let currentChannelNumber = global_idx % C; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.units_of_work')} + let currentImageNumber = global_idx / uniforms.C; + let currentChannelNumber = global_idx % uniforms.C; - let offset = currentImageNumber * imageSize; + let offset = currentImageNumber * uniforms.image_size; var sum = ${fillVector('f32', components)}; var squaredSum = ${fillVector('f32', components)}; for (var i: u32 = 0; i < ${WG}; i++) { @@ -199,24 +210,26 @@ const computeMean = sum += value[0]; squaredSum += value[1]; } - sum = sum / f32(H); - squaredSum = squaredSum / f32(H); - let invStdDev = 1 / sqrt(squaredSum - sum * sum + epsilon); + sum = sum / f32(uniforms.H); + squaredSum = squaredSum / f32(uniforms.H); + let invStdDev = 1 / sqrt(squaredSum - sum * sum + f32(${epsilon})); let channelScale = invStdDev * ${sumCastType}(scale[currentChannelNumber]); let channelShift = ${sumCastType}(bias[currentChannelNumber]) - sum * channelScale; output[global_idx] = ${setOutputValue('channelScale', 'channelShift')}; }`; - + }; return context.compute( { name: 'InstanceNormComputeChannelScaleShift', - shaderCache: {hint: JSON.stringify({components, n, h, c, epsilon})}, + // TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon. + shaderCache: {hint: `${components};${epsilon}`, inputDependencies}, getRunData: () => ({ outputs: [ {dims: [n, c, 2], dataType: DataType.float}, ], dispatchGroup: {x: Math.ceil(unitsOfWork / 64 /* workgroup size */)}, + programUniforms }), getShaderSource, }, @@ -230,50 +243,51 @@ const createInstanceNormNHWCProgramInfo = const N = xShape[0]; const C = xShape[xShape.length - 1]; const H = ShapeUtil.sizeFromDimension(xShape, 1) / C; - const components = getMaxComponents(C); const outputSize = ShapeUtil.size(outputShape) / components; - const inputHelper = inputVariable('input', inputs[0].dataType, inputs[0].dims, components); - const outputHelper = outputVariable('output', inputs[0].dataType, outputShape, components); - - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const scaleType = components === 1 ? 'vec2f' : `mat2x${components}f`; - const scaleCastType = components === 1 ? dataType : `vec${components}<${dataType}>`; + const programUniforms: ProgramUniform[] = + [{type: 'uint32', data: H}, {type: 'uint32', data: Math.floor(C / components)}]; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; // first compute mean const channelScaleShift = computeMean(context, inputs[0], inputs[1], inputs[2], N, H, C, attributes.epsilon); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const scaleType = components === 1 ? 'vec2f' : `mat2x${components}f`; + const scaleCastType = components === 1 ? dataType : `vec${components}<${dataType}>`; - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const H: u32 = ${H}; - const C: u32 = ${C / components}; + const inputHelper = inputVariable('input', inputs[0].dataType, inputs[0].dims, components); + const outputHelper = outputVariable('output', inputs[0].dataType, outputShape, components); + return ` @group(0) @binding(0) var input : array<${inputHelper.type.storage}>; @group(0) @binding(1) var scaleInput : array<${scaleType}>; @group(0) @binding(2) var output : array<${outputHelper.type.storage}>; + struct Uniforms {H: u32, C : u32}; + @group(0) @binding(3) var uniforms: Uniforms; ${shaderHelper.mainStart()} - let currentImageNumber = global_idx / (C * H); - let currentChannelNumber = global_idx % C; + let currentImageNumber = global_idx / (uniforms.C * uniforms.H); + let currentChannelNumber = global_idx % uniforms.C; - let scaleOffset = currentImageNumber * C + currentChannelNumber; + let scaleOffset = currentImageNumber * uniforms.C + currentChannelNumber; let scale = scaleInput[scaleOffset]; output[global_idx] = fma(input[global_idx], ${scaleCastType}(scale[0]), ${scaleCastType}(scale[1])); }`; + }; context.compute( { - name: 'InstanceNormalization', - shaderCache: {hint: `${attributes.cacheKey}`}, + name: 'InstanceNormalizationNHWC', + shaderCache: {hint: `${components}`, inputDependencies}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms }), getShaderSource, }, {inputs: [inputs[0], channelScaleShift]}); }; -export const parseInstanceNormAttributes = (attributes: InstanceNormAttributes): InstanceNormAttributes => - createAttributeWithCacheKey({epsilon: attributes.epsilon, format: attributes.format}); - export const instanceNorm = (context: ComputeContext, attributes: InstanceNormAttributes): void => { if (attributes.format === 'NHWC') { createInstanceNormNHWCProgramInfo(context, context.inputs, attributes); diff --git a/js/web/test/data/ops/instance-norm.jsonc b/js/web/test/data/ops/instance-norm.jsonc index 6a4e6912405ee..e89ac2da3795f 100644 --- a/js/web/test/data/ops/instance-norm.jsonc +++ b/js/web/test/data/ops/instance-norm.jsonc @@ -38,6 +38,79 @@ } ] }, + { + "name": "Simple test with NHWC, components 1", + "operator": "InstanceNormalization", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "cases": [ + { + "name": "Simple test", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8, 7, 6, 5], + "dims": [1, 5, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5], + "dims": [5], + "type": "float32" + }, + { + "data": [4, 5, 6, 7, 8], + "dims": [5], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2.775264263153076, 4, 5.224735260009766, 2.5505285263061523, 5, 7.449470520019531, 2.325794219970703, 6, + 9.674205780029297, 11.898944854736328, 7, 2.1010589599609375, 14.123676300048828, 8, 1.876321792602539 + ], + "dims": [1, 5, 3, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Simple test with NHWC, components 2", + "operator": "InstanceNormalization", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "cases": [ + { + "name": "Simple test", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8], + "dims": [2, 6, 1, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [6], + "type": "float32" + }, + { + "data": [4, 5, 6, 7, 8, 9], + "dims": [6], + "type": "float32" + } + ], + "outputs": [ + { + "data": [4, 5, 6, 7, 8, 9, 4, 5, 6, 7, 8, 9], + "dims": [2, 6, 1, 1], + "type": "float32" + } + ] + } + ] + }, { "name": "Simple test with NCHW", "operator": "InstanceNormalization", @@ -75,5 +148,81 @@ ] } ] + }, + { + "name": "Simple test with NCHW, components 1", + "operator": "InstanceNormalization", + "opset": { "domain": "", "version": 17 }, + "cases": [ + { + "name": "Simple test", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8, 7, 6, 5], + "dims": [1, 5, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5], + "dims": [5], + "type": "float32" + }, + { + "data": [4, 5, 6, 7, 8], + "dims": [5], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2.775264263153076, 4, 5.224735260009766, 2.5505285263061523, 5, 7.449470520019531, 2.325794219970703, 6, + 9.674205780029297, 11.898944854736328, 7, 2.1010589599609375, 14.123676300048828, 8, 1.876321792602539 + ], + "dims": [1, 5, 3, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Simple test with NCHW, components 2", + "operator": "InstanceNormalization", + "opset": { "domain": "", "version": 17 }, + "cases": [ + { + "name": "Simple test", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8, 7, 6, 5, 4, 3, 2], + "dims": [1, 3, 6, 1], + "type": "float32" + }, + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + }, + { + "data": [4, 5, 6], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2.5361523628234863, 3.1216912269592285, 3.70723032951355, 4.292769432067871, 4.878308296203613, + 5.4638471603393555, 1.8666191101074219, 3.9555397033691406, 6.044460296630859, 8.133380889892578, + 6.044460296630859, 3.9555397033691406, 10.3915433883667, 8.634925842285156, 6.878308296203613, + 5.121691703796387, 3.365074634552002, 1.6084575653076172 + ], + "dims": [1, 3, 6, 1], + "type": "float32" + } + ] + } + ] } ] diff --git a/onnxruntime/core/providers/acl/math/gemm.h b/onnxruntime/core/providers/acl/math/gemm.h index d2f297e83aedb..f5288d7f231b0 100644 --- a/onnxruntime/core/providers/acl/math/gemm.h +++ b/onnxruntime/core/providers/acl/math/gemm.h @@ -49,11 +49,18 @@ class Gemm : public onnxruntime::Gemm { } Status Compute(OpKernelContext* context) const override { +#ifdef ACL_2308 + if (this->packed_b_) { + // Prepacked RHS not supported, defaulting to cpu execution provider + return onnxruntime::Gemm::Compute(context); + } +#endif const auto A = context->Input(0); const auto B = context->Input(1); const auto C = context->Input(2); - GemmHelper helper(A->Shape(), trans_A_ != CblasNoTrans, B->Shape(), trans_B_ != CblasNoTrans, C->Shape()); + GemmHelper helper(A->Shape(), trans_A_ != CblasNoTrans, B->Shape(), trans_B_ != CblasNoTrans, + C != nullptr ? C->Shape() : TensorShape({})); if (!helper.State().IsOK()) return helper.State(); @@ -70,7 +77,7 @@ class Gemm : public onnxruntime::Gemm { return onnxruntime::Gemm::Compute(context); } - arm_compute::TensorShape cShape = ACLTensorShape(C->Shape()); + arm_compute::TensorShape cShape = ACLTensorShape(C != nullptr ? C->Shape() : TensorShape({})); if (useC && (cShape.num_dimensions() > 2 || (cShape.num_dimensions() == 2 && cShape[0] > 1 && cShape[1] > 1))) { // Multi-dimensional Bias @@ -89,8 +96,13 @@ class Gemm : public onnxruntime::Gemm { (cShape[1] == 1 && cShape[0] != (long unsigned int)N)) { return onnxruntime::Gemm::Compute(context); } +#ifdef ACL_2308 + cShape = arm_compute::TensorShape(N); + LOGS_DEFAULT(VERBOSE) << "Bias reshaped to: {" << N << "}"; +#else cShape = arm_compute::TensorShape(1, N); LOGS_DEFAULT(VERBOSE) << "Bias reshaped to: {1," << N << "}"; +#endif } int64_t K = helper.K(); diff --git a/onnxruntime/core/providers/acl/nn/batch_norm.cc b/onnxruntime/core/providers/acl/nn/batch_norm.cc index da7fff730c96f..eb6a10074f1db 100755 --- a/onnxruntime/core/providers/acl/nn/batch_norm.cc +++ b/onnxruntime/core/providers/acl/nn/batch_norm.cc @@ -44,6 +44,16 @@ Status BatchNorm::Compute(OpKernelContext* context) const { const Tensor* M = context->Input(3); // mean const Tensor* V = context->Input(4); // var + if (S->Shape().NumDimensions() > 1) { + LOGS_DEFAULT(WARNING) << "ACL does not support scale with dimension greater then 1; defaulting to cpu implementation"; + return onnxruntime::BatchNorm::Compute(context); + } + + if (this->is_train_) { + LOGS_DEFAULT(WARNING) << "ACL does not have batchnorm training support; defaulting to cpu implementation"; + return onnxruntime::BatchNorm::Compute(context); + } + ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(X, S, B, M, V)); LOGS_DEFAULT(VERBOSE) << "BatchNorm ACL:"; @@ -70,7 +80,23 @@ Status BatchNorm::Compute(OpKernelContext* context) const { auto layer = std::make_shared(); +#ifdef ACL_2308 + arm_compute::TensorShape in_x_shape; + const TensorShape& x_shape = X->Shape(); + const auto& dims_vec = x_shape.GetDims(); + in_x_shape.set(3, onnxruntime::narrow(dims_vec[0])); // N + in_x_shape.set(1, 1); // H + size_t W = 1; + for (size_t i = 2; i < dims_vec.size(); ++i) { + W *= narrow(dims_vec[i]); + } + in_x_shape.set(0, W); // W + in_x_shape.set(2, onnxruntime::narrow(dims_vec[1])); // C + + tbatch_norm.in->allocator()->init(arm_compute::TensorInfo(in_x_shape, arm_compute::Format::F32)); +#else tbatch_norm.in->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(X->Shape()), arm_compute::Format::F32)); +#endif tbatch_norm.out->allocator()->init(arm_compute::TensorInfo(tbatch_norm.in->info()->tensor_shape(), arm_compute::Format::F32)); tbatch_norm.scale->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(S->Shape()), arm_compute::Format::F32)); @@ -132,11 +158,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 7, 9, kAclExecutionProvider, KernelDefBuilder() - .TypeConstraint("X", DataTypeImpl::GetTensorType()) - .TypeConstraint("scale", DataTypeImpl::GetTensorType()) - .TypeConstraint("B", DataTypeImpl::GetTensorType()) - .TypeConstraint("mean", DataTypeImpl::GetTensorType()) - .TypeConstraint("var", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", DataTypeImpl::GetTensorType()), BatchNorm); } // namespace acl diff --git a/onnxruntime/core/providers/acl/nn/batch_norm.h b/onnxruntime/core/providers/acl/nn/batch_norm.h index c9ec08b67a779..264301976e6dc 100755 --- a/onnxruntime/core/providers/acl/nn/batch_norm.h +++ b/onnxruntime/core/providers/acl/nn/batch_norm.h @@ -31,9 +31,9 @@ typedef struct { typedef std::map::iterator BatchNormLayersIterator; template -class BatchNorm final : public OpKernel { +class BatchNorm : public onnxruntime::BatchNorm { public: - explicit BatchNorm(const OpKernelInfo& info) : OpKernel(info) { + explicit BatchNorm(const OpKernelInfo& info) : onnxruntime::BatchNorm(info) { auto st = info.GetAttr("epsilon", &epsilon_); ORT_ENFORCE(st.IsOK(), st.ErrorMessage()); diff --git a/onnxruntime/core/providers/acl/nn/conv.cc b/onnxruntime/core/providers/acl/nn/conv.cc index 1613d927d0f74..85bd0cfe96279 100644 --- a/onnxruntime/core/providers/acl/nn/conv.cc +++ b/onnxruntime/core/providers/acl/nn/conv.cc @@ -105,7 +105,11 @@ Status Conv::Compute(OpKernelContext* context) const { TensorShapeVector Y_dims; Y_dims.insert(Y_dims.begin(), {N, M}); TensorShape input_shape = X->Shape().Slice(2); +#ifdef ACL_2308 + ORT_RETURN_IF_ERROR(conv_attrs_.InferPadsAndOutputShape(input_shape, kernel_shape, strides, dilations, pads, Y_dims)); +#else ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShape(input_shape, kernel_shape, strides, dilations, pads, Y_dims)); +#endif Tensor* Y = context->Output(0, TensorShape(Y_dims)); LOGS_DEFAULT(VERBOSE) << "Y " << Y->Shape().ToString().c_str(); @@ -222,6 +226,15 @@ Status Conv::Compute(OpKernelContext* context) const { 1 /* depth multiplier */, acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(), arm_compute::Size2D(aclDilation0, dilations[0]))); +#elif defined(ACL_2308) + bool optimizable = bool(arm_compute::NEDepthwiseConvolutionLayer::validate(tconv.in->info(), + tconv.k->info(), + (B != nullptr) ? tconv.b->info() : nullptr, + tconv.out->info(), + aclPadStride, + 1 /* depth multiplier */, + acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(), + arm_compute::Size2D(aclDilation0, dilations[0]))); #endif if (optimizable) { @@ -230,7 +243,7 @@ Status Conv::Compute(OpKernelContext* context) const { auto layer = std::make_shared(); #elif defined(ACL_1908) auto layer = std::make_shared(); -#elif defined(ACL_2002) +#elif defined(ACL_2002) || defined(ACL_2308) auto layer = std::make_shared(); #endif @@ -238,7 +251,7 @@ Status Conv::Compute(OpKernelContext* context) const { layer->configure(tconv.in.get(), tconv.k.get(), (B != nullptr) ? tconv.b.get() : nullptr, tconv.out.get(), aclPadStride, 1 /* depth multiplier */, acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo()); -#elif defined(ACL_1905) || defined(ACL_1908) || defined(ACL_2002) +#elif defined(ACL_1905) || defined(ACL_1908) || defined(ACL_2002) || defined(ACL_2308) layer->configure(tconv.in.get(), tconv.k.get(), (B != nullptr) ? tconv.b.get() : nullptr, tconv.out.get(), aclPadStride, 1 /* depth multiplier */, acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(), diff --git a/onnxruntime/core/providers/acl/nn/conv.h b/onnxruntime/core/providers/acl/nn/conv.h index ecb11fb3c8f4e..660d47b4172df 100644 --- a/onnxruntime/core/providers/acl/nn/conv.h +++ b/onnxruntime/core/providers/acl/nn/conv.h @@ -8,6 +8,9 @@ #include "core/providers/acl/acl_execution_provider.h" // ACL +#ifdef ACL_2308 +#include "arm_compute/runtime/Tensor.h" +#endif #include "arm_compute/core/TensorInfo.h" #include "arm_compute/runtime/TensorAllocator.h" #include "arm_compute/runtime/Allocator.h" diff --git a/onnxruntime/core/providers/acl/nn/pool.cc b/onnxruntime/core/providers/acl/nn/pool.cc index dc79ae65bf21e..8fbcba3ed87a7 100644 --- a/onnxruntime/core/providers/acl/nn/pool.cc +++ b/onnxruntime/core/providers/acl/nn/pool.cc @@ -61,7 +61,14 @@ ACLNEPool PoolOperation(onnxruntime::OpKernelContext* context, tpool.out->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(Y->Shape(), PREF_DIM), arm_compute::Format::F32)); if (pool_attrs.global_pooling) { - layer->configure(tpool.in.get(), tpool.out.get(), arm_compute::PoolingLayerInfo(pool_type)); + layer->configure(tpool.in.get(), + tpool.out.get(), + arm_compute::PoolingLayerInfo(pool_type +#ifdef ACL_2308 + , + arm_compute::DataLayout::NCHW +#endif + )); } else { TensorShapeVector aclStrides(2); aclStrides[0] = (strides.size() == 2) ? strides[1] : 1; @@ -104,7 +111,13 @@ ACLNEPool PoolOperation(onnxruntime::OpKernelContext* context, LOGS_DEFAULT(VERBOSE) << "strides: {" << aclStrides[0] << "," << aclStrides[1] << "}"; LOGS_DEFAULT(VERBOSE) << "excludePadding: " << excludePadding; - arm_compute::PoolingLayerInfo pool_info(pool_type, aclSize, aclPadStride, excludePadding); + arm_compute::PoolingLayerInfo pool_info(pool_type, + aclSize, +#ifdef ACL_2308 + arm_compute::DataLayout::NCHW, +#endif + aclPadStride, + excludePadding); layer->configure(tpool.in.get(), tpool.out.get(), pool_info); } diff --git a/onnxruntime/core/providers/acl/tensor/concat.cc b/onnxruntime/core/providers/acl/tensor/concat.cc index 081472729cfcf..75eedaac80aea 100644 --- a/onnxruntime/core/providers/acl/tensor/concat.cc +++ b/onnxruntime/core/providers/acl/tensor/concat.cc @@ -10,6 +10,8 @@ #include "core/providers/acl/acl_common.h" #include "core/providers/acl/acl_fwd.h" +#include + #define PREF_DIM 4 namespace onnxruntime { @@ -22,17 +24,27 @@ Status Concat::Compute(OpKernelContext* ctx) const { return onnxruntime::Concat::Compute(ctx); } + if (axis_ < 0) { + LOGS_DEFAULT(WARNING) << "ACL does not have support for negative axis; defaulting to cpu implementation"; + return onnxruntime::Concat::Compute(ctx); + } + // Number of input tensors to concatenate auto input_count = Node().InputArgCount().front(); // Hold pointers to the input tensors to be used in the PrepareForCompute() step std::vector input_tensors; - input_tensors.reserve(input_count); + int empty_tensors = 0; for (int i = 0; i < input_count; ++i) { + if (ctx->Input(i)->Shape().Size() == 0) { + empty_tensors++; + continue; + } input_tensors.push_back(ctx->Input(i)); } + input_count -= empty_tensors; - auto output_dims = input_tensors[0]->Shape().AsShapeVector(); + auto output_dims = ctx->Input(0)->Shape().AsShapeVector(); // 'Concat' mode if (!is_stack_) { @@ -64,7 +76,11 @@ Status Concat::Compute(OpKernelContext* ctx) const { LOGS_DEFAULT(VERBOSE) << "Concat ACL:"; arm_compute::Tensor output; +#ifdef ACL_2308 + std::vector inputs_vector; +#else std::vector inputs_vector; +#endif for (int i = 0; i < input_count; i++) { arm_compute::Tensor* input = new arm_compute::Tensor(); auto X = input_tensors[i]; @@ -75,7 +91,9 @@ Status Concat::Compute(OpKernelContext* ctx) const { } arm_compute::NEConcatenateLayer layer; - layer.configure(inputs_vector, &output, 3 - axis_); + if (input_count > 0) { + layer.configure(inputs_vector, &output, 3 - axis_); + } LOGS_DEFAULT(VERBOSE) << "axis: " << axis_; LOGS_DEFAULT(VERBOSE) << std::endl; @@ -83,7 +101,11 @@ Status Concat::Compute(OpKernelContext* ctx) const { for (int i = 0; i < input_count; i++) { auto X = input_tensors[i]; const T* x_data = X->Data(); +#ifdef ACL_2308 + arm_compute::Tensor* in = const_cast(static_cast(inputs_vector[i])); +#else arm_compute::Tensor* in = static_cast(inputs_vector[i]); +#endif if (X->Shape().Size() != 0 && in->info()->has_padding()) { in->allocator()->allocate(); @@ -101,7 +123,9 @@ Status Concat::Compute(OpKernelContext* ctx) const { ACLImportMemory(output.allocator(), (void*)y_data, Y->Shape().Size() * 4); } - layer.run(); + if (input_count > 0) { + layer.run(); + } if (Y->Shape().Size() != 0 && output.info()->has_padding()) { importDataFromTensor(&output, y_data); diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index d34cb7e362446..7718fbdc2df88 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -12,6 +12,24 @@ namespace onnxruntime { namespace webnn { +InitializedTensorSet CollectAllInitializedTensors(const GraphViewer& graph_viewer) { + InitializedTensorSet all_initializers; + if (graph_viewer.IsSubgraph()) { + const Graph* cur_graph = &graph_viewer.GetGraph(); + // Traverse up to the top-level graph, collecting all initializers. + while (cur_graph->IsSubgraph()) { + const auto& current_initializers = cur_graph->GetAllInitializedTensors(); + all_initializers.insert(current_initializers.begin(), current_initializers.end()); + cur_graph = cur_graph->ParentGraph(); + } + // Collect initializers in top-level graph. + const auto& current_initializers = cur_graph->GetAllInitializedTensors(); + all_initializers.insert(current_initializers.begin(), current_initializers.end()); + } + + return all_initializers; +} + bool GetShape(const NodeArg& node_arg, std::vector& shape, const logging::Logger& logger) { const auto* shape_proto = node_arg.Shape(); if (!shape_proto) { diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 5aec81af15761..ea57ab1af19af 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -35,6 +35,9 @@ typedef struct { bool isCpuSupported; // The WebNN CPU backend XNNPack supports it (not about the CPU EP). } WebnnOpInfo; +// Collects all the initializer tensors in the subGraph and its ancestor graphs. +InitializedTensorSet CollectAllInitializedTensors(const GraphViewer& graph_viewer); + bool GetShape(const NodeArg& node_arg, std::vector& shape, const logging::Logger& logger); template diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index b6631263dfb93..b57e1b89b0af0 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -38,6 +38,25 @@ Status ModelBuilder::Initialize() { return Status::OK(); } +InitializedTensorSet ModelBuilder::GetInitializerTensors() { + if (graph_viewer_.IsSubgraph()) { + auto all_initializers = CollectAllInitializedTensors(graph_viewer_); + const auto sub_graph_id = graph_viewer_.GetFilterInfo(); + const auto subgraph_initializer_names = sub_graph_id->GetMetaDef()->constant_initializers; + InitializedTensorSet subgraph_initializers; + + for (const auto& name : subgraph_initializer_names) { + auto it = all_initializers.find(name); + if (it != all_initializers.end()) { + subgraph_initializers.insert(*it); + } + } + return subgraph_initializers; + } else { + return graph_viewer_.GetAllInitializedTensors(); + } +} + /* static */ const IOpBuilder* ModelBuilder::GetOpBuilder(const Node& node) { const auto& op_builders = GetOpBuilders(); const auto it = op_builders.find(node.OpType()); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index c381eef3f42f7..16c8bf2d3c77f 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -30,7 +30,7 @@ class ModelBuilder { // Accessors for members. const GraphViewer& GetGraphViewer() const { return graph_viewer_; } - const InitializedTensorSet& GetInitializerTensors() const { return graph_viewer_.GetAllInitializedTensors(); } + InitializedTensorSet GetInitializerTensors(); const emscripten::val& GetBuilder() const { return wnn_builder_; } const emscripten::val& GetContext() const { return wnn_context_; } diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 4da54aaad3a33..cf18b3225eb47 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -59,10 +59,15 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view const IKernelLookup& /*kernel_registries*/) const { std::vector> result; - // We do not run WebNN EP on subgraph, instead we cover this in the control flow nodes. - // TODO investigate whether we want to support subgraph using WebNN EP. - if (graph_viewer.IsSubgraph()) { - return result; + // For subgraph which is the attribute of the control flow nodes, part of its initializers are stored in its + // ancestor graphs as common initializers shared for other subgraphs. We need to collect all of them used for + // identifying the required initializer names and storing into 'meta_def->constant_initializers'. + // Thus we are able to get the required initialized tensors for this subgraph via the GetInitializerTensors() + // method defined in the model_builder.h file. + InitializedTensorSet all_initializers; + const bool is_subgraph = graph_viewer.IsSubgraph(); + if (is_subgraph) { + all_initializers = webnn::CollectAllInitializedTensors(graph_viewer); } /* @@ -110,6 +115,7 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view std::unique_ptr sub_graph = std::make_unique(); + std::vector subgraph_initializers; InlinedHashSet node_outputs; InlinedHashSet subgraph_inputs; InlinedHashSet subgraph_outputs; @@ -126,7 +132,11 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view // skip the placeholder inputs. continue; } - // if the node input was not produced by this subgraph, add it to the subgraph inputs. + // If it is a subgraph of a control flow node, collect the constant initializer. + if (is_subgraph && Contains(all_initializers, input->Name())) { + subgraph_initializers.push_back(input->Name()); + } + // If the node input was not produced by this subgraph, add it to the subgraph inputs. if (node_outputs.count(input) == 0) { if (subgraph_inputs.count(input) == 0) { subgraph_inputs.insert(input); @@ -165,6 +175,12 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view meta_def->since_version = 1; meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; + if (is_subgraph) { + for (const auto& initializer : subgraph_initializers) { + meta_def->constant_initializers.push_back(initializer); + } + } + for (const auto& input : ordered_subgraph_inputs) { meta_def->inputs.push_back(input->Name()); } diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index c655100fbf475..3d0ec92a7bd23 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -606,7 +606,7 @@ def convert_arg_line_to_args(self, arg_line): "--use_acl", nargs="?", const="ACL_1905", - choices=["ACL_1902", "ACL_1905", "ACL_1908", "ACL_2002"], + choices=["ACL_1902", "ACL_1905", "ACL_1908", "ACL_2002", "ACL_2308"], help="Build with ACL for ARM architectures.", ) parser.add_argument("--acl_home", help="Path to ACL home dir") @@ -1031,6 +1031,7 @@ def generate_build_tree( "-Donnxruntime_USE_ACL_1905=" + ("ON" if args.use_acl == "ACL_1905" else "OFF"), "-Donnxruntime_USE_ACL_1908=" + ("ON" if args.use_acl == "ACL_1908" else "OFF"), "-Donnxruntime_USE_ACL_2002=" + ("ON" if args.use_acl == "ACL_2002" else "OFF"), + "-Donnxruntime_USE_ACL_2308=" + ("ON" if args.use_acl == "ACL_2308" else "OFF"), "-Donnxruntime_USE_ARMNN=" + ("ON" if args.use_armnn else "OFF"), "-Donnxruntime_ARMNN_RELU_USE_CPU=" + ("OFF" if args.armnn_relu else "ON"), "-Donnxruntime_ARMNN_BN_USE_CPU=" + ("OFF" if args.armnn_bn else "ON"), diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml index 9755e1f0771ba..693a06f9844f5 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml @@ -13,7 +13,7 @@ stages: jobs: - job: Linux_Training_CPU_Wheels - timeoutInMinutes: 120 + timeoutInMinutes: 180 workspace: clean: all pool: onnxruntime-Ubuntu2004-AMD-CPU diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml index 44904f9248b10..7cee5045bc4f3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml @@ -102,7 +102,7 @@ stages: setVcvars: true BuildConfig: 'RelWithDebInfo' ExtraParam: ${{ parameters.build_py_parameters }} - timeoutInMinutes: 120 + timeoutInMinutes: 180 workspace: clean: all @@ -329,7 +329,7 @@ stages: - ${{ if eq(parameters.enable_mac_cpu, true) }}: - job: MacOS_py_Wheels - timeoutInMinutes: 120 + timeoutInMinutes: 180 workspace: clean: all pool: diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_nightly/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_nightly/requirements.txt index 0cd5e5c5d5c46..01fa7b0ff956e 100644 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_nightly/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_nightly/requirements.txt @@ -1,5 +1,5 @@ scikit-learn packaging==21.3 -transformers==v4.30.0 -accelerate==0.20.1 +transformers==v4.36.0 +accelerate==0.25.0 wget diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage2/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage2/requirements.txt index b4b265f65b69f..2b557f2aee00f 100644 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage2/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage2/requirements.txt @@ -2,8 +2,8 @@ pandas scikit-learn numpy==1.21.6 ; python_version < '3.11' numpy==1.24.2 ; python_version >= '3.11' -transformers==v4.30.0 -accelerate +transformers==v4.36.0 +accelerate==0.25.0 rsa==4.9 tensorboard==2.13.0 h5py