Skip to content

Commit

Permalink
[tools/onert_train] Support checkpoint loader
Browse files Browse the repository at this point in the history
ONE-DCO-1.0-Signed-off-by: Jiyoung Yun <[email protected]>
  • Loading branch information
jyoungyun committed Aug 12, 2024
1 parent 22230e9 commit 6f1d0e8
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 18 deletions.
2 changes: 1 addition & 1 deletion runtime/onert/api/nnfw/src/nnfw_api_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1719,7 +1719,7 @@ NNFW_STATUS nnfw_session::train_import_checkpoint(const char *path)

try
{
onert::loader::loadCheckpoint(_execution, _train_info, path);
onert::loader::loadCheckpoint(path, _train_info, _execution);
}
catch (const std::exception &e)
{
Expand Down
6 changes: 3 additions & 3 deletions runtime/onert/core/include/loader/CheckpointLoader.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ namespace onert
namespace loader
{

void loadCheckpoint(const std::unique_ptr<onert::exec::Execution> &exec, const
std::unique_ptr<ir::train::TrainingInfo> &train_info,
const std::string &filename);
void loadCheckpoint(const std::string &filename,
const std::unique_ptr<ir::train::TrainingInfo> &train_info,
const std::unique_ptr<onert::exec::Execution> &exec);

} // namespace loader
} // namespace onert
Expand Down
28 changes: 15 additions & 13 deletions runtime/onert/core/src/loader/CheckpointLoader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ struct __attribute__((packed)) Header

struct DataBufferPair
{
DataBufferPair(uint32_t _offset, uint32_t _size): offset{_offset}, size{_size}
DataBufferPair(uint32_t _offset, uint32_t _size) : offset{_offset}, size{_size}
{
// DO NOTHING
}
Expand All @@ -63,27 +63,29 @@ struct DataBuffer
std::vector<uint32_t> offset;
std::vector<uint32_t> size;

void resize(uint32_t length) {
void resize(uint32_t length)
{
offset.resize(length);
size.resize(length);
}

char *getOffsetBuf() {
return reinterpret_cast<char *>(offset.data());
}
char *getOffsetBuf() { return reinterpret_cast<char *>(offset.data()); }

void calculateSize(uint32_t next_beg_offset) {
void calculateSize(uint32_t next_beg_offset)
{
assert(offset.size() == size.size());
uint32_t cur = offset[0];
for (size_t i = 1; i < offset.size(); ++i) {
for (size_t i = 1; i < offset.size(); ++i)
{
size[i - 1] = offset[i] - cur;
cur = offset[i];
}
size.back() = next_beg_offset - offset.back();
}

// offset, size
DataBufferPair operator[](uint32_t i) {
DataBufferPair operator[](uint32_t i)
{
assert(offset.size() == size.size());
assert(i <= offset.size());
return DataBufferPair{offset[i], size[i]};
Expand All @@ -105,7 +107,7 @@ class CheckpointLoader
_file.seekg(0, std::ios::end);
const auto filesize = _file.tellg();
_file.seekg(0, std::ios::beg);

if (filesize < static_cast<long int>(sizeof(_header)))
throw std::runtime_error{"Invalid checkpoint file data"};

Expand All @@ -123,7 +125,7 @@ class CheckpointLoader
_tensor_data.resize(_header.length);
_file.read(_tensor_data.getOffsetBuf(), _header.length * sizeof(uint32_t));
_tensor_data.calculateSize(_header.opt1_offset);

if (_header.opt1_offset)
{
_opt1_data.resize(_header.length);
Expand Down Expand Up @@ -171,9 +173,9 @@ class CheckpointLoader
DataBuffer _opt2_data;
};


void loadCheckpoint(const std::unique_ptr<onert::exec::Execution> &exec, const std::unique_ptr<ir::train::TrainingInfo> &train_info,
const std::string &filename)
void loadCheckpoint(const std::string &filename,
const std::unique_ptr<ir::train::TrainingInfo> &train_info,
const std::unique_ptr<onert::exec::Execution> &exec)
{
CheckpointLoader loader(filename);
loader.updateTensor(exec);
Expand Down
7 changes: 6 additions & 1 deletion tests/tools/onert_train/src/onert_train.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,12 @@ int main(const int argc, char **argv)
// prepare execution

// TODO When nnfw_{prepare|run} are failed, can't catch the time
measure.run(PhaseType::PREPARE, [&]() { NNPR_ENSURE_STATUS(nnfw_train_prepare(session)); });
measure.run(PhaseType::PREPARE, [&]() {
NNPR_ENSURE_STATUS(nnfw_train_prepare(session));

if (auto name = args.getCheckpointFilename(); name != "")
NNPR_ENSURE_STATUS(nnfw_train_import_checkpoint(session, name.c_str()));
});

// prepare input and expected tensor info lists
std::vector<nnfw_tensorinfo> input_infos;
Expand Down

0 comments on commit 6f1d0e8

Please sign in to comment.