Skip to content

Commit

Permalink
Simplify class LSTMTrainer
Browse files Browse the repository at this point in the history
The function pointers and callbacks file_reader_, file_writer_,
checkpointer_reader_ and checkpoint_writer_ are always set to
the same values. Replacing them by direct function calls
simplifies the code and allows removing more code from tesscallback.h.

Signed-off-by: Stefan Weil <[email protected]>
  • Loading branch information
stweil committed Jun 23, 2019
1 parent c5525c4 commit 563a171
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 172 deletions.
105 changes: 0 additions & 105 deletions src/ccutil/tesscallback.h
Original file line number Diff line number Diff line change
Expand Up @@ -607,72 +607,6 @@ NewPermanentTessCallback(R (*function)(P1, A1),
return new _TessFunctionResultCallback_1_1<false, R, P1, A1>(function, p1);
}

template <bool del, class R, class T, class A1, class A2>
class _ConstTessMemberResultCallback_0_2
: public TessResultCallback2<R, A1, A2> {
public:
typedef TessResultCallback2<R, A1, A2> base;
using MemberSignature = R (T::*)(A1, A2) const;

private:
const T* object_;
MemberSignature member_;

public:
inline _ConstTessMemberResultCallback_0_2(const T* object,
MemberSignature member)
: object_(object), member_(member) {}

R Run(A1 a1, A2 a2) override {
if (!del) {
R result = (object_->*member_)(a1, a2);
return result;
}
R result = (object_->*member_)(a1, a2);
// zero out the pointer to ensure segfault if used again
member_ = nullptr;
delete this;
return result;
}
};

template <bool del, class T, class A1, class A2>
class _ConstTessMemberResultCallback_0_2<del, void, T, A1, A2>
: public TessCallback2<A1, A2> {
public:
typedef TessCallback2<A1, A2> base;
using MemberSignature = void (T::*)(A1, A2) const;

private:
const T* object_;
MemberSignature member_;

public:
inline _ConstTessMemberResultCallback_0_2(const T* object,
MemberSignature member)
: object_(object), member_(member) {}

virtual void Run(A1 a1, A2 a2) {
if (!del) {
(object_->*member_)(a1, a2);
} else {
(object_->*member_)(a1, a2);
// zero out the pointer to ensure segfault if used again
member_ = nullptr;
delete this;
}
}
};

#ifndef SWIG
template <class T1, class T2, class R, class A1, class A2>
inline typename _ConstTessMemberResultCallback_0_2<false, R, T1, A1, A2>::base*
NewPermanentTessCallback(const T1* obj, R (T2::*member)(A1, A2) const) {
return new _ConstTessMemberResultCallback_0_2<false, R, T1, A1, A2>(obj,
member);
}
#endif

template <bool del, class R, class T, class A1, class A2>
class _TessMemberResultCallback_0_2 : public TessResultCallback2<R, A1, A2> {
public:
Expand Down Expand Up @@ -793,45 +727,6 @@ NewPermanentTessCallback(R (*function)(A1, A2)) {
return new _TessFunctionResultCallback_0_2<false, R, A1, A2>(function);
}

template <bool del, class R, class T, class A1, class A2, class A3>
class _ConstTessMemberResultCallback_0_3
: public TessResultCallback3<R, A1, A2, A3> {
public:
typedef TessResultCallback3<R, A1, A2, A3> base;
using MemberSignature = R (T::*)(A1, A2, A3) const;

private:
const T* object_;
MemberSignature member_;

public:
inline _ConstTessMemberResultCallback_0_3(const T* object,
MemberSignature member)
: object_(object), member_(member) {}

R Run(A1 a1, A2 a2, A3 a3) override {
if (!del) {
R result = (object_->*member_)(a1, a2, a3);
return result;
}
R result = (object_->*member_)(a1, a2, a3);
// zero out the pointer to ensure segfault if used again
member_ = nullptr;
delete this;
return result;
}
};

#ifndef SWIG
template <class T1, class T2, class R, class A1, class A2, class A3>
inline
typename _ConstTessMemberResultCallback_0_3<false, R, T1, A1, A2, A3>::base*
NewPermanentTessCallback(const T1* obj, R (T2::*member)(A1, A2, A3) const) {
return new _ConstTessMemberResultCallback_0_3<false, R, T1, A1, A2, A3>(
obj, member);
}
#endif

template <bool del, class R, class T, class A1, class A2, class A3, class A4>
class _TessMemberResultCallback_0_4
: public TessResultCallback4<R, A1, A2, A3, A4> {
Expand Down
58 changes: 16 additions & 42 deletions src/lstm/lstmtrainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,41 +74,17 @@ const int kTargetYScale = 100;
LSTMTrainer::LSTMTrainer()
: randomly_rotate_(false),
training_data_(0),
file_reader_(LoadDataFromFile),
file_writer_(SaveDataToFile),
checkpoint_reader_(
NewPermanentTessCallback(this, &LSTMTrainer::ReadTrainingDump)),
checkpoint_writer_(
NewPermanentTessCallback(this, &LSTMTrainer::SaveTrainingDump)),
sub_trainer_(nullptr) {
EmptyConstructor();
debug_interval_ = 0;
}

LSTMTrainer::LSTMTrainer(FileReader file_reader, FileWriter file_writer,
CheckPointReader checkpoint_reader,
CheckPointWriter checkpoint_writer,
const char* model_base, const char* checkpoint_name,
LSTMTrainer::LSTMTrainer(const char* model_base, const char* checkpoint_name,
int debug_interval, int64_t max_memory)
: randomly_rotate_(false),
training_data_(max_memory),
file_reader_(file_reader),
file_writer_(file_writer),
checkpoint_reader_(checkpoint_reader),
checkpoint_writer_(checkpoint_writer),
sub_trainer_(nullptr),
mgr_(file_reader) {
sub_trainer_(nullptr) {
EmptyConstructor();
if (file_reader_ == nullptr) file_reader_ = LoadDataFromFile;
if (file_writer_ == nullptr) file_writer_ = SaveDataToFile;
if (checkpoint_reader_ == nullptr) {
checkpoint_reader_ =
NewPermanentTessCallback(this, &LSTMTrainer::ReadTrainingDump);
}
if (checkpoint_writer_ == nullptr) {
checkpoint_writer_ =
NewPermanentTessCallback(this, &LSTMTrainer::SaveTrainingDump);
}
debug_interval_ = debug_interval;
model_base_ = model_base;
checkpoint_name_ = checkpoint_name;
Expand All @@ -119,8 +95,6 @@ LSTMTrainer::~LSTMTrainer() {
delete target_win_;
delete ctc_win_;
delete recon_win_;
delete checkpoint_reader_;
delete checkpoint_writer_;
delete sub_trainer_;
}

Expand All @@ -129,9 +103,9 @@ LSTMTrainer::~LSTMTrainer() {
bool LSTMTrainer::TryLoadingCheckpoint(const char* filename,
const char* old_traineddata) {
GenericVector<char> data;
if (!(*file_reader_)(filename, &data)) return false;
if (!LoadDataFromFile(filename, &data)) return false;
tprintf("Loaded file %s, unpacking...\n", filename);
if (!checkpoint_reader_->Run(data, this)) return false;
if (!ReadTrainingDump(data, this)) return false;
StaticShape shape = network_->OutputShape(network_->InputShape());
if (((old_traineddata == nullptr || *old_traineddata == '\0') &&
network_->NumOutputs() == recoder_.code_range()) ||
Expand Down Expand Up @@ -303,7 +277,8 @@ bool LSTMTrainer::LoadAllTrainingData(const GenericVector<STRING>& filenames,
bool randomly_rotate) {
randomly_rotate_ = randomly_rotate;
training_data_.Clear();
return training_data_.LoadDocuments(filenames, cache_strategy, file_reader_);
return training_data_.LoadDocuments(filenames, cache_strategy,
LoadDataFromFile);
}

// Keeps track of best and locally worst char error_rate and launches tests
Expand Down Expand Up @@ -345,10 +320,10 @@ bool LSTMTrainer::MaintainCheckpoints(TestCallback tester, STRING* log_msg) {
if (TransitionTrainingStage(kStageTransitionThreshold)) {
log_msg->add_str_int(" Transitioned to stage ", CurrentTrainingStage());
}
checkpoint_writer_->Run(NO_BEST_TRAINER, this, &best_trainer_);
SaveTrainingDump(NO_BEST_TRAINER, this, &best_trainer_);
if (error_rate < error_rate_of_last_saved_best_ * kBestCheckpointFraction) {
STRING best_model_name = DumpFilename();
if (!(*file_writer_)(best_trainer_, best_model_name)) {
if (!SaveDataToFile(best_trainer_, best_model_name)) {
*log_msg += " failed to write best model:";
} else {
*log_msg += " wrote best model:";
Expand All @@ -366,7 +341,7 @@ bool LSTMTrainer::MaintainCheckpoints(TestCallback tester, STRING* log_msg) {
*log_msg += "\nDivergence! ";
// Copy best_trainer_ before reading it, as it will get overwritten.
GenericVector<char> revert_data(best_trainer_);
if (checkpoint_reader_->Run(revert_data, this)) {
if (ReadTrainingDump(revert_data, this)) {
LogIterations("Reverted to", log_msg);
ReduceLearningRates(this, log_msg);
} else {
Expand All @@ -376,18 +351,17 @@ bool LSTMTrainer::MaintainCheckpoints(TestCallback tester, STRING* log_msg) {
stall_iteration_ = iteration + 2 * (iteration - learning_iteration());
// Re-save the best trainer with the new learning rates and stall
// iteration.
checkpoint_writer_->Run(NO_BEST_TRAINER, this, &best_trainer_);
SaveTrainingDump(NO_BEST_TRAINER, this, &best_trainer_);
}
} else {
// Something interesting happened only if the sub_trainer_ was trained.
result = sub_trainer_result != STR_NONE;
}
if (checkpoint_writer_ != nullptr && file_writer_ != nullptr &&
checkpoint_name_.length() > 0) {
if (checkpoint_name_.length() > 0) {
// Write a current checkpoint.
GenericVector<char> checkpoint;
if (!checkpoint_writer_->Run(FULL, this, &checkpoint) ||
!(*file_writer_)(checkpoint, checkpoint_name_)) {
if (!SaveTrainingDump(FULL, this, &checkpoint) ||
!SaveDataToFile(checkpoint, checkpoint_name_)) {
*log_msg += " failed to write checkpoint.";
} else {
*log_msg += " wrote checkpoint.";
Expand Down Expand Up @@ -518,7 +492,7 @@ bool LSTMTrainer::DeSerialize(const TessdataManager* mgr, TFile* fp) {
void LSTMTrainer::StartSubtrainer(STRING* log_msg) {
delete sub_trainer_;
sub_trainer_ = new LSTMTrainer();
if (!checkpoint_reader_->Run(best_trainer_, sub_trainer_)) {
if (!ReadTrainingDump(best_trainer_, sub_trainer_)) {
*log_msg += " Failed to revert to previous best for trial!";
delete sub_trainer_;
sub_trainer_ = nullptr;
Expand All @@ -533,7 +507,7 @@ void LSTMTrainer::StartSubtrainer(STRING* log_msg) {
stall_iteration_ = learning_iteration() + 2 * stall_offset;
sub_trainer_->stall_iteration_ = stall_iteration_;
// Re-save the best trainer with the new learning rates and stall iteration.
checkpoint_writer_->Run(NO_BEST_TRAINER, sub_trainer_, &best_trainer_);
SaveTrainingDump(NO_BEST_TRAINER, sub_trainer_, &best_trainer_);
}
}

Expand Down Expand Up @@ -926,7 +900,7 @@ bool LSTMTrainer::SaveTraineddata(const STRING& filename) {
SaveRecognitionDump(&recognizer_data);
mgr_.OverwriteEntry(TESSDATA_LSTM, &recognizer_data[0],
recognizer_data.size());
return mgr_.SaveFile(filename, file_writer_);
return mgr_.SaveFile(filename, SaveDataToFile);
}

// Writes the recognizer to memory, so that it can be used for testing later.
Expand Down
23 changes: 1 addition & 22 deletions src/lstm/lstmtrainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// File: lstmtrainer.h
// Description: Top-level line trainer class for LSTM-based networks.
// Author: Ray Smith
// Created: Fri May 03 09:07:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -67,15 +66,6 @@ enum SubTrainerResult {
};

class LSTMTrainer;
// Function to restore the trainer state from a given checkpoint.
// Returns false on failure.
typedef TessResultCallback2<bool, const GenericVector<char>&, LSTMTrainer*>*
CheckPointReader;
// Function to save a checkpoint of the current trainer state.
// Returns false on failure. SerializeAmount determines the amount of the
// trainer to serialize, typically used for saving the best state.
typedef TessResultCallback3<bool, SerializeAmount, const LSTMTrainer*,
GenericVector<char>*>* CheckPointWriter;
// Function to compute and record error rates on some external test set(s).
// Args are: iteration, mean errors, model, training stage.
// Returns a STRING containing logging information about the tests.
Expand All @@ -89,11 +79,7 @@ typedef TessResultCallback4<STRING, int, const double*, const TessdataManager&,
class LSTMTrainer : public LSTMRecognizer {
public:
LSTMTrainer();
// Callbacks may be null, in which case defaults are used.
LSTMTrainer(FileReader file_reader, FileWriter file_writer,
CheckPointReader checkpoint_reader,
CheckPointWriter checkpoint_writer,
const char* model_base, const char* checkpoint_name,
LSTMTrainer(const char* model_base, const char* checkpoint_name,
int debug_interval, int64_t max_memory);
virtual ~LSTMTrainer();

Expand Down Expand Up @@ -416,13 +402,6 @@ class LSTMTrainer : public LSTMRecognizer {
STRING best_model_name_;
// Number of available training stages.
int num_training_stages_;
// Checkpointing callbacks.
FileReader file_reader_;
FileWriter file_writer_;
// TODO(rays) These are pointers, and must be deleted. Switch to unique_ptr
// when we can commit to c++11.
CheckPointReader checkpoint_reader_;
CheckPointWriter checkpoint_writer_;

// ===Serialized data to ensure that a restart produces the same results.===
// These members are only serialized when serialize_amount != LIGHT.
Expand Down
2 changes: 1 addition & 1 deletion src/training/lstmtraining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ int main(int argc, char **argv) {
checkpoint_file += "_checkpoint";
STRING checkpoint_bak = checkpoint_file + ".bak";
tesseract::LSTMTrainer trainer(
nullptr, nullptr, nullptr, nullptr, FLAGS_model_output.c_str(),
FLAGS_model_output.c_str(),
checkpoint_file.c_str(), FLAGS_debug_interval,
static_cast<int64_t>(FLAGS_max_image_MB) * 1048576);
trainer.InitCharSet(FLAGS_traineddata.c_str());
Expand Down
3 changes: 1 addition & 2 deletions unittest/lstm_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ class LSTMTrainerTest : public testing::Test {
nullptr, nullptr));
std::string model_path = file::JoinPath(FLAGS_test_tmpdir, model_name);
std::string checkpoint_path = model_path + "_checkpoint";
trainer_.reset(new LSTMTrainer(nullptr, nullptr, nullptr, nullptr,
model_path.c_str(), checkpoint_path.c_str(),
trainer_.reset(new LSTMTrainer(model_path.c_str(), checkpoint_path.c_str(),
0, 0));
trainer_->InitCharSet(file::JoinPath(FLAGS_test_tmpdir, kLang,
absl::StrCat(kLang, ".traineddata")));
Expand Down

11 comments on commit 563a171

@zdenop
Copy link
Contributor

@zdenop zdenop commented on 563a171 Jul 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commit caused API breaking...

@stweil
Copy link
Contributor Author

@stweil stweil commented on 563a171 Jul 2, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LSTMTrainer is not part of the 4.0 API (I searched the include files in /usr/include/tesseract), so I am not sure whether this is really relevant. Nevertheless I reverted that commit now for 4.1.

@zdenop
Copy link
Contributor

@zdenop zdenop commented on 563a171 Jul 2, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stweil: problem is related to tesscallback.h - this file is part of API and also report related to API breaking is for this file... So only changes of this file (after 4.1.0-rc4) should be reverted.

@stweil
Copy link
Contributor Author

@stweil stweil commented on 563a171 Jul 2, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The compatibility report does not complain because tesscallback.h was changed but because symbols in libtesseract.so.4.0.1 were removed. As tesscallback.h declares template classes, those symbols depend on tesscallback.h and the class used for the template (here LSTMTrainer). I wonder whether these special reported high priority problems with removed symbols are relevant at all.

tesscallback.h is only part of the API header files because it is required by other header files. Third party applications won't use it.

@zdenop
Copy link
Contributor

@zdenop zdenop commented on 563a171 Jul 2, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I mean removal of symbols with modification of tesscallback.h.
My understanding is that we agreed that we will make 4.1 release backward compatible and as reference we use abi-laboratory report.
The current result is this:
image
so more commit should be reverted...

@stweil
Copy link
Contributor Author

@stweil stweil commented on 563a171 Jul 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure whether more commits should be reverted.

How often is the compatibility report updated? Ideally it should be possible to reproduce it locally at any time, but my results look different and did not complain about removed symbols (maybe because the removed symbols were not exposed via public API header files before).

Debian will start with a new stable release in a few days, and as far as I see that new release will include Tesseract 4.0 for the next few years. Let's discuss at issue #1423 what consequences that has for Tesseract 4.1.

@zdenop
Copy link
Contributor

@zdenop zdenop commented on 563a171 Jul 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am able to run report locally on my linux machine (no bid deal, just machine is too slow). abi-laboratory.pro run it 2019-07-02 06:39.

Removal of any function from tesscallback.h will be reported as problem as it is part of API... Removing unused functionalities at this stage of release has no benefit for end users...

@stweil
Copy link
Contributor Author

@stweil stweil commented on 563a171 Jul 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Earlier removals like 3ae4069 were no problem. See the discussion for pull request #2422 which was the first one to reduce the size of tesscallback.h.

@stweil
Copy link
Contributor Author

@stweil stweil commented on 563a171 Jul 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason to reduce tesscallback.h is that's a very large header file (more than 350 KB!) without any direct use for API users. More than 9000 lines of undocumented code are not something which must be preserved for API compatibility. Binary compatibility needs separate consideration.

@stweil
Copy link
Contributor Author

@stweil stweil commented on 563a171 Jul 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zdenop, if you can reproduce the API check locally, a short description might help me to get the same results. Up to now, my local check does not indicate API problems.

@stweil
Copy link
Contributor Author

@stweil stweil commented on 563a171 Jul 6, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the description. I think I could fix the binary ABI compatibility now, see report.

Please sign in to comment.