Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JS/Web] Added Uniforms support to binary ops. #18260

Merged
merged 17 commits into from
Nov 7, 2023

Conversation

satyajandhyala
Copy link
Contributor

Description

Added Uniform support to binary ops

Motivation and Context

To improve performance

@satyajandhyala satyajandhyala added the ep:WebGPU ort-web webgpu provider label Nov 3, 2023
@satyajandhyala
Copy link
Contributor Author

satyajandhyala commented Nov 3, 2023

For the test_add with both inputs same shape I got the following shader
image

      struct Uniforms { vec_size:u32 };
      @group(0) @binding(3) var<uniform> uniforms: Uniforms;
const aData_shape = vec3<u32>(3,4,5);
const aData_strides = vec3<u32>(20,5,1);
const bData_shape = vec3<u32>(3,4,5);
const bData_strides = vec3<u32>(20,5,1);
const outputData_shape = vec3<u32>(3,4,5);
const outputData_strides = vec3<u32>(20,5,1);

        @group(0) @binding(0) var<storage, read> aData: array<vec4<f32>>;
@group(0) @binding(1) var<storage, read> bData: array<vec4<f32>>;
@group(0) @binding(2) var<storage, read_write> outputData: array<vec4<f32>>;




        @compute @workgroup_size(64, 1, 1)
  fn main(@builtin(global_invocation_id) global_id : vec3<u32>,
    @builtin(local_invocation_id) local_id : vec3<u32>) {
    let global_idx = global_id.x;

        if (global_idx >= uniforms.vec_size) { return; }
        outputData[global_idx]=aData[global_idx]+bData[global_idx];
      }
      

@satyajandhyala
Copy link
Contributor Author

satyajandhyala commented Nov 3, 2023

For the test add with different input shapes I got the following shader
image


      struct Uniforms { vec_size:u32, aData_shape:vec3<u32>, aData_strides:vec3<u32>, bData_shape:u32, bData_strides:u32, outputData_shape:vec3<u32>, outputData_strides:vec3<u32> };
      @group(0) @binding(3) var<uniform> uniforms: Uniforms;


  fn o2i_outputData(offset: u32) -> vec3<u32> {
    var indices: vec3<u32>;
    var current = offset;

    let dim0 = current / uniforms.outputData_strides[0];
    let rest0 = current % uniforms.outputData_strides[0];
    indices[0] = dim0;
    current = rest0;

    let dim1 = current / uniforms.outputData_strides[1];
    let rest1 = current % uniforms.outputData_strides[1];
    indices[1] = dim1;
    current = rest1;
    indices[2] = current;
    return indices;
  }

        @group(0) @binding(0) var<storage, read> aData: array<vec4<f32>>;
@group(0) @binding(1) var<storage, read> bData: array<vec4<f32>>;
@group(0) @binding(2) var<storage, read_write> outputData: array<vec4<f32>>;




          fn calcOffsetaData(outputIndices: vec3<u32>) -> u32{
            var retval : u32 = 0;
            for (var i = 3 - 1; i >= 0; i = i - 1) {
              retval = retval + uniforms.aData_strides[i] * (outputIndices[i] % uniforms.aData_shape[i]);
            }
            return retval;
          }

          fn calcOffsetbData(outputIndices: vec3<u32>) -> u32{
            var retval : u32 = 0;
            retval = uniforms.bData_strides * (outputIndices[3 - 1] % uniforms.bData_shape);
            return retval;
          }


        @compute @workgroup_size(64, 1, 1)
  fn main(@builtin(global_invocation_id) global_id : vec3<u32>,
    @builtin(local_invocation_id) local_id : vec3<u32>) {
    let global_idx = global_id.x;

        if (global_idx >= uniforms.vec_size) { return; }


            let outputIndices0 = o2i_outputData(global_idx * 4u + 0u);
            let offsetA0 = calcOffsetaData(outputIndices0);
            let offsetB0 = calcOffsetbData(outputIndices0);
            let indexA0 = offsetA0 / 4u;
            let indexB0 = offsetB0 / 4u;
            let componentA0 = offsetA0 % 4u;
            let componentB0 = offsetB0 % 4u;
            outputData[global_idx][0] = (aData[indexA0][componentA0]+bData[indexB0][componentB0]);


            let outputIndices1 = o2i_outputData(global_idx * 4u + 1u);
            let offsetA1 = calcOffsetaData(outputIndices1);
            let offsetB1 = calcOffsetbData(outputIndices1);
            let indexA1 = offsetA1 / 4u;
            let indexB1 = offsetB1 / 4u;
            let componentA1 = offsetA1 % 4u;
            let componentB1 = offsetB1 % 4u;
            outputData[global_idx][1] = (aData[indexA1][componentA1]+bData[indexB1][componentB1]);


            let outputIndices2 = o2i_outputData(global_idx * 4u + 2u);
            let offsetA2 = calcOffsetaData(outputIndices2);
            let offsetB2 = calcOffsetbData(outputIndices2);
            let indexA2 = offsetA2 / 4u;
            let indexB2 = offsetB2 / 4u;
            let componentA2 = offsetA2 % 4u;
            let componentB2 = offsetB2 % 4u;
            outputData[global_idx][2] = (aData[indexA2][componentA2]+bData[indexB2][componentB2]);


            let outputIndices3 = o2i_outputData(global_idx * 4u + 3u);
            let offsetA3 = calcOffsetaData(outputIndices3);
            let offsetB3 = calcOffsetbData(outputIndices3);
            let indexA3 = offsetA3 / 4u;
            let indexB3 = offsetB3 / 4u;
            let componentA3 = offsetA3 % 4u;
            let componentB3 = offsetB3 % 4u;
            outputData[global_idx][3] = (aData[indexA3][componentA3]+bData[indexB3][componentB3]);


      }

js/web/lib/wasm/jsep/webgpu/ops/common.ts Outdated Show resolved Hide resolved
js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts Outdated Show resolved Hide resolved
@satyajandhyala satyajandhyala marked this pull request as ready for review November 3, 2023 16:29
@satyajandhyala satyajandhyala force-pushed the sajandhy/webgpu_unifroms_suport_binary_ops branch from 2c05fc9 to bfdd2d5 Compare November 5, 2023 04:59
Copy link
Contributor

@qjia7 qjia7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

js/web/lib/wasm/jsep/webgpu/ops/common.ts Show resolved Hide resolved
@satyajandhyala satyajandhyala force-pushed the sajandhy/webgpu_unifroms_suport_binary_ops branch from bfdd2d5 to 0a43dcd Compare November 6, 2023 17:26
guschmue
guschmue previously approved these changes Nov 6, 2023
@satyajandhyala satyajandhyala merged commit a16d528 into main Nov 7, 2023
57 checks passed
@satyajandhyala satyajandhyala deleted the sajandhy/webgpu_unifroms_suport_binary_ops branch November 7, 2023 16:41
kleiti pushed a commit to kleiti/onnxruntime that referenced this pull request Mar 22, 2024
### Description
Added Uniform support to binary ops



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
To improve performance
siweic0 pushed a commit to siweic0/onnxruntime-web that referenced this pull request May 9, 2024
### Description
Added Uniform support to binary ops



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
To improve performance
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:WebGPU ort-web webgpu provider
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants