diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 60515a3fc68778..32c7787a9b803e 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -19,6 +19,7 @@ from __future__ import print_function import abc +import functools import six from tensorflow.contrib import losses @@ -31,6 +32,7 @@ from tensorflow.contrib.learn.python.learn.estimators import prediction_key from tensorflow.contrib.session_bundle import exporter from tensorflow.python import summary +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops @@ -78,7 +80,7 @@ def _regression_head(label_name=None, def _multi_class_head(n_classes, label_name=None, weight_column_name=None, enable_centered_bias=False, head_name=None, - thresholds=None): + thresholds=None, metric_class_ids=None): """Creates a _Head for multi class single label classification. The Head uses softmax cross entropy loss. @@ -96,18 +98,24 @@ def _multi_class_head(n_classes, label_name=None, weight_column_name=None, head_name: name of the head. If provided, predictions, summary and metrics keys will be prefixed by the head_name and an underscore. thresholds: thresholds for eval metrics, defaults to [.5] + metric_class_ids: List of class IDs for which we should report per-class + metrics. Must all be in the range `[0, n_classes)`. Invalid if + `n_classes` is 2. Returns: An instance of _MultiClassHead. Raises: - ValueError: if n_classes is < 2 + ValueError: if `n_classes` is < 2, or `metric_class_ids` is provided when + `n_classes` is 2. """ if (n_classes is None) or (n_classes < 2): raise ValueError( "n_classes must be > 1 for classification: %s." % n_classes) if n_classes == 2: + if metric_class_ids: + raise ValueError("metric_class_ids invalid for n_classes==2.") return _BinaryLogisticHead(label_name=label_name, weight_column_name=weight_column_name, enable_centered_bias=enable_centered_bias, @@ -119,7 +127,8 @@ def _multi_class_head(n_classes, label_name=None, weight_column_name=None, weight_column_name=weight_column_name, enable_centered_bias=enable_centered_bias, head_name=head_name, - thresholds=thresholds) + thresholds=thresholds, + metric_class_ids=metric_class_ids) def _binary_svm_head(label_name=None, weight_column_name=None, @@ -155,7 +164,7 @@ def _binary_svm_head(label_name=None, weight_column_name=None, def _multi_label_head(n_classes, label_name=None, weight_column_name=None, enable_centered_bias=False, head_name=None, - thresholds=None): + thresholds=None, metric_class_ids=None): """Creates a _Head for multi label classification. The Head uses softmax cross entropy loss. @@ -173,6 +182,8 @@ def _multi_label_head(n_classes, label_name=None, weight_column_name=None, head_name: name of the head. If provided, predictions, summary and metrics keys will be prefixed by the head_name and an underscore. thresholds: thresholds for eval metrics, defaults to [.5] + metric_class_ids: List of class IDs for which we should report per-class + metrics. Must all be in the range `[0, n_classes)`. Returns: An instance of _MultiClassHead. @@ -187,7 +198,8 @@ def _multi_label_head(n_classes, label_name=None, weight_column_name=None, weight_column_name=weight_column_name, enable_centered_bias=enable_centered_bias, head_name=head_name, - thresholds=thresholds) + thresholds=thresholds, + metric_class_ids=metric_class_ids) # TODO(zakaria): Make the classes public once we are ready for users to subclass @@ -353,7 +365,9 @@ def _logits_to_predictions(self, logits): def _signature_fn(self): """Returns the signature_fn to be used in exporting.""" - def _regression_signature_fn(examples, unused_features, predictions): + def _regression_signature_fn(examples, features, predictions): + # pylint: disable=missing-docstring + del features if isinstance(predictions, dict): score = predictions[prediction_key.PredictionKey.SCORES] else: @@ -485,8 +499,9 @@ def _logits_to_predictions(self, logits): def _signature_fn(self): """Returns the signature_fn to be used in exporting.""" - def _classification_signature_fn(examples, unused_features, predictions): + def _classification_signature_fn(examples, features, predictions): """Servo signature function.""" + del features if isinstance(predictions, dict): default_signature = exporter.classification_signature( input_tensor=examples, @@ -527,12 +542,13 @@ def _add_binary_metric(key, metric_fn): _add_binary_metric( metric_key.MetricKey.PREDICTION_MEAN, _predictions_streaming_mean) _add_binary_metric( - metric_key.MetricKey.LABEL_MEAN, _labels_streaming_mean) + metric_key.MetricKey.LABEL_MEAN, _indicator_labels_streaming_mean) # Also include the streaming mean of the label as an accuracy baseline, as # a reminder to users. _add_binary_metric( - metric_key.MetricKey.ACCURACY_BASELINE, _labels_streaming_mean) + metric_key.MetricKey.ACCURACY_BASELINE, + _indicator_labels_streaming_mean) _add_binary_metric(metric_key.MetricKey.AUC, _streaming_auc) @@ -571,7 +587,8 @@ class _MultiClassHead(_Head): def __init__(self, n_classes, label_name, weight_column_name, enable_centered_bias, head_name, - loss_fn=_softmax_cross_entropy_loss, thresholds=None): + loss_fn=_softmax_cross_entropy_loss, thresholds=None, + metric_class_ids=None): """Base type for all single heads. Args: @@ -589,9 +606,11 @@ def __init__(self, n_classes, label_name, keys will be prefixed by the head_name and an underscore. loss_fn: Loss function. thresholds: thresholds for eval. + metric_class_ids: List of class IDs for which we should report per-class + metrics. Must all be in the range `[0, n_classes)`. Raises: - ValueError: if n_classes is invalid. + ValueError: if `n_classes` or `metric_class_ids` is invalid. """ super(_MultiClassHead, self).__init__(head_name=head_name) @@ -604,6 +623,11 @@ def __init__(self, n_classes, label_name, self._loss_fn = loss_fn self._enable_centered_bias = enable_centered_bias self._problem_type = constants.ProblemType.CLASSIFICATION + self._metric_class_ids = tuple( + [] if metric_class_ids is None else metric_class_ids) + for class_id in self._metric_class_ids: + if (class_id < 0) or (class_id >= n_classes): + raise ValueError("Class ID %s not in [0, %s)." % (class_id, n_classes)) @property def logits_dimension(self): @@ -667,8 +691,9 @@ def _logits_to_predictions(self, logits): def _signature_fn(self): """Returns the signature_fn to be used in exporting.""" - def _classification_signature_fn(examples, unused_features, predictions): + def _classification_signature_fn(examples, features, predictions): """Servo signature function.""" + del features if isinstance(predictions, dict): default_signature = exporter.classification_signature( input_tensor=examples, @@ -684,24 +709,104 @@ def _classification_signature_fn(examples, unused_features, predictions): return default_signature, {} return _classification_signature_fn + def _metric_spec(self, metric_fn, prediction_name): + return metric_spec.MetricSpec( + metric_fn, prediction_name, self._label_name, self._weight_column_name) + def _default_metrics(self): """Returns a dict of `MetricSpec` objects keyed by name.""" - metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS): - _weighted_average_loss_metric_spec( - self._loss_fn, - prediction_key.PredictionKey.LOGITS, - self._label_name, - self._weight_column_name)} - - # TODO(b/29366811): This currently results in both an "accuracy" and an - # "accuracy/threshold_0.500000_mean" metric for binary classification. - metrics[_head_prefixed(self._head_name, metric_key.MetricKey.ACCURACY)] = ( - metric_spec.MetricSpec(metrics_lib.streaming_accuracy, - prediction_key.PredictionKey.CLASSES, - self._label_name, - self._weight_column_name)) - - # TODO(b/32953199): Add multiclass metrics. + def _streaming_auc_with_class_id_label(predictions, labels, weights=None): + indicator_labels = _class_id_labels_to_indicator( + labels, num_classes=self.logits_dimension) + return _streaming_auc(predictions, indicator_labels, weights) + + loss_key = _head_prefixed(self._head_name, metric_key.MetricKey.LOSS) + accuracy_key = _head_prefixed( + self._head_name, metric_key.MetricKey.ACCURACY) + auc_key = _head_prefixed(self._head_name, metric_key.MetricKey.AUC) + + metrics = { + loss_key: _weighted_average_loss_metric_spec( + self._loss_fn, + prediction_key.PredictionKey.LOGITS, + self._label_name, + self._weight_column_name), + # TODO(b/29366811): This currently results in both an "accuracy" and an + # "accuracy/threshold_0.500000_mean" metric for binary classification. + accuracy_key: self._metric_spec( + metrics_lib.streaming_accuracy, + prediction_key.PredictionKey.CLASSES), + auc_key: self._metric_spec( + _streaming_auc_with_class_id_label, + prediction_key.PredictionKey.PROBABILITIES) + } + + def _class_predictions_streaming_mean( + predictions, labels, weights=None, class_id=None): + del labels + return metrics_lib.streaming_mean( + math_ops.select( + math_ops.equal( + math_ops.to_int32(class_id), + math_ops.to_int32(predictions)), + array_ops.ones_like(predictions), + array_ops.zeros_like(predictions)), + weights=weights) + + def _class_labels_streaming_mean( + predictions, labels, weights=None, class_id=None): + del predictions + assert class_id is not None + return metrics_lib.streaming_mean( + math_ops.select( + math_ops.equal( + math_ops.to_int32(class_id), + math_ops.to_int32(labels)), + array_ops.ones_like(labels), + array_ops.zeros_like(labels)), + weights=weights) + + def _class_streaming_auc(predictions, labels, weights=None, class_id=None): + assert class_id is not None + indicator_labels = _class_id_labels_to_indicator( + labels, num_classes=self.logits_dimension) + return _streaming_auc( + predictions, indicator_labels, weights=weights, class_id=class_id) + + for class_id in self._metric_class_ids: + + # TODO(ptucker): Add per-class accuracy, precision, recall. + + prediction_mean_key = _head_prefixed( + self._head_name, + metric_key.MetricKey.CLASS_PREDICTION_MEAN % class_id) + label_mean_key = _head_prefixed( + self._head_name, metric_key.MetricKey.CLASS_LABEL_MEAN % class_id) + probability_mean_key = _head_prefixed( + self._head_name, + metric_key.MetricKey.CLASS_PROBABILITY_MEAN % class_id) + logits_mean_key = _head_prefixed( + self._head_name, + metric_key.MetricKey.CLASS_LOGITS_MEAN % class_id) + auc_key = _head_prefixed( + self._head_name, metric_key.MetricKey.CLASS_AUC % class_id) + + metrics[prediction_mean_key] = self._metric_spec( + functools.partial( + _class_predictions_streaming_mean, class_id=class_id), + prediction_key.PredictionKey.CLASSES) + metrics[label_mean_key] = self._metric_spec( + functools.partial(_class_labels_streaming_mean, class_id=class_id), + prediction_key.PredictionKey.PROBABILITIES) + metrics[probability_mean_key] = self._metric_spec( + functools.partial(_predictions_streaming_mean, class_id=class_id), + prediction_key.PredictionKey.PROBABILITIES) + metrics[logits_mean_key] = self._metric_spec( + functools.partial(_predictions_streaming_mean, class_id=class_id), + prediction_key.PredictionKey.LOGITS) + metrics[auc_key] = self._metric_spec( + functools.partial(_class_streaming_auc, class_id=class_id), + prediction_key.PredictionKey.LOGITS) return metrics @@ -713,6 +818,12 @@ def _to_labels_tensor(labels, label_name): return labels +def _assert_labels_rank(labels): + return control_flow_ops.Assert( + math_ops.less_equal(array_ops.rank(labels), 2), + ("labels shape should be either [batch_size, 1] or [batch_size]",)) + + class _BinarySvmHead(_BinaryLogisticHead): """_Head for binary classification using SVMs.""" @@ -720,12 +831,8 @@ def __init__(self, label_name, weight_column_name, enable_centered_bias, head_name, thresholds): def _loss_fn(logits, labels): with ops.name_scope(None, "hinge_loss", (logits, labels)) as name: - check_shape_op = control_flow_ops.Assert( - math_ops.less_equal(array_ops.rank(labels), 2), - ("labels shape should be either [batch_size, 1] or [batch_size]",)) - with ops.control_dependencies((check_shape_op,)): - labels = array_ops.reshape( - labels, shape=(array_ops.shape(labels)[0], 1)) + with ops.control_dependencies((_assert_labels_rank(labels),)): + labels = array_ops.reshape(labels, shape=(-1, 1)) return losses.hinge_loss(logits, labels, scope=name) super(_BinarySvmHead, self).__init__( @@ -769,7 +876,7 @@ class _MultiLabelHead(_MultiClassHead): # TODO(zakaria): add signature and metric for multilabel. def __init__(self, n_classes, label_name, weight_column_name, enable_centered_bias, head_name, - thresholds): + thresholds, metric_class_ids=None): super(_MultiLabelHead, self).__init__( n_classes=n_classes, @@ -778,7 +885,8 @@ def __init__(self, n_classes, label_name, enable_centered_bias=enable_centered_bias, head_name=head_name, loss_fn=_sigmoid_cross_entropy_loss, - thresholds=thresholds) + thresholds=thresholds, + metric_class_ids=metric_class_ids) def _logits_to_predictions(self, logits): """See `_MultiClassHead`.""" @@ -792,19 +900,79 @@ def _logits_to_predictions(self, logits): name=prediction_key.PredictionKey.CLASSES) } + def _metric_spec(self, metric_fn, prediction_name): + return metric_spec.MetricSpec( + metric_fn, prediction_name, self._label_name, self._weight_column_name) + + def _default_metrics(self): + """Returns a dict of `MetricSpec` objects keyed by name.""" + loss_key = _head_prefixed(self._head_name, metric_key.MetricKey.LOSS) + accuracy_key = _head_prefixed( + self._head_name, metric_key.MetricKey.ACCURACY) + auc_key = _head_prefixed(self._head_name, metric_key.MetricKey.AUC) + + metrics = { + loss_key: _weighted_average_loss_metric_spec( + self._loss_fn, + prediction_key.PredictionKey.LOGITS, + self._label_name, + self._weight_column_name), + # TODO(b/29366811): This currently results in both an "accuracy" and an + # "accuracy/threshold_0.500000_mean" metric for binary classification. + accuracy_key: self._metric_spec( + metrics_lib.streaming_accuracy, + prediction_key.PredictionKey.CLASSES), + auc_key: self._metric_spec( + _streaming_auc, prediction_key.PredictionKey.PROBABILITIES), + } + + for class_id in self._metric_class_ids: + + # TODO(ptucker): Add per-class accuracy, precision, recall. + + prediction_mean_key = _head_prefixed( + self._head_name, + metric_key.MetricKey.CLASS_PREDICTION_MEAN % class_id) + label_mean_key = _head_prefixed( + self._head_name, metric_key.MetricKey.CLASS_LABEL_MEAN % class_id) + probability_mean_key = _head_prefixed( + self._head_name, + metric_key.MetricKey.CLASS_PROBABILITY_MEAN % class_id) + logits_mean_key = _head_prefixed( + self._head_name, metric_key.MetricKey.CLASS_LOGITS_MEAN % class_id) + auc_key = _head_prefixed( + self._head_name, metric_key.MetricKey.CLASS_AUC % class_id) + + metrics[prediction_mean_key] = self._metric_spec( + functools.partial(_predictions_streaming_mean, class_id=class_id), + prediction_key.PredictionKey.CLASSES) + metrics[label_mean_key] = self._metric_spec( + functools.partial( + _indicator_labels_streaming_mean, class_id=class_id), + prediction_key.PredictionKey.CLASSES) + metrics[probability_mean_key] = self._metric_spec( + functools.partial(_predictions_streaming_mean, class_id=class_id), + prediction_key.PredictionKey.PROBABILITIES) + metrics[logits_mean_key] = self._metric_spec( + functools.partial(_predictions_streaming_mean, class_id=class_id), + prediction_key.PredictionKey.LOGITS) + metrics[auc_key] = self._metric_spec( + functools.partial(_streaming_auc, class_id=class_id), + prediction_key.PredictionKey.LOGITS) + + return metrics + def _weighted_loss(loss, weight): - """Returns cumulative weighted loss.""" + """Returns cumulative weighted loss as 1d `Tensor`.""" with ops.name_scope(None, "weighted_loss", (loss, weight)) as name: - unweighted_loss = array_ops.reshape(loss, shape=(-1,)) - weighted_loss = math_ops.mul(unweighted_loss, - array_ops.reshape( - weight, shape=(-1,)), - name=name) - return weighted_loss + return math_ops.mul(array_ops.reshape(loss, shape=(-1,)), + array_ops.reshape(weight, shape=(-1,)), + name=name) def _weight_tensor(features, weight_column_name): + """Returns weights as 1d `Tensor`.""" if not weight_column_name: return None with ops.name_scope( @@ -982,17 +1150,49 @@ def _streaming_weighted_average_loss(predictions, labels, weights=None): pred_key, label_key, weight_key) -def _labels_streaming_mean(unused_predictions, labels, weights=None): +def _indicator_labels_streaming_mean( + predictions, labels, weights=None, class_id=None): + del predictions + if class_id is not None: + labels = labels[:, class_id] return metrics_lib.streaming_mean(labels, weights=weights) -def _predictions_streaming_mean(predictions, unused_labels, weights=None): +def _predictions_streaming_mean( + predictions, labels, weights=None, class_id=None): + del labels + if class_id is not None: + predictions = predictions[:, class_id] return metrics_lib.streaming_mean(predictions, weights=weights) -def _streaming_auc(predictions, labels, weights=None): - return metrics_lib.streaming_auc(predictions, labels, - weights=_float_weights_or_none(weights)) +# TODO(ptucker): Add support for SparseTensor labels. +def _class_id_labels_to_indicator(labels, num_classes): + if (num_classes is None) or (num_classes < 2): + raise ValueError("Invalid num_classes %s." % num_classes) + with ops.control_dependencies((_assert_labels_rank(labels),)): + labels = array_ops.reshape(labels, (-1,)) + return array_ops.one_hot(labels, depth=num_classes, axis=-1) + + +def _streaming_auc(predictions, labels, weights=None, class_id=None): + if class_id is not None: + predictions = predictions[:, class_id] + labels = labels[:, class_id] + return metrics_lib.streaming_auc( + predictions, math_ops.cast(labels, dtypes.bool), + weights=_float_weights_or_none(weights)) + + +def _assert_class_id(class_id, num_classes=None): + """Average label value for class `class_id`.""" + if (class_id is None) or (class_id < 0): + raise ValueError("Invalid class_id %s." % class_id) + if num_classes is not None: + if num_classes < 2: + raise ValueError("Invalid num_classes %s." % num_classes) + if class_id >= num_classes: + raise ValueError("Invalid class_id %s." % class_id) def _accuracy_at_threshold(threshold): @@ -1013,6 +1213,6 @@ def _streaming_metrics(predictions, labels, weights=None): precision_tensor, update_op = streaming_metrics_fn( predictions, labels=labels, thresholds=(threshold,), weights=_float_weights_or_none(weights)) - return array_ops.squeeze(precision_tensor), update_op + return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op) return _streaming_metrics diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index 40eb7d17de297a..b84a8ce3c2081a 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function +import math import numpy as np import six import tensorflow as tf @@ -40,77 +41,94 @@ def _assert_variables( def _assert_no_variables(test_case): - _assert_variables(test_case, set([]), set([]), set([])) + _assert_variables(test_case) -class RegressionModelHeadTest(tf.test.TestCase): +# This must be called from within a tf.Session. +def _assert_metrics( + test_case, expected_loss, expected_eval_metrics, model_fn_ops): + test_case.assertAlmostEqual(expected_loss, model_fn_ops.loss.eval(), places=4) + for k in six.iterkeys(expected_eval_metrics): + test_case.assertIn(k, six.iterkeys(model_fn_ops.eval_metric_ops)) + tf.initialize_local_variables().run() + for key, expected_value in six.iteritems(expected_eval_metrics): + value_tensor, update_tensor = model_fn_ops.eval_metric_ops[key] + update = update_tensor.eval() + test_case.assertAlmostEqual( + expected_value, update, places=4, + msg="%s: update, expected %s, got %s." % (key, expected_value, update)) + value = value_tensor.eval() + test_case.assertAlmostEqual( + expected_value, value, places=4, + msg="%s: value, expected %s, got %s." % (key, expected_value, value)) + - def _assert_metrics(self, model_fn_ops): - self.assertItemsEqual(( - "loss", - ), six.iterkeys(model_fn_ops.eval_metric_ops)) +def _sigmoid(x): + return 1. / (1. + math.exp(-1 * x)) - # TODO(zakaria): test multilabel regresssion. + +class RegressionModelHeadTest(tf.test.TestCase): + + # TODO(zakaria): test multilabel regression. def testRegression(self): head = head_lib._regression_head() - with tf.Graph().as_default(), tf.Session() as sess: + with tf.Graph().as_default(), tf.Session(): prediction = tf.constant([[1.], [1.], [3.]]) labels = tf.constant([[0.], [1.], [1.]]) model_fn_ops = head.head_ops({}, labels, tf.contrib.learn.ModeKeys.TRAIN, _noop_train_op, logits=prediction) - self._assert_metrics(model_fn_ops) _assert_no_variables(self) - self.assertAlmostEqual(5. / 3, sess.run(model_fn_ops.loss)) + _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops) + def testRegressionEvalMode(self): + head = head_lib._regression_head() + with tf.Graph().as_default(), tf.Session(): + prediction = tf.constant([[1.], [1.], [3.]]) + labels = tf.constant([[0.], [1.], [1.]]) model_fn_ops = head.head_ops({}, labels, tf.contrib.learn.ModeKeys.EVAL, _noop_train_op, logits=prediction) self.assertIsNone(model_fn_ops.train_op) + _assert_no_variables(self) + _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops) def testRegressionWithLabelName(self): label_name = "my_label" head = head_lib._regression_head(label_name=label_name) - with tf.Graph().as_default(), tf.Session() as sess: + with tf.Graph().as_default(), tf.Session(): prediction = tf.constant([[1.], [1.], [3.]]) labels = {label_name: tf.constant([[0.], [1.], [1.]])} model_fn_ops = head.head_ops({}, labels, tf.contrib.learn.ModeKeys.TRAIN, _noop_train_op, logits=prediction) - self._assert_metrics(model_fn_ops) _assert_no_variables(self) - self.assertAlmostEqual(5. / 3, sess.run(model_fn_ops.loss)) - - model_fn_ops = head.head_ops({}, labels, - tf.contrib.learn.ModeKeys.EVAL, - _noop_train_op, logits=prediction) - self.assertIsNone(model_fn_ops.train_op) + _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops) def testRegressionWithWeights(self): head = head_lib._regression_head( weight_column_name="label_weight") - with tf.Graph().as_default(), tf.Session() as sess: - features = {"label_weight": tf.constant([[2.], [5.], [0.]])} + with tf.Graph().as_default(), tf.Session(): + weights = ((2.,), (5.,), (0.,)) + features = {"label_weight": tf.constant(weights)} prediction = tf.constant([[1.], [1.], [3.]]) labels = tf.constant([[0.], [1.], [1.]]) model_fn_ops = head.head_ops(features, labels, tf.contrib.learn.ModeKeys.TRAIN, _noop_train_op, logits=prediction) - self._assert_metrics(model_fn_ops) _assert_no_variables(self) - self.assertAlmostEqual(2. / 3, sess.run(model_fn_ops.loss), places=3) + _assert_metrics(self, 2. / len(weights), { + "loss": 2. / np.sum(weights) + }, model_fn_ops) def testRegressionWithCenteredBias(self): - head = head_lib._regression_head( - weight_column_name="label_weight", enable_centered_bias=True) - with tf.Graph().as_default(), tf.Session() as sess: - features = {"label_weight": tf.constant([[2.], [5.], [0.]])} + head = head_lib._regression_head(enable_centered_bias=True) + with tf.Graph().as_default(), tf.Session(): prediction = tf.constant([[1.], [1.], [3.]]) labels = tf.constant([[0.], [1.], [1.]]) - model_fn_ops = head.head_ops(features, labels, + model_fn_ops = head.head_ops({}, labels, tf.contrib.learn.ModeKeys.TRAIN, _noop_train_op, logits=prediction) - self._assert_metrics(model_fn_ops) _assert_variables(self, expected_global=( "centered_bias_weight:0", "centered_bias_weight/Adagrad:0", @@ -118,7 +136,7 @@ def testRegressionWithCenteredBias(self): "centered_bias_weight:0", )) tf.global_variables_initializer().run() - self.assertAlmostEqual(2. / 3, sess.run(model_fn_ops.loss), places=3) + _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops) def testErrorInSparseTensorLabels(self): head = head_lib._regression_head() @@ -136,70 +154,111 @@ def testErrorInSparseTensorLabels(self): class MultiLabelModelHeadTest(tf.test.TestCase): - def _assert_metrics(self, model_fn_ops): - self.assertItemsEqual(( - "accuracy", - "loss", - ), six.iterkeys(model_fn_ops.eval_metric_ops)) + def setUp(self): + self._logits = ((1., 0., 0.),) + self._labels = ((0, 0, 1),) + + def _expected_eval_metrics(self, expected_loss): + return { + "accuracy": 1. / 3, + "auc": 1. / 4, + "loss": expected_loss, + "auc/class0": 1., + "auc/class1": 1., + "auc/class2": 0., + "labels/actual_label_mean/class0": self._labels[0][0], + "labels/actual_label_mean/class1": self._labels[0][1], + "labels/actual_label_mean/class2": self._labels[0][2], + "labels/logits_mean/class0": self._logits[0][0], + "labels/logits_mean/class1": self._logits[0][1], + "labels/logits_mean/class2": self._logits[0][2], + "labels/prediction_mean/class0": self._logits[0][0], + "labels/prediction_mean/class1": self._logits[0][1], + "labels/prediction_mean/class2": self._logits[0][2], + "labels/probability_mean/class0": _sigmoid(self._logits[0][0]), + "labels/probability_mean/class1": _sigmoid(self._logits[0][1]), + "labels/probability_mean/class2": _sigmoid(self._logits[0][2]), + } def testMultiLabel(self): - head = head_lib._multi_label_head(n_classes=3) - with tf.Graph().as_default(), tf.Session() as sess: - logits = tf.constant([[1., 0., 0.]]) - labels = tf.constant([[0, 0, 1]]) + n_classes = 3 + head = head_lib._multi_label_head( + n_classes=n_classes, metric_class_ids=range(n_classes)) + with tf.Graph().as_default(), tf.Session(): + logits = tf.constant(self._logits) + labels = tf.constant(self._labels) model_fn_ops = head.head_ops({}, labels, tf.contrib.learn.ModeKeys.TRAIN, _noop_train_op, logits=logits) - self._assert_metrics(model_fn_ops) _assert_no_variables(self) - self.assertAlmostEqual(0.89985204, sess.run(model_fn_ops.loss)) + expected_loss = .89985204 + _assert_metrics( + self, expected_loss, self._expected_eval_metrics(expected_loss), + model_fn_ops) + def testMultiLabelEvalMode(self): + n_classes = 3 + head = head_lib._multi_label_head( + n_classes=n_classes, metric_class_ids=range(n_classes)) + with tf.Graph().as_default(), tf.Session(): + logits = tf.constant([[1., 0., 0.]]) + labels = tf.constant([[0, 0, 1]]) model_fn_ops = head.head_ops({}, labels, tf.contrib.learn.ModeKeys.EVAL, _noop_train_op, logits=logits) self.assertIsNone(model_fn_ops.train_op) + _assert_no_variables(self) + expected_loss = .89985204 + _assert_metrics( + self, expected_loss, self._expected_eval_metrics(expected_loss), + model_fn_ops) def testMultiLabelWithLabelName(self): + n_classes = 3 label_name = "my_label" - head = head_lib._multi_label_head(n_classes=3, label_name=label_name) - with tf.Graph().as_default(), tf.Session() as sess: + head = head_lib._multi_label_head( + n_classes=n_classes, label_name=label_name, + metric_class_ids=range(n_classes)) + with tf.Graph().as_default(), tf.Session(): logits = tf.constant([[1., 0., 0.]]) labels = {label_name: tf.constant([[0, 0, 1]])} model_fn_ops = head.head_ops({}, labels, tf.contrib.learn.ModeKeys.TRAIN, _noop_train_op, logits=logits) - self._assert_metrics(model_fn_ops) _assert_no_variables(self) - self.assertAlmostEqual(0.89985204, sess.run(model_fn_ops.loss)) - - model_fn_ops = head.head_ops({}, labels, - tf.contrib.learn.ModeKeys.EVAL, - _noop_train_op, logits=logits) - self.assertIsNone(model_fn_ops.train_op) + expected_loss = .89985204 + _assert_metrics( + self, expected_loss, self._expected_eval_metrics(expected_loss), + model_fn_ops) def testMultiLabelWithWeight(self): + n_classes = 3 head = head_lib._multi_label_head( - n_classes=3, weight_column_name="label_weight") - with tf.Graph().as_default(), tf.Session() as sess: - features = {"label_weight": tf.constant([0.1])} + n_classes=n_classes, weight_column_name="label_weight", + metric_class_ids=range(n_classes)) + with tf.Graph().as_default(), tf.Session(): + features = {"label_weight": tf.constant([.1])} logits = tf.constant([[1., 0., 0.]]) labels = tf.constant([[0, 0, 1]]) model_fn_ops = head.head_ops(features, labels, tf.contrib.learn.ModeKeys.TRAIN, _noop_train_op, logits=logits) - self._assert_metrics(model_fn_ops) _assert_no_variables(self) - self.assertAlmostEqual(0.089985214, sess.run(model_fn_ops.loss)) + _assert_metrics( + self, .089985214, self._expected_eval_metrics(2.69956), + model_fn_ops) def testMultiLabelWithCenteredBias(self): - head = head_lib._multi_label_head(n_classes=3, enable_centered_bias=True) - with tf.Graph().as_default(), tf.Session() as sess: + n_classes = 3 + head = head_lib._multi_label_head( + n_classes=n_classes, enable_centered_bias=True, + metric_class_ids=range(n_classes)) + with tf.Graph().as_default(), tf.Session(): logits = tf.constant([[1., 0., 0.]]) labels = tf.constant([[0, 0, 1]]) model_fn_ops = head.head_ops({}, labels, tf.contrib.learn.ModeKeys.TRAIN, _noop_train_op, logits=logits) - self._assert_metrics(model_fn_ops) _assert_variables(self, expected_global=( "centered_bias_weight:0", "centered_bias_weight/Adagrad:0", @@ -207,45 +266,69 @@ def testMultiLabelWithCenteredBias(self): "centered_bias_weight:0", )) tf.global_variables_initializer().run() - self.assertAlmostEqual(0.89985204, sess.run(model_fn_ops.loss)) + expected_loss = .89985204 + _assert_metrics( + self, expected_loss, self._expected_eval_metrics(expected_loss), + model_fn_ops) -class MultiClassModelHeadTest(tf.test.TestCase): +class BinaryClassificationModelHeadTest(tf.test.TestCase): - def _assert_binary_metrics(self, model_fn_ops): - self.assertItemsEqual(( - "accuracy", - "accuracy/baseline_label_mean", - "accuracy/threshold_0.500000_mean", - "auc", - "labels/actual_label_mean", - "labels/prediction_mean", - "loss", - "precision/positive_threshold_0.500000_mean", - "recall/positive_threshold_0.500000_mean", - ), six.iterkeys(model_fn_ops.eval_metric_ops)) + def setUp(self): + self._logits = ((1.,), (1.,)) + self._labels = ((1.,), (0.,)) + + def _expected_eval_metrics(self, expected_loss): + return { + "accuracy": 1. / 2, + "accuracy/baseline_label_mean": np.mean(self._labels), + "accuracy/threshold_0.500000_mean": 1. / 2, + "auc": 1. / 2, + "labels/actual_label_mean": np.mean(self._labels), + "labels/prediction_mean": .731059, # softmax + "loss": expected_loss, + "precision/positive_threshold_0.500000_mean": 1. / 2, + "recall/positive_threshold_0.500000_mean": 1. / 1, + } def testBinaryClassification(self): - head = head_lib._multi_class_head(n_classes=2) - with tf.Graph().as_default(), tf.Session() as sess: - logits = tf.constant([[1.], [1.]]) - labels = tf.constant([[1.], [0.]]) + n_classes = 2 + head = head_lib._multi_class_head(n_classes=n_classes) + with tf.Graph().as_default(), tf.Session(): + logits = tf.constant(self._logits) + labels = tf.constant(self._labels) # logloss: z:label, x:logit # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) model_fn_ops = head.head_ops({}, labels, tf.contrib.learn.ModeKeys.TRAIN, _noop_train_op, logits=logits) - self._assert_binary_metrics(model_fn_ops) _assert_no_variables(self) - self.assertAlmostEqual(0.81326175, sess.run(model_fn_ops.loss), - delta=1e-6) + expected_loss = .81326175 + _assert_metrics( + self, expected_loss, self._expected_eval_metrics(expected_loss), + model_fn_ops) + + def testBinaryClassificationEvalMode(self): + n_classes = 2 + head = head_lib._multi_class_head(n_classes=n_classes) + with tf.Graph().as_default(), tf.Session(): + logits = tf.constant(self._logits) + labels = tf.constant(self._labels) + # logloss: z:label, x:logit + # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) model_fn_ops = head.head_ops({}, labels, tf.contrib.learn.ModeKeys.EVAL, _noop_train_op, logits=logits) self.assertIsNone(model_fn_ops.train_op) + _assert_no_variables(self) + expected_loss = .81326175 + _assert_metrics( + self, expected_loss, self._expected_eval_metrics(expected_loss), + model_fn_ops) def testErrorInSparseTensorLabels(self): - head = head_lib._multi_class_head(n_classes=2) + n_classes = 2 + head = head_lib._multi_class_head(n_classes=n_classes) with tf.Graph().as_default(): prediction = tf.constant([[1.], [1.], [3.]]) labels = tf.SparseTensor( @@ -260,51 +343,60 @@ def testErrorInSparseTensorLabels(self): def testBinaryClassificationWithLabelName(self): label_name = "my_label" head = head_lib._multi_class_head(n_classes=2, label_name=label_name) - with tf.Graph().as_default(), tf.Session() as sess: - logits = tf.constant([[1.], [1.]]) - labels = {label_name: tf.constant([[1.], [0.]])} + with tf.Graph().as_default(), tf.Session(): + logits = tf.constant(self._logits) + labels = {label_name: tf.constant(self._labels)} # logloss: z:label, x:logit # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) model_fn_ops = head.head_ops({}, labels, tf.contrib.learn.ModeKeys.TRAIN, _noop_train_op, logits=logits) - self._assert_binary_metrics(model_fn_ops) _assert_no_variables(self) - self.assertAlmostEqual(0.81326175, sess.run(model_fn_ops.loss), - delta=1e-6) - model_fn_ops = head.head_ops({}, labels, - tf.contrib.learn.ModeKeys.EVAL, - _noop_train_op, logits=logits) - self.assertIsNone(model_fn_ops.train_op) + expected_loss = .81326175 + _assert_metrics( + self, expected_loss, self._expected_eval_metrics(expected_loss), + model_fn_ops) def testBinaryClassificationWithWeights(self): + n_classes = 2 head = head_lib._multi_class_head( - n_classes=2, weight_column_name="label_weight") - with tf.Graph().as_default(), tf.Session() as sess: - features = {"label_weight": tf.constant([[1.], [0.]])} - logits = tf.constant([[1.], [1.]]) - labels = tf.constant([[1.], [0.]]) + n_classes=n_classes, weight_column_name="label_weight") + with tf.Graph().as_default(), tf.Session(): + weights = ((1.,), (0.,)) + features = {"label_weight": tf.constant(weights)} + logits = tf.constant(self._logits) + labels = tf.constant(self._labels) # logloss: z:label, x:logit # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) model_fn_ops = head.head_ops(features, labels, tf.contrib.learn.ModeKeys.TRAIN, _noop_train_op, logits=logits) - self._assert_binary_metrics(model_fn_ops) _assert_no_variables(self) - self.assertAlmostEqual(.31326166 / 2, sess.run(model_fn_ops.loss), - delta=1e-6) + expected_total_loss = .31326166 + _assert_metrics( + self, expected_total_loss / len(weights), { + "accuracy": 1. / 1, + "accuracy/baseline_label_mean": 1. / 1, + "accuracy/threshold_0.500000_mean": 1. / 1, + "auc": 0. / 1, + "labels/actual_label_mean": 1. / 1, + "labels/prediction_mean": .731059, # softmax + # TODO(ptucker): Is this the correct eval loss, sum not average? + "loss": expected_total_loss, + "precision/positive_threshold_0.500000_mean": 1. / 1, + "recall/positive_threshold_0.500000_mean": 1. / 1, + }, model_fn_ops) def testBinaryClassificationWithCenteredBias(self): head = head_lib._multi_class_head(n_classes=2, enable_centered_bias=True) - with tf.Graph().as_default(), tf.Session() as sess: - logits = tf.constant([[1.], [1.]]) - labels = tf.constant([[1.], [0.]]) + with tf.Graph().as_default(), tf.Session(): + logits = tf.constant(self._logits) + labels = tf.constant(self._labels) # logloss: z:label, x:logit # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) model_fn_ops = head.head_ops({}, labels, tf.contrib.learn.ModeKeys.TRAIN, _noop_train_op, logits=logits) - self._assert_binary_metrics(model_fn_ops) _assert_variables(self, expected_global=( "centered_bias_weight:0", "centered_bias_weight/Adagrad:0", @@ -312,50 +404,97 @@ def testBinaryClassificationWithCenteredBias(self): "centered_bias_weight:0", )) tf.global_variables_initializer().run() - self.assertAlmostEqual(0.81326175, sess.run(model_fn_ops.loss), - delta=1e-6) + expected_loss = .81326175 + _assert_metrics( + self, expected_loss, self._expected_eval_metrics(expected_loss), + model_fn_ops) - def _assert_multi_class_metrics(self, model_fn_ops): - self.assertItemsEqual(( - "accuracy", - "loss", - ), six.iterkeys(model_fn_ops.eval_metric_ops)) + +class MultiClassModelHeadTest(tf.test.TestCase): + + def setUp(self): + self._logits = ((1., 0., 0.),) + self._labels = (2,) + + def _expected_eval_metrics(self, expected_loss): + return { + "accuracy": 0., + "auc": 1. / 4, + "loss": expected_loss, + "auc/class0": 1., + "auc/class1": 1., + "auc/class2": 0., + "labels/actual_label_mean/class0": 0. / 1, + "labels/actual_label_mean/class1": 0. / 1, + "labels/actual_label_mean/class2": 1. / 1, + "labels/logits_mean/class0": self._logits[0][0], + "labels/logits_mean/class1": self._logits[0][1], + "labels/logits_mean/class2": self._logits[0][2], + "labels/prediction_mean/class0": self._logits[0][0], + "labels/prediction_mean/class1": self._logits[0][1], + "labels/prediction_mean/class2": self._logits[0][2], + "labels/probability_mean/class0": 0.576117, # softmax + "labels/probability_mean/class1": 0.211942, # softmax + "labels/probability_mean/class2": 0.211942, # softmax + } def testMultiClass(self): n_classes = 3 - head = head_lib._multi_class_head(n_classes=n_classes) - with tf.Graph().as_default(), tf.Session() as sess: - logits = tf.constant([[1., 0., 0.]]) - labels = tf.constant([2]) + head = head_lib._multi_class_head( + n_classes=n_classes, metric_class_ids=range(n_classes)) + with tf.Graph().as_default(), tf.Session(): + logits = tf.constant(self._logits) + labels = tf.constant(self._labels) # logloss: z:label, x:logit # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) model_fn_ops = head.head_ops({}, labels, tf.contrib.learn.ModeKeys.TRAIN, _noop_train_op, logits=logits) - self._assert_multi_class_metrics(model_fn_ops) _assert_no_variables(self) - self.assertAlmostEqual(1.5514446, sess.run(model_fn_ops.loss)) + expected_loss = 1.5514446 + _assert_metrics( + self, expected_loss, self._expected_eval_metrics(expected_loss), + model_fn_ops) + + def testMultiClassEvalMode(self): + n_classes = 3 + head = head_lib._multi_class_head( + n_classes=n_classes, metric_class_ids=range(n_classes)) + with tf.Graph().as_default(), tf.Session(): + logits = tf.constant(self._logits) + labels = tf.constant(self._labels) + # logloss: z:label, x:logit + # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) model_fn_ops = head.head_ops({}, labels, tf.contrib.learn.ModeKeys.EVAL, _noop_train_op, logits=logits) self.assertIsNone(model_fn_ops.train_op) + _assert_no_variables(self) + expected_loss = 1.5514446 + _assert_metrics( + self, expected_loss, self._expected_eval_metrics(expected_loss), + model_fn_ops) def testMultiClassWithWeight(self): n_classes = 3 head = head_lib._multi_class_head( - n_classes=n_classes, weight_column_name="label_weight") - with tf.Graph().as_default(), tf.Session() as sess: - features = {"label_weight": tf.constant([0.1])} - logits = tf.constant([[1., 0., 0.]]) - labels = tf.constant([2]) + n_classes=n_classes, weight_column_name="label_weight", + metric_class_ids=range(n_classes)) + with tf.Graph().as_default(), tf.Session(): + weight = .1 + features = {"label_weight": tf.constant([weight])} + logits = tf.constant(self._logits) + labels = tf.constant(self._labels) # logloss: z:label, x:logit # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) model_fn_ops = head.head_ops(features, labels, tf.contrib.learn.ModeKeys.TRAIN, _noop_train_op, logits=logits) - self._assert_multi_class_metrics(model_fn_ops) _assert_no_variables(self) - self.assertAlmostEqual(.15514446, sess.run(model_fn_ops.loss)) + expected_loss = 1.5514446 + _assert_metrics( + self, expected_loss * weight, + self._expected_eval_metrics(expected_loss), model_fn_ops) def testInvalidNClasses(self): for n_classes in (None, -1, 0, 1): @@ -370,15 +509,9 @@ def setUp(self): # (i.e., < 0) but it is within the [-1,1] margin. There is a 0.5 loss # incurred by this example. The 2nd prediction is outside the margin so it # incurs no loss at all. - self._predictions = ((-0.5,), (1.2,)) + self._predictions = ((-.5,), (1.2,)) self._labels = (0, 1) - self._expected_losses = (0.5, 0.0) - - def _assert_metrics(self, model_fn_ops): - self.assertItemsEqual(( - "accuracy", - "loss", - ), six.iterkeys(model_fn_ops.eval_metric_ops)) + self._expected_losses = (.5, 0.) def testBinarySVMDefaultWeights(self): head = head_lib._binary_svm_head() @@ -388,15 +521,28 @@ def testBinarySVMDefaultWeights(self): model_fn_ops = head.head_ops({}, labels, tf.contrib.learn.ModeKeys.TRAIN, _noop_train_op, logits=predictions) - self._assert_metrics(model_fn_ops) _assert_no_variables(self) - self.assertAlmostEqual( - np.average(self._expected_losses), model_fn_ops.loss.eval()) + expected_loss = np.average(self._expected_losses) + _assert_metrics(self, expected_loss, { + "accuracy": 1., + "loss": expected_loss, + }, model_fn_ops) - model_fn_ops = head.head_ops({}, labels, - tf.contrib.learn.ModeKeys.EVAL, - _noop_train_op, logits=predictions) - self.assertIsNone(model_fn_ops.train_op) + def testBinarySVMEvalMode(self): + head = head_lib._binary_svm_head() + with tf.Graph().as_default(), tf.Session(): + predictions = tf.constant(self._predictions) + labels = tf.constant(self._labels) + model_fn_ops = head.head_ops({}, labels, + tf.contrib.learn.ModeKeys.EVAL, + _noop_train_op, logits=predictions) + self.assertIsNone(model_fn_ops.train_op) + _assert_no_variables(self) + expected_loss = np.average(self._expected_losses) + _assert_metrics(self, expected_loss, { + "accuracy": 1., + "loss": expected_loss, + }, model_fn_ops) def testBinarySVMWithLabelName(self): label_name = "my_label" @@ -407,31 +553,30 @@ def testBinarySVMWithLabelName(self): model_fn_ops = head.head_ops({}, labels, tf.contrib.learn.ModeKeys.TRAIN, _noop_train_op, logits=predictions) - self._assert_metrics(model_fn_ops) _assert_no_variables(self) - self.assertAlmostEqual( - np.average(self._expected_losses), model_fn_ops.loss.eval()) - - model_fn_ops = head.head_ops({}, labels, - tf.contrib.learn.ModeKeys.EVAL, - _noop_train_op, logits=predictions) - self.assertIsNone(model_fn_ops.train_op) + expected_loss = np.average(self._expected_losses) + _assert_metrics(self, expected_loss, { + "accuracy": 1., + "loss": expected_loss, + }, model_fn_ops) def testBinarySVMWithWeights(self): head = head_lib._binary_svm_head(weight_column_name="weights") with tf.Graph().as_default(), tf.Session(): predictions = tf.constant(self._predictions) labels = tf.constant(self._labels) - weights = (7.0, 11.0) + weights = (7., 11.) features = {"weights": tf.constant(weights)} model_fn_ops = head.head_ops(features, labels, tf.contrib.learn.ModeKeys.TRAIN, _noop_train_op, logits=predictions) - self._assert_metrics(model_fn_ops) _assert_no_variables(self) - self.assertAlmostEqual( - np.sum(np.multiply(weights, self._expected_losses)) / 2.0, - model_fn_ops.loss.eval()) + expected_weighted_sum = np.sum(np.multiply( + weights, self._expected_losses)) + _assert_metrics(self, expected_weighted_sum / len(weights), { + "accuracy": 1., + "loss": expected_weighted_sum / np.sum(weights), + }, model_fn_ops) def testBinarySVMWithCenteredBias(self): head = head_lib._binary_svm_head(enable_centered_bias=True) @@ -441,7 +586,6 @@ def testBinarySVMWithCenteredBias(self): model_fn_ops = head.head_ops({}, labels, tf.contrib.learn.ModeKeys.TRAIN, _noop_train_op, logits=predictions) - self._assert_metrics(model_fn_ops) _assert_variables(self, expected_global=( "centered_bias_weight:0", "centered_bias_weight/Adagrad:0", @@ -449,8 +593,11 @@ def testBinarySVMWithCenteredBias(self): "centered_bias_weight:0", )) tf.global_variables_initializer().run() - self.assertAlmostEqual( - np.average(self._expected_losses), model_fn_ops.loss.eval()) + expected_loss = np.average(self._expected_losses) + _assert_metrics(self, expected_loss, { + "accuracy": 1., + "loss": expected_loss, + }, model_fn_ops) def _noop_train_op(unused_loss): diff --git a/tensorflow/contrib/learn/python/learn/estimators/metric_key.py b/tensorflow/contrib/learn/python/learn/estimators/metric_key.py index 8df08e507fed33..10ac888eca7a0f 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/metric_key.py +++ b/tensorflow/contrib/learn/python/learn/estimators/metric_key.py @@ -19,10 +19,16 @@ class MetricKey(object): + """Metric key strings.""" LOSS = "loss" AUC = "auc" + CLASS_AUC = "auc/class%d" PREDICTION_MEAN = "labels/prediction_mean" + CLASS_PREDICTION_MEAN = "labels/prediction_mean/class%d" + CLASS_LOGITS_MEAN = "labels/logits_mean/class%d" + CLASS_PROBABILITY_MEAN = "labels/probability_mean/class%d" LABEL_MEAN = "labels/actual_label_mean" + CLASS_LABEL_MEAN = "labels/actual_label_mean/class%d" ACCURACY = "accuracy" ACCURACY_BASELINE = "accuracy/baseline_label_mean" ACCURACY_MEAN = "accuracy/threshold_%f_mean" diff --git a/tensorflow/contrib/learn/python/learn/metric_spec.py b/tensorflow/contrib/learn/python/learn/metric_spec.py index a4df7ba658c724..1c404903e53fc5 100644 --- a/tensorflow/contrib/learn/python/learn/metric_spec.py +++ b/tensorflow/contrib/learn/python/learn/metric_spec.py @@ -194,6 +194,9 @@ def _get_dict(name, dict_or_tensor, key): raise ValueError('MetricSpec with ' + name + '_key specified' ' requires ' + name + 's dict, got %s' % dict_or_tensor) + if key not in dict_or_tensor: + raise KeyError( + 'Key \'%s\' missing from %s.' % (key, dict_or_tensor.keys())) return dict_or_tensor[key] else: if isinstance(dict_or_tensor, dict):