Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/upstream' into wasm32
Browse files Browse the repository at this point in the history
  • Loading branch information
dakenf committed Oct 31, 2023
2 parents 58c722b + 1c25fe5 commit a6b4549
Show file tree
Hide file tree
Showing 27 changed files with 1,289 additions and 212 deletions.
20 changes: 20 additions & 0 deletions js/web/.npmignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,26 @@

/dist/**/*.report.html

# We remove some of the files in NPM packages because restrictions in jsdelivr:
#
# "Packages larger than 150 MB or single files larger than 20 MB (in the case of GitHub) are not supported"
#
# from https://www.jsdelivr.com/documentation
#
# We only include development build in the NPM package for the following targets:
# - /dist/ort.js
# - /dist/ort.all.js
#
/dist/cjs/ort.js
/dist/esm/ort.js
/dist/cjs/ort.all.js
/dist/esm/ort.all.js
/dist/**/ort.wasm.js
/dist/**/ort.wasm-core.js
/dist/**/ort.webgl.js
/dist/**/ort.webgpu.js
/dist/**/ort.training.wasm.js

/types/

karma.conf.js
Expand Down
3 changes: 3 additions & 0 deletions js/web/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
// So we import code inside the if-clause to allow bundler remove the code safely.

export * from 'onnxruntime-common';
import * as ort from 'onnxruntime-common';
export default ort;

import {registerBackend, env} from 'onnxruntime-common';
import {version} from './version';

Expand Down
44 changes: 30 additions & 14 deletions js/web/lib/wasm/jsep/webgpu/ops/softmax.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo} from '../types';

import {ShaderHelper, tensorTypeToWsglStorageType} from './common';
import {getMaxComponents, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common';

const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length !== 1) {
Expand All @@ -37,23 +37,39 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut

const cols = shape[axis];
const rows = outputSize / cols;
const components = getMaxComponents(cols);
const packedCols = cols / components;
const valueType = components === 1 ? dataType : `vec${components}<${dataType}>`;

const maxVector = (name: string, components: number) => {
if (components === 4) {
return `max(max(${name}.x, ${name}.y), max(${name}.z, ${name}.w))`;
} else if (components === 2) {
return `max(${name}.x, ${name}.y)`;
} else if (components === 3) {
return `max(max(${name}.x, ${name}.y), ${name}.z)`;
}

return name;
};

// 6.2.4 in wgsl spec
const threadMaxDecl = dataType === 'f32' ? 'var threadMax: f32 = -3.402823e+38f;' : 'var threadMax: f16 = -65504.0h;';
const threadMaxDecl =
dataType === 'f32' ? `var threadMax = ${valueType}(-3.402823e+38f);` : `var threadMax = ${valueType}(-65504.0h);`;
const getShaderSource = (_shaderHelper: ShaderHelper) => `
var<workgroup> rowMaxShared : ${dataType};
var<workgroup> rowSumShared : ${dataType};
var<workgroup> threadShared : array<${dataType}, ${WG}>;
var<workgroup> rowMaxShared : ${valueType};
var<workgroup> rowSumShared : ${valueType};
var<workgroup> threadShared : array<${valueType}, ${WG}>;
@group(0) @binding(0) var<storage, read> x : array<${dataType}>;
@group(0) @binding(1) var<storage, read_write> result : array<${dataType}>;
@group(0) @binding(0) var<storage, read> x : array<${valueType}>;
@group(0) @binding(1) var<storage, read_write> result : array<${valueType}>;
fn getValue(row: i32, col: i32, row_stride: i32) -> ${dataType} {
fn getValue(row: i32, col: i32, row_stride: i32) -> ${valueType} {
let index = row * row_stride + col;
return x[index];
}
fn setValue(row: i32, col: i32, row_stride: i32, value: ${dataType}) {
fn setValue(row: i32, col: i32, row_stride: i32, value: ${valueType}) {
let index = row * row_stride + col;
result[index] = value;
}
Expand All @@ -64,8 +80,8 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
let lindex = i32(local_id.x);
const wg = ${WG};
let row = gindex / wg;
let cols = ${cols};
let row_stride : i32 = ${cols};
let cols = ${packedCols};
let row_stride : i32 = ${packedCols};
// find the rows max
${threadMaxDecl}
Expand All @@ -87,12 +103,12 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
workgroupBarrier();
}
if (lindex == 0) {
rowMaxShared = threadShared[0];
rowMaxShared = ${valueType}(${maxVector('threadShared[0]', components)});
}
workgroupBarrier();
// find the rows sum
var threadSum: ${dataType} = 0.0;
var threadSum = ${valueType}(0.0);
for (var col = lindex; col < cols; col += wg) {
let subExp = exp(getValue(row, col, row_stride) - rowMaxShared);
threadSum += subExp;
Expand All @@ -107,7 +123,7 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
workgroupBarrier();
}
if (lindex == 0) {
rowSumShared = threadShared[0];
rowSumShared = ${valueType}(${sumVector('threadShared[0]', components)});
}
workgroupBarrier();
Expand Down
95 changes: 19 additions & 76 deletions js/web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -68,103 +68,46 @@
".": {
"node": "./dist/ort.node.min.js",
"default": {
"import": {
"development": "./dist/esm/ort.js",
"default": "./dist/esm/ort.min.js"
},
"require": {
"development": "./dist/cjs/ort.js",
"default": "./dist/cjs/ort.min.js"
},
"import": "./dist/esm/ort.min.js",
"require": "./dist/cjs/ort.min.js",
"default": {
"development": "./dist/ort.js",
"default": "./dist/ort.min.js"
}
}
},
"./experimental": {
"import": {
"development": "./dist/esm/ort.all.js",
"default": "./dist/esm/ort.all.min.js"
},
"require": {
"development": "./dist/cjs/ort.all.js",
"default": "./dist/cjs/ort.all.min.js"
},
"import": "./dist/esm/ort.all.min.js",
"require": "./dist/cjs/ort.all.min.js",
"default": {
"development": "./dist/ort.all.js",
"default": "./dist/ort.all.min.js"
}
},
"./wasm": {
"import": {
"development": "./dist/esm/ort.wasm.js",
"default": "./dist/esm/ort.wasm.min.js"
},
"require": {
"development": "./dist/cjs/ort.wasm.js",
"default": "./dist/cjs/ort.wasm.min.js"
},
"default": {
"development": "./dist/ort.wasm.js",
"default": "./dist/ort.wasm.min.js"
}
"import": "./dist/esm/ort.wasm.min.js",
"require": "./dist/cjs/ort.wasm.min.js",
"default": "./dist/ort.wasm.min.js"
},
"./wasm-core": {
"import": {
"development": "./dist/esm/ort.wasm-core.js",
"default": "./dist/esm/ort.wasm-core.min.js"
},
"require": {
"development": "./dist/cjs/ort.wasm-core.js",
"default": "./dist/cjs/ort.wasm-core.min.js"
},
"default": {
"development": "./dist/ort.wasm-core.js",
"default": "./dist/ort.wasm-core.min.js"
}
"import": "./dist/esm/ort.wasm-core.min.js",
"require": "./dist/cjs/ort.wasm-core.min.js",
"default": "./dist/ort.wasm-core.min.js"
},
"./webgl": {
"import": {
"development": "./dist/esm/ort.webgl.js",
"default": "./dist/esm/ort.webgl.min.js"
},
"require": {
"development": "./dist/cjs/ort.webgl.js",
"default": "./dist/cjs/ort.webgl.min.js"
},
"default": {
"development": "./dist/ort.webgl.js",
"default": "./dist/ort.webgl.min.js"
}
"import": "./dist/esm/ort.webgl.min.js",
"require": "./dist/cjs/ort.webgl.min.js",
"default": "./dist/ort.webgl.min.js"
},
"./webgpu": {
"import": {
"development": "./dist/esm/ort.webgpu.js",
"default": "./dist/esm/ort.webgpu.min.js"
},
"require": {
"development": "./dist/cjs/ort.webgpu.js",
"default": "./dist/cjs/ort.webgpu.min.js"
},
"default": {
"development": "./dist/ort.webgpu.js",
"default": "./dist/ort.webgpu.min.js"
}
"import": "./dist/esm/ort.webgpu.min.js",
"require": "./dist/cjs/ort.webgpu.min.js",
"default": "./dist/ort.webgpu.min.js"
},
"./training": {
"import": {
"development": "./dist/esm/ort.training.wasm.js",
"default": "./dist/esm/ort.training.wasm.min.js"
},
"require": {
"development": "./dist/cjs/ort.training.wasm.js",
"default": "./dist/cjs/ort.training.wasm.min.js"
},
"default": {
"development": "./dist/ort.training.wasm.js",
"default": "./dist/ort.training.wasm.min.js"
}
"import": "./dist/esm/ort.training.wasm.min.js",
"require": "./dist/cjs/ort.training.wasm.min.js",
"default": "./dist/ort.training.wasm.min.js"
}
},
"types": "./types.d.ts",
Expand Down
82 changes: 45 additions & 37 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,35 +51,34 @@ half maybe2half(float x) {

// Using only power of 2 numbers will lead to waste of compute for same size such as 768, which is a very common case
// in BERT. Ideally we can step by wrap_size * num_unroll, but listing too many steps will cause long compile time.
constexpr int kSizes[] = {32, 64, 128, 384, 768, 1024, 2048};
constexpr int kSizes[] = {128, 384, 768, 1024, 2048, 4096, 5120, 8192};
constexpr size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]);
constexpr int kMaxSize = kSizes[kNumOfSizes - 1];
constexpr int kMinBlockSize = 32;
constexpr int kMaxBlockSize = 256;
constexpr int kMaxBlockSize = 1024;

int NextSize(int x) {
for (size_t i = 0; i < kNumOfSizes; ++i) {
if (x <= kSizes[i]) {
return kSizes[i];
}
}
return kMaxSize;
return kMaxSize + 1;
}

template <typename T, int NumUnroll>
bool CanVectorized(T* output, T* sum_output, const T* input, const T* skip, const T* bias,
const T* gamma, const T* beta, const int ld, const int next_size) {
constexpr int alignment = std::alignment_of<aligned_vector<T, NumUnroll>>::value;
return ld % NumUnroll == 0 &&
bool CanVectorized(void* output, void* sum_output, const void* input, const void* skip, const void* bias,
const void* gamma, const void* beta, const int ld, const int next_size, int num_unroll, int element_size) {
int alignment = element_size * num_unroll;
return ld % num_unroll == 0 &&
reinterpret_cast<uint64_t>(output) % alignment == 0 &&
reinterpret_cast<uint64_t>(sum_output) % alignment == 0 &&
reinterpret_cast<uint64_t>(input) % alignment == 0 &&
reinterpret_cast<uint64_t>(skip) % alignment == 0 &&
reinterpret_cast<uint64_t>(bias) % alignment == 0 &&
reinterpret_cast<uint64_t>(gamma) % alignment == 0 &&
reinterpret_cast<uint64_t>(beta) % alignment == 0 &&
next_size / NumUnroll >= kMinBlockSize &&
next_size / NumUnroll <= kMaxBlockSize;
next_size / num_unroll >= kMinBlockSize &&
next_size / num_unroll <= kMaxBlockSize;
}
} // namespace

Expand Down Expand Up @@ -187,8 +186,14 @@ void LaunchSkipLayerNormKernel(
int ld, int row_count, int skip_size) {
const int next_size = NextSize(ld);
const int grid_size = row_count;
bool flag_vec2 = CanVectorized<T, 2>(output, sum_output, input, skip, bias, gamma, beta, ld, next_size);
bool flag_vec4 = CanVectorized<T, 4>(output, sum_output, input, skip, bias, gamma, beta, ld, next_size);
bool can_unroll_vec4 = CanVectorized(output, sum_output, input,
skip, bias, gamma,
beta, ld, next_size,
4, sizeof(T));
bool can_unroll_vec8 = CanVectorized(output, sum_output, input,
skip, bias, gamma,
beta, ld, next_size,
8, sizeof(T));

#define LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(num_unroll) \
SkipLayerNormKernelSmall<T, block_size, num_unroll, Simplified><<<grid_size, block_size, 0, stream>>>( \
Expand All @@ -198,39 +203,42 @@ void LaunchSkipLayerNormKernel(
SkipLayerNormKernel<T, block_size, Simplified><<<grid_size, block_size, 0, stream>>>( \
output, sum_output, input, skip, bias, gamma, beta, maybe2half<T>(epsilon), ld, skip_size)

#define CASE_NEXT_SIZE(next_size_value) \
case next_size_value: { \
static_assert(next_size_value > kSizes[0] && next_size_value < kMaxSize); \
if (flag_vec4) { \
constexpr int block_size = next_size_value / 4; \
LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(4); \
} else if (flag_vec2) { \
constexpr int block_size = next_size_value / 2; \
LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(2); \
} else { \
if (next_size_value <= kMaxBlockSize) { \
constexpr int block_size = next_size_value; \
LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1); \
} else { \
constexpr int block_size = 256; \
LAUNCH_SKIP_LAYER_NORM_KERNEL(); \
} \
} \
#define CASE_NEXT_SIZE(next_size_value) \
case next_size_value: { \
static_assert(next_size_value >= kSizes[0] && next_size_value <= kMaxSize); \
if constexpr (next_size_value >= 8 * 256) { \
if (can_unroll_vec8) { \
constexpr int block_size = next_size_value / 8; \
LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(8); \
} else { \
constexpr int block_size = 256; \
LAUNCH_SKIP_LAYER_NORM_KERNEL(); \
} \
} else { \
if (can_unroll_vec4) { \
constexpr int block_size = next_size_value / 4; \
LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(4); \
} else { \
if (next_size_value <= kMaxBlockSize) { \
constexpr int block_size = next_size_value; \
LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1); \
} else { \
constexpr int block_size = 256; \
LAUNCH_SKIP_LAYER_NORM_KERNEL(); \
} \
} \
} \
} break

switch (next_size) {
case kSizes[0]: {
constexpr int block_size = kSizes[0];
// TODO: Add back the small TensorRT kernel for 32. No need to use vertorized kernel for such small size.
LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1);
break;
}
CASE_NEXT_SIZE(kSizes[0]);
CASE_NEXT_SIZE(kSizes[1]);
CASE_NEXT_SIZE(kSizes[2]);
CASE_NEXT_SIZE(kSizes[3]);
CASE_NEXT_SIZE(kSizes[4]);
CASE_NEXT_SIZE(kSizes[5]);
// kMaxSize shall not run vectorized kernel since ld might be larger than kMaxSize.
CASE_NEXT_SIZE(kSizes[6]);
CASE_NEXT_SIZE(kSizes[7]);
default: {
constexpr int block_size = 256;
LAUNCH_SKIP_LAYER_NORM_KERNEL();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() {
Nop{});
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()),
impl->GetTypeString(), " does not support ", params->Signature());
impl->GetTypeString(), " does not support the params");
if constexpr (USE_MASK) {
ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp<T>::LaunchConvertToFilledMaskValue(params));
Expand Down
Loading

0 comments on commit a6b4549

Please sign in to comment.