From ed12d2f02a804cc4f23577e2def8c3c870f9520f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 7 Aug 2024 13:40:23 -0700 Subject: [PATCH] Sparsity Preserving DP-SGD in TF Privacy Add support for adding sparsity preserving noise in add_aggregate_noise See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm. PiperOrigin-RevId: 660521257 --- .../privacy/fast_gradient_clipping/BUILD | 3 + .../fast_gradient_clipping/clip_grads.py | 267 +++++++----------- .../fast_gradient_clipping/clip_grads_test.py | 35 ++- .../gradient_clipping_utils.py | 145 ++++++++++ .../fast_gradient_clipping/noise_utils.py | 56 +++- .../noise_utils_test.py | 92 ++++++ .../fast_gradient_clipping/type_aliases.py | 2 + .../privacy/keras_models/dp_keras_model.py | 24 +- 8 files changed, 451 insertions(+), 173 deletions(-) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD index adb5a763..acabca76 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD @@ -43,6 +43,7 @@ py_library( srcs = ["gradient_clipping_utils.py"], srcs_version = "PY3", deps = [ + ":common_manip_utils", ":layer_registry", ":type_aliases", ], @@ -83,6 +84,7 @@ py_library( py_library( name = "noise_utils", srcs = ["noise_utils.py"], + deps = ["//tensorflow_privacy/privacy/sparsity_preserving_noise:sparse_noise_utils"], ) py_test( @@ -94,6 +96,7 @@ py_test( deps = [ ":clip_grads", ":common_test_utils", + ":gradient_clipping_utils", ":layer_registry", ":type_aliases", ], diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py index 100e66ab..09dd2872 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py @@ -22,7 +22,7 @@ """ import collections -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from typing import Optional import tensorflow as tf @@ -32,73 +32,81 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases -def _infer_per_example_loss_fn(model: tf.keras.Model): - """Infer the per-example loss from model config.""" +def _compute_gradient_norms_internal( + registry_fn_outputs_list: Sequence[ + gradient_clipping_utils.RegistryGeneratorFunctionOutput + ], + layer_grad_vars: Mapping[str, Sequence[type_aliases.Tensor]], + trainable_vars: Optional[Sequence[tf.Variable]] = None, +): + """Computes the per-example loss gradient norms for given data. - def _convert(loss_fn): - loss_config = loss_fn.get_config() - loss_config['reduction'] = tf.keras.losses.Reduction.NONE - return loss_fn.from_config(loss_config) + Args: + registry_fn_outputs_list: A `list` of RegistryGeneratorFunctionOutput + containing information required to compute the gradient norms and + contribution counts. Output from + `gradient_clipping_utils.model_forward_backward_pass()`. + layer_grad_vars: A mapping of layer id to a list of gradients for each + trainablev ariable in the layer. Output from + `gradient_clipping_utils.model_forward_backward_pass()`. + trainable_vars: The list of variables included in computing the gradient + norm. When a layer has multiple variables, we include all the variables if + any of the variables is in the list. If `trainable_vars` is None, all the + variables are included. - model_loss = model.loss - if isinstance(model_loss, tf.keras.losses.Loss): - return _convert(model_loss) - elif isinstance(model_loss, dict): - # Note that we cannot call the public method `.get_compile_config()` because - # it calls a numpy function, which is not supported inside a `tf.function` - # wrapped function. - compile_config = model._compile_config.config # pylint: disable=protected-access - if compile_config is None: - raise ValueError('Model must be compiled for loss function conversion') - # Does a weighted mean of the configured losses. Note that we cannot build - # from the config of the compiled loss because (i) it builds a - # `keras.metrics.Mean` class, which generates non-unique `tf.Variable`s - # during its construction, (ii) non-unique `tf.Variables` cannot be used - # inside a `tf.function`, which is usually where this function is used. - if 'loss_weights' not in compile_config: - raise ValueError( - 'Models with multiple loss must have corresponding loss weights for' - ' loss function conversion' - ) - weights = compile_config['loss_weights'] - per_example_losses = {k: _convert(v) for k, v in model_loss.items()} - num_losses = len(weights) + Returns: + A scalar vector, whose i-th entry is the norm of the gradient of the i-th + weighted example loss (when num_microbatches is None) or the norm of the + gradient of the i-th microbatch loss (define as a mean over the microbatch). + Note that when the loss is weighted (`weight_batch` is not None), weights + are applied prior to clipping. - def _per_example_loss_fn(y_true, y_pred, sample_weight=None): - loss_values = [] - if model_loss.keys() - y_pred.keys(): - raise ValueError( - 'y_pred must contain the same keys and the model losses, but ' - 'got %s and %s' % (y_pred.keys(), model_loss.keys()) - ) - if model_loss.keys() - y_true.keys(): - raise ValueError( - 'y_true must contain the same keys and the model losses, but ' - 'got %s and %s' % (y_true.keys(), model_loss.keys()) - ) - if sample_weight is not None: - if model_loss.keys() - sample_weight.keys(): - raise ValueError( - 'sample_weight must contain the same keys and the model losses,' - ' but got %s and %s' % (y_true.keys(), model_loss.keys()) - ) - for k in y_true.keys(): - sgl_sample_weight = None if sample_weight is None else sample_weight[k] - sgl_value = ( - weights[k] - * per_example_losses[k](y_true[k], y_pred[k], sgl_sample_weight) - / num_losses - ) - loss_values.append(tf.reshape(sgl_value, shape=[-1])) - return tf.math.add_n(loss_values) + Raises: + ValueError: If `layer_grad_vars` is empty. + ValueError: If the number of gradients for a layer is not equal to the + number of squared norm functions for that layer. + """ + if trainable_vars is not None: + # Create a set using `ref()` for fast set membership check. tf.Variable + # itself is not hashable. + trainable_vars = set([v.ref() for v in trainable_vars]) - return _per_example_loss_fn - else: - raise ValueError( - 'Unsupported type for loss function conversion: {}'.format( - type(model_loss) - ) - ) + layer_sqr_norm_fns = collections.defaultdict(list) + # The case of shared weights: + # If a layer is called k times, it will appear k times in filtered_outputs, + # with the same id, but potentially with different v and f. The code below + # groups filtered_outputs by layer_id, so we can correctly compute gradient + # norms. The gradient norm of a layer that occurs k times is computed as + # $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th + # occurrence. This is an over-estimate of the actual norm. For more details, + # see the explanation in go/dp-sgd-shared-weights. + for registry_fn_output in registry_fn_outputs_list: + if trainable_vars is None or any( + w.ref() in trainable_vars + for w in registry_fn_output.layer_trainable_weights + ): + layer_sqr_norm_fns[registry_fn_output.layer_id].append( + registry_fn_output.layer_sqr_norm_fn + ) + + if not layer_grad_vars: + raise ValueError('The gradient list cannot be empty.') + sqr_norm_list = [] + for layer_id in layer_sqr_norm_fns.keys(): + fns = layer_sqr_norm_fns[layer_id] + grads = layer_grad_vars[layer_id] + # Number of duplicates for this layer in `filtered_outputs`. + num_passes = len(fns) + if len(fns) != len(grads): + raise ValueError( + 'There must be as many gradients as squared norm functions.' + ) + # See go/dp-sgd-shared-weights for more details. + for fn, grad in zip(fns, grads): + sqr_norm_list.append(num_passes * fn(grad)) + sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1) + gradient_norms = tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1)) + return gradient_norms def compute_gradient_norms( @@ -110,7 +118,7 @@ def compute_gradient_norms( per_example_loss_fn: Optional[type_aliases.LossFn] = None, num_microbatches: Optional[type_aliases.BatchSize] = None, trainable_vars: Optional[Sequence[tf.Variable]] = None, -): +) -> tf.Tensor: """Computes the per-example loss gradient norms for given data. Applies a variant of the approach given in @@ -154,90 +162,27 @@ def compute_gradient_norms( """ tape = tf.GradientTape(persistent=True, watch_accessed_variables=False) registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn( - tape, layer_registry, num_microbatches + tape=tape, + layer_registry=layer_registry, + num_microbatches=num_microbatches, ) - # First loop computes the model outputs, summed loss, and generator outputs. - with tape: - model_outputs, generator_outputs_list = ( - gradient_clipping_utils.model_forward_pass( - input_model, x_batch, generator_fn=registry_generator_fn - ) - ) - - # Ignore the original loss function's reduction to get per-example loss. - if per_example_loss_fn is None: - per_example_loss_fn = _infer_per_example_loss_fn(input_model) - - losses = per_example_loss_fn(y_batch, model_outputs, weight_batch) - if losses.shape is None: - raise NotImplementedError( - "The unreduced (or per-example) loss's shape cannot be `None`" - ) - if len(losses.shape) != 1: - raise NotImplementedError( - 'The unreduced (or per-example) loss needs to have a shape of length ' - 'one, but received an unreduced loss of shape length %s' - % len(losses.shape) - ) - if num_microbatches is not None: - losses = tf.reduce_mean( - common_manip_utils.maybe_add_microbatch_axis( - losses, num_microbatches - ), - axis=1, - ) - summed_loss = tf.reduce_sum(losses) - # Unwrap the generator outputs so that the next loop avoids duplicating - # backprop ops. - filtered_outputs = [t for t in generator_outputs_list if t is not None] - if trainable_vars is not None: - # Create a set using `ref()` for fast set membership check. tf.Variable - # itself is not hashable. - trainable_vars = set([v.ref() for v in trainable_vars]) - layer_vars = collections.defaultdict(list) - layer_sqr_norm_fns = collections.defaultdict(list) - # The case of shared weights: - # If a layer is called k times, it will appear k times in filtered_outputs, - # with the same id, but potentially with different v and f. The code below - # groups filtered_outputs by layer_id, so we can correctly compute gradient - # norms. The gradient norm of a layer that occurs k times is computed as - # $sqrt(k * \sum_i c_i^2)$ where $c_i$ is the norm estimate of its i-th - # occurrence. This is an over-estimate of the actual norm. For more details, - # see the explanation in go/dp-sgd-shared-weights. - for registry_fn_output in filtered_outputs: - if trainable_vars is None or any( - w.ref() in trainable_vars - for w in registry_fn_output.layer_trainable_weights - ): - layer_vars[registry_fn_output.layer_id].append( - registry_fn_output.layer_vars + layer_grad_vars, generator_outputs_list = ( + gradient_clipping_utils.model_forward_backward_pass( + tape=tape, + input_model=input_model, + x_batch=x_batch, + y_batch=y_batch, + registry_generator_fn=registry_generator_fn, + weight_batch=weight_batch, + per_example_loss_fn=per_example_loss_fn, + num_microbatches=num_microbatches, ) - layer_sqr_norm_fns[registry_fn_output.layer_id].append( - registry_fn_output.layer_sqr_norm_fn - ) - # Second loop evaluates the squared L2 norm functions and appends the results. - layer_grad_vars = tape.gradient( - summed_loss, - layer_vars, - unconnected_gradients=tf.UnconnectedGradients.ZERO, ) - if not layer_grad_vars: - raise ValueError('The gradient list cannot be empty.') - sqr_norm_list = [] - for layer_id in layer_sqr_norm_fns.keys(): - fns = layer_sqr_norm_fns[layer_id] - grads = layer_grad_vars[layer_id] - # Number of duplicates for this layer in `filtered_outputs`. - num_passes = len(fns) - if len(fns) != len(grads): - raise ValueError( - 'There must be as many gradients as squared norm functions.' - ) - # See go/dp-sgd-shared-weights for more details. - for fn, grad in zip(fns, grads): - sqr_norm_list.append(num_passes * fn(grad)) - sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1) - return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1)) + return _compute_gradient_norms_internal( + registry_fn_outputs_list=generator_outputs_list, + layer_grad_vars=layer_grad_vars, + trainable_vars=trainable_vars, + ) def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor): @@ -267,14 +212,17 @@ def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor): def compute_clipped_gradients_and_outputs( input_model: tf.keras.Model, + registry_fn_outputs_list: Sequence[ + gradient_clipping_utils.RegistryGeneratorFunctionOutput + ], + layer_grad_vars: Mapping[str, Sequence[type_aliases.Tensor]], l2_norm_clip: float, - layer_registry: lr.LayerRegistry, x_batch: type_aliases.InputTensors, y_batch: type_aliases.OutputTensors, weight_batch: Optional[tf.Tensor] = None, num_microbatches: Optional[type_aliases.BatchSize] = None, clipping_loss: Optional[type_aliases.LossFn] = None, -) -> tuple[Sequence[tf.Tensor], tf.Tensor, tf.Tensor]: +) -> tuple[Sequence[type_aliases.Tensor], tf.Tensor, tf.Tensor]: """Computes the per-example clipped loss gradient and other useful outputs. Given a batch of observations `(x_batch, y_batch, weight_batch)`, the main @@ -287,15 +235,16 @@ def compute_clipped_gradients_and_outputs( Args: input_model: The `tf.keras.Model` from which to obtain the layers from. + registry_fn_outputs_list: A `list` of RegistryGeneratorFunctionOutput + containing information required to compute the gradient norms and + contribution counts. Output from + `gradient_clipping_utils.model_forward_backward_pass()`. + layer_grad_vars: A mapping of layer id to a list of gradients for each + trainablev ariable in the layer. Output from + `gradient_clipping_utils.model_forward_backward_pass()`. l2_norm_clip: A `float` indicating the norm to which per-example gradients will be clipped. That is, all gradients of the per-example loss functions will have norm at most `l2_norm_clip`. - layer_registry: A `dict` of layers that support "fast" gradient norm - computations. The key is the class of the layer and the value is a - function that returns a `tuple` `(output, sqr_grad_norms, vars)`, where - `output` is the pre-activator tensor, `sqr_grad_norms` is related to the - squared norms of a layer's pre-activation tensor, and `vars` are relevant - trainable weights (see `layer_registry_factories.py` for examples). x_batch: An `InputTensor` representing a batch of inputs to the model. The first axes of each tensor must be the batch dimension. y_batch: An `OutputTensor` representing a batch of output labels. The first @@ -330,13 +279,9 @@ def compute_clipped_gradients_and_outputs( ) if clipping_loss is None: clipping_loss = input_model.compiled_loss - gradient_norms = compute_gradient_norms( - input_model, - layer_registry, - x_batch, - y_batch, - weight_batch, - num_microbatches=num_microbatches, + gradient_norms = _compute_gradient_norms_internal( + registry_fn_outputs_list=registry_fn_outputs_list, + layer_grad_vars=layer_grad_vars, trainable_vars=input_model.trainable_variables, ) clip_weights = compute_clip_weights(l2_norm_clip, gradient_norms) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py index a028affe..6ec47941 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py @@ -19,6 +19,7 @@ import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils +from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases @@ -122,6 +123,30 @@ def test_gradient_norms_on_various_models( self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) +def _run_model_forward_backward_pass( + model: tf.keras.Model, + x_batch: type_aliases.InputTensors, + y_batch: type_aliases.OutputTensors, +): + tape = tf.GradientTape(persistent=True, watch_accessed_variables=False) + registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn( + tape=tape, + layer_registry=layer_registry.make_default_layer_registry(), + sparse_noise_layer_registry=None, + num_microbatches=None, + ) + layer_grad_vars, registry_fn_outputs_list = ( + gradient_clipping_utils.model_forward_backward_pass( + tape=tape, + input_model=model, + x_batch=x_batch, + y_batch=y_batch, + registry_generator_fn=registry_generator_fn, + ) + ) + return layer_grad_vars, registry_fn_outputs_list + + class ComputeClippedGradsAndOutputsTest( tf.test.TestCase, parameterized.TestCase ): @@ -153,13 +178,17 @@ def test_clipped_gradients_on_different_losses( y_batch = tf.reshape( 1.0 + tf.range(batch_size, dtype=tf.float32), [batch_size, -1] ) + layer_grad_vars, registry_fn_outputs_list = ( + _run_model_forward_backward_pass(self._model, x_batch, y_batch) + ) # Stop early for efficiency. if reduction == 'none': with self.assertRaises(NotImplementedError): clip_grads.compute_clipped_gradients_and_outputs( self._model, + registry_fn_outputs_list, + layer_grad_vars, l2_norm_clip, - layer_registry.make_default_layer_registry(), x_batch, y_batch, ) @@ -169,10 +198,12 @@ def test_clipped_gradients_on_different_losses( y_pred = self._model(x_batch) loss_value = loss_fn(y_pred, y_batch) true_grads = tape.gradient(loss_value, self._model.trainable_variables) + clipped_grads, _, _ = clip_grads.compute_clipped_gradients_and_outputs( self._model, + registry_fn_outputs_list, + layer_grad_vars, l2_norm_clip, - layer_registry.make_default_layer_registry(), x_batch, y_batch, ) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py index 7a060e9b..bac323ce 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py @@ -13,11 +13,13 @@ # limitations under the License. """Utility functions that help in the computation of per-example gradient norms.""" +import collections from collections.abc import Callable, Sequence, Set import dataclasses from typing import Any, Optional, Tuple import tensorflow as tf +from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases @@ -98,6 +100,149 @@ def registry_generator_fn(layer_instance, args, kwargs): return registry_generator_fn +def _infer_per_example_loss_fn(model: tf.keras.Model): + """Infer the per-example loss from model config.""" + + def _convert(loss_fn): + loss_config = loss_fn.get_config() + loss_config['reduction'] = tf.keras.losses.Reduction.NONE + return loss_fn.from_config(loss_config) + + model_loss = model.loss + if isinstance(model_loss, tf.keras.losses.Loss): + return _convert(model_loss) + elif isinstance(model_loss, dict): + # Note that we cannot call the public method `.get_compile_config()` because + # it calls a numpy function, which is not supported inside a `tf.function` + # wrapped function. + compile_config = model._compile_config.config # pylint: disable=protected-access + if compile_config is None: + raise ValueError('Model must be compiled for loss function conversion') + # Does a weighted mean of the configured losses. Note that we cannot build + # from the config of the compiled loss because (i) it builds a + # `keras.metrics.Mean` class, which generates non-unique `tf.Variable`s + # during its construction, (ii) non-unique `tf.Variables` cannot be used + # inside a `tf.function`, which is usually where this function is used. + if 'loss_weights' not in compile_config: + raise ValueError( + 'Models with multiple loss must have corresponding loss weights for' + ' loss function conversion' + ) + weights = compile_config['loss_weights'] + per_example_losses = {k: _convert(v) for k, v in model_loss.items()} + num_losses = len(weights) + + def _per_example_loss_fn(y_true, y_pred, sample_weight=None): + loss_values = [] + if model_loss.keys() - y_pred.keys(): + raise ValueError( + 'y_pred must contain the same keys and the model losses, but ' + 'got %s and %s' % (y_pred.keys(), model_loss.keys()) + ) + if model_loss.keys() - y_true.keys(): + raise ValueError( + 'y_true must contain the same keys and the model losses, but ' + 'got %s and %s' % (y_true.keys(), model_loss.keys()) + ) + if sample_weight is not None: + if model_loss.keys() - sample_weight.keys(): + raise ValueError( + 'sample_weight must contain the same keys and the model losses,' + ' but got %s and %s' % (y_true.keys(), model_loss.keys()) + ) + for k in y_true.keys(): + sgl_sample_weight = None if sample_weight is None else sample_weight[k] + sgl_value = ( + weights[k] + * per_example_losses[k](y_true[k], y_pred[k], sgl_sample_weight) + / num_losses + ) + loss_values.append(tf.reshape(sgl_value, shape=[-1])) + return tf.math.add_n(loss_values) + + return _per_example_loss_fn + else: + raise ValueError( + 'Unsupported type for loss function conversion: {}'.format( + type(model_loss) + ) + ) + + +def model_forward_backward_pass( + tape: tf.GradientTape, + input_model: tf.keras.Model, + x_batch: type_aliases.InputTensors, + y_batch: type_aliases.OutputTensors, + registry_generator_fn: Optional[ + Callable[..., Tuple[tf.Tensor, RegistryGeneratorFunctionOutput]] + ], + weight_batch: Optional[tf.Tensor] = None, + per_example_loss_fn: Optional[type_aliases.LossFn] = None, + num_microbatches: Optional[type_aliases.BatchSize] = None, + trainable_vars: Optional[Sequence[tf.Variable]] = None, +) -> tuple[ + dict[str, list[type_aliases.Tensor]], list[RegistryGeneratorFunctionOutput] +]: + """Does a forward and backward pass of a model and returns useful intermediates.""" + # First loop computes the model outputs, summed loss, and generator outputs. + with tape: + model_outputs, generator_outputs_list = model_forward_pass( + input_model, x_batch, generator_fn=registry_generator_fn + ) + + # Ignore the original loss function's reduction to get per-example loss. + if per_example_loss_fn is None: + per_example_loss_fn = _infer_per_example_loss_fn(input_model) + + losses = per_example_loss_fn(y_batch, model_outputs, weight_batch) + if losses.shape is None: + raise NotImplementedError( + "The unreduced (or per-example) loss's shape cannot be `None`" + ) + if len(losses.shape) != 1: + raise NotImplementedError( + 'The unreduced (or per-example) loss needs to have a shape of length ' + 'one, but received an unreduced loss of shape length %s' + % len(losses.shape) + ) + if num_microbatches is not None: + losses = tf.reduce_mean( + common_manip_utils.maybe_add_microbatch_axis( + losses, num_microbatches + ), + axis=1, + ) + summed_loss = tf.reduce_sum(losses) + # Unwrap the generator outputs so that the next loop avoids duplicating + # backprop ops. + filtered_outputs = [t for t in generator_outputs_list if t is not None] + + if trainable_vars is not None: + # Create a set using `ref()` for fast set membership check. tf.Variable + # itself is not hashable. + trainable_vars = set([v.ref() for v in trainable_vars]) + layer_vars = collections.defaultdict(list) + for registry_fn_output in filtered_outputs: + if trainable_vars is None or any( + w.ref() in trainable_vars + for w in registry_fn_output.layer_trainable_weights + ): + layer_vars[registry_fn_output.layer_id].append( + registry_fn_output.layer_vars + ) + + layer_grad_vars = tape.gradient( + summed_loss, + layer_vars, + unconnected_gradients=tf.UnconnectedGradients.ZERO, + ) + if not layer_grad_vars: + raise ValueError('The gradient list cannot be empty.') + + return layer_grad_vars, filtered_outputs + + def model_forward_pass( input_model: tf.keras.Model, inputs: type_aliases.PackedTensors, diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils.py index 7dd2f157..7349a812 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils.py @@ -14,10 +14,21 @@ """Utility functions that help in adding noise to gradients.""" from collections.abc import Sequence +import dataclasses from typing import Literal, Optional from absl import logging import tensorflow as tf +from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils + + +@dataclasses.dataclass +class SparsityPreservingNoiseConfig: + """Configuration for adding noise to gradients.""" + + sparse_noise_multiplier: float = 0.0 + sparse_selection_threshold: int = 0 + sparse_contribution_counts: Optional[Sequence[tf.SparseTensor]] = None def _infer_loss_reduction_type(model: tf.keras.Model): @@ -45,20 +56,22 @@ def _infer_loss_reduction_type(model: tf.keras.Model): def add_aggregate_noise( - clipped_grads: list[tf.Tensor], + clipped_grads: list[tf.Tensor | tf.IndexedSlices], batch_size: tf.Tensor, l2_norm_clip: float, noise_multiplier: float, loss_reduction: Optional[Literal['mean', 'sum']] = None, loss_model: Optional[tf.keras.Model] = None, -) -> Sequence[tf.Tensor]: + sparse_noise_config: Optional[SparsityPreservingNoiseConfig] = None, +) -> Sequence[tf.Tensor | tf.IndexedSlices]: """Adds noise to a collection of clipped gradients. The magnitude of the noise depends on the aggregation strategy of the input model's loss function. Args: - clipped_grads: A list of `tf.Tensor`s representing the clipped gradients. + clipped_grads: A list of `tf.Tensor`s or `tf.IndexedSlices`s representing + the clipped gradients. batch_size: The batch size. Used for normalizing the noise when `loss_reduction` is 'sum'. l2_norm_clip: Clipping norm (max L2 norm of each gradient). @@ -68,11 +81,14 @@ def add_aggregate_noise( aggregation type must be inferred from `input_model.loss`. loss_model: An optional `tf.keras.Model` used to infer the loss reduction strategy from if `loss_reduction` is `None`. + sparse_noise_config: A `SparsityPreservingNoiseConfig` instance containing + the configuration for adding sparse noise. If None, all noise added is + dense. Returns: A list of tensors containing the clipped gradients, but with the right - amount of Gaussian noise added to them (depending on the reduction - strategy of the loss function). + amount of Gaussian or sparse Gaussain noise added to them (depending on + the reduction strategy of the loss function). Raises: ValueError: If both `loss_model` and `loss_reduction` are `None` or if @@ -103,13 +119,35 @@ def add_aggregate_noise( 'Assuming that the model loss reduction is `SUM_OVER_BATCH_SIZE`.' ) + if sparse_noise_config is None: + sparse_contribution_counts = tf.nest.map_structure( + lambda x: None, clipped_grads + ) + else: + sparse_contribution_counts = sparse_noise_config.sparse_contribution_counts + scale = l2_norm_clip if loss_reduction == 'mean': scale /= tf.cast(batch_size, tf.float32) - def add_noise(g): - return g + tf.random.normal( - tf.shape(g), mean=0.0, stddev=noise_multiplier * scale + def add_noise(grad, contribution_counts): + if ( + sparse_noise_config is not None + and isinstance(grad, tf.IndexedSlices) + and contribution_counts is not None + ): + return sparse_noise_utils.add_sparse_noise( + grad, + contribution_counts, + noise_multiplier, + sparse_noise_config.sparse_noise_multiplier, + l2_norm_clip, + sparse_noise_config.sparse_selection_threshold, + ) + return grad + tf.random.normal( + tf.shape(grad), mean=0.0, stddev=noise_multiplier * scale ) - return tf.nest.map_structure(add_noise, clipped_grads) + return tf.nest.map_structure( + add_noise, clipped_grads, sparse_contribution_counts + ) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils_test.py index 9d5cac3b..880b8d3a 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/noise_utils_test.py @@ -70,3 +70,95 @@ def test_noise_is_computed_correctly( computed_std = np.std(noised_grads[0] - clipped_grads[0]) expected_std = l2_norm_clip * noise_multiplier * scale self.assertNear(computed_std, expected_std, 0.1 * expected_std) + + @parameterized.product( + l2_norm_clip=[3.0, 5.0], + noise_multiplier=[2.0, 4.0], + sparse_noise_multiplier=[1.0], + batch_size=[1, 2, 10], + model_fn_reduction=[None, 'auto', 'sum_over_batch_size', 'sum'], + noise_fn_reduction=[None, 'mean', 'sum'], + ) + def test_sparse_noise_is_computed_correctly( + self, + l2_norm_clip, + noise_multiplier, + sparse_noise_multiplier, + batch_size, + model_fn_reduction, + noise_fn_reduction, + ): + # Skip invalid combinations. + if model_fn_reduction is None and noise_fn_reduction is None: + return + if model_fn_reduction is not None and noise_fn_reduction is not None: + return + # Make an simple model container for storing the loss. + if model_fn_reduction is not None: + linear_model = tf.keras.Sequential([tf.keras.layers.Dense(1)]) + linear_model.compile( + loss=tf.keras.losses.MeanSquaredError(reduction=model_fn_reduction) + ) + else: + linear_model = None + # The main computation is done on a deterministic dummy vector. + num_units = 100 + dense_grad = tf.expand_dims(np.arange(num_units, dtype=np.float32), axis=-1) + sparse_grad = tf.IndexedSlices( + values=tf.ones((3, 4)), + indices=tf.constant([0, 3, 5]), + dense_shape=tf.constant([8, 4]), + ) + sparse_grad_contribution_counts = tf.SparseTensor( + indices=[[0], [3], [5]], + values=[10.0, 10.0, 20.0], + dense_shape=[8], + ) + + sparse_noise_config = noise_utils.SparsityPreservingNoiseConfig( + sparse_noise_multiplier=sparse_noise_multiplier, + sparse_selection_threshold=8, + sparse_contribution_counts=[None, sparse_grad_contribution_counts], + ) + + sparse_noised_grad, dense_noised_grad = noise_utils.add_aggregate_noise( + clipped_grads=[dense_grad, sparse_grad], + batch_size=batch_size, + l2_norm_clip=l2_norm_clip, + noise_multiplier=noise_multiplier, + loss_model=linear_model, + sparse_noise_config=sparse_noise_config, + ) + self.assertContainsSubset( + sparse_grad.indices.numpy().tolist(), + sparse_noised_grad.indices.numpy().tolist(), + ) + sparse_noised_grad_dense = tf.scatter_nd( + tf.reshape(sparse_noised_grad.indices, (-1, 1)), + sparse_noised_grad.values, + shape=(8, 4), + ).numpy() + sparse_noised_grad_valid_indices = sparse_noised_grad_dense[ + sparse_grad.indices.numpy() + ] + sparse_grad_values = sparse_grad.values.numpy() + self.assertTrue( + np.all( + np.not_equal(sparse_noised_grad_valid_indices, sparse_grad_values) + ) + ) + scale = ( + 1.0 + if noise_fn_reduction == 'sum' or model_fn_reduction == 'sum' + else 1.0 / batch_size + ) + # The only measure that varies is the standard deviation of the variation. + expected_std = l2_norm_clip * noise_multiplier * scale + + sparse_computed_std = np.std( + sparse_noised_grad_valid_indices - sparse_grad_values + ) + self.assertNear(sparse_computed_std, expected_std, 0.1 * expected_std) + + dense_computed_std = np.std(dense_noised_grad - dense_grad) + self.assertNear(dense_computed_std, expected_std, 0.1 * expected_std) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/type_aliases.py b/tensorflow_privacy/privacy/fast_gradient_clipping/type_aliases.py index 1e602a01..99468123 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/type_aliases.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/type_aliases.py @@ -19,6 +19,8 @@ # Tensorflow aliases. +Tensor = tf.Tensor + PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Mapping[str, tf.Tensor]] InputTensors = PackedTensors diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py index 2a28d69b..b7104f4d 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py @@ -274,14 +274,36 @@ def train_step(self, data): # trick, and uses these norms to clip the per-example gradients. # NOTE: Reshaping of the input according to the effective number of # microbatches is done here. + tape = tf.GradientTape(persistent=True, watch_accessed_variables=False) + + registry_generator_fn = ( + gradient_clipping_utils.get_registry_generator_fn( + tape=tape, + layer_registry=self._layer_registry, + num_microbatches=num_microbatches, + ) + ) + layer_grad_vars, registry_fn_outputs_list = ( + gradient_clipping_utils.model_forward_backward_pass( + tape=tape, + input_model=self, + x_batch=x, + y_batch=y, + registry_generator_fn=registry_generator_fn, + weight_batch=weights, + num_microbatches=num_microbatches, + trainable_vars=self.trainable_variables, + ) + ) clipped_grads, y_pred, clipping_loss = ( clip_grads.compute_clipped_gradients_and_outputs( input_model=self, + registry_fn_outputs_list=registry_fn_outputs_list, + layer_grad_vars=layer_grad_vars, x_batch=x, y_batch=y, weight_batch=weights, l2_norm_clip=self._l2_norm_clip, - layer_registry=self._layer_registry, num_microbatches=self._num_microbatches, clipping_loss=self._clipping_loss, )