Skip to content

Commit

Permalink
Merge branch 'main' into webgpu_where
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Sep 19, 2023
2 parents b035d3a + f969e7f commit dbd9e63
Show file tree
Hide file tree
Showing 53 changed files with 3,162 additions and 1,393 deletions.
13 changes: 7 additions & 6 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1351,8 +1351,8 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints

<dl>
<dt><tt>T1</tt> : tensor(int8), tensor(uint8), tensor(int32)</dt>
<dd>Constrain 'x' and 'x_zero_point' to 8-bit integer tensors or 32-bit signed integer tensors.</dd>
<dt><tt>T1</tt> : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16), tensor(int32)</dt>
<dd>Constrain 'x' and 'x_zero_point' to 8-bit integer tensors, 16-bit integer tensors, or 32-bit signed integer tensors.</dd>
<dt><tt>T2</tt> : tensor(float16), tensor(float)</dt>
<dd>Constrain 'y', 'x_scale' to float tensors.</dd>
</dl>
Expand Down Expand Up @@ -4194,8 +4194,9 @@ This version of the operator has been available since version 1 of the 'com.micr
### <a name="com.microsoft.QuantizeLinear"></a><a name="com.microsoft.quantizelinear">**com.microsoft.QuantizeLinear**</a>

The linear quantization operator. It consumes a full precision data, a scale, a zero point to compute the low precision / quantized tensor.
The quantization formula is y = saturate ((x / y_scale) + y_zero_point).For saturation, it saturates to [0, 255] if it's uint8, or [-128, 127] if it's int8.
For (x / y_scale), it's rounding to nearest ties to even. Refer to https://en.wikipedia.org/wiki/Rounding for details.
The quantization formula is y = saturate ((x / y_scale) + y_zero_point). For saturation, it saturates to [0, 255] if it's uint8, [-128, 127] if it's int8,
[0, 65,535] if it's uint16, and [-32,768, 32,767] if it's int16. For (x / y_scale), it's rounding to nearest ties to even.
Refer to https://en.wikipedia.org/wiki/Rounding for details.
Scale and zero point must have same shape. They must be either scalar (per tensor) or 1-D tensor (per 'axis').

#### Version
Expand Down Expand Up @@ -4232,8 +4233,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>T1</tt> : tensor(float16), tensor(float)</dt>
<dd>Constrain 'x', 'y_scale' to float tensors.</dd>
<dt><tt>T2</tt> : tensor(int8), tensor(uint8)</dt>
<dd>Constrain 'y_zero_point' and 'y' to 8-bit integer tensors.</dd>
<dt><tt>T2</tt> : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16)</dt>
<dd>Constrain 'y_zero_point' and 'y' to 8-bit and 16-bit integer tensors.</dd>
</dl>


Expand Down
4 changes: 2 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ Do not modify directly.*
|CDist|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(double), tensor(float)|
|ConvTransposeWithDynamicPads|*in* X:**T**<br> *in* W:**T**<br> *in* Pads:**tensor(int64)**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|CropAndResize|*in* X:**T1**<br> *in* rois:**T1**<br> *in* batch_indices:**T2**<br> *in* crop_size:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int32)|
|DequantizeLinear|*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|1+|**T1** = tensor(int32), tensor(int8), tensor(uint8)<br/> **T2** = tensor(float)|
|DequantizeLinear|*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|1+|**T1** = tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint8)<br/> **T2** = tensor(float)|
|DynamicQuantizeLSTM|*in* X:**T**<br> *in* W:**T2**<br> *in* R:**T2**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *in* W_scale:**T**<br> *in* W_zero_point:**T2**<br> *in* R_scale:**T**<br> *in* R_zero_point:**T2**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|1+|**T** = tensor(float)<br/> **T1** = tensor(int32)<br/> **T2** = tensor(int8), tensor(uint8)|
|DynamicQuantizeMatMul|*in* A:**T1**<br> *in* B:**T2**<br> *in* b_scale:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int8), tensor(uint8)|
|EmbedLayerNormalization|*in* input_ids:**T1**<br> *in* segment_ids:**T1**<br> *in* word_embedding:**T**<br> *in* position_embedding:**T**<br> *in* segment_embedding:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* mask:**T1**<br> *in* position_ids:**T1**<br> *out* output:**T**<br> *out* mask_index:**T1**<br> *out* embedding_sum:**T**|1+|**T** = tensor(float)|
Expand Down Expand Up @@ -472,7 +472,7 @@ Do not modify directly.*
|QLinearSigmoid|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* X_zero_point:**T**<br> *in* Y_scale:**tensor(float)**<br> *in* Y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearSoftmax|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* x_zero_point:**T**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearWhere|*in* condition:**B**<br> *in* X:**T**<br> *in* x_scale:**TF**<br> *in* x_zero_point:**T**<br> *in* Y:**T**<br> *in* y_scale:**TF**<br> *in* y_zero_point:**T**<br> *in* z_scale:**TF**<br> *in* z_zero_point:**T**<br> *out* Z:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int8), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int16), tensor(int8), tensor(uint16), tensor(uint8)|
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**<br> *in* limit:**T**<br> *in* delta:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
|SampleOp|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
Expand Down
32 changes: 19 additions & 13 deletions js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {MAX_CLIP, MIN_CLIP, ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';

import {inputVariable, outputVariable, ShaderHelper} from './common';
import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common';

type BuiltinFunctionName = string;
type ElementwiseCustomExpression = (expression: string) => string;
Expand Down Expand Up @@ -101,6 +101,9 @@ export const parseCastAttributes = (attributes: Record<string, unknown>): CastAt
export const cast = (context: ComputeContext, attributes: CastAttributes): void => {
let func: ElementwiseFunctionCall;
switch (attributes.to) {
case DataType.float16:
func = 'vec4<f16>';
break;
case DataType.float:
func = 'vec4<f32>';
break;
Expand All @@ -126,11 +129,12 @@ export interface ClipAttributes extends AttributeWithCacheKey {
}

export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): void => {
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
context.compute(
createElementwiseProgramInfoLoader(
context.inputs[0], 'Clip', a => `clamp(${a}, clip_min_, clip_max_)`, `
const clip_min_: vec4<f32> = vec4(f32(${attributes.min}));
const clip_max_: vec4<f32> = vec4(f32(${attributes.max}));
const clip_min_: vec4<${dataType}> = vec4(${dataType}(${attributes.min}));
const clip_max_: vec4<${dataType}> = vec4(${dataType}(${attributes.max}));
`,
attributes.cacheKey),
{inputs: [0]});
Expand Down Expand Up @@ -180,13 +184,13 @@ export const elu = (context: ComputeContext, attributes: AlphaAttributes): void
attributes.cacheKey));
};

export const erfImpl = (dataType: string) => `
const r0: f32 = 0.3275911;
const r1: f32 = 0.254829592;
const r2: f32 = -0.284496736;
const r3: f32 = 1.421413741;
const r4: f32 = -1.453152027;
const r5: f32 = 1.061405429;
export const erfImpl = (dataType: string, varType = 'f32') => `
const r0: ${varType} = 0.3275911;
const r1: ${varType} = 0.254829592;
const r2: ${varType} = -0.284496736;
const r3: ${varType} = 1.421413741;
const r4: ${varType} = -1.453152027;
const r5: ${varType} = 1.061405429;
fn erf_vf32(v: ${dataType}) -> ${dataType} {
let absv = abs(v);
Expand All @@ -195,8 +199,9 @@ fn erf_vf32(v: ${dataType}) -> ${dataType} {
}`;

export const erf = (context: ComputeContext): void => {
context.compute(
createElementwiseProgramInfoLoader(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl('vec4<f32>')));
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfoLoader(
context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(`vec4<${dataType}>`, dataType)));
};

export const exp = (context: ComputeContext): void => {
Expand All @@ -208,9 +213,10 @@ export const floor = (context: ComputeContext): void => {
};

export const gelu = (context: ComputeContext): void => {
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfoLoader(
context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`,
erfImpl('vec4<f32>')));
erfImpl(`vec4<${dataType}>`, dataType)));
};

export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => {
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,13 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLine
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearAveragePool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int32_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearLeakyRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QLinearLeakyRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearSigmoid);
Expand Down Expand Up @@ -191,9 +195,13 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearAveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int32_t, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint16_t, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int16_t, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearLeakyRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QLinearLeakyRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QLinearSigmoid)>,
Expand Down
56 changes: 0 additions & 56 deletions onnxruntime/contrib_ops/cpu/quantization/quantize_ops.cc

This file was deleted.

16 changes: 9 additions & 7 deletions onnxruntime/core/graph/contrib_ops/quantization_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,9 @@ Performs element-wise binary {name} on 8 bit data types (with Numpy-style broadc

static const char* QuantizeLinear_ver1_doc = R"DOC(
The linear quantization operator. It consumes a full precision data, a scale, a zero point to compute the low precision / quantized tensor.
The quantization formula is y = saturate ((x / y_scale) + y_zero_point).For saturation, it saturates to [0, 255] if it's uint8, or [-128, 127] if it's int8.
For (x / y_scale), it's rounding to nearest ties to even. Refer to https://en.wikipedia.org/wiki/Rounding for details.
The quantization formula is y = saturate ((x / y_scale) + y_zero_point). For saturation, it saturates to [0, 255] if it's uint8, [-128, 127] if it's int8,
[0, 65,535] if it's uint16, and [-32,768, 32,767] if it's int16. For (x / y_scale), it's rounding to nearest ties to even.
Refer to https://en.wikipedia.org/wiki/Rounding for details.
Scale and zero point must have same shape. They must be either scalar (per tensor) or 1-D tensor (per 'axis').)DOC";

ONNX_MS_OPERATOR_SET_SCHEMA(
Expand All @@ -161,8 +162,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"T2", OpSchema::Optional)
.Output(0, "y", "N-D quantized output tensor. It has same shape as input 'x'.", "T2")
.TypeConstraint("T1", {"tensor(float16)", "tensor(float)"}, "Constrain 'x', 'y_scale' to float tensors.")
.TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)"},
"Constrain 'y_zero_point' and 'y' to 8-bit integer tensors.")
.TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)", "tensor(int16)", "tensor(uint16)"},
"Constrain 'y_zero_point' and 'y' to 8-bit and 16-bit integer tensors.")
.SetDoc(QuantizeLinear_ver1_doc)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
if (ctx.getNumInputs() == 3 && ctx.getInputType(2) != nullptr) {
Expand Down Expand Up @@ -202,9 +203,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(DequantizeLinear, 1,
"T1", OpSchema::Optional)
.Output(0, "y", "N-D full precision output tensor. It has same shape as input 'x'.",
"T2")
.TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)", "tensor(int32)"},
"Constrain 'x' and 'x_zero_point' to 8-bit integer tensors or 32-bit "
"signed integer tensors.")
.TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)", "tensor(int16)",
"tensor(uint16)", "tensor(int32)"},
"Constrain 'x' and 'x_zero_point' to 8-bit integer tensors, "
"16-bit integer tensors, or 32-bit signed integer tensors.")
.TypeConstraint("T2", {"tensor(float16)", "tensor(float)"},
"Constrain 'y', 'x_scale' to float tensors.")
.SetDoc(DequantizeLinear_ver1_doc)
Expand Down
24 changes: 24 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,24 @@ void
int8_t ZeroPoint
);

typedef
void
(MLASCALL MLAS_QUANTIZE_LINEAR_U16_KERNEL)(
const float* Input,
uint16_t* Output,
size_t N,
float Scale,
uint16_t ZeroPoint);

typedef
void
(MLASCALL MLAS_QUANTIZE_LINEAR_S16_KERNEL)(
const float* Input,
int16_t* Output,
size_t N,
float Scale,
int16_t ZeroPoint);

template<typename InputType, typename FilterType>
struct MLAS_QUANT_KERNEL
{
Expand Down Expand Up @@ -749,6 +767,8 @@ extern "C" {
MLAS_QLINEAR_BINARY_OP_U8_KERNEL MlasQLinearAddU8Kernel;
MLAS_QUANTIZE_LINEAR_S8_KERNEL MlasQuantizeLinearS8Kernel;
MLAS_QUANTIZE_LINEAR_U8_KERNEL MlasQuantizeLinearU8Kernel;
MLAS_QUANTIZE_LINEAR_S16_KERNEL MlasQuantizeLinearS16Kernel;
MLAS_QUANTIZE_LINEAR_U16_KERNEL MlasQuantizeLinearU16Kernel;
#if defined(MLAS_TARGET_AMD64)
MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasErfKernelFma3;
MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32KernelFma3;
Expand Down Expand Up @@ -959,6 +979,8 @@ struct MLAS_PLATFORM {
const MLAS_GEMM_QUANT_DISPATCH* GemmU8X8Dispatch;
MLAS_QUANTIZE_LINEAR_S8_KERNEL* QuantizeLinearS8Kernel;
MLAS_QUANTIZE_LINEAR_U8_KERNEL* QuantizeLinearU8Kernel;
MLAS_QUANTIZE_LINEAR_S16_KERNEL* QuantizeLinearS16Kernel;
MLAS_QUANTIZE_LINEAR_U16_KERNEL* QuantizeLinearU16Kernel;
#endif
#if defined(MLAS_TARGET_AMD64)
MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1Routine;
Expand Down Expand Up @@ -986,6 +1008,8 @@ struct MLAS_PLATFORM {
MLAS_REDUCE_MINIMUM_MAXIMUM_FLOAT_KERNEL* ReduceMinimumMaximumF32Kernel;
MLAS_QUANTIZE_LINEAR_S8_KERNEL* QuantizeLinearS8Kernel;
MLAS_QUANTIZE_LINEAR_U8_KERNEL* QuantizeLinearU8Kernel;
MLAS_QUANTIZE_LINEAR_S16_KERNEL* QuantizeLinearS16Kernel;
MLAS_QUANTIZE_LINEAR_U16_KERNEL* QuantizeLinearU16Kernel;
uint32_t NchwcBlockSize;
uint32_t PreferredBufferAlignment;
int32_t MaximumThreadCount;
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ Return Value:
this->QLinearAddU8Kernel = MlasQLinearAddU8Kernel;
this->QuantizeLinearS8Kernel = MlasQuantizeLinearS8Kernel;
this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8Kernel;
this->QuantizeLinearS16Kernel = MlasQuantizeLinearS16Kernel;
this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel;

this->NchwcBlockSize = 8;
this->PreferredBufferAlignment = MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT;
Expand Down Expand Up @@ -475,6 +477,8 @@ Return Value:
this->GemmDoubleKernel = MlasDgemmKernel;
this->QuantizeLinearS8Kernel = MlasQuantizeLinearS8Kernel;
this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8Kernel;
this->QuantizeLinearS16Kernel = MlasQuantizeLinearS16Kernel;
this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel;

#if defined(__linux__)
unsigned long hwcap2 = getauxval(AT_HWCAP2);
Expand Down
Loading

0 comments on commit dbd9e63

Please sign in to comment.