Skip to content

Commit

Permalink
Allow models with non-trainable variables in `functional_model_from_k…
Browse files Browse the repository at this point in the history
…eras`.

Note that this is done by avoiding capturing variables via a creation-based context manager, and instead grabs them directly within a graphdef from the keras model. This allows us to disambiguate trainable and non-trainable variables.

PiperOrigin-RevId: 634053525
  • Loading branch information
zcharles8 authored and copybara-github committed May 15, 2024
1 parent 62fd042 commit e67098b
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 58 deletions.
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Unreleased

* Enable support for models with non-trainable variables in
`functional_model_from_keras`.

## Breaking Changes

* Updated `com_github_grpc_grpc` to version `1.50.0`.
Expand Down
1 change: 0 additions & 1 deletion tensorflow_federated/python/learning/models/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ py_library(
"//tensorflow_federated/python/learning/metrics:keras_finalizer",
"//tensorflow_federated/python/learning/metrics:keras_utils",
"//tensorflow_federated/python/learning/metrics:types",
"//tensorflow_federated/python/tensorflow_libs:variable_utils",
],
)

Expand Down
81 changes: 36 additions & 45 deletions tensorflow_federated/python/learning/models/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from tensorflow_federated.python.learning.metrics import keras_utils
from tensorflow_federated.python.learning.metrics import types
from tensorflow_federated.python.learning.models import variable
from tensorflow_federated.python.tensorflow_libs import variable_utils


Weight = Union[np.ndarray, int, float]
Expand Down Expand Up @@ -477,15 +476,6 @@ def functional_model_from_keras(
'incompatible with `tff.learning.models.FunctionalModel`. Consider '
'using group normalization instead.'
)
if keras_model.non_trainable_variables:
raise KerasFunctionalModelError(
'Received a Keras model with non-trainable variables. Keras models'
' with non-trainable variables are currently not supported by'
' FunctionalModel. Most training algorithms (e.g. Federated'
' Averaging) will not aggregate them, and they are not updated'
' locally by the optimizer. We can relax this in the future if we'
' have APIs that support updating non-trainable variables.'
)
elif not callable(keras_model):
raise ValueError(
'`keras_model` must be a `tf.keras.Model` or a no-arg '
Expand All @@ -508,42 +498,43 @@ def functional_model_from_keras(
# also setup ops to inject the current model weights, because the cloned model
# will be re-initialized from scratch.
with tf.Graph().as_default() as g:
with variable_utils.record_variable_creation_scope() as captured_variables:
if isinstance(keras_model, tf.keras.Model):
try:
cloned_model = tf.keras.models.clone_model(keras_model)
except RuntimeError as e:
raise KerasFunctionalModelError(
'Encountered a error converting the Keras model. Often this '
'occurs when the `tf.keras.Model` has a layer that receives '
'inputs from other layers directly (e.g. shared embeddings).'
'To avoid the problem, wrap the `tf.keras.Model` construction in '
'a no-arg callable (e.g. lambda) and pass that callable to '
'`functional_model_from_keras`'
) from e
if len(cloned_model.variables) != len(keras_model.variables):
raise KerasFunctionalModelError(
'The input Keras model is likely sharing variables across layers '
'which is unsupported. Cloning the model will duplicate these '
'variables and result in unexpected training gradients.'
)
else:
cloned_model = keras_model()

# Ensure our cloned model has the same weights as the current model.
# We'll feed in the current model waits into the palceholders for
# assignmnet in a session below.
def assign_placeholder(v):
p = tf.compat.v1.placeholder(dtype=v.dtype)
return v.assign(p), p
if isinstance(keras_model, tf.keras.Model):
try:
cloned_model = tf.keras.models.clone_model(keras_model)
except RuntimeError as e:
raise KerasFunctionalModelError(
'Encountered a error converting the Keras model. Often this '
'occurs when the `tf.keras.Model` has a layer that receives '
'inputs from other layers directly (e.g. shared embeddings).'
'To avoid the problem, wrap the `tf.keras.Model` construction in '
'a no-arg callable (e.g. lambda) and pass that callable to '
'`functional_model_from_keras`'
) from e
if len(cloned_model.variables) != len(keras_model.variables):
raise KerasFunctionalModelError(
'The input Keras model is likely sharing variables across layers '
'which is unsupported. Cloning the model will duplicate these '
'variables and result in unexpected training gradients.'
)
else:
cloned_model = keras_model()
captured_variables = cloned_model.variables
captured_trainable_variables = cloned_model.trainable_variables
captured_nontrainable_variables = cloned_model.non_trainable_variables

# Ensure our cloned model has the same weights as the current model.
# We'll feed in the current model waits into the placeholders for
# assignmnet in a session below.
def assign_placeholder(v):
p = tf.compat.v1.placeholder(dtype=v.dtype)
return v.assign(p), p

assign_ops, placeholders = zip(
*(assign_placeholder(v) for v in cloned_model.variables)
)

assign_ops, placeholders = zip(
*(assign_placeholder(v) for v in cloned_model.variables)
)
trainable_variables = tuple(v for v in captured_variables if v.trainable)
non_trainable_variables = tuple(
v for v in captured_variables if not v.trainable
)
trainable_variables = tuple(v for v in captured_trainable_variables)
non_trainable_variables = tuple(v for v in captured_nontrainable_variables)

# Here we get the initial weights from the incoming keras model in the order
# they are constructed; and also ensure that the values are set to the
Expand Down
25 changes: 13 additions & 12 deletions tensorflow_federated/python/learning/models/functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,23 +899,24 @@ def train():
self.assertGreater(initial_loss, 2.0)
self.assertLess(final_loss, 0.2)

def test_keras_model_with_non_trainable_variables_fails(self):
def test_keras_model_with_non_trainable_variables(self):
inputs = tf.keras.layers.Input(shape=[1])
d = tf.keras.layers.Dense(1)
d.trainable = False
outputs = d(inputs)
keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
with self.assertRaisesRegex(
functional.KerasFunctionalModelError, 'non-trainable variables'
):
functional.functional_model_from_keras(
keras_model,
tf.keras.losses.MeanSquaredError(),
input_spec=(
tf.TensorSpec(shape=[None, 1]),
tf.TensorSpec(shape=[None, 1]),
),
)
functional_model = functional.functional_model_from_keras(
keras_model,
tf.keras.losses.MeanSquaredError(),
input_spec=(
tf.TensorSpec(shape=[None, 1]),
tf.TensorSpec(shape=[None, 1]),
),
)
self.assertEmpty(functional_model.initial_weights[0])
# We expect there to be two non-trainable variables: the kernel and bias
# of the dense layer.
self.assertLen(functional_model.initial_weights[1], 2)

def test_keras_model_with_batch_normalization_fails(self):
model = tf.keras.models.Sequential([
Expand Down

0 comments on commit e67098b

Please sign in to comment.