From 3fd091e9b4d1958f24668f66c941aea79fe90390 Mon Sep 17 00:00:00 2001 From: Hyeongseok Oh Date: Mon, 2 Sep 2024 15:39:32 +0900 Subject: [PATCH] [onert] Support block quantization operand size calculation This commit update total_size() method to calculate block quantization type operand size. ONE-DCO-1.0-Signed-off-by: Hyeongseok Oh --- runtime/onert/core/include/ir/OperandInfo.h | 2 +- runtime/onert/core/src/ir/DataType.cc | 1 + runtime/onert/core/src/ir/OperandInfo.cc | 53 ++++++++++++++++++ runtime/onert/core/src/ir/OperandInfo.test.cc | 55 +++++++++++++++++++ 4 files changed, 110 insertions(+), 1 deletion(-) create mode 100644 runtime/onert/core/src/ir/OperandInfo.cc create mode 100644 runtime/onert/core/src/ir/OperandInfo.test.cc diff --git a/runtime/onert/core/include/ir/OperandInfo.h b/runtime/onert/core/include/ir/OperandInfo.h index 2957be23e99..f6739033e24 100644 --- a/runtime/onert/core/include/ir/OperandInfo.h +++ b/runtime/onert/core/include/ir/OperandInfo.h @@ -120,7 +120,7 @@ class OperandInfo * @brief Return size of tensor (bytes) * @return Tensor size */ - size_t total_size() const { return _shape.num_elements() * sizeOfDataType(_typeInfo.type()); } + size_t total_size() const; MemAllocType memAllocType() const { return _alloc_type; } void setAsConstant() { _const = true; } diff --git a/runtime/onert/core/src/ir/DataType.cc b/runtime/onert/core/src/ir/DataType.cc index 07670c72081..97ee2818ad5 100644 --- a/runtime/onert/core/src/ir/DataType.cc +++ b/runtime/onert/core/src/ir/DataType.cc @@ -53,6 +53,7 @@ size_t sizeOfDataType(DataType data_type) case DataType::QUANT_INT16_SYMM: return sizeof(int16_t); default: + // ggml block quantize type data size is not supported throw std::runtime_error{"Unsupported type size"}; } } diff --git a/runtime/onert/core/src/ir/OperandInfo.cc b/runtime/onert/core/src/ir/OperandInfo.cc new file mode 100644 index 00000000000..0a72af9288d --- /dev/null +++ b/runtime/onert/core/src/ir/OperandInfo.cc @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/OperandInfo.h" + +#include + +namespace onert +{ +namespace ir +{ + +size_t OperandInfo::total_size() const +{ + const auto data_type = _typeInfo.type(); + try + { + return _shape.num_elements() * sizeOfDataType(data_type); + } + catch (const std::runtime_error &e) + { + // Caclulate total size for ggml block quantization type on exception handling + // because it is rare case and we should care about performance on non-block case. + if (data_type != DataType::QUANT_GGML_Q4_0 && data_type != DataType::QUANT_GGML_Q8_0) + throw e; + + if (_shape.dim(_shape.rank() - 1) % 32 != 0) + throw std::runtime_error{ + "Block quantization requires the last dimension to be a multiple of 32"}; + + const auto num_blocks = _shape.num_elements() / 32; + const auto block_size = data_type == DataType::QUANT_GGML_Q4_0 + ? (sizeof(uint8_t) * 32 / 2 + sizeof(uint16_t)) + : (sizeof(uint8_t) * 32 + sizeof(uint16_t)); + return num_blocks * block_size; + } +} + +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/OperandInfo.test.cc b/runtime/onert/core/src/ir/OperandInfo.test.cc new file mode 100644 index 00000000000..4a04189df43 --- /dev/null +++ b/runtime/onert/core/src/ir/OperandInfo.test.cc @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/OperandInfo.h" + +#include + +using namespace onert::ir; + +TEST(ir_OperandInfo, total_size) +{ + auto info = OperandInfo::createStaticInfo(Shape{1, 2, 3}, TypeInfo{DataType::FLOAT32}); + EXPECT_EQ(info.total_size(), 24); + + info = OperandInfo::createStaticInfo(Shape{1, 2, 3}, TypeInfo{DataType::QUANT_INT8_SYMM}); + EXPECT_EQ(info.total_size(), 6); + + // Block quantization type operand + info = OperandInfo::createStaticInfo(Shape{1, 4, 32}, TypeInfo{DataType::QUANT_GGML_Q4_0}); + EXPECT_EQ(info.total_size(), 18 * 4); +} + +// Unsupported type +TEST(ir_OperandInfo, neg_total_size_type) +{ + auto info = OperandInfo::createStaticInfo(Shape{1, 2, 3}, TypeInfo{DataType{-1}}); + EXPECT_THROW(info.total_size(), std::runtime_error); +} + +// Unsupported shape +TEST(ir_OperandInfo, neg_total_size_dimension) +{ + // Unspecified shape + auto info = OperandInfo::createStaticInfo(Shape{1, -1, 3}, TypeInfo{DataType::FLOAT32}); + EXPECT_THROW(info.total_size(), std::runtime_error); + + // Block quantization operand + info = OperandInfo::createStaticInfo(Shape{1, 2, 3}, TypeInfo{DataType::QUANT_GGML_Q4_0}); + EXPECT_THROW(info.total_size(), std::runtime_error); + info = OperandInfo::createStaticInfo(Shape{1, 2, 5}, TypeInfo{DataType::QUANT_GGML_Q8_0}); + EXPECT_THROW(info.total_size(), std::runtime_error); +}