forked from LLNL/lbann
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_shuffled_indices.cpp
158 lines (137 loc) · 6.06 KB
/
test_shuffled_indices.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
////////////////////////////////////////////////////////////////////////////////
// Copyright (c) 2014-2021, Lawrence Livermore National Security, LLC.
// Produced at the Lawrence Livermore National Laboratory.
// Written by the LBANN Research Team (B. Van Essen, et al.) listed in
// the CONTRIBUTORS file. <[email protected]>
//
// LLNL-CODE-697807.
// All rights reserved.
//
// This file is part of LBANN: Livermore Big Artificial Neural Network
// Toolkit. For details, see http://software.llnl.gov/LBANN or
// https://github.com/LLNL/LBANN.
//
// Licensed under the Apache License, Version 2.0 (the "Licensee"); you
// may not use this file except in compliance with the License. You may
// obtain a copy of the License at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
// implied. See the License for the specific language governing
// permissions and limitations under the license.
//
// lbann_proto.cpp - prototext application
////////////////////////////////////////////////////////////////////////////////
#include "lbann/lbann.hpp"
#include "lbann/proto/proto_common.hpp"
#include <lbann.pb.h>
#include <reader.pb.h>
#include <string>
using namespace lbann;
int mini_batch_size = 128;
void test_is_shuffled(const generic_data_reader &reader, bool is_shuffled, const char *msg = nullptr);
int main(int argc, char *argv[]) {
world_comm_ptr comm = initialize(argc, argv);
// Initialize the general RNGs and the data sequence RNGs
int random_seed = lbann_default_random_seed;
init_random(random_seed);
init_data_seq_random(random_seed);
const bool master = comm->am_world_master();
try {
// Initialize options db (this parses the command line)
auto& arg_parser = global_argument_parser();
construct_all_options();
arg_parser.add_flag("fn", {"--fn"}, "TODO");
arg_parser.parse(argc, argv);
if (arg_parser.help_requested() or argc == 1) {
if (master)
std::cout << arg_parser << std::endl;
return EXIT_SUCCESS;
}
//read data_reader prototext file
if (arg_parser.get<std::string>("fn") == "") {
std::cerr << __FILE__ << " " << __LINE__ << " :: "
<< "you must run with: --fn=<string> where <string> is\n"
<< "a data_reader prototext filePathName\n";
return EXIT_FAILURE;
}
lbann_data::LbannPB pb;
std::string reader_fn = arg_parser.get<std::string>("fn");
read_prototext_file(reader_fn.c_str(), pb, master);
const lbann_data::DataReader & d_reader = pb.data_reader();
int size = d_reader.reader_size();
for (int j=0; j<size; j++) {
const lbann_data::Reader& readme = d_reader.reader(j);
if (readme.role() == "train") {
bool shuffle = true;
auto reader = std::make_unique<mnist_reader>(shuffle);
if (readme.data_filename() != "") { reader->set_data_filename( readme.data_filename() ); }
if (readme.label_filename() != "") { reader->set_label_filename( readme.label_filename() ); }
if (readme.data_filedir() != "") { reader->set_file_dir( readme.data_filedir() ); }
reader->load();
test_is_shuffled(*reader, true, "TEST #1");
//test: indices should not be shuffled; same as previous, except we call
// shuffle(true);
shuffle = false;
reader = std::make_unique<mnist_reader>(shuffle);
if (readme.data_filename() != "") { reader->set_data_filename( readme.data_filename() ); }
if (readme.label_filename() != "") { reader->set_label_filename( readme.label_filename() ); }
if (readme.data_filedir() != "") { reader->set_file_dir( readme.data_filedir() ); }
reader->set_shuffle(shuffle);
reader->load();
test_is_shuffled(*reader, false, "TEST #2");
//test: indices should not be shuffled, due to ctor argument
shuffle = false;
reader = std::make_unique<mnist_reader>(shuffle);
if (readme.data_filename() != "") { reader->set_data_filename( readme.data_filename() ); }
if (readme.label_filename() != "") { reader->set_label_filename( readme.label_filename() ); }
if (readme.data_filedir() != "") { reader->set_file_dir( readme.data_filedir() ); }
reader->load();
test_is_shuffled(*reader, false, "TEST #3");
//test: set_shuffled_indices; indices should not be shuffled
shuffle = true;
reader = std::make_unique<mnist_reader>(shuffle);
if (readme.data_filename() != "") { reader->set_data_filename( readme.data_filename() ); }
if (readme.label_filename() != "") { reader->set_label_filename( readme.label_filename() ); }
if (readme.data_filedir() != "") { reader->set_file_dir( readme.data_filedir() ); }
reader->load();
//at this point the indices should be shuffled (same as first test)
test_is_shuffled(*reader, true, "TEST #4");
std::vector<int> indices(mini_batch_size);
std::iota(indices.begin(), indices.end(), 0);
reader->set_shuffled_indices(indices);
test_is_shuffled(*reader, false, "TEST #5");
break;
}
}
} catch (lbann_exception& e) {
e.print_report();
return EXIT_FAILURE;
}
return EXIT_SUCCESS;
}
void test_is_shuffled(const generic_data_reader &reader, bool is_shuffled, const char *msg) {
const std::vector<int> &indices = reader.get_shuffled_indices();
std::cerr << "\nstarting test_is_suffled; mini_batch_size: " << mini_batch_size
<< " indices.size(): " << indices.size();
if (msg) {
std::cout << " :: " << msg;
}
std::cout << std::endl;
bool yes = false; //if true true: indices are actaully shuffled
for (int h=0; h<mini_batch_size; h++) {
if (indices[h] != h) {
yes = true;
}
}
std::cout << "testing for is_shuffled = " << is_shuffled << " test shows the shuffled is actually "
<< yes << " :: ";
if (yes == is_shuffled) {
std::cout << "PASSED!\n";
} else {
std::cout << "FAILED!\n";
}
}