Skip to content

Commit

Permalink
Fix the Keras optimizer naming issue in canned estimator test.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 252225882
  • Loading branch information
yhliang2018 authored and mihaimaruseac committed Jun 11, 2019
1 parent 395d8d2 commit a031b45
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 23 deletions.
2 changes: 0 additions & 2 deletions tensorflow_estimator/python/estimator/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,6 @@ py_test(
srcs = ["canned/dnn_estimator_test.py"],
srcs_version = "PY2AND3",
tags = [
"no_oss", # for 1.14 release, fails with Key SGD/iter not found in checkpoint
"no_pip",
"notsan",
"optonly", # times out http://b/79220679
Expand Down Expand Up @@ -724,7 +723,6 @@ py_test(
shard_count = 32,
srcs_version = "PY2AND3",
tags = [
"no_oss", # b/134391415
"no_pip",
"notsan", # TODO(b/67510291)
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -848,17 +848,24 @@ def test_classifier_basic_warm_starting(self, fc_impl):
# Create a second DNNLinearCombinedClassifier, warm-started from the first.
# Use a learning_rate = 0.0 optimizer to check values (use SGD so we don't
# have accumulator values that change).
warm_started_dnn_lc_classifier = (
dnn_linear_combined.DNNLinearCombinedClassifierV2(
linear_feature_columns=[age],
dnn_feature_columns=[city],
dnn_hidden_units=[256, 128],
n_classes=4,
linear_optimizer=gradient_descent_v2.SGD(
learning_rate=0.0),
dnn_optimizer=gradient_descent_v2.SGD(
learning_rate=0.0),
warm_start_from=dnn_lc_classifier.model_dir))
# To avoid optimizer naming issue during warm start, when to create the
# optimizer instance, the dnn_optimizer needs to be created first
# before the linear_optimizer, since this is the order pre-defined
# in the model function.
# Create a default graph context to make sure the optimizer instance is
# created within Graph v1 to make it consistent with estimator Graph.
with ops.Graph().as_default():
warm_started_dnn_lc_classifier = (
dnn_linear_combined.DNNLinearCombinedClassifierV2(
linear_feature_columns=[age],
dnn_feature_columns=[city],
dnn_hidden_units=[256, 128],
n_classes=4,
dnn_optimizer=gradient_descent_v2.SGD(
learning_rate=0.0),
linear_optimizer=gradient_descent_v2.SGD(
learning_rate=0.0),
warm_start_from=dnn_lc_classifier.model_dir))

warm_started_dnn_lc_classifier.train(input_fn=self._input_fn, max_steps=1)
for variable_name in warm_started_dnn_lc_classifier.get_variable_names():
Expand Down Expand Up @@ -892,16 +899,23 @@ def test_regressor_basic_warm_starting(self, fc_impl):
# Create a second DNNLinearCombinedRegressor, warm-started from the first.
# Use a learning_rate = 0.0 optimizer to check values (use SGD so we don't
# have accumulator values that change).
warm_started_dnn_lc_regressor = (
dnn_linear_combined.DNNLinearCombinedRegressorV2(
linear_feature_columns=[age],
dnn_feature_columns=[city],
dnn_hidden_units=[256, 128],
linear_optimizer=gradient_descent_v2.SGD(
learning_rate=0.0),
dnn_optimizer=gradient_descent_v2.SGD(
learning_rate=0.0),
warm_start_from=dnn_lc_regressor.model_dir))
# To avoid optimizer naming issue during warm start, when to create the
# optimizer instance, the dnn_optimizer needs to be created first
# before the linear_optimizer, since this is the order pre-defined
# in the model function.
# Create a default graph context to make sure the optimizer instance is
# created within Graph v1 to make it consistent with estimator Graph.
with ops.Graph().as_default():
warm_started_dnn_lc_regressor = (
dnn_linear_combined.DNNLinearCombinedRegressorV2(
linear_feature_columns=[age],
dnn_feature_columns=[city],
dnn_hidden_units=[256, 128],
dnn_optimizer=gradient_descent_v2.SGD(
learning_rate=0.0),
linear_optimizer=gradient_descent_v2.SGD(
learning_rate=0.0),
warm_start_from=dnn_lc_regressor.model_dir))

warm_started_dnn_lc_regressor.train(input_fn=self._input_fn, max_steps=1)
for variable_name in warm_started_dnn_lc_regressor.get_variable_names():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,9 @@ def __init__(self,
def setUp(self):
# Create a directory to save our old checkpoint and vocabularies to.
self._ckpt_and_vocab_dir = tempfile.mkdtemp()
# Reset the default graph in each test method to avoid the Keras optimizer
# naming issue during warm starting.
ops.reset_default_graph()

# Make a dummy input_fn.
def _input_fn():
Expand Down

0 comments on commit a031b45

Please sign in to comment.