forked from tmbdev/clstm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.h
325 lines (285 loc) · 7.48 KB
/
utils.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
// -*- C++ -*-
// Additional functions and utilities for CLSTM networks.
// These may use the array classes from "multidim.h"
#ifndef ocropus_clstm_utils_
#define ocropus_clstm_utils_
#include <glob.h>
#include <math.h>
#include <stdarg.h>
#include <stdlib.h>
#include <sys/time.h>
#include <fstream>
#include <iostream>
#include <map>
#include <string>
#include <vector>
#include "pstring.h"
namespace ocropus {
using std::string;
using std::wstring;
using std::vector;
using std::istream;
using std::ostream;
using std::ifstream;
using std::ofstream;
using std::endl;
using std::cout;
using std::cerr;
template <class A>
inline void die(const A &arg) {
cerr << "EXCEPTION (" << arg << ")\n";
exit(255);
}
// get current time down to usec precision as a double
inline double now() {
struct timeval tv;
gettimeofday(&tv, nullptr);
return tv.tv_sec + 1e-6 * tv.tv_usec;
}
inline void glob(vector<string> &result, const string &arg) {
result.clear();
glob_t buf;
glob(arg.c_str(), GLOB_TILDE, nullptr, &buf);
for (int i = 0; i < buf.gl_pathc; i++) {
result.push_back(buf.gl_pathv[i]);
}
if (buf.gl_pathc > 0) globfree(&buf);
}
inline string basename(string s) {
int start = 0;
for (;;) {
auto pos = s.find("/", start);
if (pos == string::npos) break;
start = pos + 1;
}
auto pos = s.find(".", start);
if (pos == string::npos)
return s;
else
return s.substr(0, pos);
}
inline string read_text(string fname, int maxsize = 65536) {
vector<char> buf_v(maxsize);
char *buf = &buf_v[0];
buf[maxsize - 1] = 0;
ifstream stream(fname);
stream.read(buf, maxsize - 1);
int n = stream.gcount();
while (n > 0 && buf[n - 1] == '\n') n--;
return string(buf, n);
}
inline wstring read_text32(string fname, int maxsize = 65536) {
vector<char> buf_v(maxsize);
char *buf = &buf_v[0];
buf[maxsize - 1] = 0;
ifstream stream(fname);
stream.read(buf, maxsize - 1);
int n = stream.gcount();
while (n > 0 && buf[n - 1] == '\n') n--;
return utf8_to_utf32(string(buf, n));
}
inline void read_lines(vector<string> &lines, string fname) {
ifstream stream(fname);
string line;
lines.clear();
while (getline(stream, line)) {
lines.push_back(line);
}
}
inline void write_text(const string fname, const wstring &data) {
string utf8 = utf32_to_utf8(data);
ofstream stream(fname);
stream << utf8 << endl;
}
inline void write_text(const string fname, const string &data) {
ofstream stream(fname);
stream << data << endl;
}
// print the arguments to cout
inline void print() { cout << endl; }
inline ostream &operator<<(ostream &stream, const std::wstring &arg) {
cout << utf32_to_utf8(arg);
return stream;
}
template <class T>
inline void print(const T &arg) {
using namespace std;
cout << arg << endl;
}
template <class T, typename... Args>
inline void print(T arg, Args... args) {
cout << arg << " ";
print(args...);
}
#define PRINT(...) print(__FILE__, __LINE__, __VA_ARGS__)
inline string getdef(std::map<string, string> &m, const string &key,
const string &dflt) {
auto it = m.find(key);
if (it == m.end()) return dflt;
return it->second;
}
inline void dprint() { cerr << endl; }
template <class T>
inline void dprint(const T &arg) {
cerr << arg << endl;
}
template <class T, typename... Args>
inline void dprint(T arg, Args... args) {
cerr << arg << " ";
dprint(args...);
}
// get values from the environment, with defaults
bool reported_params(const char *name);
template <class T>
inline void report_params(const char *name, const T &value) {
const char *flag = getenv("params");
if (flag && !atoi(flag)) return;
if (reported_params(name)) return;
cerr << "#: " << name << " = " << value << endl;
}
inline const char *getsenv(const char *name, const char *dflt) {
const char *result = dflt;
if (getenv(name)) result = getenv(name);
report_params(name, result);
return result;
}
inline int split(vector<string> &tokens, string s, char c = ':') {
int last = 0;
for (;;) {
size_t next = s.find(c, last);
if (next == string::npos) {
tokens.push_back(s.substr(last));
break;
}
tokens.push_back(s.substr(last, next - last));
last = next + 1;
}
return tokens.size();
}
inline string getoneof(const char *name, const char *dflt) {
string s = dflt;
if (getenv(name)) s = getenv(name);
vector<string> tokens;
int n = split(tokens, s);
int k = (lrand48() / 1792) % n;
// cerr << "# getoneof " << name << " " << n << " " << k << endl;
string result = tokens[k];
report_params(name, result);
return result;
}
inline int getienv(const char *name, int dflt = 0) {
int result = dflt;
if (getenv(name)) result = atoi(getenv(name));
report_params(name, result);
return result;
}
inline double getdenv(const char *name, double dflt = 0) {
double result = dflt;
if (getenv(name)) result = atof(getenv(name));
report_params(name, result);
return result;
}
// get a value or random value from the environment (var=7.3 or var=2,8)
inline double getrenv(const char *name, double dflt = 0, bool logscale = true) {
const char *s = getenv(name);
if (!s) return dflt;
float lo, hi;
if (sscanf(s, "%g,%g", &lo, &hi) == 2) {
double x = exp(log(lo) + drand48() * (log(hi) - log(lo)));
report_params(name, x);
return x;
} else if (sscanf(s, "%g", &lo) == 1) {
report_params(name, lo);
return lo;
} else {
THROW("bad format for getrenv");
return 0;
}
}
inline double getuenv(const char *name, double dflt = 0) {
const char *s = getenv(name);
if (!s) return dflt;
float lo, hi;
if (sscanf(s, "%g,%g", &lo, &hi) == 2) {
double x = lo + drand48() * (hi - lo);
report_params(name, x);
return x;
} else if (sscanf(s, "%g", &lo) == 1) {
report_params(name, lo);
return lo;
} else {
THROW("bad format for getuenv");
return 0;
}
}
inline string stringf(const char *format, ...) {
static char buf[4096];
va_list v;
va_start(v, format);
vsnprintf(buf, sizeof(buf), format, v);
va_end(v);
return string(buf);
}
inline void throwf(const char *format, ...) {
static char buf[1024];
va_list arglist;
va_start(arglist, format);
vsprintf(buf, format, arglist);
va_end(arglist);
THROW(buf);
}
// A class encapsulating "report every ..." type logic.
// This will generally report every `every` steps, as well
// as when the `upto` value is reached. It can be disabled
// by setting `enabled` to false.
struct Trigger {
bool finished = false;
bool enabled = true;
int count = 0;
int every = 1;
int upto = 0;
int next = 0;
int last_trigger = 0;
int current_trigger = 0;
Trigger(int every, int upto = -1, int start = 0)
: count(start), every(every), upto(upto) {}
Trigger &skip0() {
next += every;
return *this;
}
Trigger &enable(bool flag) {
enabled = flag;
return *this;
}
void rotate() {
last_trigger = current_trigger;
current_trigger = count;
}
int since() { return count - last_trigger; }
bool check() {
assert(!finished);
if (upto > 0 && count >= upto - 1) {
finished = true;
rotate();
return true;
}
if (every == 0) return false;
if (count >= next) {
while (count >= next) next += every;
rotate();
return true;
} else {
return false;
}
}
bool operator()(int current) {
assert(!finished);
assert(current >= count);
count = current;
return check();
}
bool operator+=(int incr) { return operator()(count + incr); }
bool operator++() { return operator()(count + 1); }
};
}
#endif