-
Notifications
You must be signed in to change notification settings - Fork 1
/
emotion_eval.py
73 lines (56 loc) · 2.84 KB
/
emotion_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import data_provider
import models
import losses
from menpo.visualize import print_progress
from pathlib import Path
from tensorflow.python.platform import tf_logging as logging
slim = tf.contrib.slim
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('pretrained_model_checkpoint_path', '',
'''If specified, restore this pretrained model '''
'''before beginning any training.''')
tf.app.flags.DEFINE_integer('batch_size', 32, '''The batch size to use.''')
tf.app.flags.DEFINE_string('model', 'audio','''Which model is going to be used: audio,video, or both ''')
tf.app.flags.DEFINE_string('dataset_dir', './tf_records/', 'The tfrecords directory.')
tf.app.flags.DEFINE_string('checkpoint_dir', './ckpt/train/', 'The tfrecords directory.')
tf.app.flags.DEFINE_string('log_dir', './ckpt/logs/valid/', 'The tfrecords directory.')
tf.app.flags.DEFINE_string('num_examples', 10000, 'The number of examples in the test set')
tf.app.flags.DEFINE_string('eval_interval_secs', 300, 'The number of examples in the test set')
def evaluate(data_folder):
g = tf.Graph()
with g.as_default():
# Load dataset.
frames, audio, ground_truth = data_provider.get_split(data_folder, 'train', FLAGS.batch_size)
# Define model graph.
with slim.arg_scope([slim.batch_norm, slim.layers.dropout],
is_training=False):
prediction = models.get_model(FLAGS.model)(frames, audio)
names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({"eval/mean_squared_error": slim.metrics.streaming_mean_squared_error(prediction, ground_truth)})
#metrics_to_values = slim.metrics.aggregate_metric_map({"eval/mean_squared_error": slim.metrics.streaming_mean_squared_error(prediction, ground_truth)})
# Create the summary ops such that they also print out to std output:
summary_ops = []
op = tf.summary.scalar(names_to_values.keys()[0], names_to_updates.values()[0])
op = tf.Print(op, [names_to_updates.values()[0]], names_to_values.keys()[0])
summary_ops.append(op)
num_examples = FLAGS.num_examples
num_batches = num_examples / (FLAGS.batch_size)
logging.set_verbosity(1)
# Setup the global step.
slim.get_or_create_global_step()
eval_interval_secs = FLAGS.eval_interval_secs # How often to run the evaluation.
slim.evaluation.evaluation_loop(
'',
FLAGS.checkpoint_dir,
FLAGS.log_dir,
num_evals=num_batches,
eval_op=names_to_updates.values(),
summary_op=tf.summary.merge(summary_ops),
eval_interval_secs=eval_interval_secs)
def main(_):
evaluate(FLAGS.dataset_dir)
if __name__ == '__main__':
tf.app.run()