Skip to content

Commit

Permalink
Support multiple UniformParamsBuffer (#2348)
Browse files Browse the repository at this point in the history
Summary:
bypass-github-export-checks

Pull Request resolved: #2348

Before: Each node contains a `UniformParamsBuffer`.
After: Each node contains a `std::vector<std::shared_ptr<UniformParamsBuffer>>`.

In follow up changes, we will break up parameters to be passed via multiple UniformParamsBuffer, since
1. some are tensor-specific (e.g. image extents) and
2. others are operator-specific (e.g. alpha for binary ops).

Hence, we need **`std::vector`**.

We are adding the methods for #1 in #2340. Since #1 and #2 will be owned by different objects, we need **pointers**. Since #1 is owned by `vTensor` which is non-copyable, we can't use unique_ptr so we need **`std::shared_ptr`**.
ghstack-source-id: 218195447
exported-using-ghexport

Reviewed By: SS-JIA

Differential Revision: D54691831

fbshipit-source-id: 84ab9f777e247fd56234290ed7f7343b9701c73f
  • Loading branch information
jorgep31415 authored and facebook-github-bot committed Mar 11, 2024
1 parent caade55 commit a33fbd8
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 34 deletions.
6 changes: 6 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,12 @@ class ComputeGraph final {
ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true);
ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true);

template <typename Block>
inline std::shared_ptr<api::UniformParamsBuffer> create_params_buffer(
const Block& data) {
return std::make_shared<api::UniformParamsBuffer>(context_.get(), data);
}

/*
* Convenience function to add an input tensor along with its staging buffer
*/
Expand Down
6 changes: 3 additions & 3 deletions backends/vulkan/runtime/graph/ops/ExecuteNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ ExecuteNode::ExecuteNode(
const api::utils::uvec3& global_workgroup_size,
const api::utils::uvec3& local_workgroup_size,
const std::vector<ArgGroup>& args,
api::UniformParamsBuffer&& params)
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params)
: shader_(shader),
global_workgroup_size_(global_workgroup_size),
local_workgroup_size_(local_workgroup_size),
args_(args),
params_(std::move(params)) {
params_(params) {
graph.update_descriptor_counts(shader, /*execute = */ true);
}

Expand All @@ -43,7 +43,7 @@ void ExecuteNode::encode(ComputeGraph* graph) {
uint32_t idx = 0;
idx = bind_values_to_descriptor_set(
graph, args_, pipeline_barrier, descriptor_set, idx);
descriptor_set.bind(idx, params_.buffer());
bind_params_to_descriptor_set(params_, descriptor_set, idx);

context->register_shader_dispatch(
descriptor_set, pipeline_barrier, shader_, global_workgroup_size_);
Expand Down
5 changes: 2 additions & 3 deletions backends/vulkan/runtime/graph/ops/ExecuteNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class ExecuteNode final {
const api::utils::uvec3& global_workgroup_size,
const api::utils::uvec3& local_workgroup_size,
const std::vector<ArgGroup>& args,
api::UniformParamsBuffer&& params);
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params);

~ExecuteNode() = default;

Expand All @@ -64,9 +64,8 @@ class ExecuteNode final {
const api::utils::uvec3 global_workgroup_size_;
const api::utils::uvec3 local_workgroup_size_;
const std::vector<ArgGroup> args_;
// TODO(T180906086): pass multiple buffers and index with ValueRef.
// TODO(T180906457): allow re-computing param buffers.
api::UniformParamsBuffer params_;
std::vector<std::shared_ptr<api::UniformParamsBuffer>> params_;
};

} // namespace vulkan
Expand Down
6 changes: 3 additions & 3 deletions backends/vulkan/runtime/graph/ops/PrepackNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ PrepackNode::PrepackNode(
const api::utils::uvec3& local_workgroup_size,
const ValueRef tref,
const ValueRef packed,
api::UniformParamsBuffer&& params)
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params)
: shader_(shader),
global_workgroup_size_(global_workgroup_size),
local_workgroup_size_(local_workgroup_size),
tref_(tref),
packed_(packed),
params_(std::move(params)) {
params_(params) {
graph.update_descriptor_counts(shader, /*execute = */ false);
}

Expand Down Expand Up @@ -61,7 +61,7 @@ void PrepackNode::encode(ComputeGraph* graph) {
descriptor_set,
idx++);
bind_staging_to_descriptor_set(staging, descriptor_set, idx++);
descriptor_set.bind(idx, params_.buffer());
bind_params_to_descriptor_set(params_, descriptor_set, idx);

context->register_shader_dispatch(
descriptor_set, pipeline_barrier, shader_, global_workgroup_size_);
Expand Down
5 changes: 2 additions & 3 deletions backends/vulkan/runtime/graph/ops/PrepackNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class PrepackNode final {
const api::utils::uvec3& local_workgroup_size,
const ValueRef tref,
const ValueRef packed,
api::UniformParamsBuffer&& params);
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params);

~PrepackNode() = default;

Expand All @@ -49,9 +49,8 @@ class PrepackNode final {
const api::utils::uvec3 local_workgroup_size_;
const ValueRef tref_;
const ValueRef packed_;
// TODO(T180906086): pass multiple buffers and index with ValueRef.
// TODO(T180906457): allow re-computing param buffers.
api::UniformParamsBuffer params_;
std::vector<std::shared_ptr<api::UniformParamsBuffer>> params_;
};

} // namespace vulkan
Expand Down
3 changes: 1 addition & 2 deletions backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ void add_arithmetic_node(
get_size_as_ivec4(t_in2),
alpha_val,
};
api::UniformParamsBuffer params(graph.context(), block);

graph.execute_nodes().emplace_back(new ExecuteNode(
graph,
Expand All @@ -81,7 +80,7 @@ void add_arithmetic_node(
local_size,
{{out, api::MemoryAccessType::WRITE},
{{arg1, arg2}, api::MemoryAccessType::READ}},
std::move(params)));
{graph.create_params_buffer(block)}));
}

REGISTER_OPERATORS {
Expand Down
17 changes: 9 additions & 8 deletions backends/vulkan/runtime/graph/ops/impl/Staging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,14 @@ void add_staging_to_tensor_node(
api::utils::uvec3 global_size = t_out.extents();
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);

api::UniformParamsBuffer params(
graph.context(), create_staging_params(t_out));

graph.execute_nodes().emplace_back(new ExecuteNode(
graph,
shader,
global_size,
local_size,
{{out_tensor, api::MemoryAccessType::WRITE},
{in_staging, api::MemoryAccessType::READ}},
std::move(params)));
{graph.create_params_buffer(create_staging_params(t_out))}));
}

void add_tensor_to_staging_node(
Expand All @@ -71,7 +68,6 @@ void add_tensor_to_staging_node(
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);

StagingParams sp = create_staging_params(t_in);
api::UniformParamsBuffer params(graph.context(), sp);

// TODO(T181194784): These are workgroup sizes for special cases. Refactor the
// calculation of workgroup sizes to a standalone function. We should use
Expand All @@ -98,7 +94,7 @@ void add_tensor_to_staging_node(
local_size,
{{in_tensor, api::MemoryAccessType::READ},
{out_staging, api::MemoryAccessType::WRITE}},
std::move(params)));
{graph.create_params_buffer(sp)}));
}

ValueRef prepack(ComputeGraph& graph, const ValueRef vref) {
Expand All @@ -112,10 +108,15 @@ ValueRef prepack(ComputeGraph& graph, const ValueRef vref) {
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);

StagingParams sp = create_staging_params(t);
api::UniformParamsBuffer params(graph.context(), sp);

graph.prepack_nodes().emplace_back(new PrepackNode(
graph, shader, global_size, local_size, vref, v, std::move(params)));
graph,
shader,
global_size,
local_size,
vref,
v,
{graph.create_params_buffer(sp)}));

return v;
}
Expand Down
25 changes: 18 additions & 7 deletions backends/vulkan/runtime/graph/ops/utils/BindingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,6 @@ void bind_tensor_to_descriptor_set(
}
}

void bind_staging_to_descriptor_set(
api::StorageBuffer& staging,
api::DescriptorSet& descriptor_set,
const uint32_t idx) {
descriptor_set.bind(idx, staging.buffer());
}

uint32_t bind_values_to_descriptor_set(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
Expand Down Expand Up @@ -63,6 +56,24 @@ uint32_t bind_values_to_descriptor_set(
return idx;
}

uint32_t bind_params_to_descriptor_set(
std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params,
api::DescriptorSet& descriptor_set,
const uint32_t base_idx) {
uint32_t idx = base_idx;
for (auto& param : params) {
descriptor_set.bind(idx++, param->buffer());
}
return idx;
}

void bind_staging_to_descriptor_set(
api::StorageBuffer& staging,
api::DescriptorSet& descriptor_set,
const uint32_t idx) {
descriptor_set.bind(idx, staging.buffer());
}

} // namespace vulkan
} // namespace native
} // namespace at
23 changes: 18 additions & 5 deletions backends/vulkan/runtime/graph/ops/utils/BindingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,38 @@ namespace at {
namespace native {
namespace vulkan {

//
// For objects in the graph
//

void bind_tensor_to_descriptor_set(
vTensor& tensor,
api::PipelineBarrier& pipeline_barrier,
const api::MemoryAccessType accessType,
api::DescriptorSet& descriptor_set,
const uint32_t idx);

void bind_staging_to_descriptor_set(
api::StorageBuffer& staging,
api::DescriptorSet& descriptor_set,
const uint32_t idx);

uint32_t bind_values_to_descriptor_set(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
api::PipelineBarrier& pipeline_barrier,
api::DescriptorSet& descriptor_set,
const uint32_t base_idx);

//
// For objects NOT in the graph
//

uint32_t bind_params_to_descriptor_set(
std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params,
api::DescriptorSet& descriptor_set,
const uint32_t base_idx);

void bind_staging_to_descriptor_set(
api::StorageBuffer& staging,
api::DescriptorSet& descriptor_set,
const uint32_t idx);

} // namespace vulkan
} // namespace native
} // namespace at
Expand Down

0 comments on commit a33fbd8

Please sign in to comment.