Skip to content

Commit

Permalink
Add input_records flag that allows to restrict maximum number of reco…
Browse files Browse the repository at this point in the history
…rds to be used
  • Loading branch information
onponomarev committed Nov 29, 2017
1 parent a9f1234 commit 69d9558
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 18 deletions.
36 changes: 27 additions & 9 deletions base/readerutil.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ class InputRecordReader {
template <class ProtoClass>
class FileInputRecordReader : public InputRecordReader<ProtoClass> {
public:
explicit FileInputRecordReader(const std::string& filename) : reader(filename), has_prefetched(false) {
explicit FileInputRecordReader(const std::string& filename, const int64 max_records=-1) :
reader(filename),
has_prefetched(false),
max_records_(max_records){
}
virtual ~FileInputRecordReader() override {
reader.Close();
Expand All @@ -57,12 +60,16 @@ class FileInputRecordReader : public InputRecordReader<ProtoClass> {
}
*proto = std::move(prefetched_proto);
has_prefetched = false;
return true;
if (max_records_ != 0) {
max_records_--;
return true;
}
return false;
}

virtual bool ReachedEnd() override {
std::lock_guard<std::mutex> lock(reader_mutex);
return !has_prefetched && !PrefetchProto();
return (!has_prefetched && !PrefetchProto()) || max_records_ == 0;
}

private:
Expand All @@ -79,13 +86,16 @@ class FileInputRecordReader : public InputRecordReader<ProtoClass> {
ProtoClass prefetched_proto;
bool has_prefetched;
std::mutex reader_mutex;
int64 max_records_;
};


template <>
class FileInputRecordReader<std::string> : public InputRecordReader<std::string> {
public:
explicit FileInputRecordReader(const std::string& filename) : file(filename) {
explicit FileInputRecordReader(const std::string& filename, const int64 max_records=-1) :
file(filename),
max_records_(max_records) {
CHECK(exists(filename)) << "File '" << filename << "' does not exist!";
}
virtual ~FileInputRecordReader() override {
Expand All @@ -103,17 +113,23 @@ class FileInputRecordReader<std::string> : public InputRecordReader<std::string>
}
std::getline(file, *s); // Read until we get a non-empty line.
}
return true;
if (max_records_ != 0) {
max_records_--;
return true;
}
return false;
}

virtual bool ReachedEnd() override {
std::lock_guard<std::mutex> lock(filemutex);
return file.eof();
return file.eof() || max_records_ == 0;
}
private:
inline bool exists (const std::string& name) {
return ( access( name.c_str(), F_OK ) != -1 );
}

int64 max_records_;
};

template <class T>
Expand Down Expand Up @@ -188,17 +204,20 @@ class RecordInput {
template <class T>
class FileRecordInput : public RecordInput<T> {
public:
explicit FileRecordInput(const std::string& filename) : filename_(filename) {
explicit FileRecordInput(const std::string& filename, const int64 max_records=-1) :
filename_(filename),
max_records_(max_records) {
}
virtual ~FileRecordInput() override {
}

virtual InputRecordReader<T>* CreateReader() override {
return new FileInputRecordReader<T>(filename_);
return new FileInputRecordReader<T>(filename_, max_records_);
}

private:
std::string filename_;
int64 max_records_;
};

/**
Expand Down Expand Up @@ -261,7 +280,6 @@ class CrossValidationReader : public InputRecordReader<T> {
if ((training_ && (row_id_ % num_folds_) != fold_id_) ||
(!training_ && (row_id_ % num_folds_) == fold_id_)) {
return underlying_reader_->Read(s);
break;
} else {
T tmp;
underlying_reader_->Read(&tmp);
Expand Down
2 changes: 1 addition & 1 deletion n2p/training/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);

return LearningMain<Query>([](const Query &record) {
return EvalMain<Query>([](const Query &record) {
return record;
});
}
5 changes: 3 additions & 2 deletions n2p/training/eval_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
using nice2protos::Query;

DEFINE_string(model, "model", "File prefix for model to evaluate.");
DEFINE_int64(input_records, -1, "Number of input records to use.")

DEFINE_string(input, "testdata", "Input file with objects to be used for evaluation.");
DEFINE_bool(debug_stats, false, "If specifies, only outputs debug stats of a trained model.");
Expand Down Expand Up @@ -117,7 +118,7 @@ void Evaluate(RecordInput<InputType>* evaluation_data, GraphInference* inference
}

template <class InputType>
int LearningMain(Adapter<InputType> adapter) {
int EvalMain(Adapter<InputType> adapter) {
if (FLAGS_debug_stats) {
GraphInference inference;
inference.LoadModel(FLAGS_model);
Expand All @@ -128,7 +129,7 @@ int LearningMain(Adapter<InputType> adapter) {
GraphInference inference;
std::unique_ptr<RecordInput<InputType>> input;

input.reset(new FileRecordInput<InputType>(FLAGS_input));
input.reset(new FileRecordInput<InputType>(FLAGS_input, FLAGS_input_records));
inference.LoadModel(FLAGS_model);
PrecisionStats total_stats;
Evaluate(input.get(), &inference, &total_stats, error_stats.get(), adapter);
Expand Down
2 changes: 1 addition & 1 deletion n2p/training/eval_json.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);

return LearningMain<Query>([](const Query &record) {
return EvalMain<Query>([](const Query &record) {
return record;
});
}
12 changes: 7 additions & 5 deletions n2p/training/train_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ const std::string PROP_INITIAL_LEARN_RATE_AND_PASS_LEARN_RATE_UPDATE_PL = "prop_
DEFINE_string(input, "testdata", "Input file with training data objects");
DEFINE_string(out_model, "model", "File prefix for output models");
DEFINE_int32(num_training_passes, 24, "Number of passes in training.");
DEFINE_int64(input_records, -1, "Number of input records to use.")

DEFINE_double(start_learning_rate, 0.1, "Initial learning rate");
DEFINE_double(stop_learning_rate, 0.0001, "Stop learning if learning rate falls below the value");
Expand Down Expand Up @@ -249,11 +250,11 @@ int LearningMain(Adapter<InputType> adapter) {
for (int fold_id = 0; fold_id < FLAGS_cross_validation_folds; ++fold_id) {
GraphInference inference;
std::unique_ptr<RecordInput<InputType>> training_data(
new ShuffledCacheInput<InputType>(new CrossValidationInput<InputType>(new FileRecordInput<InputType>(FLAGS_input),
fold_id, FLAGS_cross_validation_folds, true)));
new ShuffledCacheInput<InputType>(new CrossValidationInput<InputType>(new FileRecordInput<InputType>(
FLAGS_input, FLAGS_input_records), fold_id, FLAGS_cross_validation_folds, true)));
std::unique_ptr<RecordInput<InputType>> validation_data(
new ShuffledCacheInput<InputType>(new CrossValidationInput<InputType>(new FileRecordInput<InputType>(FLAGS_input),
fold_id, FLAGS_cross_validation_folds, false)));
new ShuffledCacheInput<InputType>(new CrossValidationInput<InputType>(new FileRecordInput<InputType>(
FLAGS_input, FLAGS_input_records), fold_id, FLAGS_cross_validation_folds, false)));
LOG(INFO) << "Training fold " << fold_id;
InitTrain(training_data.get(), &inference, adapter);
if (FLAGS_training_method.compare(PL_TRAIN_NAME) == 0) {
Expand Down Expand Up @@ -286,7 +287,8 @@ int LearningMain(Adapter<InputType> adapter) {
LOG(INFO) << "Running structured training...";
// Structured training.
GraphInference inference;
std::unique_ptr<RecordInput<InputType>> input(new ShuffledCacheInput<InputType>(new FileRecordInput<InputType>(FLAGS_input)));
std::unique_ptr<RecordInput<InputType>> input(new ShuffledCacheInput<InputType>(
new FileRecordInput<InputType>(FLAGS_input, FLAGS_input_records)));
InitTrain(input.get(), &inference, adapter);
LOG(INFO) << "Training inited...";
if (FLAGS_training_method.compare(PL_TRAIN_NAME) == 0) {
Expand Down

0 comments on commit 69d9558

Please sign in to comment.