forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
net.cc
244 lines (216 loc) · 7.12 KB
/
net.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
#include "caffe2/core/net.h"
#include "caffe2/core/net_simple.h"
#include <set>
#include <unordered_map>
#include <unordered_set>
#include "caffe2/core/init.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/timer.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/utils/proto_utils.h"
#include "caffe2/utils/string_utils.h"
C10_DEFINE_string(
caffe2_override_executor,
"",
"Comma-separated list of executor overrides");
namespace caffe2 {
C10_DEFINE_REGISTRY(
NetRegistry,
NetBase,
const std::shared_ptr<const NetDef>&,
Workspace*);
NetBase::NetBase(
const std::shared_ptr<const NetDef>& def,
Workspace* /* unused */)
: external_input_(
def->external_input().begin(),
def->external_input().end()),
external_output_(
def->external_output().begin(),
def->external_output().end()),
name_(def->name()),
net_def_(def) {
static GlobalInitIsCalledGuard guard;
C10_LOG_API_USAGE_ONCE("caffe2.net.create");
// Check that node_name is empty for all ops
for (const OperatorDef& op : def->op()) {
if (op.has_device_option()) {
CAFFE_ENFORCE(
!op.device_option().has_node_name(),
"node_name must be empty for all operators at execution time.");
}
}
// Go through the operators and make sure that blobs are correctly made.
std::set<string> known_blobs(
external_input_.begin(), external_input_.end());
std::set<string> remaining_output(
external_output_.begin(), external_output_.end());
for (const auto& blob : known_blobs) {
remaining_output.erase(blob);
}
for (const OperatorDef& op : def->op()) {
for (const string& in : op.input()) {
if (!known_blobs.count(in)) {
if (external_input_.size()) {
CAFFE_THROW(
"op ",
op.type(),
": Source for input ",
in,
" is unknown for net ",
def->name(),
", operator ",
ProtoDebugString(op));
} else {
// If we are not declaring input and output, we will simply VLOG it
// for debugging purposes.
VLOG(1) << "op " << op.type() << ": input " << in << " is unknown.";
}
}
}
for (const string& out : op.output()) {
known_blobs.insert(out);
remaining_output.erase(out);
}
}
// Finally, check if all declared outputs are being created.
CAFFE_ENFORCE(
remaining_output.size() == 0,
"Some of the blobs are declared as output but never produced by the "
"net ",
def->name(),
", the first one is ",
*remaining_output.begin());
}
bool NetBase::RunAsync() {
for (auto& op : GetOperators()) {
op->ResetEvent();
}
return DoRunAsync();
}
void NetBase::Cancel() {
for (auto& op : GetOperators()) {
op->Cancel();
}
}
namespace {
const std::string kSimpleNet = "simple";
std::vector<NetObserverCreator>* GetNetObserverCreators() {
static std::vector<NetObserverCreator> creators;
return &creators;
}
const std::unordered_map<std::string, std::string>& defaultOverrides() {
// redirecting legacy net types to async_scheduling (except for 'simple');
// async_scheduling checks net type for backward compatibility
static const std::unordered_map<std::string, std::string> overrides = {
{"dag", "async_scheduling"},
{"prof_dag", "async_scheduling"},
{"async_dag", "async_scheduling"},
{"async_polling", "async_scheduling"},
{"async_simple", "simple"}, // "async_simple" impl has been removed.
{"rnn", "simple"}, // "rnn" impl has been removed.
};
return overrides;
}
void ApplyPotentialExecutorOverride(std::string* net_type) {
auto executors = caffe2::split(',', FLAGS_caffe2_override_executor);
CAFFE_ENFORCE(
executors.size() % 2 == 0, "Invalid override executors flag value");
std::unordered_map<std::string, std::string> overrides;
for (const auto& kv : defaultOverrides()) {
overrides[kv.first] = kv.second;
}
for (size_t idx = 0; idx < executors.size(); idx += 2) {
overrides[executors[idx]] = executors[idx + 1];
}
if (overrides.count(*net_type)) {
VLOG(1) << "Overrode net type '" << *net_type << "' with '"
<< overrides[*net_type] << "'";
*net_type = overrides[*net_type];
}
}
} // namespace
void AddGlobalNetObserverCreator(NetObserverCreator creator) {
GetNetObserverCreators()->push_back(creator);
VLOG(1) << "Have set a custom GlobalNetObserverCreator";
}
void ClearGlobalNetObservers() {
GetNetObserverCreators()->clear();
VLOG(1) << "All net observers cleared";
}
unique_ptr<NetBase> CreateNet(const NetDef& net_def, Workspace* ws) {
std::shared_ptr<NetDef> tmp_net_def(new NetDef(net_def));
return CreateNet(tmp_net_def, ws);
}
unique_ptr<NetBase> CreateNet(
const std::shared_ptr<const NetDef>& net_def,
Workspace* ws) {
std::string net_type;
if (net_def->has_type() && !net_def->type().empty()) {
net_type = net_def->type();
} else {
// By default, we will return a simple network that just runs all operators
// sequentially.
net_type = kSimpleNet;
}
ApplyPotentialExecutorOverride(&net_type);
unique_ptr<NetBase> net = NetRegistry()->Create(net_type, net_def, ws);
VLOG(1) << "Adding a global observer to a net";
if (net) {
auto* observer_creators = GetNetObserverCreators();
for (auto& creator : *observer_creators) {
net->AttachObserver(creator(net.get()));
}
}
return net;
}
TaskThreadPoolBase* ExecutorHelper::GetPool(
const DeviceOption& /* unused */) const {
CAFFE_THROW("Not implemented");
}
std::vector<OperatorBase*> ExecutorHelper::GetOperators() const {
CAFFE_THROW("Not implemented");
}
int ExecutorHelper::GetNumWorkers() const {
CAFFE_THROW("Not implemented");
}
// benchmark an individual run so that we can FeedBlobs with new inputs
// no warmup
// return time taken in microseconds
float NetBase::TEST_Benchmark_One_Run() {
Timer timer;
CAFFE_ENFORCE(Run(), "Run has failed.");
return timer.MicroSeconds();
}
std::vector<float> NetBase::TEST_Benchmark(
const int warmup_runs,
const int main_runs,
const bool run_individual) {
LOG(INFO) << "Starting benchmark, running warmup runs";
CAFFE_ENFORCE(
warmup_runs >= 0,
"Number of warm up runs should be non negative, provided ",
warmup_runs);
for (int run_idx = 0; run_idx < warmup_runs; ++run_idx) {
CAFFE_ENFORCE(Run(), "Warmup run ", run_idx, " has failed");
}
LOG(INFO) << "Running main runs";
CAFFE_ENFORCE(
main_runs >= 0,
"Number of main runs should be non negative, provided ",
main_runs);
Timer timer;
for (int run_idx = 0; run_idx < main_runs; ++run_idx) {
CAFFE_ENFORCE(Run(), "Main run ", run_idx, " has failed");
}
auto millis = timer.MilliSeconds();
LOG(INFO) << "Main runs finished. Milliseconds per iter: "
<< millis / main_runs
<< ". Iters per second: " << 1000.0 * main_runs / millis;
if (run_individual) {
LOG(INFO) << "Net does not support per-op benchmark; "
"to run it, switch to a simple net type";
}
return std::vector<float>{millis / main_runs};
}
} // namespace caffe2