Skip to content

Commit

Permalink
FP16 Gemm, Softmax & Transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
dakenf committed Sep 11, 2023
1 parent 24f0893 commit ce329e7
Show file tree
Hide file tree
Showing 12 changed files with 79 additions and 88 deletions.
10 changes: 2 additions & 8 deletions js/web/lib/wasm/jsep/webgpu/ops/gemm.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {GemmUtil, ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';

import {ShaderHelper} from './common';
import {ShaderHelper, tensorTypeToWsglStorageType} from './common';

const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs) {
Expand All @@ -22,11 +21,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
throw new Error('Invalid input shape of C');
}

if ((inputs[0].dataType !== DataType.float) || (inputs[1].dataType !== DataType.float) ||
(inputs.length === 3 && inputs[2].dataType !== DataType.float)) {
throw new Error('Invalid input type.');
}

if ((inputs[0].dataType !== inputs[1].dataType) ||
(inputs.length === 3 && inputs[0].dataType !== inputs[2].dataType)) {
throw new Error('Input types are mismatched');
Expand Down Expand Up @@ -81,7 +75,7 @@ const createGemmProgramInfo =
line = 'value += a[m * K + k] * b[k * N + n];';
}

const dataType = 'f32'; // TODO: support other data type
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); // TODO: support other data type
const calculateAlpha = attributes.alpha === 1 ? '' : 'value *= alpha;';
const calculateC = inputs.length === 3 ? `value += beta * c[${offsetC(M, N, inputs[2].dims)}];` : '';
const inputStorageBuffersDeclarations = [
Expand Down
14 changes: 6 additions & 8 deletions js/web/lib/wasm/jsep/webgpu/ops/softmax.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,17 @@
// performance limitations when the reduced axis is long. Need to add
// a optimized codepath for this.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo} from '../types';

import {ShaderHelper} from './common';
import {ShaderHelper, tensorTypeToWsglStorageType} from './common';

const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length !== 1) {
throw new Error('Softmax op requires 1 input.');
}
if (inputs[0].dataType !== DataType.float) {
throw new Error('Softmax input needs to be float.');
}
};

export interface SoftmaxAttributes extends AttributeWithCacheKey {
Expand All @@ -33,7 +29,7 @@ export const softmaxProgramMetadata = {


const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttributes): ProgramInfo => {
const dataType = 'f32';
const dataType = tensorTypeToWsglStorageType(input.dataType);
const shape = input.dims;
const outputSize = ShapeUtil.size(shape);
const WG = 64;
Expand All @@ -48,6 +44,8 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
const cols = shape[axis];
const rows = outputSize / cols;

// 6.2.4 in wgsl spec
const threadMaxDecl = dataType === 'f32' ? 'var threadMax: f32 = -3.402823e+38f;' : 'var threadMax: f16 = -65504.0h;';
const getShaderSource = (_shaderHelper: ShaderHelper) => `
var<workgroup> rowMaxShared : ${dataType};
var<workgroup> rowSumShared : ${dataType};
Expand Down Expand Up @@ -76,7 +74,7 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
let row_stride : i32 = ${cols};
// find the rows max
var threadMax = -3.402823e+38f; // 6.2.4 in wgsl spec
${threadMaxDecl}
for (var col = lindex; col < cols; col += wg) {
let value = getValue(row, col, row_stride);
threadMax = max(threadMax, value);
Expand All @@ -100,7 +98,7 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
workgroupBarrier();
// find the rows sum
var threadSum = 0.0;
var threadSum: ${dataType} = 0.0;
for (var col = lindex; col < cols; col += wg) {
let subExp = exp(getValue(row, col, row_stride) - rowMaxShared);
threadSum += subExp;
Expand Down
6 changes: 0 additions & 6 deletions js/web/lib/wasm/jsep/webgpu/ops/transpose.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
Expand All @@ -22,11 +21,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length !== 1) {
throw new Error('Transpose requires 1 input.');
}

if (inputs[0].dataType !== DataType.float && inputs[0].dataType !== DataType.int32 &&
inputs[0].dataType !== DataType.uint32) {
throw new Error('Transpose only support float, int32, and uint32 data types');
}
};

const getAdjustedPerm = (inputShape: readonly number[], perm: number[]): number[] =>
Expand Down
18 changes: 18 additions & 0 deletions onnxruntime/core/providers/js/js_data_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,30 @@ namespace js {
using SupportedTypes =
TypeList<
float,
#ifdef ENABLE_WEBASSEMBLY_FLOAT16
MLFloat16,
#endif
int32_t,
uint32_t>;

using SupportedFloats =
#ifdef ENABLE_WEBASSEMBLY_FLOAT16
TypeList<
float,
MLFloat16>;
#else
TypeList<float>;
#endif

const std::vector<MLDataType>& JsepSupportedDataTypes() {
static const std::vector<MLDataType> supportedDataTypes = BuildKernelDefConstraintsFromTypeList<SupportedTypes>();
return supportedDataTypes;
}

const std::vector<MLDataType>& JsepSupportedFloatTypes() {
static const std::vector<MLDataType> supportedDataTypes = BuildKernelDefConstraintsFromTypeList<SupportedFloats>();
return supportedDataTypes;
}

} // namespace js
} // namespace onnxruntime
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/js/js_data_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
namespace onnxruntime {
namespace js {
std::vector<MLDataType>& JsepSupportedDataTypes();
}
std::vector<MLDataType>& JsepSupportedFloatTypes();
} // namespace js
} // namespace onnxruntime
28 changes: 14 additions & 14 deletions onnxruntime/core/providers/js/js_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,10 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnn
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, Conv);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ConvTranspose);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, ConvTranspose);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, float, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, float, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Gemm);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Gemm);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Gemm);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gemm);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gemm);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, MatMul);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, MatMul);

Expand All @@ -269,9 +269,9 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnn
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, ArgMin);

class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, Softmax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Softmax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Softmax);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Softmax);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Softmax);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Softmax);

class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 3, Concat);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 4, 10, Concat);
Expand Down Expand Up @@ -496,10 +496,10 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, Conv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ConvTranspose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, ConvTranspose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, MatMul)>,

Expand All @@ -522,9 +522,9 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, ArgMin)>,

BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Softmax)>,

BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 3, Concat)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 4, 10, Concat)>,
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/js/js_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "core/framework/op_kernel.h"
#include "core/providers/js/js_execution_provider.h"
#include "core/providers/js/js_data_types.h"

struct pthreadpool;

Expand Down
63 changes: 28 additions & 35 deletions onnxruntime/core/providers/js/operators/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,41 +8,34 @@
namespace onnxruntime {
namespace js {

#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
Gemm, \
kOnnxDomain, \
13, \
T, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Gemm<T>); \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
Gemm, \
kOnnxDomain, \
11, 12, \
T, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Gemm<T>); \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
Gemm, \
kOnnxDomain, \
9, 10, \
T, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Gemm<T>); \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
Gemm, \
kOnnxDomain, \
7, 8, \
T, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Gemm<T>);

REGISTER_KERNEL_TYPED(float)
ONNX_OPERATOR_KERNEL_EX(
Gemm,
kOnnxDomain,
13,
kJsExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
Gemm);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Gemm,
kOnnxDomain,
11, 12,
kJsExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
Gemm);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Gemm,
kOnnxDomain,
9, 10,
kJsExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
Gemm);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Gemm,
kOnnxDomain,
7, 8,
kJsExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
Gemm);

} // namespace js
} // namespace onnxruntime
1 change: 0 additions & 1 deletion onnxruntime/core/providers/js/operators/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
namespace onnxruntime {
namespace js {

template <typename T>
class Gemm : public JsKernel {
public:
Gemm(const OpKernelInfo& info) : JsKernel(info) {
Expand Down
14 changes: 6 additions & 8 deletions onnxruntime/core/providers/js/operators/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,25 @@ namespace onnxruntime {
namespace js {

#define REGISTER_SOFTMAX_ELEMENTWISE_VERSIONED_KERNEL(SoftmaxOp, sinceVersion, endVersion) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
SoftmaxOp, \
kOnnxDomain, \
sinceVersion, endVersion, \
float, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()), \
SoftmaxOp<float>);
.TypeConstraint("T", JsepSupportedFloatTypes()), \
SoftmaxOp);

#define REGISTER_SOFTMAX_ELEMENTWISE_KERNEL(SoftmaxOp, sinceVersion) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
ONNX_OPERATOR_KERNEL_EX( \
SoftmaxOp, \
kOnnxDomain, \
sinceVersion, \
float, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()) \
.TypeConstraint("T", JsepSupportedFloatTypes()) \
.InputMemoryType(OrtMemTypeCPU, 1), \
SoftmaxOp<float>);
SoftmaxOp);

REGISTER_SOFTMAX_ELEMENTWISE_VERSIONED_KERNEL(Softmax, 1, 10);
REGISTER_SOFTMAX_ELEMENTWISE_VERSIONED_KERNEL(Softmax, 11, 12);
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/core/providers/js/operators/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

namespace onnxruntime {
namespace js {
template <typename T>
class Softmax : public JsKernel {
public:
Softmax(const OpKernelInfo& info) : JsKernel(info) {
Expand Down
8 changes: 2 additions & 6 deletions onnxruntime/core/providers/js/operators/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
1, 12,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<uint32_t>()}),
.TypeConstraint("T", JsepSupportedDataTypes()),
Transpose);

ONNX_OPERATOR_KERNEL_EX(
Expand All @@ -23,9 +21,7 @@ ONNX_OPERATOR_KERNEL_EX(
13,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<uint32_t>()}),
.TypeConstraint("T", JsepSupportedDataTypes()),
Transpose);

} // namespace js
Expand Down

0 comments on commit ce329e7

Please sign in to comment.