Skip to content

Commit

Permalink
Added embedding learning rate multiplier support for dnn-linear combi…
Browse files Browse the repository at this point in the history
…ned classifier (aka wide-n-deep)

Change: 140487205
  • Loading branch information
ispirmustafa authored and tensorflower-gardener committed Nov 29, 2016
1 parent 3f8fc30 commit fb1303a
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 37 deletions.
36 changes: 5 additions & 31 deletions tensorflow/contrib/learn/python/learn/estimators/dnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from tensorflow.contrib.framework import deprecated
from tensorflow.contrib.framework import deprecated_arg_values
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib
from tensorflow.contrib.layers.python.layers import optimizers
from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
Expand All @@ -37,7 +36,6 @@
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
from tensorflow.contrib.learn.python.learn.utils import export
from tensorflow.python import summary
from tensorflow.python.framework import ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import variable_scope
Expand Down Expand Up @@ -68,31 +66,6 @@ def _add_hidden_layer_summary(value, tag):
summary.histogram("%s_activation" % tag, value)


def _get_embedding_variable(column, collection_key, input_layer_scope):
return ops.get_collection(collection_key,
input_layer_scope + "/" + column.name)


def _extract_embedding_lr_multipliers(embedding_lr_multipliers, collection_key,
input_layer_scope):
"""Convert embedding lr multipliers to variable based gradient multiplier."""
if not embedding_lr_multipliers:
return None
gradient_multipliers = {}
for column, lr_mult in embedding_lr_multipliers.items():
if not isinstance(column, feature_column_lib._EmbeddingColumn): # pylint: disable=protected-access
raise ValueError(
"learning rate multipler can be defined for embedding columns. "
"It is defined for {}".format(column))
embedding = _get_embedding_variable(
column, collection_key, input_layer_scope)
if not embedding:
raise ValueError("Couldn't find a variable for column {}".format(column))
for v in embedding:
gradient_multipliers[v] = lr_mult
return gradient_multipliers


def _dnn_model_fn(features, labels, mode, params):
"""Deep Neural Net model_fn.
Expand All @@ -119,7 +92,7 @@ def _dnn_model_fn(features, labels, mode, params):
clipped to their global norm with this clipping ratio.
* num_ps_replicas: The number of parameter server replicas.
* embedding_lr_multipliers: Optional. A dictionary from
`EbeddingColumn` to a `float` multiplier. Multiplier will be used to
`EmbeddingColumn` to a `float` multiplier. Multiplier will be used to
multiply with learning rate for the embedding variables.
Returns:
Expand Down Expand Up @@ -194,8 +167,9 @@ def _train_op_fn(loss):
global_step=contrib_variables.get_global_step(),
learning_rate=_LEARNING_RATE,
optimizer=_get_optimizer(optimizer),
gradient_multipliers=_extract_embedding_lr_multipliers(
embedding_lr_multipliers, parent_scope, input_layer_scope),
gradient_multipliers=(
dnn_linear_combined._extract_embedding_lr_multipliers( # pylint: disable=protected-access
embedding_lr_multipliers, parent_scope, input_layer_scope)),
clip_gradients=gradient_clip_norm,
name=parent_scope,
# Empty summaries to prevent optimizers from logging the training_loss.
Expand Down Expand Up @@ -308,7 +282,7 @@ def __init__(self,
labels which are the output of `input_fn` and
returns features and labels which will be fed
into the model.
embedding_lr_multipliers: Optional. A dictionary from `EbeddingColumn` to
embedding_lr_multipliers: Optional. A dictionary from `EmbeddingColumn` to
a `float` multiplier. Multiplier will be used to multiply with
learning rate for the embedding variables.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tensorflow.contrib.framework import deprecated
from tensorflow.contrib.framework import deprecated_arg_values
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib
from tensorflow.contrib.layers.python.layers import feature_column_ops
from tensorflow.contrib.layers.python.layers import optimizers
from tensorflow.contrib.learn.python.learn import evaluable
Expand Down Expand Up @@ -365,6 +366,31 @@ def _add_hidden_layer_summary(value, tag):
logging_ops.histogram_summary("%s:activation" % tag, value)


def _get_embedding_variable(column, collection_key, input_layer_scope):
return ops.get_collection(collection_key,
input_layer_scope + "/" + column.name)


def _extract_embedding_lr_multipliers(embedding_lr_multipliers, collection_key,
input_layer_scope):
"""Converts embedding lr multipliers to variable based gradient multiplier."""
if not embedding_lr_multipliers:
return None
gradient_multipliers = {}
for column, lr_mult in embedding_lr_multipliers.items():
if not isinstance(column, feature_column_lib._EmbeddingColumn): # pylint: disable=protected-access
raise ValueError(
"learning rate multipler can only be defined for embedding columns. "
"It is defined for {}".format(column))
embedding = _get_embedding_variable(
column, collection_key, input_layer_scope)
if not embedding:
raise ValueError("Couldn't find a variable for column {}".format(column))
for v in embedding:
gradient_multipliers[v] = lr_mult
return gradient_multipliers


def _dnn_linear_combined_model_fn(features, labels, mode, params):
"""Deep Neural Net and Linear combined model_fn.
Expand Down Expand Up @@ -396,6 +422,9 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params):
* gradient_clip_norm: A float > 0. If provided, gradients are
clipped to their global norm with this clipping ratio.
* num_ps_replicas: The number of parameter server replicas.
* embedding_lr_multipliers: Optional. A dictionary from
`EmbeddingColumn` to a `float` multiplier. Multiplier will be used to
multiply with learning rate for the embedding variables.
Returns:
`ModelFnOps`
Expand All @@ -414,7 +443,8 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params):
dnn_activation_fn = params.get("dnn_activation_fn")
dnn_dropout = params.get("dnn_dropout")
gradient_clip_norm = params.get("gradient_clip_norm")
num_ps_replicas = params["num_ps_replicas"]
num_ps_replicas = params.get("num_ps_replicas", 0)
embedding_lr_multipliers = params.get("embedding_lr_multipliers", {})

if not linear_feature_columns and not dnn_feature_columns:
raise ValueError(
Expand All @@ -432,8 +462,9 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params):
partitioned_variables.min_max_variable_partitioner(
max_partitions=num_ps_replicas,
min_slice_size=64 << 20))
input_layer_scope = dnn_parent_scope + "/input_from_feature_columns"
with variable_scope.variable_scope(
dnn_parent_scope + "/input_from_feature_columns",
input_layer_scope,
values=features.values(),
partitioner=input_layer_partitioner) as scope:
net = layers.input_from_feature_columns(
Expand Down Expand Up @@ -521,6 +552,9 @@ def _make_training_op(training_loss):
global_step=contrib_variables.get_global_step(),
learning_rate=_DNN_LEARNING_RATE,
optimizer=_get_optimizer(dnn_optimizer),
gradient_multipliers=_extract_embedding_lr_multipliers( # pylint: disable=protected-access
embedding_lr_multipliers, dnn_parent_scope,
input_layer_scope),
clip_gradients=gradient_clip_norm,
variables=ops.get_collection(dnn_parent_scope),
name=dnn_parent_scope,
Expand Down Expand Up @@ -612,7 +646,8 @@ def __init__(self, # _joint_linear_weights pylint: disable=invalid-name
gradient_clip_norm=None,
enable_centered_bias=False,
config=None,
feature_engineering_fn=None):
feature_engineering_fn=None,
embedding_lr_multipliers=None):
"""Constructs a DNNLinearCombinedClassifier instance.
Args:
Expand Down Expand Up @@ -656,6 +691,9 @@ def __init__(self, # _joint_linear_weights pylint: disable=invalid-name
labels which are the output of `input_fn` and
returns features and labels which will be fed
into the model.
embedding_lr_multipliers: Optional. A dictionary from `EmbeddingColumn` to
a `float` multiplier. Multiplier will be used to multiply with
learning rate for the embedding variables.
Raises:
ValueError: If `n_classes` < 2.
Expand Down Expand Up @@ -695,6 +733,7 @@ def __init__(self, # _joint_linear_weights pylint: disable=invalid-name
"dnn_dropout": dnn_dropout,
"gradient_clip_norm": gradient_clip_norm,
"num_ps_replicas": config.num_ps_replicas if config else 0,
"embedding_lr_multipliers": embedding_lr_multipliers,
},
feature_engineering_fn=feature_engineering_fn)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
import tensorflow as tf

from tensorflow.contrib.learn.python.learn.estimators import _sklearn
from tensorflow.contrib.learn.python.learn.estimators import dnn_linear_combined
from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
from tensorflow.contrib.learn.python.learn.estimators import test_data
from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec

Expand All @@ -39,6 +41,82 @@ def _assert_metrics_in_range(keys, metrics):
0.0 - epsilon, 1.0 + epsilon, key, metrics)


class EmbeddingMultiplierTest(tf.test.TestCase):
"""dnn_model_fn tests."""

def testRaisesNonEmbeddingColumn(self):
one_hot_language = tf.contrib.layers.one_hot_column(
tf.contrib.layers.sparse_column_with_hash_bucket('language', 10))

params = {
'dnn_feature_columns': [one_hot_language],
'head': head_lib._multi_class_head(2),
'dnn_hidden_units': [1],
# Set lr mult to 0. to keep embeddings constant.
'embedding_lr_multipliers': {
one_hot_language: 0.0
},
'dnn_optimizer': 'Adagrad',
}
features = {
'language':
tf.SparseTensor(
values=['en', 'fr', 'zh'],
indices=[[0, 0], [1, 0], [2, 0]],
shape=[3, 1]),
}
labels = tf.constant([[0], [0], [0]], dtype=tf.int32)
with self.assertRaisesRegexp(
ValueError, 'can only be defined for embedding columns'):
dnn_linear_combined._dnn_linear_combined_model_fn(
features, labels, tf.contrib.learn.ModeKeys.TRAIN, params)

def testMultipliesGradient(self):
embedding_language = tf.contrib.layers.embedding_column(
tf.contrib.layers.sparse_column_with_hash_bucket('language', 10),
dimension=1, initializer=tf.constant_initializer(0.1))
embedding_wire = tf.contrib.layers.embedding_column(
tf.contrib.layers.sparse_column_with_hash_bucket('wire', 10),
dimension=1, initializer=tf.constant_initializer(0.1))

params = {
'dnn_feature_columns': [embedding_language, embedding_wire],
'head': head_lib._multi_class_head(2),
'dnn_hidden_units': [1],
# Set lr mult to 0. to keep embeddings constant.
'embedding_lr_multipliers': {
embedding_language: 0.0
},
'dnn_optimizer': 'Adagrad',
}
features = {
'language':
tf.SparseTensor(
values=['en', 'fr', 'zh'],
indices=[[0, 0], [1, 0], [2, 0]],
shape=[3, 1]),
'wire':
tf.SparseTensor(
values=['omar', 'stringer', 'marlo'],
indices=[[0, 0], [1, 0], [2, 0]],
shape=[3, 1]),
}
labels = tf.constant([[0], [0], [0]], dtype=tf.int32)
model_ops = dnn_linear_combined._dnn_linear_combined_model_fn(
features, labels, tf.contrib.learn.ModeKeys.TRAIN, params)
with tf.train.MonitoredSession() as sess:
language_var = dnn_linear_combined._get_embedding_variable(
embedding_language, 'dnn', 'dnn/input_from_feature_columns')
wire_var = dnn_linear_combined._get_embedding_variable(
embedding_wire, 'dnn', 'dnn/input_from_feature_columns')
for _ in range(2):
_, language_value, wire_value = sess.run(
[model_ops.train_op, language_var, wire_var])
initial_value = np.full_like(language_value, 0.1)
self.assertTrue(np.all(np.isclose(language_value, initial_value)))
self.assertFalse(np.all(np.isclose(wire_value, initial_value)))


class DNNLinearCombinedClassifierTest(tf.test.TestCase):

def testEstimatorContract(self):
Expand All @@ -54,6 +132,18 @@ def testNoFeatureColumns(self):
dnn_feature_columns=None,
dnn_hidden_units=[3, 3])

def testEmbeddingMultiplier(self):
embedding_language = tf.contrib.layers.embedding_column(
tf.contrib.layers.sparse_column_with_hash_bucket('language', 10),
dimension=1, initializer=tf.constant_initializer(0.1))
classifier = tf.contrib.learn.DNNLinearCombinedClassifier(
dnn_feature_columns=[embedding_language],
dnn_hidden_units=[3, 3],
embedding_lr_multipliers={embedding_language: 0.8})
self.assertEqual(
{embedding_language: 0.8},
classifier._estimator.params['embedding_lr_multipliers'])

def testLogisticRegression_MatrixData(self):
"""Tests binary classification using matrix data as input."""
iris = test_data.prepare_iris_data_for_logistic_regression()
Expand Down
19 changes: 16 additions & 3 deletions tensorflow/contrib/learn/python/learn/estimators/dnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from tensorflow.contrib.learn.python.learn.estimators import _sklearn
from tensorflow.contrib.learn.python.learn.estimators import dnn
from tensorflow.contrib.learn.python.learn.estimators import dnn_linear_combined
from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
from tensorflow.contrib.learn.python.learn.estimators import test_data
Expand Down Expand Up @@ -60,7 +61,7 @@ def testRaisesNonEmbeddingColumn(self):
}
labels = tf.constant([[0], [0], [0]], dtype=tf.int32)
with self.assertRaisesRegexp(
ValueError, 'can be defined for embedding columns'):
ValueError, 'can only be defined for embedding columns'):
dnn._dnn_model_fn(features, labels,
tf.contrib.learn.ModeKeys.TRAIN, params)

Expand Down Expand Up @@ -97,9 +98,9 @@ def testMultipliesGradient(self):
model_ops = dnn._dnn_model_fn(features, labels,
tf.contrib.learn.ModeKeys.TRAIN, params)
with tf.train.MonitoredSession() as sess:
language_var = dnn._get_embedding_variable(
language_var = dnn_linear_combined._get_embedding_variable(
embedding_language, 'dnn', 'dnn/input_from_feature_columns')
wire_var = dnn._get_embedding_variable(
wire_var = dnn_linear_combined._get_embedding_variable(
embedding_wire, 'dnn', 'dnn/input_from_feature_columns')
for _ in range(2):
_, language_value, wire_value = sess.run(
Expand All @@ -119,6 +120,18 @@ def testEstimatorContract(self):
estimator_test_utils.assert_estimator_contract(
self, tf.contrib.learn.DNNClassifier)

def testEmbeddingMultiplier(self):
embedding_language = tf.contrib.layers.embedding_column(
tf.contrib.layers.sparse_column_with_hash_bucket('language', 10),
dimension=1, initializer=tf.constant_initializer(0.1))
classifier = tf.contrib.learn.DNNClassifier(
feature_columns=[embedding_language],
hidden_units=[3, 3],
embedding_lr_multipliers={embedding_language: 0.8})
self.assertEqual(
{embedding_language: 0.8},
classifier._estimator.params['embedding_lr_multipliers'])

def testLogisticRegression_MatrixData(self):
"""Tests binary classification using matrix data as input."""
cont_features = [
Expand Down

0 comments on commit fb1303a

Please sign in to comment.