From 22230e9b7865b75df1c2b55a4235215a2ff9da5e Mon Sep 17 00:00:00 2001 From: Jiyoung Yun Date: Fri, 9 Aug 2024 10:46:31 +0900 Subject: [PATCH] [onert] Introduce CheckpointLoader ONE-DCO-1.0-Signed-off-by: Jiyoung Yun --- .../onert/api/nnfw/src/nnfw_api_internal.cc | 5 +- .../core/include/loader/CheckpointLoader.h | 50 +++++ .../onert/core/src/loader/CheckpointLoader.cc | 185 ++++++++++++++++++ 3 files changed, 237 insertions(+), 3 deletions(-) create mode 100644 runtime/onert/core/include/loader/CheckpointLoader.h create mode 100644 runtime/onert/core/src/loader/CheckpointLoader.cc diff --git a/runtime/onert/api/nnfw/src/nnfw_api_internal.cc b/runtime/onert/api/nnfw/src/nnfw_api_internal.cc index f830a7a7dba..6a2ab2e4ae3 100644 --- a/runtime/onert/api/nnfw/src/nnfw_api_internal.cc +++ b/runtime/onert/api/nnfw/src/nnfw_api_internal.cc @@ -21,6 +21,7 @@ #include "util/Exceptions.h" #include "util/logging.h" #include "exec/Execution.h" +#include "loader/CheckpointLoader.h" #include "loader/CircleLoader.h" #include "loader/ModelLoader.h" #include "loader/TFLiteLoader.h" @@ -1718,9 +1719,7 @@ NNFW_STATUS nnfw_session::train_import_checkpoint(const char *path) try { - // onert::exporter::CircleExporter exporter(_model_path, std::string{path}); - // exporter.updateWeight(_execution); - // exporter.updateMetadata(_train_info); + onert::loader::loadCheckpoint(_execution, _train_info, path); } catch (const std::exception &e) { diff --git a/runtime/onert/core/include/loader/CheckpointLoader.h b/runtime/onert/core/include/loader/CheckpointLoader.h new file mode 100644 index 00000000000..dd433895dfe --- /dev/null +++ b/runtime/onert/core/include/loader/CheckpointLoader.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_LOADER_CHECKPOINT_LOADER_H__ +#define __ONERT_LOADER_CHECKPOINT_LOADER_H__ + +#include +#include + +namespace onert +{ +namespace exec +{ +class Execution; +} // namespace exec +namespace ir +{ +namespace train +{ +class TrainingInfo; +} // namespace train +} // namespace ir +} // namespace onert + +namespace onert +{ +namespace loader +{ + +void loadCheckpoint(const std::unique_ptr &exec, const +std::unique_ptr &train_info, + const std::string &filename); + +} // namespace loader +} // namespace onert + +#endif // __ONERT_LOADER_CHECKPOINT_LOADER_H__ diff --git a/runtime/onert/core/src/loader/CheckpointLoader.cc b/runtime/onert/core/src/loader/CheckpointLoader.cc new file mode 100644 index 00000000000..42b502753e4 --- /dev/null +++ b/runtime/onert/core/src/loader/CheckpointLoader.cc @@ -0,0 +1,185 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loader/CheckpointLoader.h" + +#include "ir/Model.h" +#include "ir/train/TrainingInfo.h" +#include "exec/Execution.h" +#include "util/Utils.h" + +#include +#include +// #include "BaseLoader.h" +// #include "circle_schema_generated.h" + +namespace onert +{ +namespace loader +{ + +namespace checkpoint +{ + +struct __attribute__((packed)) Header +{ + uint16_t magic; + uint8_t schema; + uint8_t reserved; + uint32_t opt1_offset; + uint32_t opt2_offset; + uint32_t other_offset; + uint32_t length; +}; + +} // namespace checkpoint + +struct DataBufferPair +{ + DataBufferPair(uint32_t _offset, uint32_t _size): offset{_offset}, size{_size} + { + // DO NOTHING + } + + uint32_t offset; + uint32_t size; +}; + +struct DataBuffer +{ + std::vector offset; + std::vector size; + + void resize(uint32_t length) { + offset.resize(length); + size.resize(length); + } + + char *getOffsetBuf() { + return reinterpret_cast(offset.data()); + } + + void calculateSize(uint32_t next_beg_offset) { + assert(offset.size() == size.size()); + uint32_t cur = offset[0]; + for (size_t i = 1; i < offset.size(); ++i) { + size[i - 1] = offset[i] - cur; + cur = offset[i]; + } + size.back() = next_beg_offset - offset.back(); + } + + // offset, size + DataBufferPair operator[](uint32_t i) { + assert(offset.size() == size.size()); + assert(i <= offset.size()); + return DataBufferPair{offset[i], size[i]}; + } +}; + +class CheckpointLoader +{ +public: + CheckpointLoader(const std::string &filename) + { + if (filename.empty() || !std::filesystem::exists(filename)) + throw std::runtime_error{"Invalid checkpoint file"}; + + _file.open(filename.c_str(), std::ios::binary | std::ios::in); + if (!_file.good()) + throw std::runtime_error{"Failed to open checkpoint file"}; + + _file.seekg(0, std::ios::end); + const auto filesize = _file.tellg(); + _file.seekg(0, std::ios::beg); + + if (filesize < static_cast(sizeof(_header))) + throw std::runtime_error{"Invalid checkpoint file data"}; + + memset(reinterpret_cast(&_header), 0, sizeof(_header)); + _file.read(reinterpret_cast(&_header), sizeof(_header)); + if (_file.fail()) + throw std::runtime_error{"Failed to load header data"}; + + if (_header.magic != MAGIC_NUMBER) + throw std::runtime_error{"Invalid MAGIC NUMBER"}; + + if (_header.schema != SCHEMA_VERSION) + throw std::runtime_error{"Invalid SCHEMA VERSION"}; + + _tensor_data.resize(_header.length); + _file.read(_tensor_data.getOffsetBuf(), _header.length * sizeof(uint32_t)); + _tensor_data.calculateSize(_header.opt1_offset); + + if (_header.opt1_offset) + { + _opt1_data.resize(_header.length); + _file.seekg(_header.opt1_offset, std::ios::beg); + _file.read(_opt1_data.getOffsetBuf(), _header.length * sizeof(uint32_t)); + _opt1_data.calculateSize(_header.opt2_offset); + } + + if (_header.opt2_offset) + { + _opt2_data.resize(_header.length); + _file.seekg(_header.opt2_offset, std::ios::beg); + _file.read(_opt2_data.getOffsetBuf(), _header.length * sizeof(uint32_t)); + _opt2_data.calculateSize(_header.other_offset); + } + } + + ~CheckpointLoader() + { + if (_file.is_open()) + _file.close(); + } + + void updateTensor(const std::unique_ptr &exec) + { + auto vindex = 0; + exec->iterateTrainableTensors( + [&](const ir::OperandIndex &, const backend::train::ITrainableTensor *tensor) { + assert(tensor); + assert(tensor->total_size() == _tensor_data[vindex].size); + _file.seekg(_tensor_data[vindex].offset, std::ios::beg); + _file.read(reinterpret_cast(tensor->buffer()), tensor->total_size()); + vindex++; + }); + } + +private: + static constexpr uint16_t MAGIC_NUMBER = 429; + static constexpr uint8_t SCHEMA_VERSION = 1; + + std::ifstream _file; + checkpoint::Header _header; + DataBuffer _tensor_data; + DataBuffer _opt1_data; + DataBuffer _opt2_data; +}; + + +void loadCheckpoint(const std::unique_ptr &exec, const std::unique_ptr &train_info, + const std::string &filename) +{ + CheckpointLoader loader(filename); + loader.updateTensor(exec); + // loader.updateOptimizer(train_info, exec); + UNUSED_RELEASE(train_info); +} + +} // namespace loader +} // namespace onert