Skip to content

Commit

Permalink
Merge branch 'main' into webgpu_where
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Sep 25, 2023
2 parents 7bb7924 + ce287a4 commit bfe43d4
Show file tree
Hide file tree
Showing 71 changed files with 1,735 additions and 1,093 deletions.
23 changes: 17 additions & 6 deletions js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,24 @@ const createBinaryOpProgramShader =
let assignment: string;
if (vectorize) {
if (doBroadcast) {
assignment = `
const isAOneElement = ShapeUtil.size(dimsA) === 1;
const isBOneElement = ShapeUtil.size(dimsB) === 1;
if (isAOneElement || isBOneElement) {
assignment = output.setByOffset(
'global_idx',
expressionVector(
isAOneElement ? `${a.type.value}(${a.getByOffset('0')}.x)` : a.getByOffset('global_idx'),
isBOneElement ? `${b.type.value}(${b.getByOffset('0')}.x)` : b.getByOffset('global_idx')));
} else {
assignment = `
let outputIndices = ${output.offsetToIndices('global_idx * 4u')};
let offsetA = calcOffsetA(outputIndices);
let offsetB = calcOffsetB(outputIndices);
${
output.setByOffset(
'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))}
output.setByOffset(
'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))}
`;
}
} else {
assignment = output.setByOffset(
'global_idx', expressionVector(a.getByOffset('global_idx'), b.getByOffset('global_idx')));
Expand Down Expand Up @@ -141,6 +151,8 @@ const createBinaryOpProgramInfo =
}
outputShape = calculatedShape;
outputSize = ShapeUtil.size(outputShape);
const isAOneElement = ShapeUtil.size(a.dims) === 1;
const isBOneElement = ShapeUtil.size(b.dims) === 1;

// check whether vectorize can be enabled
let sharedDimension = 1;
Expand All @@ -153,7 +165,7 @@ const createBinaryOpProgramInfo =
break;
}
}
if (sharedDimension % 4 === 0) {
if (sharedDimension % 4 === 0 || isAOneElement || isBOneElement) {
vectorize = true;
}
} else {
Expand All @@ -167,8 +179,7 @@ const createBinaryOpProgramInfo =
shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, funcCall, a.dataType, b.dataType,
outputDataType, additionalImplementation),
outputs: [{dims: outputShape, dataType: outputDataType, gpuDataType: GpuDataType.default}],
dispatchGroup: () =>
({x: Math.ceil(outputSize / 64 /* workgroup size */ / (vectorize ? 4 : 1) /* vec size */)})
dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)})
};
};

Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ const createIndicesHelper =

const getByIndicesImplementation = rank < 2 ? '' : `
fn get_${name}ByIndices(indices: ${type.indices}) -> ${valueType} {
return ${name}[i2o_${name}(indices)];
return ${getByOffset(`i2o_${name}(indices)`)};
}`;

const getImplementation = rank < 2 ? '' : (() => {
Expand Down
91 changes: 41 additions & 50 deletions js/web/lib/wasm/jsep/webgpu/ops/gather.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';

import {ShaderHelper} from './common';
import {inputVariable, outputVariable, ShaderHelper} from './common';

export interface GatherAttributes extends AttributeWithCacheKey {
axis: number;
Expand All @@ -30,63 +29,55 @@ const createGatherProgramInfo =
const outputShape = inputShape.slice(0);
outputShape.splice(axis, 1, ...indicesShape);

const inputDataType = inputs[0].dataType;
const block = ShapeUtil.sizeFromDimension(inputShape, axis + 1);
const elementSize = [DataType.int64, DataType.uint64, DataType.double].includes(inputDataType) ? 2 : 1;
const indicesElementSize = inputs[1].dataType === DataType.int64 ? 2 : 1;
const blockSize = elementSize * block;
const M = ShapeUtil.sizeToDimension(inputShape, axis);
const N = ShapeUtil.size(indicesShape);
const dataBatchElements = ShapeUtil.sizeFromDimension(inputShape, axis) * elementSize;
const gatheredBatchElements = N * block * elementSize;
const axisDimLimit = inputShape[axis];
const outputSize = ShapeUtil.size(outputShape);

const data = inputVariable('data', inputs[0].dataType, inputs[0].dims);
const indices = inputVariable('inputIndices', inputs[1].dataType, inputs[1].dims);
const output = outputVariable('output', inputs[0].dataType, outputShape);
const calcDataIndices = (): string => {
const indicesRank = indicesShape.length;
let calcStr = `var indicesIndices = ${indices.type.indices}(0);`;
for (let i = 0; i < indicesRank; i++) {
calcStr += `${indicesRank > 1 ? `indicesIndices[${i}]` : 'indicesIndices'} = ${
outputShape.length > 1 ? `outputIndices[${axis + i}]` : 'outputIndices'};`;
}
calcStr += `
var idx = ${indices.getByIndices('indicesIndices')};
if (idx < 0) {
idx = idx + ${axisDimLimit};
}
var dataIndices = ${data.type.indices}(0);
`;
for (let i = 0, j = 0; i < inputRank; i++) {
if (i === axis) {
calcStr += `${inputRank > 1 ? `dataIndices[${i}]` : 'dataIndices'} = u32(idx);`;
j += indicesRank;
} else {
calcStr += `${inputRank > 1 ? `dataIndices[${i}]` : 'dataIndices'} = ${
outputShape.length > 1 ? `outputIndices[${j}]` : 'outputIndices'};`;
j++;
}
}
return calcStr;
};

const inputSize = ShapeUtil.size(inputShape) * elementSize;
const outputSize = ShapeUtil.size(outputShape) * elementSize;

const totalGathers = M * N;
// int64 indices would be treated as little endian i32 with assumption they fall in i32 limits
// That assumption is safe as it's not possible to allocate >2gb buffer for input tensor
// Input data will be treated as u32 or two u32 for 8-byte tensors
const getShaderSource = (shaderHelper: ShaderHelper) => `
const N: u32 = ${N};
const elementSize: u32 = ${elementSize};
const indicesElementSize: u32 = ${indicesElementSize};
@group(0) @binding(0) var<storage, read> input : array<u32>;
@group(0) @binding(1) var<storage, read> inputIndices : array<i32>;
@group(0) @binding(2) var<storage, read_write> output: array<u32>;
${shaderHelper.mainStart()}
let batch: u32 = global_idx / N;
let i: u32 = global_idx % N;
let srcOffsetBatch: u32 = batch * ${dataBatchElements};
let dstOffsetBatch: u32 = batch * ${gatheredBatchElements};
var idx = inputIndices[i * indicesElementSize];
if (idx < 0) {
idx = idx + ${axisDimLimit};
}
let srcOffset = srcOffsetBatch + u32(idx) * ${blockSize};
let dstOffset = dstOffsetBatch + i * ${blockSize};
if (srcOffset >= ${inputSize}) {
return;
}
if (dstOffset >= ${outputSize}) {
return;
}
for (var j: u32 = 0; j < ${blockSize}; j++) {
output[dstOffset + j] = input[srcOffset + j];
}
}`;
${shaderHelper.declareVariables(data, indices, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
let outputIndices = ${output.offsetToIndices('global_idx')};
${calcDataIndices()};
let value = ${data.getByIndices('dataIndices')};
${output.setByOffset('global_idx', 'value')};
}`;
return {
...metadata,
outputs: [
{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default},
],
getShaderSource,
dispatchGroup: () => ({x: Math.ceil(totalGathers / 64 /* workgroup size */)})
dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)})
};
};

Expand Down
6 changes: 1 addition & 5 deletions js/web/lib/wasm/jsep/webgpu/ops/pool.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {PoolConvUtil, ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
Expand All @@ -22,9 +21,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
if (inputs[0].dims.length !== 4) {
throw new Error('Pool ops supports 2-D inputs only for now.');
}
if (inputs[0].dataType !== DataType.float) {
throw new Error('Invalid input type.');
}
};

const getAdjustedPoolAttributesAndOutputShape = <AttributeType extends AveragePoolAttributes|MaxPoolAttributes>(
Expand Down Expand Up @@ -248,7 +244,7 @@ const createAveragePoolProgramInfo =
const kernelSize = ShapeUtil.size(adjustedAttributes.kernelShape);

const x = inputVariable('x', input.dataType, input.dims);
const dataType = 'f32';
const dataType = x.type.value;

const op1 = 'value += x_val;';
let op2 = '';
Expand Down
14 changes: 5 additions & 9 deletions js/web/lib/wasm/jsep/webgpu/ops/reduce.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
if (inputs.length === 2 && inputs[1].dims.length !== 1) {
throw new Error('Invalid axes input dims.');
}

if (inputs[0].dataType !== DataType.float) {
throw new Error('Invalid input type.');
}
};

export interface ReduceAttributes extends AttributeWithCacheKey {
Expand Down Expand Up @@ -161,7 +157,7 @@ export const reduceL1 = (context: ComputeContext, attributes: ReduceAttributes):
export const reduceL2 = (context: ComputeContext, attributes: ReduceAttributes): void => {
validateInputs(context.inputs);
const reduceOp: ReduceOp = (input, output) =>
[`var t = f32(0); var value = ${output.type.storage}(0);`,
[`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`,
'',
`t = ${input.getByOffset('inputOffset')}; value += (t * t);`,
'value = sqrt(value);',
Expand Down Expand Up @@ -212,10 +208,10 @@ export const reduceMean = (context: ComputeContext, attributes: ReduceAttributes
}

return [
`var value = ${output.type.storage}(0);`,
'var sum = f32(0);',
'',
`value += ${input.getByOffset('inputOffset')};`,
`value = value / ${size}.;`,
`sum += f32(${input.getByOffset('inputOffset')});`,
`let value = ${output.type.value}(sum / ${size});`,
];
};
context.compute(createReduceProgramInfoLoader(context.inputs, 'ReduceMean', attributes, reduceOp), {inputs: [0]});
Expand Down Expand Up @@ -266,7 +262,7 @@ export const reduceSum = (context: ComputeContext, attributes: ReduceAttributes)
export const reduceSumSquare = (context: ComputeContext, attributes: ReduceAttributes): void => {
validateInputs(context.inputs);
const reduceOp: ReduceOp = (input, output) =>
[`var t = f32(0); var value = ${output.type.storage}(0);`,
[`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`,
'',
`t = ${input.getByOffset('inputOffset')}; value += t * t;`,
'',
Expand Down
54 changes: 25 additions & 29 deletions onnxruntime/core/common/cpuid_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,38 +135,34 @@ void CPUIDInfo::ArmLinuxInit() {
LOGS_DEFAULT(WARNING) << "Failed to init pytorch cpuinfo library, may cause CPU EP performance degradation due to undetected CPU features.";
return;
}
is_hybrid_ = cpuinfo_get_uarchs_count() > 1;
has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot();
has_fp16_ = cpuinfo_has_arm_neon_fp16_arith();
const uint32_t core_cnt = cpuinfo_get_cores_count();
core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown);
is_armv8_narrow_ld_.resize(core_cnt, false);
for (uint32_t c = 0; c < core_cnt; c++) {
const struct cpuinfo_processor* proc = cpuinfo_get_processor(c);
if (proc == nullptr) {
continue;
}
const struct cpuinfo_core* corep = proc->core;
if (corep == nullptr) {
continue;
}
auto coreid = proc->linux_id;
auto uarch = corep->uarch;
core_uarchs_[coreid] = uarch;
if (uarch == cpuinfo_uarch_cortex_a53 || uarch == cpuinfo_uarch_cortex_a55r0 ||
uarch == cpuinfo_uarch_cortex_a55) {
is_armv8_narrow_ld_[coreid] = true;
}
}
#else
pytorch_cpuinfo_init_ = false;
has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0);
has_fp16_ |= has_arm_neon_dot_;
#endif

if (pytorch_cpuinfo_init_) {
is_hybrid_ = cpuinfo_get_uarchs_count() > 1;
has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot();
has_fp16_ = cpuinfo_has_arm_neon_fp16_arith();
const uint32_t core_cnt = cpuinfo_get_cores_count();
core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown);
is_armv8_narrow_ld_.resize(core_cnt, false);
for (uint32_t c = 0; c < core_cnt; c++) {
const struct cpuinfo_processor* proc = cpuinfo_get_processor(c);
if (proc == nullptr) {
continue;
}
const struct cpuinfo_core* corep = proc->core;
if (corep == nullptr) {
continue;
}
auto coreid = proc->linux_id;
auto uarch = corep->uarch;
core_uarchs_[coreid] = uarch;
if (uarch == cpuinfo_uarch_cortex_a53 || uarch == cpuinfo_uarch_cortex_a55r0 ||
uarch == cpuinfo_uarch_cortex_a55) {
is_armv8_narrow_ld_[coreid] = true;
}
}
} else {
has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0);
has_fp16_ |= has_arm_neon_dot_;
}
}

#elif defined(_WIN32)
Expand Down
17 changes: 11 additions & 6 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@ MLAS_INTERNAL_DATA MLAS_DECLSPEC_ALIGN(const int16_t MlasOpmask16BitTableAvx512[
#define _XCR_XFEATURE_ENABLED_MASK 0
#endif

#if !defined(XFEATURE_MASK_XTILE)
#define XFEATURE_XTILECFG 17
#define XFEATURE_XTILEDATA 18
#define XFEATURE_MASK_XTILECFG (1 << XFEATURE_XTILECFG)
#define XFEATURE_MASK_XTILEDATA (1 << XFEATURE_XTILEDATA)
#define XFEATURE_MASK_XTILE (XFEATURE_MASK_XTILECFG | XFEATURE_MASK_XTILEDATA)
#endif

inline
uint64_t
MlasReadExtendedControlRegister(
Expand Down Expand Up @@ -142,11 +150,6 @@ bool
MlasInitAMX()
{
#if defined(__linux__)
#define XFEATURE_XTILECFG 17
#define XFEATURE_XTILEDATA 18
#define XFEATURE_MASK_XTILECFG (1 << XFEATURE_XTILECFG)
#define XFEATURE_MASK_XTILEDATA (1 << XFEATURE_XTILEDATA)
#define XFEATURE_MASK_XTILE (XFEATURE_MASK_XTILECFG | XFEATURE_MASK_XTILEDATA)

#define ARCH_GET_XCOMP_PERM 0x1022
#define ARCH_REQ_XCOMP_PERM 0x1023
Expand Down Expand Up @@ -417,7 +420,9 @@ Return Value:
// Check if the processor supports AMX-TILE and AMX-INT8
// features.
//
if ((Cpuid7[3] & 0b1 << 24) != 0 && (Cpuid7[3] & 0b1 << 25) != 0) {
if ((Cpuid7[3] & 0b1 << 24) != 0 &&
(Cpuid7[3] & 0b1 << 25) != 0 &&
(xcr0 & XFEATURE_MASK_XTILE) == XFEATURE_MASK_XTILE) {
if (MlasInitAMX()) {
this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAmx;
this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchAmx;
Expand Down
Loading

0 comments on commit bfe43d4

Please sign in to comment.