From a600836814834c2c1980fac2179999d6c92267ee Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 7 Aug 2024 12:21:44 -0700 Subject: [PATCH] Sparsity Preserving DP-SGD in TF Privacy Move get_registry_generator_fn from clip_grads.py to gradient_clipping_utils.py and change return type to dataclass. See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm. PiperOrigin-RevId: 660490424 --- .../privacy/fast_gradient_clipping/BUILD | 16 +- .../fast_gradient_clipping/clip_grads.py | 52 ++---- .../gradient_clipping_utils.py | 166 +++++++----------- .../gradient_clipping_utils_test.py | 55 ------ .../fast_gradient_clipping/noise_utils.py | 115 ++++++++++++ .../noise_utils_test.py | 72 ++++++++ tensorflow_privacy/privacy/keras_models/BUILD | 2 +- .../privacy/keras_models/dp_keras_model.py | 3 +- 8 files changed, 280 insertions(+), 201 deletions(-) create mode 100644 tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils.py create mode 100644 tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils_test.py diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD index ac3e47d7..adb5a763 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD @@ -54,10 +54,7 @@ py_test( python_version = "PY3", shard_count = 8, srcs_version = "PY3", - deps = [ - ":gradient_clipping_utils", - ":type_aliases", - ], + deps = [":gradient_clipping_utils"], ) py_library( @@ -83,6 +80,11 @@ py_library( ], ) +py_library( + name = "noise_utils", + srcs = ["noise_utils.py"], +) + py_test( name = "clip_grads_test", srcs = ["clip_grads_test.py"], @@ -96,3 +98,9 @@ py_test( ":type_aliases", ], ) + +py_test( + name = "noise_utils_test", + srcs = ["noise_utils_test.py"], + deps = [":noise_utils"], +) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py index a3168499..100e66ab 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py @@ -32,43 +32,6 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases -def get_registry_generator_fn( - tape: tf.GradientTape, - layer_registry: lr.LayerRegistry, - num_microbatches: Optional[type_aliases.BatchSize] = None, -): - """Creates the generator function for `compute_gradient_norms()`.""" - if layer_registry is None: - # Needed for backwards compatibility. - registry_generator_fn = None - else: - - def registry_generator_fn(layer_instance, args, kwargs): - if layer_instance.trainable_variables: - # Only trainable variables factor into the gradient. - if not layer_registry.is_elem(layer_instance): - raise NotImplementedError( - 'Layer %s is not in the registry of known layers that can ' - 'be used for efficient gradient clipping.' - % layer_instance.__class__.__name__ - ) - registry_fn = layer_registry.lookup(layer_instance) - (layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn( - layer_instance, args, kwargs, tape, num_microbatches - ) - return layer_outputs, ( - str(id(layer_instance)), - layer_vars, - layer_sqr_norm_fn, - layer_instance.trainable_weights, - ) - else: - # Non-trainable layer. - return layer_instance(*args, **kwargs), None - - return registry_generator_fn - - def _infer_per_example_loss_fn(model: tf.keras.Model): """Infer the per-example loss from model config.""" @@ -190,7 +153,7 @@ def compute_gradient_norms( are applied prior to clipping. """ tape = tf.GradientTape(persistent=True, watch_accessed_variables=False) - registry_generator_fn = get_registry_generator_fn( + registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn( tape, layer_registry, num_microbatches ) # First loop computes the model outputs, summed loss, and generator outputs. @@ -241,12 +204,17 @@ def compute_gradient_norms( # $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th # occurrence. This is an over-estimate of the actual norm. For more details, # see the explanation in go/dp-sgd-shared-weights. - for layer_id, v, f, weights_list in filtered_outputs: + for registry_fn_output in filtered_outputs: if trainable_vars is None or any( - w.ref() in trainable_vars for w in weights_list + w.ref() in trainable_vars + for w in registry_fn_output.layer_trainable_weights ): - layer_vars[layer_id].append(v) - layer_sqr_norm_fns[layer_id].append(f) + layer_vars[registry_fn_output.layer_id].append( + registry_fn_output.layer_vars + ) + layer_sqr_norm_fns[registry_fn_output.layer_id].append( + registry_fn_output.layer_sqr_norm_fn + ) # Second loop evaluates the squared L2 norm functions and appends the results. layer_grad_vars = tape.gradient( summed_loss, diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py index 74163187..7a060e9b 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py @@ -13,15 +13,23 @@ # limitations under the License. """Utility functions that help in the computation of per-example gradient norms.""" -from collections.abc import Sequence, Set -from typing import Any, Literal, Optional +from collections.abc import Callable, Sequence, Set +import dataclasses +from typing import Any, Optional, Tuple -from absl import logging import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases +@dataclasses.dataclass(frozen=True) +class RegistryGeneratorFunctionOutput: + layer_id: str + layer_vars: Optional[Sequence[tf.Variable]] + layer_sqr_norm_fn: Optional[type_aliases.SquareNormFunction] + layer_trainable_weights: Optional[Sequence[tf.Variable]] + + def has_internal_compute_graph(input_object: Any): """Checks if input is a TF model and has a TF internal compute graph.""" return ( @@ -33,6 +41,63 @@ def has_internal_compute_graph(input_object: Any): ) +def get_registry_generator_fn( + tape: tf.GradientTape, + layer_registry: lr.LayerRegistry, + num_microbatches: Optional[type_aliases.BatchSize] = None, +) -> Optional[Callable[..., Tuple[tf.Tensor, RegistryGeneratorFunctionOutput]]]: + """Creates the generator function for `model_forward_backward_pass()`. + + Args: + tape: The `tf.GradientTape` to use for the gradient computation. + layer_registry: A `dict` of layers that support "fast" gradient norm + computations. The key is the class of the layer and the value is a + function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where + `output` is the pre-activator tensor, `sqr_grad_norms` is related to the + squared norms of a layer's pre-activation tensor, and `vars` are relevant + trainable + num_microbatches: An optional number or scalar `tf.Tensor` for the number of + microbatches. If not None, indicates that the loss is grouped into + num_microbatches (in this case, the batch dimension needs to be a multiple + of num_microbatches). + + Returns: + A function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where + `output` is the pre-activator tensor, `sqr_grad_norms` is related to the + squared norms of a layer's pre-activation tensor, and `vars` are relevant + trainable variables. + """ + if layer_registry is None: + # Needed for backwards compatibility. + registry_generator_fn = None + else: + + def registry_generator_fn(layer_instance, args, kwargs): + if layer_instance.trainable_variables: + # Only trainable variables factor into the gradient. + if not layer_registry.is_elem(layer_instance): + raise NotImplementedError( + 'Layer %s is not in the registry of known layers that can ' + 'be used for efficient gradient clipping.' + % layer_instance.__class__.__name__ + ) + registry_fn = layer_registry.lookup(layer_instance) + (layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn( + layer_instance, args, kwargs, tape, num_microbatches + ) + return layer_outputs, RegistryGeneratorFunctionOutput( + layer_id=str(id(layer_instance)), + layer_vars=layer_vars, + layer_sqr_norm_fn=layer_sqr_norm_fn, + layer_trainable_weights=layer_instance.trainable_weights, + ) + else: + # Non-trainable layer. + return layer_instance(*args, **kwargs), None + + return registry_generator_fn + + def model_forward_pass( input_model: tf.keras.Model, inputs: type_aliases.PackedTensors, @@ -144,101 +209,6 @@ def all_trainable_layers_are_registered( return True -def _infer_loss_reduction_type(model: tf.keras.Model): - """Infers what type of loss reduction is being performed.""" - model_loss = model.loss - if isinstance(model_loss, tf.keras.losses.Loss): - return model_loss.reduction - elif isinstance(model.loss, dict): - reductions = set() - compiled_loss = model.compiled_loss - if compiled_loss is None: - raise ValueError('Model must be compiled for adding noise') - new_config_list = compiled_loss.get_config()['losses'] - for loss_config in new_config_list: - reductions.add(loss_config['config']['reduction']) - if len(reductions) > 1: - raise ValueError( - 'Reductions in models with multiple losses must all be the same' - ) - return reductions.pop() - else: - raise ValueError( - 'Unsupported type for adding noise: {}'.format(type(model_loss)) - ) - - -def add_aggregate_noise( - clipped_grads: list[tf.Tensor], - batch_size: tf.Tensor, - l2_norm_clip: float, - noise_multiplier: float, - loss_reduction: Optional[Literal['mean', 'sum']] = None, - loss_model: Optional[tf.keras.Model] = None, -) -> Sequence[tf.Tensor]: - """Adds noise to a collection of clipped gradients. - - The magnitude of the noise depends on the aggregation strategy of the - input model's loss function. - - Args: - clipped_grads: A list of `tf.Tensor`s representing the clipped gradients. - batch_size: The batch size. Used for normalizing the noise when - `loss_reduction` is 'sum'. - l2_norm_clip: Clipping norm (max L2 norm of each gradient). - noise_multiplier: Ratio of the standard deviation to the clipping norm. - loss_reduction: An string description of how the loss is reduced over - examples. Currently supports 'mean' and 'sum'. If `None`, then the - aggregation type must be inferred from `input_model.loss`. - loss_model: An optional `tf.keras.Model` used to infer the loss reduction - strategy from if `loss_reduction` is `None`. - - Returns: - A list of tensors containing the clipped gradients, but with the right - amount of Gaussian noise added to them (depending on the reduction - strategy of the loss function). - - Raises: - ValueError: If both `loss_model` and `loss_reduction` are `None` or if - they are both not `None`. - """ - if loss_reduction is None and loss_model is None: - raise ValueError( - 'Exactly one of `loss_reduction` and `loss_model` must be populated.' - ' Instead, both arguments were `None`.' - ) - if loss_reduction is not None and loss_model is not None: - raise ValueError( - 'Exactly one of `loss_reduction` and `loss_model` must be populated.' - ' Instead, both arguments were not `None`.' - ) - - if loss_reduction is None and loss_model is not None: - implicit_mean_reductions = [ - tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE, - tf.keras.losses.Reduction.AUTO, - ] - model_reduction = _infer_loss_reduction_type(loss_model) - loss_reduction = ( - 'mean' if model_reduction in implicit_mean_reductions else 'sum' - ) - if model_reduction == tf.keras.losses.Reduction.AUTO: - logging.info( - 'Assuming that the model loss reduction is `SUM_OVER_BATCH_SIZE`.' - ) - - scale = l2_norm_clip - if loss_reduction == 'mean': - scale /= tf.cast(batch_size, tf.float32) - - def add_noise(g): - return g + tf.random.normal( - tf.shape(g), mean=0.0, stddev=noise_multiplier * scale - ) - - return tf.nest.map_structure(add_noise, clipped_grads) - - def generate_model_outputs_using_core_keras_layers( input_model: tf.keras.Model, custom_layer_set: Optional[Set[type]] = None, # pylint: disable=g-bare-generic diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py index f7cf3b0f..7069273d 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py @@ -15,7 +15,6 @@ from typing import Any from absl.testing import parameterized -import numpy as np import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils @@ -135,60 +134,6 @@ def test_outputs_are_consistent( self.assertAllClose(computed_outputs, true_outputs) -class AddAggregateNoise(tf.test.TestCase, parameterized.TestCase): - - @parameterized.product( - l2_norm_clip=[3.0, 5.0], - noise_multiplier=[2.0, 4.0], - batch_size=[1, 2, 10], - model_fn_reduction=[None, 'auto', 'sum_over_batch_size', 'sum'], - noise_fn_reduction=[None, 'mean', 'sum'], - ) - def test_noise_is_computed_correctly( - self, - l2_norm_clip, - noise_multiplier, - batch_size, - model_fn_reduction, - noise_fn_reduction, - ): - # Skip invalid combinations. - if model_fn_reduction is None and noise_fn_reduction is None: - return - if model_fn_reduction is not None and noise_fn_reduction is not None: - return - # Make an simple model container for storing the loss. - if model_fn_reduction is not None: - linear_model = tf.keras.Sequential([tf.keras.layers.Dense(1)]) - linear_model.compile( - loss=tf.keras.losses.MeanSquaredError(reduction=model_fn_reduction) - ) - else: - linear_model = None - # The main computation is done on a deterministic dummy vector. - num_units = 100 - clipped_grads = [ - tf.expand_dims(np.arange(num_units, dtype=np.float32), axis=-1) - ] - noised_grads = gradient_clipping_utils.add_aggregate_noise( - clipped_grads, - batch_size, - l2_norm_clip, - noise_multiplier, - noise_fn_reduction, - linear_model, - ) - # The only measure that varies is the standard deviation of the variation. - scale = ( - 1.0 - if noise_fn_reduction == 'sum' or model_fn_reduction == 'sum' - else 1.0 / batch_size - ) - computed_std = np.std(noised_grads[0] - clipped_grads[0]) - expected_std = l2_norm_clip * noise_multiplier * scale - self.assertNear(computed_std, expected_std, 0.1 * expected_std) - - class GenerateOutputsUsingCoreKerasLayers( tf.test.TestCase, parameterized.TestCase ): diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils.py new file mode 100644 index 00000000..7dd2f157 --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils.py @@ -0,0 +1,115 @@ +# Copyright 2024, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility functions that help in adding noise to gradients.""" + +from collections.abc import Sequence +from typing import Literal, Optional + +from absl import logging +import tensorflow as tf + + +def _infer_loss_reduction_type(model: tf.keras.Model): + """Infers what type of loss reduction is being performed.""" + model_loss = model.loss + if isinstance(model_loss, tf.keras.losses.Loss): + return model_loss.reduction + elif isinstance(model.loss, dict): + reductions = set() + compiled_loss = model.compiled_loss + if compiled_loss is None: + raise ValueError('Model must be compiled for adding noise') + new_config_list = compiled_loss.get_config()['losses'] + for loss_config in new_config_list: + reductions.add(loss_config['config']['reduction']) + if len(reductions) > 1: + raise ValueError( + 'Reductions in models with multiple losses must all be the same' + ) + return reductions.pop() + else: + raise ValueError( + 'Unsupported type for adding noise: {}'.format(type(model_loss)) + ) + + +def add_aggregate_noise( + clipped_grads: list[tf.Tensor], + batch_size: tf.Tensor, + l2_norm_clip: float, + noise_multiplier: float, + loss_reduction: Optional[Literal['mean', 'sum']] = None, + loss_model: Optional[tf.keras.Model] = None, +) -> Sequence[tf.Tensor]: + """Adds noise to a collection of clipped gradients. + + The magnitude of the noise depends on the aggregation strategy of the + input model's loss function. + + Args: + clipped_grads: A list of `tf.Tensor`s representing the clipped gradients. + batch_size: The batch size. Used for normalizing the noise when + `loss_reduction` is 'sum'. + l2_norm_clip: Clipping norm (max L2 norm of each gradient). + noise_multiplier: Ratio of the standard deviation to the clipping norm. + loss_reduction: An string description of how the loss is reduced over + examples. Currently supports 'mean' and 'sum'. If `None`, then the + aggregation type must be inferred from `input_model.loss`. + loss_model: An optional `tf.keras.Model` used to infer the loss reduction + strategy from if `loss_reduction` is `None`. + + Returns: + A list of tensors containing the clipped gradients, but with the right + amount of Gaussian noise added to them (depending on the reduction + strategy of the loss function). + + Raises: + ValueError: If both `loss_model` and `loss_reduction` are `None` or if + they are both not `None`. + """ + if loss_reduction is None and loss_model is None: + raise ValueError( + 'Exactly one of `loss_reduction` and `loss_model` must be populated.' + ' Instead, both arguments were `None`.' + ) + if loss_reduction is not None and loss_model is not None: + raise ValueError( + 'Exactly one of `loss_reduction` and `loss_model` must be populated.' + ' Instead, both arguments were not `None`.' + ) + + if loss_reduction is None and loss_model is not None: + implicit_mean_reductions = [ + tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE, + tf.keras.losses.Reduction.AUTO, + ] + model_reduction = _infer_loss_reduction_type(loss_model) + loss_reduction = ( + 'mean' if model_reduction in implicit_mean_reductions else 'sum' + ) + if model_reduction == tf.keras.losses.Reduction.AUTO: + logging.info( + 'Assuming that the model loss reduction is `SUM_OVER_BATCH_SIZE`.' + ) + + scale = l2_norm_clip + if loss_reduction == 'mean': + scale /= tf.cast(batch_size, tf.float32) + + def add_noise(g): + return g + tf.random.normal( + tf.shape(g), mean=0.0, stddev=noise_multiplier * scale + ) + + return tf.nest.map_structure(add_noise, clipped_grads) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils_test.py new file mode 100644 index 00000000..9d5cac3b --- /dev/null +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils_test.py @@ -0,0 +1,72 @@ +# Copyright 2024, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import parameterized +import numpy as np +import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import noise_utils + + +class NoiseUtilsTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.product( + l2_norm_clip=[3.0, 5.0], + noise_multiplier=[2.0, 4.0], + batch_size=[1, 2, 10], + model_fn_reduction=[None, 'auto', 'sum_over_batch_size', 'sum'], + noise_fn_reduction=[None, 'mean', 'sum'], + ) + def test_noise_is_computed_correctly( + self, + l2_norm_clip, + noise_multiplier, + batch_size, + model_fn_reduction, + noise_fn_reduction, + ): + # Skip invalid combinations. + if model_fn_reduction is None and noise_fn_reduction is None: + return + if model_fn_reduction is not None and noise_fn_reduction is not None: + return + # Make an simple model container for storing the loss. + if model_fn_reduction is not None: + linear_model = tf.keras.Sequential([tf.keras.layers.Dense(1)]) + linear_model.compile( + loss=tf.keras.losses.MeanSquaredError(reduction=model_fn_reduction) + ) + else: + linear_model = None + # The main computation is done on a deterministic dummy vector. + num_units = 100 + clipped_grads = [ + tf.expand_dims(np.arange(num_units, dtype=np.float32), axis=-1) + ] + noised_grads = noise_utils.add_aggregate_noise( + clipped_grads, + batch_size, + l2_norm_clip, + noise_multiplier, + noise_fn_reduction, + linear_model, + ) + # The only measure that varies is the standard deviation of the variation. + scale = ( + 1.0 + if noise_fn_reduction == 'sum' or model_fn_reduction == 'sum' + else 1.0 / batch_size + ) + computed_std = np.std(noised_grads[0] - clipped_grads[0]) + expected_std = l2_norm_clip * noise_multiplier * scale + self.assertNear(computed_std, expected_std, 0.1 * expected_std) diff --git a/tensorflow_privacy/privacy/keras_models/BUILD b/tensorflow_privacy/privacy/keras_models/BUILD index a4ae42c8..93afc859 100644 --- a/tensorflow_privacy/privacy/keras_models/BUILD +++ b/tensorflow_privacy/privacy/keras_models/BUILD @@ -17,7 +17,7 @@ py_library( "//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads", "//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils", "//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils", - "//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry", + "//tensorflow_privacy/privacy/fast_gradient_clipping:noise_utils", ], ) diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py index 4879b3bb..2a28d69b 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py @@ -18,6 +18,7 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils +from tensorflow_privacy.privacy.fast_gradient_clipping import noise_utils _PRIVATIZED_LOSS_NAME = 'privatized_loss' @@ -287,7 +288,7 @@ def train_step(self, data): ) output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss if self._noise_multiplier > 0: - grads = gradient_clipping_utils.add_aggregate_noise( + grads = noise_utils.add_aggregate_noise( clipped_grads, num_microbatches, self._l2_norm_clip,