Skip to content

Commit

Permalink
getRunData changes
Browse files Browse the repository at this point in the history
  • Loading branch information
dakenf committed Oct 15, 2023
1 parent 53543bc commit 542c2dd
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 125 deletions.
41 changes: 21 additions & 20 deletions js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,14 @@ const computeMean =
const meanValues = context.compute(
{
name: 'InstanceNormComputeMean',
inputTypes: [GpuDataType.default],
cacheHint: JSON.stringify({components, n, h, c}),
outputs: [
{dims: [n, c, WG, 2], dataType: DataType.float, gpuDataType: GpuDataType.default},
],
shaderCache: {hint: JSON.stringify({components, n, h, c})},
getRunData: () => ({
outputs: [
{dims: [n, c, WG, 2], dataType: DataType.float},
],
dispatchGroup: {x: n * c / components},
}),
getShaderSource: getMeanShaderSource,
dispatchGroup: () => ({x: n * c / components})
},
{inputs: [input], outputs: [-1]})[0];
const getShaderSource = (shaderHelper: ShaderHelper) => `
Expand Down Expand Up @@ -206,20 +207,20 @@ const computeMean =
return context.compute(
{
name: 'InstanceNormComputeChannelScaleShift',
inputTypes: [GpuDataType.default, GpuDataType.default, GpuDataType.default],
cacheHint: JSON.stringify({components, n, h, c, epsilon}),
outputs: [
{dims: [n, c, 2], dataType: DataType.float, gpuDataType: GpuDataType.default},
],
shaderCache: {hint: JSON.stringify({components, n, h, c, epsilon})},
getRunData: () => ({
outputs: [
{dims: [n, c, 2], dataType: DataType.float},
],
dispatchGroup: {x: Math.ceil(unitsOfWork / 64 /* workgroup size */)},
}),
getShaderSource,
dispatchGroup: () => ({x: Math.ceil(unitsOfWork / 64 /* workgroup size */)})
},
{inputs: [meanValues, scale, bias], outputs: [-1]})[0];
};

const createInstanceNormNHWCProgramInfo =
(context: ComputeContext, metadata: ProgramMetadata, inputs: readonly TensorView[],
attributes: InstanceNormAttributes) => {
(context: ComputeContext, inputs: readonly TensorView[], attributes: InstanceNormAttributes) => {
const xShape = inputs[0].dims;
const outputShape = xShape;
const N = xShape[0];
Expand Down Expand Up @@ -255,13 +256,13 @@ const createInstanceNormNHWCProgramInfo =
}`;
context.compute(
{
...metadata,
inputTypes: [GpuDataType.default, GpuDataType.default],
outputs: [
{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default},
],
name: 'InstanceNormalization',
shaderCache: {hint: `${attributes.cacheKey}`},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}
}),
getShaderSource,
dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)})
},
{inputs: [inputs[0], channelScaleShift]});
};
Expand Down
98 changes: 49 additions & 49 deletions js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,50 +26,50 @@ const createLayerNormProgramInfo =
const scale = inputs[1];
const bias = inputs[2];

const outputShape = xShape;
const axis = ShapeUtil.normalizeAxis(attributes.axis, xShape.length);
const normCount = ShapeUtil.sizeToDimension(xShape, axis);
const normSize = ShapeUtil.sizeFromDimension(xShape, axis);

const scaleSize = ShapeUtil.size(scale.dims);
const biasSize = bias ? ShapeUtil.size(bias.dims) : 0;
if (scaleSize !== normSize || (bias && biasSize !== normSize)) {
throw new Error(`Size of X.shape()[axis:] == ${normSize}.
const outputShape = xShape;
const axis = ShapeUtil.normalizeAxis(attributes.axis, xShape.length);
const normCount = ShapeUtil.sizeToDimension(xShape, axis);
const normSize = ShapeUtil.sizeFromDimension(xShape, axis);

const scaleSize = ShapeUtil.size(scale.dims);
const biasSize = bias ? ShapeUtil.size(bias.dims) : 0;
if (scaleSize !== normSize || (bias && biasSize !== normSize)) {
throw new Error(`Size of X.shape()[axis:] == ${normSize}.
Size of scale and bias (if provided) must match this.
Got scale size of ${scaleSize} and bias size of ${biasSize}`);
}

const meanInvStdDevDim = [];
for (let i = 0; i < xShape.length; ++i) {
if (i < axis) {
meanInvStdDevDim.push(xShape[i]);
} else {
meanInvStdDevDim.push(1);
}
}

const components = getMaxComponents(normSize);
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const variables = [
inputVariable('x', inputs[0].dataType, inputs[0].dims, components),
inputVariable('scale', scale.dataType, scale.dims, components),
];
if (bias) {
variables.push(inputVariable('bias', bias.dataType, bias.dims, components));
}
variables.push(outputVariable('output', inputs[0].dataType, outputShape, components));

const hasMeanDataOutput = outputCount > 1;
const hasInvStdOutput = outputCount > 2;

if (hasMeanDataOutput) {
variables.push(outputVariable('meanDataOutput', DataType.float, meanInvStdDevDim));
}
if (hasInvStdOutput) {
variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim));
}

const getShaderSource = (shaderHelper: ShaderHelper) => `
}

const meanInvStdDevDim = [];
for (let i = 0; i < xShape.length; ++i) {
if (i < axis) {
meanInvStdDevDim.push(xShape[i]);
} else {
meanInvStdDevDim.push(1);
}
}

const components = getMaxComponents(normSize);
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const variables = [
inputVariable('x', inputs[0].dataType, inputs[0].dims, components),
inputVariable('scale', scale.dataType, scale.dims, components),
];
if (bias) {
variables.push(inputVariable('bias', bias.dataType, bias.dims, components));
}
variables.push(outputVariable('output', inputs[0].dataType, outputShape, components));

const hasMeanDataOutput = outputCount > 1;
const hasInvStdOutput = outputCount > 2;

if (hasMeanDataOutput) {
variables.push(outputVariable('meanDataOutput', DataType.float, meanInvStdDevDim));
}
if (hasInvStdOutput) {
variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim));
}

const getShaderSource = (shaderHelper: ShaderHelper) => `
const normSize: f32 = ${normSize};
const normSizeVectorized: u32 = ${normSize / components};
const epsilon: f32 = ${attributes.epsilon};
Expand Down Expand Up @@ -101,13 +101,13 @@ const createLayerNormProgramInfo =
${hasMeanDataOutput ? 'meanDataOutput[global_idx] = mean' : ''};
${hasInvStdOutput ? 'invStdOutput[global_idx] = 1 / meanSquare' : ''};
}`;
const outputs = [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}];
if (hasMeanDataOutput) {
outputs.push({dims: meanInvStdDevDim, dataType: DataType.float, gpuDataType: GpuDataType.default});
}
if (hasInvStdOutput) {
outputs.push({dims: meanInvStdDevDim, dataType: DataType.float, gpuDataType: GpuDataType.default});
}
const outputs = [{dims: outputShape, dataType: inputs[0].dataType}];
if (hasMeanDataOutput) {
outputs.push({dims: meanInvStdDevDim, dataType: DataType.float});
}
if (hasInvStdOutput) {
outputs.push({dims: meanInvStdDevDim, dataType: DataType.float});
}

return {
name: 'LayerNormalization',
Expand Down
112 changes: 56 additions & 56 deletions js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,44 +71,44 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
};

const createSkipLayerNormProgramInfo =
(inputs: readonly TensorView[], attributes: SkipLayerNormAttributes, outputCount: number,
isTraining: boolean): ProgramInfo => {
const inputShape = inputs[0].dims;
const inputSize = ShapeUtil.size(inputShape);
const outputShape = inputShape;
const outputSize = inputSize;
const hiddenSize = inputShape.slice(-1)[0];
const meanInvStdDevDim = isTraining ? inputShape.slice(0, -1).concat(1) : [];
const hasBetaInput = inputs.length > 3;
const hasBiasInput = inputs.length > 4;
const hasMeanOutput = isTraining && outputCount > 1;
const hasInvStdDevOutput = isTraining && outputCount > 2;
const hasInputSkipBiasSumOutput = outputCount > 3;

const components = getMaxComponents(hiddenSize);
const variables = [
inputVariable('x', inputs[0].dataType, inputs[0].dims, components),
inputVariable('skip', inputs[1].dataType, inputs[1].dims, components),
inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components),
];
if (hasBetaInput) {
variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components));
}
if (hasBiasInput) {
variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components));
}
variables.push(outputVariable('output', inputs[0].dataType, outputShape, components));
if (hasMeanOutput) {
variables.push(outputVariable('meanOutput', DataType.float, meanInvStdDevDim));
}
if (hasInvStdDevOutput) {
variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim));
}
if (hasInputSkipBiasSumOutput) {
variables.push(outputVariable('inputSkipBiasSum', inputs[0].dataType, outputShape, components));
}
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const getShaderSource = (shaderHelper: ShaderHelper) => `
(inputs: readonly TensorView[], attributes: SkipLayerNormAttributes, outputCount: number, isTraining: boolean):
ProgramInfo => {
const inputShape = inputs[0].dims;
const inputSize = ShapeUtil.size(inputShape);
const outputShape = inputShape;
const outputSize = inputSize;
const hiddenSize = inputShape.slice(-1)[0];
const meanInvStdDevDim = isTraining ? inputShape.slice(0, -1).concat(1) : [];
const hasBetaInput = inputs.length > 3;
const hasBiasInput = inputs.length > 4;
const hasMeanOutput = isTraining && outputCount > 1;
const hasInvStdDevOutput = isTraining && outputCount > 2;
const hasInputSkipBiasSumOutput = outputCount > 3;

const components = getMaxComponents(hiddenSize);
const variables = [
inputVariable('x', inputs[0].dataType, inputs[0].dims, components),
inputVariable('skip', inputs[1].dataType, inputs[1].dims, components),
inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components),
];
if (hasBetaInput) {
variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components));
}
if (hasBiasInput) {
variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components));
}
variables.push(outputVariable('output', inputs[0].dataType, outputShape, components));
if (hasMeanOutput) {
variables.push(outputVariable('meanOutput', DataType.float, meanInvStdDevDim));
}
if (hasInvStdDevOutput) {
variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim));
}
if (hasInputSkipBiasSumOutput) {
variables.push(outputVariable('inputSkipBiasSum', inputs[0].dataType, outputShape, components));
}
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const getShaderSource = (shaderHelper: ShaderHelper) => `
const hiddenSize: f32 = ${hiddenSize};
const hiddenSizeVectorized: u32 = ${hiddenSize / components};
const epsilon: f32 = ${attributes.epsilon};
Expand Down Expand Up @@ -140,24 +140,24 @@ const createSkipLayerNormProgramInfo =
+ ${hasBetaInput ? 'beta[i]' : '0.0'};
}
}`;
const outputs = [{dims: outputShape, dataType: inputs[0].dataType}];
if (outputCount > 1) {
outputs.push({dims: meanInvStdDevDim, dataType: DataType.float});
}
if (outputCount > 2) {
outputs.push({dims: meanInvStdDevDim, dataType: DataType.float});
}
if (outputCount > 3) {
outputs.push({dims: inputShape, dataType: inputs[0].dataType});
}

return {
name: 'SkipLayerNormalization',
shaderCache: {hint: attributes.cacheKey},
getShaderSource,
getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}}),
};
};
const outputs = [{dims: outputShape, dataType: inputs[0].dataType}];
if (outputCount > 1) {
outputs.push({dims: meanInvStdDevDim, dataType: DataType.float});
}
if (outputCount > 2) {
outputs.push({dims: meanInvStdDevDim, dataType: DataType.float});
}
if (outputCount > 3) {
outputs.push({dims: inputShape, dataType: inputs[0].dataType});
}

return {
name: 'SkipLayerNormalization',
shaderCache: {hint: attributes.cacheKey},
getShaderSource,
getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}}),
};
};

export const skipLayerNorm = (context: ComputeContext, attributes: SkipLayerNormAttributes): void => {
// TODO: initialize isTraining from ComputeContext
Expand Down

0 comments on commit 542c2dd

Please sign in to comment.