Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/fs-eire/webgpu-ep' into webgpu-e…
Browse files Browse the repository at this point in the history
…p-gather
  • Loading branch information
fs-eire committed Sep 25, 2024
2 parents d734543 + 9bdbd85 commit bac2848
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
6 changes: 0 additions & 6 deletions onnxruntime/core/providers/webgpu/shader_variable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,6 @@ ShaderIndicesHelper::ShaderIndicesHelper(std::string_view name, ProgramVariableD
element_type_alias_{name_ + "_element_t"},
indices_type_alias_{name_ + "_indices_t"} {}

inline int ShaderIndicesHelper::Rank() {
// getting the rank means the information is exposed to the shader. So we consider it as a usage of shape and stride.
usage_ |= ShaderUsage::UseShapeAndStride;
return rank_;
}

ShaderVariableHelper::ShaderVariableHelper(std::string_view name, ProgramVariableDataType type, ShaderUsage usage, const TensorShape& dims)
: ShaderIndicesHelper{name, type, usage, dims} {
ORT_ENFORCE(type_ != ProgramVariableDataType::InvalidType, "Invalid type for variable ", name_);
Expand Down
8 changes: 7 additions & 1 deletion onnxruntime/core/providers/webgpu/shader_variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class ShaderIndicesHelper {
inline int NumComponents() const { return num_components_; }

// get the rank of the indices.
inline int Rank();
inline int Rank() const;

// create a WGSL expression ({varname}_indices_t) for getting indices from offset.
// \param offset: a WGSL expression (u32) representing the offset.
Expand Down Expand Up @@ -213,6 +213,12 @@ std::string pass_as_string(T&& v) {
}
} // namespace detail

inline int ShaderIndicesHelper::Rank() const {
// getting the rank means the information is exposed to the shader. So we consider it as a usage of shape and stride.
usage_ |= ShaderUsage::UseShapeAndStride;
return rank_;
}

inline std::string ShaderIndicesHelper::OffsetToIndices(std::string_view offset_expr) const {
usage_ |= ShaderUsage::UseOffsetToIndices | ShaderUsage::UseShapeAndStride;
return rank_ < 2 ? std::string{offset_expr}
Expand Down

0 comments on commit bac2848

Please sign in to comment.