Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow models with non-trainable variables in functional_model_from_keras. #4683

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Unreleased

* Enable support for models with non-trainable variables in
`functional_model_from_keras`.

## Breaking Changes

* Updated `com_github_grpc_grpc` to version `1.50.0`.
Expand Down
1 change: 0 additions & 1 deletion tensorflow_federated/python/learning/models/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ py_library(
"//tensorflow_federated/python/learning/metrics:keras_finalizer",
"//tensorflow_federated/python/learning/metrics:keras_utils",
"//tensorflow_federated/python/learning/metrics:types",
"//tensorflow_federated/python/tensorflow_libs:variable_utils",
],
)

Expand Down
81 changes: 36 additions & 45 deletions tensorflow_federated/python/learning/models/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from tensorflow_federated.python.learning.metrics import keras_utils
from tensorflow_federated.python.learning.metrics import types
from tensorflow_federated.python.learning.models import variable
from tensorflow_federated.python.tensorflow_libs import variable_utils


Weight = Union[np.ndarray, int, float]
Expand Down Expand Up @@ -477,15 +476,6 @@ def functional_model_from_keras(
'incompatible with `tff.learning.models.FunctionalModel`. Consider '
'using group normalization instead.'
)
if keras_model.non_trainable_variables:
raise KerasFunctionalModelError(
'Received a Keras model with non-trainable variables. Keras models'
' with non-trainable variables are currently not supported by'
' FunctionalModel. Most training algorithms (e.g. Federated'
' Averaging) will not aggregate them, and they are not updated'
' locally by the optimizer. We can relax this in the future if we'
' have APIs that support updating non-trainable variables.'
)
elif not callable(keras_model):
raise ValueError(
'`keras_model` must be a `tf.keras.Model` or a no-arg '
Expand All @@ -508,42 +498,43 @@ def functional_model_from_keras(
# also setup ops to inject the current model weights, because the cloned model
# will be re-initialized from scratch.
with tf.Graph().as_default() as g:
with variable_utils.record_variable_creation_scope() as captured_variables:
if isinstance(keras_model, tf.keras.Model):
try:
cloned_model = tf.keras.models.clone_model(keras_model)
except RuntimeError as e:
raise KerasFunctionalModelError(
'Encountered a error converting the Keras model. Often this '
'occurs when the `tf.keras.Model` has a layer that receives '
'inputs from other layers directly (e.g. shared embeddings).'
'To avoid the problem, wrap the `tf.keras.Model` construction in '
'a no-arg callable (e.g. lambda) and pass that callable to '
'`functional_model_from_keras`'
) from e
if len(cloned_model.variables) != len(keras_model.variables):
raise KerasFunctionalModelError(
'The input Keras model is likely sharing variables across layers '
'which is unsupported. Cloning the model will duplicate these '
'variables and result in unexpected training gradients.'
)
else:
cloned_model = keras_model()

# Ensure our cloned model has the same weights as the current model.
# We'll feed in the current model waits into the palceholders for
# assignmnet in a session below.
def assign_placeholder(v):
p = tf.compat.v1.placeholder(dtype=v.dtype)
return v.assign(p), p
if isinstance(keras_model, tf.keras.Model):
try:
cloned_model = tf.keras.models.clone_model(keras_model)
except RuntimeError as e:
raise KerasFunctionalModelError(
'Encountered a error converting the Keras model. Often this '
'occurs when the `tf.keras.Model` has a layer that receives '
'inputs from other layers directly (e.g. shared embeddings).'
'To avoid the problem, wrap the `tf.keras.Model` construction in '
'a no-arg callable (e.g. lambda) and pass that callable to '
'`functional_model_from_keras`'
) from e
if len(cloned_model.variables) != len(keras_model.variables):
raise KerasFunctionalModelError(
'The input Keras model is likely sharing variables across layers '
'which is unsupported. Cloning the model will duplicate these '
'variables and result in unexpected training gradients.'
)
else:
cloned_model = keras_model()
captured_variables = cloned_model.variables
captured_trainable_variables = cloned_model.trainable_variables
captured_nontrainable_variables = cloned_model.non_trainable_variables

# Ensure our cloned model has the same weights as the current model.
# We'll feed in the current model waits into the placeholders for
# assignmnet in a session below.
def assign_placeholder(v):
p = tf.compat.v1.placeholder(dtype=v.dtype)
return v.assign(p), p

assign_ops, placeholders = zip(
*(assign_placeholder(v) for v in cloned_model.variables)
)

assign_ops, placeholders = zip(
*(assign_placeholder(v) for v in cloned_model.variables)
)
trainable_variables = tuple(v for v in captured_variables if v.trainable)
non_trainable_variables = tuple(
v for v in captured_variables if not v.trainable
)
trainable_variables = tuple(v for v in captured_trainable_variables)
non_trainable_variables = tuple(v for v in captured_nontrainable_variables)

# Here we get the initial weights from the incoming keras model in the order
# they are constructed; and also ensure that the values are set to the
Expand Down
25 changes: 13 additions & 12 deletions tensorflow_federated/python/learning/models/functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,23 +899,24 @@ def train():
self.assertGreater(initial_loss, 2.0)
self.assertLess(final_loss, 0.2)

def test_keras_model_with_non_trainable_variables_fails(self):
def test_keras_model_with_non_trainable_variables(self):
inputs = tf.keras.layers.Input(shape=[1])
d = tf.keras.layers.Dense(1)
d.trainable = False
outputs = d(inputs)
keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
with self.assertRaisesRegex(
functional.KerasFunctionalModelError, 'non-trainable variables'
):
functional.functional_model_from_keras(
keras_model,
tf.keras.losses.MeanSquaredError(),
input_spec=(
tf.TensorSpec(shape=[None, 1]),
tf.TensorSpec(shape=[None, 1]),
),
)
functional_model = functional.functional_model_from_keras(
keras_model,
tf.keras.losses.MeanSquaredError(),
input_spec=(
tf.TensorSpec(shape=[None, 1]),
tf.TensorSpec(shape=[None, 1]),
),
)
self.assertEmpty(functional_model.initial_weights[0])
# We expect there to be two non-trainable variables: the kernel and bias
# of the dense layer.
self.assertLen(functional_model.initial_weights[1], 2)

def test_keras_model_with_batch_normalization_fails(self):
model = tf.keras.models.Sequential([
Expand Down