diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 776f5e12ee..1253111150 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -154,6 +154,12 @@ class ComputeGraph final { ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true); ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true); + template + inline std::shared_ptr create_params_buffer( + const Block& data) { + return std::make_shared(context_.get(), data); + } + /* * Convenience function to add an input tensor along with its staging buffer */ diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp b/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp index f6649fb19c..496a94238b 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp @@ -22,12 +22,12 @@ ExecuteNode::ExecuteNode( const api::utils::uvec3& global_workgroup_size, const api::utils::uvec3& local_workgroup_size, const std::vector& args, - api::UniformParamsBuffer&& params) + const std::vector>& params) : shader_(shader), global_workgroup_size_(global_workgroup_size), local_workgroup_size_(local_workgroup_size), args_(args), - params_(std::move(params)) { + params_(params) { graph.update_descriptor_counts(shader, /*execute = */ true); } @@ -43,7 +43,7 @@ void ExecuteNode::encode(ComputeGraph* graph) { uint32_t idx = 0; idx = bind_values_to_descriptor_set( graph, args_, pipeline_barrier, descriptor_set, idx); - descriptor_set.bind(idx, params_.buffer()); + bind_params_to_descriptor_set(params_, descriptor_set, idx); context->register_shader_dispatch( descriptor_set, pipeline_barrier, shader_, global_workgroup_size_); diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.h b/backends/vulkan/runtime/graph/ops/ExecuteNode.h index 5bcad8fb80..5e3a1e003b 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.h +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.h @@ -53,7 +53,7 @@ class ExecuteNode final { const api::utils::uvec3& global_workgroup_size, const api::utils::uvec3& local_workgroup_size, const std::vector& args, - api::UniformParamsBuffer&& params); + const std::vector>& params); ~ExecuteNode() = default; @@ -64,9 +64,8 @@ class ExecuteNode final { const api::utils::uvec3 global_workgroup_size_; const api::utils::uvec3 local_workgroup_size_; const std::vector args_; - // TODO(T180906086): pass multiple buffers and index with ValueRef. // TODO(T180906457): allow re-computing param buffers. - api::UniformParamsBuffer params_; + std::vector> params_; }; } // namespace vulkan diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp index 43ad64c942..c21c1447d9 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp @@ -24,13 +24,13 @@ PrepackNode::PrepackNode( const api::utils::uvec3& local_workgroup_size, const ValueRef tref, const ValueRef packed, - api::UniformParamsBuffer&& params) + const std::vector>& params) : shader_(shader), global_workgroup_size_(global_workgroup_size), local_workgroup_size_(local_workgroup_size), tref_(tref), packed_(packed), - params_(std::move(params)) { + params_(params) { graph.update_descriptor_counts(shader, /*execute = */ false); } @@ -61,7 +61,7 @@ void PrepackNode::encode(ComputeGraph* graph) { descriptor_set, idx++); bind_staging_to_descriptor_set(staging, descriptor_set, idx++); - descriptor_set.bind(idx, params_.buffer()); + bind_params_to_descriptor_set(params_, descriptor_set, idx); context->register_shader_dispatch( descriptor_set, pipeline_barrier, shader_, global_workgroup_size_); diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.h b/backends/vulkan/runtime/graph/ops/PrepackNode.h index 80465703da..7d8a8b4ce3 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.h +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.h @@ -37,7 +37,7 @@ class PrepackNode final { const api::utils::uvec3& local_workgroup_size, const ValueRef tref, const ValueRef packed, - api::UniformParamsBuffer&& params); + const std::vector>& params); ~PrepackNode() = default; @@ -49,9 +49,8 @@ class PrepackNode final { const api::utils::uvec3 local_workgroup_size_; const ValueRef tref_; const ValueRef packed_; - // TODO(T180906086): pass multiple buffers and index with ValueRef. // TODO(T180906457): allow re-computing param buffers. - api::UniformParamsBuffer params_; + std::vector> params_; }; } // namespace vulkan diff --git a/backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp b/backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp index 52cd04c492..453e290045 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp @@ -72,7 +72,6 @@ void add_arithmetic_node( get_size_as_ivec4(t_in2), alpha_val, }; - api::UniformParamsBuffer params(graph.context(), block); graph.execute_nodes().emplace_back(new ExecuteNode( graph, @@ -81,7 +80,7 @@ void add_arithmetic_node( local_size, {{out, api::MemoryAccessType::WRITE}, {{arg1, arg2}, api::MemoryAccessType::READ}}, - std::move(params))); + {graph.create_params_buffer(block)})); } REGISTER_OPERATORS { diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index 349e000086..1659a030ff 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -45,9 +45,6 @@ void add_staging_to_tensor_node( api::utils::uvec3 global_size = t_out.extents(); api::utils::uvec3 local_size = adaptive_work_group_size(global_size); - api::UniformParamsBuffer params( - graph.context(), create_staging_params(t_out)); - graph.execute_nodes().emplace_back(new ExecuteNode( graph, shader, @@ -55,7 +52,7 @@ void add_staging_to_tensor_node( local_size, {{out_tensor, api::MemoryAccessType::WRITE}, {in_staging, api::MemoryAccessType::READ}}, - std::move(params))); + {graph.create_params_buffer(create_staging_params(t_out))})); } void add_tensor_to_staging_node( @@ -71,7 +68,6 @@ void add_tensor_to_staging_node( api::utils::uvec3 local_size = adaptive_work_group_size(global_size); StagingParams sp = create_staging_params(t_in); - api::UniformParamsBuffer params(graph.context(), sp); // TODO(T181194784): These are workgroup sizes for special cases. Refactor the // calculation of workgroup sizes to a standalone function. We should use @@ -98,7 +94,7 @@ void add_tensor_to_staging_node( local_size, {{in_tensor, api::MemoryAccessType::READ}, {out_staging, api::MemoryAccessType::WRITE}}, - std::move(params))); + {graph.create_params_buffer(sp)})); } ValueRef prepack(ComputeGraph& graph, const ValueRef vref) { @@ -112,10 +108,15 @@ ValueRef prepack(ComputeGraph& graph, const ValueRef vref) { api::utils::uvec3 local_size = adaptive_work_group_size(global_size); StagingParams sp = create_staging_params(t); - api::UniformParamsBuffer params(graph.context(), sp); graph.prepack_nodes().emplace_back(new PrepackNode( - graph, shader, global_size, local_size, vref, v, std::move(params))); + graph, + shader, + global_size, + local_size, + vref, + v, + {graph.create_params_buffer(sp)})); return v; } diff --git a/backends/vulkan/runtime/graph/ops/utils/BindingUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/BindingUtils.cpp index 6e471167ec..6e1d9b3013 100644 --- a/backends/vulkan/runtime/graph/ops/utils/BindingUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/BindingUtils.cpp @@ -29,13 +29,6 @@ void bind_tensor_to_descriptor_set( } } -void bind_staging_to_descriptor_set( - api::StorageBuffer& staging, - api::DescriptorSet& descriptor_set, - const uint32_t idx) { - descriptor_set.bind(idx, staging.buffer()); -} - uint32_t bind_values_to_descriptor_set( ComputeGraph* graph, const std::vector& args, @@ -63,6 +56,24 @@ uint32_t bind_values_to_descriptor_set( return idx; } +uint32_t bind_params_to_descriptor_set( + std::vector>& params, + api::DescriptorSet& descriptor_set, + const uint32_t base_idx) { + uint32_t idx = base_idx; + for (auto& param : params) { + descriptor_set.bind(idx++, param->buffer()); + } + return idx; +} + +void bind_staging_to_descriptor_set( + api::StorageBuffer& staging, + api::DescriptorSet& descriptor_set, + const uint32_t idx) { + descriptor_set.bind(idx, staging.buffer()); +} + } // namespace vulkan } // namespace native } // namespace at diff --git a/backends/vulkan/runtime/graph/ops/utils/BindingUtils.h b/backends/vulkan/runtime/graph/ops/utils/BindingUtils.h index 28649a1194..e8d508b791 100644 --- a/backends/vulkan/runtime/graph/ops/utils/BindingUtils.h +++ b/backends/vulkan/runtime/graph/ops/utils/BindingUtils.h @@ -16,6 +16,10 @@ namespace at { namespace native { namespace vulkan { +// +// For objects in the graph +// + void bind_tensor_to_descriptor_set( vTensor& tensor, api::PipelineBarrier& pipeline_barrier, @@ -23,11 +27,6 @@ void bind_tensor_to_descriptor_set( api::DescriptorSet& descriptor_set, const uint32_t idx); -void bind_staging_to_descriptor_set( - api::StorageBuffer& staging, - api::DescriptorSet& descriptor_set, - const uint32_t idx); - uint32_t bind_values_to_descriptor_set( ComputeGraph* graph, const std::vector& args, @@ -35,6 +34,20 @@ uint32_t bind_values_to_descriptor_set( api::DescriptorSet& descriptor_set, const uint32_t base_idx); +// +// For objects NOT in the graph +// + +uint32_t bind_params_to_descriptor_set( + std::vector>& params, + api::DescriptorSet& descriptor_set, + const uint32_t base_idx); + +void bind_staging_to_descriptor_set( + api::StorageBuffer& staging, + api::DescriptorSet& descriptor_set, + const uint32_t idx); + } // namespace vulkan } // namespace native } // namespace at