Skip to content

Commit

Permalink
[ET-VK] Adding function to set push constants in Command buffer.
Browse files Browse the repository at this point in the history
Differential Revision: D66714317

Pull Request resolved: #7221
  • Loading branch information
trivedivivek authored Dec 9, 2024
1 parent a9cf2a4 commit b6df23b
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 2 deletions.
13 changes: 12 additions & 1 deletion backends/vulkan/runtime/api/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ void Context::register_shader_dispatch(
const vkapi::DescriptorSet& descriptors,
vkapi::PipelineBarrier& pipeline_barrier,
const vkapi::ShaderInfo& shader_descriptor,
const utils::uvec3& global_workgroup_size) {
const utils::uvec3& global_workgroup_size,
const void* push_constants_data,
const uint32_t push_constants_size) {
// Adjust the global workgroup size based on the output tile size
uint32_t global_wg_w = utils::div_up(
global_workgroup_size[0u], shader_descriptor.out_tile_size[0u]);
Expand All @@ -145,6 +147,15 @@ void Context::register_shader_dispatch(
cmd_.bind_descriptors(descriptors.get_bind_handle());
cmd_.insert_barrier(pipeline_barrier);

if (push_constants_size > 0 && push_constants_data != nullptr) {
const VkDescriptorSetLayout shader_layout =
shader_layout_cache().retrieve(shader_descriptor.kernel_layout);
const VkPipelineLayout pipeline_layout =
pipeline_layout_cache().retrieve(shader_layout);
cmd_.set_push_constants(
pipeline_layout, push_constants_data, push_constants_size);
}

cmd_.dispatch(effective_global_wg);
}

Expand Down
4 changes: 3 additions & 1 deletion backends/vulkan/runtime/api/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,9 @@ class Context final {
const vkapi::DescriptorSet&,
vkapi::PipelineBarrier&,
const vkapi::ShaderInfo&,
const utils::uvec3&);
const utils::uvec3&,
const void* = nullptr,
const uint32_t = 0);

void register_blit(
vkapi::PipelineBarrier&,
Expand Down
15 changes: 15 additions & 0 deletions backends/vulkan/runtime/vk_api/Command.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,21 @@ void CommandBuffer::bind_descriptors(VkDescriptorSet descriptors) {
state_ = CommandBuffer::State::DESCRIPTORS_BOUND;
}

void CommandBuffer::set_push_constants(
VkPipelineLayout pipeline_layout,
const void* push_constants_data,
uint32_t push_constants_size) {
if (push_constants_data != nullptr && push_constants_size > 0) {
vkCmdPushConstants(
handle_,
pipeline_layout,
VK_SHADER_STAGE_COMPUTE_BIT,
0,
push_constants_size,
push_constants_data);
}
}

void CommandBuffer::insert_barrier(PipelineBarrier& pipeline_barrier) {
VK_CHECK_COND(
state_ == CommandBuffer::State::DESCRIPTORS_BOUND ||
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/runtime/vk_api/Command.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class CommandBuffer final {

void bind_pipeline(VkPipeline, VkPipelineLayout, const utils::uvec3);
void bind_descriptors(VkDescriptorSet);
void set_push_constants(VkPipelineLayout, const void*, uint32_t);

void insert_barrier(PipelineBarrier& pipeline_barrier);
void dispatch(const utils::uvec3&);
Expand Down

0 comments on commit b6df23b

Please sign in to comment.