Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ET-VK] Add PushConstantDataInfo and vector to hold push constants data in DispatchNode. #7223

Open
wants to merge 8 commits into
base: gh/trivedivivek/22/base
Choose a base branch
from
36 changes: 33 additions & 3 deletions backends/vulkan/runtime/graph/ops/DispatchNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@

namespace vkcompute {

uint32_t PushConstantDataInfo::write(
void* dst,
const uint32_t dst_offset,
const uint32_t max_dst_size) const {
if (tensorUniformData != nullptr) {
return tensorUniformData->write_attribute(
dst, dst_offset, max_dst_size, payload_.attr);
}

VK_CHECK_COND(
(dst_offset + payload_.dataSize) <= max_dst_size,
"Attempting to write push constant data outside data boundary.");
memcpy((uint8_t*)dst + dst_offset, payload_.data, payload_.dataSize);
return payload_.dataSize;
}

DispatchNode::DispatchNode(
ComputeGraph& graph,
const vkapi::ShaderInfo& shader,
Expand All @@ -23,13 +39,15 @@ DispatchNode::DispatchNode(
const vkapi::ParamsBindList& params,
const vkapi::SpecVarList& spec_vars,
const ResizeFunction& resize_fn,
const std::vector<ValueRef>& resize_args)
const std::vector<ValueRef>& resize_args,
const std::vector<PushConstantDataInfo>& push_constants)
: ExecuteNode(resize_fn, resize_args, args, shader.kernel_name),
shader_(shader),
global_workgroup_size_(global_workgroup_size),
local_workgroup_size_(local_workgroup_size),
params_(params),
spec_vars_(spec_vars) {
spec_vars_(spec_vars),
push_constants_(push_constants) {
graph.update_descriptor_counts(shader, /*execute = */ true);
}

Expand Down Expand Up @@ -57,8 +75,20 @@ void DispatchNode::encode(ComputeGraph* graph) {

bind_params_to_descriptor_set(params_, descriptor_set, idx);

uint8_t push_constants_data[128];
uint32_t push_constants_offset = 0;

for (const auto& push_constant : push_constants_) {
push_constants_offset +=
push_constant.write(push_constants_data, push_constants_offset, 128);
}
context->register_shader_dispatch(
descriptor_set, pipeline_barrier, shader_, global_workgroup_size_);
descriptor_set,
pipeline_barrier,
shader_,
global_workgroup_size_,
push_constants_data,
push_constants_offset);

context->report_shader_dispatch_end();
}
Expand Down
48 changes: 47 additions & 1 deletion backends/vulkan/runtime/graph/ops/DispatchNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,50 @@ namespace vkcompute {

class ComputeGraph;

/*
* Represents a push constant data entry
* Which is either shared pointer to a tensor's uniform data with an attribute
* Or data with a maximum size of 16 bytes
*/
class PushConstantDataInfo {
std::shared_ptr<api::vTensor::UniformData> tensorUniformData;
union Payload {
struct {
api::vTensor::Attribute attr;
};
struct {
uint8_t data[16];
uint32_t dataSize;
};
};

Payload payload_;

public:
explicit PushConstantDataInfo(
const std::shared_ptr<api::vTensor::UniformData>& tensorUniformData,
api::vTensor::Attribute attr)
: tensorUniformData(tensorUniformData) {
payload_.attr = attr;
}

explicit PushConstantDataInfo(const void* data, uint32_t dataLen)
: tensorUniformData(nullptr) {
VK_CHECK_COND(
dataLen <= 16, "Single push constant data size must be <= 16 bytes");
payload_.dataSize = dataLen;
memcpy(payload_.data, data, payload_.dataSize);
}

/*
* Function writes push constant data to the destination buffer
*/
uint32_t write(
void* dst,
const uint32_t dst_offset,
const uint32_t max_dst_size) const;
};

/*
* Represents a single shader execution op in a ML model.
*/
Expand All @@ -34,7 +78,8 @@ class DispatchNode final : public ExecuteNode {
const vkapi::ParamsBindList& params,
const vkapi::SpecVarList& spec_vars = {},
const ResizeFunction& resize_fn = nullptr,
const std::vector<ValueRef>& resize_args = {});
const std::vector<ValueRef>& resize_args = {},
const std::vector<PushConstantDataInfo>& push_constants = {});

~DispatchNode() override = default;

Expand All @@ -46,6 +91,7 @@ class DispatchNode final : public ExecuteNode {
const utils::uvec3 local_workgroup_size_;
const vkapi::ParamsBindList params_;
const vkapi::SpecVarList spec_vars_;
const std::vector<PushConstantDataInfo> push_constants_;

public:
operator bool() const {
Expand Down
Loading