diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index 0962d9191d785..e0a0113e13224 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -47,11 +47,11 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("T", WebGpuSupportedNumberTypes()), Transpose); -const std::string AppendPermFunction(gsl::span perm) { +const std::string AppendPermFunction(gsl::span perm) { std::ostringstream ss; ss.imbue(std::locale::classic()); - ss << "fn perm(i: y_indices_t)->x_indices_t {\n" - " var a: x_indices_t;\n"; + ss << "fn perm(i: output_indices_t)->a_indices_t {\n" + " var a: a_indices_t;\n"; for (size_t i = 0; i < perm.size(); ++i) { ss << " a[" << perm[i] << "] = i[" << i << "];\n"; } @@ -60,21 +60,52 @@ const std::string AppendPermFunction(gsl::span perm) { return ss.str(); } +auto SqueezeShape(const gsl::span& shape, const gsl::span& adjusted_perm, InlinedVector& new_shape, InlinedVector& new_perm) { + for (auto i = 0; i < shape.size(); ++i) { + if (shape[i] != 1) { + new_shape.push_back(shape[i]); + } + if (shape[adjusted_perm[i]] != 1) { + new_perm.push_back(adjusted_perm[i]); + } + } +}; + Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); - const auto& output = shader.AddOutput("y", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); - shader.AppendImplementation(AppendPermFunction(this->perm_)); - shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"), - " let indices = ", output.OffsetToIndices("global_idx"), - ";\n" - " let x_indices = perm(indices); \n" - " ", - output.SetByOffset("global_idx", input.GetByIndices("x_indices"))); + const auto& input = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + + if (use_shared_) { + shader.AppendImplementation("var tile : array, tile_size>;\n"); + shader.SetMainFunctionBody( + " let stride = (uniforms.output_shape[1] - 1) / tile_size + 1;\n" + " let workgroup_id_x = workgroup_idx % stride;\n" + " let workgroup_id_y = workgroup_idx / stride;\n" + " let input_col = workgroup_id_y * tile_size + local_id.x;\n" + " let input_row = workgroup_id_x * tile_size + local_id.y;\n" + " if (input_row < uniforms.a_shape[0] && input_col < uniforms.a_shape[1]) {\n" + " tile[local_id.y][local_id.x] = " + + input.GetByIndices("a_indices_t(input_row, input_col)") + + ";\n" + " }\n" + " workgroupBarrier();\n" + " let output_col = workgroup_id_x * tile_size + local_id.x;\n" + " let output_row = workgroup_id_y * tile_size + local_id.y;\n" + " if (output_row < uniforms.output_shape[0] && output_col < uniforms.output_shape[1]) {\n " + + output.SetByIndices("output_indices_t(output_row, output_col)", "tile[local_id.x][local_id.y]") + "\n }"); + } else { + shader.AppendImplementation(AppendPermFunction(this->perm_)); + shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"), + " let indices = ", output.OffsetToIndices("global_idx"), + ";\n" + " let x_indices = perm(indices);\n", + " ", + output.SetByOffset("global_idx", input.GetByIndices("x_indices"))); + } return Status::OK(); } Status Transpose::ComputeInternal(ComputeContext& context) const { - // TODO: there is an optimized version of transpose to port. const auto* input_tensor = context.Input(0); const TensorShape& input_shape = input_tensor->Shape(); int32_t rank = gsl::narrow_cast(input_shape.NumDimensions()); @@ -86,16 +117,42 @@ Status Transpose::ComputeInternal(ComputeContext& context) const { TensorShape output_shape(output_dims); auto* output_tensor = context.Output(0, output_shape); + InlinedVector new_shape{}; + InlinedVector new_perm{}; + SqueezeShape(input_shape.GetDims(), *p_perm, new_shape, new_perm); + const bool channels_last = new_perm == InlinedVector({2, 3, 1}); + const bool channels_first = new_perm == InlinedVector({3, 1, 2}); + const bool use_shared = (new_shape.size() == 2 && new_perm[0] > new_perm[1]) || channels_last || channels_first; + auto new_input_shape = input_shape; + TensorShape new_output_shape(output_dims); + if (use_shared) { + new_input_shape = channels_last + ? TensorShape({new_shape[0], new_shape[1] * new_shape[2]}) + : channels_first + ? TensorShape({new_shape[0] * new_shape[1], new_shape[2]}) + : new_shape; + new_output_shape = TensorShape({new_input_shape[1], new_input_shape[0]}); + } + uint32_t output_size = gsl::narrow_cast(input_tensor->Shape().Size()); - TransposeProgram program{*p_perm}; + TransposeProgram program{*p_perm, use_shared}; + if (use_shared) { + program.SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1); + } + program .CacheHint(absl::StrJoin(*p_perm, "-")) - .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) - .AddOutputs({output_tensor}) - .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::None, new_output_shape, 1}}) + .SetDispatchGroupSize(static_cast((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE), + static_cast(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE))) .AddUniformVariables({ {static_cast(output_size)}, }); + + use_shared ? program.SetDispatchGroupSize(static_cast((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE), + static_cast(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE))) + : program.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); return context.RunProgram(program); } diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.h b/onnxruntime/core/providers/webgpu/tensor/transpose.h index 3ca5674d5dfab..7cf5c1fe0865d 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.h +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.h @@ -11,26 +11,28 @@ namespace onnxruntime { namespace webgpu { +class Transpose final : public WebGpuKernel, public TransposeBase { + public: + Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} { + } + Status ComputeInternal(ComputeContext& context) const override; + constexpr static uint32_t TILE_SIZE = 16; +}; + class TransposeProgram final : public Program { public: - TransposeProgram(const gsl::span& permutations) - : Program{"Transpose"}, perm_(permutations.begin(), permutations.end()) { + TransposeProgram(const gsl::span& permutations, bool use_shared) + : Program{"Transpose"}, perm_(permutations.begin(), permutations.end()), use_shared_(use_shared) { } Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}); + WEBGPU_PROGRAM_DEFINE_CONSTANTS({"tile_size", Transpose::TILE_SIZE}); private: - InlinedVector perm_; -}; - -class Transpose final : public WebGpuKernel, public TransposeBase { - public: - Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} { - } - - Status ComputeInternal(ComputeContext& context) const override; + InlinedVector perm_; + const bool use_shared_; }; } // namespace webgpu