diff --git a/backends/vulkan/runtime/graph/containers/Value.h b/backends/vulkan/runtime/graph/containers/Value.h index 8773f0c0b0..07176b025c 100644 --- a/backends/vulkan/runtime/graph/containers/Value.h +++ b/backends/vulkan/runtime/graph/containers/Value.h @@ -58,7 +58,7 @@ struct Value final { bool as_bool; } u; - api::vTensor as_tensor; + std::unique_ptr as_tensor; api::StagingBuffer as_staging; TensorRef as_tensorref; @@ -106,6 +106,11 @@ struct Value final { rhs.payload.member_name.~dtor_name(); \ break; +#define CASE_MOVE_UNIQUE_PTR_TYPE(type_tag, member_name) \ + case type_tag: \ + payload.member_name = std::move(rhs.payload.member_name); \ + break; + Value(Value&& rhs) noexcept : tag(rhs.tag) { switch (tag) { // Scalar types @@ -113,8 +118,7 @@ struct Value final { CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(TypeTag::DOUBLE, as_double); CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(TypeTag::BOOL, as_bool); // Tensor and tensor adjacent types - CASE_MOVE_MOVEABLE_TYPE( - TypeTag::TENSOR, api::vTensor, as_tensor, vTensor); + CASE_MOVE_UNIQUE_PTR_TYPE(TypeTag::TENSOR, as_tensor); CASE_MOVE_MOVEABLE_TYPE( TypeTag::STAGING, api::StagingBuffer, as_staging, StagingBuffer); CASE_MOVE_MOVEABLE_TYPE( @@ -142,6 +146,7 @@ struct Value final { #undef CASE_MOVE_TRIVIALLY_COPYABLE_TYPE #undef CASE_MOVE_MOVEABLE_TYPE +#undef CASE_MOVE_UNIQUE_PTR_TYPE // // Accessors @@ -158,7 +163,7 @@ struct Value final { ~Value() { switch (tag) { case TypeTag::TENSOR: - payload.as_tensor.~vTensor(); + payload.as_tensor.reset(); break; case TypeTag::STAGING: payload.as_staging.~StagingBuffer(); @@ -227,6 +232,39 @@ struct Value final { #undef SUPPORT_TRIVIALLY_COPYABLE_TYPE +#define SUPPORT_TRIVIALLY_MOVEABLE_UNIQUE_PTR_TYPE( \ + type, type_name, type_tag, member_name) \ + explicit Value(type t) : tag(type_tag) { \ + payload.member_name = std::make_unique(std::move(t)); \ + } \ + inline bool is##type_name() const { \ + return tag == type_tag; \ + } \ + inline type& to##type_name() const { \ + VK_CHECK_COND( \ + is##type_name(), \ + "Expected value to have type " #type_name ", got ", \ + tag, \ + " instead."); \ + return *payload.member_name; \ + } \ + inline const type& toConst##type_name() const { \ + VK_CHECK_COND( \ + is##type_name(), \ + "Expected value to have type " #type_name ", got ", \ + tag, \ + " instead."); \ + return *payload.member_name; \ + } + + SUPPORT_TRIVIALLY_MOVEABLE_UNIQUE_PTR_TYPE( + api::vTensor, + Tensor, + TypeTag::TENSOR, + as_tensor); + +#undef SUPPORT_TRIVIALLY_MOVEABLE_UNIQUE_PTR_TYPE + #define SUPPORT_TRIVIALLY_MOVEABLE_TYPE( \ type, type_name, type_tag, member_name) \ explicit Value(type&& t) : tag(type_tag) { \ @@ -252,12 +290,6 @@ struct Value final { return payload.member_name; \ } - SUPPORT_TRIVIALLY_MOVEABLE_TYPE( - api::vTensor, - Tensor, - TypeTag::TENSOR, - as_tensor); - SUPPORT_TRIVIALLY_MOVEABLE_TYPE( api::StagingBuffer, Staging,