diff --git a/backends/vulkan/runtime/graph/containers/Value.h b/backends/vulkan/runtime/graph/containers/Value.h index 8773f0c0b0..83669c85b1 100644 --- a/backends/vulkan/runtime/graph/containers/Value.h +++ b/backends/vulkan/runtime/graph/containers/Value.h @@ -58,7 +58,7 @@ struct Value final { bool as_bool; } u; - api::vTensor as_tensor; + std::unique_ptr as_tensor; api::StagingBuffer as_staging; TensorRef as_tensorref; @@ -106,15 +106,18 @@ struct Value final { rhs.payload.member_name.~dtor_name(); \ break; +#define CASE_MOVE_UNIQUE_PTR_TYPE(type_tag, member_name) \ + case type_tag: \ + payload.member_name = std::move(rhs.payload.member_name); \ + break; + Value(Value&& rhs) noexcept : tag(rhs.tag) { switch (tag) { // Scalar types CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(TypeTag::INT, as_int); CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(TypeTag::DOUBLE, as_double); CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(TypeTag::BOOL, as_bool); - // Tensor and tensor adjacent types - CASE_MOVE_MOVEABLE_TYPE( - TypeTag::TENSOR, api::vTensor, as_tensor, vTensor); + // Tensor adjacent types CASE_MOVE_MOVEABLE_TYPE( TypeTag::STAGING, api::StagingBuffer, as_staging, StagingBuffer); CASE_MOVE_MOVEABLE_TYPE( @@ -132,6 +135,8 @@ struct Value final { CASE_MOVE_MOVEABLE_TYPE( TypeTag::STRING, std::string, as_string, basic_string); CASE_MOVE_MOVEABLE_TYPE(TypeTag::SYMINT, SymInt, as_symint, SymInt); + // Tensor type + CASE_MOVE_UNIQUE_PTR_TYPE(TypeTag::TENSOR, as_tensor); case TypeTag::NONE: clearToNone(); @@ -142,6 +147,7 @@ struct Value final { #undef CASE_MOVE_TRIVIALLY_COPYABLE_TYPE #undef CASE_MOVE_MOVEABLE_TYPE +#undef CASE_MOVE_UNIQUE_PTR_TYPE // // Accessors @@ -157,9 +163,6 @@ struct Value final { ~Value() { switch (tag) { - case TypeTag::TENSOR: - payload.as_tensor.~vTensor(); - break; case TypeTag::STAGING: payload.as_staging.~StagingBuffer(); break; @@ -184,6 +187,9 @@ struct Value final { case TypeTag::SYMINT: payload.as_symint.~SymInt(); break; + case TypeTag::TENSOR: + payload.as_tensor.reset(); + break; // Manually list out the types so that if a type here is added later and // not handled the compiler can catch it. case TypeTag::NONE: @@ -252,12 +258,6 @@ struct Value final { return payload.member_name; \ } - SUPPORT_TRIVIALLY_MOVEABLE_TYPE( - api::vTensor, - Tensor, - TypeTag::TENSOR, - as_tensor); - SUPPORT_TRIVIALLY_MOVEABLE_TYPE( api::StagingBuffer, Staging, @@ -302,9 +302,36 @@ struct Value final { SUPPORT_TRIVIALLY_MOVEABLE_TYPE(SymInt, SymInt, TypeTag::SYMINT, as_symint); -#undef SUPPORT_TRIVIALLY_COPYABLE_TYPE #undef SUPPORT_TRIVIALLY_MOVEABLE_TYPE +#define SUPPORT_UNIQUE_PTR_TYPE(type, type_name, type_tag, member_name) \ + explicit Value(type t) : tag(type_tag) { \ + payload.member_name = std::make_unique(std::move(t)); \ + } \ + inline bool is##type_name() const { \ + return tag == type_tag; \ + } \ + inline type& to##type_name() const { \ + VK_CHECK_COND( \ + is##type_name(), \ + "Expected value to have type " #type_name ", got ", \ + tag, \ + " instead."); \ + return *payload.member_name; \ + } \ + inline const type& toConst##type_name() const { \ + VK_CHECK_COND( \ + is##type_name(), \ + "Expected value to have type " #type_name ", got ", \ + tag, \ + " instead."); \ + return *payload.member_name; \ + } + + SUPPORT_UNIQUE_PTR_TYPE(api::vTensor, Tensor, TypeTag::TENSOR, as_tensor); + +#undef SUPPORT_UNIQUE_PTR_TYPE + private: Payload payload; TypeTag tag; diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index ee49f95ee2..36f12e39da 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -1087,8 +1087,8 @@ TEST_F(VulkanComputeAPITest, print_object_sizes) { // Current known size on 64 bit system: 1040 B EXPECT_TRUE(sizeof(vTensor) < 1200); - // Current known size on 64 bit system: 1056 B - EXPECT_TRUE(sizeof(Value) < 1200); + // Current known size on 64 bit system: 80 B + EXPECT_TRUE(sizeof(Value) < 88); // Current known size on 64 bit system: 120 B EXPECT_TRUE(sizeof(StagingBuffer) < 500); // Current known size on 64 bit system: 384 B