diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index cb958cefea..201278ac61 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -20,6 +20,7 @@ #include #include +#include #include #include @@ -350,6 +351,28 @@ class ComputeGraph final { return values_.at(idx).toTensor().logical_limits_ubo(); } + inline PushConstantDataInfo sizes_pc_of(const ValueRef idx) const { + return PushConstantDataInfo( + values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorSizes); + } + + inline PushConstantDataInfo strides_pc_of(const ValueRef idx) const { + return PushConstantDataInfo( + values_.at(idx).toConstTensor().get_uniform_data(), + api::kTensorStrides); + } + + inline PushConstantDataInfo logical_limits_pc_of(const ValueRef idx) const { + return PushConstantDataInfo( + values_.at(idx).toConstTensor().get_uniform_data(), + api::kTensorLogicalLimits); + } + + inline PushConstantDataInfo numel_pc_of(const ValueRef idx) const { + return PushConstantDataInfo( + values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorNumel); + } + // // Scalar Value Extraction //