Skip to content

Commit

Permalink
[webgpu] Use override shape in shader key (#23188)
Browse files Browse the repository at this point in the history
### Description
This PR 1) uses override shape instead of tensor original shape in
shader key to reduce some shader variants; 2) adds indices shape rank to
shader key in case some potential errors.
  • Loading branch information
qjia7 authored Jan 7, 2025
1 parent 519fae0 commit 4883ec5
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,7 @@ Status BinaryElementwise::ComputeInternal(ComputeContext& context) const {
.AddIndices(reshaped_output_shape)
.AddIndices(reshaped_lhs_shape)
.AddIndices(reshaped_rhs_shape)
.CacheHint("V" + absl::StrJoin({reshaped_lhs_shape.NumDimensions(),
reshaped_rhs_shape.NumDimensions(),
reshaped_output_shape.NumDimensions()},
";"));
.CacheHint("V");
} else {
// Mode Broadcast
// cache hint: "B"
Expand Down
23 changes: 18 additions & 5 deletions onnxruntime/core/providers/webgpu/program_cache_key.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace webgpu {

namespace {
// append the info of an input or output to the cachekey
void AppendTensorInfo(std::ostream& ss, const Tensor& tensor, ProgramVariableDataType var_type, ProgramTensorMetadataDependency dependency,
void AppendTensorInfo(std::ostream& ss, const TensorShape& tensor_shape, ProgramVariableDataType var_type, ProgramTensorMetadataDependency dependency,
bool& first) {
if (first) {
first = false;
Expand All @@ -35,9 +35,9 @@ void AppendTensorInfo(std::ostream& ss, const Tensor& tensor, ProgramVariableDat
}

if ((dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape) {
ss D("Dims=") << tensor.Shape().ToString();
ss D("Dims=") << tensor_shape.ToString();
} else if ((dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank) {
ss D("Rank=") << tensor.Shape().NumDimensions();
ss D("Rank=") << tensor_shape.NumDimensions();
}
}
} // namespace
Expand Down Expand Up @@ -97,13 +97,26 @@ std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_disp
ss << ":" D("Inputs=");
first = true;
for (const auto& input : program.Inputs()) {
AppendTensorInfo(ss, *input.tensor, input.var_type, input.dependency, first);
AppendTensorInfo(ss, input.use_override_shape ? input.override_shape : input.tensor->Shape(), input.var_type, input.dependency, first);
}

ss << ":" D("Outputs=");
first = true;
for (const auto& output : program.Outputs()) {
AppendTensorInfo(ss, *output.tensor, output.var_type, output.dependency, first);
AppendTensorInfo(ss, output.use_override_shape ? output.override_shape : output.tensor->Shape(), output.var_type, output.dependency, first);
}

if (!program.Indices().empty()) {
ss << ":" D("Indices=");
first = true;
for (const auto& indices_shape : program.Indices()) {
if (first) {
first = false;
} else {
ss << '|';
}
ss D("Rank=") << indices_shape.NumDimensions();
}
}

return SS_GET(ss);
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/webgpu/tensor/where.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ Status Where::ComputeInternal(ComputeContext& context) const {
program
.CacheHint(is_broadcast)
.SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddInputs({{cond_tensor, ProgramTensorMetadataDependency::Rank, {(cond_shape.Size() + 3) / 4}, 4},
{x_tensor, ProgramTensorMetadataDependency::Rank, {(x_shape.Size() + 3) / 4}, 4},
{y_tensor, ProgramTensorMetadataDependency::Rank, {(y_shape.Size() + 3) / 4}, 4}})
.AddInputs({{cond_tensor, ProgramTensorMetadataDependency::Type, {(cond_shape.Size() + 3) / 4}, 4},
{x_tensor, ProgramTensorMetadataDependency::Type, {(x_shape.Size() + 3) / 4}, 4},
{y_tensor, ProgramTensorMetadataDependency::Type, {(y_shape.Size() + 3) / 4}, 4}})
.AddOutput({output_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4})
.AddUniformVariables({
{static_cast<uint32_t>(vec_size)},
Expand Down

0 comments on commit 4883ec5

Please sign in to comment.