Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
Browse files Browse the repository at this point in the history
…zhanyi/replaceT4
  • Loading branch information
mszhanyi committed Jan 30, 2024
2 parents b38936e + 9f68a27 commit 8ea07e2
Show file tree
Hide file tree
Showing 60 changed files with 383 additions and 236 deletions.
2 changes: 2 additions & 0 deletions .pipelines/OneBranch.Nuget-WindowsAI-Pipeline.Official.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ extends:
git:
submodules: false
globalSdl: # https://aka.ms/obpipelines/sdl
asyncSdl:
enabled: false
tsa:
enabled: true
prefast:
Expand Down
2 changes: 1 addition & 1 deletion .pipelines/windowsai-steps.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ jobs:
7z x cmake-3.26.3-windows-x86_64.zip
set PYTHONHOME=$(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools
set PYTHONPATH=$(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools
$(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools\python.exe "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_qspectre --enable_lto --use_telemetry --disable_rtti --enable_wcos $(BuildFlags) --cmake_extra_defines "CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" "CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" CMAKE_SYSTEM_VERSION=10.0.19041.0 --cmake_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\ctest.exe
$(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools\python.exe "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --parallel --use_binskim_compliant_compile_flags --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos $(BuildFlags) --cmake_extra_defines "CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" "CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" CMAKE_SYSTEM_VERSION=10.0.19041.0 --cmake_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\ctest.exe
workingDirectory: '$(Build.BinariesDirectory)'
displayName: 'Generate cmake config'
Expand Down
22 changes: 21 additions & 1 deletion docs/python/on_device_training/training_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,32 @@ Sample usage:
CheckpointState.save_checkpoint(state, path_to_the_checkpoint_artifact)
.. autoclass:: onnxruntime.training.api.checkpoint_state.Parameter
:members:
:show-inheritance:
:member-order: bysource
:inherited-members:
:special-members: __repr__
.. autoclass:: onnxruntime.training.api.checkpoint_state.Parameters
:members:
:show-inheritance:
:member-order: bysource
:inherited-members:
:special-members: __getitem__, __setitem__, __contains__, __iter__, __repr__, __len__
.. autoclass:: onnxruntime.training.api.checkpoint_state.Properties
:members:
:show-inheritance:
:member-order: bysource
:inherited-members:
:special-members: __getitem__, __setitem__, __contains__, __iter__, __repr__, __len__
.. autoclass:: onnxruntime.training.api.CheckpointState
:members:
:show-inheritance:
:member-order: bysource
:inherited-members:
:special-members: __getitem__, __setitem__, __contains__
.. autoclass:: onnxruntime.training.api.Module
:members:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,9 +443,9 @@ export const createMatmulProgramInfo =

const components = isVec4 ? 4 : 1;
const aShapeTemp = [...outerDimsA, dimAOuter, dimInner / components];
const aShapeOrRank = aShapeTemp.length;
const aRank = aShapeTemp.length;
const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components];
const bShapeOrRank = bShapeTemp.length;
const bRank = bShapeTemp.length;
const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components];
const programUniforms: ProgramUniform[] =
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
Expand All @@ -467,12 +467,12 @@ export const createMatmulProgramInfo =
programUniforms.push(...createTensorShapeVariables(outputShapeTemp));

const getShaderSource = (shaderHelper: ShaderHelper) => {
const batchShapeOrRank = outerDims.length;
const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1);
const batchRank = outerDims.length;
const batchDims = internalVariable('batchDims', inputs[0].dataType, batchRank, 1);
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);

const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components);
const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components);
const A = inputVariable('a', inputs[0].dataType, aRank, components);
const B = inputVariable('b', inputs[1].dataType, bRank, components);
const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components);
const inputVariables = [A, B];
if (hasBias) {
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView
if (sum == 0) {
for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) {
x[offset + i] = ${fillVector('f32', components, 'uniforms.d_inv')};
x[offset + i] = ${fillVector(elemValueType, components, 'uniforms.d_inv')};
}
} else {
for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) {
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo} from '../types';

import {createTensorShapeVariables, enableShapesUniforms, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common';
import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common';

export interface BatchNormAttributes extends AttributeWithCacheKey {
readonly epsilon: number;
Expand Down Expand Up @@ -61,7 +61,7 @@ const createBatchNormInferenceProgramInfo =
const cComponents = format === 'NHWC' && yShape.length > 1 ? components : 1;
const outputSize = ShapeUtil.size(yShape) / components;
// Only support uniforms for opset version >= 9 (spatial = true).
const useShapesUniforms = enableShapesUniforms(yShape.length) && spatial;
const useShapesUniforms = spatial;
const shapeOrRank = useShapesUniforms ? yShape.length : yShape;
const x = inputVariable('x', inputs[0].dataType, inputs[0].dims, components);
const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims, cComponents);
Expand Down
37 changes: 14 additions & 23 deletions js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view';
import {BroadcastUtil, ShapeUtil} from '../../util';
import {ComputeContext, ProgramInfo} from '../types';

import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common';

type BuiltinFunctionName = string;
type BinaryCustomExpression = (expressionA: string, expressionB: string) => string;
Expand All @@ -18,8 +18,7 @@ type BinaryFunctionCall = BuiltinFunctionName|BinaryCustomExpression|{
const createBinaryOpProgramShader =
(shaderHelper: ShaderHelper, dimsA: readonly number[], dimsB: readonly number[], dimsOutput: readonly number[],
vectorize: boolean, doBroadcast: boolean, sharedDimensionDivisibleBy4: boolean, funcCall: BinaryFunctionCall,
typeA: number, typeB: number, typeOutput: number, useShapesUniforms: boolean,
additionalImplementation?: string) => {
typeA: number, typeB: number, typeOutput: number, additionalImplementation?: string) => {
let expressionScalar: BinaryCustomExpression;
let expressionVector: BinaryCustomExpression;
if (typeof funcCall === 'string') {
Expand All @@ -31,12 +30,9 @@ const createBinaryOpProgramShader =
expressionVector = funcCall.vector;
}

const inputAShapeOrRank = useShapesUniforms ? dimsA.length : dimsA;
const inputBShapeOrRank = useShapesUniforms ? dimsB.length : dimsB;
const outputShapeOrRank = useShapesUniforms ? dimsOutput.length : dimsOutput;
const output = outputVariable('outputData', typeOutput, outputShapeOrRank, 4);
const a = inputVariable('aData', typeA, inputAShapeOrRank, 4);
const b = inputVariable('bData', typeB, inputBShapeOrRank, 4);
const output = outputVariable('outputData', typeOutput, dimsOutput.length, 4);
const a = inputVariable('aData', typeA, dimsA.length, 4);
const b = inputVariable('bData', typeB, dimsB.length, 4);

let assignment: string;
if (vectorize) {
Expand Down Expand Up @@ -169,30 +165,25 @@ const createBinaryOpProgramInfo =
vectorize = true;
}
cacheKeyAux.push(vectorize);
const useShapesUniforms = enableShapesUniforms(a.dims.length) && enableShapesUniforms(b.dims.length) &&
enableShapesUniforms(outputShape.length);

return {
name,
shaderCache: {
hint: cacheKey + cacheKeyAux.map((x) => x.toString()).join('_'),
inputDependencies: useShapesUniforms ? ['rank', 'rank'] : ['dims', 'dims'],
inputDependencies: ['rank', 'rank'],
},
getShaderSource: (shaderHelper) => createBinaryOpProgramShader(
shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, sharedDimensionDivisibleBy4, funcCall,
a.dataType, b.dataType, outputDataType, useShapesUniforms, additionalImplementation),
a.dataType, b.dataType, outputDataType, additionalImplementation),
getRunData: () => ({
outputs: [{dims: outputShape, dataType: outputDataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)},
programUniforms: useShapesUniforms ?
[
{type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)},
...createTensorShapeVariables(a.dims),
...createTensorShapeVariables(b.dims),
...createTensorShapeVariables(outputShape),
] :
[
{type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)},
],
programUniforms: [
{type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)},
...createTensorShapeVariables(a.dims),
...createTensorShapeVariables(b.dims),
...createTensorShapeVariables(outputShape),
],
}),
};
};
Expand Down
3 changes: 0 additions & 3 deletions js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -922,6 +922,3 @@ export const getBroadcastDims = (inShape: readonly number[], outShape: readonly
}
return dims;
};

// TODO: remove this when all related uses have been removed.
export const enableShapesUniforms = (_rank: number): boolean => true;
26 changes: 8 additions & 18 deletions js/web/lib/wasm/jsep/webgpu/ops/concat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';

import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';

export interface ConcatAttributes extends AttributeWithCacheKey {
readonly axis: number;
Expand Down Expand Up @@ -94,32 +94,22 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P

let previousSum = 0;
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
const inputShapeOrRanks = [];
const enableInputShapesUniforms = [];
const inputRanks = [];
const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}];
for (let i = 0; i < inputs.length; ++i) {
previousSum += inputs[i].dims[adjustedAxis];
sizeInConcatAxis[i] = previousSum;
enableInputShapesUniforms.push(enableShapesUniforms(inputs[i].dims.length));
inputShapeOrRanks.push(enableInputShapesUniforms[i] ? inputs[i].dims.length : inputs[i].dims);
inputVars[i] = inputVariable(`input${i}`, dataType, inputShapeOrRanks[i]);
inputDependencies.push(enableInputShapesUniforms[i] ? 'rank' : 'dims');
inputRanks.push(inputs[i].dims.length);
inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]);
inputDependencies.push('rank');
programUniforms.push({type: 'uint32', data: sizeInConcatAxis[i]});
}
for (let i = 0; i < inputs.length; ++i) {
if (enableInputShapesUniforms[i]) {
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
}
}

const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length);
if (enableOutputShapesUniforms) {
programUniforms.push(...createTensorShapeVariables(outputShape));
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
}
programUniforms.push(...createTensorShapeVariables(outputShape));

const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape;
const output = outputVariable('output', dataType, outputShapeOrRank);

const output = outputVariable('output', dataType, outputShape.length);
const indicesAxis = output.indicesGet('indices', adjustedAxis);
const sizeInConcatAxisStr =
Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(',');
Expand Down
31 changes: 10 additions & 21 deletions js/web/lib/wasm/jsep/webgpu/ops/einsum.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';

import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common';

import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common';

export interface EinsumAttributes extends AttributeWithCacheKey {
readonly equation: string;
Expand Down Expand Up @@ -181,14 +180,12 @@ class EinsumEquation {
const appendMax = (name: string): string => name + '_max';

const createEinsumProgramInfo =
(enableInputShapesUniforms: readonly boolean[], inputShapes: Array<readonly number[]>, dataType: number,
einsumEquation: EinsumEquation, outputShape: readonly number[]): ProgramInfo => {
const shapeOrRanks = inputShapes.map((dims, index) => enableInputShapesUniforms[index] ? dims.length : dims);
const inputVars = shapeOrRanks.map((shapeOrRank, index) => inputVariable(`input${index}`, dataType, shapeOrRank));
(inputShapes: Array<readonly number[]>, dataType: number, einsumEquation: EinsumEquation,
outputShape: readonly number[]): ProgramInfo => {
const ranks = inputShapes.map((dims) => dims.length);
const inputVars = ranks.map((rank, index) => inputVariable(`input${index}`, dataType, rank));
const outputSize = ShapeUtil.size(outputShape);
const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length);
const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape;
const output = outputVariable('output', dataType, outputShapeOrRank);
const output = outputVariable('output', dataType, outputShape.length);
const uniformsSymbols =
[...einsumEquation.symbolToInfo.keys()].filter((symbol) => !einsumEquation.rhs.symbolToIndices.has(symbol));
const getShaderSource = (shaderHelper: ShaderHelper) => {
Expand Down Expand Up @@ -269,10 +266,7 @@ const createEinsumProgramInfo =
};
return {
name: 'Einsum',
shaderCache: {
hint: einsumEquation.equation,
inputDependencies: enableInputShapesUniforms.map((enableShapeUniform) => enableShapeUniform ? 'rank' : 'dims')
},
shaderCache: {hint: einsumEquation.equation, inputDependencies: inputShapes.map(() => 'rank')},
getRunData: () => {
// The symbols from uniformSymbols array are guaranteed to exist in einsumEquations.symbolToInfo map. The
// filter is added to make sure that dimValue is never 0.
Expand All @@ -281,12 +275,9 @@ const createEinsumProgramInfo =
.map((symbol) => ({type: 'uint32', data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0}));
programUniformsInit.push({type: 'uint32', data: outputSize});
const programUniforms: ProgramUniform[] =
inputShapes.filter((_, index) => enableInputShapesUniforms[index])
.map((dims, _) => [...createTensorShapeVariables(dims)])
inputShapes.map((dims, _) => [...createTensorShapeVariables(dims)])
.reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit);
if (enableOutputShapesUniforms) {
programUniforms.push(...createTensorShapeVariables(outputShape));
}
programUniforms.push(...createTensorShapeVariables(outputShape));
return ({
outputs: [{dims: outputShape, dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
Expand All @@ -299,11 +290,9 @@ const createEinsumProgramInfo =

export const einsum = (context: ComputeContext, attributes: EinsumAttributes): void => {
const einsumEquation = new EinsumEquation(context.inputs, attributes.equation);
const enableInputShapesUniforms = context.inputs.map((input, _) => enableShapesUniforms(input.dims.length));
const outputShape = einsumEquation.outputDims;
const inputShapes = context.inputs.map((input, _) => input.dims);
context.compute(createEinsumProgramInfo(
enableInputShapesUniforms, inputShapes, context.inputs[0].dataType, einsumEquation, outputShape));
context.compute(createEinsumProgramInfo(inputShapes, context.inputs[0].dataType, einsumEquation, outputShape));
};

export const parseEinsumAttributes = (attributes: Record<string, unknown>): EinsumAttributes => {
Expand Down
Loading

0 comments on commit 8ea07e2

Please sign in to comment.