Skip to content

Commit

Permalink
[js/webgpu] fix heap access > 2GB (#19010)
Browse files Browse the repository at this point in the history
  • Loading branch information
guschmue authored Jan 9, 2024
1 parent 975a315 commit a8bb1df
Show file tree
Hide file tree
Showing 9 changed files with 47 additions and 46 deletions.
1 change: 1 addition & 0 deletions onnxruntime/core/providers/js/js_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ namespace js {
float value; \
ORT_ENFORCE(info.GetAttr<float>(#attr_name, &value));, \
, ({#attr_name : $1}), static_cast<double>(value))
#define JSEP_HEAP_PTR(ptr) reinterpret_cast<uintptr_t>(ptr)

// TODO:
// class JsMultiProgramKernel : public OpKernel { /* TBD */ };
Expand Down
12 changes: 6 additions & 6 deletions onnxruntime/core/providers/js/operators/conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ class ConvBase : public JsKernel {
static_cast<int32_t>(conv_attrs_.group),
static_cast<int32_t>(kernel_shape_0),
static_cast<int32_t>(local_pads.size()),
reinterpret_cast<int32_t>(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2,
JSEP_HEAP_PTR(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2,
static_cast<int32_t>(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0),
static_cast<int32_t>(channels_last),
reinterpret_cast<int32_t>(&w_is_const_),
JSEP_HEAP_PTR(&w_is_const_),
conv_attrs_.activation.c_str(),
activation_params.size(),
reinterpret_cast<int32_t>(activation_params_ptr) >> 2);
JSEP_HEAP_PTR(activation_params_ptr) >> 2);
} else {
JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({
"format" : $11 ? "NHWC" : "NCHW",
Expand All @@ -81,14 +81,14 @@ class ConvBase : public JsKernel {
static_cast<int32_t>(kernel_shape_0),
static_cast<int32_t>(kernel_shape_1),
static_cast<int32_t>(local_pads.size()),
reinterpret_cast<int32_t>(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2,
JSEP_HEAP_PTR(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2,
static_cast<int32_t>(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0),
static_cast<int32_t>(conv_attrs_.strides.size() > 1 ? conv_attrs_.strides[1] : 0),
static_cast<int32_t>(channels_last),
reinterpret_cast<int32_t>(&w_is_const_),
JSEP_HEAP_PTR(&w_is_const_),
conv_attrs_.activation.c_str(),
activation_params.size(),
reinterpret_cast<int32_t>(activation_params_ptr) >> 2);
JSEP_HEAP_PTR(activation_params_ptr) >> 2);
}
}

Expand Down
20 changes: 10 additions & 10 deletions onnxruntime/core/providers/js/operators/conv_transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ class ConvTranspose : public JsKernel {
static_cast<int32_t>(pads_1),
static_cast<int32_t>(strides),
static_cast<int32_t>(channels_last),
reinterpret_cast<int32_t>(&w_is_const_),
JSEP_HEAP_PTR(&w_is_const_),
gsl::narrow_cast<int32_t>(local_output_padding.size()),
reinterpret_cast<int32_t>(local_output_padding_ptr) >> 2,
JSEP_HEAP_PTR(local_output_padding_ptr) >> 2,
gsl::narrow_cast<int32_t>(local_output_shape.size()),
reinterpret_cast<int32_t>(local_output_shape_ptr) >> 2,
JSEP_HEAP_PTR(local_output_shape_ptr) >> 2,
conv_transpose_attrs_.activation.c_str());
} else {
constexpr size_t pads_vec_size = 4;
Expand Down Expand Up @@ -114,17 +114,17 @@ class ConvTranspose : public JsKernel {
"activation" : UTF8ToString($13)
}),
static_cast<int32_t>(conv_transpose_attrs_.auto_pad),
reinterpret_cast<int32_t>(local_dilations.data()) >> 2,
JSEP_HEAP_PTR(local_dilations.data()) >> 2,
static_cast<int32_t>(conv_transpose_attrs_.group),
reinterpret_cast<int32_t>(local_kernel_shape.data()) >> 2,
reinterpret_cast<int32_t>(local_pads.data()) >> 2,
reinterpret_cast<int32_t>(local_strides.data()) >> 2,
JSEP_HEAP_PTR(local_kernel_shape.data()) >> 2,
JSEP_HEAP_PTR(local_pads.data()) >> 2,
JSEP_HEAP_PTR(local_strides.data()) >> 2,
static_cast<int32_t>(channels_last),
reinterpret_cast<int32_t>(&w_is_const_),
JSEP_HEAP_PTR(&w_is_const_),
gsl::narrow_cast<int32_t>(local_output_padding.size()),
reinterpret_cast<int32_t>(local_output_padding_ptr) >> 2,
JSEP_HEAP_PTR(local_output_padding_ptr) >> 2,
gsl::narrow_cast<int32_t>(local_output_shape.size()),
reinterpret_cast<int32_t>(local_output_shape_ptr) >> 2,
JSEP_HEAP_PTR(local_output_shape_ptr) >> 2,
conv_transpose_attrs_.activation.c_str());
}
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/js/operators/pad.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class Pad : public JsKernel, public PadBase {
static_cast<int32_t>(mode_),
static_cast<double>(value_),
gsl::narrow_cast<int32_t>(pads.size()),
reinterpret_cast<int32_t>((pads.size() > 0) ? pads.data() : nullptr) >> 2);
JSEP_HEAP_PTR((pads.size() > 0) ? pads.data() : nullptr) >> 2);
}
};

Expand Down
46 changes: 23 additions & 23 deletions onnxruntime/core/providers/js/operators/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,29 @@

namespace onnxruntime {
namespace js {
#define JSEP_DEFINE_REDUCE_KERNEL(ReduceKernel) \
template <bool allow_multi_axes = true> \
class ReduceKernel : public JsKernel, public ReduceKernelBase<allow_multi_axes> { \
public: \
using ReduceKernelBase<allow_multi_axes>::axes_; \
using ReduceKernelBase<allow_multi_axes>::noop_with_empty_axes_; \
using ReduceKernelBase<allow_multi_axes>::keepdims_; \
ReduceKernel(const OpKernelInfo& info) : JsKernel(info), ReduceKernelBase<allow_multi_axes>(info) { \
std::vector<int32_t> axes(axes_.size()); \
if (axes_.size() > 0) { \
std::transform(axes_.begin(), axes_.end(), axes.begin(), \
[](int64_t axis) { return gsl::narrow_cast<int32_t>(axis); }); \
} \
JSEP_INIT_KERNEL_ATTRIBUTE(ReduceKernel, ({ \
"keepDims" : !!$1, \
"noopWithEmptyAxes" : !!$2, \
"axes" : $3 ? (Array.from(HEAP32.subarray($4, $4 + $3))) : [], \
}), \
static_cast<int32_t>(keepdims_), \
static_cast<int32_t>(noop_with_empty_axes_), \
gsl::narrow_cast<int32_t>(axes.size()), \
reinterpret_cast<int32_t>((axes.size() > 0) ? axes.data() : nullptr) >> 2); \
} \
#define JSEP_DEFINE_REDUCE_KERNEL(ReduceKernel) \
template <bool allow_multi_axes = true> \
class ReduceKernel : public JsKernel, public ReduceKernelBase<allow_multi_axes> { \
public: \
using ReduceKernelBase<allow_multi_axes>::axes_; \
using ReduceKernelBase<allow_multi_axes>::noop_with_empty_axes_; \
using ReduceKernelBase<allow_multi_axes>::keepdims_; \
ReduceKernel(const OpKernelInfo& info) : JsKernel(info), ReduceKernelBase<allow_multi_axes>(info) { \
std::vector<int32_t> axes(axes_.size()); \
if (axes_.size() > 0) { \
std::transform(axes_.begin(), axes_.end(), axes.begin(), \
[](int64_t axis) { return gsl::narrow_cast<int32_t>(axis); }); \
} \
JSEP_INIT_KERNEL_ATTRIBUTE(ReduceKernel, ({ \
"keepDims" : !!$1, \
"noopWithEmptyAxes" : !!$2, \
"axes" : $3 ? (Array.from(HEAP32.subarray($4, $4 + $3))) : [], \
}), \
static_cast<int32_t>(keepdims_), \
static_cast<int32_t>(noop_with_empty_axes_), \
gsl::narrow_cast<int32_t>(axes.size()), \
JSEP_HEAP_PTR((axes.size() > 0) ? axes.data() : nullptr) >> 2); \
} \
};

JSEP_DEFINE_REDUCE_KERNEL(ReduceMax);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/js/operators/resize.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class Resize : public JsKernel, public UpsampleBase {
}),
static_cast<int32_t>(antialias_),
gsl::narrow_cast<int32_t>(axes.size()),
reinterpret_cast<int32_t>((axes.size() > 0) ? axes.data() : nullptr) >> 2,
JSEP_HEAP_PTR((axes.size() > 0) ? axes.data() : nullptr) >> 2,
resize_coordinate_transformation_mode.c_str(),
static_cast<double>(cubic_coeff_a_),
static_cast<int32_t>(exclude_outside_),
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/js/operators/slice.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ class Slice : public JsKernel, public SliceBase {
"ends" : $3 ? Array.from(HEAP32.subarray($4, $4 + $3)) : [],
"axes" : $5 ? Array.from(HEAP32.subarray($6, $6 + $5)) : []}),
gsl::narrow_cast<int32_t>(starts.size()),
reinterpret_cast<int32_t>((starts.size() > 0) ? starts.data() : nullptr) >> 2,
JSEP_HEAP_PTR((starts.size() > 0) ? starts.data() : nullptr) >> 2,
gsl::narrow_cast<int32_t>(ends.size()),
reinterpret_cast<int32_t>((ends.size() > 0) ? ends.data() : nullptr) >> 2,
JSEP_HEAP_PTR((ends.size() > 0) ? ends.data() : nullptr) >> 2,
gsl::narrow_cast<int32_t>(axes.size()),
reinterpret_cast<int32_t>((axes.size() > 0) ? axes.data() : nullptr) >> 2);
JSEP_HEAP_PTR((axes.size() > 0) ? axes.data() : nullptr) >> 2);
}
};

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/js/operators/split.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class Split : public JsKernel, public SplitBase {
static_cast<int32_t>(axis_),
static_cast<int32_t>(num_outputs_),
gsl::narrow_cast<int32_t>(split_sizes.size()),
reinterpret_cast<int32_t>((split_sizes.size() > 0) ? split_sizes.data() : nullptr) >> 2);
JSEP_HEAP_PTR((split_sizes.size() > 0) ? split_sizes.data() : nullptr) >> 2);
}
};

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/js/operators/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Transpose final : public JsKernel, public TransposeBase {
gsl::narrow_cast<int32_t>(perm_specified_ ? perm_.size() : 0),
// $2: index to HEAP32 of the first int32 element. calculated from right shift memory
// address by 2
reinterpret_cast<int32_t>(perm_specified_ && !perm.empty() ? perm.data() : nullptr) >> 2);
JSEP_HEAP_PTR(perm_specified_ && !perm.empty() ? perm.data() : nullptr) >> 2);
}
};

Expand Down

0 comments on commit a8bb1df

Please sign in to comment.