Skip to content

Commit

Permalink
Fix broken tests
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani committed Jan 24, 2024
1 parent e02a5d6 commit c630fef
Show file tree
Hide file tree
Showing 8 changed files with 11 additions and 7 deletions.
Binary file modified onnxruntime/test/testdata/training_api/checkpoint.ckpt
Binary file not shown.
Binary file modified onnxruntime/test/testdata/training_api/custom_ops/checkpoint
Binary file not shown.
Binary file modified onnxruntime/test/testdata/training_api/nominal_checkpoint
Binary file not shown.
Binary file modified onnxruntime/test/testdata/training_api/ort_format/checkpoint
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -443,10 +443,10 @@ TEST(CheckpointApiTest, SaveCustomPropertyAsCheckpoint_ThenLoad_CPU) {
* Saves the checkpoint, and loads it again. Checks for nominal flag, and that the state is empty.
*/
TEST(CheckpointApiTest, LoadAndSaveNominalCheckpoint) {
auto checkpoint_uri = "testdata/training_api/nominal_checkpoint";
PathString nominal_checkpoint_path{ORT_TSTR("testdata/training_api/nominal_checkpoint")};

CheckpointState checkpoint_state;
ASSERT_STATUS_OK(LoadCheckpoint(onnxruntime::ToUTF8String(checkpoint_uri), checkpoint_state));
ASSERT_STATUS_OK(LoadCheckpoint(nominal_checkpoint_path, checkpoint_state));
ASSERT_TRUE(checkpoint_state.module_checkpoint_state.is_nominal_state);
for (auto& [name, param] : checkpoint_state.module_checkpoint_state.named_parameters) {
ASSERT_TRUE(param->Data().IsTensor());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -590,8 +590,8 @@ TEST(TrainingApiTest, ModuleAndOptimizerWithNominalState) {

OrtValue input, target;
GenerateRandomInput(std::array<int64_t, 2>{2, 784}, input);
target = onnxruntime::test::CreateInputOrtValueOnCPU<int64_t>(
std::array<int64_t, 1>{2}, std::vector<int64_t>(2, 1));
target = onnxruntime::test::CreateInputOrtValueOnCPU<int32_t>(
std::array<int64_t, 1>{2}, std::vector<int32_t>(2, 1));
auto data_loader = std::vector<std::vector<OrtValue>>(4, std::vector<OrtValue>{input, target});

for (auto it = data_loader.begin(); it != data_loader.end(); ++it) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ TEST(TrainingCApiTest, ModuleAndOptimizerWithNominalState) {
std::vector<int64_t> x_shape{2, 784};
GenerateRandomData(x);

std::vector<int64_t> labels{0, 8};
std::vector<int32_t> labels{0, 8};
std::vector<int64_t> labels_shape{2};

Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
Expand All @@ -460,9 +460,9 @@ TEST(TrainingCApiTest, ModuleAndOptimizerWithNominalState) {
x_shape.data(), x_shape.size(),
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT));
ort_inputs.emplace_back(Ort::Value::CreateTensor(memory_info, labels.data(),
labels.size() * sizeof(int64_t),
labels.size() * sizeof(int32_t),
labels_shape.data(), labels_shape.size(),
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64));
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32));

std::vector<Ort::Value> complete_fetches = complete_training_session.TrainStep(ort_inputs);
std::vector<Ort::Value> nominal_fetches = nominal_training_session.TrainStep(ort_inputs);
Expand Down
4 changes: 4 additions & 0 deletions orttraining/orttraining/training_api/checkpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,10 @@ Status ToModelProto(gsl::span<const uint8_t> checkpoint_bytes,
ORT_RETURN_IF_NOT(frozen_params,
"Checkpoint is invalid. Expected: Valid non-trainable params flatbuffer. Actual: nullptr.");

ORT_RETURN_IF(module_state->is_nominal_state(),
"Cannot load a nominal checkpoint to a model proto. "
"Expected: Complete checkpoint. Actual: Nominal checkpoint.");

InlinedHashMap<std::string, ONNX_NAMESPACE::TensorProto> param_tensor_protos;
param_tensor_protos.reserve(
static_cast<size_t>(requires_grad_params->size()) + static_cast<size_t>(frozen_params->size()));
Expand Down

0 comments on commit c630fef

Please sign in to comment.