Skip to content

Commit

Permalink
fix f16 for attention, enable slice and flatten for more types (#19262)
Browse files Browse the repository at this point in the history
  • Loading branch information
guschmue authored Jan 29, 2024
1 parent e96a038 commit 9e69606
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 13 deletions.
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView
if (sum == 0) {
for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) {
x[offset + i] = ${fillVector('f32', components, 'uniforms.d_inv')};
x[offset + i] = ${fillVector(elemValueType, components, 'uniforms.d_inv')};
}
} else {
for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) {
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/core/providers/js/operators/flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
.TypeConstraint("T", JsepSupportedFloatTypes()),
Flatten);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Expand All @@ -23,7 +23,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
.TypeConstraint("T", JsepSupportedFloatTypes()),
Flatten);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Expand All @@ -33,7 +33,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
.TypeConstraint("T", JsepSupportedFloatTypes()),
Flatten);

ONNX_OPERATOR_KERNEL_EX(
Expand All @@ -43,7 +43,7 @@ ONNX_OPERATOR_KERNEL_EX(
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
.TypeConstraint("T", JsepSupportedFloatTypes()),
Flatten);

} // namespace js
Expand Down
12 changes: 4 additions & 8 deletions onnxruntime/core/providers/js/operators/slice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
1, 9,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int32_t>()}),
.TypeConstraint("T", JsepSupportedDataTypes()),
Slice_1);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Expand All @@ -26,8 +25,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
.InputMemoryType(OrtMemTypeCPU, 2)
.InputMemoryType(OrtMemTypeCPU, 3)
.InputMemoryType(OrtMemTypeCPU, 4)
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int32_t>()}),
.TypeConstraint("T", JsepSupportedDataTypes()),
Slice);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Expand All @@ -40,8 +38,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
.InputMemoryType(OrtMemTypeCPU, 2)
.InputMemoryType(OrtMemTypeCPU, 3)
.InputMemoryType(OrtMemTypeCPU, 4)
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int32_t>()}),
.TypeConstraint("T", JsepSupportedDataTypes()),
Slice);

ONNX_OPERATOR_KERNEL_EX(
Expand All @@ -54,8 +51,7 @@ ONNX_OPERATOR_KERNEL_EX(
.InputMemoryType(OrtMemTypeCPU, 2)
.InputMemoryType(OrtMemTypeCPU, 3)
.InputMemoryType(OrtMemTypeCPU, 4)
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int32_t>()}),
.TypeConstraint("T", JsepSupportedDataTypes()),
Slice);

} // namespace js
Expand Down

0 comments on commit 9e69606

Please sign in to comment.