Skip to content

Commit

Permalink
wasm32 fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dakenf committed Oct 31, 2023
1 parent a6b4549 commit fcb06bf
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
5 changes: 3 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {ComputeContext, ProgramInfo} from '../types';

import {inputVariable, outputVariable, ShaderHelper} from './common';
import { inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType } from './common'
import {erfImpl} from './unary-op';

const validateInputs = (inputs: readonly TensorView[]): void => {
Expand Down Expand Up @@ -35,14 +35,15 @@ const createBiasSplitGeluProgramInfo = (inputs: readonly TensorView[]): ProgramI
const output = outputVariable('output', inputs[0].dataType, outputShape, 4);

const outputSize = ShapeUtil.size(outputShape) / 4;
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);

const getShaderSource = (shaderHelper: ShaderHelper) => `
const M_SQRT2 = sqrt(2.0);
const halfChannels = ${inputs[0].dims[2] / 4 / 2}u;
${shaderHelper.declareVariables(input, bias, output)}
${erfImpl('vec4f')}
${erfImpl(`vec4<${dataType}>`, dataType)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ const maybeTransposeToBNSHAndAddBias =
reshapedInput = input.reshape([batchSize, sequenceLength, numHeads, headSize]);
}
return context.compute(
createTransposeProgramInfo(input.dataType, reshapedInput.dims.length, weightTransposeAttribute.perm),
createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm),
{inputs: [reshapedInput], outputs: [-1]})[0];
} else {
if (sequenceLength === 1) {
Expand All @@ -290,7 +290,7 @@ const maybeTransposeToBNSHAndAddBias =
addBiasTranspose(context, input, bias, batchSize, sequenceLength, numHeads * headSize, biasOffset!);
reshapedInput = reshapedInput.reshape([batchSize, sequenceLength, numHeads, headSize]);
return context.compute(
createTransposeProgramInfo(input.dataType, reshapedInput.dims.length, weightTransposeAttribute.perm),
createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm),
{inputs: [reshapedInput], outputs: [-1]})[0];
}
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/wasm/js_post_js.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
Module["PTR_SIZE"] = 4;
Module["createFileFromArrayBuffer"] = (path, buffer) => {
const weightsFile = FS.create(path);
weightsFile.contents = buffer;
weightsFile.contents = new Uint8Array(buffer);
weightsFile.usedBytes = buffer.byteLength;
}

0 comments on commit fcb06bf

Please sign in to comment.