From fcb06bf8b1649a6c96a5820c9a05afbec67c0fbe Mon Sep 17 00:00:00 2001 From: Arthur Islamov Date: Tue, 31 Oct 2023 17:57:16 +0400 Subject: [PATCH] wasm32 fixes --- js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts | 5 +++-- js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts | 4 ++-- onnxruntime/wasm/js_post_js.js | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts b/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts index 14eefc344f3c0..bd3ce0a433211 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts @@ -5,7 +5,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo} from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; +import { inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType } from './common' import {erfImpl} from './unary-op'; const validateInputs = (inputs: readonly TensorView[]): void => { @@ -35,6 +35,7 @@ const createBiasSplitGeluProgramInfo = (inputs: readonly TensorView[]): ProgramI const output = outputVariable('output', inputs[0].dataType, outputShape, 4); const outputSize = ShapeUtil.size(outputShape) / 4; + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const getShaderSource = (shaderHelper: ShaderHelper) => ` const M_SQRT2 = sqrt(2.0); @@ -42,7 +43,7 @@ const createBiasSplitGeluProgramInfo = (inputs: readonly TensorView[]): ProgramI ${shaderHelper.declareVariables(input, bias, output)} - ${erfImpl('vec4f')} + ${erfImpl(`vec4<${dataType}>`, dataType)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts index 0a55ade2469eb..ad6d1d3c561f7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts @@ -280,7 +280,7 @@ const maybeTransposeToBNSHAndAddBias = reshapedInput = input.reshape([batchSize, sequenceLength, numHeads, headSize]); } return context.compute( - createTransposeProgramInfo(input.dataType, reshapedInput.dims.length, weightTransposeAttribute.perm), + createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), {inputs: [reshapedInput], outputs: [-1]})[0]; } else { if (sequenceLength === 1) { @@ -290,7 +290,7 @@ const maybeTransposeToBNSHAndAddBias = addBiasTranspose(context, input, bias, batchSize, sequenceLength, numHeads * headSize, biasOffset!); reshapedInput = reshapedInput.reshape([batchSize, sequenceLength, numHeads, headSize]); return context.compute( - createTransposeProgramInfo(input.dataType, reshapedInput.dims.length, weightTransposeAttribute.perm), + createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), {inputs: [reshapedInput], outputs: [-1]})[0]; } } diff --git a/onnxruntime/wasm/js_post_js.js b/onnxruntime/wasm/js_post_js.js index 012c11d4343d6..daed028a4fa31 100644 --- a/onnxruntime/wasm/js_post_js.js +++ b/onnxruntime/wasm/js_post_js.js @@ -6,6 +6,6 @@ Module["PTR_SIZE"] = 4; Module["createFileFromArrayBuffer"] = (path, buffer) => { const weightsFile = FS.create(path); - weightsFile.contents = buffer; + weightsFile.contents = new Uint8Array(buffer); weightsFile.usedBytes = buffer.byteLength; }