Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Implement Checkpoint APIs #13561

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions runtime/onert/api/nnfw/include/nnfw_experimental.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,28 @@ NNFW_STATUS nnfw_train_get_loss(nnfw_session *session, uint32_t index, float *lo
*/
NNFW_STATUS nnfw_train_export_circle(nnfw_session *session, const char *path);

/**
* @brief Import circle checkpoint
* @note This function should be called on training mode
* This function should be called before {@link nnfw_train}
*
* @param[in] session The session to export a checkpoint
* @param[in] path The path to export a checkpoint
* @return @c NNFW_STATUS_NO_ERROR if successful
*/
NNFW_STATUS nnfw_train_import_checkpoint(nnfw_session *session, const char *path);

/**
* @brief Export circle checkpoint
* @note This function should be called on training mode
* This function should be called after {@link nnfw_train}
*
* @param[in] session The session to export a checkpoint
* @param[in] path The path to export a checkpoint
* @return @c NNFW_STATUS_NO_ERROR if successful
*/
NNFW_STATUS nnfw_train_export_checkpoint(nnfw_session *session, const char *path);

//////////////////////////////////////////////
// Optional APIs for training
//////////////////////////////////////////////
Expand Down
12 changes: 12 additions & 0 deletions runtime/onert/api/nnfw/src/nnfw_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,18 @@ NNFW_STATUS nnfw_train_export_circleplus(nnfw_session *session, const char *path
return session->train_export_circleplus(path);
}

NNFW_STATUS nnfw_train_import_checkpoint(nnfw_session *session, const char *path)
{
NNFW_RETURN_ERROR_IF_NULL(session);
return session->train_import_checkpoint(path);
}

NNFW_STATUS nnfw_train_export_checkpoint(nnfw_session *session, const char *path)
{
NNFW_RETURN_ERROR_IF_NULL(session);
return session->train_export_checkpoint(path);
}

// Quantization

NNFW_STATUS nnfw_set_quantization_type(nnfw_session *session, NNFW_QUANTIZE_TYPE qtype)
Expand Down
57 changes: 57 additions & 0 deletions runtime/onert/api/nnfw/src/nnfw_api_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
#include "util/Exceptions.h"
#include "util/logging.h"
#include "exec/Execution.h"
#include "loader/CheckpointLoader.h"
#include "loader/CircleLoader.h"
#include "loader/ModelLoader.h"
#include "loader/TFLiteLoader.h"
#include "loader/TrainInfoLoader.h"
#include "exporter/CircleExporter.h"
#include "exporter/CheckpointExporter.h"
#include "json/json.h"
#include "ir/NNPkg.h"
#include "ir/OpCode.h"
Expand Down Expand Up @@ -1701,6 +1703,61 @@ NNFW_STATUS nnfw_session::train_export_circleplus(const char *path)
return NNFW_STATUS_NO_ERROR;
}

NNFW_STATUS nnfw_session::train_import_checkpoint(const char *path)
{
if (path == nullptr)
{
std::cerr << "Error during nnfw_session::train_import_checkpoint : path is null" << std::endl;
return NNFW_STATUS_UNEXPECTED_NULL;
}

if (!isStatePreparedOrFinishedTraining())
{
std::cerr << "Error during nnfw_session::train_import_checkpoint : invalid state" << std::endl;
return NNFW_STATUS_INVALID_STATE;
}

try
{
onert::loader::loadCheckpoint(path, _train_info, _execution);
}
catch (const std::exception &e)
{
std::cerr << "Error during nnfw_session::train_import_checkpoint : " << e.what() << std::endl;
return NNFW_STATUS_ERROR;
}

return NNFW_STATUS_NO_ERROR;
}

NNFW_STATUS nnfw_session::train_export_checkpoint(const char *path)
{
if (path == nullptr)
{
std::cerr << "Error during nnfw_session::train_export_checkpoint : path is null" << std::endl;
return NNFW_STATUS_UNEXPECTED_NULL;
}

// Check training mode is enabled
if (!isStateFinishedTraining())
{
std::cerr << "Error during nnfw_session::train_export_checkpoint : invalid state" << std::endl;
return NNFW_STATUS_INVALID_STATE;
}

try
{
onert::exporter::exportCheckpoint(path, _train_info, _execution);
}
catch (const std::exception &e)
{
std::cerr << "Error during nnfw_session::train_export_checkpoint : " << e.what() << std::endl;
return NNFW_STATUS_ERROR;
}

return NNFW_STATUS_NO_ERROR;
}

bool nnfw_session::isStatePreparedTraining()
{
if (_state == State::PREPARED_TRAINING)
Expand Down
2 changes: 2 additions & 0 deletions runtime/onert/api/nnfw/src/nnfw_api_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ struct nnfw_session
NNFW_STATUS train_get_loss(uint32_t index, float *loss);
NNFW_STATUS train_export_circle(const char *path);
NNFW_STATUS train_export_circleplus(const char *path);
NNFW_STATUS train_import_checkpoint(const char *path);
NNFW_STATUS train_export_checkpoint(const char *path);

NNFW_STATUS set_quantization_type(NNFW_QUANTIZE_TYPE qtype);
NNFW_STATUS set_quantized_model_path(const char *path);
Expand Down
53 changes: 53 additions & 0 deletions runtime/onert/core/include/exporter/CheckpointExporter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __ONERT_EXPORTER_CHECKPOINT_EXPORTER_H__
#define __ONERT_EXPORTER_CHECKPOINT_EXPORTER_H__

#include <string>
#include <vector>
#include <memory>

#include "ir/Checkpoint.h"

namespace onert
{
namespace exec
{
class Execution;
} // namespace exec
namespace ir
{
namespace train
{
class TrainingInfo;
} // namespace train
} // namespace ir
} // namespace onert

namespace onert
{
namespace exporter
{

void exportCheckpoint(const std::string &filename,
const std::unique_ptr<ir::train::TrainingInfo> &train_info,
const std::unique_ptr<onert::exec::Execution> &exec);

} // namespace exporter
} // namespace onert

#endif // __ONERT_EXPORTER_CHECKPOINT_EXPORTER_H__
48 changes: 48 additions & 0 deletions runtime/onert/core/include/ir/Checkpoint.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __ONERT_IR_CHECKPOINT_H__
#define __ONERT_IR_CHECKPOINT_H__

namespace onert
{
namespace checkpoint
{

struct __attribute__((packed)) Header
{
uint16_t magic;
uint8_t schema;
uint8_t reserved;
uint32_t opt1_offset;
uint32_t opt2_offset;
uint32_t other_offset;
uint32_t length;
};

struct __attribute__((packed)) Footer
{
uint32_t cur_step;
uint32_t cur_epoch;
};

constexpr uint16_t MAGIC_NUMBER = 429;
constexpr uint8_t SCHEMA_VERSION = 1;

} // namespace checkpoint
} // namespace onert

#endif // __ONERT_IR_CHECKPOINT_H__
50 changes: 50 additions & 0 deletions runtime/onert/core/include/loader/CheckpointLoader.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __ONERT_LOADER_CHECKPOINT_LOADER_H__
#define __ONERT_LOADER_CHECKPOINT_LOADER_H__

#include <string>
#include <memory>

namespace onert
{
namespace exec
{
class Execution;
} // namespace exec
namespace ir
{
namespace train
{
class TrainingInfo;
} // namespace train
} // namespace ir
} // namespace onert

namespace onert
{
namespace loader
{

void loadCheckpoint(const std::string &filename,
const std::unique_ptr<ir::train::TrainingInfo> &train_info,
const std::unique_ptr<onert::exec::Execution> &exec);

} // namespace loader
} // namespace onert

#endif // __ONERT_LOADER_CHECKPOINT_LOADER_H__
Loading