-
Notifications
You must be signed in to change notification settings - Fork 216
/
common.py
355 lines (293 loc) · 12.1 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
""" A bunch of general utilities shared by train/embed/eval """
from argparse import ArgumentTypeError
import logging
import os
import numpy as np
import tensorflow as tf
# Commandline argument parsing
###
def check_directory(arg, access=os.W_OK, access_str="writeable"):
""" Check for directory-type argument validity.
Checks whether the given `arg` commandline argument is either a readable
existing directory, or a createable/writeable directory.
Args:
arg (string): The commandline argument to check.
access (constant): What access rights to the directory are requested.
access_str (string): Used for the error message.
Returns:
The string passed din `arg` if the checks succeed.
Raises:
ArgumentTypeError if the checks fail.
"""
path_head = arg
while path_head:
if os.path.exists(path_head):
if os.access(path_head, access):
# Seems legit, but it still doesn't guarantee a valid path.
# We'll just go with it for now though.
return arg
else:
raise ArgumentTypeError(
'The provided string `{0}` is not a valid {1} path '
'since {2} is an existing folder without {1} access.'
''.format(arg, access_str, path_head))
path_head, _ = os.path.split(path_head)
# No part of the provided string exists and can be written on.
raise ArgumentTypeError('The provided string `{}` is not a valid {}'
' path.'.format(arg, access_str))
def writeable_directory(arg):
""" To be used as a type for `ArgumentParser.add_argument`. """
return check_directory(arg, os.W_OK, "writeable")
def readable_directory(arg):
""" To be used as a type for `ArgumentParser.add_argument`. """
return check_directory(arg, os.R_OK, "readable")
def number_greater_x(arg, type_, x):
try:
value = type_(arg)
except ValueError:
raise ArgumentTypeError('The argument "{}" is not an {}.'.format(
arg, type_.__name__))
if value > x:
return value
else:
raise ArgumentTypeError('Found {} where an {} greater than {} was '
'required'.format(arg, type_.__name__, x))
def positive_int(arg):
return number_greater_x(arg, int, 0)
def nonnegative_int(arg):
return number_greater_x(arg, int, -1)
def positive_float(arg):
return number_greater_x(arg, float, 0)
def float_or_string(arg):
"""Tries to convert the string to float, otherwise returns the string."""
try:
return float(arg)
except (ValueError, TypeError):
return arg
# Dataset handling
###
def load_dataset(csv_file, image_root, fail_on_missing=True):
""" Loads a dataset .csv file, returning PIDs and FIDs.
PIDs are the "person IDs", i.e. class names/labels.
FIDs are the "file IDs", which are individual relative filenames.
Args:
csv_file (string, file-like object): The csv data file to load.
image_root (string): The path to which the image files as stored in the
csv file are relative to. Used for verification purposes.
If this is `None`, no verification at all is made.
fail_on_missing (bool or None): If one or more files from the dataset
are not present in the `image_root`, either raise an IOError (if
True) or remove it from the returned dataset (if False).
Returns:
(pids, fids) a tuple of numpy string arrays corresponding to the PIDs,
i.e. the identities/classes/labels and the FIDs, i.e. the filenames.
Raises:
IOError if any one file is missing and `fail_on_missing` is True.
"""
dataset = np.genfromtxt(csv_file, delimiter=',', dtype='|U')
pids, fids = dataset.T
# Possibly check if all files exist
if image_root is not None:
missing = np.full(len(fids), False, dtype=bool)
for i, fid in enumerate(fids):
missing[i] = not os.path.isfile(os.path.join(image_root, fid))
missing_count = np.sum(missing)
if missing_count > 0:
if fail_on_missing:
raise IOError('Using the `{}` file and `{}` as an image root {}/'
'{} images are missing'.format(
csv_file, image_root, missing_count, len(fids)))
else:
print('[Warning] removing {} missing file(s) from the'
' dataset.'.format(missing_count))
# We simply remove the missing files.
fids = fids[np.logical_not(missing)]
pids = pids[np.logical_not(missing)]
return pids, fids
def fid_to_image(fid, pid, image_root, image_size):
""" Loads and resizes an image given by FID. Pass-through the PID. """
# Since there is no symbolic path.join, we just add a '/' to be sure.
image_encoded = tf.read_file(tf.reduce_join([image_root, '/', fid]))
# tf.image.decode_image doesn't set the shape, not even the dimensionality,
# because it potentially loads animated .gif files. Instead, we use either
# decode_jpeg or decode_png, each of which can decode both.
# Sounds ridiculous, but is true:
# https://github.com/tensorflow/tensorflow/issues/9356#issuecomment-309144064
image_decoded = tf.image.decode_jpeg(image_encoded, channels=3)
image_resized = tf.image.resize_images(image_decoded, image_size)
return image_resized, fid, pid
def get_logging_dict(name):
return {
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'standard': {
'format': '%(asctime)s [%(levelname)s] %(name)s: %(message)s'
},
},
'handlers': {
'stderr': {
'level': 'INFO',
'formatter': 'standard',
'class': 'common.ColorStreamHandler',
'stream': 'ext://sys.stderr',
},
'logfile': {
'level': 'DEBUG',
'formatter': 'standard',
'class': 'logging.FileHandler',
'filename': name + '.log',
'mode': 'a',
}
},
'loggers': {
'': {
'handlers': ['stderr', 'logfile'],
'level': 'DEBUG',
'propagate': True
},
# extra ones to shut up.
'tensorflow': {
'handlers': ['stderr', 'logfile'],
'level': 'INFO',
},
}
}
# Source for the remainder: https://gist.github.com/mooware/a1ed40987b6cc9ab9c65
# Fixed some things mentioned in the comments there.
# colored stream handler for python logging framework (use the ColorStreamHandler class).
#
# based on:
# http://stackoverflow.com/questions/384076/how-can-i-color-python-logging-output/1336640#1336640
# how to use:
# i used a dict-based logging configuration, not sure what else would work.
#
# import logging, logging.config, colorstreamhandler
#
# _LOGCONFIG = {
# "version": 1,
# "disable_existing_loggers": False,
#
# "handlers": {
# "console": {
# "class": "colorstreamhandler.ColorStreamHandler",
# "stream": "ext://sys.stderr",
# "level": "INFO"
# }
# },
#
# "root": {
# "level": "INFO",
# "handlers": ["console"]
# }
# }
#
# logging.config.dictConfig(_LOGCONFIG)
# mylogger = logging.getLogger("mylogger")
# mylogger.warning("foobar")
# Copyright (c) 2014 Markus Pointner
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
class _AnsiColorStreamHandler(logging.StreamHandler):
DEFAULT = '\x1b[0m'
RED = '\x1b[31m'
GREEN = '\x1b[32m'
YELLOW = '\x1b[33m'
CYAN = '\x1b[36m'
CRITICAL = RED
ERROR = RED
WARNING = YELLOW
INFO = DEFAULT # GREEN
DEBUG = CYAN
@classmethod
def _get_color(cls, level):
if level >= logging.CRITICAL: return cls.CRITICAL
elif level >= logging.ERROR: return cls.ERROR
elif level >= logging.WARNING: return cls.WARNING
elif level >= logging.INFO: return cls.INFO
elif level >= logging.DEBUG: return cls.DEBUG
else: return cls.DEFAULT
def __init__(self, stream=None):
logging.StreamHandler.__init__(self, stream)
def format(self, record):
text = logging.StreamHandler.format(self, record)
color = self._get_color(record.levelno)
return (color + text + self.DEFAULT) if self.is_tty() else text
def is_tty(self):
isatty = getattr(self.stream, 'isatty', None)
return isatty and isatty()
class _WinColorStreamHandler(logging.StreamHandler):
# wincon.h
FOREGROUND_BLACK = 0x0000
FOREGROUND_BLUE = 0x0001
FOREGROUND_GREEN = 0x0002
FOREGROUND_CYAN = 0x0003
FOREGROUND_RED = 0x0004
FOREGROUND_MAGENTA = 0x0005
FOREGROUND_YELLOW = 0x0006
FOREGROUND_GREY = 0x0007
FOREGROUND_INTENSITY = 0x0008 # foreground color is intensified.
FOREGROUND_WHITE = FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_RED
BACKGROUND_BLACK = 0x0000
BACKGROUND_BLUE = 0x0010
BACKGROUND_GREEN = 0x0020
BACKGROUND_CYAN = 0x0030
BACKGROUND_RED = 0x0040
BACKGROUND_MAGENTA = 0x0050
BACKGROUND_YELLOW = 0x0060
BACKGROUND_GREY = 0x0070
BACKGROUND_INTENSITY = 0x0080 # background color is intensified.
DEFAULT = FOREGROUND_WHITE
CRITICAL = BACKGROUND_YELLOW | FOREGROUND_RED | FOREGROUND_INTENSITY | BACKGROUND_INTENSITY
ERROR = FOREGROUND_RED | FOREGROUND_INTENSITY
WARNING = FOREGROUND_YELLOW | FOREGROUND_INTENSITY
INFO = FOREGROUND_GREEN
DEBUG = FOREGROUND_CYAN
@classmethod
def _get_color(cls, level):
if level >= logging.CRITICAL: return cls.CRITICAL
elif level >= logging.ERROR: return cls.ERROR
elif level >= logging.WARNING: return cls.WARNING
elif level >= logging.INFO: return cls.INFO
elif level >= logging.DEBUG: return cls.DEBUG
else: return cls.DEFAULT
def _set_color(self, code):
import ctypes
ctypes.windll.kernel32.SetConsoleTextAttribute(self._outhdl, code)
def __init__(self, stream=None):
logging.StreamHandler.__init__(self, stream)
# get file handle for the stream
import ctypes, ctypes.util
# for some reason find_msvcrt() sometimes doesn't find msvcrt.dll on my system?
crtname = ctypes.util.find_msvcrt()
if not crtname:
crtname = ctypes.util.find_library("msvcrt")
crtlib = ctypes.cdll.LoadLibrary(crtname)
self._outhdl = crtlib._get_osfhandle(self.stream.fileno())
def emit(self, record):
color = self._get_color(record.levelno)
self._set_color(color)
logging.StreamHandler.emit(self, record)
self._set_color(self.FOREGROUND_WHITE)
# select ColorStreamHandler based on platform
import platform
if platform.system() == 'Windows':
ColorStreamHandler = _WinColorStreamHandler
else:
ColorStreamHandler = _AnsiColorStreamHandler