From 69d95581b672fe09b35cd4a9fdc912794a2e922c Mon Sep 17 00:00:00 2001 From: Oleg Ponomarev Date: Wed, 29 Nov 2017 14:26:55 +0100 Subject: [PATCH] Add input_records flag that allows to restrict maximum number of records to be used --- base/readerutil.h | 36 ++++++++++++++++++++++++++--------- n2p/training/eval.cpp | 2 +- n2p/training/eval_internal.h | 5 +++-- n2p/training/eval_json.cpp | 2 +- n2p/training/train_internal.h | 12 +++++++----- 5 files changed, 39 insertions(+), 18 deletions(-) diff --git a/base/readerutil.h b/base/readerutil.h index 0c81a53..ec8bde5 100644 --- a/base/readerutil.h +++ b/base/readerutil.h @@ -43,7 +43,10 @@ class InputRecordReader { template class FileInputRecordReader : public InputRecordReader { public: - explicit FileInputRecordReader(const std::string& filename) : reader(filename), has_prefetched(false) { + explicit FileInputRecordReader(const std::string& filename, const int64 max_records=-1) : + reader(filename), + has_prefetched(false), + max_records_(max_records){ } virtual ~FileInputRecordReader() override { reader.Close(); @@ -57,12 +60,16 @@ class FileInputRecordReader : public InputRecordReader { } *proto = std::move(prefetched_proto); has_prefetched = false; - return true; + if (max_records_ != 0) { + max_records_--; + return true; + } + return false; } virtual bool ReachedEnd() override { std::lock_guard lock(reader_mutex); - return !has_prefetched && !PrefetchProto(); + return (!has_prefetched && !PrefetchProto()) || max_records_ == 0; } private: @@ -79,13 +86,16 @@ class FileInputRecordReader : public InputRecordReader { ProtoClass prefetched_proto; bool has_prefetched; std::mutex reader_mutex; + int64 max_records_; }; template <> class FileInputRecordReader : public InputRecordReader { public: - explicit FileInputRecordReader(const std::string& filename) : file(filename) { + explicit FileInputRecordReader(const std::string& filename, const int64 max_records=-1) : + file(filename), + max_records_(max_records) { CHECK(exists(filename)) << "File '" << filename << "' does not exist!"; } virtual ~FileInputRecordReader() override { @@ -103,17 +113,23 @@ class FileInputRecordReader : public InputRecordReader } std::getline(file, *s); // Read until we get a non-empty line. } - return true; + if (max_records_ != 0) { + max_records_--; + return true; + } + return false; } virtual bool ReachedEnd() override { std::lock_guard lock(filemutex); - return file.eof(); + return file.eof() || max_records_ == 0; } private: inline bool exists (const std::string& name) { return ( access( name.c_str(), F_OK ) != -1 ); } + + int64 max_records_; }; template @@ -188,17 +204,20 @@ class RecordInput { template class FileRecordInput : public RecordInput { public: - explicit FileRecordInput(const std::string& filename) : filename_(filename) { + explicit FileRecordInput(const std::string& filename, const int64 max_records=-1) : + filename_(filename), + max_records_(max_records) { } virtual ~FileRecordInput() override { } virtual InputRecordReader* CreateReader() override { - return new FileInputRecordReader(filename_); + return new FileInputRecordReader(filename_, max_records_); } private: std::string filename_; + int64 max_records_; }; /** @@ -261,7 +280,6 @@ class CrossValidationReader : public InputRecordReader { if ((training_ && (row_id_ % num_folds_) != fold_id_) || (!training_ && (row_id_ % num_folds_) == fold_id_)) { return underlying_reader_->Read(s); - break; } else { T tmp; underlying_reader_->Read(&tmp); diff --git a/n2p/training/eval.cpp b/n2p/training/eval.cpp index 9ab9a2f..8b986f6 100644 --- a/n2p/training/eval.cpp +++ b/n2p/training/eval.cpp @@ -26,7 +26,7 @@ int main(int argc, char** argv) { google::ParseCommandLineFlags(&argc, &argv, true); google::InitGoogleLogging(argv[0]); - return LearningMain([](const Query &record) { + return EvalMain([](const Query &record) { return record; }); } diff --git a/n2p/training/eval_internal.h b/n2p/training/eval_internal.h index 10e67b6..5377dab 100644 --- a/n2p/training/eval_internal.h +++ b/n2p/training/eval_internal.h @@ -35,6 +35,7 @@ using nice2protos::Query; DEFINE_string(model, "model", "File prefix for model to evaluate."); +DEFINE_int64(input_records, -1, "Number of input records to use.") DEFINE_string(input, "testdata", "Input file with objects to be used for evaluation."); DEFINE_bool(debug_stats, false, "If specifies, only outputs debug stats of a trained model."); @@ -117,7 +118,7 @@ void Evaluate(RecordInput* evaluation_data, GraphInference* inference } template -int LearningMain(Adapter adapter) { +int EvalMain(Adapter adapter) { if (FLAGS_debug_stats) { GraphInference inference; inference.LoadModel(FLAGS_model); @@ -128,7 +129,7 @@ int LearningMain(Adapter adapter) { GraphInference inference; std::unique_ptr> input; - input.reset(new FileRecordInput(FLAGS_input)); + input.reset(new FileRecordInput(FLAGS_input, FLAGS_input_records)); inference.LoadModel(FLAGS_model); PrecisionStats total_stats; Evaluate(input.get(), &inference, &total_stats, error_stats.get(), adapter); diff --git a/n2p/training/eval_json.cpp b/n2p/training/eval_json.cpp index b71e924..678222b 100644 --- a/n2p/training/eval_json.cpp +++ b/n2p/training/eval_json.cpp @@ -26,7 +26,7 @@ int main(int argc, char** argv) { google::ParseCommandLineFlags(&argc, &argv, true); google::InitGoogleLogging(argv[0]); - return LearningMain([](const Query &record) { + return EvalMain([](const Query &record) { return record; }); } diff --git a/n2p/training/train_internal.h b/n2p/training/train_internal.h index d64287f..5a175d2 100644 --- a/n2p/training/train_internal.h +++ b/n2p/training/train_internal.h @@ -45,6 +45,7 @@ const std::string PROP_INITIAL_LEARN_RATE_AND_PASS_LEARN_RATE_UPDATE_PL = "prop_ DEFINE_string(input, "testdata", "Input file with training data objects"); DEFINE_string(out_model, "model", "File prefix for output models"); DEFINE_int32(num_training_passes, 24, "Number of passes in training."); +DEFINE_int64(input_records, -1, "Number of input records to use.") DEFINE_double(start_learning_rate, 0.1, "Initial learning rate"); DEFINE_double(stop_learning_rate, 0.0001, "Stop learning if learning rate falls below the value"); @@ -249,11 +250,11 @@ int LearningMain(Adapter adapter) { for (int fold_id = 0; fold_id < FLAGS_cross_validation_folds; ++fold_id) { GraphInference inference; std::unique_ptr> training_data( - new ShuffledCacheInput(new CrossValidationInput(new FileRecordInput(FLAGS_input), - fold_id, FLAGS_cross_validation_folds, true))); + new ShuffledCacheInput(new CrossValidationInput(new FileRecordInput( + FLAGS_input, FLAGS_input_records), fold_id, FLAGS_cross_validation_folds, true))); std::unique_ptr> validation_data( - new ShuffledCacheInput(new CrossValidationInput(new FileRecordInput(FLAGS_input), - fold_id, FLAGS_cross_validation_folds, false))); + new ShuffledCacheInput(new CrossValidationInput(new FileRecordInput( + FLAGS_input, FLAGS_input_records), fold_id, FLAGS_cross_validation_folds, false))); LOG(INFO) << "Training fold " << fold_id; InitTrain(training_data.get(), &inference, adapter); if (FLAGS_training_method.compare(PL_TRAIN_NAME) == 0) { @@ -286,7 +287,8 @@ int LearningMain(Adapter adapter) { LOG(INFO) << "Running structured training..."; // Structured training. GraphInference inference; - std::unique_ptr> input(new ShuffledCacheInput(new FileRecordInput(FLAGS_input))); + std::unique_ptr> input(new ShuffledCacheInput( + new FileRecordInput(FLAGS_input, FLAGS_input_records))); InitTrain(input.get(), &inference, adapter); LOG(INFO) << "Training inited..."; if (FLAGS_training_method.compare(PL_TRAIN_NAME) == 0) {