From 2e597fa4e5f292a088db560ada67b7bcdf5a5bc7 Mon Sep 17 00:00:00 2001 From: guschmue Date: Tue, 9 Jan 2024 09:08:09 -0800 Subject: [PATCH 1/4] more fixes for access above 2GB --- js/web/lib/wasm/jsep/init.ts | 8 ++-- onnxruntime/core/providers/js/js_kernel.h | 2 +- .../core/providers/js/operators/conv.h | 24 +++++----- .../providers/js/operators/conv_transpose.h | 40 ++++++++-------- onnxruntime/core/providers/js/operators/pad.h | 4 +- .../core/providers/js/operators/reduce.h | 46 +++++++++---------- .../core/providers/js/operators/resize.h | 4 +- .../core/providers/js/operators/slice.h | 12 ++--- .../core/providers/js/operators/split.h | 4 +- .../core/providers/js/operators/transpose.h | 4 +- 10 files changed, 75 insertions(+), 73 deletions(-) diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 3c6edf3ebb35d..8f77f1fb29acc 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -61,7 +61,8 @@ class ComputeContextImpl implements ComputeContext { return this.backend.currentKernelCustomData; } get customDataBuffer(): Uint8Array { - return this.module.HEAPU8.subarray(this.customDataOffset, this.customDataOffset + this.customDataSize); + return this.module.HEAPU8.subarray( + this.customDataOffset >>> 0, (this.customDataOffset >>> 0) + this.customDataSize); } private customDataOffset = 0; private customDataSize = 0; @@ -170,7 +171,7 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte backend.memcpy(src, dst); } else { LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${src}, gpuDataId=${dst}, size=${size}`); - const data = module.HEAPU8.subarray(src, src + size); + const data = module.HEAPU8.subarray(src >>> 0, (src >>> 0) + size); backend.upload(dst, data); } }, @@ -182,7 +183,8 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte 'verbose', () => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`); - await backend.download(gpuDataId, () => module.HEAPU8.subarray(dataOffset, dataOffset + size)); + await backend.download(gpuDataId, + () => module.HEAPU8.subarray(dataOffset >>> 0, (dataOffset >>> 0) + size)); }, // jsepCreateKernel diff --git a/onnxruntime/core/providers/js/js_kernel.h b/onnxruntime/core/providers/js/js_kernel.h index b850bea4bc275..1b15c97f0e587 100644 --- a/onnxruntime/core/providers/js/js_kernel.h +++ b/onnxruntime/core/providers/js/js_kernel.h @@ -67,7 +67,7 @@ namespace js { float value; \ ORT_ENFORCE(info.GetAttr(#attr_name, &value));, \ , ({#attr_name : $1}), static_cast(value)) -#define JSEP_HEAP_PTR(ptr) reinterpret_cast(ptr) +#define JSEP_HEAP_INDEX(ptr) reinterpret_cast(ptr) // TODO: // class JsMultiProgramKernel : public OpKernel { /* TBD */ }; diff --git a/onnxruntime/core/providers/js/operators/conv.h b/onnxruntime/core/providers/js/operators/conv.h index 98a530c6b77f6..2c28263646140 100644 --- a/onnxruntime/core/providers/js/operators/conv.h +++ b/onnxruntime/core/providers/js/operators/conv.h @@ -43,24 +43,24 @@ class ConvBase : public JsKernel { "dilations" : [$2], "group" : $3, "kernel_shape" : [$4], - "pads" : $5 ? Array.from(HEAP32.subarray($6, $6 + $5)) : [], + "pads" : $5 ? Array.from(HEAP32.subarray($6 >>> 0, ($6 >>> 0) + $5)) : [], "strides" : [$7], - "w_is_const" : () JS_ARROW(!!HEAP8[$9]), + "w_is_const" : () JS_ARROW(!!HEAP8[$9 >>> 0]), "activation" : UTF8ToString($10), - "activation_params" : $11 ? Array.from(HEAPF32.subarray($12, $12 + $11)) : [] + "activation_params" : $11 ? Array.from(HEAPF32.subarray($12 >>> 0, ($12 >>> 0) + $11)) : [] }), static_cast(conv_attrs_.auto_pad), static_cast(conv_attrs_.dilations.size() > 0 ? conv_attrs_.dilations[0] : 0), static_cast(conv_attrs_.group), static_cast(kernel_shape_0), static_cast(local_pads.size()), - JSEP_HEAP_PTR(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2, + JSEP_HEAP_INDEX(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2, static_cast(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0), static_cast(channels_last), - JSEP_HEAP_PTR(&w_is_const_), + JSEP_HEAP_INDEX(&w_is_const_), conv_attrs_.activation.c_str(), activation_params.size(), - JSEP_HEAP_PTR(activation_params_ptr) >> 2); + JSEP_HEAP_INDEX(activation_params_ptr) >> 2); } else { JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({ "format" : $11 ? "NHWC" : "NCHW", @@ -68,11 +68,11 @@ class ConvBase : public JsKernel { "dilations" : [ $2, $3 ], "group" : $4, "kernel_shape" : [ $5, $6 ], - "pads" : $7 ? Array.from(HEAP32.subarray($8, $8 + $7)) : [], + "pads" : $7 ? Array.from(HEAP32.subarray($8 >>> 0, ($8 >>> 0) + $7)) : [], "strides" : [ $9, $10 ], - "w_is_const" : () JS_ARROW(!!HEAP8[$12]), + "w_is_const" : () JS_ARROW(!!HEAP8[$12 >>> 0]), "activation" : UTF8ToString($13), - "activation_params" : $14 ? Array.from(HEAPF32.subarray($15, $15 + $14)) : [] + "activation_params" : $14 ? Array.from(HEAPF32.subarray($15 >>> 0, ($15 >>> 0) + $14)) : [] }), static_cast(conv_attrs_.auto_pad), static_cast(conv_attrs_.dilations.size() > 0 ? conv_attrs_.dilations[0] : 0), @@ -81,14 +81,14 @@ class ConvBase : public JsKernel { static_cast(kernel_shape_0), static_cast(kernel_shape_1), static_cast(local_pads.size()), - JSEP_HEAP_PTR(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2, + JSEP_HEAP_INDEX(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2, static_cast(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0), static_cast(conv_attrs_.strides.size() > 1 ? conv_attrs_.strides[1] : 0), static_cast(channels_last), - JSEP_HEAP_PTR(&w_is_const_), + JSEP_HEAP_INDEX(&w_is_const_), conv_attrs_.activation.c_str(), activation_params.size(), - JSEP_HEAP_PTR(activation_params_ptr) >> 2); + JSEP_HEAP_INDEX(activation_params_ptr) >> 2); } } diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.h b/onnxruntime/core/providers/js/operators/conv_transpose.h index 353a946e95c21..3e37f80d1d0ce 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.h +++ b/onnxruntime/core/providers/js/operators/conv_transpose.h @@ -51,9 +51,9 @@ class ConvTranspose : public JsKernel { "kernel_shape" : [$4], "pads" : [ $5, $6 ], "strides" : [$7], - "wIsConst" : () JS_ARROW(!!HEAP8[$9]), - "outputPadding" : $10 ? Array.from(HEAP32.subarray($11, $11 + $10)) : [], - "outputShape" : $12 ? Array.from(HEAP32.subarray($13, $13 + $12)) : [], + "wIsConst" : () JS_ARROW(!!HEAP8[$9 >>> 0]), + "outputPadding" : $10 ? Array.from(HEAP32.subarray($11 >>> 0, ($11 >>> 0) + $10)) : [], + "outputShape" : $12 ? Array.from(HEAP32.subarray($13 >>> 0, ($13 >>> 0) + $12)) : [], "activation" : UTF8ToString($14) }), static_cast(conv_transpose_attrs_.auto_pad), @@ -64,11 +64,11 @@ class ConvTranspose : public JsKernel { static_cast(pads_1), static_cast(strides), static_cast(channels_last), - JSEP_HEAP_PTR(&w_is_const_), + JSEP_HEAP_INDEX(&w_is_const_), gsl::narrow_cast(local_output_padding.size()), - JSEP_HEAP_PTR(local_output_padding_ptr) >> 2, + JSEP_HEAP_INDEX(local_output_padding_ptr) >> 2, gsl::narrow_cast(local_output_shape.size()), - JSEP_HEAP_PTR(local_output_shape_ptr) >> 2, + JSEP_HEAP_INDEX(local_output_shape_ptr) >> 2, conv_transpose_attrs_.activation.c_str()); } else { constexpr size_t pads_vec_size = 4; @@ -103,28 +103,28 @@ class ConvTranspose : public JsKernel { JSEP_INIT_KERNEL_ATTRIBUTE(ConvTranspose, ({ "format" : $7 ? "NHWC" : "NCHW", "autoPad" : $1, - "dilations" : Array.from(HEAP32.subarray($2, $2 + /* dialations_vec_size */ 2)), + "dilations" : Array.from(HEAP32.subarray($2 >>> 0, ($2 >>> 0) + /* dialations_vec_size */ 2)), "group" : $3, - "kernelShape" : Array.from(HEAP32.subarray($4, $4 + /* kernel_shape_vec_size */ 2)), - "pads" : Array.from(HEAP32.subarray($5, $5 + /* pads_vec_size */ 4)), - "strides" : Array.from(HEAP32.subarray($6, $6 + /* strides_vec_size */ 2)), - "wIsConst" : () JS_ARROW(!!HEAP8[$8]), - "outputPadding" : ($9 > 0) ? Array.from(HEAP32.subarray($10, $10 + $9)) : [], - "outputShape" : ($11 > 0) ? Array.from(HEAP32.subarray($12, $12 + $11)) : [], + "kernelShape" : Array.from(HEAP32.subarray($4 >>> 0, ($4 >>> 0) + /* kernel_shape_vec_size */ 2)), + "pads" : Array.from(HEAP32.subarray($5 >>> 0, ($5 >>> 0) + /* pads_vec_size */ 4)), + "strides" : Array.from(HEAP32.subarray($6 >>> 0, ($6 >>> 0) + /* strides_vec_size */ 2)), + "wIsConst" : () JS_ARROW(!!HEAP8[$8 >>> 0]), + "outputPadding" : ($9 > 0) ? Array.from(HEAP32.subarray($10 >>> 0, ($10 >>> 0) + $9)) : [], + "outputShape" : ($11 > 0) ? Array.from(HEAP32.subarray($12 >>> 0, ($12 >>> 0) + $11)) : [], "activation" : UTF8ToString($13) }), static_cast(conv_transpose_attrs_.auto_pad), - JSEP_HEAP_PTR(local_dilations.data()) >> 2, + JSEP_HEAP_INDEX(local_dilations.data()) >> 2, static_cast(conv_transpose_attrs_.group), - JSEP_HEAP_PTR(local_kernel_shape.data()) >> 2, - JSEP_HEAP_PTR(local_pads.data()) >> 2, - JSEP_HEAP_PTR(local_strides.data()) >> 2, + JSEP_HEAP_INDEX(local_kernel_shape.data()) >> 2, + JSEP_HEAP_INDEX(local_pads.data()) >> 2, + JSEP_HEAP_INDEX(local_strides.data()) >> 2, static_cast(channels_last), - JSEP_HEAP_PTR(&w_is_const_), + JSEP_HEAP_INDEX(&w_is_const_), gsl::narrow_cast(local_output_padding.size()), - JSEP_HEAP_PTR(local_output_padding_ptr) >> 2, + JSEP_HEAP_INDEX(local_output_padding_ptr) >> 2, gsl::narrow_cast(local_output_shape.size()), - JSEP_HEAP_PTR(local_output_shape_ptr) >> 2, + JSEP_HEAP_INDEX(local_output_shape_ptr) >> 2, conv_transpose_attrs_.activation.c_str()); } } diff --git a/onnxruntime/core/providers/js/operators/pad.h b/onnxruntime/core/providers/js/operators/pad.h index bf808be949cf8..c3929ad592f38 100644 --- a/onnxruntime/core/providers/js/operators/pad.h +++ b/onnxruntime/core/providers/js/operators/pad.h @@ -22,11 +22,11 @@ class Pad : public JsKernel, public PadBase { JSEP_INIT_KERNEL_ATTRIBUTE(Pad, ({"mode" : $1, "value" : $2, - "pads" : $3 ? Array.from(HEAP32.subarray($4, $4 + $3)) : []}), + "pads" : $3 ? Array.from(HEAP32.subarray($4 >>> 0, ($4 >>> 0) + $3)) : []}), static_cast(mode_), static_cast(value_), gsl::narrow_cast(pads.size()), - JSEP_HEAP_PTR((pads.size() > 0) ? pads.data() : nullptr) >> 2); + JSEP_HEAP_INDEX((pads.size() > 0) ? pads.data() : nullptr) >> 2); } }; diff --git a/onnxruntime/core/providers/js/operators/reduce.h b/onnxruntime/core/providers/js/operators/reduce.h index 95c4f2bec230d..3db168a59cabc 100644 --- a/onnxruntime/core/providers/js/operators/reduce.h +++ b/onnxruntime/core/providers/js/operators/reduce.h @@ -8,29 +8,29 @@ namespace onnxruntime { namespace js { -#define JSEP_DEFINE_REDUCE_KERNEL(ReduceKernel) \ - template \ - class ReduceKernel : public JsKernel, public ReduceKernelBase { \ - public: \ - using ReduceKernelBase::axes_; \ - using ReduceKernelBase::noop_with_empty_axes_; \ - using ReduceKernelBase::keepdims_; \ - ReduceKernel(const OpKernelInfo& info) : JsKernel(info), ReduceKernelBase(info) { \ - std::vector axes(axes_.size()); \ - if (axes_.size() > 0) { \ - std::transform(axes_.begin(), axes_.end(), axes.begin(), \ - [](int64_t axis) { return gsl::narrow_cast(axis); }); \ - } \ - JSEP_INIT_KERNEL_ATTRIBUTE(ReduceKernel, ({ \ - "keepDims" : !!$1, \ - "noopWithEmptyAxes" : !!$2, \ - "axes" : $3 ? (Array.from(HEAP32.subarray($4, $4 + $3))) : [], \ - }), \ - static_cast(keepdims_), \ - static_cast(noop_with_empty_axes_), \ - gsl::narrow_cast(axes.size()), \ - JSEP_HEAP_PTR((axes.size() > 0) ? axes.data() : nullptr) >> 2); \ - } \ +#define JSEP_DEFINE_REDUCE_KERNEL(ReduceKernel) \ + template \ + class ReduceKernel : public JsKernel, public ReduceKernelBase { \ + public: \ + using ReduceKernelBase::axes_; \ + using ReduceKernelBase::noop_with_empty_axes_; \ + using ReduceKernelBase::keepdims_; \ + ReduceKernel(const OpKernelInfo& info) : JsKernel(info), ReduceKernelBase(info) { \ + std::vector axes(axes_.size()); \ + if (axes_.size() > 0) { \ + std::transform(axes_.begin(), axes_.end(), axes.begin(), \ + [](int64_t axis) { return gsl::narrow_cast(axis); }); \ + } \ + JSEP_INIT_KERNEL_ATTRIBUTE(ReduceKernel, ({ \ + "keepDims" : !!$1, \ + "noopWithEmptyAxes" : !!$2, \ + "axes" : $3 ? (Array.from(HEAP32.subarray($4 >>> 0, ($4 >>> 0) + $3))) : [], \ + }), \ + static_cast(keepdims_), \ + static_cast(noop_with_empty_axes_), \ + gsl::narrow_cast(axes.size()), \ + JSEP_HEAP_INDEX((axes.size() > 0) ? axes.data() : nullptr) >> 2); \ + } \ }; JSEP_DEFINE_REDUCE_KERNEL(ReduceMax); diff --git a/onnxruntime/core/providers/js/operators/resize.h b/onnxruntime/core/providers/js/operators/resize.h index 4b1c288ae3015..533bb4908c773 100644 --- a/onnxruntime/core/providers/js/operators/resize.h +++ b/onnxruntime/core/providers/js/operators/resize.h @@ -23,7 +23,7 @@ class Resize : public JsKernel, public UpsampleBase { std::transform(axes_.begin(), axes_.end(), std::back_inserter(axes), [](auto& axis) { return gsl::narrow_cast(axis); }); JSEP_INIT_KERNEL_ATTRIBUTE(Resize, ({ "antialias" : $1, - "axes" : $2 ? Array.from(HEAP32.subarray($3, $3 + $2)) : [], + "axes" : $2 ? Array.from(HEAP32.subarray($3 >>> 0, ($3 >>> 0) + $2)) : [], "coordinateTransformMode" : UTF8ToString($4), "cubicCoeffA" : $5, "excludeOutside" : $6, @@ -34,7 +34,7 @@ class Resize : public JsKernel, public UpsampleBase { }), static_cast(antialias_), gsl::narrow_cast(axes.size()), - JSEP_HEAP_PTR((axes.size() > 0) ? axes.data() : nullptr) >> 2, + JSEP_HEAP_INDEX((axes.size() > 0) ? axes.data() : nullptr) >> 2, resize_coordinate_transformation_mode.c_str(), static_cast(cubic_coeff_a_), static_cast(exclude_outside_), diff --git a/onnxruntime/core/providers/js/operators/slice.h b/onnxruntime/core/providers/js/operators/slice.h index 989adabf029a5..0b9365a2d1896 100644 --- a/onnxruntime/core/providers/js/operators/slice.h +++ b/onnxruntime/core/providers/js/operators/slice.h @@ -20,15 +20,15 @@ class Slice : public JsKernel, public SliceBase { std::vector starts(attr_starts.begin(), attr_starts.end()); std::vector ends(attr_ends.begin(), attr_ends.end()); - JSEP_INIT_KERNEL_ATTRIBUTE(Slice, ({"starts" : $1 ? Array.from(HEAP32.subarray($2, $2 + $1)) : [], - "ends" : $3 ? Array.from(HEAP32.subarray($4, $4 + $3)) : [], - "axes" : $5 ? Array.from(HEAP32.subarray($6, $6 + $5)) : []}), + JSEP_INIT_KERNEL_ATTRIBUTE(Slice, ({"starts" : $1 ? Array.from(HEAP32.subarray($2 >>> 0, ($2 >>> 0) + $1)) : [], + "ends" : $3 ? Array.from(HEAP32.subarray($4 >>> 0, ($4 >>> 0) + $3)) : [], + "axes" : $5 ? Array.from(HEAP32.subarray($6 >>> 0, ($6 >>> 0) + $5)) : []}), gsl::narrow_cast(starts.size()), - JSEP_HEAP_PTR((starts.size() > 0) ? starts.data() : nullptr) >> 2, + JSEP_HEAP_INDEX((starts.size() > 0) ? starts.data() : nullptr) >> 2, gsl::narrow_cast(ends.size()), - JSEP_HEAP_PTR((ends.size() > 0) ? ends.data() : nullptr) >> 2, + JSEP_HEAP_INDEX((ends.size() > 0) ? ends.data() : nullptr) >> 2, gsl::narrow_cast(axes.size()), - JSEP_HEAP_PTR((axes.size() > 0) ? axes.data() : nullptr) >> 2); + JSEP_HEAP_INDEX((axes.size() > 0) ? axes.data() : nullptr) >> 2); } }; diff --git a/onnxruntime/core/providers/js/operators/split.h b/onnxruntime/core/providers/js/operators/split.h index 1c1874e5aa98e..6b7cc32322d24 100644 --- a/onnxruntime/core/providers/js/operators/split.h +++ b/onnxruntime/core/providers/js/operators/split.h @@ -49,11 +49,11 @@ class Split : public JsKernel, public SplitBase { JSEP_INIT_KERNEL_ATTRIBUTE(Split, ({"axis" : $1, "numOutputs" : $2, - "splitSizes" : $3 ? Array.from(HEAP32.subarray($4, $4 + $3)) : []}), + "splitSizes" : $3 ? Array.from(HEAP32.subarray($4 >>> 0, ($4 >>> 0) + $3)) : []}), static_cast(axis_), static_cast(num_outputs_), gsl::narrow_cast(split_sizes.size()), - JSEP_HEAP_PTR((split_sizes.size() > 0) ? split_sizes.data() : nullptr) >> 2); + JSEP_HEAP_INDEX((split_sizes.size() > 0) ? split_sizes.data() : nullptr) >> 2); } }; diff --git a/onnxruntime/core/providers/js/operators/transpose.h b/onnxruntime/core/providers/js/operators/transpose.h index dae442b9f5a13..4d282410d95a7 100644 --- a/onnxruntime/core/providers/js/operators/transpose.h +++ b/onnxruntime/core/providers/js/operators/transpose.h @@ -21,13 +21,13 @@ class Transpose final : public JsKernel, public TransposeBase { } } JSEP_INIT_KERNEL_ATTRIBUTE(Transpose, ({ - "perm" : $1 ? Array.from(HEAP32.subarray($2, $2 + $1)) : [] + "perm" : $1 ? Array.from(HEAP32.subarray($2 >>> 0, ($2 >>> 0) + $1)) : [] }), // $1: length of attribute "perm" (int32[]) gsl::narrow_cast(perm_specified_ ? perm_.size() : 0), // $2: index to HEAP32 of the first int32 element. calculated from right shift memory // address by 2 - JSEP_HEAP_PTR(perm_specified_ && !perm.empty() ? perm.data() : nullptr) >> 2); + JSEP_HEAP_INDEX(perm_specified_ && !perm.empty() ? perm.data() : nullptr) >> 2); } }; From ba8540ee37d58045b4b383531013d3d29782aed3 Mon Sep 17 00:00:00 2001 From: guschmue Date: Tue, 9 Jan 2024 10:01:17 -0800 Subject: [PATCH 2/4] lint --- js/web/lib/wasm/jsep/init.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 8f77f1fb29acc..d82bc24b34824 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -62,7 +62,7 @@ class ComputeContextImpl implements ComputeContext { } get customDataBuffer(): Uint8Array { return this.module.HEAPU8.subarray( - this.customDataOffset >>> 0, (this.customDataOffset >>> 0) + this.customDataSize); + this.customDataOffset >>> 0, (this.customDataOffset >>> 0) + this.customDataSize); } private customDataOffset = 0; private customDataSize = 0; @@ -183,8 +183,8 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte 'verbose', () => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`); - await backend.download(gpuDataId, - () => module.HEAPU8.subarray(dataOffset >>> 0, (dataOffset >>> 0) + size)); + await backend.download( + gpuDataId, () => module.HEAPU8.subarray(dataOffset >>> 0, (dataOffset >>> 0) + size)); }, // jsepCreateKernel From b41132ed4ec9edc7199042e27b8c7d2a58728c1d Mon Sep 17 00:00:00 2001 From: guschmue Date: Fri, 12 Jan 2024 17:16:23 -0800 Subject: [PATCH 3/4] address review feedback --- js/web/lib/wasm/jsep/init.ts | 2 +- onnxruntime/core/providers/js/js_kernel.h | 5 +- .../core/providers/js/operators/conv.h | 33 ++++++------ .../providers/js/operators/conv_transpose.h | 52 +++++++++---------- onnxruntime/core/providers/js/operators/pad.h | 6 +-- .../core/providers/js/operators/reduce.h | 46 ++++++++-------- .../core/providers/js/operators/resize.h | 6 +-- .../core/providers/js/operators/slice.h | 18 +++---- .../core/providers/js/operators/split.h | 6 +-- .../core/providers/js/operators/transpose.h | 9 ++-- 10 files changed, 89 insertions(+), 94 deletions(-) diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index d82bc24b34824..1b35686c78487 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -62,7 +62,7 @@ class ComputeContextImpl implements ComputeContext { } get customDataBuffer(): Uint8Array { return this.module.HEAPU8.subarray( - this.customDataOffset >>> 0, (this.customDataOffset >>> 0) + this.customDataSize); + this.customDataOffset, this.customDataOffset + this.customDataSize); } private customDataOffset = 0; private customDataSize = 0; diff --git a/onnxruntime/core/providers/js/js_kernel.h b/onnxruntime/core/providers/js/js_kernel.h index 1b15c97f0e587..7324b0d69474c 100644 --- a/onnxruntime/core/providers/js/js_kernel.h +++ b/onnxruntime/core/providers/js/js_kernel.h @@ -67,7 +67,10 @@ namespace js { float value; \ ORT_ENFORCE(info.GetAttr(#attr_name, &value));, \ , ({#attr_name : $1}), static_cast(value)) -#define JSEP_HEAP_INDEX(ptr) reinterpret_cast(ptr) + +#define JSEP_HEAP8_INDEX(ptr) reinterpret_cast(ptr) +#define JSEP_HEAP32_INDEX_START(vec) ((vec.size() > 0) ? reinterpret_cast(vec.data()) >> 2 : 0) +#define JSEP_HEAP32_INDEX_END(vec) ((reinterpret_cast(vec.data()) >> 2) + vec.size()) // TODO: // class JsMultiProgramKernel : public OpKernel { /* TBD */ }; diff --git a/onnxruntime/core/providers/js/operators/conv.h b/onnxruntime/core/providers/js/operators/conv.h index 2c28263646140..89719f6ba6657 100644 --- a/onnxruntime/core/providers/js/operators/conv.h +++ b/onnxruntime/core/providers/js/operators/conv.h @@ -29,7 +29,6 @@ class ConvBase : public JsKernel { } conv_attrs_.activation = info.GetAttrOrDefault("activation", ""); std::vector activation_params = info.GetAttrsOrDefault("activation_params"); - const auto* activation_params_ptr = activation_params.size() > 0 ? activation_params.data() : nullptr; int64_t channels_last = is_channels_last ? 1 : info.GetAttrOrDefault("channels_last", 0); auto kernel_shape_0 = conv_attrs_.kernel_shape_specified && kernel_shape.size() > 0 ? kernel_shape[0] : 0; auto kernel_shape_1 = conv_attrs_.kernel_shape_specified && kernel_shape.size() > 1 ? kernel_shape[1] : 0; @@ -43,24 +42,24 @@ class ConvBase : public JsKernel { "dilations" : [$2], "group" : $3, "kernel_shape" : [$4], - "pads" : $5 ? Array.from(HEAP32.subarray($6 >>> 0, ($6 >>> 0) + $5)) : [], + "pads" : $5 ? Array.from(HEAP32.subarray($5, $6)) : [], "strides" : [$7], - "w_is_const" : () JS_ARROW(!!HEAP8[$9 >>> 0]), + "w_is_const" : () JS_ARROW(!!HEAP8[$9]), "activation" : UTF8ToString($10), - "activation_params" : $11 ? Array.from(HEAPF32.subarray($12 >>> 0, ($12 >>> 0) + $11)) : [] + "activation_params" : $11 ? Array.from(HEAPF32.subarray($11, $12)) : [] }), static_cast(conv_attrs_.auto_pad), static_cast(conv_attrs_.dilations.size() > 0 ? conv_attrs_.dilations[0] : 0), static_cast(conv_attrs_.group), static_cast(kernel_shape_0), - static_cast(local_pads.size()), - JSEP_HEAP_INDEX(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2, + JSEP_HEAP32_INDEX_START(local_pads), + JSEP_HEAP32_INDEX_END(local_pads), static_cast(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0), static_cast(channels_last), - JSEP_HEAP_INDEX(&w_is_const_), + JSEP_HEAP8_INDEX(&w_is_const_), conv_attrs_.activation.c_str(), - activation_params.size(), - JSEP_HEAP_INDEX(activation_params_ptr) >> 2); + JSEP_HEAP32_INDEX_START(activation_params), + JSEP_HEAP32_INDEX_END(activation_params)); } else { JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({ "format" : $11 ? "NHWC" : "NCHW", @@ -68,11 +67,11 @@ class ConvBase : public JsKernel { "dilations" : [ $2, $3 ], "group" : $4, "kernel_shape" : [ $5, $6 ], - "pads" : $7 ? Array.from(HEAP32.subarray($8 >>> 0, ($8 >>> 0) + $7)) : [], + "pads" : $7 ? Array.from(HEAP32.subarray($7, $8)) : [], "strides" : [ $9, $10 ], - "w_is_const" : () JS_ARROW(!!HEAP8[$12 >>> 0]), + "w_is_const" : () JS_ARROW(!!HEAP8[$12]), "activation" : UTF8ToString($13), - "activation_params" : $14 ? Array.from(HEAPF32.subarray($15 >>> 0, ($15 >>> 0) + $14)) : [] + "activation_params" : $14 ? Array.from(HEAPF32.subarray($14, $15)) : [] }), static_cast(conv_attrs_.auto_pad), static_cast(conv_attrs_.dilations.size() > 0 ? conv_attrs_.dilations[0] : 0), @@ -80,15 +79,15 @@ class ConvBase : public JsKernel { static_cast(conv_attrs_.group), static_cast(kernel_shape_0), static_cast(kernel_shape_1), - static_cast(local_pads.size()), - JSEP_HEAP_INDEX(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2, + JSEP_HEAP32_INDEX_START(local_pads), + JSEP_HEAP32_INDEX_END(local_pads), static_cast(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0), static_cast(conv_attrs_.strides.size() > 1 ? conv_attrs_.strides[1] : 0), static_cast(channels_last), - JSEP_HEAP_INDEX(&w_is_const_), + JSEP_HEAP8_INDEX(&w_is_const_), conv_attrs_.activation.c_str(), - activation_params.size(), - JSEP_HEAP_INDEX(activation_params_ptr) >> 2); + JSEP_HEAP32_INDEX_START(activation_params), + JSEP_HEAP32_INDEX_END(activation_params)); } } diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.h b/onnxruntime/core/providers/js/operators/conv_transpose.h index 3e37f80d1d0ce..258f5676eb93e 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.h +++ b/onnxruntime/core/providers/js/operators/conv_transpose.h @@ -29,10 +29,6 @@ class ConvTranspose : public JsKernel { conv_transpose_attrs_.output_shape.end()); std::vector local_output_padding(conv_transpose_attrs_.output_padding.begin(), conv_transpose_attrs_.output_padding.end()); - const auto* local_output_padding_ptr = - local_output_padding.size() > 0 ? local_output_padding.data() : nullptr; - const auto* local_output_shape_ptr = - local_output_shape.size() > 0 ? local_output_shape.data() : nullptr; // currently only support Conv 1D/2D. TODO: support Conv3D and other if (conv_transpose_attrs_.dilations.size() == 1 || @@ -51,9 +47,9 @@ class ConvTranspose : public JsKernel { "kernel_shape" : [$4], "pads" : [ $5, $6 ], "strides" : [$7], - "wIsConst" : () JS_ARROW(!!HEAP8[$9 >>> 0]), - "outputPadding" : $10 ? Array.from(HEAP32.subarray($11 >>> 0, ($11 >>> 0) + $10)) : [], - "outputShape" : $12 ? Array.from(HEAP32.subarray($13 >>> 0, ($13 >>> 0) + $12)) : [], + "wIsConst" : () JS_ARROW(!!HEAP8[$9]), + "outputPadding" : $10 ? Array.from(HEAP32.subarray($10, $11)) : [], + "outputShape" : $12 ? Array.from(HEAP32.subarray($12, $13)) : [], "activation" : UTF8ToString($14) }), static_cast(conv_transpose_attrs_.auto_pad), @@ -64,11 +60,11 @@ class ConvTranspose : public JsKernel { static_cast(pads_1), static_cast(strides), static_cast(channels_last), - JSEP_HEAP_INDEX(&w_is_const_), - gsl::narrow_cast(local_output_padding.size()), - JSEP_HEAP_INDEX(local_output_padding_ptr) >> 2, - gsl::narrow_cast(local_output_shape.size()), - JSEP_HEAP_INDEX(local_output_shape_ptr) >> 2, + JSEP_HEAP8_INDEX(&w_is_const_), + JSEP_HEAP32_INDEX_START(local_output_padding), + JSEP_HEAP32_INDEX_END(local_output_padding), + JSEP_HEAP32_INDEX_START(local_output_shape), + JSEP_HEAP32_INDEX_END(local_output_shape), conv_transpose_attrs_.activation.c_str()); } else { constexpr size_t pads_vec_size = 4; @@ -103,28 +99,28 @@ class ConvTranspose : public JsKernel { JSEP_INIT_KERNEL_ATTRIBUTE(ConvTranspose, ({ "format" : $7 ? "NHWC" : "NCHW", "autoPad" : $1, - "dilations" : Array.from(HEAP32.subarray($2 >>> 0, ($2 >>> 0) + /* dialations_vec_size */ 2)), + "dilations" : Array.from(HEAP32.subarray($2, ($2 >>> 0) + /* dialations_vec_size */ 2)), "group" : $3, - "kernelShape" : Array.from(HEAP32.subarray($4 >>> 0, ($4 >>> 0) + /* kernel_shape_vec_size */ 2)), - "pads" : Array.from(HEAP32.subarray($5 >>> 0, ($5 >>> 0) + /* pads_vec_size */ 4)), - "strides" : Array.from(HEAP32.subarray($6 >>> 0, ($6 >>> 0) + /* strides_vec_size */ 2)), - "wIsConst" : () JS_ARROW(!!HEAP8[$8 >>> 0]), - "outputPadding" : ($9 > 0) ? Array.from(HEAP32.subarray($10 >>> 0, ($10 >>> 0) + $9)) : [], - "outputShape" : ($11 > 0) ? Array.from(HEAP32.subarray($12 >>> 0, ($12 >>> 0) + $11)) : [], + "kernelShape" : Array.from(HEAP32.subarray($4, ($4 >>> 0) + /* kernel_shape_vec_size */ 2)), + "pads" : Array.from(HEAP32.subarray($5, ($5 >>> 0) + /* pads_vec_size */ 4)), + "strides" : Array.from(HEAP32.subarray($6, ($6 >>> 0) + /* strides_vec_size */ 2)), + "wIsConst" : () JS_ARROW(!!HEAP8[$8]), + "outputPadding" : $9 ? Array.from(HEAP32.subarray($9, $10)) : [], + "outputShape" : $11 ? Array.from(HEAP32.subarray($11, $12)) : [], "activation" : UTF8ToString($13) }), static_cast(conv_transpose_attrs_.auto_pad), - JSEP_HEAP_INDEX(local_dilations.data()) >> 2, + JSEP_HEAP32_INDEX_START(local_dilations), static_cast(conv_transpose_attrs_.group), - JSEP_HEAP_INDEX(local_kernel_shape.data()) >> 2, - JSEP_HEAP_INDEX(local_pads.data()) >> 2, - JSEP_HEAP_INDEX(local_strides.data()) >> 2, + JSEP_HEAP32_INDEX_START(local_kernel_shape), + JSEP_HEAP32_INDEX_START(local_pads), + JSEP_HEAP32_INDEX_START(local_strides), static_cast(channels_last), - JSEP_HEAP_INDEX(&w_is_const_), - gsl::narrow_cast(local_output_padding.size()), - JSEP_HEAP_INDEX(local_output_padding_ptr) >> 2, - gsl::narrow_cast(local_output_shape.size()), - JSEP_HEAP_INDEX(local_output_shape_ptr) >> 2, + JSEP_HEAP8_INDEX(&w_is_const_), + JSEP_HEAP32_INDEX_START(local_output_padding), + JSEP_HEAP32_INDEX_END(local_output_padding), + JSEP_HEAP32_INDEX_START(local_output_shape), + JSEP_HEAP32_INDEX_END(local_output_shape), conv_transpose_attrs_.activation.c_str()); } } diff --git a/onnxruntime/core/providers/js/operators/pad.h b/onnxruntime/core/providers/js/operators/pad.h index c3929ad592f38..c18c7dd456dc2 100644 --- a/onnxruntime/core/providers/js/operators/pad.h +++ b/onnxruntime/core/providers/js/operators/pad.h @@ -22,11 +22,11 @@ class Pad : public JsKernel, public PadBase { JSEP_INIT_KERNEL_ATTRIBUTE(Pad, ({"mode" : $1, "value" : $2, - "pads" : $3 ? Array.from(HEAP32.subarray($4 >>> 0, ($4 >>> 0) + $3)) : []}), + "pads" : $3 ? Array.from(HEAP32.subarray($3, $4)) : []}), static_cast(mode_), static_cast(value_), - gsl::narrow_cast(pads.size()), - JSEP_HEAP_INDEX((pads.size() > 0) ? pads.data() : nullptr) >> 2); + JSEP_HEAP32_INDEX_START(pads), + JSEP_HEAP32_INDEX_END(pads)); } }; diff --git a/onnxruntime/core/providers/js/operators/reduce.h b/onnxruntime/core/providers/js/operators/reduce.h index 3db168a59cabc..937f1f990dc67 100644 --- a/onnxruntime/core/providers/js/operators/reduce.h +++ b/onnxruntime/core/providers/js/operators/reduce.h @@ -8,29 +8,29 @@ namespace onnxruntime { namespace js { -#define JSEP_DEFINE_REDUCE_KERNEL(ReduceKernel) \ - template \ - class ReduceKernel : public JsKernel, public ReduceKernelBase { \ - public: \ - using ReduceKernelBase::axes_; \ - using ReduceKernelBase::noop_with_empty_axes_; \ - using ReduceKernelBase::keepdims_; \ - ReduceKernel(const OpKernelInfo& info) : JsKernel(info), ReduceKernelBase(info) { \ - std::vector axes(axes_.size()); \ - if (axes_.size() > 0) { \ - std::transform(axes_.begin(), axes_.end(), axes.begin(), \ - [](int64_t axis) { return gsl::narrow_cast(axis); }); \ - } \ - JSEP_INIT_KERNEL_ATTRIBUTE(ReduceKernel, ({ \ - "keepDims" : !!$1, \ - "noopWithEmptyAxes" : !!$2, \ - "axes" : $3 ? (Array.from(HEAP32.subarray($4 >>> 0, ($4 >>> 0) + $3))) : [], \ - }), \ - static_cast(keepdims_), \ - static_cast(noop_with_empty_axes_), \ - gsl::narrow_cast(axes.size()), \ - JSEP_HEAP_INDEX((axes.size() > 0) ? axes.data() : nullptr) >> 2); \ - } \ +#define JSEP_DEFINE_REDUCE_KERNEL(ReduceKernel) \ + template \ + class ReduceKernel : public JsKernel, public ReduceKernelBase { \ + public: \ + using ReduceKernelBase::axes_; \ + using ReduceKernelBase::noop_with_empty_axes_; \ + using ReduceKernelBase::keepdims_; \ + ReduceKernel(const OpKernelInfo& info) : JsKernel(info), ReduceKernelBase(info) { \ + std::vector axes(axes_.size()); \ + if (axes_.size() > 0) { \ + std::transform(axes_.begin(), axes_.end(), axes.begin(), \ + [](int64_t axis) { return gsl::narrow_cast(axis); }); \ + } \ + JSEP_INIT_KERNEL_ATTRIBUTE(ReduceKernel, ({ \ + "keepDims" : !!$1, \ + "noopWithEmptyAxes" : !!$2, \ + "axes" : $3 ? (Array.from(HEAP32.subarray($3, $4))) : [], \ + }), \ + static_cast(keepdims_), \ + static_cast(noop_with_empty_axes_), \ + JSEP_HEAP32_INDEX_START(axes), \ + JSEP_HEAP32_INDEX_END(axes)); \ + } \ }; JSEP_DEFINE_REDUCE_KERNEL(ReduceMax); diff --git a/onnxruntime/core/providers/js/operators/resize.h b/onnxruntime/core/providers/js/operators/resize.h index 533bb4908c773..134eb4bf5a7f4 100644 --- a/onnxruntime/core/providers/js/operators/resize.h +++ b/onnxruntime/core/providers/js/operators/resize.h @@ -23,7 +23,7 @@ class Resize : public JsKernel, public UpsampleBase { std::transform(axes_.begin(), axes_.end(), std::back_inserter(axes), [](auto& axis) { return gsl::narrow_cast(axis); }); JSEP_INIT_KERNEL_ATTRIBUTE(Resize, ({ "antialias" : $1, - "axes" : $2 ? Array.from(HEAP32.subarray($3 >>> 0, ($3 >>> 0) + $2)) : [], + "axes" : $2 ? Array.from(HEAP32.subarray($2, $3)) : [], "coordinateTransformMode" : UTF8ToString($4), "cubicCoeffA" : $5, "excludeOutside" : $6, @@ -33,8 +33,8 @@ class Resize : public JsKernel, public UpsampleBase { "nearestMode" : UTF8ToString($10), }), static_cast(antialias_), - gsl::narrow_cast(axes.size()), - JSEP_HEAP_INDEX((axes.size() > 0) ? axes.data() : nullptr) >> 2, + JSEP_HEAP32_INDEX_START(axes), + JSEP_HEAP32_INDEX_END(axes), resize_coordinate_transformation_mode.c_str(), static_cast(cubic_coeff_a_), static_cast(exclude_outside_), diff --git a/onnxruntime/core/providers/js/operators/slice.h b/onnxruntime/core/providers/js/operators/slice.h index 0b9365a2d1896..daeffaa664741 100644 --- a/onnxruntime/core/providers/js/operators/slice.h +++ b/onnxruntime/core/providers/js/operators/slice.h @@ -20,15 +20,15 @@ class Slice : public JsKernel, public SliceBase { std::vector starts(attr_starts.begin(), attr_starts.end()); std::vector ends(attr_ends.begin(), attr_ends.end()); - JSEP_INIT_KERNEL_ATTRIBUTE(Slice, ({"starts" : $1 ? Array.from(HEAP32.subarray($2 >>> 0, ($2 >>> 0) + $1)) : [], - "ends" : $3 ? Array.from(HEAP32.subarray($4 >>> 0, ($4 >>> 0) + $3)) : [], - "axes" : $5 ? Array.from(HEAP32.subarray($6 >>> 0, ($6 >>> 0) + $5)) : []}), - gsl::narrow_cast(starts.size()), - JSEP_HEAP_INDEX((starts.size() > 0) ? starts.data() : nullptr) >> 2, - gsl::narrow_cast(ends.size()), - JSEP_HEAP_INDEX((ends.size() > 0) ? ends.data() : nullptr) >> 2, - gsl::narrow_cast(axes.size()), - JSEP_HEAP_INDEX((axes.size() > 0) ? axes.data() : nullptr) >> 2); + JSEP_INIT_KERNEL_ATTRIBUTE(Slice, ({"starts" : $1 ? Array.from(HEAP32.subarray($1, $2)) : [], + "ends" : $3 ? Array.from(HEAP32.subarray($3, $4)) : [], + "axes" : $5 ? Array.from(HEAP32.subarray($5, $6)) : []}), + JSEP_HEAP32_INDEX_START(starts), + JSEP_HEAP32_INDEX_END(starts), + JSEP_HEAP32_INDEX_START(ends), + JSEP_HEAP32_INDEX_END(ends), + JSEP_HEAP32_INDEX_START(axes), + JSEP_HEAP32_INDEX_END(axes)); } }; diff --git a/onnxruntime/core/providers/js/operators/split.h b/onnxruntime/core/providers/js/operators/split.h index 6b7cc32322d24..4fdbab00e739c 100644 --- a/onnxruntime/core/providers/js/operators/split.h +++ b/onnxruntime/core/providers/js/operators/split.h @@ -49,11 +49,11 @@ class Split : public JsKernel, public SplitBase { JSEP_INIT_KERNEL_ATTRIBUTE(Split, ({"axis" : $1, "numOutputs" : $2, - "splitSizes" : $3 ? Array.from(HEAP32.subarray($4 >>> 0, ($4 >>> 0) + $3)) : []}), + "splitSizes" : $3 ? Array.from(HEAP32.subarray($3, $4)) : []}), static_cast(axis_), static_cast(num_outputs_), - gsl::narrow_cast(split_sizes.size()), - JSEP_HEAP_INDEX((split_sizes.size() > 0) ? split_sizes.data() : nullptr) >> 2); + JSEP_HEAP32_INDEX_START(split_sizes), + JSEP_HEAP32_INDEX_END(split_sizes)); } }; diff --git a/onnxruntime/core/providers/js/operators/transpose.h b/onnxruntime/core/providers/js/operators/transpose.h index 4d282410d95a7..f43dd814aa959 100644 --- a/onnxruntime/core/providers/js/operators/transpose.h +++ b/onnxruntime/core/providers/js/operators/transpose.h @@ -21,13 +21,10 @@ class Transpose final : public JsKernel, public TransposeBase { } } JSEP_INIT_KERNEL_ATTRIBUTE(Transpose, ({ - "perm" : $1 ? Array.from(HEAP32.subarray($2 >>> 0, ($2 >>> 0) + $1)) : [] + "perm" : $1 ? Array.from(HEAP32.subarray($1, $2)) : [] }), - // $1: length of attribute "perm" (int32[]) - gsl::narrow_cast(perm_specified_ ? perm_.size() : 0), - // $2: index to HEAP32 of the first int32 element. calculated from right shift memory - // address by 2 - JSEP_HEAP_INDEX(perm_specified_ && !perm.empty() ? perm.data() : nullptr) >> 2); + JSEP_HEAP32_INDEX_START(perm), + JSEP_HEAP32_INDEX_END(perm)); } }; From 1209b26f98e82be3d66df2c60afd6c61e0a0d715 Mon Sep 17 00:00:00 2001 From: guschmue Date: Fri, 12 Jan 2024 17:25:00 -0800 Subject: [PATCH 4/4] lint --- js/web/lib/wasm/jsep/init.ts | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 1b35686c78487..935f0dcabcd73 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -61,8 +61,7 @@ class ComputeContextImpl implements ComputeContext { return this.backend.currentKernelCustomData; } get customDataBuffer(): Uint8Array { - return this.module.HEAPU8.subarray( - this.customDataOffset, this.customDataOffset + this.customDataSize); + return this.module.HEAPU8.subarray(this.customDataOffset, this.customDataOffset + this.customDataSize); } private customDataOffset = 0; private customDataSize = 0;