Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/webgpu] more fixes for access above 2GB #19065

Merged
merged 5 commits into from
Jan 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte
backend.memcpy(src, dst);
} else {
LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${src}, gpuDataId=${dst}, size=${size}`);
const data = module.HEAPU8.subarray(src, src + size);
const data = module.HEAPU8.subarray(src >>> 0, (src >>> 0) + size);
backend.upload(dst, data);
}
},
Expand All @@ -182,7 +182,8 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte
'verbose',
() => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`);

await backend.download(gpuDataId, () => module.HEAPU8.subarray(dataOffset, dataOffset + size));
await backend.download(
gpuDataId, () => module.HEAPU8.subarray(dataOffset >>> 0, (dataOffset >>> 0) + size));
},

// jsepCreateKernel
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/providers/js/js_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ namespace js {
float value; \
ORT_ENFORCE(info.GetAttr<float>(#attr_name, &value));, \
, ({#attr_name : $1}), static_cast<double>(value))
#define JSEP_HEAP_PTR(ptr) reinterpret_cast<uintptr_t>(ptr)

#define JSEP_HEAP8_INDEX(ptr) reinterpret_cast<uintptr_t>(ptr)
#define JSEP_HEAP32_INDEX_START(vec) ((vec.size() > 0) ? reinterpret_cast<uintptr_t>(vec.data()) >> 2 : 0)
#define JSEP_HEAP32_INDEX_END(vec) ((reinterpret_cast<uintptr_t>(vec.data()) >> 2) + vec.size())

// TODO:
// class JsMultiProgramKernel : public OpKernel { /* TBD */ };
Expand Down
29 changes: 14 additions & 15 deletions onnxruntime/core/providers/js/operators/conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ class ConvBase : public JsKernel {
}
conv_attrs_.activation = info.GetAttrOrDefault<std::string>("activation", "");
std::vector<float> activation_params = info.GetAttrsOrDefault<float>("activation_params");
const auto* activation_params_ptr = activation_params.size() > 0 ? activation_params.data() : nullptr;
int64_t channels_last = is_channels_last ? 1 : info.GetAttrOrDefault<int64_t>("channels_last", 0);
auto kernel_shape_0 = conv_attrs_.kernel_shape_specified && kernel_shape.size() > 0 ? kernel_shape[0] : 0;
auto kernel_shape_1 = conv_attrs_.kernel_shape_specified && kernel_shape.size() > 1 ? kernel_shape[1] : 0;
Expand All @@ -43,52 +42,52 @@ class ConvBase : public JsKernel {
"dilations" : [$2],
"group" : $3,
"kernel_shape" : [$4],
"pads" : $5 ? Array.from(HEAP32.subarray($6, $6 + $5)) : [],
"pads" : $5 ? Array.from(HEAP32.subarray($5, $6)) : [],
"strides" : [$7],
"w_is_const" : () JS_ARROW(!!HEAP8[$9]),
"activation" : UTF8ToString($10),
"activation_params" : $11 ? Array.from(HEAPF32.subarray($12, $12 + $11)) : []
"activation_params" : $11 ? Array.from(HEAPF32.subarray($11, $12)) : []
}),
static_cast<int32_t>(conv_attrs_.auto_pad),
static_cast<int32_t>(conv_attrs_.dilations.size() > 0 ? conv_attrs_.dilations[0] : 0),
static_cast<int32_t>(conv_attrs_.group),
static_cast<int32_t>(kernel_shape_0),
static_cast<int32_t>(local_pads.size()),
JSEP_HEAP_PTR(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2,
JSEP_HEAP32_INDEX_START(local_pads),
JSEP_HEAP32_INDEX_END(local_pads),
static_cast<int32_t>(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0),
static_cast<int32_t>(channels_last),
JSEP_HEAP_PTR(&w_is_const_),
JSEP_HEAP8_INDEX(&w_is_const_),
conv_attrs_.activation.c_str(),
activation_params.size(),
JSEP_HEAP_PTR(activation_params_ptr) >> 2);
JSEP_HEAP32_INDEX_START(activation_params),
JSEP_HEAP32_INDEX_END(activation_params));
} else {
JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({
"format" : $11 ? "NHWC" : "NCHW",
"auto_pad" : $1,
"dilations" : [ $2, $3 ],
"group" : $4,
"kernel_shape" : [ $5, $6 ],
"pads" : $7 ? Array.from(HEAP32.subarray($8, $8 + $7)) : [],
"pads" : $7 ? Array.from(HEAP32.subarray($7, $8)) : [],
"strides" : [ $9, $10 ],
"w_is_const" : () JS_ARROW(!!HEAP8[$12]),
"activation" : UTF8ToString($13),
"activation_params" : $14 ? Array.from(HEAPF32.subarray($15, $15 + $14)) : []
"activation_params" : $14 ? Array.from(HEAPF32.subarray($14, $15)) : []
}),
static_cast<int32_t>(conv_attrs_.auto_pad),
static_cast<int32_t>(conv_attrs_.dilations.size() > 0 ? conv_attrs_.dilations[0] : 0),
static_cast<int32_t>(conv_attrs_.dilations.size() > 1 ? conv_attrs_.dilations[1] : 0),
static_cast<int32_t>(conv_attrs_.group),
static_cast<int32_t>(kernel_shape_0),
static_cast<int32_t>(kernel_shape_1),
static_cast<int32_t>(local_pads.size()),
JSEP_HEAP_PTR(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2,
JSEP_HEAP32_INDEX_START(local_pads),
JSEP_HEAP32_INDEX_END(local_pads),
static_cast<int32_t>(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0),
static_cast<int32_t>(conv_attrs_.strides.size() > 1 ? conv_attrs_.strides[1] : 0),
static_cast<int32_t>(channels_last),
JSEP_HEAP_PTR(&w_is_const_),
JSEP_HEAP8_INDEX(&w_is_const_),
conv_attrs_.activation.c_str(),
activation_params.size(),
JSEP_HEAP_PTR(activation_params_ptr) >> 2);
JSEP_HEAP32_INDEX_START(activation_params),
JSEP_HEAP32_INDEX_END(activation_params));
}
}

Expand Down
48 changes: 22 additions & 26 deletions onnxruntime/core/providers/js/operators/conv_transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@
conv_transpose_attrs_.output_shape.end());
std::vector<int32_t> local_output_padding(conv_transpose_attrs_.output_padding.begin(),
conv_transpose_attrs_.output_padding.end());
const auto* local_output_padding_ptr =
local_output_padding.size() > 0 ? local_output_padding.data() : nullptr;
const auto* local_output_shape_ptr =
local_output_shape.size() > 0 ? local_output_shape.data() : nullptr;

// currently only support Conv 1D/2D. TODO: support Conv3D and other
if (conv_transpose_attrs_.dilations.size() == 1 ||
Expand All @@ -52,8 +48,8 @@
"pads" : [ $5, $6 ],
"strides" : [$7],
"wIsConst" : () JS_ARROW(!!HEAP8[$9]),
"outputPadding" : $10 ? Array.from(HEAP32.subarray($11, $11 + $10)) : [],
"outputShape" : $12 ? Array.from(HEAP32.subarray($13, $13 + $12)) : [],
"outputPadding" : $10 ? Array.from(HEAP32.subarray($10, $11)) : [],
"outputShape" : $12 ? Array.from(HEAP32.subarray($12, $13)) : [],
"activation" : UTF8ToString($14)
}),
static_cast<int32_t>(conv_transpose_attrs_.auto_pad),
Expand All @@ -64,11 +60,11 @@
static_cast<int32_t>(pads_1),
static_cast<int32_t>(strides),
static_cast<int32_t>(channels_last),
JSEP_HEAP_PTR(&w_is_const_),
gsl::narrow_cast<int32_t>(local_output_padding.size()),
JSEP_HEAP_PTR(local_output_padding_ptr) >> 2,
gsl::narrow_cast<int32_t>(local_output_shape.size()),
JSEP_HEAP_PTR(local_output_shape_ptr) >> 2,
JSEP_HEAP8_INDEX(&w_is_const_),
JSEP_HEAP32_INDEX_START(local_output_padding),
JSEP_HEAP32_INDEX_END(local_output_padding),
JSEP_HEAP32_INDEX_START(local_output_shape),
JSEP_HEAP32_INDEX_END(local_output_shape),
conv_transpose_attrs_.activation.c_str());
} else {
constexpr size_t pads_vec_size = 4;
Expand Down Expand Up @@ -103,28 +99,28 @@
JSEP_INIT_KERNEL_ATTRIBUTE(ConvTranspose, ({
"format" : $7 ? "NHWC" : "NCHW",
"autoPad" : $1,
"dilations" : Array.from(HEAP32.subarray($2, $2 + /* dialations_vec_size */ 2)),
"dilations" : Array.from(HEAP32.subarray($2, ($2 >>> 0) + /* dialations_vec_size */ 2)),

Check warning on line 102 in onnxruntime/core/providers/js/operators/conv_transpose.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/js/operators/conv_transpose.h#L102

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/js/operators/conv_transpose.h:102:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
"group" : $3,
"kernelShape" : Array.from(HEAP32.subarray($4, $4 + /* kernel_shape_vec_size */ 2)),
"pads" : Array.from(HEAP32.subarray($5, $5 + /* pads_vec_size */ 4)),
"strides" : Array.from(HEAP32.subarray($6, $6 + /* strides_vec_size */ 2)),
"kernelShape" : Array.from(HEAP32.subarray($4, ($4 >>> 0) + /* kernel_shape_vec_size */ 2)),

Check warning on line 104 in onnxruntime/core/providers/js/operators/conv_transpose.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/js/operators/conv_transpose.h#L104

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/js/operators/conv_transpose.h:104:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
"pads" : Array.from(HEAP32.subarray($5, ($5 >>> 0) + /* pads_vec_size */ 4)),
"strides" : Array.from(HEAP32.subarray($6, ($6 >>> 0) + /* strides_vec_size */ 2)),
"wIsConst" : () JS_ARROW(!!HEAP8[$8]),
"outputPadding" : ($9 > 0) ? Array.from(HEAP32.subarray($10, $10 + $9)) : [],
"outputShape" : ($11 > 0) ? Array.from(HEAP32.subarray($12, $12 + $11)) : [],
"outputPadding" : $9 ? Array.from(HEAP32.subarray($9, $10)) : [],
"outputShape" : $11 ? Array.from(HEAP32.subarray($11, $12)) : [],
"activation" : UTF8ToString($13)
}),
static_cast<int32_t>(conv_transpose_attrs_.auto_pad),
JSEP_HEAP_PTR(local_dilations.data()) >> 2,
JSEP_HEAP32_INDEX_START(local_dilations),
static_cast<int32_t>(conv_transpose_attrs_.group),
JSEP_HEAP_PTR(local_kernel_shape.data()) >> 2,
JSEP_HEAP_PTR(local_pads.data()) >> 2,
JSEP_HEAP_PTR(local_strides.data()) >> 2,
JSEP_HEAP32_INDEX_START(local_kernel_shape),
JSEP_HEAP32_INDEX_START(local_pads),
JSEP_HEAP32_INDEX_START(local_strides),
static_cast<int32_t>(channels_last),
JSEP_HEAP_PTR(&w_is_const_),
gsl::narrow_cast<int32_t>(local_output_padding.size()),
JSEP_HEAP_PTR(local_output_padding_ptr) >> 2,
gsl::narrow_cast<int32_t>(local_output_shape.size()),
JSEP_HEAP_PTR(local_output_shape_ptr) >> 2,
JSEP_HEAP8_INDEX(&w_is_const_),
JSEP_HEAP32_INDEX_START(local_output_padding),
JSEP_HEAP32_INDEX_END(local_output_padding),
JSEP_HEAP32_INDEX_START(local_output_shape),
JSEP_HEAP32_INDEX_END(local_output_shape),
conv_transpose_attrs_.activation.c_str());
}
}
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/js/operators/pad.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ class Pad : public JsKernel, public PadBase {

JSEP_INIT_KERNEL_ATTRIBUTE(Pad, ({"mode" : $1,
"value" : $2,
"pads" : $3 ? Array.from(HEAP32.subarray($4, $4 + $3)) : []}),
"pads" : $3 ? Array.from(HEAP32.subarray($3, $4)) : []}),
static_cast<int32_t>(mode_),
static_cast<double>(value_),
gsl::narrow_cast<int32_t>(pads.size()),
JSEP_HEAP_PTR((pads.size() > 0) ? pads.data() : nullptr) >> 2);
JSEP_HEAP32_INDEX_START(pads),
JSEP_HEAP32_INDEX_END(pads));
}
};

Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/js/operators/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ namespace js {
JSEP_INIT_KERNEL_ATTRIBUTE(ReduceKernel, ({ \
"keepDims" : !!$1, \
"noopWithEmptyAxes" : !!$2, \
"axes" : $3 ? (Array.from(HEAP32.subarray($4, $4 + $3))) : [], \
"axes" : $3 ? (Array.from(HEAP32.subarray($3, $4))) : [], \
}), \
static_cast<int32_t>(keepdims_), \
static_cast<int32_t>(noop_with_empty_axes_), \
gsl::narrow_cast<int32_t>(axes.size()), \
JSEP_HEAP_PTR((axes.size() > 0) ? axes.data() : nullptr) >> 2); \
JSEP_HEAP32_INDEX_START(axes), \
JSEP_HEAP32_INDEX_END(axes)); \
} \
};

Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/js/operators/resize.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Resize : public JsKernel, public UpsampleBase {
std::transform(axes_.begin(), axes_.end(), std::back_inserter(axes), [](auto& axis) { return gsl::narrow_cast<int32_t>(axis); });
JSEP_INIT_KERNEL_ATTRIBUTE(Resize, ({
"antialias" : $1,
"axes" : $2 ? Array.from(HEAP32.subarray($3, $3 + $2)) : [],
"axes" : $2 ? Array.from(HEAP32.subarray($2, $3)) : [],
"coordinateTransformMode" : UTF8ToString($4),
"cubicCoeffA" : $5,
"excludeOutside" : $6,
Expand All @@ -33,8 +33,8 @@ class Resize : public JsKernel, public UpsampleBase {
"nearestMode" : UTF8ToString($10),
}),
static_cast<int32_t>(antialias_),
gsl::narrow_cast<int32_t>(axes.size()),
JSEP_HEAP_PTR((axes.size() > 0) ? axes.data() : nullptr) >> 2,
JSEP_HEAP32_INDEX_START(axes),
JSEP_HEAP32_INDEX_END(axes),
resize_coordinate_transformation_mode.c_str(),
static_cast<double>(cubic_coeff_a_),
static_cast<int32_t>(exclude_outside_),
Expand Down
18 changes: 9 additions & 9 deletions onnxruntime/core/providers/js/operators/slice.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ class Slice : public JsKernel, public SliceBase {
std::vector<int32_t> starts(attr_starts.begin(), attr_starts.end());
std::vector<int32_t> ends(attr_ends.begin(), attr_ends.end());

JSEP_INIT_KERNEL_ATTRIBUTE(Slice, ({"starts" : $1 ? Array.from(HEAP32.subarray($2, $2 + $1)) : [],
"ends" : $3 ? Array.from(HEAP32.subarray($4, $4 + $3)) : [],
"axes" : $5 ? Array.from(HEAP32.subarray($6, $6 + $5)) : []}),
gsl::narrow_cast<int32_t>(starts.size()),
JSEP_HEAP_PTR((starts.size() > 0) ? starts.data() : nullptr) >> 2,
gsl::narrow_cast<int32_t>(ends.size()),
JSEP_HEAP_PTR((ends.size() > 0) ? ends.data() : nullptr) >> 2,
gsl::narrow_cast<int32_t>(axes.size()),
JSEP_HEAP_PTR((axes.size() > 0) ? axes.data() : nullptr) >> 2);
JSEP_INIT_KERNEL_ATTRIBUTE(Slice, ({"starts" : $1 ? Array.from(HEAP32.subarray($1, $2)) : [],
"ends" : $3 ? Array.from(HEAP32.subarray($3, $4)) : [],
"axes" : $5 ? Array.from(HEAP32.subarray($5, $6)) : []}),
JSEP_HEAP32_INDEX_START(starts),
JSEP_HEAP32_INDEX_END(starts),
JSEP_HEAP32_INDEX_START(ends),
JSEP_HEAP32_INDEX_END(ends),
JSEP_HEAP32_INDEX_START(axes),
JSEP_HEAP32_INDEX_END(axes));
}
};

Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/js/operators/split.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ class Split : public JsKernel, public SplitBase {

JSEP_INIT_KERNEL_ATTRIBUTE(Split, ({"axis" : $1,
"numOutputs" : $2,
"splitSizes" : $3 ? Array.from(HEAP32.subarray($4, $4 + $3)) : []}),
"splitSizes" : $3 ? Array.from(HEAP32.subarray($3, $4)) : []}),
static_cast<int32_t>(axis_),
static_cast<int32_t>(num_outputs_),
gsl::narrow_cast<int32_t>(split_sizes.size()),
JSEP_HEAP_PTR((split_sizes.size() > 0) ? split_sizes.data() : nullptr) >> 2);
JSEP_HEAP32_INDEX_START(split_sizes),
JSEP_HEAP32_INDEX_END(split_sizes));
}
};

Expand Down
Loading
Loading