From 73f82e85356a83c2c66d0aec40f0ffdb1ceb122c Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 24 Dec 2024 14:47:38 +0800 Subject: [PATCH] [webgpu] Use override shape in shader key --- .../providers/webgpu/program_cache_key.cc | 23 +++++++++++++++---- .../core/providers/webgpu/tensor/where.cc | 6 ++--- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/program_cache_key.cc b/onnxruntime/core/providers/webgpu/program_cache_key.cc index a5c21563dbfcd..a351cacc783cf 100644 --- a/onnxruntime/core/providers/webgpu/program_cache_key.cc +++ b/onnxruntime/core/providers/webgpu/program_cache_key.cc @@ -17,7 +17,7 @@ namespace webgpu { namespace { // append the info of an input or output to the cachekey -void AppendTensorInfo(std::ostream& ss, const Tensor& tensor, ProgramVariableDataType var_type, ProgramTensorMetadataDependency dependency, +void AppendTensorInfo(std::ostream& ss, const TensorShape& tensor_shape, ProgramVariableDataType var_type, ProgramTensorMetadataDependency dependency, bool& first) { if (first) { first = false; @@ -35,9 +35,9 @@ void AppendTensorInfo(std::ostream& ss, const Tensor& tensor, ProgramVariableDat } if ((dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape) { - ss D("Dims=") << tensor.Shape().ToString(); + ss D("Dims=") << tensor_shape.ToString(); } else if ((dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank) { - ss D("Rank=") << tensor.Shape().NumDimensions(); + ss D("Rank=") << tensor_shape.NumDimensions(); } } } // namespace @@ -97,13 +97,26 @@ std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_disp ss << ":" D("Inputs="); first = true; for (const auto& input : program.Inputs()) { - AppendTensorInfo(ss, *input.tensor, input.var_type, input.dependency, first); + AppendTensorInfo(ss, input.use_override_shape ? input.override_shape : input.tensor->Shape(), input.var_type, input.dependency, first); } ss << ":" D("Outputs="); first = true; for (const auto& output : program.Outputs()) { - AppendTensorInfo(ss, *output.tensor, output.var_type, output.dependency, first); + AppendTensorInfo(ss, output.use_override_shape ? output.override_shape : output.tensor->Shape(), output.var_type, output.dependency, first); + } + + if (!program.Indices().empty()) { + ss << ":" D("Indices="); + first = true; + for (const auto& indices_shape : program.Indices()) { + if (first) { + first = false; + } else { + ss << '|'; + } + ss D("Rank=") << indices_shape.NumDimensions(); + } } return SS_GET(ss); diff --git a/onnxruntime/core/providers/webgpu/tensor/where.cc b/onnxruntime/core/providers/webgpu/tensor/where.cc index 524dd07d5b710..e8cdabb9dbe40 100644 --- a/onnxruntime/core/providers/webgpu/tensor/where.cc +++ b/onnxruntime/core/providers/webgpu/tensor/where.cc @@ -134,9 +134,9 @@ Status Where::ComputeInternal(ComputeContext& context) const { program .CacheHint(is_broadcast) .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) - .AddInputs({{cond_tensor, ProgramTensorMetadataDependency::Rank, {(cond_shape.Size() + 3) / 4}, 4}, - {x_tensor, ProgramTensorMetadataDependency::Rank, {(x_shape.Size() + 3) / 4}, 4}, - {y_tensor, ProgramTensorMetadataDependency::Rank, {(y_shape.Size() + 3) / 4}, 4}}) + .AddInputs({{cond_tensor, ProgramTensorMetadataDependency::Type, {(cond_shape.Size() + 3) / 4}, 4}, + {x_tensor, ProgramTensorMetadataDependency::Type, {(x_shape.Size() + 3) / 4}, 4}, + {y_tensor, ProgramTensorMetadataDependency::Type, {(y_shape.Size() + 3) / 4}, 4}}) .AddOutput({output_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}) .AddUniformVariables({ {static_cast(vec_size)},